diff --git a/.github/audit-exceptions.yml b/.github/audit-exceptions.yml
index a1d8411c..4e05aae6 100644
--- a/.github/audit-exceptions.yml
+++ b/.github/audit-exceptions.yml
@@ -5,12 +5,33 @@ exceptions:
severity: high
reason: "Admin export only; switched to dynamic import to reduce exposure (CVE-2023-30533)"
mitigation: "Load only on export; restrict export permissions and data scope"
- expires_on: "2026-04-05"
+ expires_on: "2026-07-06"
owner: "security@your-domain"
- package: xlsx
advisory: "GHSA-5pgg-2g8v-p4x9"
severity: high
reason: "Admin export only; switched to dynamic import to reduce exposure (CVE-2024-22363)"
mitigation: "Load only on export; restrict export permissions and data scope"
- expires_on: "2026-04-05"
+ expires_on: "2026-07-06"
+ owner: "security@your-domain"
+ - package: lodash
+ advisory: "GHSA-r5fr-rjxr-66jc"
+ severity: high
+ reason: "lodash _.template not used with untrusted input; only internal admin UI templates"
+ mitigation: "No user-controlled template strings; plan to migrate to lodash-es tree-shaken imports"
+ expires_on: "2026-07-02"
+ owner: "security@your-domain"
+ - package: lodash-es
+ advisory: "GHSA-r5fr-rjxr-66jc"
+ severity: high
+ reason: "lodash-es _.template not used with untrusted input; only internal admin UI templates"
+ mitigation: "No user-controlled template strings; plan to migrate to native JS alternatives"
+ expires_on: "2026-07-02"
+ owner: "security@your-domain"
+ - package: axios
+ advisory: "GHSA-3p68-rc4w-qgx5"
+ severity: critical
+ reason: "NO_PROXY bypass not exploitable; all API calls go to known endpoints via server-side proxy"
+ mitigation: "Proxy configuration not user-controlled; upgrade when axios releases fix"
+ expires_on: "2026-07-10"
owner: "security@your-domain"
diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml
index 01c00bb9..f8b22ee7 100644
--- a/.github/workflows/backend-ci.yml
+++ b/.github/workflows/backend-ci.yml
@@ -17,9 +17,10 @@ jobs:
go-version-file: backend/go.mod
check-latest: false
cache: true
+ cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
- go version | grep -q 'go1.26.1'
+ go version | grep -q 'go1.26.2'
- name: Unit tests
working-directory: backend
run: make test-unit
@@ -27,6 +28,26 @@ jobs:
working-directory: backend
run: make test-integration
+ frontend:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v6
+ - name: Setup pnpm
+ uses: pnpm/action-setup@v4
+ with:
+ version: 9
+ - name: Setup Node.js
+ uses: actions/setup-node@v6
+ with:
+ node-version: '20'
+ cache: 'pnpm'
+ cache-dependency-path: frontend/pnpm-lock.yaml
+ - name: Install frontend dependencies
+ working-directory: frontend
+ run: pnpm install --frozen-lockfile
+ - name: Frontend typecheck and critical vitest
+ run: make test-frontend
+
golangci-lint:
runs-on: ubuntu-latest
steps:
@@ -36,12 +57,13 @@ jobs:
go-version-file: backend/go.mod
check-latest: false
cache: true
+ cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
- go version | grep -q 'go1.26.1'
+ go version | grep -q 'go1.26.2'
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
with:
version: v2.9
args: --timeout=30m
- working-directory: backend
\ No newline at end of file
+ working-directory: backend
diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml
new file mode 100644
index 00000000..67c8d6e9
--- /dev/null
+++ b/.github/workflows/cla.yml
@@ -0,0 +1,59 @@
+name: "CLA Assistant"
+
+on:
+ issue_comment:
+ types: [created]
+ pull_request_target:
+ types: [opened, reopened, closed, synchronize]
+
+permissions:
+ actions: write
+ contents: write
+ pull-requests: write
+ statuses: write
+
+jobs:
+ cla-check:
+ if: |
+ github.event_name == 'issue_comment' ||
+ (github.event_name == 'pull_request_target' && github.event.action != 'closed')
+ runs-on: ubuntu-latest
+ steps:
+ - name: "CLA Assistant"
+ if: |
+ (github.event.comment.body == 'recheck' ||
+ github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') ||
+ github.event_name == 'pull_request_target'
+ uses: contributor-assistant/github-action@v2.6.1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ path-to-signatures: "cla.json"
+ path-to-document: "https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md"
+ branch: "cla-signatures"
+ allowlist: "dependabot[bot],renovate[bot],bot*"
+ lock-pullrequest-aftermerge: false
+ custom-notsigned-prcomment: |
+ Thank you for your contribution! Before we can merge this PR, we need $you to sign our [Contributor License Agreement (CLA)](https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md).
+
+ **To sign**, please reply with the following comment:
+
+ > I have read the CLA Document and I hereby sign the CLA
+
+ You only need to sign once — it will be valid for all your future contributions to this project.
+ custom-pr-sign-comment: "I have read the CLA Document and I hereby sign the CLA"
+ custom-allsigned-prcomment: "All contributors have signed the CLA. ✅"
+
+ cla-lock:
+ if: github.event_name == 'pull_request_target' && github.event.action == 'closed' && github.event.pull_request.merged == true
+ runs-on: ubuntu-latest
+ steps:
+ - name: "Lock merged PR"
+ uses: contributor-assistant/github-action@v2.6.1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ with:
+ path-to-signatures: "cla.json"
+ path-to-document: "https://github.com/Wei-Shaw/sub2api/blob/main/CLA.md"
+ branch: "cla-signatures"
+ lock-pullrequest-aftermerge: true
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index c51b3c07..26ed8524 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -115,7 +115,7 @@ jobs:
- name: Verify Go version
run: |
- go version | grep -q 'go1.26.1'
+ go version | grep -q 'go1.26.2'
# Docker setup for GoReleaser
- name: Set up QEMU
@@ -246,10 +246,10 @@ jobs:
if [ -n "$DOCKERHUB_USERNAME" ]; then
DOCKER_IMAGE="${DOCKERHUB_USERNAME}/sub2api"
MESSAGE+="# Docker Hub"$'\n'
- MESSAGE+="docker pull ${DOCKER_IMAGE}:${TAG_NAME}"$'\n'
+ MESSAGE+="docker pull ${DOCKER_IMAGE}:${VERSION}"$'\n'
MESSAGE+="# GitHub Container Registry"$'\n'
fi
- MESSAGE+="docker pull ${GHCR_IMAGE}:${TAG_NAME}"$'\n'
+ MESSAGE+="docker pull ${GHCR_IMAGE}:${VERSION}"$'\n'
MESSAGE+="\`\`\`"$'\n'$'\n'
MESSAGE+="🔗 *相关链接:*"$'\n'
MESSAGE+="• [GitHub Release](https://github.com/${REPO}/releases/tag/${TAG_NAME})"$'\n'
diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml
index cc5a90cf..600fd2fa 100644
--- a/.github/workflows/security-scan.yml
+++ b/.github/workflows/security-scan.yml
@@ -23,7 +23,7 @@ jobs:
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
- go version | grep -q 'go1.26.1'
+ go version | grep -q 'go1.26.2'
- name: Run govulncheck
working-directory: backend
run: |
diff --git a/.gitignore b/.gitignore
index 297c1d6f..a61f406d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,5 @@
docs/claude-relay-service/
+.codex
# ===================
# Go 后端
@@ -127,9 +128,11 @@ deploy/docker-compose.override.yml
.gocache/
vite.config.js
docs/*
+!docs/PAYMENT.md
+!docs/PAYMENT_CN.md
+!docs/ADMIN_PAYMENT_INTEGRATION_API.md
.serena/
.codex/
frontend/coverage/
aicodex
output/
-
diff --git a/CLA.md b/CLA.md
new file mode 100644
index 00000000..ed0d74b8
--- /dev/null
+++ b/CLA.md
@@ -0,0 +1,73 @@
+# Sub2API Individual Contributor License Agreement (v1.0)
+
+Thank you for your interest in contributing to Sub2API ("the Project"). This Contributor License Agreement ("Agreement") documents the rights granted by contributors to the Project.
+
+By signing this Agreement, you accept and agree to the following terms and conditions for your present and future contributions submitted to the Project.
+
+## 1. Definitions
+
+- **"You" (or "Your")** means the copyright owner or legal entity authorized by the copyright owner that is making this Agreement.
+- **"Contribution"** means any original work of authorship, including any modifications or additions to an existing work, that is intentionally submitted by You to the Project for inclusion in, or documentation of, any of the products owned or managed by the Project. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Project or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Project for the purpose of discussing and improving the Project, but excluding communication that is conspicuously marked or otherwise designated in writing by You as "Not a Contribution."
+- **"Project Owner"** means Wesley Liddick, or any individual or legal entity to whom Wesley Liddick has explicitly assigned or transferred ownership of the Project in writing, and their respective successors and assigns.
+
+## 2. Grant of Copyright License
+
+Subject to the terms and conditions of this Agreement, You hereby grant to the Project Owner a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense, and distribute Your Contributions and such derivative works. This license includes, without limitation, the right to sublicense, assign, and transfer these rights to any third party, including without limitation any successor, assignee, or acquiring entity of the Project or the Project Owner, and to use Your Contributions under any license, including proprietary or commercial licenses.
+
+## 3. Moral Rights
+
+To the fullest extent permitted by applicable law, You irrevocably waive and agree not to assert any moral rights (including rights of attribution and integrity) that You may have in Your Contributions, and agree that the Project Owner and its licensees may use, modify, and distribute Your Contributions without attribution or other obligations arising from moral rights.
+
+## 4. Grant of Patent License
+
+Subject to the terms and conditions of this Agreement, You hereby grant to the Project Owner a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer Your Contributions, where such license applies only to those patent claims licensable by You that are necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Project to which such Contribution(s) was submitted.
+
+## 5. Representations and Warranties
+
+You represent and warrant that:
+
+(a) You are legally entitled to grant the above licenses.
+
+(b) If Your employer(s) has rights to intellectual property that You create that includes Your Contributions, You have received permission to make Contributions on behalf of that employer, or that Your employer has waived such rights for Your Contributions to the Project.
+
+(c) Each of Your Contributions is Your original creation, or You have sufficient rights to submit it under the terms of this Agreement. You agree to provide, upon request, reasonable documentation or explanation of any third-party materials included in Your Contributions.
+
+## 6. No Warranty
+
+Your Contributions are provided on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support.
+
+## 7. No Obligation
+
+You understand that the decision to include Your Contribution in any product or project is entirely at the discretion of the Project Owner, and this Agreement does not obligate the Project Owner to use Your Contribution.
+
+## 8. Retention of Rights
+
+You retain ownership of the copyright in Your Contributions. This Agreement does not transfer any copyright or other intellectual property rights from You to the Project Owner. This Agreement only grants the licenses described above.
+
+## 9. Term and Termination
+
+This Agreement shall remain in effect indefinitely. You may terminate this Agreement prospectively by providing written notice to the Project Owner, but such termination shall not affect the licenses granted for Contributions submitted prior to the effective date of termination. The licenses granted herein for Contributions submitted prior to termination are perpetual and irrevocable.
+
+## 10. Electronic Signature
+
+You agree that Your electronic signature (including but not limited to typing a specific phrase in a pull request, issue, or other electronic communication) is legally binding and has the same force and effect as a handwritten signature. You consent to the use of electronic means to enter into this Agreement and acknowledge that this Agreement is enforceable as if executed in a traditional written format.
+
+## 11. General Provisions
+
+**Entire Agreement.** This Agreement constitutes the entire agreement between You and the Project Owner with respect to Your Contributions and supersedes all prior or contemporaneous understandings regarding such subject matter.
+
+**Severability.** If any provision of this Agreement is held to be unenforceable or invalid, that provision will be enforced to the maximum extent possible and the remaining provisions will remain in full force and effect.
+
+**No Waiver.** The failure of the Project Owner to enforce any provision of this Agreement shall not constitute a waiver of that provision or any other provision.
+
+**Amendment.** This Agreement may only be modified by a written instrument signed by both parties. Modifications to this Agreement apply only to Contributions submitted after the modified Agreement is published and accepted by You. Prior Contributions remain governed by the version of the Agreement in effect at the time of submission.
+
+**Notification.** Notices under this Agreement shall be sent to the Project Owner via a GitHub issue on the Project repository. Notices are effective upon receipt.
+
+---
+
+**By signing this CLA, you acknowledge that you have read and understood this Agreement and agree to be bound by its terms.**
+
+To sign, reply in the pull request with:
+
+> I have read the CLA Document and I hereby sign the CLA
diff --git a/Dockerfile b/Dockerfile
index a16eb958..890bda0b 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -7,7 +7,7 @@
# =============================================================================
ARG NODE_IMAGE=node:24-alpine
-ARG GOLANG_IMAGE=golang:1.26.1-alpine
+ARG GOLANG_IMAGE=golang:1.26.2-alpine
ARG ALPINE_IMAGE=alpine:3.21
ARG POSTGRES_IMAGE=postgres:18-alpine
ARG GOPROXY=https://goproxy.cn,direct
diff --git a/LICENSE b/LICENSE
index 7a94ca9d..153d416d 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,21 +1,165 @@
-MIT License
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
-Copyright (c) 2025 Wesley Liddick
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
\ No newline at end of file
diff --git a/Makefile b/Makefile
index fd6a5a9a..d00d0c4f 100644
--- a/Makefile
+++ b/Makefile
@@ -1,4 +1,12 @@
-.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-datamanagementd secret-scan
+.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-frontend-critical test-datamanagementd secret-scan
+
+FRONTEND_CRITICAL_VITEST := \
+ src/views/auth/__tests__/LinuxDoCallbackView.spec.ts \
+ src/views/auth/__tests__/WechatCallbackView.spec.ts \
+ src/views/user/__tests__/PaymentView.spec.ts \
+ src/views/user/__tests__/PaymentResultView.spec.ts \
+ src/components/user/profile/__tests__/ProfileInfoCard.spec.ts \
+ src/views/admin/__tests__/SettingsView.spec.ts
# 一键编译前后端
build: build-backend build-frontend
@@ -24,6 +32,10 @@ test-backend:
test-frontend:
@pnpm --dir frontend run lint:check
@pnpm --dir frontend run typecheck
+ @$(MAKE) test-frontend-critical
+
+test-frontend-critical:
+ @pnpm --dir frontend exec vitest run $(FRONTEND_CRITICAL_VITEST)
test-datamanagementd:
@cd datamanagement && go test ./...
diff --git a/README.md b/README.md
index 99753e45..3e609d65 100644
--- a/README.md
+++ b/README.md
@@ -42,20 +42,65 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
- **Smart Scheduling** - Intelligent account selection with sticky sessions
- **Concurrency Control** - Per-user and per-account concurrency limits
- **Rate Limiting** - Configurable request and token rate limits
+- **Built-in Payment System** - Supports EasyPay, Alipay, WeChat Pay, and Stripe for user self-service top-up, no separate payment service needed ([Configuration Guide](docs/PAYMENT.md))
- **Admin Dashboard** - Web interface for monitoring and management
-- **External System Integration** - Embed external systems (e.g. payment, ticketing) via iframe to extend the admin dashboard
+- **External System Integration** - Embed external systems (e.g. ticketing) via iframe to extend the admin dashboard
-## Don't Want to Self-Host?
+## ❤️ Sponsors
+
+> [Want to appear here?](mailto:support@pincc.ai)
PinCC is the official relay service built on Sub2API, offering stable access to Claude Code, Codex, Gemini and other popular models — ready to use, no deployment or maintenance required.
+
Thanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using this link and enter the "sub2api" promo code during first recharge to get 10% off.
+
+
+
+Thanks to Poixe Ai for sponsoring this project! Poixe AI provides reliable LLM API services. You can leverage the platform's API endpoints to seamlessly build AI-powered products. Additionally, you can become a vendor by providing AI API resources to the platform and earn revenue. Register through the exclusive sub2api referral link and receive a bonus of $5 USD on your first top-up.
+
+
+
+
+Thanks to CTok.ai for sponsoring this project! CTok.ai is dedicated to building a one-stop AI programming tool service platform. We offer professional Claude Code packages and technical community services, with support for Google Gemini and OpenAI Codex. Through carefully designed plans and a professional tech community, we provide developers with reliable service guarantees and continuous technical support, making AI-assisted programming a true productivity tool. Click here to register!
+
+
+
+
+Thanks to SilkAPI for sponsoring this project! SilkAPI is a relay service built on Sub2API, specializing in providing high-speed and stable Codex API relay.
+
+
+
+
+Thanks to YLS Code for sponsoring this project! YLS Code is dedicated to building secure enterprise-grade Coding Agent productivity services, offering stable and fast Codex / Claude / Gemini subscription services along with pay-as-you-go API options for flexible choices. Register now for a limited-time 3-day Codex trial bonus!
+
+
+
+
+Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for sub2api users: register via this link to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off!
+
+
+
+
+Thanks to AIGoCode for sponsoring this project! AIGoCode is an all-in-one platform that integrates Claude Code, Codex, and the latest Gemini models, providing you with stable, efficient, and highly cost-effective AI coding services. The platform offers flexible subscription plans, zero risk of account suspension, direct access with no VPN required, and lightning-fast responses. AIGoCode has prepared a special benefit for sub2api users: if you register via this link , you'll receive an extra 10% bonus credit on your first top-up!
+
+
+
+
+Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through BmoPlus - Premium AI Accounts & Top-ups , users can unlock the mind-blowing rate of 10% of the official GPT subscription price (90% OFF)
+
+
+
+
+Thanks to Bestproxy for sponsoring this project! Bestproxy provides high-purity residential IPs with dedicated one-IP-per-account support. By combining real home networks with fingerprint isolation, it enables link environment isolation and reduces the probability of association-based risk control.
+
+
## Ecosystem
@@ -64,7 +109,7 @@ Community projects that extend or integrate with Sub2API:
| Project | Description | Features |
|---------|-------------|----------|
-| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | Self-service payment system | Self-service top-up and subscription purchase; supports YiPay protocol, WeChat Pay, Alipay, Stripe; embeddable via iframe |
+| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~Self-service payment system~~ | **Now Built-in** — Payment is now integrated into Sub2API, no separate deployment needed. See [Payment Configuration Guide](docs/PAYMENT.md) |
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | Mobile admin console | Cross-platform app (iOS/Android/Web) for user management, account management, monitoring dashboard, and multi-backend switching; built with Expo + React Native |
## Tech Stack
@@ -578,7 +623,9 @@ sub2api/
## License
-MIT License
+This project is licensed under the [GNU Lesser General Public License v3.0](LICENSE) (or later).
+
+Copyright (c) 2026 Wesley Liddick
---
diff --git a/README_CN.md b/README_CN.md
index 8b6feaba..add32a17 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -41,20 +41,65 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
- **智能调度** - 智能账号选择,支持粘性会话
- **并发控制** - 用户级和账号级并发限制
- **速率限制** - 可配置的请求和 Token 速率限制
+- **内置支付系统** - 支持 EasyPay 易支付、支付宝官方、微信官方、Stripe,用户自助充值,无需独立部署支付服务([配置指南](docs/PAYMENT_CN.md))
- **管理后台** - Web 界面进行监控和管理
-- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如支付、工单等),扩展管理后台功能
+- **外部系统集成** - 支持通过 iframe 嵌入外部系统(如工单等),扩展管理后台功能
-## 不想自建?试试官方中转
+## ❤️ 赞助商
+
+> [想出现在这里?](mailto:support@pincc.ai)
PinCC 是基于 Sub2API 搭建的官方中转服务,提供 Claude Code、Codex、Gemini 等主流模型的稳定中转,开箱即用,免去自建部署与运维烦恼。
+
感谢 PackyCode 赞助了本项目!PackyCode 是一家稳定、高效的API中转服务商,提供 Claude Code、Codex、Gemini 等多种中转服务。PackyCode 为本软件的用户提供了特别优惠,使用此链接 注册并在充值时填写"sub2api"优惠码,首次充值可以享受9折优惠!
+
+
+
+感谢 Poixe AI 赞助了本项目!Poixe AI 提供可靠的 AI 模型接口服务,您可以使用平台提供的 LLM API 接口轻松构建 AI 产品,同时也可以成为供应商,为平台提供大模型资源以赚取收益。通过 此链接 专属链接注册,充值额外赠送 $5 美金
+
+
+
+
+感谢 CTok.ai 赞助了本项目!CTok.ai 致力于打造一站式 AI 编程工具服务平台。我们提供 Claude Code 专业套餐及技术社群服务,同时支持 Google Gemini 和 OpenAI Codex。通过精心设计的套餐方案和专业的技术社群,为开发者提供稳定的服务保障和持续的技术支持,让 AI 辅助编程真正成为开发者的生产力工具。点击这里 注册!
+
+
+
+
+感谢 丝绸API 赞助了本项目! 丝绸API 是基于 Sub2API 搭建的中转服务,专注于提供 Codex 高速稳定API中转。
+
+
+
+
+感谢 伊莉思Code 赞助了本项目! 伊莉思Code 致力于构建安全的企业级Coding Agent生产力服务,提供稳定快速的 Codex / Claude / Gemini 订阅服务与即用即付API多种方案灵活选择,限时注册赠送 3 天 Codex 试用福利!
+
+
+
+
+感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定性中转服务,企业级并发、快速开票、7×24 小时专属技术支持。Claude Code / Codex / Gemini 官方通道低至原价 38% / 2% / 9%,充值更享额外折扣!AICodeMirror 为 sub2api 用户提供专属福利:通过此链接 注册,首次充值立享 8 折优惠,企业客户最高可享 75 折!
+
+
+
+
+感谢 AIGoCode 赞助了本项目!AIGoCode 是一站式集成 Claude Code、Codex 以及最新 Gemini 模型的综合平台,为您提供稳定、高效、高性价比的 AI 编程服务。平台提供灵活的订阅方案,零封号风险,免 VPN 直连,响应极速。AIGoCode 为 sub2api 用户准备了专属福利:通过此链接 注册,首次充值可额外获得 10% 赠送额度!
+
+
+
+
+感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过BmoPlus AI成品号专卖/代充 注册下单的用户,可享GPT 官网订阅一折 的震撼价格!
+
+
+
+
+感谢 Bestproxy 赞助了本项目!Bestproxy 是一家提供高纯度住宅IP,支持一号一IP独享,结合真实家庭网络与指纹隔离,可实现链路环境隔离,降低关联风控概率。
+
+
## 生态项目
@@ -63,7 +108,7 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
| 项目 | 说明 | 功能 |
|------|------|------|
-| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | 自助支付系统 | 用户自助充值、自助订阅购买;兼容易支付协议、微信官方支付、支付宝官方支付、Stripe;支持 iframe 嵌入管理后台 |
+| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~自助支付系统~~ | **已内置** — 支付功能已集成到 Sub2API 中,无需独立部署。详见 [支付配置指南](docs/PAYMENT_CN.md) |
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | 移动端管理控制台 | 跨平台应用(iOS/Android/Web),支持用户管理、账号管理、监控看板、多后端切换;基于 Expo + React Native 构建 |
## 技术栈
@@ -639,7 +684,9 @@ sub2api/
## 许可证
-MIT License
+本项目基于 [GNU 宽通用公共许可证 v3.0](LICENSE)(或更高版本)授权。
+
+Copyright (c) 2026 Wesley Liddick
---
diff --git a/README_JA.md b/README_JA.md
index 1266bd84..ccd595b9 100644
--- a/README_JA.md
+++ b/README_JA.md
@@ -42,10 +42,13 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
- **スマートスケジューリング** - スティッキーセッション付きのインテリジェントなアカウント選択
- **同時実行制御** - ユーザーごと・アカウントごとの同時実行数制限
- **レート制限** - 設定可能なリクエスト数およびトークンレート制限
+- **内蔵決済システム** - EasyPay、Alipay、WeChat Pay、Stripe に対応。ユーザーのセルフサービスチャージが可能で、別途決済サービスのデプロイは不要([設定ガイド](docs/PAYMENT.md))
- **管理ダッシュボード** - 監視・管理のための Web インターフェース
-- **外部システム連携** - 外部システム(決済、チケット管理など)を iframe 経由で管理ダッシュボードに埋め込み可能
+- **外部システム連携** - 外部システム(チケット管理など)を iframe 経由で管理ダッシュボードに埋め込み可能
-## セルフホストが不要な方へ
+## ❤️ スポンサー
+
+> [こちらに掲載しませんか?](mailto:support@pincc.ai)
@@ -56,6 +59,47 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
PackyCode のご支援に感謝します!PackyCode は Claude Code、Codex、Gemini などのリレーサービスを提供する信頼性の高い API 中継プラットフォームです。本ソフト利用者向けに特別割引があります:このリンク で登録し、チャージ時に「sub2api」クーポンを入力すると 10% オフになります。
+
+
+
+Poixe AI のご支援に感謝します!Poixe AI は信頼性の高い LLM API サービスを提供しています。プラットフォームの API エンドポイントを活用して、AI 搭載プロダクトをシームレスに構築できます。また、ベンダーとして AI API リソースをプラットフォームに提供し、収益を得ることも可能です。専用の sub2api 紹介リンクから登録すると、初回チャージ時に $5 USD のボーナスがもらえます。
+
+
+
+
+CTok.ai のご支援に感謝します!CTok.ai はワンストップ AI プログラミングツールサービスプラットフォームの構築に取り組んでいます。Claude Code の専用プランと技術コミュニティサービスを提供し、Google Gemini や OpenAI Codex もサポートしています。丁寧に設計されたプランと専門的な技術コミュニティを通じて、開発者に安定したサービス保証と継続的な技術サポートを提供し、AI アシスト プログラミングを真の生産性向上ツールにします。こちら から登録!
+
+
+
+
+SilkAPI のご支援に感謝します!SilkAPI は Sub2API をベースに構築された中継サービスで、高速かつ安定した Codex API 中継の提供に特化しています。
+
+
+
+
+YLS Code のご支援に感謝します!YLS Code は安全なエンタープライズグレードの Coding Agent 生産性サービスの構築に取り組んでおり、安定かつ高速な Codex / Claude / Gemini サブスクリプションサービスと従量課金 API の柔軟なプランを提供しています。期間限定で新規登録者に 3 日間の Codex 試用特典をプレゼント中!
+
+
+
+
+AICodeMirror のご支援に感謝します!AICodeMirror は Claude Code / Codex / Gemini CLI の公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時実行、迅速な請求書発行、24時間年中無休の専属テクニカルサポートを備えています。Claude Code / Codex / Gemini の公式チャネルを定価の 38% / 2% / 9% で利用可能、チャージ時にはさらに追加割引!AICodeMirror は sub2api ユーザー向けに特別特典を提供中:こちらのリンク から登録すると、初回チャージが 20% オフ、法人のお客様は最大 25% オフ!
+
+
+
+
+AIGoCode のご支援に感謝します!AIGoCode は Claude Code、Codex、最新の Gemini モデルを統合したオールインワンプラットフォームで、安定的かつ効率的でコストパフォーマンスに優れた AI コーディングサービスを提供します。柔軟なサブスクリプションプラン、アカウント停止リスクゼロ、VPN 不要の直接アクセス、超高速レスポンスが特長です。AIGoCode は sub2api ユーザー向けに特別特典を用意しています:こちらのリンク から登録すると、初回チャージ時に 10% のボーナスクレジットを追加プレゼント!
+
+
+
+
+本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらのBmoPlus AIアカウント専門店/代行チャージ 経由でご登録・ご注文いただいたユーザー様は、GPTを 公式サイト価格の約1割(90% OFF) という驚異的な価格でご利用いただけます!
+
+
+
+
+Bestproxy のご支援に感謝します!Bestproxy は高純度の住宅IPを提供し、1アカウント1IP専有をサポートしています。実際の家庭ネットワークとフィンガープリント分離を組み合わせることで、リンク環境の分離を実現し、関連付けによるリスク管理の確率を低減します。
+
+
## エコシステム
@@ -64,7 +108,7 @@ Sub2API を拡張・統合するコミュニティプロジェクト:
| プロジェクト | 説明 | 機能 |
|---------|-------------|----------|
-| [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) | セルフサービス決済システム | セルフサービスによるチャージおよびサブスクリプション購入。YiPay プロトコル、WeChat Pay、Alipay、Stripe 対応。iframe での埋め込み可能 |
+| ~~[Sub2ApiPay](https://github.com/touwaeriol/sub2apipay)~~ | ~~セルフサービス決済システム~~ | **内蔵済み** — 決済機能は Sub2API に統合されました。別途デプロイは不要です。[決済設定ガイド](docs/PAYMENT.md)をご参照ください |
| [sub2api-mobile](https://github.com/ckken/sub2api-mobile) | モバイル管理コンソール | ユーザー管理、アカウント管理、監視ダッシュボード、マルチバックエンド切り替えが可能なクロスプラットフォームアプリ(iOS/Android/Web)。Expo + React Native で構築 |
## 技術スタック
@@ -578,7 +622,9 @@ sub2api/
## ライセンス
-MIT License
+本プロジェクトは [GNU Lesser General Public License v3.0](LICENSE)(またはそれ以降のバージョン)の下でライセンスされています。
+
+Copyright (c) 2026 Wesley Liddick
---
diff --git a/assets/partners/logos/AICodeMirror.jpg b/assets/partners/logos/AICodeMirror.jpg
new file mode 100644
index 00000000..1c98b223
Binary files /dev/null and b/assets/partners/logos/AICodeMirror.jpg differ
diff --git a/assets/partners/logos/aigocode.png b/assets/partners/logos/aigocode.png
new file mode 100644
index 00000000..6dd5965a
Binary files /dev/null and b/assets/partners/logos/aigocode.png differ
diff --git a/assets/partners/logos/bestproxy.png b/assets/partners/logos/bestproxy.png
new file mode 100644
index 00000000..87c58670
Binary files /dev/null and b/assets/partners/logos/bestproxy.png differ
diff --git a/assets/partners/logos/bmoplus.jpg b/assets/partners/logos/bmoplus.jpg
new file mode 100644
index 00000000..1a9b4d8b
Binary files /dev/null and b/assets/partners/logos/bmoplus.jpg differ
diff --git a/assets/partners/logos/ctok.png b/assets/partners/logos/ctok.png
new file mode 100644
index 00000000..cf6fcf17
Binary files /dev/null and b/assets/partners/logos/ctok.png differ
diff --git a/assets/partners/logos/poixe.png b/assets/partners/logos/poixe.png
new file mode 100644
index 00000000..aa89cb06
Binary files /dev/null and b/assets/partners/logos/poixe.png differ
diff --git a/assets/partners/logos/silkapi.png b/assets/partners/logos/silkapi.png
new file mode 100644
index 00000000..97afbda9
Binary files /dev/null and b/assets/partners/logos/silkapi.png differ
diff --git a/assets/partners/logos/ylscode.png b/assets/partners/logos/ylscode.png
new file mode 100644
index 00000000..4d374f04
Binary files /dev/null and b/assets/partners/logos/ylscode.png differ
diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go
index 7eabde62..9386678d 100644
--- a/backend/cmd/jwtgen/main.go
+++ b/backend/cmd/jwtgen/main.go
@@ -33,7 +33,7 @@ func main() {
}()
userRepo := repository.NewUserRepository(client, sqlDB)
- authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
+ authService := service.NewAuthService(client, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index 9e3db2aa..841597f0 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.106
+0.1.119
diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go
index 7fc648ac..9bfa2717 100644
--- a/backend/cmd/server/wire.go
+++ b/backend/cmd/server/wire.go
@@ -13,6 +13,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -35,6 +36,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
// Business layer ProviderSets
repository.ProviderSet,
service.ProviderSet,
+ payment.ProviderSet,
middleware.ProviderSet,
handler.ProviderSet,
@@ -76,7 +78,6 @@ func provideCleanup(
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
opsSystemLogSink *service.OpsSystemLogSink,
- soraMediaCleanup *service.SoraMediaCleanupService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
@@ -95,6 +96,8 @@ func provideCleanup(
openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService,
+ paymentOrderExpiry *service.PaymentOrderExpiryService,
+ channelMonitorRunner *service.ChannelMonitorRunner,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -125,12 +128,6 @@ func provideCleanup(
}
return nil
}},
- {"SoraMediaCleanupService", func() error {
- if soraMediaCleanup != nil {
- soraMediaCleanup.Stop()
- }
- return nil
- }},
{"OpsAlertEvaluatorService", func() error {
if opsAlertEvaluator != nil {
opsAlertEvaluator.Stop()
@@ -237,6 +234,18 @@ func provideCleanup(
}
return nil
}},
+ {"PaymentOrderExpiryService", func() error {
+ if paymentOrderExpiry != nil {
+ paymentOrderExpiry.Stop()
+ }
+ return nil
+ }},
+ {"ChannelMonitorRunner", func() error {
+ if channelMonitorRunner != nil {
+ channelMonitorRunner.Stop()
+ }
+ return nil
+ }},
}
infraSteps := []cleanupStep{
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index ce898a4a..f767bbea 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler/admin"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -49,7 +50,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
settingRepository := repository.NewSettingRepository(client)
groupRepository := repository.NewGroupRepository(client, db)
- settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig)
+ proxyRepository := repository.NewProxyRepository(client, db)
+ settingService := service.ProvideSettingService(settingRepository, groupRepository, proxyRepository, configConfig)
emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier()
@@ -59,16 +61,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingCache := repository.NewBillingCache(redisClient)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
apiKeyRepository := repository.NewAPIKeyRepository(client, db)
- billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, configConfig)
+ userRPMCache := repository.NewUserRPMCache(redisClient)
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
+ billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig)
apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
- apiKeyService.SetRateLimitCacheInvalidator(billingCache)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
- authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
- userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
+ affiliateRepository := repository.NewAffiliateRepository(client, db)
+ affiliateService := service.NewAffiliateService(affiliateRepository, settingService, apiKeyAuthCacheInvalidator, billingCacheService)
+ authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService)
+ userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
@@ -78,10 +82,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
- userHandler := handler.NewUserHandler(userService)
+ userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
- usageBillingRepository := repository.NewUsageBillingRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemHandler := handler.NewRedeemHandler(redeemService)
@@ -90,6 +93,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
announcementReadRepository := repository.NewAnnouncementReadRepository(client)
announcementService := service.NewAnnouncementService(announcementRepository, announcementReadRepository, userRepository, userSubscriptionRepository)
announcementHandler := handler.NewAnnouncementHandler(announcementService)
+ channelMonitorRepository := repository.NewChannelMonitorRepository(client, db)
+ channelMonitorService := service.ProvideChannelMonitorService(channelMonitorRepository, secretEncryptor)
+ channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService)
dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
@@ -99,22 +105,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
}
dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
- schedulerCache := repository.NewSchedulerCache(redisClient)
+ schedulerCache := repository.ProvideSchedulerCache(redisClient, configConfig)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
- soraAccountRepository := repository.NewSoraAccountRepository(db)
- proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
privacyClientFactory := providePrivacyClientFactory()
- adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
+ adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
+ sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
+ rpmCache := repository.NewRPMCache(redisClient)
+ groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
+ groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
- openAIOAuthService.SetPrivacyClientFactory(privacyClientFactory)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
driveClient := repository.NewGeminiDriveClient()
@@ -123,32 +130,29 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
+ openAI403CounterCache := repository.NewOpenAI403CounterCache(redisClient)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
- oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
- rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
+ rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator)
httpUpstream := repository.NewHTTPUpstream(configConfig)
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
identityCache := repository.NewIdentityCache(redisClient)
- geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
- gatewayCache := repository.NewGatewayCache(redisClient)
- schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
- schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
- antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
- internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
- antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
+ oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
+ geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
+ gatewayCache := repository.NewGatewayCache(redisClient)
+ schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
+ schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
+ antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
+ internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
+ antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
- sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
- rpmCache := repository.NewRPMCache(redisClient)
- groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
- groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
dataManagementService := service.NewDataManagementService()
@@ -165,6 +169,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
+ usageBillingRepository := repository.NewUsageBillingRepository(client, db)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
@@ -173,20 +178,27 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
- claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI)
+ claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore()
- gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService)
- openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI)
- openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
+ channelRepository := repository.NewChannelRepository(db)
+ channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
+ modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
+ balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
+ gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
+ openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
+ openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
- soraS3Storage := service.NewSoraS3Storage(settingService)
- settingService.SetOnS3UpdateCallback(soraS3Storage.RefreshClient)
- soraGenerationRepository := repository.NewSoraGenerationRepository(db)
- soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService)
- soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService)
- settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, soraS3Storage)
+ encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
+ if err != nil {
+ return nil, err
+ }
+ paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
+ registry := payment.ProvideRegistry()
+ defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
+ paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
+ settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
@@ -213,22 +225,27 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
- adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler)
+ channelHandler := admin.NewChannelHandler(channelService, billingService)
+ channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
+ channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db)
+ channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
+ channelMonitorRequestTemplateHandler := admin.NewChannelMonitorRequestTemplateHandler(channelMonitorRequestTemplateService)
+ paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
+ affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService)
+ adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, affiliateHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
- soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
- soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
- soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig)
- soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService)
- soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
totpHandler := handler.NewTotpHandler(totpService)
+ handlerPaymentHandler := handler.NewPaymentHandler(paymentService, paymentConfigService, channelService)
+ paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry)
+ availableChannelHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService)
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
- handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService)
+ handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelHandler, idempotencyCoordinator, idempotencyCleanupService)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -237,14 +254,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
- opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
+ opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
- soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig)
- tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI)
+ tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
- v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService)
+ paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
+ channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService)
+ v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -279,7 +297,6 @@ func provideCleanup(
opsCleanup *service.OpsCleanupService,
opsScheduledReport *service.OpsScheduledReportService,
opsSystemLogSink *service.OpsSystemLogSink,
- soraMediaCleanup *service.SoraMediaCleanupService,
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
@@ -298,6 +315,8 @@ func provideCleanup(
openAIGateway *service.OpenAIGatewayService,
scheduledTestRunner *service.ScheduledTestRunnerService,
backupSvc *service.BackupService,
+ paymentOrderExpiry *service.PaymentOrderExpiryService,
+ channelMonitorRunner *service.ChannelMonitorRunner,
) func() {
return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -327,12 +346,6 @@ func provideCleanup(
}
return nil
}},
- {"SoraMediaCleanupService", func() error {
- if soraMediaCleanup != nil {
- soraMediaCleanup.Stop()
- }
- return nil
- }},
{"OpsAlertEvaluatorService", func() error {
if opsAlertEvaluator != nil {
opsAlertEvaluator.Stop()
@@ -439,6 +452,18 @@ func provideCleanup(
}
return nil
}},
+ {"PaymentOrderExpiryService", func() error {
+ if paymentOrderExpiry != nil {
+ paymentOrderExpiry.Stop()
+ }
+ return nil
+ }},
+ {"ChannelMonitorRunner", func() error {
+ if channelMonitorRunner != nil {
+ channelMonitorRunner.Stop()
+ }
+ return nil
+ }},
}
infraSteps := []cleanupStep{
diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go
index 9d2a54b9..5ccd67fb 100644
--- a/backend/cmd/server/wire_gen_test.go
+++ b/backend/cmd/server/wire_gen_test.go
@@ -43,7 +43,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second)
pricingSvc := service.NewPricingService(cfg, nil)
emailQueueSvc := service.NewEmailQueueService(nil, 1)
- billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
+ billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg)
schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg)
opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil)
@@ -57,7 +57,6 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
&service.OpsCleanupService{},
&service.OpsScheduledReportService{},
opsSystemLogSinkSvc,
- &service.SoraMediaCleanupService{},
schedulerSnapshotSvc,
tokenRefreshSvc,
accountExpirySvc,
@@ -76,6 +75,8 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) {
nil, // openAIGateway
nil, // scheduledTestRunner
nil, // backupSvc
+ nil, // paymentOrderExpiry
+ nil, // channelMonitorRunner
)
require.NotPanics(t, func() {
diff --git a/backend/ent/authidentity.go b/backend/ent/authidentity.go
new file mode 100644
index 00000000..5ccfcf19
--- /dev/null
+++ b/backend/ent/authidentity.go
@@ -0,0 +1,266 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentity is the model entity for the AuthIdentity schema.
+type AuthIdentity struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // UserID holds the value of the "user_id" field.
+ UserID int64 `json:"user_id,omitempty"`
+ // ProviderType holds the value of the "provider_type" field.
+ ProviderType string `json:"provider_type,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // ProviderSubject holds the value of the "provider_subject" field.
+ ProviderSubject string `json:"provider_subject,omitempty"`
+ // VerifiedAt holds the value of the "verified_at" field.
+ VerifiedAt *time.Time `json:"verified_at,omitempty"`
+ // Issuer holds the value of the "issuer" field.
+ Issuer *string `json:"issuer,omitempty"`
+ // Metadata holds the value of the "metadata" field.
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the AuthIdentityQuery when eager-loading is set.
+ Edges AuthIdentityEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// AuthIdentityEdges holds the relations/edges for other nodes in the graph.
+type AuthIdentityEdges struct {
+ // User holds the value of the user edge.
+ User *User `json:"user,omitempty"`
+ // Channels holds the value of the channels edge.
+ Channels []*AuthIdentityChannel `json:"channels,omitempty"`
+ // AdoptionDecisions holds the value of the adoption_decisions edge.
+ AdoptionDecisions []*IdentityAdoptionDecision `json:"adoption_decisions,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [3]bool
+}
+
+// UserOrErr returns the User value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e AuthIdentityEdges) UserOrErr() (*User, error) {
+ if e.User != nil {
+ return e.User, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: user.Label}
+ }
+ return nil, &NotLoadedError{edge: "user"}
+}
+
+// ChannelsOrErr returns the Channels value or an error if the edge
+// was not loaded in eager-loading.
+func (e AuthIdentityEdges) ChannelsOrErr() ([]*AuthIdentityChannel, error) {
+ if e.loadedTypes[1] {
+ return e.Channels, nil
+ }
+ return nil, &NotLoadedError{edge: "channels"}
+}
+
+// AdoptionDecisionsOrErr returns the AdoptionDecisions value or an error if the edge
+// was not loaded in eager-loading.
+func (e AuthIdentityEdges) AdoptionDecisionsOrErr() ([]*IdentityAdoptionDecision, error) {
+ if e.loadedTypes[2] {
+ return e.AdoptionDecisions, nil
+ }
+ return nil, &NotLoadedError{edge: "adoption_decisions"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*AuthIdentity) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case authidentity.FieldMetadata:
+ values[i] = new([]byte)
+ case authidentity.FieldID, authidentity.FieldUserID:
+ values[i] = new(sql.NullInt64)
+ case authidentity.FieldProviderType, authidentity.FieldProviderKey, authidentity.FieldProviderSubject, authidentity.FieldIssuer:
+ values[i] = new(sql.NullString)
+ case authidentity.FieldCreatedAt, authidentity.FieldUpdatedAt, authidentity.FieldVerifiedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the AuthIdentity fields.
+func (_m *AuthIdentity) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case authidentity.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case authidentity.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case authidentity.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case authidentity.FieldUserID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field user_id", values[i])
+ } else if value.Valid {
+ _m.UserID = value.Int64
+ }
+ case authidentity.FieldProviderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_type", values[i])
+ } else if value.Valid {
+ _m.ProviderType = value.String
+ }
+ case authidentity.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case authidentity.FieldProviderSubject:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_subject", values[i])
+ } else if value.Valid {
+ _m.ProviderSubject = value.String
+ }
+ case authidentity.FieldVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field verified_at", values[i])
+ } else if value.Valid {
+ _m.VerifiedAt = new(time.Time)
+ *_m.VerifiedAt = value.Time
+ }
+ case authidentity.FieldIssuer:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field issuer", values[i])
+ } else if value.Valid {
+ _m.Issuer = new(string)
+ *_m.Issuer = value.String
+ }
+ case authidentity.FieldMetadata:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field metadata", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.Metadata); err != nil {
+ return fmt.Errorf("unmarshal field metadata: %w", err)
+ }
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentity.
+// This includes values selected through modifiers, order, etc.
+func (_m *AuthIdentity) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryUser queries the "user" edge of the AuthIdentity entity.
+func (_m *AuthIdentity) QueryUser() *UserQuery {
+ return NewAuthIdentityClient(_m.config).QueryUser(_m)
+}
+
+// QueryChannels queries the "channels" edge of the AuthIdentity entity.
+func (_m *AuthIdentity) QueryChannels() *AuthIdentityChannelQuery {
+ return NewAuthIdentityClient(_m.config).QueryChannels(_m)
+}
+
+// QueryAdoptionDecisions queries the "adoption_decisions" edge of the AuthIdentity entity.
+func (_m *AuthIdentity) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery {
+ return NewAuthIdentityClient(_m.config).QueryAdoptionDecisions(_m)
+}
+
+// Update returns a builder for updating this AuthIdentity.
+// Note that you need to call AuthIdentity.Unwrap() before calling this method if this AuthIdentity
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *AuthIdentity) Update() *AuthIdentityUpdateOne {
+ return NewAuthIdentityClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the AuthIdentity entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *AuthIdentity) Unwrap() *AuthIdentity {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: AuthIdentity is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *AuthIdentity) String() string {
+ var builder strings.Builder
+ builder.WriteString("AuthIdentity(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("user_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.UserID))
+ builder.WriteString(", ")
+ builder.WriteString("provider_type=")
+ builder.WriteString(_m.ProviderType)
+ builder.WriteString(", ")
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("provider_subject=")
+ builder.WriteString(_m.ProviderSubject)
+ builder.WriteString(", ")
+ if v := _m.VerifiedAt; v != nil {
+ builder.WriteString("verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.Issuer; v != nil {
+ builder.WriteString("issuer=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("metadata=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Metadata))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// AuthIdentities is a parsable slice of AuthIdentity.
+type AuthIdentities []*AuthIdentity
diff --git a/backend/ent/authidentity/authidentity.go b/backend/ent/authidentity/authidentity.go
new file mode 100644
index 00000000..c90be759
--- /dev/null
+++ b/backend/ent/authidentity/authidentity.go
@@ -0,0 +1,209 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentity
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the authidentity type in the database.
+ Label = "auth_identity"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldUserID holds the string denoting the user_id field in the database.
+ FieldUserID = "user_id"
+ // FieldProviderType holds the string denoting the provider_type field in the database.
+ FieldProviderType = "provider_type"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldProviderSubject holds the string denoting the provider_subject field in the database.
+ FieldProviderSubject = "provider_subject"
+ // FieldVerifiedAt holds the string denoting the verified_at field in the database.
+ FieldVerifiedAt = "verified_at"
+ // FieldIssuer holds the string denoting the issuer field in the database.
+ FieldIssuer = "issuer"
+ // FieldMetadata holds the string denoting the metadata field in the database.
+ FieldMetadata = "metadata"
+ // EdgeUser holds the string denoting the user edge name in mutations.
+ EdgeUser = "user"
+ // EdgeChannels holds the string denoting the channels edge name in mutations.
+ EdgeChannels = "channels"
+ // EdgeAdoptionDecisions holds the string denoting the adoption_decisions edge name in mutations.
+ EdgeAdoptionDecisions = "adoption_decisions"
+ // Table holds the table name of the authidentity in the database.
+ Table = "auth_identities"
+ // UserTable is the table that holds the user relation/edge.
+ UserTable = "auth_identities"
+ // UserInverseTable is the table name for the User entity.
+ // It exists in this package in order to avoid circular dependency with the "user" package.
+ UserInverseTable = "users"
+ // UserColumn is the table column denoting the user relation/edge.
+ UserColumn = "user_id"
+ // ChannelsTable is the table that holds the channels relation/edge.
+ ChannelsTable = "auth_identity_channels"
+ // ChannelsInverseTable is the table name for the AuthIdentityChannel entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentitychannel" package.
+ ChannelsInverseTable = "auth_identity_channels"
+ // ChannelsColumn is the table column denoting the channels relation/edge.
+ ChannelsColumn = "identity_id"
+ // AdoptionDecisionsTable is the table that holds the adoption_decisions relation/edge.
+ AdoptionDecisionsTable = "identity_adoption_decisions"
+ // AdoptionDecisionsInverseTable is the table name for the IdentityAdoptionDecision entity.
+ // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package.
+ AdoptionDecisionsInverseTable = "identity_adoption_decisions"
+ // AdoptionDecisionsColumn is the table column denoting the adoption_decisions relation/edge.
+ AdoptionDecisionsColumn = "identity_id"
+)
+
+// Columns holds all SQL columns for authidentity fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldUserID,
+ FieldProviderType,
+ FieldProviderKey,
+ FieldProviderSubject,
+ FieldVerifiedAt,
+ FieldIssuer,
+ FieldMetadata,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ ProviderTypeValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ ProviderSubjectValidator func(string) error
+ // DefaultMetadata holds the default value on creation for the "metadata" field.
+ DefaultMetadata func() map[string]interface{}
+)
+
+// OrderOption defines the ordering options for the AuthIdentity queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByUserID orders the results by the user_id field.
+func ByUserID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUserID, opts...).ToFunc()
+}
+
+// ByProviderType orders the results by the provider_type field.
+func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderType, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByProviderSubject orders the results by the provider_subject field.
+func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderSubject, opts...).ToFunc()
+}
+
+// ByVerifiedAt orders the results by the verified_at field.
+func ByVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldVerifiedAt, opts...).ToFunc()
+}
+
+// ByIssuer orders the results by the issuer field.
+func ByIssuer(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIssuer, opts...).ToFunc()
+}
+
+// ByUserField orders the results by user field.
+func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByChannelsCount orders the results by channels count.
+func ByChannelsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newChannelsStep(), opts...)
+ }
+}
+
+// ByChannels orders the results by channels terms.
+func ByChannels(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newChannelsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
+// ByAdoptionDecisionsCount orders the results by adoption_decisions count.
+func ByAdoptionDecisionsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newAdoptionDecisionsStep(), opts...)
+ }
+}
+
+// ByAdoptionDecisions orders the results by adoption_decisions terms.
+func ByAdoptionDecisions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+func newUserStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(UserInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
+ )
+}
+func newChannelsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(ChannelsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn),
+ )
+}
+func newAdoptionDecisionsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AdoptionDecisionsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn),
+ )
+}
diff --git a/backend/ent/authidentity/where.go b/backend/ent/authidentity/where.go
new file mode 100644
index 00000000..3dbf3178
--- /dev/null
+++ b/backend/ent/authidentity/where.go
@@ -0,0 +1,600 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentity
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
+func UserID(v int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v))
+}
+
+// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
+func ProviderType(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ.
+func ProviderSubject(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// VerifiedAt applies equality check predicate on the "verified_at" field. It's identical to VerifiedAtEQ.
+func VerifiedAt(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v))
+}
+
+// Issuer applies equality check predicate on the "issuer" field. It's identical to IssuerEQ.
+func Issuer(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// UserIDEQ applies the EQ predicate on the "user_id" field.
+func UserIDEQ(v int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v))
+}
+
+// UserIDNEQ applies the NEQ predicate on the "user_id" field.
+func UserIDNEQ(v int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldUserID, v))
+}
+
+// UserIDIn applies the In predicate on the "user_id" field.
+func UserIDIn(vs ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldUserID, vs...))
+}
+
+// UserIDNotIn applies the NotIn predicate on the "user_id" field.
+func UserIDNotIn(vs ...int64) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldUserID, vs...))
+}
+
+// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
+func ProviderTypeEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
+func ProviderTypeNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderType, v))
+}
+
+// ProviderTypeIn applies the In predicate on the "provider_type" field.
+func ProviderTypeIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
+func ProviderTypeNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeGT applies the GT predicate on the "provider_type" field.
+func ProviderTypeGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldProviderType, v))
+}
+
+// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
+func ProviderTypeGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldProviderType, v))
+}
+
+// ProviderTypeLT applies the LT predicate on the "provider_type" field.
+func ProviderTypeLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldProviderType, v))
+}
+
+// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
+func ProviderTypeLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldProviderType, v))
+}
+
+// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
+func ProviderTypeContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldProviderType, v))
+}
+
+// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
+func ProviderTypeHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderType, v))
+}
+
+// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
+func ProviderTypeHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderType, v))
+}
+
+// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
+func ProviderTypeEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderType, v))
+}
+
+// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
+func ProviderTypeContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderType, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field.
+func ProviderSubjectEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field.
+func ProviderSubjectNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectIn applies the In predicate on the "provider_subject" field.
+func ProviderSubjectIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field.
+func ProviderSubjectNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectGT applies the GT predicate on the "provider_subject" field.
+func ProviderSubjectGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field.
+func ProviderSubjectGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLT applies the LT predicate on the "provider_subject" field.
+func ProviderSubjectLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field.
+func ProviderSubjectLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field.
+func ProviderSubjectContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field.
+func ProviderSubjectHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field.
+func ProviderSubjectHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field.
+func ProviderSubjectEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field.
+func ProviderSubjectContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderSubject, v))
+}
+
+// VerifiedAtEQ applies the EQ predicate on the "verified_at" field.
+func VerifiedAtEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v))
+}
+
+// VerifiedAtNEQ applies the NEQ predicate on the "verified_at" field.
+func VerifiedAtNEQ(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldVerifiedAt, v))
+}
+
+// VerifiedAtIn applies the In predicate on the "verified_at" field.
+func VerifiedAtIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldVerifiedAt, vs...))
+}
+
+// VerifiedAtNotIn applies the NotIn predicate on the "verified_at" field.
+func VerifiedAtNotIn(vs ...time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldVerifiedAt, vs...))
+}
+
+// VerifiedAtGT applies the GT predicate on the "verified_at" field.
+func VerifiedAtGT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldVerifiedAt, v))
+}
+
+// VerifiedAtGTE applies the GTE predicate on the "verified_at" field.
+func VerifiedAtGTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldVerifiedAt, v))
+}
+
+// VerifiedAtLT applies the LT predicate on the "verified_at" field.
+func VerifiedAtLT(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldVerifiedAt, v))
+}
+
+// VerifiedAtLTE applies the LTE predicate on the "verified_at" field.
+func VerifiedAtLTE(v time.Time) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldVerifiedAt, v))
+}
+
+// VerifiedAtIsNil applies the IsNil predicate on the "verified_at" field.
+func VerifiedAtIsNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIsNull(FieldVerifiedAt))
+}
+
+// VerifiedAtNotNil applies the NotNil predicate on the "verified_at" field.
+func VerifiedAtNotNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotNull(FieldVerifiedAt))
+}
+
+// IssuerEQ applies the EQ predicate on the "issuer" field.
+func IssuerEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v))
+}
+
+// IssuerNEQ applies the NEQ predicate on the "issuer" field.
+func IssuerNEQ(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNEQ(FieldIssuer, v))
+}
+
+// IssuerIn applies the In predicate on the "issuer" field.
+func IssuerIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIn(FieldIssuer, vs...))
+}
+
+// IssuerNotIn applies the NotIn predicate on the "issuer" field.
+func IssuerNotIn(vs ...string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotIn(FieldIssuer, vs...))
+}
+
+// IssuerGT applies the GT predicate on the "issuer" field.
+func IssuerGT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGT(FieldIssuer, v))
+}
+
+// IssuerGTE applies the GTE predicate on the "issuer" field.
+func IssuerGTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldGTE(FieldIssuer, v))
+}
+
+// IssuerLT applies the LT predicate on the "issuer" field.
+func IssuerLT(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLT(FieldIssuer, v))
+}
+
+// IssuerLTE applies the LTE predicate on the "issuer" field.
+func IssuerLTE(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldLTE(FieldIssuer, v))
+}
+
+// IssuerContains applies the Contains predicate on the "issuer" field.
+func IssuerContains(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContains(FieldIssuer, v))
+}
+
+// IssuerHasPrefix applies the HasPrefix predicate on the "issuer" field.
+func IssuerHasPrefix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasPrefix(FieldIssuer, v))
+}
+
+// IssuerHasSuffix applies the HasSuffix predicate on the "issuer" field.
+func IssuerHasSuffix(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldHasSuffix(FieldIssuer, v))
+}
+
+// IssuerIsNil applies the IsNil predicate on the "issuer" field.
+func IssuerIsNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldIsNull(FieldIssuer))
+}
+
+// IssuerNotNil applies the NotNil predicate on the "issuer" field.
+func IssuerNotNil() predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldNotNull(FieldIssuer))
+}
+
+// IssuerEqualFold applies the EqualFold predicate on the "issuer" field.
+func IssuerEqualFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldEqualFold(FieldIssuer, v))
+}
+
+// IssuerContainsFold applies the ContainsFold predicate on the "issuer" field.
+func IssuerContainsFold(v string) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.FieldContainsFold(FieldIssuer, v))
+}
+
+// HasUser applies the HasEdge predicate on the "user" edge.
+func HasUser() predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
+func HasUserWith(preds ...predicate.User) predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := newUserStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasChannels applies the HasEdge predicate on the "channels" edge.
+func HasChannels() predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasChannelsWith applies the HasEdge predicate on the "channels" edge with a given conditions (other predicates).
+func HasChannelsWith(preds ...predicate.AuthIdentityChannel) predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := newChannelsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasAdoptionDecisions applies the HasEdge predicate on the "adoption_decisions" edge.
+func HasAdoptionDecisions() predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAdoptionDecisionsWith applies the HasEdge predicate on the "adoption_decisions" edge with a given conditions (other predicates).
+func HasAdoptionDecisionsWith(preds ...predicate.IdentityAdoptionDecision) predicate.AuthIdentity {
+ return predicate.AuthIdentity(func(s *sql.Selector) {
+ step := newAdoptionDecisionsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.AuthIdentity) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.AuthIdentity) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.AuthIdentity) predicate.AuthIdentity {
+ return predicate.AuthIdentity(sql.NotPredicates(p))
+}
diff --git a/backend/ent/authidentity_create.go b/backend/ent/authidentity_create.go
new file mode 100644
index 00000000..e287705c
--- /dev/null
+++ b/backend/ent/authidentity_create.go
@@ -0,0 +1,1036 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentityCreate is the builder for creating a AuthIdentity entity.
+type AuthIdentityCreate struct {
+ config
+ mutation *AuthIdentityMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *AuthIdentityCreate) SetCreatedAt(v time.Time) *AuthIdentityCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *AuthIdentityCreate) SetUpdatedAt(v time.Time) *AuthIdentityCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetUserID sets the "user_id" field.
+func (_c *AuthIdentityCreate) SetUserID(v int64) *AuthIdentityCreate {
+ _c.mutation.SetUserID(v)
+ return _c
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_c *AuthIdentityCreate) SetProviderType(v string) *AuthIdentityCreate {
+ _c.mutation.SetProviderType(v)
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *AuthIdentityCreate) SetProviderKey(v string) *AuthIdentityCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_c *AuthIdentityCreate) SetProviderSubject(v string) *AuthIdentityCreate {
+ _c.mutation.SetProviderSubject(v)
+ return _c
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (_c *AuthIdentityCreate) SetVerifiedAt(v time.Time) *AuthIdentityCreate {
+ _c.mutation.SetVerifiedAt(v)
+ return _c
+}
+
+// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetIssuer sets the "issuer" field.
+func (_c *AuthIdentityCreate) SetIssuer(v string) *AuthIdentityCreate {
+ _c.mutation.SetIssuer(v)
+ return _c
+}
+
+// SetNillableIssuer sets the "issuer" field if the given value is not nil.
+func (_c *AuthIdentityCreate) SetNillableIssuer(v *string) *AuthIdentityCreate {
+ if v != nil {
+ _c.SetIssuer(*v)
+ }
+ return _c
+}
+
+// SetMetadata sets the "metadata" field.
+func (_c *AuthIdentityCreate) SetMetadata(v map[string]interface{}) *AuthIdentityCreate {
+ _c.mutation.SetMetadata(v)
+ return _c
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_c *AuthIdentityCreate) SetUser(v *User) *AuthIdentityCreate {
+ return _c.SetUserID(v.ID)
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (_c *AuthIdentityCreate) AddChannelIDs(ids ...int64) *AuthIdentityCreate {
+ _c.mutation.AddChannelIDs(ids...)
+ return _c
+}
+
+// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
+func (_c *AuthIdentityCreate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddChannelIDs(ids...)
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (_c *AuthIdentityCreate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityCreate {
+ _c.mutation.AddAdoptionDecisionIDs(ids...)
+ return _c
+}
+
+// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_c *AuthIdentityCreate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddAdoptionDecisionIDs(ids...)
+}
+
+// Mutation returns the AuthIdentityMutation object of the builder.
+func (_c *AuthIdentityCreate) Mutation() *AuthIdentityMutation {
+ return _c.mutation
+}
+
+// Save creates the AuthIdentity in the database.
+func (_c *AuthIdentityCreate) Save(ctx context.Context) (*AuthIdentity, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *AuthIdentityCreate) SaveX(ctx context.Context) *AuthIdentity {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *AuthIdentityCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := authidentity.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := authidentity.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ v := authidentity.DefaultMetadata()
+ _c.mutation.SetMetadata(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *AuthIdentityCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentity.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentity.updated_at"`)}
+ }
+ if _, ok := _c.mutation.UserID(); !ok {
+ return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "AuthIdentity.user_id"`)}
+ }
+ if _, ok := _c.mutation.ProviderType(); !ok {
+ return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentity.provider_type"`)}
+ }
+ if v, ok := _c.mutation.ProviderType(); ok {
+ if err := authidentity.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentity.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := authidentity.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderSubject(); !ok {
+ return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "AuthIdentity.provider_subject"`)}
+ }
+ if v, ok := _c.mutation.ProviderSubject(); ok {
+ if err := authidentity.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentity.metadata"`)}
+ }
+ if len(_c.mutation.UserIDs()) == 0 {
+ return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "AuthIdentity.user"`)}
+ }
+ return nil
+}
+
+func (_c *AuthIdentityCreate) sqlSave(ctx context.Context) (*AuthIdentity, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *AuthIdentityCreate) createSpec() (*AuthIdentity, *sqlgraph.CreateSpec) {
+ var (
+ _node = &AuthIdentity{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(authidentity.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.ProviderType(); ok {
+ _spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
+ _node.ProviderType = value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.ProviderSubject(); ok {
+ _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
+ _node.ProviderSubject = value
+ }
+ if value, ok := _c.mutation.VerifiedAt(); ok {
+ _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
+ _node.VerifiedAt = &value
+ }
+ if value, ok := _c.mutation.Issuer(); ok {
+ _spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
+ _node.Issuer = &value
+ }
+ if value, ok := _c.mutation.Metadata(); ok {
+ _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
+ _node.Metadata = value
+ }
+ if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.UserID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.ChannelsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentity.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertOne {
+ _c.conflict = opts
+ return &AuthIdentityUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityCreate) OnConflictColumns(columns ...string) *AuthIdentityUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // AuthIdentityUpsertOne is the builder for "upsert"-ing
+ // one AuthIdentity node.
+ AuthIdentityUpsertOne struct {
+ create *AuthIdentityCreate
+ }
+
+ // AuthIdentityUpsert is the "OnConflict" setter.
+ AuthIdentityUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityUpsert) SetUpdatedAt(v time.Time) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateUpdatedAt() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldUpdatedAt)
+ return u
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AuthIdentityUpsert) SetUserID(v int64) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldUserID, v)
+ return u
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateUserID() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldUserID)
+ return u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityUpsert) SetProviderType(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldProviderType, v)
+ return u
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateProviderType() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldProviderType)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityUpsert) SetProviderKey(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateProviderKey() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldProviderKey)
+ return u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *AuthIdentityUpsert) SetProviderSubject(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldProviderSubject, v)
+ return u
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateProviderSubject() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldProviderSubject)
+ return u
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (u *AuthIdentityUpsert) SetVerifiedAt(v time.Time) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldVerifiedAt, v)
+ return u
+}
+
+// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateVerifiedAt() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldVerifiedAt)
+ return u
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (u *AuthIdentityUpsert) ClearVerifiedAt() *AuthIdentityUpsert {
+ u.SetNull(authidentity.FieldVerifiedAt)
+ return u
+}
+
+// SetIssuer sets the "issuer" field.
+func (u *AuthIdentityUpsert) SetIssuer(v string) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldIssuer, v)
+ return u
+}
+
+// UpdateIssuer sets the "issuer" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateIssuer() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldIssuer)
+ return u
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (u *AuthIdentityUpsert) ClearIssuer() *AuthIdentityUpsert {
+ u.SetNull(authidentity.FieldIssuer)
+ return u
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityUpsert {
+ u.Set(authidentity.FieldMetadata, v)
+ return u
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityUpsert) UpdateMetadata() *AuthIdentityUpsert {
+ u.SetExcluded(authidentity.FieldMetadata)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityUpsertOne) UpdateNewValues() *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentity.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityUpsertOne) Ignore() *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityUpsertOne) DoNothing() *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreate.OnConflict
+// documentation for more info.
+func (u *AuthIdentityUpsertOne) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateUpdatedAt() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AuthIdentityUpsertOne) SetUserID(v int64) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUserID(v)
+ })
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateUserID() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUserID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityUpsertOne) SetProviderType(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateProviderType() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityUpsertOne) SetProviderKey(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateProviderKey() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *AuthIdentityUpsertOne) SetProviderSubject(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateProviderSubject() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (u *AuthIdentityUpsertOne) SetVerifiedAt(v time.Time) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetVerifiedAt(v)
+ })
+}
+
+// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateVerifiedAt() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateVerifiedAt()
+ })
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (u *AuthIdentityUpsertOne) ClearVerifiedAt() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearVerifiedAt()
+ })
+}
+
+// SetIssuer sets the "issuer" field.
+func (u *AuthIdentityUpsertOne) SetIssuer(v string) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetIssuer(v)
+ })
+}
+
+// UpdateIssuer sets the "issuer" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateIssuer() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateIssuer()
+ })
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (u *AuthIdentityUpsertOne) ClearIssuer() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearIssuer()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityUpsertOne) UpdateMetadata() *AuthIdentityUpsertOne {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *AuthIdentityUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *AuthIdentityUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// AuthIdentityCreateBulk is the builder for creating many AuthIdentity entities in bulk.
+type AuthIdentityCreateBulk struct {
+ config
+ err error
+ builders []*AuthIdentityCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the AuthIdentity entities in the database.
+func (_c *AuthIdentityCreateBulk) Save(ctx context.Context) ([]*AuthIdentity, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*AuthIdentity, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*AuthIdentityMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *AuthIdentityCreateBulk) SaveX(ctx context.Context) []*AuthIdentity {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentity.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertBulk {
+ _c.conflict = opts
+ return &AuthIdentityUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityUpsertBulk{
+ create: _c,
+ }
+}
+
+// AuthIdentityUpsertBulk is the builder for "upsert"-ing
+// a bulk of AuthIdentity nodes.
+type AuthIdentityUpsertBulk struct {
+ create *AuthIdentityCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityUpsertBulk) UpdateNewValues() *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentity.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentity.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityUpsertBulk) Ignore() *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityUpsertBulk) DoNothing() *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreateBulk.OnConflict
+// documentation for more info.
+func (u *AuthIdentityUpsertBulk) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateUpdatedAt() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetUserID sets the "user_id" field.
+func (u *AuthIdentityUpsertBulk) SetUserID(v int64) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetUserID(v)
+ })
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateUserID() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateUserID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityUpsertBulk) SetProviderType(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateProviderType() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityUpsertBulk) SetProviderKey(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateProviderKey() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *AuthIdentityUpsertBulk) SetProviderSubject(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateProviderSubject() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (u *AuthIdentityUpsertBulk) SetVerifiedAt(v time.Time) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetVerifiedAt(v)
+ })
+}
+
+// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateVerifiedAt() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateVerifiedAt()
+ })
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (u *AuthIdentityUpsertBulk) ClearVerifiedAt() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearVerifiedAt()
+ })
+}
+
+// SetIssuer sets the "issuer" field.
+func (u *AuthIdentityUpsertBulk) SetIssuer(v string) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetIssuer(v)
+ })
+}
+
+// UpdateIssuer sets the "issuer" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateIssuer() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateIssuer()
+ })
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (u *AuthIdentityUpsertBulk) ClearIssuer() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.ClearIssuer()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityUpsertBulk) UpdateMetadata() *AuthIdentityUpsertBulk {
+ return u.Update(func(s *AuthIdentityUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentity_delete.go b/backend/ent/authidentity_delete.go
new file mode 100644
index 00000000..4f1f6f3c
--- /dev/null
+++ b/backend/ent/authidentity_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityDelete is the builder for deleting a AuthIdentity entity.
+type AuthIdentityDelete struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityMutation
+}
+
+// Where appends a list predicates to the AuthIdentityDelete builder.
+func (_d *AuthIdentityDelete) Where(ps ...predicate.AuthIdentity) *AuthIdentityDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *AuthIdentityDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *AuthIdentityDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// AuthIdentityDeleteOne is the builder for deleting a single AuthIdentity entity.
+type AuthIdentityDeleteOne struct {
+ _d *AuthIdentityDelete
+}
+
+// Where appends a list predicates to the AuthIdentityDelete builder.
+func (_d *AuthIdentityDeleteOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *AuthIdentityDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{authidentity.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentity_query.go b/backend/ent/authidentity_query.go
new file mode 100644
index 00000000..ff27ef3c
--- /dev/null
+++ b/backend/ent/authidentity_query.go
@@ -0,0 +1,797 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentityQuery is the builder for querying AuthIdentity entities.
+type AuthIdentityQuery struct {
+ config
+ ctx *QueryContext
+ order []authidentity.OrderOption
+ inters []Interceptor
+ predicates []predicate.AuthIdentity
+ withUser *UserQuery
+ withChannels *AuthIdentityChannelQuery
+ withAdoptionDecisions *IdentityAdoptionDecisionQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the AuthIdentityQuery builder.
+func (_q *AuthIdentityQuery) Where(ps ...predicate.AuthIdentity) *AuthIdentityQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *AuthIdentityQuery) Limit(limit int) *AuthIdentityQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *AuthIdentityQuery) Offset(offset int) *AuthIdentityQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *AuthIdentityQuery) Unique(unique bool) *AuthIdentityQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *AuthIdentityQuery) Order(o ...authidentity.OrderOption) *AuthIdentityQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryUser chains the current query on the "user" edge.
+func (_q *AuthIdentityQuery) QueryUser() *UserQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryChannels chains the current query on the "channels" edge.
+func (_q *AuthIdentityQuery) QueryChannels() *AuthIdentityChannelQuery {
+ query := (&AuthIdentityChannelClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
+ sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecisions chains the current query on the "adoption_decisions" edge.
+func (_q *AuthIdentityQuery) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, selector),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first AuthIdentity entity from the query.
+// Returns a *NotFoundError when no AuthIdentity was found.
+func (_q *AuthIdentityQuery) First(ctx context.Context) (*AuthIdentity, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{authidentity.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *AuthIdentityQuery) FirstX(ctx context.Context) *AuthIdentity {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first AuthIdentity ID from the query.
+// Returns a *NotFoundError when no AuthIdentity ID was found.
+func (_q *AuthIdentityQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{authidentity.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *AuthIdentityQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single AuthIdentity entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one AuthIdentity entity is found.
+// Returns a *NotFoundError when no AuthIdentity entities are found.
+func (_q *AuthIdentityQuery) Only(ctx context.Context) (*AuthIdentity, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{authidentity.Label}
+ default:
+ return nil, &NotSingularError{authidentity.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *AuthIdentityQuery) OnlyX(ctx context.Context) *AuthIdentity {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only AuthIdentity ID in the query.
+// Returns a *NotSingularError when more than one AuthIdentity ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *AuthIdentityQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{authidentity.Label}
+ default:
+ err = &NotSingularError{authidentity.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *AuthIdentityQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of AuthIdentities.
+func (_q *AuthIdentityQuery) All(ctx context.Context) ([]*AuthIdentity, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*AuthIdentity, *AuthIdentityQuery]()
+ return withInterceptors[[]*AuthIdentity](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *AuthIdentityQuery) AllX(ctx context.Context) []*AuthIdentity {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of AuthIdentity IDs.
+func (_q *AuthIdentityQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(authidentity.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *AuthIdentityQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *AuthIdentityQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *AuthIdentityQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *AuthIdentityQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *AuthIdentityQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the AuthIdentityQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *AuthIdentityQuery) Clone() *AuthIdentityQuery {
+ if _q == nil {
+ return nil
+ }
+ return &AuthIdentityQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]authidentity.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.AuthIdentity{}, _q.predicates...),
+ withUser: _q.withUser.Clone(),
+ withChannels: _q.withChannels.Clone(),
+ withAdoptionDecisions: _q.withAdoptionDecisions.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithUser tells the query-builder to eager-load the nodes that are connected to
+// the "user" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityQuery) WithUser(opts ...func(*UserQuery)) *AuthIdentityQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withUser = query
+ return _q
+}
+
+// WithChannels tells the query-builder to eager-load the nodes that are connected to
+// the "channels" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityQuery) WithChannels(opts ...func(*AuthIdentityChannelQuery)) *AuthIdentityQuery {
+ query := (&AuthIdentityChannelClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withChannels = query
+ return _q
+}
+
+// WithAdoptionDecisions tells the query-builder to eager-load the nodes that are connected to
+// the "adoption_decisions" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityQuery) WithAdoptionDecisions(opts ...func(*IdentityAdoptionDecisionQuery)) *AuthIdentityQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAdoptionDecisions = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.AuthIdentity.Query().
+// GroupBy(authidentity.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *AuthIdentityQuery) GroupBy(field string, fields ...string) *AuthIdentityGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &AuthIdentityGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = authidentity.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.AuthIdentity.Query().
+// Select(authidentity.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *AuthIdentityQuery) Select(fields ...string) *AuthIdentitySelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &AuthIdentitySelect{AuthIdentityQuery: _q}
+ sbuild.label = authidentity.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a AuthIdentitySelect configured with the given aggregations.
+func (_q *AuthIdentityQuery) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *AuthIdentityQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !authidentity.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *AuthIdentityQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentity, error) {
+ var (
+ nodes = []*AuthIdentity{}
+ _spec = _q.querySpec()
+ loadedTypes = [3]bool{
+ _q.withUser != nil,
+ _q.withChannels != nil,
+ _q.withAdoptionDecisions != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*AuthIdentity).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &AuthIdentity{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withUser; query != nil {
+ if err := _q.loadUser(ctx, query, nodes, nil,
+ func(n *AuthIdentity, e *User) { n.Edges.User = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withChannels; query != nil {
+ if err := _q.loadChannels(ctx, query, nodes,
+ func(n *AuthIdentity) { n.Edges.Channels = []*AuthIdentityChannel{} },
+ func(n *AuthIdentity, e *AuthIdentityChannel) { n.Edges.Channels = append(n.Edges.Channels, e) }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withAdoptionDecisions; query != nil {
+ if err := _q.loadAdoptionDecisions(ctx, query, nodes,
+ func(n *AuthIdentity) { n.Edges.AdoptionDecisions = []*IdentityAdoptionDecision{} },
+ func(n *AuthIdentity, e *IdentityAdoptionDecision) {
+ n.Edges.AdoptionDecisions = append(n.Edges.AdoptionDecisions, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *AuthIdentityQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *User)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*AuthIdentity)
+ for i := range nodes {
+ fk := nodes[i].UserID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(user.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *AuthIdentityQuery) loadChannels(ctx context.Context, query *AuthIdentityChannelQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *AuthIdentityChannel)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*AuthIdentity)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(authidentitychannel.FieldIdentityID)
+ }
+ query.Where(predicate.AuthIdentityChannel(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(authidentity.ChannelsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.IdentityID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+func (_q *AuthIdentityQuery) loadAdoptionDecisions(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *IdentityAdoptionDecision)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*AuthIdentity)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(identityadoptiondecision.FieldIdentityID)
+ }
+ query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(authidentity.AdoptionDecisionsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.IdentityID
+ if fk == nil {
+ return fmt.Errorf(`foreign-key "identity_id" is nil for node %v`, n.ID)
+ }
+ node, ok := nodeids[*fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, *fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+
+func (_q *AuthIdentityQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *AuthIdentityQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID)
+ for i := range fields {
+ if fields[i] != authidentity.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withUser != nil {
+ _spec.Node.AddColumnOnce(authidentity.FieldUserID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *AuthIdentityQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(authidentity.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = authidentity.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *AuthIdentityQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *AuthIdentityQuery) ForShare(opts ...sql.LockOption) *AuthIdentityQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// AuthIdentityGroupBy is the group-by builder for AuthIdentity entities.
+type AuthIdentityGroupBy struct {
+ selector
+ build *AuthIdentityQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *AuthIdentityGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *AuthIdentityGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentityGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *AuthIdentityGroupBy) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// AuthIdentitySelect is the builder for selecting fields of AuthIdentity entities.
+type AuthIdentitySelect struct {
+ *AuthIdentityQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *AuthIdentitySelect) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *AuthIdentitySelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentitySelect](ctx, _s.AuthIdentityQuery, _s, _s.inters, v)
+}
+
+func (_s *AuthIdentitySelect) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/authidentity_update.go b/backend/ent/authidentity_update.go
new file mode 100644
index 00000000..c457470b
--- /dev/null
+++ b/backend/ent/authidentity_update.go
@@ -0,0 +1,923 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// AuthIdentityUpdate is the builder for updating AuthIdentity entities.
+type AuthIdentityUpdate struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityMutation
+}
+
+// Where appends a list predicates to the AuthIdentityUpdate builder.
+func (_u *AuthIdentityUpdate) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityUpdate) SetUpdatedAt(v time.Time) *AuthIdentityUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetUserID sets the "user_id" field.
+func (_u *AuthIdentityUpdate) SetUserID(v int64) *AuthIdentityUpdate {
+ _u.mutation.SetUserID(v)
+ return _u
+}
+
+// SetNillableUserID sets the "user_id" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableUserID(v *int64) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetUserID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityUpdate) SetProviderType(v string) *AuthIdentityUpdate {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableProviderType(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityUpdate) SetProviderKey(v string) *AuthIdentityUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableProviderKey(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *AuthIdentityUpdate) SetProviderSubject(v string) *AuthIdentityUpdate {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableProviderSubject(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (_u *AuthIdentityUpdate) SetVerifiedAt(v time.Time) *AuthIdentityUpdate {
+ _u.mutation.SetVerifiedAt(v)
+ return _u
+}
+
+// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (_u *AuthIdentityUpdate) ClearVerifiedAt() *AuthIdentityUpdate {
+ _u.mutation.ClearVerifiedAt()
+ return _u
+}
+
+// SetIssuer sets the "issuer" field.
+func (_u *AuthIdentityUpdate) SetIssuer(v string) *AuthIdentityUpdate {
+ _u.mutation.SetIssuer(v)
+ return _u
+}
+
+// SetNillableIssuer sets the "issuer" field if the given value is not nil.
+func (_u *AuthIdentityUpdate) SetNillableIssuer(v *string) *AuthIdentityUpdate {
+ if v != nil {
+ _u.SetIssuer(*v)
+ }
+ return _u
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (_u *AuthIdentityUpdate) ClearIssuer() *AuthIdentityUpdate {
+ _u.mutation.ClearIssuer()
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityUpdate {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_u *AuthIdentityUpdate) SetUser(v *User) *AuthIdentityUpdate {
+ return _u.SetUserID(v.ID)
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (_u *AuthIdentityUpdate) AddChannelIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.AddChannelIDs(ids...)
+ return _u
+}
+
+// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddChannelIDs(ids...)
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (_u *AuthIdentityUpdate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.AddAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAdoptionDecisionIDs(ids...)
+}
+
+// Mutation returns the AuthIdentityMutation object of the builder.
+func (_u *AuthIdentityUpdate) Mutation() *AuthIdentityMutation {
+ return _u.mutation
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (_u *AuthIdentityUpdate) ClearUser() *AuthIdentityUpdate {
+ _u.mutation.ClearUser()
+ return _u
+}
+
+// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdate) ClearChannels() *AuthIdentityUpdate {
+ _u.mutation.ClearChannels()
+ return _u
+}
+
+// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs.
+func (_u *AuthIdentityUpdate) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.RemoveChannelIDs(ids...)
+ return _u
+}
+
+// RemoveChannels removes "channels" edges to AuthIdentityChannel entities.
+func (_u *AuthIdentityUpdate) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveChannelIDs(ids...)
+}
+
+// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdate) ClearAdoptionDecisions() *AuthIdentityUpdate {
+ _u.mutation.ClearAdoptionDecisions()
+ return _u
+}
+
+// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs.
+func (_u *AuthIdentityUpdate) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate {
+ _u.mutation.RemoveAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities.
+func (_u *AuthIdentityUpdate) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAdoptionDecisionIDs(ids...)
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *AuthIdentityUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *AuthIdentityUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentity.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityUpdate) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentity.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentity.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := authidentity.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.VerifiedAt(); ok {
+ _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.VerifiedAtCleared() {
+ _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.Issuer(); ok {
+ _spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
+ }
+ if _u.mutation.IssuerCleared() {
+ _spec.ClearField(authidentity.FieldIssuer, field.TypeString)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.UserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentity.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// AuthIdentityUpdateOne is the builder for updating a single AuthIdentity entity.
+type AuthIdentityUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *AuthIdentityMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetUserID sets the "user_id" field.
+func (_u *AuthIdentityUpdateOne) SetUserID(v int64) *AuthIdentityUpdateOne {
+ _u.mutation.SetUserID(v)
+ return _u
+}
+
+// SetNillableUserID sets the "user_id" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableUserID(v *int64) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetUserID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityUpdateOne) SetProviderType(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableProviderType(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityUpdateOne) SetProviderKey(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *AuthIdentityUpdateOne) SetProviderSubject(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableProviderSubject(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (_u *AuthIdentityUpdateOne) SetVerifiedAt(v time.Time) *AuthIdentityUpdateOne {
+ _u.mutation.SetVerifiedAt(v)
+ return _u
+}
+
+// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (_u *AuthIdentityUpdateOne) ClearVerifiedAt() *AuthIdentityUpdateOne {
+ _u.mutation.ClearVerifiedAt()
+ return _u
+}
+
+// SetIssuer sets the "issuer" field.
+func (_u *AuthIdentityUpdateOne) SetIssuer(v string) *AuthIdentityUpdateOne {
+ _u.mutation.SetIssuer(v)
+ return _u
+}
+
+// SetNillableIssuer sets the "issuer" field if the given value is not nil.
+func (_u *AuthIdentityUpdateOne) SetNillableIssuer(v *string) *AuthIdentityUpdateOne {
+ if v != nil {
+ _u.SetIssuer(*v)
+ }
+ return _u
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (_u *AuthIdentityUpdateOne) ClearIssuer() *AuthIdentityUpdateOne {
+ _u.mutation.ClearIssuer()
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpdateOne {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_u *AuthIdentityUpdateOne) SetUser(v *User) *AuthIdentityUpdateOne {
+ return _u.SetUserID(v.ID)
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (_u *AuthIdentityUpdateOne) AddChannelIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.AddChannelIDs(ids...)
+ return _u
+}
+
+// AddChannels adds the "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdateOne) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddChannelIDs(ids...)
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (_u *AuthIdentityUpdateOne) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.AddAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdateOne) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAdoptionDecisionIDs(ids...)
+}
+
+// Mutation returns the AuthIdentityMutation object of the builder.
+func (_u *AuthIdentityUpdateOne) Mutation() *AuthIdentityMutation {
+ return _u.mutation
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (_u *AuthIdentityUpdateOne) ClearUser() *AuthIdentityUpdateOne {
+ _u.mutation.ClearUser()
+ return _u
+}
+
+// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity.
+func (_u *AuthIdentityUpdateOne) ClearChannels() *AuthIdentityUpdateOne {
+ _u.mutation.ClearChannels()
+ return _u
+}
+
+// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs.
+func (_u *AuthIdentityUpdateOne) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.RemoveChannelIDs(ids...)
+ return _u
+}
+
+// RemoveChannels removes "channels" edges to AuthIdentityChannel entities.
+func (_u *AuthIdentityUpdateOne) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveChannelIDs(ids...)
+}
+
+// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity.
+func (_u *AuthIdentityUpdateOne) ClearAdoptionDecisions() *AuthIdentityUpdateOne {
+ _u.mutation.ClearAdoptionDecisions()
+ return _u
+}
+
+// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs.
+func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne {
+ _u.mutation.RemoveAdoptionDecisionIDs(ids...)
+ return _u
+}
+
+// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities.
+func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAdoptionDecisionIDs(ids...)
+}
+
+// Where appends a list predicates to the AuthIdentityUpdate builder.
+func (_u *AuthIdentityUpdateOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *AuthIdentityUpdateOne) Select(field string, fields ...string) *AuthIdentityUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated AuthIdentity entity.
+func (_u *AuthIdentityUpdateOne) Save(ctx context.Context) (*AuthIdentity, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityUpdateOne) SaveX(ctx context.Context) *AuthIdentity {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *AuthIdentityUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentity.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityUpdateOne) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentity.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentity.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := authidentity.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentity, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentity.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID)
+ for _, f := range fields {
+ if !authidentity.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != authidentity.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentity.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.VerifiedAt(); ok {
+ _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.VerifiedAtCleared() {
+ _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.Issuer(); ok {
+ _spec.SetField(authidentity.FieldIssuer, field.TypeString, value)
+ }
+ if _u.mutation.IssuerCleared() {
+ _spec.ClearField(authidentity.FieldIssuer, field.TypeString)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.UserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentity.UserTable,
+ Columns: []string{authidentity.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.ChannelsTable,
+ Columns: []string{authidentity.ChannelsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: authidentity.AdoptionDecisionsTable,
+ Columns: []string{authidentity.AdoptionDecisionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &AuthIdentity{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentity.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/authidentitychannel.go b/backend/ent/authidentitychannel.go
new file mode 100644
index 00000000..1ff3e5d1
--- /dev/null
+++ b/backend/ent/authidentitychannel.go
@@ -0,0 +1,228 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+)
+
+// AuthIdentityChannel is the model entity for the AuthIdentityChannel schema.
+type AuthIdentityChannel struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // IdentityID holds the value of the "identity_id" field.
+ IdentityID int64 `json:"identity_id,omitempty"`
+ // ProviderType holds the value of the "provider_type" field.
+ ProviderType string `json:"provider_type,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // Channel holds the value of the "channel" field.
+ Channel string `json:"channel,omitempty"`
+ // ChannelAppID holds the value of the "channel_app_id" field.
+ ChannelAppID string `json:"channel_app_id,omitempty"`
+ // ChannelSubject holds the value of the "channel_subject" field.
+ ChannelSubject string `json:"channel_subject,omitempty"`
+ // Metadata holds the value of the "metadata" field.
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the AuthIdentityChannelQuery when eager-loading is set.
+ Edges AuthIdentityChannelEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// AuthIdentityChannelEdges holds the relations/edges for other nodes in the graph.
+type AuthIdentityChannelEdges struct {
+ // Identity holds the value of the identity edge.
+ Identity *AuthIdentity `json:"identity,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// IdentityOrErr returns the Identity value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e AuthIdentityChannelEdges) IdentityOrErr() (*AuthIdentity, error) {
+ if e.Identity != nil {
+ return e.Identity, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: authidentity.Label}
+ }
+ return nil, &NotLoadedError{edge: "identity"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*AuthIdentityChannel) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case authidentitychannel.FieldMetadata:
+ values[i] = new([]byte)
+ case authidentitychannel.FieldID, authidentitychannel.FieldIdentityID:
+ values[i] = new(sql.NullInt64)
+ case authidentitychannel.FieldProviderType, authidentitychannel.FieldProviderKey, authidentitychannel.FieldChannel, authidentitychannel.FieldChannelAppID, authidentitychannel.FieldChannelSubject:
+ values[i] = new(sql.NullString)
+ case authidentitychannel.FieldCreatedAt, authidentitychannel.FieldUpdatedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the AuthIdentityChannel fields.
+func (_m *AuthIdentityChannel) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case authidentitychannel.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case authidentitychannel.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case authidentitychannel.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case authidentitychannel.FieldIdentityID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field identity_id", values[i])
+ } else if value.Valid {
+ _m.IdentityID = value.Int64
+ }
+ case authidentitychannel.FieldProviderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_type", values[i])
+ } else if value.Valid {
+ _m.ProviderType = value.String
+ }
+ case authidentitychannel.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case authidentitychannel.FieldChannel:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field channel", values[i])
+ } else if value.Valid {
+ _m.Channel = value.String
+ }
+ case authidentitychannel.FieldChannelAppID:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field channel_app_id", values[i])
+ } else if value.Valid {
+ _m.ChannelAppID = value.String
+ }
+ case authidentitychannel.FieldChannelSubject:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field channel_subject", values[i])
+ } else if value.Valid {
+ _m.ChannelSubject = value.String
+ }
+ case authidentitychannel.FieldMetadata:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field metadata", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.Metadata); err != nil {
+ return fmt.Errorf("unmarshal field metadata: %w", err)
+ }
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentityChannel.
+// This includes values selected through modifiers, order, etc.
+func (_m *AuthIdentityChannel) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryIdentity queries the "identity" edge of the AuthIdentityChannel entity.
+func (_m *AuthIdentityChannel) QueryIdentity() *AuthIdentityQuery {
+ return NewAuthIdentityChannelClient(_m.config).QueryIdentity(_m)
+}
+
+// Update returns a builder for updating this AuthIdentityChannel.
+// Note that you need to call AuthIdentityChannel.Unwrap() before calling this method if this AuthIdentityChannel
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *AuthIdentityChannel) Update() *AuthIdentityChannelUpdateOne {
+ return NewAuthIdentityChannelClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the AuthIdentityChannel entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *AuthIdentityChannel) Unwrap() *AuthIdentityChannel {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: AuthIdentityChannel is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *AuthIdentityChannel) String() string {
+ var builder strings.Builder
+ builder.WriteString("AuthIdentityChannel(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("identity_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.IdentityID))
+ builder.WriteString(", ")
+ builder.WriteString("provider_type=")
+ builder.WriteString(_m.ProviderType)
+ builder.WriteString(", ")
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("channel=")
+ builder.WriteString(_m.Channel)
+ builder.WriteString(", ")
+ builder.WriteString("channel_app_id=")
+ builder.WriteString(_m.ChannelAppID)
+ builder.WriteString(", ")
+ builder.WriteString("channel_subject=")
+ builder.WriteString(_m.ChannelSubject)
+ builder.WriteString(", ")
+ builder.WriteString("metadata=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Metadata))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// AuthIdentityChannels is a parsable slice of AuthIdentityChannel.
+type AuthIdentityChannels []*AuthIdentityChannel
diff --git a/backend/ent/authidentitychannel/authidentitychannel.go b/backend/ent/authidentitychannel/authidentitychannel.go
new file mode 100644
index 00000000..7dcc98bb
--- /dev/null
+++ b/backend/ent/authidentitychannel/authidentitychannel.go
@@ -0,0 +1,153 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentitychannel
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the authidentitychannel type in the database.
+ Label = "auth_identity_channel"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldIdentityID holds the string denoting the identity_id field in the database.
+ FieldIdentityID = "identity_id"
+ // FieldProviderType holds the string denoting the provider_type field in the database.
+ FieldProviderType = "provider_type"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldChannel holds the string denoting the channel field in the database.
+ FieldChannel = "channel"
+ // FieldChannelAppID holds the string denoting the channel_app_id field in the database.
+ FieldChannelAppID = "channel_app_id"
+ // FieldChannelSubject holds the string denoting the channel_subject field in the database.
+ FieldChannelSubject = "channel_subject"
+ // FieldMetadata holds the string denoting the metadata field in the database.
+ FieldMetadata = "metadata"
+ // EdgeIdentity holds the string denoting the identity edge name in mutations.
+ EdgeIdentity = "identity"
+ // Table holds the table name of the authidentitychannel in the database.
+ Table = "auth_identity_channels"
+ // IdentityTable is the table that holds the identity relation/edge.
+ IdentityTable = "auth_identity_channels"
+ // IdentityInverseTable is the table name for the AuthIdentity entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentity" package.
+ IdentityInverseTable = "auth_identities"
+ // IdentityColumn is the table column denoting the identity relation/edge.
+ IdentityColumn = "identity_id"
+)
+
+// Columns holds all SQL columns for authidentitychannel fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldIdentityID,
+ FieldProviderType,
+ FieldProviderKey,
+ FieldChannel,
+ FieldChannelAppID,
+ FieldChannelSubject,
+ FieldMetadata,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ ProviderTypeValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // ChannelValidator is a validator for the "channel" field. It is called by the builders before save.
+ ChannelValidator func(string) error
+ // ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save.
+ ChannelAppIDValidator func(string) error
+ // ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save.
+ ChannelSubjectValidator func(string) error
+ // DefaultMetadata holds the default value on creation for the "metadata" field.
+ DefaultMetadata func() map[string]interface{}
+)
+
+// OrderOption defines the ordering options for the AuthIdentityChannel queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByIdentityID orders the results by the identity_id field.
+func ByIdentityID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIdentityID, opts...).ToFunc()
+}
+
+// ByProviderType orders the results by the provider_type field.
+func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderType, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByChannel orders the results by the channel field.
+func ByChannel(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannel, opts...).ToFunc()
+}
+
+// ByChannelAppID orders the results by the channel_app_id field.
+func ByChannelAppID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannelAppID, opts...).ToFunc()
+}
+
+// ByChannelSubject orders the results by the channel_subject field.
+func ByChannelSubject(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannelSubject, opts...).ToFunc()
+}
+
+// ByIdentityField orders the results by identity field.
+func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newIdentityStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(IdentityInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+}
diff --git a/backend/ent/authidentitychannel/where.go b/backend/ent/authidentitychannel/where.go
new file mode 100644
index 00000000..827dc384
--- /dev/null
+++ b/backend/ent/authidentitychannel/where.go
@@ -0,0 +1,559 @@
+// Code generated by ent, DO NOT EDIT.
+
+package authidentitychannel
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ.
+func IdentityID(v int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
+func ProviderType(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// Channel applies equality check predicate on the "channel" field. It's identical to ChannelEQ.
+func Channel(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v))
+}
+
+// ChannelAppID applies equality check predicate on the "channel_app_id" field. It's identical to ChannelAppIDEQ.
+func ChannelAppID(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v))
+}
+
+// ChannelSubject applies equality check predicate on the "channel_subject" field. It's identical to ChannelSubjectEQ.
+func ChannelSubject(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// IdentityIDEQ applies the EQ predicate on the "identity_id" field.
+func IdentityIDEQ(v int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field.
+func IdentityIDNEQ(v int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldIdentityID, v))
+}
+
+// IdentityIDIn applies the In predicate on the "identity_id" field.
+func IdentityIDIn(vs ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldIdentityID, vs...))
+}
+
+// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field.
+func IdentityIDNotIn(vs ...int64) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldIdentityID, vs...))
+}
+
+// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
+func ProviderTypeEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
+func ProviderTypeNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderType, v))
+}
+
+// ProviderTypeIn applies the In predicate on the "provider_type" field.
+func ProviderTypeIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
+func ProviderTypeNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeGT applies the GT predicate on the "provider_type" field.
+func ProviderTypeGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderType, v))
+}
+
+// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
+func ProviderTypeGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderType, v))
+}
+
+// ProviderTypeLT applies the LT predicate on the "provider_type" field.
+func ProviderTypeLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderType, v))
+}
+
+// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
+func ProviderTypeLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderType, v))
+}
+
+// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
+func ProviderTypeContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderType, v))
+}
+
+// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
+func ProviderTypeHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderType, v))
+}
+
+// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
+func ProviderTypeHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderType, v))
+}
+
+// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
+func ProviderTypeEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderType, v))
+}
+
+// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
+func ProviderTypeContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderType, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ChannelEQ applies the EQ predicate on the "channel" field.
+func ChannelEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v))
+}
+
+// ChannelNEQ applies the NEQ predicate on the "channel" field.
+func ChannelNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannel, v))
+}
+
+// ChannelIn applies the In predicate on the "channel" field.
+func ChannelIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannel, vs...))
+}
+
+// ChannelNotIn applies the NotIn predicate on the "channel" field.
+func ChannelNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannel, vs...))
+}
+
+// ChannelGT applies the GT predicate on the "channel" field.
+func ChannelGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannel, v))
+}
+
+// ChannelGTE applies the GTE predicate on the "channel" field.
+func ChannelGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannel, v))
+}
+
+// ChannelLT applies the LT predicate on the "channel" field.
+func ChannelLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannel, v))
+}
+
+// ChannelLTE applies the LTE predicate on the "channel" field.
+func ChannelLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannel, v))
+}
+
+// ChannelContains applies the Contains predicate on the "channel" field.
+func ChannelContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannel, v))
+}
+
+// ChannelHasPrefix applies the HasPrefix predicate on the "channel" field.
+func ChannelHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannel, v))
+}
+
+// ChannelHasSuffix applies the HasSuffix predicate on the "channel" field.
+func ChannelHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannel, v))
+}
+
+// ChannelEqualFold applies the EqualFold predicate on the "channel" field.
+func ChannelEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannel, v))
+}
+
+// ChannelContainsFold applies the ContainsFold predicate on the "channel" field.
+func ChannelContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannel, v))
+}
+
+// ChannelAppIDEQ applies the EQ predicate on the "channel_app_id" field.
+func ChannelAppIDEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v))
+}
+
+// ChannelAppIDNEQ applies the NEQ predicate on the "channel_app_id" field.
+func ChannelAppIDNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelAppID, v))
+}
+
+// ChannelAppIDIn applies the In predicate on the "channel_app_id" field.
+func ChannelAppIDIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelAppID, vs...))
+}
+
+// ChannelAppIDNotIn applies the NotIn predicate on the "channel_app_id" field.
+func ChannelAppIDNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelAppID, vs...))
+}
+
+// ChannelAppIDGT applies the GT predicate on the "channel_app_id" field.
+func ChannelAppIDGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelAppID, v))
+}
+
+// ChannelAppIDGTE applies the GTE predicate on the "channel_app_id" field.
+func ChannelAppIDGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelAppID, v))
+}
+
+// ChannelAppIDLT applies the LT predicate on the "channel_app_id" field.
+func ChannelAppIDLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelAppID, v))
+}
+
+// ChannelAppIDLTE applies the LTE predicate on the "channel_app_id" field.
+func ChannelAppIDLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelAppID, v))
+}
+
+// ChannelAppIDContains applies the Contains predicate on the "channel_app_id" field.
+func ChannelAppIDContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelAppID, v))
+}
+
+// ChannelAppIDHasPrefix applies the HasPrefix predicate on the "channel_app_id" field.
+func ChannelAppIDHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelAppID, v))
+}
+
+// ChannelAppIDHasSuffix applies the HasSuffix predicate on the "channel_app_id" field.
+func ChannelAppIDHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelAppID, v))
+}
+
+// ChannelAppIDEqualFold applies the EqualFold predicate on the "channel_app_id" field.
+func ChannelAppIDEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelAppID, v))
+}
+
+// ChannelAppIDContainsFold applies the ContainsFold predicate on the "channel_app_id" field.
+func ChannelAppIDContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelAppID, v))
+}
+
+// ChannelSubjectEQ applies the EQ predicate on the "channel_subject" field.
+func ChannelSubjectEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v))
+}
+
+// ChannelSubjectNEQ applies the NEQ predicate on the "channel_subject" field.
+func ChannelSubjectNEQ(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelSubject, v))
+}
+
+// ChannelSubjectIn applies the In predicate on the "channel_subject" field.
+func ChannelSubjectIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelSubject, vs...))
+}
+
+// ChannelSubjectNotIn applies the NotIn predicate on the "channel_subject" field.
+func ChannelSubjectNotIn(vs ...string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelSubject, vs...))
+}
+
+// ChannelSubjectGT applies the GT predicate on the "channel_subject" field.
+func ChannelSubjectGT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelSubject, v))
+}
+
+// ChannelSubjectGTE applies the GTE predicate on the "channel_subject" field.
+func ChannelSubjectGTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelSubject, v))
+}
+
+// ChannelSubjectLT applies the LT predicate on the "channel_subject" field.
+func ChannelSubjectLT(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelSubject, v))
+}
+
+// ChannelSubjectLTE applies the LTE predicate on the "channel_subject" field.
+func ChannelSubjectLTE(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelSubject, v))
+}
+
+// ChannelSubjectContains applies the Contains predicate on the "channel_subject" field.
+func ChannelSubjectContains(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelSubject, v))
+}
+
+// ChannelSubjectHasPrefix applies the HasPrefix predicate on the "channel_subject" field.
+func ChannelSubjectHasPrefix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelSubject, v))
+}
+
+// ChannelSubjectHasSuffix applies the HasSuffix predicate on the "channel_subject" field.
+func ChannelSubjectHasSuffix(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelSubject, v))
+}
+
+// ChannelSubjectEqualFold applies the EqualFold predicate on the "channel_subject" field.
+func ChannelSubjectEqualFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelSubject, v))
+}
+
+// ChannelSubjectContainsFold applies the ContainsFold predicate on the "channel_subject" field.
+func ChannelSubjectContainsFold(v string) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelSubject, v))
+}
+
+// HasIdentity applies the HasEdge predicate on the "identity" edge.
+func HasIdentity() predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates).
+func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(func(s *sql.Selector) {
+ step := newIdentityStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.AuthIdentityChannel) predicate.AuthIdentityChannel {
+ return predicate.AuthIdentityChannel(sql.NotPredicates(p))
+}
diff --git a/backend/ent/authidentitychannel_create.go b/backend/ent/authidentitychannel_create.go
new file mode 100644
index 00000000..4ce28479
--- /dev/null
+++ b/backend/ent/authidentitychannel_create.go
@@ -0,0 +1,932 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+)
+
+// AuthIdentityChannelCreate is the builder for creating a AuthIdentityChannel entity.
+type AuthIdentityChannelCreate struct {
+ config
+ mutation *AuthIdentityChannelMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *AuthIdentityChannelCreate) SetCreatedAt(v time.Time) *AuthIdentityChannelCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *AuthIdentityChannelCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityChannelCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *AuthIdentityChannelCreate) SetUpdatedAt(v time.Time) *AuthIdentityChannelCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *AuthIdentityChannelCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityChannelCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_c *AuthIdentityChannelCreate) SetIdentityID(v int64) *AuthIdentityChannelCreate {
+ _c.mutation.SetIdentityID(v)
+ return _c
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_c *AuthIdentityChannelCreate) SetProviderType(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetProviderType(v)
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *AuthIdentityChannelCreate) SetProviderKey(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetChannel sets the "channel" field.
+func (_c *AuthIdentityChannelCreate) SetChannel(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetChannel(v)
+ return _c
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (_c *AuthIdentityChannelCreate) SetChannelAppID(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetChannelAppID(v)
+ return _c
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (_c *AuthIdentityChannelCreate) SetChannelSubject(v string) *AuthIdentityChannelCreate {
+ _c.mutation.SetChannelSubject(v)
+ return _c
+}
+
+// SetMetadata sets the "metadata" field.
+func (_c *AuthIdentityChannelCreate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelCreate {
+ _c.mutation.SetMetadata(v)
+ return _c
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_c *AuthIdentityChannelCreate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelCreate {
+ return _c.SetIdentityID(v.ID)
+}
+
+// Mutation returns the AuthIdentityChannelMutation object of the builder.
+func (_c *AuthIdentityChannelCreate) Mutation() *AuthIdentityChannelMutation {
+ return _c.mutation
+}
+
+// Save creates the AuthIdentityChannel in the database.
+func (_c *AuthIdentityChannelCreate) Save(ctx context.Context) (*AuthIdentityChannel, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *AuthIdentityChannelCreate) SaveX(ctx context.Context) *AuthIdentityChannel {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityChannelCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityChannelCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *AuthIdentityChannelCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := authidentitychannel.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := authidentitychannel.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ v := authidentitychannel.DefaultMetadata()
+ _c.mutation.SetMetadata(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *AuthIdentityChannelCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.updated_at"`)}
+ }
+ if _, ok := _c.mutation.IdentityID(); !ok {
+ return &ValidationError{Name: "identity_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.identity_id"`)}
+ }
+ if _, ok := _c.mutation.ProviderType(); !ok {
+ return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_type"`)}
+ }
+ if v, ok := _c.mutation.ProviderType(); ok {
+ if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Channel(); !ok {
+ return &ValidationError{Name: "channel", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel"`)}
+ }
+ if v, ok := _c.mutation.Channel(); ok {
+ if err := authidentitychannel.ChannelValidator(v); err != nil {
+ return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ChannelAppID(); !ok {
+ return &ValidationError{Name: "channel_app_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_app_id"`)}
+ }
+ if v, ok := _c.mutation.ChannelAppID(); ok {
+ if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
+ return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ChannelSubject(); !ok {
+ return &ValidationError{Name: "channel_subject", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_subject"`)}
+ }
+ if v, ok := _c.mutation.ChannelSubject(); ok {
+ if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Metadata(); !ok {
+ return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentityChannel.metadata"`)}
+ }
+ if len(_c.mutation.IdentityIDs()) == 0 {
+ return &ValidationError{Name: "identity", err: errors.New(`ent: missing required edge "AuthIdentityChannel.identity"`)}
+ }
+ return nil
+}
+
+func (_c *AuthIdentityChannelCreate) sqlSave(ctx context.Context) (*AuthIdentityChannel, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *AuthIdentityChannelCreate) createSpec() (*AuthIdentityChannel, *sqlgraph.CreateSpec) {
+ var (
+ _node = &AuthIdentityChannel{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.ProviderType(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
+ _node.ProviderType = value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.Channel(); ok {
+ _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
+ _node.Channel = value
+ }
+ if value, ok := _c.mutation.ChannelAppID(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
+ _node.ChannelAppID = value
+ }
+ if value, ok := _c.mutation.ChannelSubject(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
+ _node.ChannelSubject = value
+ }
+ if value, ok := _c.mutation.Metadata(); ok {
+ _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
+ _node.Metadata = value
+ }
+ if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.IdentityID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentityChannel.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityChannelUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertOne {
+ _c.conflict = opts
+ return &AuthIdentityChannelUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreate) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityChannelUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // AuthIdentityChannelUpsertOne is the builder for "upsert"-ing
+ // one AuthIdentityChannel node.
+ AuthIdentityChannelUpsertOne struct {
+ create *AuthIdentityChannelCreate
+ }
+
+ // AuthIdentityChannelUpsert is the "OnConflict" setter.
+ AuthIdentityChannelUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityChannelUpsert) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateUpdatedAt() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldUpdatedAt)
+ return u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *AuthIdentityChannelUpsert) SetIdentityID(v int64) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldIdentityID, v)
+ return u
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateIdentityID() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldIdentityID)
+ return u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityChannelUpsert) SetProviderType(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldProviderType, v)
+ return u
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateProviderType() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldProviderType)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityChannelUpsert) SetProviderKey(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateProviderKey() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldProviderKey)
+ return u
+}
+
+// SetChannel sets the "channel" field.
+func (u *AuthIdentityChannelUpsert) SetChannel(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldChannel, v)
+ return u
+}
+
+// UpdateChannel sets the "channel" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateChannel() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldChannel)
+ return u
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (u *AuthIdentityChannelUpsert) SetChannelAppID(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldChannelAppID, v)
+ return u
+}
+
+// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateChannelAppID() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldChannelAppID)
+ return u
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (u *AuthIdentityChannelUpsert) SetChannelSubject(v string) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldChannelSubject, v)
+ return u
+}
+
+// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateChannelSubject() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldChannelSubject)
+ return u
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityChannelUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsert {
+ u.Set(authidentitychannel.FieldMetadata, v)
+ return u
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsert) UpdateMetadata() *AuthIdentityChannelUpsert {
+ u.SetExcluded(authidentitychannel.FieldMetadata)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertOne) UpdateNewValues() *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentitychannel.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertOne) Ignore() *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityChannelUpsertOne) DoNothing() *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreate.OnConflict
+// documentation for more info.
+func (u *AuthIdentityChannelUpsertOne) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityChannelUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityChannelUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateUpdatedAt() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *AuthIdentityChannelUpsertOne) SetIdentityID(v int64) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateIdentityID() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityChannelUpsertOne) SetProviderType(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateProviderType() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityChannelUpsertOne) SetProviderKey(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateProviderKey() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetChannel sets the "channel" field.
+func (u *AuthIdentityChannelUpsertOne) SetChannel(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannel(v)
+ })
+}
+
+// UpdateChannel sets the "channel" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateChannel() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannel()
+ })
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (u *AuthIdentityChannelUpsertOne) SetChannelAppID(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelAppID(v)
+ })
+}
+
+// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateChannelAppID() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelAppID()
+ })
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (u *AuthIdentityChannelUpsertOne) SetChannelSubject(v string) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelSubject(v)
+ })
+}
+
+// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateChannelSubject() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelSubject()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityChannelUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertOne) UpdateMetadata() *AuthIdentityChannelUpsertOne {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityChannelUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityChannelCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityChannelUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *AuthIdentityChannelUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *AuthIdentityChannelUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// AuthIdentityChannelCreateBulk is the builder for creating many AuthIdentityChannel entities in bulk.
+type AuthIdentityChannelCreateBulk struct {
+ config
+ err error
+ builders []*AuthIdentityChannelCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the AuthIdentityChannel entities in the database.
+func (_c *AuthIdentityChannelCreateBulk) Save(ctx context.Context) ([]*AuthIdentityChannel, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*AuthIdentityChannel, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*AuthIdentityChannelMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *AuthIdentityChannelCreateBulk) SaveX(ctx context.Context) []*AuthIdentityChannel {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *AuthIdentityChannelCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *AuthIdentityChannelCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.AuthIdentityChannel.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.AuthIdentityChannelUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertBulk {
+ _c.conflict = opts
+ return &AuthIdentityChannelUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *AuthIdentityChannelCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &AuthIdentityChannelUpsertBulk{
+ create: _c,
+ }
+}
+
+// AuthIdentityChannelUpsertBulk is the builder for "upsert"-ing
+// a bulk of AuthIdentityChannel nodes.
+type AuthIdentityChannelUpsertBulk struct {
+ create *AuthIdentityChannelCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertBulk) UpdateNewValues() *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(authidentitychannel.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.AuthIdentityChannel.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *AuthIdentityChannelUpsertBulk) Ignore() *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *AuthIdentityChannelUpsertBulk) DoNothing() *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreateBulk.OnConflict
+// documentation for more info.
+func (u *AuthIdentityChannelUpsertBulk) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&AuthIdentityChannelUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *AuthIdentityChannelUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateUpdatedAt() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *AuthIdentityChannelUpsertBulk) SetIdentityID(v int64) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateIdentityID() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *AuthIdentityChannelUpsertBulk) SetProviderType(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateProviderType() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *AuthIdentityChannelUpsertBulk) SetProviderKey(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateProviderKey() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetChannel sets the "channel" field.
+func (u *AuthIdentityChannelUpsertBulk) SetChannel(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannel(v)
+ })
+}
+
+// UpdateChannel sets the "channel" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateChannel() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannel()
+ })
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (u *AuthIdentityChannelUpsertBulk) SetChannelAppID(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelAppID(v)
+ })
+}
+
+// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateChannelAppID() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelAppID()
+ })
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (u *AuthIdentityChannelUpsertBulk) SetChannelSubject(v string) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetChannelSubject(v)
+ })
+}
+
+// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateChannelSubject() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateChannelSubject()
+ })
+}
+
+// SetMetadata sets the "metadata" field.
+func (u *AuthIdentityChannelUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.SetMetadata(v)
+ })
+}
+
+// UpdateMetadata sets the "metadata" field to the value that was provided on create.
+func (u *AuthIdentityChannelUpsertBulk) UpdateMetadata() *AuthIdentityChannelUpsertBulk {
+ return u.Update(func(s *AuthIdentityChannelUpsert) {
+ s.UpdateMetadata()
+ })
+}
+
+// Exec executes the query.
+func (u *AuthIdentityChannelUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityChannelCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for AuthIdentityChannelCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *AuthIdentityChannelUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentitychannel_delete.go b/backend/ent/authidentitychannel_delete.go
new file mode 100644
index 00000000..1a4acac5
--- /dev/null
+++ b/backend/ent/authidentitychannel_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityChannelDelete is the builder for deleting a AuthIdentityChannel entity.
+type AuthIdentityChannelDelete struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityChannelMutation
+}
+
+// Where appends a list predicates to the AuthIdentityChannelDelete builder.
+func (_d *AuthIdentityChannelDelete) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *AuthIdentityChannelDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityChannelDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *AuthIdentityChannelDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// AuthIdentityChannelDeleteOne is the builder for deleting a single AuthIdentityChannel entity.
+type AuthIdentityChannelDeleteOne struct {
+ _d *AuthIdentityChannelDelete
+}
+
+// Where appends a list predicates to the AuthIdentityChannelDelete builder.
+func (_d *AuthIdentityChannelDeleteOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *AuthIdentityChannelDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{authidentitychannel.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *AuthIdentityChannelDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/authidentitychannel_query.go b/backend/ent/authidentitychannel_query.go
new file mode 100644
index 00000000..7a202b7f
--- /dev/null
+++ b/backend/ent/authidentitychannel_query.go
@@ -0,0 +1,643 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityChannelQuery is the builder for querying AuthIdentityChannel entities.
+type AuthIdentityChannelQuery struct {
+ config
+ ctx *QueryContext
+ order []authidentitychannel.OrderOption
+ inters []Interceptor
+ predicates []predicate.AuthIdentityChannel
+ withIdentity *AuthIdentityQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the AuthIdentityChannelQuery builder.
+func (_q *AuthIdentityChannelQuery) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *AuthIdentityChannelQuery) Limit(limit int) *AuthIdentityChannelQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *AuthIdentityChannelQuery) Offset(offset int) *AuthIdentityChannelQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *AuthIdentityChannelQuery) Unique(unique bool) *AuthIdentityChannelQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *AuthIdentityChannelQuery) Order(o ...authidentitychannel.OrderOption) *AuthIdentityChannelQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryIdentity chains the current query on the "identity" edge.
+func (_q *AuthIdentityChannelQuery) QueryIdentity() *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, selector),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first AuthIdentityChannel entity from the query.
+// Returns a *NotFoundError when no AuthIdentityChannel was found.
+func (_q *AuthIdentityChannelQuery) First(ctx context.Context) (*AuthIdentityChannel, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{authidentitychannel.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) FirstX(ctx context.Context) *AuthIdentityChannel {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first AuthIdentityChannel ID from the query.
+// Returns a *NotFoundError when no AuthIdentityChannel ID was found.
+func (_q *AuthIdentityChannelQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{authidentitychannel.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single AuthIdentityChannel entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one AuthIdentityChannel entity is found.
+// Returns a *NotFoundError when no AuthIdentityChannel entities are found.
+func (_q *AuthIdentityChannelQuery) Only(ctx context.Context) (*AuthIdentityChannel, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{authidentitychannel.Label}
+ default:
+ return nil, &NotSingularError{authidentitychannel.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) OnlyX(ctx context.Context) *AuthIdentityChannel {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only AuthIdentityChannel ID in the query.
+// Returns a *NotSingularError when more than one AuthIdentityChannel ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *AuthIdentityChannelQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{authidentitychannel.Label}
+ default:
+ err = &NotSingularError{authidentitychannel.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of AuthIdentityChannels.
+func (_q *AuthIdentityChannelQuery) All(ctx context.Context) ([]*AuthIdentityChannel, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*AuthIdentityChannel, *AuthIdentityChannelQuery]()
+ return withInterceptors[[]*AuthIdentityChannel](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) AllX(ctx context.Context) []*AuthIdentityChannel {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of AuthIdentityChannel IDs.
+func (_q *AuthIdentityChannelQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(authidentitychannel.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *AuthIdentityChannelQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityChannelQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *AuthIdentityChannelQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *AuthIdentityChannelQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the AuthIdentityChannelQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *AuthIdentityChannelQuery) Clone() *AuthIdentityChannelQuery {
+ if _q == nil {
+ return nil
+ }
+ return &AuthIdentityChannelQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]authidentitychannel.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.AuthIdentityChannel{}, _q.predicates...),
+ withIdentity: _q.withIdentity.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithIdentity tells the query-builder to eager-load the nodes that are connected to
+// the "identity" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *AuthIdentityChannelQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *AuthIdentityChannelQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withIdentity = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.AuthIdentityChannel.Query().
+// GroupBy(authidentitychannel.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *AuthIdentityChannelQuery) GroupBy(field string, fields ...string) *AuthIdentityChannelGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &AuthIdentityChannelGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = authidentitychannel.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.AuthIdentityChannel.Query().
+// Select(authidentitychannel.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *AuthIdentityChannelQuery) Select(fields ...string) *AuthIdentityChannelSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &AuthIdentityChannelSelect{AuthIdentityChannelQuery: _q}
+ sbuild.label = authidentitychannel.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a AuthIdentityChannelSelect configured with the given aggregations.
+func (_q *AuthIdentityChannelQuery) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *AuthIdentityChannelQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !authidentitychannel.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *AuthIdentityChannelQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentityChannel, error) {
+ var (
+ nodes = []*AuthIdentityChannel{}
+ _spec = _q.querySpec()
+ loadedTypes = [1]bool{
+ _q.withIdentity != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*AuthIdentityChannel).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &AuthIdentityChannel{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withIdentity; query != nil {
+ if err := _q.loadIdentity(ctx, query, nodes, nil,
+ func(n *AuthIdentityChannel, e *AuthIdentity) { n.Edges.Identity = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *AuthIdentityChannelQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*AuthIdentityChannel, init func(*AuthIdentityChannel), assign func(*AuthIdentityChannel, *AuthIdentity)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*AuthIdentityChannel)
+ for i := range nodes {
+ fk := nodes[i].IdentityID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(authidentity.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *AuthIdentityChannelQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *AuthIdentityChannelQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID)
+ for i := range fields {
+ if fields[i] != authidentitychannel.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withIdentity != nil {
+ _spec.Node.AddColumnOnce(authidentitychannel.FieldIdentityID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *AuthIdentityChannelQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(authidentitychannel.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = authidentitychannel.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *AuthIdentityChannelQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityChannelQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *AuthIdentityChannelQuery) ForShare(opts ...sql.LockOption) *AuthIdentityChannelQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// AuthIdentityChannelGroupBy is the group-by builder for AuthIdentityChannel entities.
+type AuthIdentityChannelGroupBy struct {
+ selector
+ build *AuthIdentityChannelQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *AuthIdentityChannelGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *AuthIdentityChannelGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *AuthIdentityChannelGroupBy) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// AuthIdentityChannelSelect is the builder for selecting fields of AuthIdentityChannel entities.
+type AuthIdentityChannelSelect struct {
+ *AuthIdentityChannelQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *AuthIdentityChannelSelect) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *AuthIdentityChannelSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelSelect](ctx, _s.AuthIdentityChannelQuery, _s, _s.inters, v)
+}
+
+func (_s *AuthIdentityChannelSelect) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/authidentitychannel_update.go b/backend/ent/authidentitychannel_update.go
new file mode 100644
index 00000000..b550c454
--- /dev/null
+++ b/backend/ent/authidentitychannel_update.go
@@ -0,0 +1,581 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// AuthIdentityChannelUpdate is the builder for updating AuthIdentityChannel entities.
+type AuthIdentityChannelUpdate struct {
+ config
+ hooks []Hook
+ mutation *AuthIdentityChannelMutation
+}
+
+// Where appends a list predicates to the AuthIdentityChannelUpdate builder.
+func (_u *AuthIdentityChannelUpdate) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityChannelUpdate) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *AuthIdentityChannelUpdate) SetIdentityID(v int64) *AuthIdentityChannelUpdate {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityChannelUpdate) SetProviderType(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableProviderType(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityChannelUpdate) SetProviderKey(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetChannel sets the "channel" field.
+func (_u *AuthIdentityChannelUpdate) SetChannel(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetChannel(v)
+ return _u
+}
+
+// SetNillableChannel sets the "channel" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableChannel(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetChannel(*v)
+ }
+ return _u
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (_u *AuthIdentityChannelUpdate) SetChannelAppID(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetChannelAppID(v)
+ return _u
+}
+
+// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetChannelAppID(*v)
+ }
+ return _u
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (_u *AuthIdentityChannelUpdate) SetChannelSubject(v string) *AuthIdentityChannelUpdate {
+ _u.mutation.SetChannelSubject(v)
+ return _u
+}
+
+// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdate) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdate {
+ if v != nil {
+ _u.SetChannelSubject(*v)
+ }
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityChannelUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdate {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdate {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the AuthIdentityChannelMutation object of the builder.
+func (_u *AuthIdentityChannelUpdate) Mutation() *AuthIdentityChannelMutation {
+ return _u.mutation
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdate) ClearIdentity() *AuthIdentityChannelUpdate {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *AuthIdentityChannelUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *AuthIdentityChannelUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityChannelUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentitychannel.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityChannelUpdate) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Channel(); ok {
+ if err := authidentitychannel.ChannelValidator(v); err != nil {
+ return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelAppID(); ok {
+ if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
+ return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelSubject(); ok {
+ if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityChannelUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Channel(); ok {
+ _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelAppID(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelSubject(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentitychannel.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// AuthIdentityChannelUpdateOne is the builder for updating a single AuthIdentityChannel entity.
+type AuthIdentityChannelUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *AuthIdentityChannelMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *AuthIdentityChannelUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *AuthIdentityChannelUpdateOne) SetIdentityID(v int64) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *AuthIdentityChannelUpdateOne) SetProviderType(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderType(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *AuthIdentityChannelUpdateOne) SetProviderKey(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetChannel sets the "channel" field.
+func (_u *AuthIdentityChannelUpdateOne) SetChannel(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetChannel(v)
+ return _u
+}
+
+// SetNillableChannel sets the "channel" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableChannel(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetChannel(*v)
+ }
+ return _u
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (_u *AuthIdentityChannelUpdateOne) SetChannelAppID(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetChannelAppID(v)
+ return _u
+}
+
+// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetChannelAppID(*v)
+ }
+ return _u
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (_u *AuthIdentityChannelUpdateOne) SetChannelSubject(v string) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetChannelSubject(v)
+ return _u
+}
+
+// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil.
+func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdateOne {
+ if v != nil {
+ _u.SetChannelSubject(*v)
+ }
+ return _u
+}
+
+// SetMetadata sets the "metadata" field.
+func (_u *AuthIdentityChannelUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdateOne {
+ _u.mutation.SetMetadata(v)
+ return _u
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdateOne) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdateOne {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the AuthIdentityChannelMutation object of the builder.
+func (_u *AuthIdentityChannelUpdateOne) Mutation() *AuthIdentityChannelMutation {
+ return _u.mutation
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *AuthIdentityChannelUpdateOne) ClearIdentity() *AuthIdentityChannelUpdateOne {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Where appends a list predicates to the AuthIdentityChannelUpdate builder.
+func (_u *AuthIdentityChannelUpdateOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *AuthIdentityChannelUpdateOne) Select(field string, fields ...string) *AuthIdentityChannelUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated AuthIdentityChannel entity.
+func (_u *AuthIdentityChannelUpdateOne) Save(ctx context.Context) (*AuthIdentityChannel, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdateOne) SaveX(ctx context.Context) *AuthIdentityChannel {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *AuthIdentityChannelUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *AuthIdentityChannelUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *AuthIdentityChannelUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := authidentitychannel.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *AuthIdentityChannelUpdateOne) check() error {
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := authidentitychannel.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := authidentitychannel.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Channel(); ok {
+ if err := authidentitychannel.ChannelValidator(v); err != nil {
+ return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelAppID(); ok {
+ if err := authidentitychannel.ChannelAppIDValidator(v); err != nil {
+ return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ChannelSubject(); ok {
+ if err := authidentitychannel.ChannelSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)}
+ }
+ }
+ if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`)
+ }
+ return nil
+}
+
+func (_u *AuthIdentityChannelUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentityChannel, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentityChannel.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID)
+ for _, f := range fields {
+ if !authidentitychannel.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != authidentitychannel.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Channel(); ok {
+ _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelAppID(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ChannelSubject(); ok {
+ _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Metadata(); ok {
+ _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: authidentitychannel.IdentityTable,
+ Columns: []string{authidentitychannel.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &AuthIdentityChannel{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{authidentitychannel.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/channelmonitor.go b/backend/ent/channelmonitor.go
new file mode 100644
index 00000000..dbb73362
--- /dev/null
+++ b/backend/ent/channelmonitor.go
@@ -0,0 +1,359 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+)
+
+// ChannelMonitor is the model entity for the ChannelMonitor schema.
+type ChannelMonitor struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // Name holds the value of the "name" field.
+ Name string `json:"name,omitempty"`
+ // Provider holds the value of the "provider" field.
+ Provider channelmonitor.Provider `json:"provider,omitempty"`
+ // Provider base origin, e.g. https://api.openai.com
+ Endpoint string `json:"endpoint,omitempty"`
+ // AES-256-GCM encrypted API key
+ APIKeyEncrypted string `json:"-"`
+ // PrimaryModel holds the value of the "primary_model" field.
+ PrimaryModel string `json:"primary_model,omitempty"`
+ // Additional model names to test alongside primary_model
+ ExtraModels []string `json:"extra_models,omitempty"`
+ // GroupName holds the value of the "group_name" field.
+ GroupName string `json:"group_name,omitempty"`
+ // Enabled holds the value of the "enabled" field.
+ Enabled bool `json:"enabled,omitempty"`
+ // IntervalSeconds holds the value of the "interval_seconds" field.
+ IntervalSeconds int `json:"interval_seconds,omitempty"`
+ // LastCheckedAt holds the value of the "last_checked_at" field.
+ LastCheckedAt *time.Time `json:"last_checked_at,omitempty"`
+ // CreatedBy holds the value of the "created_by" field.
+ CreatedBy int64 `json:"created_by,omitempty"`
+ // TemplateID holds the value of the "template_id" field.
+ TemplateID *int64 `json:"template_id,omitempty"`
+ // ExtraHeaders holds the value of the "extra_headers" field.
+ ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
+ // BodyOverrideMode holds the value of the "body_override_mode" field.
+ BodyOverrideMode string `json:"body_override_mode,omitempty"`
+ // BodyOverride holds the value of the "body_override" field.
+ BodyOverride map[string]interface{} `json:"body_override,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the ChannelMonitorQuery when eager-loading is set.
+ Edges ChannelMonitorEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// ChannelMonitorEdges holds the relations/edges for other nodes in the graph.
+type ChannelMonitorEdges struct {
+ // History holds the value of the history edge.
+ History []*ChannelMonitorHistory `json:"history,omitempty"`
+ // DailyRollups holds the value of the daily_rollups edge.
+ DailyRollups []*ChannelMonitorDailyRollup `json:"daily_rollups,omitempty"`
+ // RequestTemplate holds the value of the request_template edge.
+ RequestTemplate *ChannelMonitorRequestTemplate `json:"request_template,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [3]bool
+}
+
+// HistoryOrErr returns the History value or an error if the edge
+// was not loaded in eager-loading.
+func (e ChannelMonitorEdges) HistoryOrErr() ([]*ChannelMonitorHistory, error) {
+ if e.loadedTypes[0] {
+ return e.History, nil
+ }
+ return nil, &NotLoadedError{edge: "history"}
+}
+
+// DailyRollupsOrErr returns the DailyRollups value or an error if the edge
+// was not loaded in eager-loading.
+func (e ChannelMonitorEdges) DailyRollupsOrErr() ([]*ChannelMonitorDailyRollup, error) {
+ if e.loadedTypes[1] {
+ return e.DailyRollups, nil
+ }
+ return nil, &NotLoadedError{edge: "daily_rollups"}
+}
+
+// RequestTemplateOrErr returns the RequestTemplate value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e ChannelMonitorEdges) RequestTemplateOrErr() (*ChannelMonitorRequestTemplate, error) {
+ if e.RequestTemplate != nil {
+ return e.RequestTemplate, nil
+ } else if e.loadedTypes[2] {
+ return nil, &NotFoundError{label: channelmonitorrequesttemplate.Label}
+ }
+ return nil, &NotLoadedError{edge: "request_template"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*ChannelMonitor) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitor.FieldExtraModels, channelmonitor.FieldExtraHeaders, channelmonitor.FieldBodyOverride:
+ values[i] = new([]byte)
+ case channelmonitor.FieldEnabled:
+ values[i] = new(sql.NullBool)
+ case channelmonitor.FieldID, channelmonitor.FieldIntervalSeconds, channelmonitor.FieldCreatedBy, channelmonitor.FieldTemplateID:
+ values[i] = new(sql.NullInt64)
+ case channelmonitor.FieldName, channelmonitor.FieldProvider, channelmonitor.FieldEndpoint, channelmonitor.FieldAPIKeyEncrypted, channelmonitor.FieldPrimaryModel, channelmonitor.FieldGroupName, channelmonitor.FieldBodyOverrideMode:
+ values[i] = new(sql.NullString)
+ case channelmonitor.FieldCreatedAt, channelmonitor.FieldUpdatedAt, channelmonitor.FieldLastCheckedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the ChannelMonitor fields.
+func (_m *ChannelMonitor) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitor.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case channelmonitor.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case channelmonitor.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case channelmonitor.FieldName:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field name", values[i])
+ } else if value.Valid {
+ _m.Name = value.String
+ }
+ case channelmonitor.FieldProvider:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider", values[i])
+ } else if value.Valid {
+ _m.Provider = channelmonitor.Provider(value.String)
+ }
+ case channelmonitor.FieldEndpoint:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field endpoint", values[i])
+ } else if value.Valid {
+ _m.Endpoint = value.String
+ }
+ case channelmonitor.FieldAPIKeyEncrypted:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field api_key_encrypted", values[i])
+ } else if value.Valid {
+ _m.APIKeyEncrypted = value.String
+ }
+ case channelmonitor.FieldPrimaryModel:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field primary_model", values[i])
+ } else if value.Valid {
+ _m.PrimaryModel = value.String
+ }
+ case channelmonitor.FieldExtraModels:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field extra_models", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.ExtraModels); err != nil {
+ return fmt.Errorf("unmarshal field extra_models: %w", err)
+ }
+ }
+ case channelmonitor.FieldGroupName:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field group_name", values[i])
+ } else if value.Valid {
+ _m.GroupName = value.String
+ }
+ case channelmonitor.FieldEnabled:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field enabled", values[i])
+ } else if value.Valid {
+ _m.Enabled = value.Bool
+ }
+ case channelmonitor.FieldIntervalSeconds:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field interval_seconds", values[i])
+ } else if value.Valid {
+ _m.IntervalSeconds = int(value.Int64)
+ }
+ case channelmonitor.FieldLastCheckedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field last_checked_at", values[i])
+ } else if value.Valid {
+ _m.LastCheckedAt = new(time.Time)
+ *_m.LastCheckedAt = value.Time
+ }
+ case channelmonitor.FieldCreatedBy:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field created_by", values[i])
+ } else if value.Valid {
+ _m.CreatedBy = value.Int64
+ }
+ case channelmonitor.FieldTemplateID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field template_id", values[i])
+ } else if value.Valid {
+ _m.TemplateID = new(int64)
+ *_m.TemplateID = value.Int64
+ }
+ case channelmonitor.FieldExtraHeaders:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field extra_headers", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.ExtraHeaders); err != nil {
+ return fmt.Errorf("unmarshal field extra_headers: %w", err)
+ }
+ }
+ case channelmonitor.FieldBodyOverrideMode:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field body_override_mode", values[i])
+ } else if value.Valid {
+ _m.BodyOverrideMode = value.String
+ }
+ case channelmonitor.FieldBodyOverride:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field body_override", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.BodyOverride); err != nil {
+ return fmt.Errorf("unmarshal field body_override: %w", err)
+ }
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the ChannelMonitor.
+// This includes values selected through modifiers, order, etc.
+func (_m *ChannelMonitor) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryHistory queries the "history" edge of the ChannelMonitor entity.
+func (_m *ChannelMonitor) QueryHistory() *ChannelMonitorHistoryQuery {
+ return NewChannelMonitorClient(_m.config).QueryHistory(_m)
+}
+
+// QueryDailyRollups queries the "daily_rollups" edge of the ChannelMonitor entity.
+func (_m *ChannelMonitor) QueryDailyRollups() *ChannelMonitorDailyRollupQuery {
+ return NewChannelMonitorClient(_m.config).QueryDailyRollups(_m)
+}
+
+// QueryRequestTemplate queries the "request_template" edge of the ChannelMonitor entity.
+func (_m *ChannelMonitor) QueryRequestTemplate() *ChannelMonitorRequestTemplateQuery {
+ return NewChannelMonitorClient(_m.config).QueryRequestTemplate(_m)
+}
+
+// Update returns a builder for updating this ChannelMonitor.
+// Note that you need to call ChannelMonitor.Unwrap() before calling this method if this ChannelMonitor
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *ChannelMonitor) Update() *ChannelMonitorUpdateOne {
+ return NewChannelMonitorClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the ChannelMonitor entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *ChannelMonitor) Unwrap() *ChannelMonitor {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: ChannelMonitor is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *ChannelMonitor) String() string {
+ var builder strings.Builder
+ builder.WriteString("ChannelMonitor(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("name=")
+ builder.WriteString(_m.Name)
+ builder.WriteString(", ")
+ builder.WriteString("provider=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Provider))
+ builder.WriteString(", ")
+ builder.WriteString("endpoint=")
+ builder.WriteString(_m.Endpoint)
+ builder.WriteString(", ")
+ builder.WriteString("api_key_encrypted=")
+ builder.WriteString(", ")
+ builder.WriteString("primary_model=")
+ builder.WriteString(_m.PrimaryModel)
+ builder.WriteString(", ")
+ builder.WriteString("extra_models=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ExtraModels))
+ builder.WriteString(", ")
+ builder.WriteString("group_name=")
+ builder.WriteString(_m.GroupName)
+ builder.WriteString(", ")
+ builder.WriteString("enabled=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Enabled))
+ builder.WriteString(", ")
+ builder.WriteString("interval_seconds=")
+ builder.WriteString(fmt.Sprintf("%v", _m.IntervalSeconds))
+ builder.WriteString(", ")
+ if v := _m.LastCheckedAt; v != nil {
+ builder.WriteString("last_checked_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("created_by=")
+ builder.WriteString(fmt.Sprintf("%v", _m.CreatedBy))
+ builder.WriteString(", ")
+ if v := _m.TemplateID; v != nil {
+ builder.WriteString("template_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("extra_headers=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ExtraHeaders))
+ builder.WriteString(", ")
+ builder.WriteString("body_override_mode=")
+ builder.WriteString(_m.BodyOverrideMode)
+ builder.WriteString(", ")
+ builder.WriteString("body_override=")
+ builder.WriteString(fmt.Sprintf("%v", _m.BodyOverride))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// ChannelMonitors is a parsable slice of ChannelMonitor.
+type ChannelMonitors []*ChannelMonitor
diff --git a/backend/ent/channelmonitor/channelmonitor.go b/backend/ent/channelmonitor/channelmonitor.go
new file mode 100644
index 00000000..e5a6bfe7
--- /dev/null
+++ b/backend/ent/channelmonitor/channelmonitor.go
@@ -0,0 +1,304 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitor
+
+import (
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the channelmonitor type in the database.
+ Label = "channel_monitor"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldName holds the string denoting the name field in the database.
+ FieldName = "name"
+ // FieldProvider holds the string denoting the provider field in the database.
+ FieldProvider = "provider"
+ // FieldEndpoint holds the string denoting the endpoint field in the database.
+ FieldEndpoint = "endpoint"
+ // FieldAPIKeyEncrypted holds the string denoting the api_key_encrypted field in the database.
+ FieldAPIKeyEncrypted = "api_key_encrypted"
+ // FieldPrimaryModel holds the string denoting the primary_model field in the database.
+ FieldPrimaryModel = "primary_model"
+ // FieldExtraModels holds the string denoting the extra_models field in the database.
+ FieldExtraModels = "extra_models"
+ // FieldGroupName holds the string denoting the group_name field in the database.
+ FieldGroupName = "group_name"
+ // FieldEnabled holds the string denoting the enabled field in the database.
+ FieldEnabled = "enabled"
+ // FieldIntervalSeconds holds the string denoting the interval_seconds field in the database.
+ FieldIntervalSeconds = "interval_seconds"
+ // FieldLastCheckedAt holds the string denoting the last_checked_at field in the database.
+ FieldLastCheckedAt = "last_checked_at"
+ // FieldCreatedBy holds the string denoting the created_by field in the database.
+ FieldCreatedBy = "created_by"
+ // FieldTemplateID holds the string denoting the template_id field in the database.
+ FieldTemplateID = "template_id"
+ // FieldExtraHeaders holds the string denoting the extra_headers field in the database.
+ FieldExtraHeaders = "extra_headers"
+ // FieldBodyOverrideMode holds the string denoting the body_override_mode field in the database.
+ FieldBodyOverrideMode = "body_override_mode"
+ // FieldBodyOverride holds the string denoting the body_override field in the database.
+ FieldBodyOverride = "body_override"
+ // EdgeHistory holds the string denoting the history edge name in mutations.
+ EdgeHistory = "history"
+ // EdgeDailyRollups holds the string denoting the daily_rollups edge name in mutations.
+ EdgeDailyRollups = "daily_rollups"
+ // EdgeRequestTemplate holds the string denoting the request_template edge name in mutations.
+ EdgeRequestTemplate = "request_template"
+ // Table holds the table name of the channelmonitor in the database.
+ Table = "channel_monitors"
+ // HistoryTable is the table that holds the history relation/edge.
+ HistoryTable = "channel_monitor_histories"
+ // HistoryInverseTable is the table name for the ChannelMonitorHistory entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitorhistory" package.
+ HistoryInverseTable = "channel_monitor_histories"
+ // HistoryColumn is the table column denoting the history relation/edge.
+ HistoryColumn = "monitor_id"
+ // DailyRollupsTable is the table that holds the daily_rollups relation/edge.
+ DailyRollupsTable = "channel_monitor_daily_rollups"
+ // DailyRollupsInverseTable is the table name for the ChannelMonitorDailyRollup entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitordailyrollup" package.
+ DailyRollupsInverseTable = "channel_monitor_daily_rollups"
+ // DailyRollupsColumn is the table column denoting the daily_rollups relation/edge.
+ DailyRollupsColumn = "monitor_id"
+ // RequestTemplateTable is the table that holds the request_template relation/edge.
+ RequestTemplateTable = "channel_monitors"
+ // RequestTemplateInverseTable is the table name for the ChannelMonitorRequestTemplate entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitorrequesttemplate" package.
+ RequestTemplateInverseTable = "channel_monitor_request_templates"
+ // RequestTemplateColumn is the table column denoting the request_template relation/edge.
+ RequestTemplateColumn = "template_id"
+)
+
+// Columns holds all SQL columns for channelmonitor fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldName,
+ FieldProvider,
+ FieldEndpoint,
+ FieldAPIKeyEncrypted,
+ FieldPrimaryModel,
+ FieldExtraModels,
+ FieldGroupName,
+ FieldEnabled,
+ FieldIntervalSeconds,
+ FieldLastCheckedAt,
+ FieldCreatedBy,
+ FieldTemplateID,
+ FieldExtraHeaders,
+ FieldBodyOverrideMode,
+ FieldBodyOverride,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // NameValidator is a validator for the "name" field. It is called by the builders before save.
+ NameValidator func(string) error
+ // EndpointValidator is a validator for the "endpoint" field. It is called by the builders before save.
+ EndpointValidator func(string) error
+ // APIKeyEncryptedValidator is a validator for the "api_key_encrypted" field. It is called by the builders before save.
+ APIKeyEncryptedValidator func(string) error
+ // PrimaryModelValidator is a validator for the "primary_model" field. It is called by the builders before save.
+ PrimaryModelValidator func(string) error
+ // DefaultExtraModels holds the default value on creation for the "extra_models" field.
+ DefaultExtraModels []string
+ // DefaultGroupName holds the default value on creation for the "group_name" field.
+ DefaultGroupName string
+ // GroupNameValidator is a validator for the "group_name" field. It is called by the builders before save.
+ GroupNameValidator func(string) error
+ // DefaultEnabled holds the default value on creation for the "enabled" field.
+ DefaultEnabled bool
+ // IntervalSecondsValidator is a validator for the "interval_seconds" field. It is called by the builders before save.
+ IntervalSecondsValidator func(int) error
+ // DefaultExtraHeaders holds the default value on creation for the "extra_headers" field.
+ DefaultExtraHeaders map[string]string
+ // DefaultBodyOverrideMode holds the default value on creation for the "body_override_mode" field.
+ DefaultBodyOverrideMode string
+ // BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
+ BodyOverrideModeValidator func(string) error
+)
+
+// Provider defines the type for the "provider" enum field.
+type Provider string
+
+// Provider values.
+const (
+ ProviderOpenai Provider = "openai"
+ ProviderAnthropic Provider = "anthropic"
+ ProviderGemini Provider = "gemini"
+)
+
+func (pr Provider) String() string {
+ return string(pr)
+}
+
+// ProviderValidator is a validator for the "provider" field enum values. It is called by the builders before save.
+func ProviderValidator(pr Provider) error {
+ switch pr {
+ case ProviderOpenai, ProviderAnthropic, ProviderGemini:
+ return nil
+ default:
+ return fmt.Errorf("channelmonitor: invalid enum value for provider field: %q", pr)
+ }
+}
+
+// OrderOption defines the ordering options for the ChannelMonitor queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByName orders the results by the name field.
+func ByName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldName, opts...).ToFunc()
+}
+
+// ByProvider orders the results by the provider field.
+func ByProvider(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProvider, opts...).ToFunc()
+}
+
+// ByEndpoint orders the results by the endpoint field.
+func ByEndpoint(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldEndpoint, opts...).ToFunc()
+}
+
+// ByAPIKeyEncrypted orders the results by the api_key_encrypted field.
+func ByAPIKeyEncrypted(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAPIKeyEncrypted, opts...).ToFunc()
+}
+
+// ByPrimaryModel orders the results by the primary_model field.
+func ByPrimaryModel(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPrimaryModel, opts...).ToFunc()
+}
+
+// ByGroupName orders the results by the group_name field.
+func ByGroupName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldGroupName, opts...).ToFunc()
+}
+
+// ByEnabled orders the results by the enabled field.
+func ByEnabled(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldEnabled, opts...).ToFunc()
+}
+
+// ByIntervalSeconds orders the results by the interval_seconds field.
+func ByIntervalSeconds(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIntervalSeconds, opts...).ToFunc()
+}
+
+// ByLastCheckedAt orders the results by the last_checked_at field.
+func ByLastCheckedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLastCheckedAt, opts...).ToFunc()
+}
+
+// ByCreatedBy orders the results by the created_by field.
+func ByCreatedBy(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedBy, opts...).ToFunc()
+}
+
+// ByTemplateID orders the results by the template_id field.
+func ByTemplateID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTemplateID, opts...).ToFunc()
+}
+
+// ByBodyOverrideMode orders the results by the body_override_mode field.
+func ByBodyOverrideMode(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBodyOverrideMode, opts...).ToFunc()
+}
+
+// ByHistoryCount orders the results by history count.
+func ByHistoryCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newHistoryStep(), opts...)
+ }
+}
+
+// ByHistory orders the results by history terms.
+func ByHistory(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newHistoryStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
+// ByDailyRollupsCount orders the results by daily_rollups count.
+func ByDailyRollupsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newDailyRollupsStep(), opts...)
+ }
+}
+
+// ByDailyRollups orders the results by daily_rollups terms.
+func ByDailyRollups(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newDailyRollupsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
+// ByRequestTemplateField orders the results by request_template field.
+func ByRequestTemplateField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newRequestTemplateStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newHistoryStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(HistoryInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, HistoryTable, HistoryColumn),
+ )
+}
+func newDailyRollupsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(DailyRollupsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, DailyRollupsTable, DailyRollupsColumn),
+ )
+}
+func newRequestTemplateStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(RequestTemplateInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, false, RequestTemplateTable, RequestTemplateColumn),
+ )
+}
diff --git a/backend/ent/channelmonitor/where.go b/backend/ent/channelmonitor/where.go
new file mode 100644
index 00000000..755d83a3
--- /dev/null
+++ b/backend/ent/channelmonitor/where.go
@@ -0,0 +1,885 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitor
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
+func Name(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldName, v))
+}
+
+// Endpoint applies equality check predicate on the "endpoint" field. It's identical to EndpointEQ.
+func Endpoint(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldEndpoint, v))
+}
+
+// APIKeyEncrypted applies equality check predicate on the "api_key_encrypted" field. It's identical to APIKeyEncryptedEQ.
+func APIKeyEncrypted(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldAPIKeyEncrypted, v))
+}
+
+// PrimaryModel applies equality check predicate on the "primary_model" field. It's identical to PrimaryModelEQ.
+func PrimaryModel(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldPrimaryModel, v))
+}
+
+// GroupName applies equality check predicate on the "group_name" field. It's identical to GroupNameEQ.
+func GroupName(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldGroupName, v))
+}
+
+// Enabled applies equality check predicate on the "enabled" field. It's identical to EnabledEQ.
+func Enabled(v bool) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldEnabled, v))
+}
+
+// IntervalSeconds applies equality check predicate on the "interval_seconds" field. It's identical to IntervalSecondsEQ.
+func IntervalSeconds(v int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldIntervalSeconds, v))
+}
+
+// LastCheckedAt applies equality check predicate on the "last_checked_at" field. It's identical to LastCheckedAtEQ.
+func LastCheckedAt(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldLastCheckedAt, v))
+}
+
+// CreatedBy applies equality check predicate on the "created_by" field. It's identical to CreatedByEQ.
+func CreatedBy(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldCreatedBy, v))
+}
+
+// TemplateID applies equality check predicate on the "template_id" field. It's identical to TemplateIDEQ.
+func TemplateID(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldTemplateID, v))
+}
+
+// BodyOverrideMode applies equality check predicate on the "body_override_mode" field. It's identical to BodyOverrideModeEQ.
+func BodyOverrideMode(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldBodyOverrideMode, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// NameEQ applies the EQ predicate on the "name" field.
+func NameEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldName, v))
+}
+
+// NameNEQ applies the NEQ predicate on the "name" field.
+func NameNEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldName, v))
+}
+
+// NameIn applies the In predicate on the "name" field.
+func NameIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldName, vs...))
+}
+
+// NameNotIn applies the NotIn predicate on the "name" field.
+func NameNotIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldName, vs...))
+}
+
+// NameGT applies the GT predicate on the "name" field.
+func NameGT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldName, v))
+}
+
+// NameGTE applies the GTE predicate on the "name" field.
+func NameGTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldName, v))
+}
+
+// NameLT applies the LT predicate on the "name" field.
+func NameLT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldName, v))
+}
+
+// NameLTE applies the LTE predicate on the "name" field.
+func NameLTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldName, v))
+}
+
+// NameContains applies the Contains predicate on the "name" field.
+func NameContains(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContains(FieldName, v))
+}
+
+// NameHasPrefix applies the HasPrefix predicate on the "name" field.
+func NameHasPrefix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldName, v))
+}
+
+// NameHasSuffix applies the HasSuffix predicate on the "name" field.
+func NameHasSuffix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldName, v))
+}
+
+// NameEqualFold applies the EqualFold predicate on the "name" field.
+func NameEqualFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEqualFold(FieldName, v))
+}
+
+// NameContainsFold applies the ContainsFold predicate on the "name" field.
+func NameContainsFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContainsFold(FieldName, v))
+}
+
+// ProviderEQ applies the EQ predicate on the "provider" field.
+func ProviderEQ(v Provider) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldProvider, v))
+}
+
+// ProviderNEQ applies the NEQ predicate on the "provider" field.
+func ProviderNEQ(v Provider) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldProvider, v))
+}
+
+// ProviderIn applies the In predicate on the "provider" field.
+func ProviderIn(vs ...Provider) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldProvider, vs...))
+}
+
+// ProviderNotIn applies the NotIn predicate on the "provider" field.
+func ProviderNotIn(vs ...Provider) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldProvider, vs...))
+}
+
+// EndpointEQ applies the EQ predicate on the "endpoint" field.
+func EndpointEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldEndpoint, v))
+}
+
+// EndpointNEQ applies the NEQ predicate on the "endpoint" field.
+func EndpointNEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldEndpoint, v))
+}
+
+// EndpointIn applies the In predicate on the "endpoint" field.
+func EndpointIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldEndpoint, vs...))
+}
+
+// EndpointNotIn applies the NotIn predicate on the "endpoint" field.
+func EndpointNotIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldEndpoint, vs...))
+}
+
+// EndpointGT applies the GT predicate on the "endpoint" field.
+func EndpointGT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldEndpoint, v))
+}
+
+// EndpointGTE applies the GTE predicate on the "endpoint" field.
+func EndpointGTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldEndpoint, v))
+}
+
+// EndpointLT applies the LT predicate on the "endpoint" field.
+func EndpointLT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldEndpoint, v))
+}
+
+// EndpointLTE applies the LTE predicate on the "endpoint" field.
+func EndpointLTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldEndpoint, v))
+}
+
+// EndpointContains applies the Contains predicate on the "endpoint" field.
+func EndpointContains(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContains(FieldEndpoint, v))
+}
+
+// EndpointHasPrefix applies the HasPrefix predicate on the "endpoint" field.
+func EndpointHasPrefix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldEndpoint, v))
+}
+
+// EndpointHasSuffix applies the HasSuffix predicate on the "endpoint" field.
+func EndpointHasSuffix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldEndpoint, v))
+}
+
+// EndpointEqualFold applies the EqualFold predicate on the "endpoint" field.
+func EndpointEqualFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEqualFold(FieldEndpoint, v))
+}
+
+// EndpointContainsFold applies the ContainsFold predicate on the "endpoint" field.
+func EndpointContainsFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContainsFold(FieldEndpoint, v))
+}
+
+// APIKeyEncryptedEQ applies the EQ predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedNEQ applies the NEQ predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedNEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedIn applies the In predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldAPIKeyEncrypted, vs...))
+}
+
+// APIKeyEncryptedNotIn applies the NotIn predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedNotIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldAPIKeyEncrypted, vs...))
+}
+
+// APIKeyEncryptedGT applies the GT predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedGT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedGTE applies the GTE predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedGTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedLT applies the LT predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedLT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedLTE applies the LTE predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedLTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedContains applies the Contains predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedContains(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContains(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedHasPrefix applies the HasPrefix predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedHasPrefix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedHasSuffix applies the HasSuffix predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedHasSuffix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedEqualFold applies the EqualFold predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedEqualFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEqualFold(FieldAPIKeyEncrypted, v))
+}
+
+// APIKeyEncryptedContainsFold applies the ContainsFold predicate on the "api_key_encrypted" field.
+func APIKeyEncryptedContainsFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContainsFold(FieldAPIKeyEncrypted, v))
+}
+
+// PrimaryModelEQ applies the EQ predicate on the "primary_model" field.
+func PrimaryModelEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldPrimaryModel, v))
+}
+
+// PrimaryModelNEQ applies the NEQ predicate on the "primary_model" field.
+func PrimaryModelNEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldPrimaryModel, v))
+}
+
+// PrimaryModelIn applies the In predicate on the "primary_model" field.
+func PrimaryModelIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldPrimaryModel, vs...))
+}
+
+// PrimaryModelNotIn applies the NotIn predicate on the "primary_model" field.
+func PrimaryModelNotIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldPrimaryModel, vs...))
+}
+
+// PrimaryModelGT applies the GT predicate on the "primary_model" field.
+func PrimaryModelGT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldPrimaryModel, v))
+}
+
+// PrimaryModelGTE applies the GTE predicate on the "primary_model" field.
+func PrimaryModelGTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldPrimaryModel, v))
+}
+
+// PrimaryModelLT applies the LT predicate on the "primary_model" field.
+func PrimaryModelLT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldPrimaryModel, v))
+}
+
+// PrimaryModelLTE applies the LTE predicate on the "primary_model" field.
+func PrimaryModelLTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldPrimaryModel, v))
+}
+
+// PrimaryModelContains applies the Contains predicate on the "primary_model" field.
+func PrimaryModelContains(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContains(FieldPrimaryModel, v))
+}
+
+// PrimaryModelHasPrefix applies the HasPrefix predicate on the "primary_model" field.
+func PrimaryModelHasPrefix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldPrimaryModel, v))
+}
+
+// PrimaryModelHasSuffix applies the HasSuffix predicate on the "primary_model" field.
+func PrimaryModelHasSuffix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldPrimaryModel, v))
+}
+
+// PrimaryModelEqualFold applies the EqualFold predicate on the "primary_model" field.
+func PrimaryModelEqualFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEqualFold(FieldPrimaryModel, v))
+}
+
+// PrimaryModelContainsFold applies the ContainsFold predicate on the "primary_model" field.
+func PrimaryModelContainsFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContainsFold(FieldPrimaryModel, v))
+}
+
+// GroupNameEQ applies the EQ predicate on the "group_name" field.
+func GroupNameEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldGroupName, v))
+}
+
+// GroupNameNEQ applies the NEQ predicate on the "group_name" field.
+func GroupNameNEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldGroupName, v))
+}
+
+// GroupNameIn applies the In predicate on the "group_name" field.
+func GroupNameIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldGroupName, vs...))
+}
+
+// GroupNameNotIn applies the NotIn predicate on the "group_name" field.
+func GroupNameNotIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldGroupName, vs...))
+}
+
+// GroupNameGT applies the GT predicate on the "group_name" field.
+func GroupNameGT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldGroupName, v))
+}
+
+// GroupNameGTE applies the GTE predicate on the "group_name" field.
+func GroupNameGTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldGroupName, v))
+}
+
+// GroupNameLT applies the LT predicate on the "group_name" field.
+func GroupNameLT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldGroupName, v))
+}
+
+// GroupNameLTE applies the LTE predicate on the "group_name" field.
+func GroupNameLTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldGroupName, v))
+}
+
+// GroupNameContains applies the Contains predicate on the "group_name" field.
+func GroupNameContains(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContains(FieldGroupName, v))
+}
+
+// GroupNameHasPrefix applies the HasPrefix predicate on the "group_name" field.
+func GroupNameHasPrefix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldGroupName, v))
+}
+
+// GroupNameHasSuffix applies the HasSuffix predicate on the "group_name" field.
+func GroupNameHasSuffix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldGroupName, v))
+}
+
+// GroupNameIsNil applies the IsNil predicate on the "group_name" field.
+func GroupNameIsNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIsNull(FieldGroupName))
+}
+
+// GroupNameNotNil applies the NotNil predicate on the "group_name" field.
+func GroupNameNotNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotNull(FieldGroupName))
+}
+
+// GroupNameEqualFold applies the EqualFold predicate on the "group_name" field.
+func GroupNameEqualFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEqualFold(FieldGroupName, v))
+}
+
+// GroupNameContainsFold applies the ContainsFold predicate on the "group_name" field.
+func GroupNameContainsFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContainsFold(FieldGroupName, v))
+}
+
+// EnabledEQ applies the EQ predicate on the "enabled" field.
+func EnabledEQ(v bool) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldEnabled, v))
+}
+
+// EnabledNEQ applies the NEQ predicate on the "enabled" field.
+func EnabledNEQ(v bool) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldEnabled, v))
+}
+
+// IntervalSecondsEQ applies the EQ predicate on the "interval_seconds" field.
+func IntervalSecondsEQ(v int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldIntervalSeconds, v))
+}
+
+// IntervalSecondsNEQ applies the NEQ predicate on the "interval_seconds" field.
+func IntervalSecondsNEQ(v int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldIntervalSeconds, v))
+}
+
+// IntervalSecondsIn applies the In predicate on the "interval_seconds" field.
+func IntervalSecondsIn(vs ...int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldIntervalSeconds, vs...))
+}
+
+// IntervalSecondsNotIn applies the NotIn predicate on the "interval_seconds" field.
+func IntervalSecondsNotIn(vs ...int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldIntervalSeconds, vs...))
+}
+
+// IntervalSecondsGT applies the GT predicate on the "interval_seconds" field.
+func IntervalSecondsGT(v int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldIntervalSeconds, v))
+}
+
+// IntervalSecondsGTE applies the GTE predicate on the "interval_seconds" field.
+func IntervalSecondsGTE(v int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldIntervalSeconds, v))
+}
+
+// IntervalSecondsLT applies the LT predicate on the "interval_seconds" field.
+func IntervalSecondsLT(v int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldIntervalSeconds, v))
+}
+
+// IntervalSecondsLTE applies the LTE predicate on the "interval_seconds" field.
+func IntervalSecondsLTE(v int) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldIntervalSeconds, v))
+}
+
+// LastCheckedAtEQ applies the EQ predicate on the "last_checked_at" field.
+func LastCheckedAtEQ(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldLastCheckedAt, v))
+}
+
+// LastCheckedAtNEQ applies the NEQ predicate on the "last_checked_at" field.
+func LastCheckedAtNEQ(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldLastCheckedAt, v))
+}
+
+// LastCheckedAtIn applies the In predicate on the "last_checked_at" field.
+func LastCheckedAtIn(vs ...time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldLastCheckedAt, vs...))
+}
+
+// LastCheckedAtNotIn applies the NotIn predicate on the "last_checked_at" field.
+func LastCheckedAtNotIn(vs ...time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldLastCheckedAt, vs...))
+}
+
+// LastCheckedAtGT applies the GT predicate on the "last_checked_at" field.
+func LastCheckedAtGT(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldLastCheckedAt, v))
+}
+
+// LastCheckedAtGTE applies the GTE predicate on the "last_checked_at" field.
+func LastCheckedAtGTE(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldLastCheckedAt, v))
+}
+
+// LastCheckedAtLT applies the LT predicate on the "last_checked_at" field.
+func LastCheckedAtLT(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldLastCheckedAt, v))
+}
+
+// LastCheckedAtLTE applies the LTE predicate on the "last_checked_at" field.
+func LastCheckedAtLTE(v time.Time) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldLastCheckedAt, v))
+}
+
+// LastCheckedAtIsNil applies the IsNil predicate on the "last_checked_at" field.
+func LastCheckedAtIsNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIsNull(FieldLastCheckedAt))
+}
+
+// LastCheckedAtNotNil applies the NotNil predicate on the "last_checked_at" field.
+func LastCheckedAtNotNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotNull(FieldLastCheckedAt))
+}
+
+// CreatedByEQ applies the EQ predicate on the "created_by" field.
+func CreatedByEQ(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldCreatedBy, v))
+}
+
+// CreatedByNEQ applies the NEQ predicate on the "created_by" field.
+func CreatedByNEQ(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldCreatedBy, v))
+}
+
+// CreatedByIn applies the In predicate on the "created_by" field.
+func CreatedByIn(vs ...int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldCreatedBy, vs...))
+}
+
+// CreatedByNotIn applies the NotIn predicate on the "created_by" field.
+func CreatedByNotIn(vs ...int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldCreatedBy, vs...))
+}
+
+// CreatedByGT applies the GT predicate on the "created_by" field.
+func CreatedByGT(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldCreatedBy, v))
+}
+
+// CreatedByGTE applies the GTE predicate on the "created_by" field.
+func CreatedByGTE(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldCreatedBy, v))
+}
+
+// CreatedByLT applies the LT predicate on the "created_by" field.
+func CreatedByLT(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldCreatedBy, v))
+}
+
+// CreatedByLTE applies the LTE predicate on the "created_by" field.
+func CreatedByLTE(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldCreatedBy, v))
+}
+
+// TemplateIDEQ applies the EQ predicate on the "template_id" field.
+func TemplateIDEQ(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldTemplateID, v))
+}
+
+// TemplateIDNEQ applies the NEQ predicate on the "template_id" field.
+func TemplateIDNEQ(v int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldTemplateID, v))
+}
+
+// TemplateIDIn applies the In predicate on the "template_id" field.
+func TemplateIDIn(vs ...int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldTemplateID, vs...))
+}
+
+// TemplateIDNotIn applies the NotIn predicate on the "template_id" field.
+func TemplateIDNotIn(vs ...int64) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldTemplateID, vs...))
+}
+
+// TemplateIDIsNil applies the IsNil predicate on the "template_id" field.
+func TemplateIDIsNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIsNull(FieldTemplateID))
+}
+
+// TemplateIDNotNil applies the NotNil predicate on the "template_id" field.
+func TemplateIDNotNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotNull(FieldTemplateID))
+}
+
+// BodyOverrideModeEQ applies the EQ predicate on the "body_override_mode" field.
+func BodyOverrideModeEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeNEQ applies the NEQ predicate on the "body_override_mode" field.
+func BodyOverrideModeNEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeIn applies the In predicate on the "body_override_mode" field.
+func BodyOverrideModeIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldBodyOverrideMode, vs...))
+}
+
+// BodyOverrideModeNotIn applies the NotIn predicate on the "body_override_mode" field.
+func BodyOverrideModeNotIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldBodyOverrideMode, vs...))
+}
+
+// BodyOverrideModeGT applies the GT predicate on the "body_override_mode" field.
+func BodyOverrideModeGT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeGTE applies the GTE predicate on the "body_override_mode" field.
+func BodyOverrideModeGTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeLT applies the LT predicate on the "body_override_mode" field.
+func BodyOverrideModeLT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeLTE applies the LTE predicate on the "body_override_mode" field.
+func BodyOverrideModeLTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeContains applies the Contains predicate on the "body_override_mode" field.
+func BodyOverrideModeContains(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContains(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeHasPrefix applies the HasPrefix predicate on the "body_override_mode" field.
+func BodyOverrideModeHasPrefix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeHasSuffix applies the HasSuffix predicate on the "body_override_mode" field.
+func BodyOverrideModeHasSuffix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeEqualFold applies the EqualFold predicate on the "body_override_mode" field.
+func BodyOverrideModeEqualFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEqualFold(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeContainsFold applies the ContainsFold predicate on the "body_override_mode" field.
+func BodyOverrideModeContainsFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContainsFold(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideIsNil applies the IsNil predicate on the "body_override" field.
+func BodyOverrideIsNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIsNull(FieldBodyOverride))
+}
+
+// BodyOverrideNotNil applies the NotNil predicate on the "body_override" field.
+func BodyOverrideNotNil() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotNull(FieldBodyOverride))
+}
+
+// HasHistory applies the HasEdge predicate on the "history" edge.
+func HasHistory() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, HistoryTable, HistoryColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasHistoryWith applies the HasEdge predicate on the "history" edge with a given conditions (other predicates).
+func HasHistoryWith(preds ...predicate.ChannelMonitorHistory) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := newHistoryStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasDailyRollups applies the HasEdge predicate on the "daily_rollups" edge.
+func HasDailyRollups() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, DailyRollupsTable, DailyRollupsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasDailyRollupsWith applies the HasEdge predicate on the "daily_rollups" edge with a given conditions (other predicates).
+func HasDailyRollupsWith(preds ...predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := newDailyRollupsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasRequestTemplate applies the HasEdge predicate on the "request_template" edge.
+func HasRequestTemplate() predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, false, RequestTemplateTable, RequestTemplateColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasRequestTemplateWith applies the HasEdge predicate on the "request_template" edge with a given conditions (other predicates).
+func HasRequestTemplateWith(preds ...predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(func(s *sql.Selector) {
+ step := newRequestTemplateStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.ChannelMonitor) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.ChannelMonitor) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.ChannelMonitor) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.NotPredicates(p))
+}
diff --git a/backend/ent/channelmonitor_create.go b/backend/ent/channelmonitor_create.go
new file mode 100644
index 00000000..2f70c300
--- /dev/null
+++ b/backend/ent/channelmonitor_create.go
@@ -0,0 +1,1610 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+)
+
+// ChannelMonitorCreate is the builder for creating a ChannelMonitor entity.
+type ChannelMonitorCreate struct {
+ config
+ mutation *ChannelMonitorMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *ChannelMonitorCreate) SetCreatedAt(v time.Time) *ChannelMonitorCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableCreatedAt(v *time.Time) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *ChannelMonitorCreate) SetUpdatedAt(v time.Time) *ChannelMonitorCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableUpdatedAt(v *time.Time) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetName sets the "name" field.
+func (_c *ChannelMonitorCreate) SetName(v string) *ChannelMonitorCreate {
+ _c.mutation.SetName(v)
+ return _c
+}
+
+// SetProvider sets the "provider" field.
+func (_c *ChannelMonitorCreate) SetProvider(v channelmonitor.Provider) *ChannelMonitorCreate {
+ _c.mutation.SetProvider(v)
+ return _c
+}
+
+// SetEndpoint sets the "endpoint" field.
+func (_c *ChannelMonitorCreate) SetEndpoint(v string) *ChannelMonitorCreate {
+ _c.mutation.SetEndpoint(v)
+ return _c
+}
+
+// SetAPIKeyEncrypted sets the "api_key_encrypted" field.
+func (_c *ChannelMonitorCreate) SetAPIKeyEncrypted(v string) *ChannelMonitorCreate {
+ _c.mutation.SetAPIKeyEncrypted(v)
+ return _c
+}
+
+// SetPrimaryModel sets the "primary_model" field.
+func (_c *ChannelMonitorCreate) SetPrimaryModel(v string) *ChannelMonitorCreate {
+ _c.mutation.SetPrimaryModel(v)
+ return _c
+}
+
+// SetExtraModels sets the "extra_models" field.
+func (_c *ChannelMonitorCreate) SetExtraModels(v []string) *ChannelMonitorCreate {
+ _c.mutation.SetExtraModels(v)
+ return _c
+}
+
+// SetGroupName sets the "group_name" field.
+func (_c *ChannelMonitorCreate) SetGroupName(v string) *ChannelMonitorCreate {
+ _c.mutation.SetGroupName(v)
+ return _c
+}
+
+// SetNillableGroupName sets the "group_name" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableGroupName(v *string) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetGroupName(*v)
+ }
+ return _c
+}
+
+// SetEnabled sets the "enabled" field.
+func (_c *ChannelMonitorCreate) SetEnabled(v bool) *ChannelMonitorCreate {
+ _c.mutation.SetEnabled(v)
+ return _c
+}
+
+// SetNillableEnabled sets the "enabled" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableEnabled(v *bool) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetEnabled(*v)
+ }
+ return _c
+}
+
+// SetIntervalSeconds sets the "interval_seconds" field.
+func (_c *ChannelMonitorCreate) SetIntervalSeconds(v int) *ChannelMonitorCreate {
+ _c.mutation.SetIntervalSeconds(v)
+ return _c
+}
+
+// SetLastCheckedAt sets the "last_checked_at" field.
+func (_c *ChannelMonitorCreate) SetLastCheckedAt(v time.Time) *ChannelMonitorCreate {
+ _c.mutation.SetLastCheckedAt(v)
+ return _c
+}
+
+// SetNillableLastCheckedAt sets the "last_checked_at" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableLastCheckedAt(v *time.Time) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetLastCheckedAt(*v)
+ }
+ return _c
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (_c *ChannelMonitorCreate) SetCreatedBy(v int64) *ChannelMonitorCreate {
+ _c.mutation.SetCreatedBy(v)
+ return _c
+}
+
+// SetTemplateID sets the "template_id" field.
+func (_c *ChannelMonitorCreate) SetTemplateID(v int64) *ChannelMonitorCreate {
+ _c.mutation.SetTemplateID(v)
+ return _c
+}
+
+// SetNillableTemplateID sets the "template_id" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableTemplateID(v *int64) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetTemplateID(*v)
+ }
+ return _c
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_c *ChannelMonitorCreate) SetExtraHeaders(v map[string]string) *ChannelMonitorCreate {
+ _c.mutation.SetExtraHeaders(v)
+ return _c
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_c *ChannelMonitorCreate) SetBodyOverrideMode(v string) *ChannelMonitorCreate {
+ _c.mutation.SetBodyOverrideMode(v)
+ return _c
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetBodyOverrideMode(*v)
+ }
+ return _c
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_c *ChannelMonitorCreate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorCreate {
+ _c.mutation.SetBodyOverride(v)
+ return _c
+}
+
+// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by IDs.
+func (_c *ChannelMonitorCreate) AddHistoryIDs(ids ...int64) *ChannelMonitorCreate {
+ _c.mutation.AddHistoryIDs(ids...)
+ return _c
+}
+
+// AddHistory adds the "history" edges to the ChannelMonitorHistory entity.
+func (_c *ChannelMonitorCreate) AddHistory(v ...*ChannelMonitorHistory) *ChannelMonitorCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddHistoryIDs(ids...)
+}
+
+// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs.
+func (_c *ChannelMonitorCreate) AddDailyRollupIDs(ids ...int64) *ChannelMonitorCreate {
+ _c.mutation.AddDailyRollupIDs(ids...)
+ return _c
+}
+
+// AddDailyRollups adds the "daily_rollups" edges to the ChannelMonitorDailyRollup entity.
+func (_c *ChannelMonitorCreate) AddDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddDailyRollupIDs(ids...)
+}
+
+// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID.
+func (_c *ChannelMonitorCreate) SetRequestTemplateID(id int64) *ChannelMonitorCreate {
+ _c.mutation.SetRequestTemplateID(id)
+ return _c
+}
+
+// SetNillableRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableRequestTemplateID(id *int64) *ChannelMonitorCreate {
+ if id != nil {
+ _c = _c.SetRequestTemplateID(*id)
+ }
+ return _c
+}
+
+// SetRequestTemplate sets the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (_c *ChannelMonitorCreate) SetRequestTemplate(v *ChannelMonitorRequestTemplate) *ChannelMonitorCreate {
+ return _c.SetRequestTemplateID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorMutation object of the builder.
+func (_c *ChannelMonitorCreate) Mutation() *ChannelMonitorMutation {
+ return _c.mutation
+}
+
+// Save creates the ChannelMonitor in the database.
+func (_c *ChannelMonitorCreate) Save(ctx context.Context) (*ChannelMonitor, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *ChannelMonitorCreate) SaveX(ctx context.Context) *ChannelMonitor {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *ChannelMonitorCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := channelmonitor.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := channelmonitor.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.ExtraModels(); !ok {
+ v := channelmonitor.DefaultExtraModels
+ _c.mutation.SetExtraModels(v)
+ }
+ if _, ok := _c.mutation.GroupName(); !ok {
+ v := channelmonitor.DefaultGroupName
+ _c.mutation.SetGroupName(v)
+ }
+ if _, ok := _c.mutation.Enabled(); !ok {
+ v := channelmonitor.DefaultEnabled
+ _c.mutation.SetEnabled(v)
+ }
+ if _, ok := _c.mutation.ExtraHeaders(); !ok {
+ v := channelmonitor.DefaultExtraHeaders
+ _c.mutation.SetExtraHeaders(v)
+ }
+ if _, ok := _c.mutation.BodyOverrideMode(); !ok {
+ v := channelmonitor.DefaultBodyOverrideMode
+ _c.mutation.SetBodyOverrideMode(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *ChannelMonitorCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "ChannelMonitor.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "ChannelMonitor.updated_at"`)}
+ }
+ if _, ok := _c.mutation.Name(); !ok {
+ return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "ChannelMonitor.name"`)}
+ }
+ if v, ok := _c.mutation.Name(); ok {
+ if err := channelmonitor.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.name": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Provider(); !ok {
+ return &ValidationError{Name: "provider", err: errors.New(`ent: missing required field "ChannelMonitor.provider"`)}
+ }
+ if v, ok := _c.mutation.Provider(); ok {
+ if err := channelmonitor.ProviderValidator(v); err != nil {
+ return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Endpoint(); !ok {
+ return &ValidationError{Name: "endpoint", err: errors.New(`ent: missing required field "ChannelMonitor.endpoint"`)}
+ }
+ if v, ok := _c.mutation.Endpoint(); ok {
+ if err := channelmonitor.EndpointValidator(v); err != nil {
+ return &ValidationError{Name: "endpoint", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.endpoint": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.APIKeyEncrypted(); !ok {
+ return &ValidationError{Name: "api_key_encrypted", err: errors.New(`ent: missing required field "ChannelMonitor.api_key_encrypted"`)}
+ }
+ if v, ok := _c.mutation.APIKeyEncrypted(); ok {
+ if err := channelmonitor.APIKeyEncryptedValidator(v); err != nil {
+ return &ValidationError{Name: "api_key_encrypted", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_key_encrypted": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.PrimaryModel(); !ok {
+ return &ValidationError{Name: "primary_model", err: errors.New(`ent: missing required field "ChannelMonitor.primary_model"`)}
+ }
+ if v, ok := _c.mutation.PrimaryModel(); ok {
+ if err := channelmonitor.PrimaryModelValidator(v); err != nil {
+ return &ValidationError{Name: "primary_model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.primary_model": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ExtraModels(); !ok {
+ return &ValidationError{Name: "extra_models", err: errors.New(`ent: missing required field "ChannelMonitor.extra_models"`)}
+ }
+ if v, ok := _c.mutation.GroupName(); ok {
+ if err := channelmonitor.GroupNameValidator(v); err != nil {
+ return &ValidationError{Name: "group_name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.group_name": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Enabled(); !ok {
+ return &ValidationError{Name: "enabled", err: errors.New(`ent: missing required field "ChannelMonitor.enabled"`)}
+ }
+ if _, ok := _c.mutation.IntervalSeconds(); !ok {
+ return &ValidationError{Name: "interval_seconds", err: errors.New(`ent: missing required field "ChannelMonitor.interval_seconds"`)}
+ }
+ if v, ok := _c.mutation.IntervalSeconds(); ok {
+ if err := channelmonitor.IntervalSecondsValidator(v); err != nil {
+ return &ValidationError{Name: "interval_seconds", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.interval_seconds": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.CreatedBy(); !ok {
+ return &ValidationError{Name: "created_by", err: errors.New(`ent: missing required field "ChannelMonitor.created_by"`)}
+ }
+ if _, ok := _c.mutation.ExtraHeaders(); !ok {
+ return &ValidationError{Name: "extra_headers", err: errors.New(`ent: missing required field "ChannelMonitor.extra_headers"`)}
+ }
+ if _, ok := _c.mutation.BodyOverrideMode(); !ok {
+ return &ValidationError{Name: "body_override_mode", err: errors.New(`ent: missing required field "ChannelMonitor.body_override_mode"`)}
+ }
+ if v, ok := _c.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitor.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.body_override_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_c *ChannelMonitorCreate) sqlSave(ctx context.Context) (*ChannelMonitor, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *ChannelMonitorCreate) createSpec() (*ChannelMonitor, *sqlgraph.CreateSpec) {
+ var (
+ _node = &ChannelMonitor{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(channelmonitor.Table, sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(channelmonitor.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(channelmonitor.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.Name(); ok {
+ _spec.SetField(channelmonitor.FieldName, field.TypeString, value)
+ _node.Name = value
+ }
+ if value, ok := _c.mutation.Provider(); ok {
+ _spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value)
+ _node.Provider = value
+ }
+ if value, ok := _c.mutation.Endpoint(); ok {
+ _spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value)
+ _node.Endpoint = value
+ }
+ if value, ok := _c.mutation.APIKeyEncrypted(); ok {
+ _spec.SetField(channelmonitor.FieldAPIKeyEncrypted, field.TypeString, value)
+ _node.APIKeyEncrypted = value
+ }
+ if value, ok := _c.mutation.PrimaryModel(); ok {
+ _spec.SetField(channelmonitor.FieldPrimaryModel, field.TypeString, value)
+ _node.PrimaryModel = value
+ }
+ if value, ok := _c.mutation.ExtraModels(); ok {
+ _spec.SetField(channelmonitor.FieldExtraModels, field.TypeJSON, value)
+ _node.ExtraModels = value
+ }
+ if value, ok := _c.mutation.GroupName(); ok {
+ _spec.SetField(channelmonitor.FieldGroupName, field.TypeString, value)
+ _node.GroupName = value
+ }
+ if value, ok := _c.mutation.Enabled(); ok {
+ _spec.SetField(channelmonitor.FieldEnabled, field.TypeBool, value)
+ _node.Enabled = value
+ }
+ if value, ok := _c.mutation.IntervalSeconds(); ok {
+ _spec.SetField(channelmonitor.FieldIntervalSeconds, field.TypeInt, value)
+ _node.IntervalSeconds = value
+ }
+ if value, ok := _c.mutation.LastCheckedAt(); ok {
+ _spec.SetField(channelmonitor.FieldLastCheckedAt, field.TypeTime, value)
+ _node.LastCheckedAt = &value
+ }
+ if value, ok := _c.mutation.CreatedBy(); ok {
+ _spec.SetField(channelmonitor.FieldCreatedBy, field.TypeInt64, value)
+ _node.CreatedBy = value
+ }
+ if value, ok := _c.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitor.FieldExtraHeaders, field.TypeJSON, value)
+ _node.ExtraHeaders = value
+ }
+ if value, ok := _c.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverrideMode, field.TypeString, value)
+ _node.BodyOverrideMode = value
+ }
+ if value, ok := _c.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverride, field.TypeJSON, value)
+ _node.BodyOverride = value
+ }
+ if nodes := _c.mutation.HistoryIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.HistoryTable,
+ Columns: []string{channelmonitor.HistoryColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.DailyRollupsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.RequestTemplateIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: false,
+ Table: channelmonitor.RequestTemplateTable,
+ Columns: []string{channelmonitor.RequestTemplateColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.TemplateID = &nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitor.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorUpsertOne {
+ _c.conflict = opts
+ return &ChannelMonitorUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitor.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorCreate) OnConflictColumns(columns ...string) *ChannelMonitorUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // ChannelMonitorUpsertOne is the builder for "upsert"-ing
+ // one ChannelMonitor node.
+ ChannelMonitorUpsertOne struct {
+ create *ChannelMonitorCreate
+ }
+
+ // ChannelMonitorUpsert is the "OnConflict" setter.
+ ChannelMonitorUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *ChannelMonitorUpsert) SetUpdatedAt(v time.Time) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateUpdatedAt() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldUpdatedAt)
+ return u
+}
+
+// SetName sets the "name" field.
+func (u *ChannelMonitorUpsert) SetName(v string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldName, v)
+ return u
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateName() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldName)
+ return u
+}
+
+// SetProvider sets the "provider" field.
+func (u *ChannelMonitorUpsert) SetProvider(v channelmonitor.Provider) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldProvider, v)
+ return u
+}
+
+// UpdateProvider sets the "provider" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateProvider() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldProvider)
+ return u
+}
+
+// SetEndpoint sets the "endpoint" field.
+func (u *ChannelMonitorUpsert) SetEndpoint(v string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldEndpoint, v)
+ return u
+}
+
+// UpdateEndpoint sets the "endpoint" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateEndpoint() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldEndpoint)
+ return u
+}
+
+// SetAPIKeyEncrypted sets the "api_key_encrypted" field.
+func (u *ChannelMonitorUpsert) SetAPIKeyEncrypted(v string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldAPIKeyEncrypted, v)
+ return u
+}
+
+// UpdateAPIKeyEncrypted sets the "api_key_encrypted" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateAPIKeyEncrypted() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldAPIKeyEncrypted)
+ return u
+}
+
+// SetPrimaryModel sets the "primary_model" field.
+func (u *ChannelMonitorUpsert) SetPrimaryModel(v string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldPrimaryModel, v)
+ return u
+}
+
+// UpdatePrimaryModel sets the "primary_model" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdatePrimaryModel() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldPrimaryModel)
+ return u
+}
+
+// SetExtraModels sets the "extra_models" field.
+func (u *ChannelMonitorUpsert) SetExtraModels(v []string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldExtraModels, v)
+ return u
+}
+
+// UpdateExtraModels sets the "extra_models" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateExtraModels() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldExtraModels)
+ return u
+}
+
+// SetGroupName sets the "group_name" field.
+func (u *ChannelMonitorUpsert) SetGroupName(v string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldGroupName, v)
+ return u
+}
+
+// UpdateGroupName sets the "group_name" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateGroupName() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldGroupName)
+ return u
+}
+
+// ClearGroupName clears the value of the "group_name" field.
+func (u *ChannelMonitorUpsert) ClearGroupName() *ChannelMonitorUpsert {
+ u.SetNull(channelmonitor.FieldGroupName)
+ return u
+}
+
+// SetEnabled sets the "enabled" field.
+func (u *ChannelMonitorUpsert) SetEnabled(v bool) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldEnabled, v)
+ return u
+}
+
+// UpdateEnabled sets the "enabled" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateEnabled() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldEnabled)
+ return u
+}
+
+// SetIntervalSeconds sets the "interval_seconds" field.
+func (u *ChannelMonitorUpsert) SetIntervalSeconds(v int) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldIntervalSeconds, v)
+ return u
+}
+
+// UpdateIntervalSeconds sets the "interval_seconds" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateIntervalSeconds() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldIntervalSeconds)
+ return u
+}
+
+// AddIntervalSeconds adds v to the "interval_seconds" field.
+func (u *ChannelMonitorUpsert) AddIntervalSeconds(v int) *ChannelMonitorUpsert {
+ u.Add(channelmonitor.FieldIntervalSeconds, v)
+ return u
+}
+
+// SetLastCheckedAt sets the "last_checked_at" field.
+func (u *ChannelMonitorUpsert) SetLastCheckedAt(v time.Time) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldLastCheckedAt, v)
+ return u
+}
+
+// UpdateLastCheckedAt sets the "last_checked_at" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateLastCheckedAt() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldLastCheckedAt)
+ return u
+}
+
+// ClearLastCheckedAt clears the value of the "last_checked_at" field.
+func (u *ChannelMonitorUpsert) ClearLastCheckedAt() *ChannelMonitorUpsert {
+ u.SetNull(channelmonitor.FieldLastCheckedAt)
+ return u
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (u *ChannelMonitorUpsert) SetCreatedBy(v int64) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldCreatedBy, v)
+ return u
+}
+
+// UpdateCreatedBy sets the "created_by" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateCreatedBy() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldCreatedBy)
+ return u
+}
+
+// AddCreatedBy adds v to the "created_by" field.
+func (u *ChannelMonitorUpsert) AddCreatedBy(v int64) *ChannelMonitorUpsert {
+ u.Add(channelmonitor.FieldCreatedBy, v)
+ return u
+}
+
+// SetTemplateID sets the "template_id" field.
+func (u *ChannelMonitorUpsert) SetTemplateID(v int64) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldTemplateID, v)
+ return u
+}
+
+// UpdateTemplateID sets the "template_id" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateTemplateID() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldTemplateID)
+ return u
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (u *ChannelMonitorUpsert) ClearTemplateID() *ChannelMonitorUpsert {
+ u.SetNull(channelmonitor.FieldTemplateID)
+ return u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorUpsert) SetExtraHeaders(v map[string]string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldExtraHeaders, v)
+ return u
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateExtraHeaders() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldExtraHeaders)
+ return u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorUpsert) SetBodyOverrideMode(v string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldBodyOverrideMode, v)
+ return u
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateBodyOverrideMode() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldBodyOverrideMode)
+ return u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorUpsert) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldBodyOverride, v)
+ return u
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateBodyOverride() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldBodyOverride)
+ return u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorUpsert) ClearBodyOverride() *ChannelMonitorUpsert {
+ u.SetNull(channelmonitor.FieldBodyOverride)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitor.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorUpsertOne) UpdateNewValues() *ChannelMonitorUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(channelmonitor.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitor.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorUpsertOne) Ignore() *ChannelMonitorUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorUpsertOne) DoNothing() *ChannelMonitorUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorCreate.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorUpsertOne) Update(set func(*ChannelMonitorUpsert)) *ChannelMonitorUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *ChannelMonitorUpsertOne) SetUpdatedAt(v time.Time) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateUpdatedAt() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetName sets the "name" field.
+func (u *ChannelMonitorUpsertOne) SetName(v string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetName(v)
+ })
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateName() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateName()
+ })
+}
+
+// SetProvider sets the "provider" field.
+func (u *ChannelMonitorUpsertOne) SetProvider(v channelmonitor.Provider) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetProvider(v)
+ })
+}
+
+// UpdateProvider sets the "provider" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateProvider() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateProvider()
+ })
+}
+
+// SetEndpoint sets the "endpoint" field.
+func (u *ChannelMonitorUpsertOne) SetEndpoint(v string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetEndpoint(v)
+ })
+}
+
+// UpdateEndpoint sets the "endpoint" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateEndpoint() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateEndpoint()
+ })
+}
+
+// SetAPIKeyEncrypted sets the "api_key_encrypted" field.
+func (u *ChannelMonitorUpsertOne) SetAPIKeyEncrypted(v string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetAPIKeyEncrypted(v)
+ })
+}
+
+// UpdateAPIKeyEncrypted sets the "api_key_encrypted" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateAPIKeyEncrypted() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateAPIKeyEncrypted()
+ })
+}
+
+// SetPrimaryModel sets the "primary_model" field.
+func (u *ChannelMonitorUpsertOne) SetPrimaryModel(v string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetPrimaryModel(v)
+ })
+}
+
+// UpdatePrimaryModel sets the "primary_model" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdatePrimaryModel() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdatePrimaryModel()
+ })
+}
+
+// SetExtraModels sets the "extra_models" field.
+func (u *ChannelMonitorUpsertOne) SetExtraModels(v []string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetExtraModels(v)
+ })
+}
+
+// UpdateExtraModels sets the "extra_models" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateExtraModels() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateExtraModels()
+ })
+}
+
+// SetGroupName sets the "group_name" field.
+func (u *ChannelMonitorUpsertOne) SetGroupName(v string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetGroupName(v)
+ })
+}
+
+// UpdateGroupName sets the "group_name" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateGroupName() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateGroupName()
+ })
+}
+
+// ClearGroupName clears the value of the "group_name" field.
+func (u *ChannelMonitorUpsertOne) ClearGroupName() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearGroupName()
+ })
+}
+
+// SetEnabled sets the "enabled" field.
+func (u *ChannelMonitorUpsertOne) SetEnabled(v bool) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetEnabled(v)
+ })
+}
+
+// UpdateEnabled sets the "enabled" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateEnabled() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateEnabled()
+ })
+}
+
+// SetIntervalSeconds sets the "interval_seconds" field.
+func (u *ChannelMonitorUpsertOne) SetIntervalSeconds(v int) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetIntervalSeconds(v)
+ })
+}
+
+// AddIntervalSeconds adds v to the "interval_seconds" field.
+func (u *ChannelMonitorUpsertOne) AddIntervalSeconds(v int) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.AddIntervalSeconds(v)
+ })
+}
+
+// UpdateIntervalSeconds sets the "interval_seconds" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateIntervalSeconds() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateIntervalSeconds()
+ })
+}
+
+// SetLastCheckedAt sets the "last_checked_at" field.
+func (u *ChannelMonitorUpsertOne) SetLastCheckedAt(v time.Time) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetLastCheckedAt(v)
+ })
+}
+
+// UpdateLastCheckedAt sets the "last_checked_at" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateLastCheckedAt() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateLastCheckedAt()
+ })
+}
+
+// ClearLastCheckedAt clears the value of the "last_checked_at" field.
+func (u *ChannelMonitorUpsertOne) ClearLastCheckedAt() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearLastCheckedAt()
+ })
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (u *ChannelMonitorUpsertOne) SetCreatedBy(v int64) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetCreatedBy(v)
+ })
+}
+
+// AddCreatedBy adds v to the "created_by" field.
+func (u *ChannelMonitorUpsertOne) AddCreatedBy(v int64) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.AddCreatedBy(v)
+ })
+}
+
+// UpdateCreatedBy sets the "created_by" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateCreatedBy() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateCreatedBy()
+ })
+}
+
+// SetTemplateID sets the "template_id" field.
+func (u *ChannelMonitorUpsertOne) SetTemplateID(v int64) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetTemplateID(v)
+ })
+}
+
+// UpdateTemplateID sets the "template_id" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateTemplateID() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateTemplateID()
+ })
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (u *ChannelMonitorUpsertOne) ClearTemplateID() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearTemplateID()
+ })
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorUpsertOne) SetExtraHeaders(v map[string]string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetExtraHeaders(v)
+ })
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateExtraHeaders() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateExtraHeaders()
+ })
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorUpsertOne) SetBodyOverrideMode(v string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetBodyOverrideMode(v)
+ })
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateBodyOverrideMode() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateBodyOverrideMode()
+ })
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorUpsertOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetBodyOverride(v)
+ })
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateBodyOverride() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateBodyOverride()
+ })
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorUpsertOne) ClearBodyOverride() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearBodyOverride()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *ChannelMonitorUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *ChannelMonitorUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// ChannelMonitorCreateBulk is the builder for creating many ChannelMonitor entities in bulk.
+type ChannelMonitorCreateBulk struct {
+ config
+ err error
+ builders []*ChannelMonitorCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the ChannelMonitor entities in the database.
+func (_c *ChannelMonitorCreateBulk) Save(ctx context.Context) ([]*ChannelMonitor, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*ChannelMonitor, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*ChannelMonitorMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *ChannelMonitorCreateBulk) SaveX(ctx context.Context) []*ChannelMonitor {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitor.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorUpsertBulk {
+ _c.conflict = opts
+ return &ChannelMonitorUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitor.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorCreateBulk) OnConflictColumns(columns ...string) *ChannelMonitorUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorUpsertBulk{
+ create: _c,
+ }
+}
+
+// ChannelMonitorUpsertBulk is the builder for "upsert"-ing
+// a bulk of ChannelMonitor nodes.
+type ChannelMonitorUpsertBulk struct {
+ create *ChannelMonitorCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.ChannelMonitor.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorUpsertBulk) UpdateNewValues() *ChannelMonitorUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(channelmonitor.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitor.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorUpsertBulk) Ignore() *ChannelMonitorUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorUpsertBulk) DoNothing() *ChannelMonitorUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorCreateBulk.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorUpsertBulk) Update(set func(*ChannelMonitorUpsert)) *ChannelMonitorUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *ChannelMonitorUpsertBulk) SetUpdatedAt(v time.Time) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateUpdatedAt() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetName sets the "name" field.
+func (u *ChannelMonitorUpsertBulk) SetName(v string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetName(v)
+ })
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateName() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateName()
+ })
+}
+
+// SetProvider sets the "provider" field.
+func (u *ChannelMonitorUpsertBulk) SetProvider(v channelmonitor.Provider) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetProvider(v)
+ })
+}
+
+// UpdateProvider sets the "provider" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateProvider() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateProvider()
+ })
+}
+
+// SetEndpoint sets the "endpoint" field.
+func (u *ChannelMonitorUpsertBulk) SetEndpoint(v string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetEndpoint(v)
+ })
+}
+
+// UpdateEndpoint sets the "endpoint" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateEndpoint() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateEndpoint()
+ })
+}
+
+// SetAPIKeyEncrypted sets the "api_key_encrypted" field.
+func (u *ChannelMonitorUpsertBulk) SetAPIKeyEncrypted(v string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetAPIKeyEncrypted(v)
+ })
+}
+
+// UpdateAPIKeyEncrypted sets the "api_key_encrypted" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateAPIKeyEncrypted() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateAPIKeyEncrypted()
+ })
+}
+
+// SetPrimaryModel sets the "primary_model" field.
+func (u *ChannelMonitorUpsertBulk) SetPrimaryModel(v string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetPrimaryModel(v)
+ })
+}
+
+// UpdatePrimaryModel sets the "primary_model" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdatePrimaryModel() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdatePrimaryModel()
+ })
+}
+
+// SetExtraModels sets the "extra_models" field.
+func (u *ChannelMonitorUpsertBulk) SetExtraModels(v []string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetExtraModels(v)
+ })
+}
+
+// UpdateExtraModels sets the "extra_models" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateExtraModels() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateExtraModels()
+ })
+}
+
+// SetGroupName sets the "group_name" field.
+func (u *ChannelMonitorUpsertBulk) SetGroupName(v string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetGroupName(v)
+ })
+}
+
+// UpdateGroupName sets the "group_name" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateGroupName() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateGroupName()
+ })
+}
+
+// ClearGroupName clears the value of the "group_name" field.
+func (u *ChannelMonitorUpsertBulk) ClearGroupName() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearGroupName()
+ })
+}
+
+// SetEnabled sets the "enabled" field.
+func (u *ChannelMonitorUpsertBulk) SetEnabled(v bool) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetEnabled(v)
+ })
+}
+
+// UpdateEnabled sets the "enabled" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateEnabled() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateEnabled()
+ })
+}
+
+// SetIntervalSeconds sets the "interval_seconds" field.
+func (u *ChannelMonitorUpsertBulk) SetIntervalSeconds(v int) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetIntervalSeconds(v)
+ })
+}
+
+// AddIntervalSeconds adds v to the "interval_seconds" field.
+func (u *ChannelMonitorUpsertBulk) AddIntervalSeconds(v int) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.AddIntervalSeconds(v)
+ })
+}
+
+// UpdateIntervalSeconds sets the "interval_seconds" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateIntervalSeconds() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateIntervalSeconds()
+ })
+}
+
+// SetLastCheckedAt sets the "last_checked_at" field.
+func (u *ChannelMonitorUpsertBulk) SetLastCheckedAt(v time.Time) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetLastCheckedAt(v)
+ })
+}
+
+// UpdateLastCheckedAt sets the "last_checked_at" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateLastCheckedAt() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateLastCheckedAt()
+ })
+}
+
+// ClearLastCheckedAt clears the value of the "last_checked_at" field.
+func (u *ChannelMonitorUpsertBulk) ClearLastCheckedAt() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearLastCheckedAt()
+ })
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (u *ChannelMonitorUpsertBulk) SetCreatedBy(v int64) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetCreatedBy(v)
+ })
+}
+
+// AddCreatedBy adds v to the "created_by" field.
+func (u *ChannelMonitorUpsertBulk) AddCreatedBy(v int64) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.AddCreatedBy(v)
+ })
+}
+
+// UpdateCreatedBy sets the "created_by" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateCreatedBy() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateCreatedBy()
+ })
+}
+
+// SetTemplateID sets the "template_id" field.
+func (u *ChannelMonitorUpsertBulk) SetTemplateID(v int64) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetTemplateID(v)
+ })
+}
+
+// UpdateTemplateID sets the "template_id" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateTemplateID() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateTemplateID()
+ })
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (u *ChannelMonitorUpsertBulk) ClearTemplateID() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearTemplateID()
+ })
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorUpsertBulk) SetExtraHeaders(v map[string]string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetExtraHeaders(v)
+ })
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateExtraHeaders() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateExtraHeaders()
+ })
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorUpsertBulk) SetBodyOverrideMode(v string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetBodyOverrideMode(v)
+ })
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateBodyOverrideMode() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateBodyOverrideMode()
+ })
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorUpsertBulk) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetBodyOverride(v)
+ })
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateBodyOverride() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateBodyOverride()
+ })
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorUpsertBulk) ClearBodyOverride() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.ClearBodyOverride()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ChannelMonitorCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitor_delete.go b/backend/ent/channelmonitor_delete.go
new file mode 100644
index 00000000..500dbb48
--- /dev/null
+++ b/backend/ent/channelmonitor_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorDelete is the builder for deleting a ChannelMonitor entity.
+type ChannelMonitorDelete struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorDelete builder.
+func (_d *ChannelMonitorDelete) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *ChannelMonitorDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *ChannelMonitorDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(channelmonitor.Table, sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// ChannelMonitorDeleteOne is the builder for deleting a single ChannelMonitor entity.
+type ChannelMonitorDeleteOne struct {
+ _d *ChannelMonitorDelete
+}
+
+// Where appends a list predicates to the ChannelMonitorDelete builder.
+func (_d *ChannelMonitorDeleteOne) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *ChannelMonitorDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{channelmonitor.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitor_query.go b/backend/ent/channelmonitor_query.go
new file mode 100644
index 00000000..b6722e78
--- /dev/null
+++ b/backend/ent/channelmonitor_query.go
@@ -0,0 +1,797 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorQuery is the builder for querying ChannelMonitor entities.
+type ChannelMonitorQuery struct {
+ config
+ ctx *QueryContext
+ order []channelmonitor.OrderOption
+ inters []Interceptor
+ predicates []predicate.ChannelMonitor
+ withHistory *ChannelMonitorHistoryQuery
+ withDailyRollups *ChannelMonitorDailyRollupQuery
+ withRequestTemplate *ChannelMonitorRequestTemplateQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the ChannelMonitorQuery builder.
+func (_q *ChannelMonitorQuery) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *ChannelMonitorQuery) Limit(limit int) *ChannelMonitorQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *ChannelMonitorQuery) Offset(offset int) *ChannelMonitorQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *ChannelMonitorQuery) Unique(unique bool) *ChannelMonitorQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *ChannelMonitorQuery) Order(o ...channelmonitor.OrderOption) *ChannelMonitorQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryHistory chains the current query on the "history" edge.
+func (_q *ChannelMonitorQuery) QueryHistory() *ChannelMonitorHistoryQuery {
+ query := (&ChannelMonitorHistoryClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, selector),
+ sqlgraph.To(channelmonitorhistory.Table, channelmonitorhistory.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, channelmonitor.HistoryTable, channelmonitor.HistoryColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryDailyRollups chains the current query on the "daily_rollups" edge.
+func (_q *ChannelMonitorQuery) QueryDailyRollups() *ChannelMonitorDailyRollupQuery {
+ query := (&ChannelMonitorDailyRollupClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, selector),
+ sqlgraph.To(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, channelmonitor.DailyRollupsTable, channelmonitor.DailyRollupsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryRequestTemplate chains the current query on the "request_template" edge.
+func (_q *ChannelMonitorQuery) QueryRequestTemplate() *ChannelMonitorRequestTemplateQuery {
+ query := (&ChannelMonitorRequestTemplateClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, selector),
+ sqlgraph.To(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, false, channelmonitor.RequestTemplateTable, channelmonitor.RequestTemplateColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first ChannelMonitor entity from the query.
+// Returns a *NotFoundError when no ChannelMonitor was found.
+func (_q *ChannelMonitorQuery) First(ctx context.Context) (*ChannelMonitor, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{channelmonitor.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) FirstX(ctx context.Context) *ChannelMonitor {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first ChannelMonitor ID from the query.
+// Returns a *NotFoundError when no ChannelMonitor ID was found.
+func (_q *ChannelMonitorQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{channelmonitor.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single ChannelMonitor entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one ChannelMonitor entity is found.
+// Returns a *NotFoundError when no ChannelMonitor entities are found.
+func (_q *ChannelMonitorQuery) Only(ctx context.Context) (*ChannelMonitor, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{channelmonitor.Label}
+ default:
+ return nil, &NotSingularError{channelmonitor.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) OnlyX(ctx context.Context) *ChannelMonitor {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only ChannelMonitor ID in the query.
+// Returns a *NotSingularError when more than one ChannelMonitor ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *ChannelMonitorQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{channelmonitor.Label}
+ default:
+ err = &NotSingularError{channelmonitor.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of ChannelMonitors.
+func (_q *ChannelMonitorQuery) All(ctx context.Context) ([]*ChannelMonitor, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*ChannelMonitor, *ChannelMonitorQuery]()
+ return withInterceptors[[]*ChannelMonitor](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) AllX(ctx context.Context) []*ChannelMonitor {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of ChannelMonitor IDs.
+func (_q *ChannelMonitorQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(channelmonitor.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *ChannelMonitorQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*ChannelMonitorQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *ChannelMonitorQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *ChannelMonitorQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the ChannelMonitorQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *ChannelMonitorQuery) Clone() *ChannelMonitorQuery {
+ if _q == nil {
+ return nil
+ }
+ return &ChannelMonitorQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]channelmonitor.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.ChannelMonitor{}, _q.predicates...),
+ withHistory: _q.withHistory.Clone(),
+ withDailyRollups: _q.withDailyRollups.Clone(),
+ withRequestTemplate: _q.withRequestTemplate.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithHistory tells the query-builder to eager-load the nodes that are connected to
+// the "history" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorQuery) WithHistory(opts ...func(*ChannelMonitorHistoryQuery)) *ChannelMonitorQuery {
+ query := (&ChannelMonitorHistoryClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withHistory = query
+ return _q
+}
+
+// WithDailyRollups tells the query-builder to eager-load the nodes that are connected to
+// the "daily_rollups" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorQuery) WithDailyRollups(opts ...func(*ChannelMonitorDailyRollupQuery)) *ChannelMonitorQuery {
+ query := (&ChannelMonitorDailyRollupClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withDailyRollups = query
+ return _q
+}
+
+// WithRequestTemplate tells the query-builder to eager-load the nodes that are connected to
+// the "request_template" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorQuery) WithRequestTemplate(opts ...func(*ChannelMonitorRequestTemplateQuery)) *ChannelMonitorQuery {
+ query := (&ChannelMonitorRequestTemplateClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withRequestTemplate = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.ChannelMonitor.Query().
+// GroupBy(channelmonitor.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorQuery) GroupBy(field string, fields ...string) *ChannelMonitorGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &ChannelMonitorGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = channelmonitor.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.ChannelMonitor.Query().
+// Select(channelmonitor.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorQuery) Select(fields ...string) *ChannelMonitorSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &ChannelMonitorSelect{ChannelMonitorQuery: _q}
+ sbuild.label = channelmonitor.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a ChannelMonitorSelect configured with the given aggregations.
+func (_q *ChannelMonitorQuery) Aggregate(fns ...AggregateFunc) *ChannelMonitorSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *ChannelMonitorQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !channelmonitor.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ChannelMonitor, error) {
+ var (
+ nodes = []*ChannelMonitor{}
+ _spec = _q.querySpec()
+ loadedTypes = [3]bool{
+ _q.withHistory != nil,
+ _q.withDailyRollups != nil,
+ _q.withRequestTemplate != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*ChannelMonitor).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &ChannelMonitor{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withHistory; query != nil {
+ if err := _q.loadHistory(ctx, query, nodes,
+ func(n *ChannelMonitor) { n.Edges.History = []*ChannelMonitorHistory{} },
+ func(n *ChannelMonitor, e *ChannelMonitorHistory) { n.Edges.History = append(n.Edges.History, e) }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withDailyRollups; query != nil {
+ if err := _q.loadDailyRollups(ctx, query, nodes,
+ func(n *ChannelMonitor) { n.Edges.DailyRollups = []*ChannelMonitorDailyRollup{} },
+ func(n *ChannelMonitor, e *ChannelMonitorDailyRollup) {
+ n.Edges.DailyRollups = append(n.Edges.DailyRollups, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withRequestTemplate; query != nil {
+ if err := _q.loadRequestTemplate(ctx, query, nodes, nil,
+ func(n *ChannelMonitor, e *ChannelMonitorRequestTemplate) { n.Edges.RequestTemplate = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *ChannelMonitorQuery) loadHistory(ctx context.Context, query *ChannelMonitorHistoryQuery, nodes []*ChannelMonitor, init func(*ChannelMonitor), assign func(*ChannelMonitor, *ChannelMonitorHistory)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*ChannelMonitor)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(channelmonitorhistory.FieldMonitorID)
+ }
+ query.Where(predicate.ChannelMonitorHistory(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(channelmonitor.HistoryColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.MonitorID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "monitor_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+func (_q *ChannelMonitorQuery) loadDailyRollups(ctx context.Context, query *ChannelMonitorDailyRollupQuery, nodes []*ChannelMonitor, init func(*ChannelMonitor), assign func(*ChannelMonitor, *ChannelMonitorDailyRollup)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*ChannelMonitor)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(channelmonitordailyrollup.FieldMonitorID)
+ }
+ query.Where(predicate.ChannelMonitorDailyRollup(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(channelmonitor.DailyRollupsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.MonitorID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "monitor_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+func (_q *ChannelMonitorQuery) loadRequestTemplate(ctx context.Context, query *ChannelMonitorRequestTemplateQuery, nodes []*ChannelMonitor, init func(*ChannelMonitor), assign func(*ChannelMonitor, *ChannelMonitorRequestTemplate)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*ChannelMonitor)
+ for i := range nodes {
+ if nodes[i].TemplateID == nil {
+ continue
+ }
+ fk := *nodes[i].TemplateID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(channelmonitorrequesttemplate.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "template_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *ChannelMonitorQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(channelmonitor.Table, channelmonitor.Columns, sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitor.FieldID)
+ for i := range fields {
+ if fields[i] != channelmonitor.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withRequestTemplate != nil {
+ _spec.Node.AddColumnOnce(channelmonitor.FieldTemplateID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *ChannelMonitorQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(channelmonitor.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = channelmonitor.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *ChannelMonitorQuery) ForUpdate(opts ...sql.LockOption) *ChannelMonitorQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *ChannelMonitorQuery) ForShare(opts ...sql.LockOption) *ChannelMonitorQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// ChannelMonitorGroupBy is the group-by builder for ChannelMonitor entities.
+type ChannelMonitorGroupBy struct {
+ selector
+ build *ChannelMonitorQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *ChannelMonitorGroupBy) Aggregate(fns ...AggregateFunc) *ChannelMonitorGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *ChannelMonitorGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorQuery, *ChannelMonitorGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *ChannelMonitorGroupBy) sqlScan(ctx context.Context, root *ChannelMonitorQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// ChannelMonitorSelect is the builder for selecting fields of ChannelMonitor entities.
+type ChannelMonitorSelect struct {
+ *ChannelMonitorQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *ChannelMonitorSelect) Aggregate(fns ...AggregateFunc) *ChannelMonitorSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *ChannelMonitorSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorQuery, *ChannelMonitorSelect](ctx, _s.ChannelMonitorQuery, _s, _s.inters, v)
+}
+
+func (_s *ChannelMonitorSelect) sqlScan(ctx context.Context, root *ChannelMonitorQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/channelmonitor_update.go b/backend/ent/channelmonitor_update.go
new file mode 100644
index 00000000..4bbcd564
--- /dev/null
+++ b/backend/ent/channelmonitor_update.go
@@ -0,0 +1,1328 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/dialect/sql/sqljson"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorUpdate is the builder for updating ChannelMonitor entities.
+type ChannelMonitorUpdate struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorUpdate builder.
+func (_u *ChannelMonitorUpdate) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *ChannelMonitorUpdate) SetUpdatedAt(v time.Time) *ChannelMonitorUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetName sets the "name" field.
+func (_u *ChannelMonitorUpdate) SetName(v string) *ChannelMonitorUpdate {
+ _u.mutation.SetName(v)
+ return _u
+}
+
+// SetNillableName sets the "name" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableName(v *string) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetName(*v)
+ }
+ return _u
+}
+
+// SetProvider sets the "provider" field.
+func (_u *ChannelMonitorUpdate) SetProvider(v channelmonitor.Provider) *ChannelMonitorUpdate {
+ _u.mutation.SetProvider(v)
+ return _u
+}
+
+// SetNillableProvider sets the "provider" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableProvider(v *channelmonitor.Provider) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetProvider(*v)
+ }
+ return _u
+}
+
+// SetEndpoint sets the "endpoint" field.
+func (_u *ChannelMonitorUpdate) SetEndpoint(v string) *ChannelMonitorUpdate {
+ _u.mutation.SetEndpoint(v)
+ return _u
+}
+
+// SetNillableEndpoint sets the "endpoint" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableEndpoint(v *string) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetEndpoint(*v)
+ }
+ return _u
+}
+
+// SetAPIKeyEncrypted sets the "api_key_encrypted" field.
+func (_u *ChannelMonitorUpdate) SetAPIKeyEncrypted(v string) *ChannelMonitorUpdate {
+ _u.mutation.SetAPIKeyEncrypted(v)
+ return _u
+}
+
+// SetNillableAPIKeyEncrypted sets the "api_key_encrypted" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableAPIKeyEncrypted(v *string) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetAPIKeyEncrypted(*v)
+ }
+ return _u
+}
+
+// SetPrimaryModel sets the "primary_model" field.
+func (_u *ChannelMonitorUpdate) SetPrimaryModel(v string) *ChannelMonitorUpdate {
+ _u.mutation.SetPrimaryModel(v)
+ return _u
+}
+
+// SetNillablePrimaryModel sets the "primary_model" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillablePrimaryModel(v *string) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetPrimaryModel(*v)
+ }
+ return _u
+}
+
+// SetExtraModels sets the "extra_models" field.
+func (_u *ChannelMonitorUpdate) SetExtraModels(v []string) *ChannelMonitorUpdate {
+ _u.mutation.SetExtraModels(v)
+ return _u
+}
+
+// AppendExtraModels appends value to the "extra_models" field.
+func (_u *ChannelMonitorUpdate) AppendExtraModels(v []string) *ChannelMonitorUpdate {
+ _u.mutation.AppendExtraModels(v)
+ return _u
+}
+
+// SetGroupName sets the "group_name" field.
+func (_u *ChannelMonitorUpdate) SetGroupName(v string) *ChannelMonitorUpdate {
+ _u.mutation.SetGroupName(v)
+ return _u
+}
+
+// SetNillableGroupName sets the "group_name" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableGroupName(v *string) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetGroupName(*v)
+ }
+ return _u
+}
+
+// ClearGroupName clears the value of the "group_name" field.
+func (_u *ChannelMonitorUpdate) ClearGroupName() *ChannelMonitorUpdate {
+ _u.mutation.ClearGroupName()
+ return _u
+}
+
+// SetEnabled sets the "enabled" field.
+func (_u *ChannelMonitorUpdate) SetEnabled(v bool) *ChannelMonitorUpdate {
+ _u.mutation.SetEnabled(v)
+ return _u
+}
+
+// SetNillableEnabled sets the "enabled" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableEnabled(v *bool) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetEnabled(*v)
+ }
+ return _u
+}
+
+// SetIntervalSeconds sets the "interval_seconds" field.
+func (_u *ChannelMonitorUpdate) SetIntervalSeconds(v int) *ChannelMonitorUpdate {
+ _u.mutation.ResetIntervalSeconds()
+ _u.mutation.SetIntervalSeconds(v)
+ return _u
+}
+
+// SetNillableIntervalSeconds sets the "interval_seconds" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableIntervalSeconds(v *int) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetIntervalSeconds(*v)
+ }
+ return _u
+}
+
+// AddIntervalSeconds adds value to the "interval_seconds" field.
+func (_u *ChannelMonitorUpdate) AddIntervalSeconds(v int) *ChannelMonitorUpdate {
+ _u.mutation.AddIntervalSeconds(v)
+ return _u
+}
+
+// SetLastCheckedAt sets the "last_checked_at" field.
+func (_u *ChannelMonitorUpdate) SetLastCheckedAt(v time.Time) *ChannelMonitorUpdate {
+ _u.mutation.SetLastCheckedAt(v)
+ return _u
+}
+
+// SetNillableLastCheckedAt sets the "last_checked_at" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableLastCheckedAt(v *time.Time) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetLastCheckedAt(*v)
+ }
+ return _u
+}
+
+// ClearLastCheckedAt clears the value of the "last_checked_at" field.
+func (_u *ChannelMonitorUpdate) ClearLastCheckedAt() *ChannelMonitorUpdate {
+ _u.mutation.ClearLastCheckedAt()
+ return _u
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (_u *ChannelMonitorUpdate) SetCreatedBy(v int64) *ChannelMonitorUpdate {
+ _u.mutation.ResetCreatedBy()
+ _u.mutation.SetCreatedBy(v)
+ return _u
+}
+
+// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableCreatedBy(v *int64) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetCreatedBy(*v)
+ }
+ return _u
+}
+
+// AddCreatedBy adds value to the "created_by" field.
+func (_u *ChannelMonitorUpdate) AddCreatedBy(v int64) *ChannelMonitorUpdate {
+ _u.mutation.AddCreatedBy(v)
+ return _u
+}
+
+// SetTemplateID sets the "template_id" field.
+func (_u *ChannelMonitorUpdate) SetTemplateID(v int64) *ChannelMonitorUpdate {
+ _u.mutation.SetTemplateID(v)
+ return _u
+}
+
+// SetNillableTemplateID sets the "template_id" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableTemplateID(v *int64) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetTemplateID(*v)
+ }
+ return _u
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (_u *ChannelMonitorUpdate) ClearTemplateID() *ChannelMonitorUpdate {
+ _u.mutation.ClearTemplateID()
+ return _u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_u *ChannelMonitorUpdate) SetExtraHeaders(v map[string]string) *ChannelMonitorUpdate {
+ _u.mutation.SetExtraHeaders(v)
+ return _u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_u *ChannelMonitorUpdate) SetBodyOverrideMode(v string) *ChannelMonitorUpdate {
+ _u.mutation.SetBodyOverrideMode(v)
+ return _u
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetBodyOverrideMode(*v)
+ }
+ return _u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_u *ChannelMonitorUpdate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpdate {
+ _u.mutation.SetBodyOverride(v)
+ return _u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (_u *ChannelMonitorUpdate) ClearBodyOverride() *ChannelMonitorUpdate {
+ _u.mutation.ClearBodyOverride()
+ return _u
+}
+
+// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by IDs.
+func (_u *ChannelMonitorUpdate) AddHistoryIDs(ids ...int64) *ChannelMonitorUpdate {
+ _u.mutation.AddHistoryIDs(ids...)
+ return _u
+}
+
+// AddHistory adds the "history" edges to the ChannelMonitorHistory entity.
+func (_u *ChannelMonitorUpdate) AddHistory(v ...*ChannelMonitorHistory) *ChannelMonitorUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddHistoryIDs(ids...)
+}
+
+// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs.
+func (_u *ChannelMonitorUpdate) AddDailyRollupIDs(ids ...int64) *ChannelMonitorUpdate {
+ _u.mutation.AddDailyRollupIDs(ids...)
+ return _u
+}
+
+// AddDailyRollups adds the "daily_rollups" edges to the ChannelMonitorDailyRollup entity.
+func (_u *ChannelMonitorUpdate) AddDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddDailyRollupIDs(ids...)
+}
+
+// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID.
+func (_u *ChannelMonitorUpdate) SetRequestTemplateID(id int64) *ChannelMonitorUpdate {
+ _u.mutation.SetRequestTemplateID(id)
+ return _u
+}
+
+// SetNillableRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableRequestTemplateID(id *int64) *ChannelMonitorUpdate {
+ if id != nil {
+ _u = _u.SetRequestTemplateID(*id)
+ }
+ return _u
+}
+
+// SetRequestTemplate sets the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (_u *ChannelMonitorUpdate) SetRequestTemplate(v *ChannelMonitorRequestTemplate) *ChannelMonitorUpdate {
+ return _u.SetRequestTemplateID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorMutation object of the builder.
+func (_u *ChannelMonitorUpdate) Mutation() *ChannelMonitorMutation {
+ return _u.mutation
+}
+
+// ClearHistory clears all "history" edges to the ChannelMonitorHistory entity.
+func (_u *ChannelMonitorUpdate) ClearHistory() *ChannelMonitorUpdate {
+ _u.mutation.ClearHistory()
+ return _u
+}
+
+// RemoveHistoryIDs removes the "history" edge to ChannelMonitorHistory entities by IDs.
+func (_u *ChannelMonitorUpdate) RemoveHistoryIDs(ids ...int64) *ChannelMonitorUpdate {
+ _u.mutation.RemoveHistoryIDs(ids...)
+ return _u
+}
+
+// RemoveHistory removes "history" edges to ChannelMonitorHistory entities.
+func (_u *ChannelMonitorUpdate) RemoveHistory(v ...*ChannelMonitorHistory) *ChannelMonitorUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveHistoryIDs(ids...)
+}
+
+// ClearDailyRollups clears all "daily_rollups" edges to the ChannelMonitorDailyRollup entity.
+func (_u *ChannelMonitorUpdate) ClearDailyRollups() *ChannelMonitorUpdate {
+ _u.mutation.ClearDailyRollups()
+ return _u
+}
+
+// RemoveDailyRollupIDs removes the "daily_rollups" edge to ChannelMonitorDailyRollup entities by IDs.
+func (_u *ChannelMonitorUpdate) RemoveDailyRollupIDs(ids ...int64) *ChannelMonitorUpdate {
+ _u.mutation.RemoveDailyRollupIDs(ids...)
+ return _u
+}
+
+// RemoveDailyRollups removes "daily_rollups" edges to ChannelMonitorDailyRollup entities.
+func (_u *ChannelMonitorUpdate) RemoveDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveDailyRollupIDs(ids...)
+}
+
+// ClearRequestTemplate clears the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (_u *ChannelMonitorUpdate) ClearRequestTemplate() *ChannelMonitorUpdate {
+ _u.mutation.ClearRequestTemplate()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *ChannelMonitorUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *ChannelMonitorUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := channelmonitor.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorUpdate) check() error {
+ if v, ok := _u.mutation.Name(); ok {
+ if err := channelmonitor.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Provider(); ok {
+ if err := channelmonitor.ProviderValidator(v); err != nil {
+ return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Endpoint(); ok {
+ if err := channelmonitor.EndpointValidator(v); err != nil {
+ return &ValidationError{Name: "endpoint", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.endpoint": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.APIKeyEncrypted(); ok {
+ if err := channelmonitor.APIKeyEncryptedValidator(v); err != nil {
+ return &ValidationError{Name: "api_key_encrypted", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_key_encrypted": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.PrimaryModel(); ok {
+ if err := channelmonitor.PrimaryModelValidator(v); err != nil {
+ return &ValidationError{Name: "primary_model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.primary_model": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.GroupName(); ok {
+ if err := channelmonitor.GroupNameValidator(v); err != nil {
+ return &ValidationError{Name: "group_name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.group_name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.IntervalSeconds(); ok {
+ if err := channelmonitor.IntervalSecondsValidator(v); err != nil {
+ return &ValidationError{Name: "interval_seconds", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.interval_seconds": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitor.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.body_override_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitor.Table, channelmonitor.Columns, sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(channelmonitor.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.Name(); ok {
+ _spec.SetField(channelmonitor.FieldName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Provider(); ok {
+ _spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value)
+ }
+ if value, ok := _u.mutation.Endpoint(); ok {
+ _spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.APIKeyEncrypted(); ok {
+ _spec.SetField(channelmonitor.FieldAPIKeyEncrypted, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.PrimaryModel(); ok {
+ _spec.SetField(channelmonitor.FieldPrimaryModel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ExtraModels(); ok {
+ _spec.SetField(channelmonitor.FieldExtraModels, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.AppendedExtraModels(); ok {
+ _spec.AddModifier(func(u *sql.UpdateBuilder) {
+ sqljson.Append(u, channelmonitor.FieldExtraModels, value)
+ })
+ }
+ if value, ok := _u.mutation.GroupName(); ok {
+ _spec.SetField(channelmonitor.FieldGroupName, field.TypeString, value)
+ }
+ if _u.mutation.GroupNameCleared() {
+ _spec.ClearField(channelmonitor.FieldGroupName, field.TypeString)
+ }
+ if value, ok := _u.mutation.Enabled(); ok {
+ _spec.SetField(channelmonitor.FieldEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.IntervalSeconds(); ok {
+ _spec.SetField(channelmonitor.FieldIntervalSeconds, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedIntervalSeconds(); ok {
+ _spec.AddField(channelmonitor.FieldIntervalSeconds, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.LastCheckedAt(); ok {
+ _spec.SetField(channelmonitor.FieldLastCheckedAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastCheckedAtCleared() {
+ _spec.ClearField(channelmonitor.FieldLastCheckedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.CreatedBy(); ok {
+ _spec.SetField(channelmonitor.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedCreatedBy(); ok {
+ _spec.AddField(channelmonitor.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitor.FieldExtraHeaders, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverrideMode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverride, field.TypeJSON, value)
+ }
+ if _u.mutation.BodyOverrideCleared() {
+ _spec.ClearField(channelmonitor.FieldBodyOverride, field.TypeJSON)
+ }
+ if _u.mutation.HistoryCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.HistoryTable,
+ Columns: []string{channelmonitor.HistoryColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedHistoryIDs(); len(nodes) > 0 && !_u.mutation.HistoryCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.HistoryTable,
+ Columns: []string{channelmonitor.HistoryColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.HistoryIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.HistoryTable,
+ Columns: []string{channelmonitor.HistoryColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.DailyRollupsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedDailyRollupsIDs(); len(nodes) > 0 && !_u.mutation.DailyRollupsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.DailyRollupsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.RequestTemplateCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: false,
+ Table: channelmonitor.RequestTemplateTable,
+ Columns: []string{channelmonitor.RequestTemplateColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RequestTemplateIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: false,
+ Table: channelmonitor.RequestTemplateTable,
+ Columns: []string{channelmonitor.RequestTemplateColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitor.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// ChannelMonitorUpdateOne is the builder for updating a single ChannelMonitor entity.
+type ChannelMonitorUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *ChannelMonitorMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *ChannelMonitorUpdateOne) SetUpdatedAt(v time.Time) *ChannelMonitorUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetName sets the "name" field.
+func (_u *ChannelMonitorUpdateOne) SetName(v string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetName(v)
+ return _u
+}
+
+// SetNillableName sets the "name" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableName(v *string) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetName(*v)
+ }
+ return _u
+}
+
+// SetProvider sets the "provider" field.
+func (_u *ChannelMonitorUpdateOne) SetProvider(v channelmonitor.Provider) *ChannelMonitorUpdateOne {
+ _u.mutation.SetProvider(v)
+ return _u
+}
+
+// SetNillableProvider sets the "provider" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableProvider(v *channelmonitor.Provider) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetProvider(*v)
+ }
+ return _u
+}
+
+// SetEndpoint sets the "endpoint" field.
+func (_u *ChannelMonitorUpdateOne) SetEndpoint(v string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetEndpoint(v)
+ return _u
+}
+
+// SetNillableEndpoint sets the "endpoint" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableEndpoint(v *string) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetEndpoint(*v)
+ }
+ return _u
+}
+
+// SetAPIKeyEncrypted sets the "api_key_encrypted" field.
+func (_u *ChannelMonitorUpdateOne) SetAPIKeyEncrypted(v string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetAPIKeyEncrypted(v)
+ return _u
+}
+
+// SetNillableAPIKeyEncrypted sets the "api_key_encrypted" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableAPIKeyEncrypted(v *string) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetAPIKeyEncrypted(*v)
+ }
+ return _u
+}
+
+// SetPrimaryModel sets the "primary_model" field.
+func (_u *ChannelMonitorUpdateOne) SetPrimaryModel(v string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetPrimaryModel(v)
+ return _u
+}
+
+// SetNillablePrimaryModel sets the "primary_model" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillablePrimaryModel(v *string) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetPrimaryModel(*v)
+ }
+ return _u
+}
+
+// SetExtraModels sets the "extra_models" field.
+func (_u *ChannelMonitorUpdateOne) SetExtraModels(v []string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetExtraModels(v)
+ return _u
+}
+
+// AppendExtraModels appends value to the "extra_models" field.
+func (_u *ChannelMonitorUpdateOne) AppendExtraModels(v []string) *ChannelMonitorUpdateOne {
+ _u.mutation.AppendExtraModels(v)
+ return _u
+}
+
+// SetGroupName sets the "group_name" field.
+func (_u *ChannelMonitorUpdateOne) SetGroupName(v string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetGroupName(v)
+ return _u
+}
+
+// SetNillableGroupName sets the "group_name" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableGroupName(v *string) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetGroupName(*v)
+ }
+ return _u
+}
+
+// ClearGroupName clears the value of the "group_name" field.
+func (_u *ChannelMonitorUpdateOne) ClearGroupName() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearGroupName()
+ return _u
+}
+
+// SetEnabled sets the "enabled" field.
+func (_u *ChannelMonitorUpdateOne) SetEnabled(v bool) *ChannelMonitorUpdateOne {
+ _u.mutation.SetEnabled(v)
+ return _u
+}
+
+// SetNillableEnabled sets the "enabled" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableEnabled(v *bool) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetEnabled(*v)
+ }
+ return _u
+}
+
+// SetIntervalSeconds sets the "interval_seconds" field.
+func (_u *ChannelMonitorUpdateOne) SetIntervalSeconds(v int) *ChannelMonitorUpdateOne {
+ _u.mutation.ResetIntervalSeconds()
+ _u.mutation.SetIntervalSeconds(v)
+ return _u
+}
+
+// SetNillableIntervalSeconds sets the "interval_seconds" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableIntervalSeconds(v *int) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetIntervalSeconds(*v)
+ }
+ return _u
+}
+
+// AddIntervalSeconds adds value to the "interval_seconds" field.
+func (_u *ChannelMonitorUpdateOne) AddIntervalSeconds(v int) *ChannelMonitorUpdateOne {
+ _u.mutation.AddIntervalSeconds(v)
+ return _u
+}
+
+// SetLastCheckedAt sets the "last_checked_at" field.
+func (_u *ChannelMonitorUpdateOne) SetLastCheckedAt(v time.Time) *ChannelMonitorUpdateOne {
+ _u.mutation.SetLastCheckedAt(v)
+ return _u
+}
+
+// SetNillableLastCheckedAt sets the "last_checked_at" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableLastCheckedAt(v *time.Time) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetLastCheckedAt(*v)
+ }
+ return _u
+}
+
+// ClearLastCheckedAt clears the value of the "last_checked_at" field.
+func (_u *ChannelMonitorUpdateOne) ClearLastCheckedAt() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearLastCheckedAt()
+ return _u
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (_u *ChannelMonitorUpdateOne) SetCreatedBy(v int64) *ChannelMonitorUpdateOne {
+ _u.mutation.ResetCreatedBy()
+ _u.mutation.SetCreatedBy(v)
+ return _u
+}
+
+// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableCreatedBy(v *int64) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetCreatedBy(*v)
+ }
+ return _u
+}
+
+// AddCreatedBy adds value to the "created_by" field.
+func (_u *ChannelMonitorUpdateOne) AddCreatedBy(v int64) *ChannelMonitorUpdateOne {
+ _u.mutation.AddCreatedBy(v)
+ return _u
+}
+
+// SetTemplateID sets the "template_id" field.
+func (_u *ChannelMonitorUpdateOne) SetTemplateID(v int64) *ChannelMonitorUpdateOne {
+ _u.mutation.SetTemplateID(v)
+ return _u
+}
+
+// SetNillableTemplateID sets the "template_id" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableTemplateID(v *int64) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetTemplateID(*v)
+ }
+ return _u
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (_u *ChannelMonitorUpdateOne) ClearTemplateID() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearTemplateID()
+ return _u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_u *ChannelMonitorUpdateOne) SetExtraHeaders(v map[string]string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetExtraHeaders(v)
+ return _u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_u *ChannelMonitorUpdateOne) SetBodyOverrideMode(v string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetBodyOverrideMode(v)
+ return _u
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableBodyOverrideMode(v *string) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetBodyOverrideMode(*v)
+ }
+ return _u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_u *ChannelMonitorUpdateOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorUpdateOne {
+ _u.mutation.SetBodyOverride(v)
+ return _u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (_u *ChannelMonitorUpdateOne) ClearBodyOverride() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearBodyOverride()
+ return _u
+}
+
+// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by IDs.
+func (_u *ChannelMonitorUpdateOne) AddHistoryIDs(ids ...int64) *ChannelMonitorUpdateOne {
+ _u.mutation.AddHistoryIDs(ids...)
+ return _u
+}
+
+// AddHistory adds the "history" edges to the ChannelMonitorHistory entity.
+func (_u *ChannelMonitorUpdateOne) AddHistory(v ...*ChannelMonitorHistory) *ChannelMonitorUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddHistoryIDs(ids...)
+}
+
+// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs.
+func (_u *ChannelMonitorUpdateOne) AddDailyRollupIDs(ids ...int64) *ChannelMonitorUpdateOne {
+ _u.mutation.AddDailyRollupIDs(ids...)
+ return _u
+}
+
+// AddDailyRollups adds the "daily_rollups" edges to the ChannelMonitorDailyRollup entity.
+func (_u *ChannelMonitorUpdateOne) AddDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddDailyRollupIDs(ids...)
+}
+
+// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID.
+func (_u *ChannelMonitorUpdateOne) SetRequestTemplateID(id int64) *ChannelMonitorUpdateOne {
+ _u.mutation.SetRequestTemplateID(id)
+ return _u
+}
+
+// SetNillableRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by ID if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableRequestTemplateID(id *int64) *ChannelMonitorUpdateOne {
+ if id != nil {
+ _u = _u.SetRequestTemplateID(*id)
+ }
+ return _u
+}
+
+// SetRequestTemplate sets the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (_u *ChannelMonitorUpdateOne) SetRequestTemplate(v *ChannelMonitorRequestTemplate) *ChannelMonitorUpdateOne {
+ return _u.SetRequestTemplateID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorMutation object of the builder.
+func (_u *ChannelMonitorUpdateOne) Mutation() *ChannelMonitorMutation {
+ return _u.mutation
+}
+
+// ClearHistory clears all "history" edges to the ChannelMonitorHistory entity.
+func (_u *ChannelMonitorUpdateOne) ClearHistory() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearHistory()
+ return _u
+}
+
+// RemoveHistoryIDs removes the "history" edge to ChannelMonitorHistory entities by IDs.
+func (_u *ChannelMonitorUpdateOne) RemoveHistoryIDs(ids ...int64) *ChannelMonitorUpdateOne {
+ _u.mutation.RemoveHistoryIDs(ids...)
+ return _u
+}
+
+// RemoveHistory removes "history" edges to ChannelMonitorHistory entities.
+func (_u *ChannelMonitorUpdateOne) RemoveHistory(v ...*ChannelMonitorHistory) *ChannelMonitorUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveHistoryIDs(ids...)
+}
+
+// ClearDailyRollups clears all "daily_rollups" edges to the ChannelMonitorDailyRollup entity.
+func (_u *ChannelMonitorUpdateOne) ClearDailyRollups() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearDailyRollups()
+ return _u
+}
+
+// RemoveDailyRollupIDs removes the "daily_rollups" edge to ChannelMonitorDailyRollup entities by IDs.
+func (_u *ChannelMonitorUpdateOne) RemoveDailyRollupIDs(ids ...int64) *ChannelMonitorUpdateOne {
+ _u.mutation.RemoveDailyRollupIDs(ids...)
+ return _u
+}
+
+// RemoveDailyRollups removes "daily_rollups" edges to ChannelMonitorDailyRollup entities.
+func (_u *ChannelMonitorUpdateOne) RemoveDailyRollups(v ...*ChannelMonitorDailyRollup) *ChannelMonitorUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveDailyRollupIDs(ids...)
+}
+
+// ClearRequestTemplate clears the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (_u *ChannelMonitorUpdateOne) ClearRequestTemplate() *ChannelMonitorUpdateOne {
+ _u.mutation.ClearRequestTemplate()
+ return _u
+}
+
+// Where appends a list predicates to the ChannelMonitorUpdate builder.
+func (_u *ChannelMonitorUpdateOne) Where(ps ...predicate.ChannelMonitor) *ChannelMonitorUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *ChannelMonitorUpdateOne) Select(field string, fields ...string) *ChannelMonitorUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated ChannelMonitor entity.
+func (_u *ChannelMonitorUpdateOne) Save(ctx context.Context) (*ChannelMonitor, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorUpdateOne) SaveX(ctx context.Context) *ChannelMonitor {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *ChannelMonitorUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := channelmonitor.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorUpdateOne) check() error {
+ if v, ok := _u.mutation.Name(); ok {
+ if err := channelmonitor.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Provider(); ok {
+ if err := channelmonitor.ProviderValidator(v); err != nil {
+ return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Endpoint(); ok {
+ if err := channelmonitor.EndpointValidator(v); err != nil {
+ return &ValidationError{Name: "endpoint", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.endpoint": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.APIKeyEncrypted(); ok {
+ if err := channelmonitor.APIKeyEncryptedValidator(v); err != nil {
+ return &ValidationError{Name: "api_key_encrypted", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_key_encrypted": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.PrimaryModel(); ok {
+ if err := channelmonitor.PrimaryModelValidator(v); err != nil {
+ return &ValidationError{Name: "primary_model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.primary_model": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.GroupName(); ok {
+ if err := channelmonitor.GroupNameValidator(v); err != nil {
+ return &ValidationError{Name: "group_name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.group_name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.IntervalSeconds(); ok {
+ if err := channelmonitor.IntervalSecondsValidator(v); err != nil {
+ return &ValidationError{Name: "interval_seconds", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.interval_seconds": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitor.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.body_override_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorUpdateOne) sqlSave(ctx context.Context) (_node *ChannelMonitor, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitor.Table, channelmonitor.Columns, sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ChannelMonitor.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitor.FieldID)
+ for _, f := range fields {
+ if !channelmonitor.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != channelmonitor.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(channelmonitor.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.Name(); ok {
+ _spec.SetField(channelmonitor.FieldName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Provider(); ok {
+ _spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value)
+ }
+ if value, ok := _u.mutation.Endpoint(); ok {
+ _spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.APIKeyEncrypted(); ok {
+ _spec.SetField(channelmonitor.FieldAPIKeyEncrypted, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.PrimaryModel(); ok {
+ _spec.SetField(channelmonitor.FieldPrimaryModel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ExtraModels(); ok {
+ _spec.SetField(channelmonitor.FieldExtraModels, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.AppendedExtraModels(); ok {
+ _spec.AddModifier(func(u *sql.UpdateBuilder) {
+ sqljson.Append(u, channelmonitor.FieldExtraModels, value)
+ })
+ }
+ if value, ok := _u.mutation.GroupName(); ok {
+ _spec.SetField(channelmonitor.FieldGroupName, field.TypeString, value)
+ }
+ if _u.mutation.GroupNameCleared() {
+ _spec.ClearField(channelmonitor.FieldGroupName, field.TypeString)
+ }
+ if value, ok := _u.mutation.Enabled(); ok {
+ _spec.SetField(channelmonitor.FieldEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.IntervalSeconds(); ok {
+ _spec.SetField(channelmonitor.FieldIntervalSeconds, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedIntervalSeconds(); ok {
+ _spec.AddField(channelmonitor.FieldIntervalSeconds, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.LastCheckedAt(); ok {
+ _spec.SetField(channelmonitor.FieldLastCheckedAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastCheckedAtCleared() {
+ _spec.ClearField(channelmonitor.FieldLastCheckedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.CreatedBy(); ok {
+ _spec.SetField(channelmonitor.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedCreatedBy(); ok {
+ _spec.AddField(channelmonitor.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitor.FieldExtraHeaders, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverrideMode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitor.FieldBodyOverride, field.TypeJSON, value)
+ }
+ if _u.mutation.BodyOverrideCleared() {
+ _spec.ClearField(channelmonitor.FieldBodyOverride, field.TypeJSON)
+ }
+ if _u.mutation.HistoryCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.HistoryTable,
+ Columns: []string{channelmonitor.HistoryColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedHistoryIDs(); len(nodes) > 0 && !_u.mutation.HistoryCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.HistoryTable,
+ Columns: []string{channelmonitor.HistoryColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.HistoryIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.HistoryTable,
+ Columns: []string{channelmonitor.HistoryColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.DailyRollupsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedDailyRollupsIDs(); len(nodes) > 0 && !_u.mutation.DailyRollupsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.DailyRollupsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: channelmonitor.DailyRollupsTable,
+ Columns: []string{channelmonitor.DailyRollupsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.RequestTemplateCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: false,
+ Table: channelmonitor.RequestTemplateTable,
+ Columns: []string{channelmonitor.RequestTemplateColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RequestTemplateIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: false,
+ Table: channelmonitor.RequestTemplateTable,
+ Columns: []string{channelmonitor.RequestTemplateColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &ChannelMonitor{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitor.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/channelmonitordailyrollup.go b/backend/ent/channelmonitordailyrollup.go
new file mode 100644
index 00000000..78a5f489
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup.go
@@ -0,0 +1,278 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+)
+
+// ChannelMonitorDailyRollup is the model entity for the ChannelMonitorDailyRollup schema.
+type ChannelMonitorDailyRollup struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // MonitorID holds the value of the "monitor_id" field.
+ MonitorID int64 `json:"monitor_id,omitempty"`
+ // Model holds the value of the "model" field.
+ Model string `json:"model,omitempty"`
+ // BucketDate holds the value of the "bucket_date" field.
+ BucketDate time.Time `json:"bucket_date,omitempty"`
+ // TotalChecks holds the value of the "total_checks" field.
+ TotalChecks int `json:"total_checks,omitempty"`
+ // OkCount holds the value of the "ok_count" field.
+ OkCount int `json:"ok_count,omitempty"`
+ // OperationalCount holds the value of the "operational_count" field.
+ OperationalCount int `json:"operational_count,omitempty"`
+ // DegradedCount holds the value of the "degraded_count" field.
+ DegradedCount int `json:"degraded_count,omitempty"`
+ // FailedCount holds the value of the "failed_count" field.
+ FailedCount int `json:"failed_count,omitempty"`
+ // ErrorCount holds the value of the "error_count" field.
+ ErrorCount int `json:"error_count,omitempty"`
+ // SumLatencyMs holds the value of the "sum_latency_ms" field.
+ SumLatencyMs int64 `json:"sum_latency_ms,omitempty"`
+ // CountLatency holds the value of the "count_latency" field.
+ CountLatency int `json:"count_latency,omitempty"`
+ // SumPingLatencyMs holds the value of the "sum_ping_latency_ms" field.
+ SumPingLatencyMs int64 `json:"sum_ping_latency_ms,omitempty"`
+ // CountPingLatency holds the value of the "count_ping_latency" field.
+ CountPingLatency int `json:"count_ping_latency,omitempty"`
+ // ComputedAt holds the value of the "computed_at" field.
+ ComputedAt time.Time `json:"computed_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the ChannelMonitorDailyRollupQuery when eager-loading is set.
+ Edges ChannelMonitorDailyRollupEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// ChannelMonitorDailyRollupEdges holds the relations/edges for other nodes in the graph.
+type ChannelMonitorDailyRollupEdges struct {
+ // Monitor holds the value of the monitor edge.
+ Monitor *ChannelMonitor `json:"monitor,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// MonitorOrErr returns the Monitor value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e ChannelMonitorDailyRollupEdges) MonitorOrErr() (*ChannelMonitor, error) {
+ if e.Monitor != nil {
+ return e.Monitor, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: channelmonitor.Label}
+ }
+ return nil, &NotLoadedError{edge: "monitor"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*ChannelMonitorDailyRollup) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitordailyrollup.FieldID, channelmonitordailyrollup.FieldMonitorID, channelmonitordailyrollup.FieldTotalChecks, channelmonitordailyrollup.FieldOkCount, channelmonitordailyrollup.FieldOperationalCount, channelmonitordailyrollup.FieldDegradedCount, channelmonitordailyrollup.FieldFailedCount, channelmonitordailyrollup.FieldErrorCount, channelmonitordailyrollup.FieldSumLatencyMs, channelmonitordailyrollup.FieldCountLatency, channelmonitordailyrollup.FieldSumPingLatencyMs, channelmonitordailyrollup.FieldCountPingLatency:
+ values[i] = new(sql.NullInt64)
+ case channelmonitordailyrollup.FieldModel:
+ values[i] = new(sql.NullString)
+ case channelmonitordailyrollup.FieldBucketDate, channelmonitordailyrollup.FieldComputedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the ChannelMonitorDailyRollup fields.
+func (_m *ChannelMonitorDailyRollup) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitordailyrollup.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case channelmonitordailyrollup.FieldMonitorID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field monitor_id", values[i])
+ } else if value.Valid {
+ _m.MonitorID = value.Int64
+ }
+ case channelmonitordailyrollup.FieldModel:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field model", values[i])
+ } else if value.Valid {
+ _m.Model = value.String
+ }
+ case channelmonitordailyrollup.FieldBucketDate:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field bucket_date", values[i])
+ } else if value.Valid {
+ _m.BucketDate = value.Time
+ }
+ case channelmonitordailyrollup.FieldTotalChecks:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field total_checks", values[i])
+ } else if value.Valid {
+ _m.TotalChecks = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldOkCount:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field ok_count", values[i])
+ } else if value.Valid {
+ _m.OkCount = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldOperationalCount:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field operational_count", values[i])
+ } else if value.Valid {
+ _m.OperationalCount = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldDegradedCount:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field degraded_count", values[i])
+ } else if value.Valid {
+ _m.DegradedCount = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldFailedCount:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field failed_count", values[i])
+ } else if value.Valid {
+ _m.FailedCount = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldErrorCount:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field error_count", values[i])
+ } else if value.Valid {
+ _m.ErrorCount = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field sum_latency_ms", values[i])
+ } else if value.Valid {
+ _m.SumLatencyMs = value.Int64
+ }
+ case channelmonitordailyrollup.FieldCountLatency:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field count_latency", values[i])
+ } else if value.Valid {
+ _m.CountLatency = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field sum_ping_latency_ms", values[i])
+ } else if value.Valid {
+ _m.SumPingLatencyMs = value.Int64
+ }
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field count_ping_latency", values[i])
+ } else if value.Valid {
+ _m.CountPingLatency = int(value.Int64)
+ }
+ case channelmonitordailyrollup.FieldComputedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field computed_at", values[i])
+ } else if value.Valid {
+ _m.ComputedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the ChannelMonitorDailyRollup.
+// This includes values selected through modifiers, order, etc.
+func (_m *ChannelMonitorDailyRollup) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryMonitor queries the "monitor" edge of the ChannelMonitorDailyRollup entity.
+func (_m *ChannelMonitorDailyRollup) QueryMonitor() *ChannelMonitorQuery {
+ return NewChannelMonitorDailyRollupClient(_m.config).QueryMonitor(_m)
+}
+
+// Update returns a builder for updating this ChannelMonitorDailyRollup.
+// Note that you need to call ChannelMonitorDailyRollup.Unwrap() before calling this method if this ChannelMonitorDailyRollup
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *ChannelMonitorDailyRollup) Update() *ChannelMonitorDailyRollupUpdateOne {
+ return NewChannelMonitorDailyRollupClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the ChannelMonitorDailyRollup entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *ChannelMonitorDailyRollup) Unwrap() *ChannelMonitorDailyRollup {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: ChannelMonitorDailyRollup is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *ChannelMonitorDailyRollup) String() string {
+ var builder strings.Builder
+ builder.WriteString("ChannelMonitorDailyRollup(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("monitor_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.MonitorID))
+ builder.WriteString(", ")
+ builder.WriteString("model=")
+ builder.WriteString(_m.Model)
+ builder.WriteString(", ")
+ builder.WriteString("bucket_date=")
+ builder.WriteString(_m.BucketDate.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("total_checks=")
+ builder.WriteString(fmt.Sprintf("%v", _m.TotalChecks))
+ builder.WriteString(", ")
+ builder.WriteString("ok_count=")
+ builder.WriteString(fmt.Sprintf("%v", _m.OkCount))
+ builder.WriteString(", ")
+ builder.WriteString("operational_count=")
+ builder.WriteString(fmt.Sprintf("%v", _m.OperationalCount))
+ builder.WriteString(", ")
+ builder.WriteString("degraded_count=")
+ builder.WriteString(fmt.Sprintf("%v", _m.DegradedCount))
+ builder.WriteString(", ")
+ builder.WriteString("failed_count=")
+ builder.WriteString(fmt.Sprintf("%v", _m.FailedCount))
+ builder.WriteString(", ")
+ builder.WriteString("error_count=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ErrorCount))
+ builder.WriteString(", ")
+ builder.WriteString("sum_latency_ms=")
+ builder.WriteString(fmt.Sprintf("%v", _m.SumLatencyMs))
+ builder.WriteString(", ")
+ builder.WriteString("count_latency=")
+ builder.WriteString(fmt.Sprintf("%v", _m.CountLatency))
+ builder.WriteString(", ")
+ builder.WriteString("sum_ping_latency_ms=")
+ builder.WriteString(fmt.Sprintf("%v", _m.SumPingLatencyMs))
+ builder.WriteString(", ")
+ builder.WriteString("count_ping_latency=")
+ builder.WriteString(fmt.Sprintf("%v", _m.CountPingLatency))
+ builder.WriteString(", ")
+ builder.WriteString("computed_at=")
+ builder.WriteString(_m.ComputedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// ChannelMonitorDailyRollups is a parsable slice of ChannelMonitorDailyRollup.
+type ChannelMonitorDailyRollups []*ChannelMonitorDailyRollup
diff --git a/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go b/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go
new file mode 100644
index 00000000..e7cb9307
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup/channelmonitordailyrollup.go
@@ -0,0 +1,206 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitordailyrollup
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the channelmonitordailyrollup type in the database.
+ Label = "channel_monitor_daily_rollup"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldMonitorID holds the string denoting the monitor_id field in the database.
+ FieldMonitorID = "monitor_id"
+ // FieldModel holds the string denoting the model field in the database.
+ FieldModel = "model"
+ // FieldBucketDate holds the string denoting the bucket_date field in the database.
+ FieldBucketDate = "bucket_date"
+ // FieldTotalChecks holds the string denoting the total_checks field in the database.
+ FieldTotalChecks = "total_checks"
+ // FieldOkCount holds the string denoting the ok_count field in the database.
+ FieldOkCount = "ok_count"
+ // FieldOperationalCount holds the string denoting the operational_count field in the database.
+ FieldOperationalCount = "operational_count"
+ // FieldDegradedCount holds the string denoting the degraded_count field in the database.
+ FieldDegradedCount = "degraded_count"
+ // FieldFailedCount holds the string denoting the failed_count field in the database.
+ FieldFailedCount = "failed_count"
+ // FieldErrorCount holds the string denoting the error_count field in the database.
+ FieldErrorCount = "error_count"
+ // FieldSumLatencyMs holds the string denoting the sum_latency_ms field in the database.
+ FieldSumLatencyMs = "sum_latency_ms"
+ // FieldCountLatency holds the string denoting the count_latency field in the database.
+ FieldCountLatency = "count_latency"
+ // FieldSumPingLatencyMs holds the string denoting the sum_ping_latency_ms field in the database.
+ FieldSumPingLatencyMs = "sum_ping_latency_ms"
+ // FieldCountPingLatency holds the string denoting the count_ping_latency field in the database.
+ FieldCountPingLatency = "count_ping_latency"
+ // FieldComputedAt holds the string denoting the computed_at field in the database.
+ FieldComputedAt = "computed_at"
+ // EdgeMonitor holds the string denoting the monitor edge name in mutations.
+ EdgeMonitor = "monitor"
+ // Table holds the table name of the channelmonitordailyrollup in the database.
+ Table = "channel_monitor_daily_rollups"
+ // MonitorTable is the table that holds the monitor relation/edge.
+ MonitorTable = "channel_monitor_daily_rollups"
+ // MonitorInverseTable is the table name for the ChannelMonitor entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitor" package.
+ MonitorInverseTable = "channel_monitors"
+ // MonitorColumn is the table column denoting the monitor relation/edge.
+ MonitorColumn = "monitor_id"
+)
+
+// Columns holds all SQL columns for channelmonitordailyrollup fields.
+var Columns = []string{
+ FieldID,
+ FieldMonitorID,
+ FieldModel,
+ FieldBucketDate,
+ FieldTotalChecks,
+ FieldOkCount,
+ FieldOperationalCount,
+ FieldDegradedCount,
+ FieldFailedCount,
+ FieldErrorCount,
+ FieldSumLatencyMs,
+ FieldCountLatency,
+ FieldSumPingLatencyMs,
+ FieldCountPingLatency,
+ FieldComputedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // ModelValidator is a validator for the "model" field. It is called by the builders before save.
+ ModelValidator func(string) error
+ // DefaultTotalChecks holds the default value on creation for the "total_checks" field.
+ DefaultTotalChecks int
+ // DefaultOkCount holds the default value on creation for the "ok_count" field.
+ DefaultOkCount int
+ // DefaultOperationalCount holds the default value on creation for the "operational_count" field.
+ DefaultOperationalCount int
+ // DefaultDegradedCount holds the default value on creation for the "degraded_count" field.
+ DefaultDegradedCount int
+ // DefaultFailedCount holds the default value on creation for the "failed_count" field.
+ DefaultFailedCount int
+ // DefaultErrorCount holds the default value on creation for the "error_count" field.
+ DefaultErrorCount int
+ // DefaultSumLatencyMs holds the default value on creation for the "sum_latency_ms" field.
+ DefaultSumLatencyMs int64
+ // DefaultCountLatency holds the default value on creation for the "count_latency" field.
+ DefaultCountLatency int
+ // DefaultSumPingLatencyMs holds the default value on creation for the "sum_ping_latency_ms" field.
+ DefaultSumPingLatencyMs int64
+ // DefaultCountPingLatency holds the default value on creation for the "count_ping_latency" field.
+ DefaultCountPingLatency int
+ // DefaultComputedAt holds the default value on creation for the "computed_at" field.
+ DefaultComputedAt func() time.Time
+ // UpdateDefaultComputedAt holds the default value on update for the "computed_at" field.
+ UpdateDefaultComputedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the ChannelMonitorDailyRollup queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByMonitorID orders the results by the monitor_id field.
+func ByMonitorID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldMonitorID, opts...).ToFunc()
+}
+
+// ByModel orders the results by the model field.
+func ByModel(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldModel, opts...).ToFunc()
+}
+
+// ByBucketDate orders the results by the bucket_date field.
+func ByBucketDate(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBucketDate, opts...).ToFunc()
+}
+
+// ByTotalChecks orders the results by the total_checks field.
+func ByTotalChecks(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotalChecks, opts...).ToFunc()
+}
+
+// ByOkCount orders the results by the ok_count field.
+func ByOkCount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldOkCount, opts...).ToFunc()
+}
+
+// ByOperationalCount orders the results by the operational_count field.
+func ByOperationalCount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldOperationalCount, opts...).ToFunc()
+}
+
+// ByDegradedCount orders the results by the degraded_count field.
+func ByDegradedCount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDegradedCount, opts...).ToFunc()
+}
+
+// ByFailedCount orders the results by the failed_count field.
+func ByFailedCount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldFailedCount, opts...).ToFunc()
+}
+
+// ByErrorCount orders the results by the error_count field.
+func ByErrorCount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldErrorCount, opts...).ToFunc()
+}
+
+// BySumLatencyMs orders the results by the sum_latency_ms field.
+func BySumLatencyMs(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSumLatencyMs, opts...).ToFunc()
+}
+
+// ByCountLatency orders the results by the count_latency field.
+func ByCountLatency(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCountLatency, opts...).ToFunc()
+}
+
+// BySumPingLatencyMs orders the results by the sum_ping_latency_ms field.
+func BySumPingLatencyMs(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSumPingLatencyMs, opts...).ToFunc()
+}
+
+// ByCountPingLatency orders the results by the count_ping_latency field.
+func ByCountPingLatency(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCountPingLatency, opts...).ToFunc()
+}
+
+// ByComputedAt orders the results by the computed_at field.
+func ByComputedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldComputedAt, opts...).ToFunc()
+}
+
+// ByMonitorField orders the results by monitor field.
+func ByMonitorField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newMonitorStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newMonitorStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(MonitorInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, MonitorTable, MonitorColumn),
+ )
+}
diff --git a/backend/ent/channelmonitordailyrollup/where.go b/backend/ent/channelmonitordailyrollup/where.go
new file mode 100644
index 00000000..424c957e
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup/where.go
@@ -0,0 +1,729 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitordailyrollup
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldID, id))
+}
+
+// MonitorID applies equality check predicate on the "monitor_id" field. It's identical to MonitorIDEQ.
+func MonitorID(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldMonitorID, v))
+}
+
+// Model applies equality check predicate on the "model" field. It's identical to ModelEQ.
+func Model(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldModel, v))
+}
+
+// BucketDate applies equality check predicate on the "bucket_date" field. It's identical to BucketDateEQ.
+func BucketDate(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldBucketDate, v))
+}
+
+// TotalChecks applies equality check predicate on the "total_checks" field. It's identical to TotalChecksEQ.
+func TotalChecks(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldTotalChecks, v))
+}
+
+// OkCount applies equality check predicate on the "ok_count" field. It's identical to OkCountEQ.
+func OkCount(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOkCount, v))
+}
+
+// OperationalCount applies equality check predicate on the "operational_count" field. It's identical to OperationalCountEQ.
+func OperationalCount(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOperationalCount, v))
+}
+
+// DegradedCount applies equality check predicate on the "degraded_count" field. It's identical to DegradedCountEQ.
+func DegradedCount(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldDegradedCount, v))
+}
+
+// FailedCount applies equality check predicate on the "failed_count" field. It's identical to FailedCountEQ.
+func FailedCount(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldFailedCount, v))
+}
+
+// ErrorCount applies equality check predicate on the "error_count" field. It's identical to ErrorCountEQ.
+func ErrorCount(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldErrorCount, v))
+}
+
+// SumLatencyMs applies equality check predicate on the "sum_latency_ms" field. It's identical to SumLatencyMsEQ.
+func SumLatencyMs(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumLatencyMs, v))
+}
+
+// CountLatency applies equality check predicate on the "count_latency" field. It's identical to CountLatencyEQ.
+func CountLatency(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountLatency, v))
+}
+
+// SumPingLatencyMs applies equality check predicate on the "sum_ping_latency_ms" field. It's identical to SumPingLatencyMsEQ.
+func SumPingLatencyMs(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumPingLatencyMs, v))
+}
+
+// CountPingLatency applies equality check predicate on the "count_ping_latency" field. It's identical to CountPingLatencyEQ.
+func CountPingLatency(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountPingLatency, v))
+}
+
+// ComputedAt applies equality check predicate on the "computed_at" field. It's identical to ComputedAtEQ.
+func ComputedAt(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldComputedAt, v))
+}
+
+// MonitorIDEQ applies the EQ predicate on the "monitor_id" field.
+func MonitorIDEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldMonitorID, v))
+}
+
+// MonitorIDNEQ applies the NEQ predicate on the "monitor_id" field.
+func MonitorIDNEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldMonitorID, v))
+}
+
+// MonitorIDIn applies the In predicate on the "monitor_id" field.
+func MonitorIDIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldMonitorID, vs...))
+}
+
+// MonitorIDNotIn applies the NotIn predicate on the "monitor_id" field.
+func MonitorIDNotIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldMonitorID, vs...))
+}
+
+// ModelEQ applies the EQ predicate on the "model" field.
+func ModelEQ(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldModel, v))
+}
+
+// ModelNEQ applies the NEQ predicate on the "model" field.
+func ModelNEQ(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldModel, v))
+}
+
+// ModelIn applies the In predicate on the "model" field.
+func ModelIn(vs ...string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldModel, vs...))
+}
+
+// ModelNotIn applies the NotIn predicate on the "model" field.
+func ModelNotIn(vs ...string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldModel, vs...))
+}
+
+// ModelGT applies the GT predicate on the "model" field.
+func ModelGT(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldModel, v))
+}
+
+// ModelGTE applies the GTE predicate on the "model" field.
+func ModelGTE(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldModel, v))
+}
+
+// ModelLT applies the LT predicate on the "model" field.
+func ModelLT(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldModel, v))
+}
+
+// ModelLTE applies the LTE predicate on the "model" field.
+func ModelLTE(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldModel, v))
+}
+
+// ModelContains applies the Contains predicate on the "model" field.
+func ModelContains(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldContains(FieldModel, v))
+}
+
+// ModelHasPrefix applies the HasPrefix predicate on the "model" field.
+func ModelHasPrefix(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldHasPrefix(FieldModel, v))
+}
+
+// ModelHasSuffix applies the HasSuffix predicate on the "model" field.
+func ModelHasSuffix(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldHasSuffix(FieldModel, v))
+}
+
+// ModelEqualFold applies the EqualFold predicate on the "model" field.
+func ModelEqualFold(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEqualFold(FieldModel, v))
+}
+
+// ModelContainsFold applies the ContainsFold predicate on the "model" field.
+func ModelContainsFold(v string) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldContainsFold(FieldModel, v))
+}
+
+// BucketDateEQ applies the EQ predicate on the "bucket_date" field.
+func BucketDateEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldBucketDate, v))
+}
+
+// BucketDateNEQ applies the NEQ predicate on the "bucket_date" field.
+func BucketDateNEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldBucketDate, v))
+}
+
+// BucketDateIn applies the In predicate on the "bucket_date" field.
+func BucketDateIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldBucketDate, vs...))
+}
+
+// BucketDateNotIn applies the NotIn predicate on the "bucket_date" field.
+func BucketDateNotIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldBucketDate, vs...))
+}
+
+// BucketDateGT applies the GT predicate on the "bucket_date" field.
+func BucketDateGT(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldBucketDate, v))
+}
+
+// BucketDateGTE applies the GTE predicate on the "bucket_date" field.
+func BucketDateGTE(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldBucketDate, v))
+}
+
+// BucketDateLT applies the LT predicate on the "bucket_date" field.
+func BucketDateLT(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldBucketDate, v))
+}
+
+// BucketDateLTE applies the LTE predicate on the "bucket_date" field.
+func BucketDateLTE(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldBucketDate, v))
+}
+
+// TotalChecksEQ applies the EQ predicate on the "total_checks" field.
+func TotalChecksEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldTotalChecks, v))
+}
+
+// TotalChecksNEQ applies the NEQ predicate on the "total_checks" field.
+func TotalChecksNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldTotalChecks, v))
+}
+
+// TotalChecksIn applies the In predicate on the "total_checks" field.
+func TotalChecksIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldTotalChecks, vs...))
+}
+
+// TotalChecksNotIn applies the NotIn predicate on the "total_checks" field.
+func TotalChecksNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldTotalChecks, vs...))
+}
+
+// TotalChecksGT applies the GT predicate on the "total_checks" field.
+func TotalChecksGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldTotalChecks, v))
+}
+
+// TotalChecksGTE applies the GTE predicate on the "total_checks" field.
+func TotalChecksGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldTotalChecks, v))
+}
+
+// TotalChecksLT applies the LT predicate on the "total_checks" field.
+func TotalChecksLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldTotalChecks, v))
+}
+
+// TotalChecksLTE applies the LTE predicate on the "total_checks" field.
+func TotalChecksLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldTotalChecks, v))
+}
+
+// OkCountEQ applies the EQ predicate on the "ok_count" field.
+func OkCountEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOkCount, v))
+}
+
+// OkCountNEQ applies the NEQ predicate on the "ok_count" field.
+func OkCountNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldOkCount, v))
+}
+
+// OkCountIn applies the In predicate on the "ok_count" field.
+func OkCountIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldOkCount, vs...))
+}
+
+// OkCountNotIn applies the NotIn predicate on the "ok_count" field.
+func OkCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldOkCount, vs...))
+}
+
+// OkCountGT applies the GT predicate on the "ok_count" field.
+func OkCountGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldOkCount, v))
+}
+
+// OkCountGTE applies the GTE predicate on the "ok_count" field.
+func OkCountGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldOkCount, v))
+}
+
+// OkCountLT applies the LT predicate on the "ok_count" field.
+func OkCountLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldOkCount, v))
+}
+
+// OkCountLTE applies the LTE predicate on the "ok_count" field.
+func OkCountLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldOkCount, v))
+}
+
+// OperationalCountEQ applies the EQ predicate on the "operational_count" field.
+func OperationalCountEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldOperationalCount, v))
+}
+
+// OperationalCountNEQ applies the NEQ predicate on the "operational_count" field.
+func OperationalCountNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldOperationalCount, v))
+}
+
+// OperationalCountIn applies the In predicate on the "operational_count" field.
+func OperationalCountIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldOperationalCount, vs...))
+}
+
+// OperationalCountNotIn applies the NotIn predicate on the "operational_count" field.
+func OperationalCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldOperationalCount, vs...))
+}
+
+// OperationalCountGT applies the GT predicate on the "operational_count" field.
+func OperationalCountGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldOperationalCount, v))
+}
+
+// OperationalCountGTE applies the GTE predicate on the "operational_count" field.
+func OperationalCountGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldOperationalCount, v))
+}
+
+// OperationalCountLT applies the LT predicate on the "operational_count" field.
+func OperationalCountLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldOperationalCount, v))
+}
+
+// OperationalCountLTE applies the LTE predicate on the "operational_count" field.
+func OperationalCountLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldOperationalCount, v))
+}
+
+// DegradedCountEQ applies the EQ predicate on the "degraded_count" field.
+func DegradedCountEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldDegradedCount, v))
+}
+
+// DegradedCountNEQ applies the NEQ predicate on the "degraded_count" field.
+func DegradedCountNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldDegradedCount, v))
+}
+
+// DegradedCountIn applies the In predicate on the "degraded_count" field.
+func DegradedCountIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldDegradedCount, vs...))
+}
+
+// DegradedCountNotIn applies the NotIn predicate on the "degraded_count" field.
+func DegradedCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldDegradedCount, vs...))
+}
+
+// DegradedCountGT applies the GT predicate on the "degraded_count" field.
+func DegradedCountGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldDegradedCount, v))
+}
+
+// DegradedCountGTE applies the GTE predicate on the "degraded_count" field.
+func DegradedCountGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldDegradedCount, v))
+}
+
+// DegradedCountLT applies the LT predicate on the "degraded_count" field.
+func DegradedCountLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldDegradedCount, v))
+}
+
+// DegradedCountLTE applies the LTE predicate on the "degraded_count" field.
+func DegradedCountLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldDegradedCount, v))
+}
+
+// FailedCountEQ applies the EQ predicate on the "failed_count" field.
+func FailedCountEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldFailedCount, v))
+}
+
+// FailedCountNEQ applies the NEQ predicate on the "failed_count" field.
+func FailedCountNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldFailedCount, v))
+}
+
+// FailedCountIn applies the In predicate on the "failed_count" field.
+func FailedCountIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldFailedCount, vs...))
+}
+
+// FailedCountNotIn applies the NotIn predicate on the "failed_count" field.
+func FailedCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldFailedCount, vs...))
+}
+
+// FailedCountGT applies the GT predicate on the "failed_count" field.
+func FailedCountGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldFailedCount, v))
+}
+
+// FailedCountGTE applies the GTE predicate on the "failed_count" field.
+func FailedCountGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldFailedCount, v))
+}
+
+// FailedCountLT applies the LT predicate on the "failed_count" field.
+func FailedCountLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldFailedCount, v))
+}
+
+// FailedCountLTE applies the LTE predicate on the "failed_count" field.
+func FailedCountLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldFailedCount, v))
+}
+
+// ErrorCountEQ applies the EQ predicate on the "error_count" field.
+func ErrorCountEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldErrorCount, v))
+}
+
+// ErrorCountNEQ applies the NEQ predicate on the "error_count" field.
+func ErrorCountNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldErrorCount, v))
+}
+
+// ErrorCountIn applies the In predicate on the "error_count" field.
+func ErrorCountIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldErrorCount, vs...))
+}
+
+// ErrorCountNotIn applies the NotIn predicate on the "error_count" field.
+func ErrorCountNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldErrorCount, vs...))
+}
+
+// ErrorCountGT applies the GT predicate on the "error_count" field.
+func ErrorCountGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldErrorCount, v))
+}
+
+// ErrorCountGTE applies the GTE predicate on the "error_count" field.
+func ErrorCountGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldErrorCount, v))
+}
+
+// ErrorCountLT applies the LT predicate on the "error_count" field.
+func ErrorCountLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldErrorCount, v))
+}
+
+// ErrorCountLTE applies the LTE predicate on the "error_count" field.
+func ErrorCountLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldErrorCount, v))
+}
+
+// SumLatencyMsEQ applies the EQ predicate on the "sum_latency_ms" field.
+func SumLatencyMsEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumLatencyMs, v))
+}
+
+// SumLatencyMsNEQ applies the NEQ predicate on the "sum_latency_ms" field.
+func SumLatencyMsNEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldSumLatencyMs, v))
+}
+
+// SumLatencyMsIn applies the In predicate on the "sum_latency_ms" field.
+func SumLatencyMsIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldSumLatencyMs, vs...))
+}
+
+// SumLatencyMsNotIn applies the NotIn predicate on the "sum_latency_ms" field.
+func SumLatencyMsNotIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldSumLatencyMs, vs...))
+}
+
+// SumLatencyMsGT applies the GT predicate on the "sum_latency_ms" field.
+func SumLatencyMsGT(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldSumLatencyMs, v))
+}
+
+// SumLatencyMsGTE applies the GTE predicate on the "sum_latency_ms" field.
+func SumLatencyMsGTE(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldSumLatencyMs, v))
+}
+
+// SumLatencyMsLT applies the LT predicate on the "sum_latency_ms" field.
+func SumLatencyMsLT(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldSumLatencyMs, v))
+}
+
+// SumLatencyMsLTE applies the LTE predicate on the "sum_latency_ms" field.
+func SumLatencyMsLTE(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldSumLatencyMs, v))
+}
+
+// CountLatencyEQ applies the EQ predicate on the "count_latency" field.
+func CountLatencyEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountLatency, v))
+}
+
+// CountLatencyNEQ applies the NEQ predicate on the "count_latency" field.
+func CountLatencyNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldCountLatency, v))
+}
+
+// CountLatencyIn applies the In predicate on the "count_latency" field.
+func CountLatencyIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldCountLatency, vs...))
+}
+
+// CountLatencyNotIn applies the NotIn predicate on the "count_latency" field.
+func CountLatencyNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldCountLatency, vs...))
+}
+
+// CountLatencyGT applies the GT predicate on the "count_latency" field.
+func CountLatencyGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldCountLatency, v))
+}
+
+// CountLatencyGTE applies the GTE predicate on the "count_latency" field.
+func CountLatencyGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldCountLatency, v))
+}
+
+// CountLatencyLT applies the LT predicate on the "count_latency" field.
+func CountLatencyLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldCountLatency, v))
+}
+
+// CountLatencyLTE applies the LTE predicate on the "count_latency" field.
+func CountLatencyLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldCountLatency, v))
+}
+
+// SumPingLatencyMsEQ applies the EQ predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldSumPingLatencyMs, v))
+}
+
+// SumPingLatencyMsNEQ applies the NEQ predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsNEQ(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldSumPingLatencyMs, v))
+}
+
+// SumPingLatencyMsIn applies the In predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldSumPingLatencyMs, vs...))
+}
+
+// SumPingLatencyMsNotIn applies the NotIn predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsNotIn(vs ...int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldSumPingLatencyMs, vs...))
+}
+
+// SumPingLatencyMsGT applies the GT predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsGT(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldSumPingLatencyMs, v))
+}
+
+// SumPingLatencyMsGTE applies the GTE predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsGTE(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldSumPingLatencyMs, v))
+}
+
+// SumPingLatencyMsLT applies the LT predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsLT(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldSumPingLatencyMs, v))
+}
+
+// SumPingLatencyMsLTE applies the LTE predicate on the "sum_ping_latency_ms" field.
+func SumPingLatencyMsLTE(v int64) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldSumPingLatencyMs, v))
+}
+
+// CountPingLatencyEQ applies the EQ predicate on the "count_ping_latency" field.
+func CountPingLatencyEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldCountPingLatency, v))
+}
+
+// CountPingLatencyNEQ applies the NEQ predicate on the "count_ping_latency" field.
+func CountPingLatencyNEQ(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldCountPingLatency, v))
+}
+
+// CountPingLatencyIn applies the In predicate on the "count_ping_latency" field.
+func CountPingLatencyIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldCountPingLatency, vs...))
+}
+
+// CountPingLatencyNotIn applies the NotIn predicate on the "count_ping_latency" field.
+func CountPingLatencyNotIn(vs ...int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldCountPingLatency, vs...))
+}
+
+// CountPingLatencyGT applies the GT predicate on the "count_ping_latency" field.
+func CountPingLatencyGT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldCountPingLatency, v))
+}
+
+// CountPingLatencyGTE applies the GTE predicate on the "count_ping_latency" field.
+func CountPingLatencyGTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldCountPingLatency, v))
+}
+
+// CountPingLatencyLT applies the LT predicate on the "count_ping_latency" field.
+func CountPingLatencyLT(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldCountPingLatency, v))
+}
+
+// CountPingLatencyLTE applies the LTE predicate on the "count_ping_latency" field.
+func CountPingLatencyLTE(v int) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldCountPingLatency, v))
+}
+
+// ComputedAtEQ applies the EQ predicate on the "computed_at" field.
+func ComputedAtEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldEQ(FieldComputedAt, v))
+}
+
+// ComputedAtNEQ applies the NEQ predicate on the "computed_at" field.
+func ComputedAtNEQ(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNEQ(FieldComputedAt, v))
+}
+
+// ComputedAtIn applies the In predicate on the "computed_at" field.
+func ComputedAtIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldIn(FieldComputedAt, vs...))
+}
+
+// ComputedAtNotIn applies the NotIn predicate on the "computed_at" field.
+func ComputedAtNotIn(vs ...time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldNotIn(FieldComputedAt, vs...))
+}
+
+// ComputedAtGT applies the GT predicate on the "computed_at" field.
+func ComputedAtGT(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGT(FieldComputedAt, v))
+}
+
+// ComputedAtGTE applies the GTE predicate on the "computed_at" field.
+func ComputedAtGTE(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldGTE(FieldComputedAt, v))
+}
+
+// ComputedAtLT applies the LT predicate on the "computed_at" field.
+func ComputedAtLT(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLT(FieldComputedAt, v))
+}
+
+// ComputedAtLTE applies the LTE predicate on the "computed_at" field.
+func ComputedAtLTE(v time.Time) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.FieldLTE(FieldComputedAt, v))
+}
+
+// HasMonitor applies the HasEdge predicate on the "monitor" edge.
+func HasMonitor() predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, MonitorTable, MonitorColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasMonitorWith applies the HasEdge predicate on the "monitor" edge with a given conditions (other predicates).
+func HasMonitorWith(preds ...predicate.ChannelMonitor) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(func(s *sql.Selector) {
+ step := newMonitorStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.ChannelMonitorDailyRollup) predicate.ChannelMonitorDailyRollup {
+ return predicate.ChannelMonitorDailyRollup(sql.NotPredicates(p))
+}
diff --git a/backend/ent/channelmonitordailyrollup_create.go b/backend/ent/channelmonitordailyrollup_create.go
new file mode 100644
index 00000000..5f8754ba
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup_create.go
@@ -0,0 +1,1509 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+)
+
+// ChannelMonitorDailyRollupCreate is the builder for creating a ChannelMonitorDailyRollup entity.
+type ChannelMonitorDailyRollupCreate struct {
+ config
+ mutation *ChannelMonitorDailyRollupMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetMonitorID(v int64) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetMonitorID(v)
+ return _c
+}
+
+// SetModel sets the "model" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetModel(v string) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetModel(v)
+ return _c
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetBucketDate(v)
+ return _c
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetTotalChecks(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetTotalChecks(v)
+ return _c
+}
+
+// SetNillableTotalChecks sets the "total_checks" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableTotalChecks(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetTotalChecks(*v)
+ }
+ return _c
+}
+
+// SetOkCount sets the "ok_count" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetOkCount(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetOkCount(v)
+ return _c
+}
+
+// SetNillableOkCount sets the "ok_count" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableOkCount(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetOkCount(*v)
+ }
+ return _c
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetOperationalCount(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetOperationalCount(v)
+ return _c
+}
+
+// SetNillableOperationalCount sets the "operational_count" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableOperationalCount(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetOperationalCount(*v)
+ }
+ return _c
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetDegradedCount(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetDegradedCount(v)
+ return _c
+}
+
+// SetNillableDegradedCount sets the "degraded_count" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableDegradedCount(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetDegradedCount(*v)
+ }
+ return _c
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetFailedCount(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetFailedCount(v)
+ return _c
+}
+
+// SetNillableFailedCount sets the "failed_count" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableFailedCount(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetFailedCount(*v)
+ }
+ return _c
+}
+
+// SetErrorCount sets the "error_count" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetErrorCount(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetErrorCount(v)
+ return _c
+}
+
+// SetNillableErrorCount sets the "error_count" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableErrorCount(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetErrorCount(*v)
+ }
+ return _c
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetSumLatencyMs(v)
+ return _c
+}
+
+// SetNillableSumLatencyMs sets the "sum_latency_ms" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableSumLatencyMs(v *int64) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetSumLatencyMs(*v)
+ }
+ return _c
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetCountLatency(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetCountLatency(v)
+ return _c
+}
+
+// SetNillableCountLatency sets the "count_latency" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableCountLatency(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetCountLatency(*v)
+ }
+ return _c
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetSumPingLatencyMs(v)
+ return _c
+}
+
+// SetNillableSumPingLatencyMs sets the "sum_ping_latency_ms" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableSumPingLatencyMs(v *int64) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetSumPingLatencyMs(*v)
+ }
+ return _c
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetCountPingLatency(v int) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetCountPingLatency(v)
+ return _c
+}
+
+// SetNillableCountPingLatency sets the "count_ping_latency" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableCountPingLatency(v *int) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetCountPingLatency(*v)
+ }
+ return _c
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (_c *ChannelMonitorDailyRollupCreate) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupCreate {
+ _c.mutation.SetComputedAt(v)
+ return _c
+}
+
+// SetNillableComputedAt sets the "computed_at" field if the given value is not nil.
+func (_c *ChannelMonitorDailyRollupCreate) SetNillableComputedAt(v *time.Time) *ChannelMonitorDailyRollupCreate {
+ if v != nil {
+ _c.SetComputedAt(*v)
+ }
+ return _c
+}
+
+// SetMonitor sets the "monitor" edge to the ChannelMonitor entity.
+func (_c *ChannelMonitorDailyRollupCreate) SetMonitor(v *ChannelMonitor) *ChannelMonitorDailyRollupCreate {
+ return _c.SetMonitorID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorDailyRollupMutation object of the builder.
+func (_c *ChannelMonitorDailyRollupCreate) Mutation() *ChannelMonitorDailyRollupMutation {
+ return _c.mutation
+}
+
+// Save creates the ChannelMonitorDailyRollup in the database.
+func (_c *ChannelMonitorDailyRollupCreate) Save(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *ChannelMonitorDailyRollupCreate) SaveX(ctx context.Context) *ChannelMonitorDailyRollup {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorDailyRollupCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorDailyRollupCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *ChannelMonitorDailyRollupCreate) defaults() {
+ if _, ok := _c.mutation.TotalChecks(); !ok {
+ v := channelmonitordailyrollup.DefaultTotalChecks
+ _c.mutation.SetTotalChecks(v)
+ }
+ if _, ok := _c.mutation.OkCount(); !ok {
+ v := channelmonitordailyrollup.DefaultOkCount
+ _c.mutation.SetOkCount(v)
+ }
+ if _, ok := _c.mutation.OperationalCount(); !ok {
+ v := channelmonitordailyrollup.DefaultOperationalCount
+ _c.mutation.SetOperationalCount(v)
+ }
+ if _, ok := _c.mutation.DegradedCount(); !ok {
+ v := channelmonitordailyrollup.DefaultDegradedCount
+ _c.mutation.SetDegradedCount(v)
+ }
+ if _, ok := _c.mutation.FailedCount(); !ok {
+ v := channelmonitordailyrollup.DefaultFailedCount
+ _c.mutation.SetFailedCount(v)
+ }
+ if _, ok := _c.mutation.ErrorCount(); !ok {
+ v := channelmonitordailyrollup.DefaultErrorCount
+ _c.mutation.SetErrorCount(v)
+ }
+ if _, ok := _c.mutation.SumLatencyMs(); !ok {
+ v := channelmonitordailyrollup.DefaultSumLatencyMs
+ _c.mutation.SetSumLatencyMs(v)
+ }
+ if _, ok := _c.mutation.CountLatency(); !ok {
+ v := channelmonitordailyrollup.DefaultCountLatency
+ _c.mutation.SetCountLatency(v)
+ }
+ if _, ok := _c.mutation.SumPingLatencyMs(); !ok {
+ v := channelmonitordailyrollup.DefaultSumPingLatencyMs
+ _c.mutation.SetSumPingLatencyMs(v)
+ }
+ if _, ok := _c.mutation.CountPingLatency(); !ok {
+ v := channelmonitordailyrollup.DefaultCountPingLatency
+ _c.mutation.SetCountPingLatency(v)
+ }
+ if _, ok := _c.mutation.ComputedAt(); !ok {
+ v := channelmonitordailyrollup.DefaultComputedAt()
+ _c.mutation.SetComputedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *ChannelMonitorDailyRollupCreate) check() error {
+ if _, ok := _c.mutation.MonitorID(); !ok {
+ return &ValidationError{Name: "monitor_id", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.monitor_id"`)}
+ }
+ if _, ok := _c.mutation.Model(); !ok {
+ return &ValidationError{Name: "model", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.model"`)}
+ }
+ if v, ok := _c.mutation.Model(); ok {
+ if err := channelmonitordailyrollup.ModelValidator(v); err != nil {
+ return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorDailyRollup.model": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.BucketDate(); !ok {
+ return &ValidationError{Name: "bucket_date", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.bucket_date"`)}
+ }
+ if _, ok := _c.mutation.TotalChecks(); !ok {
+ return &ValidationError{Name: "total_checks", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.total_checks"`)}
+ }
+ if _, ok := _c.mutation.OkCount(); !ok {
+ return &ValidationError{Name: "ok_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.ok_count"`)}
+ }
+ if _, ok := _c.mutation.OperationalCount(); !ok {
+ return &ValidationError{Name: "operational_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.operational_count"`)}
+ }
+ if _, ok := _c.mutation.DegradedCount(); !ok {
+ return &ValidationError{Name: "degraded_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.degraded_count"`)}
+ }
+ if _, ok := _c.mutation.FailedCount(); !ok {
+ return &ValidationError{Name: "failed_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.failed_count"`)}
+ }
+ if _, ok := _c.mutation.ErrorCount(); !ok {
+ return &ValidationError{Name: "error_count", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.error_count"`)}
+ }
+ if _, ok := _c.mutation.SumLatencyMs(); !ok {
+ return &ValidationError{Name: "sum_latency_ms", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.sum_latency_ms"`)}
+ }
+ if _, ok := _c.mutation.CountLatency(); !ok {
+ return &ValidationError{Name: "count_latency", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.count_latency"`)}
+ }
+ if _, ok := _c.mutation.SumPingLatencyMs(); !ok {
+ return &ValidationError{Name: "sum_ping_latency_ms", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.sum_ping_latency_ms"`)}
+ }
+ if _, ok := _c.mutation.CountPingLatency(); !ok {
+ return &ValidationError{Name: "count_ping_latency", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.count_ping_latency"`)}
+ }
+ if _, ok := _c.mutation.ComputedAt(); !ok {
+ return &ValidationError{Name: "computed_at", err: errors.New(`ent: missing required field "ChannelMonitorDailyRollup.computed_at"`)}
+ }
+ if len(_c.mutation.MonitorIDs()) == 0 {
+ return &ValidationError{Name: "monitor", err: errors.New(`ent: missing required edge "ChannelMonitorDailyRollup.monitor"`)}
+ }
+ return nil
+}
+
+func (_c *ChannelMonitorDailyRollupCreate) sqlSave(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *ChannelMonitorDailyRollupCreate) createSpec() (*ChannelMonitorDailyRollup, *sqlgraph.CreateSpec) {
+ var (
+ _node = &ChannelMonitorDailyRollup{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(channelmonitordailyrollup.Table, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.Model(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldModel, field.TypeString, value)
+ _node.Model = value
+ }
+ if value, ok := _c.mutation.BucketDate(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldBucketDate, field.TypeTime, value)
+ _node.BucketDate = value
+ }
+ if value, ok := _c.mutation.TotalChecks(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value)
+ _node.TotalChecks = value
+ }
+ if value, ok := _c.mutation.OkCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value)
+ _node.OkCount = value
+ }
+ if value, ok := _c.mutation.OperationalCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value)
+ _node.OperationalCount = value
+ }
+ if value, ok := _c.mutation.DegradedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value)
+ _node.DegradedCount = value
+ }
+ if value, ok := _c.mutation.FailedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value)
+ _node.FailedCount = value
+ }
+ if value, ok := _c.mutation.ErrorCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value)
+ _node.ErrorCount = value
+ }
+ if value, ok := _c.mutation.SumLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value)
+ _node.SumLatencyMs = value
+ }
+ if value, ok := _c.mutation.CountLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value)
+ _node.CountLatency = value
+ }
+ if value, ok := _c.mutation.SumPingLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value)
+ _node.SumPingLatencyMs = value
+ }
+ if value, ok := _c.mutation.CountPingLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value)
+ _node.CountPingLatency = value
+ }
+ if value, ok := _c.mutation.ComputedAt(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldComputedAt, field.TypeTime, value)
+ _node.ComputedAt = value
+ }
+ if nodes := _c.mutation.MonitorIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitordailyrollup.MonitorTable,
+ Columns: []string{channelmonitordailyrollup.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.MonitorID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// SetMonitorID(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorDailyRollupUpsert) {
+// SetMonitorID(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorDailyRollupCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorDailyRollupUpsertOne {
+ _c.conflict = opts
+ return &ChannelMonitorDailyRollupUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorDailyRollupCreate) OnConflictColumns(columns ...string) *ChannelMonitorDailyRollupUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorDailyRollupUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // ChannelMonitorDailyRollupUpsertOne is the builder for "upsert"-ing
+ // one ChannelMonitorDailyRollup node.
+ ChannelMonitorDailyRollupUpsertOne struct {
+ create *ChannelMonitorDailyRollupCreate
+ }
+
+ // ChannelMonitorDailyRollupUpsert is the "OnConflict" setter.
+ ChannelMonitorDailyRollupUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetMonitorID sets the "monitor_id" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldMonitorID, v)
+ return u
+}
+
+// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateMonitorID() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldMonitorID)
+ return u
+}
+
+// SetModel sets the "model" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetModel(v string) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldModel, v)
+ return u
+}
+
+// UpdateModel sets the "model" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateModel() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldModel)
+ return u
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldBucketDate, v)
+ return u
+}
+
+// UpdateBucketDate sets the "bucket_date" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateBucketDate() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldBucketDate)
+ return u
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldTotalChecks, v)
+ return u
+}
+
+// UpdateTotalChecks sets the "total_checks" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateTotalChecks() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldTotalChecks)
+ return u
+}
+
+// AddTotalChecks adds v to the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldTotalChecks, v)
+ return u
+}
+
+// SetOkCount sets the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetOkCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldOkCount, v)
+ return u
+}
+
+// UpdateOkCount sets the "ok_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateOkCount() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldOkCount)
+ return u
+}
+
+// AddOkCount adds v to the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddOkCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldOkCount, v)
+ return u
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldOperationalCount, v)
+ return u
+}
+
+// UpdateOperationalCount sets the "operational_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateOperationalCount() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldOperationalCount)
+ return u
+}
+
+// AddOperationalCount adds v to the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldOperationalCount, v)
+ return u
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldDegradedCount, v)
+ return u
+}
+
+// UpdateDegradedCount sets the "degraded_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateDegradedCount() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldDegradedCount)
+ return u
+}
+
+// AddDegradedCount adds v to the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldDegradedCount, v)
+ return u
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetFailedCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldFailedCount, v)
+ return u
+}
+
+// UpdateFailedCount sets the "failed_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateFailedCount() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldFailedCount)
+ return u
+}
+
+// AddFailedCount adds v to the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddFailedCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldFailedCount, v)
+ return u
+}
+
+// SetErrorCount sets the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetErrorCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldErrorCount, v)
+ return u
+}
+
+// UpdateErrorCount sets the "error_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateErrorCount() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldErrorCount)
+ return u
+}
+
+// AddErrorCount adds v to the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddErrorCount(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldErrorCount, v)
+ return u
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldSumLatencyMs, v)
+ return u
+}
+
+// UpdateSumLatencyMs sets the "sum_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateSumLatencyMs() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldSumLatencyMs)
+ return u
+}
+
+// AddSumLatencyMs adds v to the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldSumLatencyMs, v)
+ return u
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetCountLatency(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldCountLatency, v)
+ return u
+}
+
+// UpdateCountLatency sets the "count_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateCountLatency() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldCountLatency)
+ return u
+}
+
+// AddCountLatency adds v to the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddCountLatency(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldCountLatency, v)
+ return u
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldSumPingLatencyMs, v)
+ return u
+}
+
+// UpdateSumPingLatencyMs sets the "sum_ping_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateSumPingLatencyMs() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldSumPingLatencyMs)
+ return u
+}
+
+// AddSumPingLatencyMs adds v to the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldSumPingLatencyMs, v)
+ return u
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldCountPingLatency, v)
+ return u
+}
+
+// UpdateCountPingLatency sets the "count_ping_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateCountPingLatency() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldCountPingLatency)
+ return u
+}
+
+// AddCountPingLatency adds v to the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsert) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpsert {
+ u.Add(channelmonitordailyrollup.FieldCountPingLatency, v)
+ return u
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (u *ChannelMonitorDailyRollupUpsert) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpsert {
+ u.Set(channelmonitordailyrollup.FieldComputedAt, v)
+ return u
+}
+
+// UpdateComputedAt sets the "computed_at" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsert) UpdateComputedAt() *ChannelMonitorDailyRollupUpsert {
+ u.SetExcluded(channelmonitordailyrollup.FieldComputedAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateNewValues() *ChannelMonitorDailyRollupUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorDailyRollupUpsertOne) Ignore() *ChannelMonitorDailyRollupUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorDailyRollupUpsertOne) DoNothing() *ChannelMonitorDailyRollupUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorDailyRollupCreate.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorDailyRollupUpsertOne) Update(set func(*ChannelMonitorDailyRollupUpsert)) *ChannelMonitorDailyRollupUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorDailyRollupUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetMonitorID(v)
+ })
+}
+
+// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateMonitorID() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateMonitorID()
+ })
+}
+
+// SetModel sets the "model" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetModel(v string) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetModel(v)
+ })
+}
+
+// UpdateModel sets the "model" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateModel() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateModel()
+ })
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetBucketDate(v)
+ })
+}
+
+// UpdateBucketDate sets the "bucket_date" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateBucketDate() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateBucketDate()
+ })
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetTotalChecks(v)
+ })
+}
+
+// AddTotalChecks adds v to the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddTotalChecks(v)
+ })
+}
+
+// UpdateTotalChecks sets the "total_checks" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateTotalChecks() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateTotalChecks()
+ })
+}
+
+// SetOkCount sets the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetOkCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetOkCount(v)
+ })
+}
+
+// AddOkCount adds v to the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddOkCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddOkCount(v)
+ })
+}
+
+// UpdateOkCount sets the "ok_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateOkCount() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateOkCount()
+ })
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetOperationalCount(v)
+ })
+}
+
+// AddOperationalCount adds v to the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddOperationalCount(v)
+ })
+}
+
+// UpdateOperationalCount sets the "operational_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateOperationalCount() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateOperationalCount()
+ })
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetDegradedCount(v)
+ })
+}
+
+// AddDegradedCount adds v to the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddDegradedCount(v)
+ })
+}
+
+// UpdateDegradedCount sets the "degraded_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateDegradedCount() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateDegradedCount()
+ })
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetFailedCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetFailedCount(v)
+ })
+}
+
+// AddFailedCount adds v to the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddFailedCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddFailedCount(v)
+ })
+}
+
+// UpdateFailedCount sets the "failed_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateFailedCount() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateFailedCount()
+ })
+}
+
+// SetErrorCount sets the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetErrorCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetErrorCount(v)
+ })
+}
+
+// AddErrorCount adds v to the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddErrorCount(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddErrorCount(v)
+ })
+}
+
+// UpdateErrorCount sets the "error_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateErrorCount() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateErrorCount()
+ })
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetSumLatencyMs(v)
+ })
+}
+
+// AddSumLatencyMs adds v to the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddSumLatencyMs(v)
+ })
+}
+
+// UpdateSumLatencyMs sets the "sum_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateSumLatencyMs() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateSumLatencyMs()
+ })
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetCountLatency(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetCountLatency(v)
+ })
+}
+
+// AddCountLatency adds v to the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddCountLatency(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddCountLatency(v)
+ })
+}
+
+// UpdateCountLatency sets the "count_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateCountLatency() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateCountLatency()
+ })
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetSumPingLatencyMs(v)
+ })
+}
+
+// AddSumPingLatencyMs adds v to the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddSumPingLatencyMs(v)
+ })
+}
+
+// UpdateSumPingLatencyMs sets the "sum_ping_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateSumPingLatencyMs() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateSumPingLatencyMs()
+ })
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetCountPingLatency(v)
+ })
+}
+
+// AddCountPingLatency adds v to the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddCountPingLatency(v)
+ })
+}
+
+// UpdateCountPingLatency sets the "count_ping_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateCountPingLatency() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateCountPingLatency()
+ })
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (u *ChannelMonitorDailyRollupUpsertOne) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetComputedAt(v)
+ })
+}
+
+// UpdateComputedAt sets the "computed_at" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertOne) UpdateComputedAt() *ChannelMonitorDailyRollupUpsertOne {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateComputedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorDailyRollupUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorDailyRollupCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorDailyRollupUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *ChannelMonitorDailyRollupUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *ChannelMonitorDailyRollupUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// ChannelMonitorDailyRollupCreateBulk is the builder for creating many ChannelMonitorDailyRollup entities in bulk.
+type ChannelMonitorDailyRollupCreateBulk struct {
+ config
+ err error
+ builders []*ChannelMonitorDailyRollupCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the ChannelMonitorDailyRollup entities in the database.
+func (_c *ChannelMonitorDailyRollupCreateBulk) Save(ctx context.Context) ([]*ChannelMonitorDailyRollup, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*ChannelMonitorDailyRollup, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*ChannelMonitorDailyRollupMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *ChannelMonitorDailyRollupCreateBulk) SaveX(ctx context.Context) []*ChannelMonitorDailyRollup {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorDailyRollupCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorDailyRollupCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorDailyRollup.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorDailyRollupUpsert) {
+// SetMonitorID(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorDailyRollupCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorDailyRollupUpsertBulk {
+ _c.conflict = opts
+ return &ChannelMonitorDailyRollupUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorDailyRollupCreateBulk) OnConflictColumns(columns ...string) *ChannelMonitorDailyRollupUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorDailyRollupUpsertBulk{
+ create: _c,
+ }
+}
+
+// ChannelMonitorDailyRollupUpsertBulk is the builder for "upsert"-ing
+// a bulk of ChannelMonitorDailyRollup nodes.
+type ChannelMonitorDailyRollupUpsertBulk struct {
+ create *ChannelMonitorDailyRollupCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateNewValues() *ChannelMonitorDailyRollupUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorDailyRollup.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorDailyRollupUpsertBulk) Ignore() *ChannelMonitorDailyRollupUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorDailyRollupUpsertBulk) DoNothing() *ChannelMonitorDailyRollupUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorDailyRollupCreateBulk.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorDailyRollupUpsertBulk) Update(set func(*ChannelMonitorDailyRollupUpsert)) *ChannelMonitorDailyRollupUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorDailyRollupUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetMonitorID(v)
+ })
+}
+
+// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateMonitorID() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateMonitorID()
+ })
+}
+
+// SetModel sets the "model" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetModel(v string) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetModel(v)
+ })
+}
+
+// UpdateModel sets the "model" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateModel() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateModel()
+ })
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetBucketDate(v)
+ })
+}
+
+// UpdateBucketDate sets the "bucket_date" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateBucketDate() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateBucketDate()
+ })
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetTotalChecks(v)
+ })
+}
+
+// AddTotalChecks adds v to the "total_checks" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddTotalChecks(v)
+ })
+}
+
+// UpdateTotalChecks sets the "total_checks" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateTotalChecks() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateTotalChecks()
+ })
+}
+
+// SetOkCount sets the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetOkCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetOkCount(v)
+ })
+}
+
+// AddOkCount adds v to the "ok_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddOkCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddOkCount(v)
+ })
+}
+
+// UpdateOkCount sets the "ok_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateOkCount() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateOkCount()
+ })
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetOperationalCount(v)
+ })
+}
+
+// AddOperationalCount adds v to the "operational_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddOperationalCount(v)
+ })
+}
+
+// UpdateOperationalCount sets the "operational_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateOperationalCount() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateOperationalCount()
+ })
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetDegradedCount(v)
+ })
+}
+
+// AddDegradedCount adds v to the "degraded_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddDegradedCount(v)
+ })
+}
+
+// UpdateDegradedCount sets the "degraded_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateDegradedCount() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateDegradedCount()
+ })
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetFailedCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetFailedCount(v)
+ })
+}
+
+// AddFailedCount adds v to the "failed_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddFailedCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddFailedCount(v)
+ })
+}
+
+// UpdateFailedCount sets the "failed_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateFailedCount() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateFailedCount()
+ })
+}
+
+// SetErrorCount sets the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetErrorCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetErrorCount(v)
+ })
+}
+
+// AddErrorCount adds v to the "error_count" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddErrorCount(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddErrorCount(v)
+ })
+}
+
+// UpdateErrorCount sets the "error_count" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateErrorCount() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateErrorCount()
+ })
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetSumLatencyMs(v)
+ })
+}
+
+// AddSumLatencyMs adds v to the "sum_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddSumLatencyMs(v)
+ })
+}
+
+// UpdateSumLatencyMs sets the "sum_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateSumLatencyMs() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateSumLatencyMs()
+ })
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetCountLatency(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetCountLatency(v)
+ })
+}
+
+// AddCountLatency adds v to the "count_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddCountLatency(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddCountLatency(v)
+ })
+}
+
+// UpdateCountLatency sets the "count_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateCountLatency() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateCountLatency()
+ })
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetSumPingLatencyMs(v)
+ })
+}
+
+// AddSumPingLatencyMs adds v to the "sum_ping_latency_ms" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddSumPingLatencyMs(v)
+ })
+}
+
+// UpdateSumPingLatencyMs sets the "sum_ping_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateSumPingLatencyMs() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateSumPingLatencyMs()
+ })
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetCountPingLatency(v)
+ })
+}
+
+// AddCountPingLatency adds v to the "count_ping_latency" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.AddCountPingLatency(v)
+ })
+}
+
+// UpdateCountPingLatency sets the "count_ping_latency" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateCountPingLatency() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateCountPingLatency()
+ })
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (u *ChannelMonitorDailyRollupUpsertBulk) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.SetComputedAt(v)
+ })
+}
+
+// UpdateComputedAt sets the "computed_at" field to the value that was provided on create.
+func (u *ChannelMonitorDailyRollupUpsertBulk) UpdateComputedAt() *ChannelMonitorDailyRollupUpsertBulk {
+ return u.Update(func(s *ChannelMonitorDailyRollupUpsert) {
+ s.UpdateComputedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorDailyRollupUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ChannelMonitorDailyRollupCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorDailyRollupCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorDailyRollupUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitordailyrollup_delete.go b/backend/ent/channelmonitordailyrollup_delete.go
new file mode 100644
index 00000000..460c94f8
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorDailyRollupDelete is the builder for deleting a ChannelMonitorDailyRollup entity.
+type ChannelMonitorDailyRollupDelete struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorDailyRollupMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorDailyRollupDelete builder.
+func (_d *ChannelMonitorDailyRollupDelete) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *ChannelMonitorDailyRollupDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorDailyRollupDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *ChannelMonitorDailyRollupDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(channelmonitordailyrollup.Table, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// ChannelMonitorDailyRollupDeleteOne is the builder for deleting a single ChannelMonitorDailyRollup entity.
+type ChannelMonitorDailyRollupDeleteOne struct {
+ _d *ChannelMonitorDailyRollupDelete
+}
+
+// Where appends a list predicates to the ChannelMonitorDailyRollupDelete builder.
+func (_d *ChannelMonitorDailyRollupDeleteOne) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *ChannelMonitorDailyRollupDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{channelmonitordailyrollup.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorDailyRollupDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitordailyrollup_query.go b/backend/ent/channelmonitordailyrollup_query.go
new file mode 100644
index 00000000..e34afc61
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup_query.go
@@ -0,0 +1,643 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorDailyRollupQuery is the builder for querying ChannelMonitorDailyRollup entities.
+type ChannelMonitorDailyRollupQuery struct {
+ config
+ ctx *QueryContext
+ order []channelmonitordailyrollup.OrderOption
+ inters []Interceptor
+ predicates []predicate.ChannelMonitorDailyRollup
+ withMonitor *ChannelMonitorQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the ChannelMonitorDailyRollupQuery builder.
+func (_q *ChannelMonitorDailyRollupQuery) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *ChannelMonitorDailyRollupQuery) Limit(limit int) *ChannelMonitorDailyRollupQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *ChannelMonitorDailyRollupQuery) Offset(offset int) *ChannelMonitorDailyRollupQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *ChannelMonitorDailyRollupQuery) Unique(unique bool) *ChannelMonitorDailyRollupQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *ChannelMonitorDailyRollupQuery) Order(o ...channelmonitordailyrollup.OrderOption) *ChannelMonitorDailyRollupQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryMonitor chains the current query on the "monitor" edge.
+func (_q *ChannelMonitorDailyRollupQuery) QueryMonitor() *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID, selector),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, channelmonitordailyrollup.MonitorTable, channelmonitordailyrollup.MonitorColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first ChannelMonitorDailyRollup entity from the query.
+// Returns a *NotFoundError when no ChannelMonitorDailyRollup was found.
+func (_q *ChannelMonitorDailyRollupQuery) First(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{channelmonitordailyrollup.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) FirstX(ctx context.Context) *ChannelMonitorDailyRollup {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first ChannelMonitorDailyRollup ID from the query.
+// Returns a *NotFoundError when no ChannelMonitorDailyRollup ID was found.
+func (_q *ChannelMonitorDailyRollupQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{channelmonitordailyrollup.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single ChannelMonitorDailyRollup entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one ChannelMonitorDailyRollup entity is found.
+// Returns a *NotFoundError when no ChannelMonitorDailyRollup entities are found.
+func (_q *ChannelMonitorDailyRollupQuery) Only(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{channelmonitordailyrollup.Label}
+ default:
+ return nil, &NotSingularError{channelmonitordailyrollup.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) OnlyX(ctx context.Context) *ChannelMonitorDailyRollup {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only ChannelMonitorDailyRollup ID in the query.
+// Returns a *NotSingularError when more than one ChannelMonitorDailyRollup ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *ChannelMonitorDailyRollupQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{channelmonitordailyrollup.Label}
+ default:
+ err = &NotSingularError{channelmonitordailyrollup.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of ChannelMonitorDailyRollups.
+func (_q *ChannelMonitorDailyRollupQuery) All(ctx context.Context) ([]*ChannelMonitorDailyRollup, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*ChannelMonitorDailyRollup, *ChannelMonitorDailyRollupQuery]()
+ return withInterceptors[[]*ChannelMonitorDailyRollup](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) AllX(ctx context.Context) []*ChannelMonitorDailyRollup {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of ChannelMonitorDailyRollup IDs.
+func (_q *ChannelMonitorDailyRollupQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(channelmonitordailyrollup.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *ChannelMonitorDailyRollupQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*ChannelMonitorDailyRollupQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *ChannelMonitorDailyRollupQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *ChannelMonitorDailyRollupQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the ChannelMonitorDailyRollupQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *ChannelMonitorDailyRollupQuery) Clone() *ChannelMonitorDailyRollupQuery {
+ if _q == nil {
+ return nil
+ }
+ return &ChannelMonitorDailyRollupQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]channelmonitordailyrollup.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.ChannelMonitorDailyRollup{}, _q.predicates...),
+ withMonitor: _q.withMonitor.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithMonitor tells the query-builder to eager-load the nodes that are connected to
+// the "monitor" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorDailyRollupQuery) WithMonitor(opts ...func(*ChannelMonitorQuery)) *ChannelMonitorDailyRollupQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withMonitor = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// MonitorID int64 `json:"monitor_id,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.ChannelMonitorDailyRollup.Query().
+// GroupBy(channelmonitordailyrollup.FieldMonitorID).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorDailyRollupQuery) GroupBy(field string, fields ...string) *ChannelMonitorDailyRollupGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &ChannelMonitorDailyRollupGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = channelmonitordailyrollup.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// MonitorID int64 `json:"monitor_id,omitempty"`
+// }
+//
+// client.ChannelMonitorDailyRollup.Query().
+// Select(channelmonitordailyrollup.FieldMonitorID).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorDailyRollupQuery) Select(fields ...string) *ChannelMonitorDailyRollupSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &ChannelMonitorDailyRollupSelect{ChannelMonitorDailyRollupQuery: _q}
+ sbuild.label = channelmonitordailyrollup.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a ChannelMonitorDailyRollupSelect configured with the given aggregations.
+func (_q *ChannelMonitorDailyRollupQuery) Aggregate(fns ...AggregateFunc) *ChannelMonitorDailyRollupSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !channelmonitordailyrollup.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ChannelMonitorDailyRollup, error) {
+ var (
+ nodes = []*ChannelMonitorDailyRollup{}
+ _spec = _q.querySpec()
+ loadedTypes = [1]bool{
+ _q.withMonitor != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*ChannelMonitorDailyRollup).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &ChannelMonitorDailyRollup{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withMonitor; query != nil {
+ if err := _q.loadMonitor(ctx, query, nodes, nil,
+ func(n *ChannelMonitorDailyRollup, e *ChannelMonitor) { n.Edges.Monitor = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) loadMonitor(ctx context.Context, query *ChannelMonitorQuery, nodes []*ChannelMonitorDailyRollup, init func(*ChannelMonitorDailyRollup), assign func(*ChannelMonitorDailyRollup, *ChannelMonitor)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*ChannelMonitorDailyRollup)
+ for i := range nodes {
+ fk := nodes[i].MonitorID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(channelmonitor.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "monitor_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(channelmonitordailyrollup.Table, channelmonitordailyrollup.Columns, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitordailyrollup.FieldID)
+ for i := range fields {
+ if fields[i] != channelmonitordailyrollup.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withMonitor != nil {
+ _spec.Node.AddColumnOnce(channelmonitordailyrollup.FieldMonitorID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *ChannelMonitorDailyRollupQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(channelmonitordailyrollup.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = channelmonitordailyrollup.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *ChannelMonitorDailyRollupQuery) ForUpdate(opts ...sql.LockOption) *ChannelMonitorDailyRollupQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *ChannelMonitorDailyRollupQuery) ForShare(opts ...sql.LockOption) *ChannelMonitorDailyRollupQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// ChannelMonitorDailyRollupGroupBy is the group-by builder for ChannelMonitorDailyRollup entities.
+type ChannelMonitorDailyRollupGroupBy struct {
+ selector
+ build *ChannelMonitorDailyRollupQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *ChannelMonitorDailyRollupGroupBy) Aggregate(fns ...AggregateFunc) *ChannelMonitorDailyRollupGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *ChannelMonitorDailyRollupGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorDailyRollupQuery, *ChannelMonitorDailyRollupGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *ChannelMonitorDailyRollupGroupBy) sqlScan(ctx context.Context, root *ChannelMonitorDailyRollupQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// ChannelMonitorDailyRollupSelect is the builder for selecting fields of ChannelMonitorDailyRollup entities.
+type ChannelMonitorDailyRollupSelect struct {
+ *ChannelMonitorDailyRollupQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *ChannelMonitorDailyRollupSelect) Aggregate(fns ...AggregateFunc) *ChannelMonitorDailyRollupSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *ChannelMonitorDailyRollupSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorDailyRollupQuery, *ChannelMonitorDailyRollupSelect](ctx, _s.ChannelMonitorDailyRollupQuery, _s, _s.inters, v)
+}
+
+func (_s *ChannelMonitorDailyRollupSelect) sqlScan(ctx context.Context, root *ChannelMonitorDailyRollupQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/channelmonitordailyrollup_update.go b/backend/ent/channelmonitordailyrollup_update.go
new file mode 100644
index 00000000..02cd86c5
--- /dev/null
+++ b/backend/ent/channelmonitordailyrollup_update.go
@@ -0,0 +1,961 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorDailyRollupUpdate is the builder for updating ChannelMonitorDailyRollup entities.
+type ChannelMonitorDailyRollupUpdate struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorDailyRollupMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorDailyRollupUpdate builder.
+func (_u *ChannelMonitorDailyRollupUpdate) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.SetMonitorID(v)
+ return _u
+}
+
+// SetNillableMonitorID sets the "monitor_id" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableMonitorID(v *int64) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetMonitorID(*v)
+ }
+ return _u
+}
+
+// SetModel sets the "model" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetModel(v string) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.SetModel(v)
+ return _u
+}
+
+// SetNillableModel sets the "model" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableModel(v *string) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetModel(*v)
+ }
+ return _u
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.SetBucketDate(v)
+ return _u
+}
+
+// SetNillableBucketDate sets the "bucket_date" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableBucketDate(v *time.Time) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetBucketDate(*v)
+ }
+ return _u
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetTotalChecks()
+ _u.mutation.SetTotalChecks(v)
+ return _u
+}
+
+// SetNillableTotalChecks sets the "total_checks" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableTotalChecks(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetTotalChecks(*v)
+ }
+ return _u
+}
+
+// AddTotalChecks adds value to the "total_checks" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddTotalChecks(v)
+ return _u
+}
+
+// SetOkCount sets the "ok_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetOkCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetOkCount()
+ _u.mutation.SetOkCount(v)
+ return _u
+}
+
+// SetNillableOkCount sets the "ok_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableOkCount(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetOkCount(*v)
+ }
+ return _u
+}
+
+// AddOkCount adds value to the "ok_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddOkCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddOkCount(v)
+ return _u
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetOperationalCount()
+ _u.mutation.SetOperationalCount(v)
+ return _u
+}
+
+// SetNillableOperationalCount sets the "operational_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableOperationalCount(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetOperationalCount(*v)
+ }
+ return _u
+}
+
+// AddOperationalCount adds value to the "operational_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddOperationalCount(v)
+ return _u
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetDegradedCount()
+ _u.mutation.SetDegradedCount(v)
+ return _u
+}
+
+// SetNillableDegradedCount sets the "degraded_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableDegradedCount(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetDegradedCount(*v)
+ }
+ return _u
+}
+
+// AddDegradedCount adds value to the "degraded_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddDegradedCount(v)
+ return _u
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetFailedCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetFailedCount()
+ _u.mutation.SetFailedCount(v)
+ return _u
+}
+
+// SetNillableFailedCount sets the "failed_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableFailedCount(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetFailedCount(*v)
+ }
+ return _u
+}
+
+// AddFailedCount adds value to the "failed_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddFailedCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddFailedCount(v)
+ return _u
+}
+
+// SetErrorCount sets the "error_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetErrorCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetErrorCount()
+ _u.mutation.SetErrorCount(v)
+ return _u
+}
+
+// SetNillableErrorCount sets the "error_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableErrorCount(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetErrorCount(*v)
+ }
+ return _u
+}
+
+// AddErrorCount adds value to the "error_count" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddErrorCount(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddErrorCount(v)
+ return _u
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetSumLatencyMs()
+ _u.mutation.SetSumLatencyMs(v)
+ return _u
+}
+
+// SetNillableSumLatencyMs sets the "sum_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableSumLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetSumLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddSumLatencyMs adds value to the "sum_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddSumLatencyMs(v)
+ return _u
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetCountLatency(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetCountLatency()
+ _u.mutation.SetCountLatency(v)
+ return _u
+}
+
+// SetNillableCountLatency sets the "count_latency" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableCountLatency(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetCountLatency(*v)
+ }
+ return _u
+}
+
+// AddCountLatency adds value to the "count_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddCountLatency(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddCountLatency(v)
+ return _u
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetSumPingLatencyMs()
+ _u.mutation.SetSumPingLatencyMs(v)
+ return _u
+}
+
+// SetNillableSumPingLatencyMs sets the "sum_ping_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableSumPingLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetSumPingLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddSumPingLatencyMs adds value to the "sum_ping_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddSumPingLatencyMs(v)
+ return _u
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ResetCountPingLatency()
+ _u.mutation.SetCountPingLatency(v)
+ return _u
+}
+
+// SetNillableCountPingLatency sets the "count_ping_latency" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdate) SetNillableCountPingLatency(v *int) *ChannelMonitorDailyRollupUpdate {
+ if v != nil {
+ _u.SetCountPingLatency(*v)
+ }
+ return _u
+}
+
+// AddCountPingLatency adds value to the "count_ping_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdate) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.AddCountPingLatency(v)
+ return _u
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (_u *ChannelMonitorDailyRollupUpdate) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.SetComputedAt(v)
+ return _u
+}
+
+// SetMonitor sets the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorDailyRollupUpdate) SetMonitor(v *ChannelMonitor) *ChannelMonitorDailyRollupUpdate {
+ return _u.SetMonitorID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorDailyRollupMutation object of the builder.
+func (_u *ChannelMonitorDailyRollupUpdate) Mutation() *ChannelMonitorDailyRollupMutation {
+ return _u.mutation
+}
+
+// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorDailyRollupUpdate) ClearMonitor() *ChannelMonitorDailyRollupUpdate {
+ _u.mutation.ClearMonitor()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *ChannelMonitorDailyRollupUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorDailyRollupUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *ChannelMonitorDailyRollupUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorDailyRollupUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorDailyRollupUpdate) defaults() {
+ if _, ok := _u.mutation.ComputedAt(); !ok {
+ v := channelmonitordailyrollup.UpdateDefaultComputedAt()
+ _u.mutation.SetComputedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorDailyRollupUpdate) check() error {
+ if v, ok := _u.mutation.Model(); ok {
+ if err := channelmonitordailyrollup.ModelValidator(v); err != nil {
+ return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorDailyRollup.model": %w`, err)}
+ }
+ }
+ if _u.mutation.MonitorCleared() && len(_u.mutation.MonitorIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "ChannelMonitorDailyRollup.monitor"`)
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorDailyRollupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitordailyrollup.Table, channelmonitordailyrollup.Columns, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.Model(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldModel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BucketDate(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldBucketDate, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.TotalChecks(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedTotalChecks(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.OkCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedOkCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.OperationalCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedOperationalCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.DegradedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedDegradedCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.FailedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedFailedCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.ErrorCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedErrorCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.SumLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSumLatencyMs(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.CountLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedCountLatency(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.SumPingLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSumPingLatencyMs(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.CountPingLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedCountPingLatency(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.ComputedAt(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldComputedAt, field.TypeTime, value)
+ }
+ if _u.mutation.MonitorCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitordailyrollup.MonitorTable,
+ Columns: []string{channelmonitordailyrollup.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitordailyrollup.MonitorTable,
+ Columns: []string{channelmonitordailyrollup.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitordailyrollup.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// ChannelMonitorDailyRollupUpdateOne is the builder for updating a single ChannelMonitorDailyRollup entity.
+type ChannelMonitorDailyRollupUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *ChannelMonitorDailyRollupMutation
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetMonitorID(v int64) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.SetMonitorID(v)
+ return _u
+}
+
+// SetNillableMonitorID sets the "monitor_id" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableMonitorID(v *int64) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetMonitorID(*v)
+ }
+ return _u
+}
+
+// SetModel sets the "model" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetModel(v string) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.SetModel(v)
+ return _u
+}
+
+// SetNillableModel sets the "model" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableModel(v *string) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetModel(*v)
+ }
+ return _u
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetBucketDate(v time.Time) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.SetBucketDate(v)
+ return _u
+}
+
+// SetNillableBucketDate sets the "bucket_date" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableBucketDate(v *time.Time) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetBucketDate(*v)
+ }
+ return _u
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetTotalChecks(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetTotalChecks()
+ _u.mutation.SetTotalChecks(v)
+ return _u
+}
+
+// SetNillableTotalChecks sets the "total_checks" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableTotalChecks(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetTotalChecks(*v)
+ }
+ return _u
+}
+
+// AddTotalChecks adds value to the "total_checks" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddTotalChecks(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddTotalChecks(v)
+ return _u
+}
+
+// SetOkCount sets the "ok_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetOkCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetOkCount()
+ _u.mutation.SetOkCount(v)
+ return _u
+}
+
+// SetNillableOkCount sets the "ok_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableOkCount(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetOkCount(*v)
+ }
+ return _u
+}
+
+// AddOkCount adds value to the "ok_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddOkCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddOkCount(v)
+ return _u
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetOperationalCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetOperationalCount()
+ _u.mutation.SetOperationalCount(v)
+ return _u
+}
+
+// SetNillableOperationalCount sets the "operational_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableOperationalCount(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetOperationalCount(*v)
+ }
+ return _u
+}
+
+// AddOperationalCount adds value to the "operational_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddOperationalCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddOperationalCount(v)
+ return _u
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetDegradedCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetDegradedCount()
+ _u.mutation.SetDegradedCount(v)
+ return _u
+}
+
+// SetNillableDegradedCount sets the "degraded_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableDegradedCount(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetDegradedCount(*v)
+ }
+ return _u
+}
+
+// AddDegradedCount adds value to the "degraded_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddDegradedCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddDegradedCount(v)
+ return _u
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetFailedCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetFailedCount()
+ _u.mutation.SetFailedCount(v)
+ return _u
+}
+
+// SetNillableFailedCount sets the "failed_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableFailedCount(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetFailedCount(*v)
+ }
+ return _u
+}
+
+// AddFailedCount adds value to the "failed_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddFailedCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddFailedCount(v)
+ return _u
+}
+
+// SetErrorCount sets the "error_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetErrorCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetErrorCount()
+ _u.mutation.SetErrorCount(v)
+ return _u
+}
+
+// SetNillableErrorCount sets the "error_count" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableErrorCount(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetErrorCount(*v)
+ }
+ return _u
+}
+
+// AddErrorCount adds value to the "error_count" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddErrorCount(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddErrorCount(v)
+ return _u
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetSumLatencyMs()
+ _u.mutation.SetSumLatencyMs(v)
+ return _u
+}
+
+// SetNillableSumLatencyMs sets the "sum_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableSumLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetSumLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddSumLatencyMs adds value to the "sum_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddSumLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddSumLatencyMs(v)
+ return _u
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetCountLatency(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetCountLatency()
+ _u.mutation.SetCountLatency(v)
+ return _u
+}
+
+// SetNillableCountLatency sets the "count_latency" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableCountLatency(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetCountLatency(*v)
+ }
+ return _u
+}
+
+// AddCountLatency adds value to the "count_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddCountLatency(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddCountLatency(v)
+ return _u
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetSumPingLatencyMs()
+ _u.mutation.SetSumPingLatencyMs(v)
+ return _u
+}
+
+// SetNillableSumPingLatencyMs sets the "sum_ping_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableSumPingLatencyMs(v *int64) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetSumPingLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddSumPingLatencyMs adds value to the "sum_ping_latency_ms" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddSumPingLatencyMs(v int64) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddSumPingLatencyMs(v)
+ return _u
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetCountPingLatency(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ResetCountPingLatency()
+ _u.mutation.SetCountPingLatency(v)
+ return _u
+}
+
+// SetNillableCountPingLatency sets the "count_ping_latency" field if the given value is not nil.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetNillableCountPingLatency(v *int) *ChannelMonitorDailyRollupUpdateOne {
+ if v != nil {
+ _u.SetCountPingLatency(*v)
+ }
+ return _u
+}
+
+// AddCountPingLatency adds value to the "count_ping_latency" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) AddCountPingLatency(v int) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.AddCountPingLatency(v)
+ return _u
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetComputedAt(v time.Time) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.SetComputedAt(v)
+ return _u
+}
+
+// SetMonitor sets the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SetMonitor(v *ChannelMonitor) *ChannelMonitorDailyRollupUpdateOne {
+ return _u.SetMonitorID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorDailyRollupMutation object of the builder.
+func (_u *ChannelMonitorDailyRollupUpdateOne) Mutation() *ChannelMonitorDailyRollupMutation {
+ return _u.mutation
+}
+
+// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorDailyRollupUpdateOne) ClearMonitor() *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.ClearMonitor()
+ return _u
+}
+
+// Where appends a list predicates to the ChannelMonitorDailyRollupUpdate builder.
+func (_u *ChannelMonitorDailyRollupUpdateOne) Where(ps ...predicate.ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *ChannelMonitorDailyRollupUpdateOne) Select(field string, fields ...string) *ChannelMonitorDailyRollupUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated ChannelMonitorDailyRollup entity.
+func (_u *ChannelMonitorDailyRollupUpdateOne) Save(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorDailyRollupUpdateOne) SaveX(ctx context.Context) *ChannelMonitorDailyRollup {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *ChannelMonitorDailyRollupUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorDailyRollupUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorDailyRollupUpdateOne) defaults() {
+ if _, ok := _u.mutation.ComputedAt(); !ok {
+ v := channelmonitordailyrollup.UpdateDefaultComputedAt()
+ _u.mutation.SetComputedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorDailyRollupUpdateOne) check() error {
+ if v, ok := _u.mutation.Model(); ok {
+ if err := channelmonitordailyrollup.ModelValidator(v); err != nil {
+ return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorDailyRollup.model": %w`, err)}
+ }
+ }
+ if _u.mutation.MonitorCleared() && len(_u.mutation.MonitorIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "ChannelMonitorDailyRollup.monitor"`)
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorDailyRollupUpdateOne) sqlSave(ctx context.Context) (_node *ChannelMonitorDailyRollup, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitordailyrollup.Table, channelmonitordailyrollup.Columns, sqlgraph.NewFieldSpec(channelmonitordailyrollup.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ChannelMonitorDailyRollup.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitordailyrollup.FieldID)
+ for _, f := range fields {
+ if !channelmonitordailyrollup.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != channelmonitordailyrollup.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.Model(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldModel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BucketDate(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldBucketDate, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.TotalChecks(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedTotalChecks(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldTotalChecks, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.OkCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedOkCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldOkCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.OperationalCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedOperationalCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldOperationalCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.DegradedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedDegradedCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldDegradedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.FailedCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedFailedCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldFailedCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.ErrorCount(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedErrorCount(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldErrorCount, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.SumLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSumLatencyMs(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldSumLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.CountLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedCountLatency(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldCountLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.SumPingLatencyMs(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSumPingLatencyMs(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldSumPingLatencyMs, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.CountPingLatency(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedCountPingLatency(); ok {
+ _spec.AddField(channelmonitordailyrollup.FieldCountPingLatency, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.ComputedAt(); ok {
+ _spec.SetField(channelmonitordailyrollup.FieldComputedAt, field.TypeTime, value)
+ }
+ if _u.mutation.MonitorCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitordailyrollup.MonitorTable,
+ Columns: []string{channelmonitordailyrollup.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitordailyrollup.MonitorTable,
+ Columns: []string{channelmonitordailyrollup.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &ChannelMonitorDailyRollup{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitordailyrollup.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/channelmonitorhistory.go b/backend/ent/channelmonitorhistory.go
new file mode 100644
index 00000000..70dde542
--- /dev/null
+++ b/backend/ent/channelmonitorhistory.go
@@ -0,0 +1,207 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+)
+
+// ChannelMonitorHistory is the model entity for the ChannelMonitorHistory schema.
+type ChannelMonitorHistory struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // MonitorID holds the value of the "monitor_id" field.
+ MonitorID int64 `json:"monitor_id,omitempty"`
+ // Model holds the value of the "model" field.
+ Model string `json:"model,omitempty"`
+ // Status holds the value of the "status" field.
+ Status channelmonitorhistory.Status `json:"status,omitempty"`
+ // LatencyMs holds the value of the "latency_ms" field.
+ LatencyMs *int `json:"latency_ms,omitempty"`
+ // PingLatencyMs holds the value of the "ping_latency_ms" field.
+ PingLatencyMs *int `json:"ping_latency_ms,omitempty"`
+ // Message holds the value of the "message" field.
+ Message string `json:"message,omitempty"`
+ // CheckedAt holds the value of the "checked_at" field.
+ CheckedAt time.Time `json:"checked_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the ChannelMonitorHistoryQuery when eager-loading is set.
+ Edges ChannelMonitorHistoryEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// ChannelMonitorHistoryEdges holds the relations/edges for other nodes in the graph.
+type ChannelMonitorHistoryEdges struct {
+ // Monitor holds the value of the monitor edge.
+ Monitor *ChannelMonitor `json:"monitor,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// MonitorOrErr returns the Monitor value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e ChannelMonitorHistoryEdges) MonitorOrErr() (*ChannelMonitor, error) {
+ if e.Monitor != nil {
+ return e.Monitor, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: channelmonitor.Label}
+ }
+ return nil, &NotLoadedError{edge: "monitor"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*ChannelMonitorHistory) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitorhistory.FieldID, channelmonitorhistory.FieldMonitorID, channelmonitorhistory.FieldLatencyMs, channelmonitorhistory.FieldPingLatencyMs:
+ values[i] = new(sql.NullInt64)
+ case channelmonitorhistory.FieldModel, channelmonitorhistory.FieldStatus, channelmonitorhistory.FieldMessage:
+ values[i] = new(sql.NullString)
+ case channelmonitorhistory.FieldCheckedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the ChannelMonitorHistory fields.
+func (_m *ChannelMonitorHistory) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitorhistory.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case channelmonitorhistory.FieldMonitorID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field monitor_id", values[i])
+ } else if value.Valid {
+ _m.MonitorID = value.Int64
+ }
+ case channelmonitorhistory.FieldModel:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field model", values[i])
+ } else if value.Valid {
+ _m.Model = value.String
+ }
+ case channelmonitorhistory.FieldStatus:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field status", values[i])
+ } else if value.Valid {
+ _m.Status = channelmonitorhistory.Status(value.String)
+ }
+ case channelmonitorhistory.FieldLatencyMs:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field latency_ms", values[i])
+ } else if value.Valid {
+ _m.LatencyMs = new(int)
+ *_m.LatencyMs = int(value.Int64)
+ }
+ case channelmonitorhistory.FieldPingLatencyMs:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field ping_latency_ms", values[i])
+ } else if value.Valid {
+ _m.PingLatencyMs = new(int)
+ *_m.PingLatencyMs = int(value.Int64)
+ }
+ case channelmonitorhistory.FieldMessage:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field message", values[i])
+ } else if value.Valid {
+ _m.Message = value.String
+ }
+ case channelmonitorhistory.FieldCheckedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field checked_at", values[i])
+ } else if value.Valid {
+ _m.CheckedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the ChannelMonitorHistory.
+// This includes values selected through modifiers, order, etc.
+func (_m *ChannelMonitorHistory) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryMonitor queries the "monitor" edge of the ChannelMonitorHistory entity.
+func (_m *ChannelMonitorHistory) QueryMonitor() *ChannelMonitorQuery {
+ return NewChannelMonitorHistoryClient(_m.config).QueryMonitor(_m)
+}
+
+// Update returns a builder for updating this ChannelMonitorHistory.
+// Note that you need to call ChannelMonitorHistory.Unwrap() before calling this method if this ChannelMonitorHistory
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *ChannelMonitorHistory) Update() *ChannelMonitorHistoryUpdateOne {
+ return NewChannelMonitorHistoryClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the ChannelMonitorHistory entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *ChannelMonitorHistory) Unwrap() *ChannelMonitorHistory {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: ChannelMonitorHistory is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *ChannelMonitorHistory) String() string {
+ var builder strings.Builder
+ builder.WriteString("ChannelMonitorHistory(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("monitor_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.MonitorID))
+ builder.WriteString(", ")
+ builder.WriteString("model=")
+ builder.WriteString(_m.Model)
+ builder.WriteString(", ")
+ builder.WriteString("status=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Status))
+ builder.WriteString(", ")
+ if v := _m.LatencyMs; v != nil {
+ builder.WriteString("latency_ms=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ if v := _m.PingLatencyMs; v != nil {
+ builder.WriteString("ping_latency_ms=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("message=")
+ builder.WriteString(_m.Message)
+ builder.WriteString(", ")
+ builder.WriteString("checked_at=")
+ builder.WriteString(_m.CheckedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// ChannelMonitorHistories is a parsable slice of ChannelMonitorHistory.
+type ChannelMonitorHistories []*ChannelMonitorHistory
diff --git a/backend/ent/channelmonitorhistory/channelmonitorhistory.go b/backend/ent/channelmonitorhistory/channelmonitorhistory.go
new file mode 100644
index 00000000..6a9dc006
--- /dev/null
+++ b/backend/ent/channelmonitorhistory/channelmonitorhistory.go
@@ -0,0 +1,158 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitorhistory
+
+import (
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the channelmonitorhistory type in the database.
+ Label = "channel_monitor_history"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldMonitorID holds the string denoting the monitor_id field in the database.
+ FieldMonitorID = "monitor_id"
+ // FieldModel holds the string denoting the model field in the database.
+ FieldModel = "model"
+ // FieldStatus holds the string denoting the status field in the database.
+ FieldStatus = "status"
+ // FieldLatencyMs holds the string denoting the latency_ms field in the database.
+ FieldLatencyMs = "latency_ms"
+ // FieldPingLatencyMs holds the string denoting the ping_latency_ms field in the database.
+ FieldPingLatencyMs = "ping_latency_ms"
+ // FieldMessage holds the string denoting the message field in the database.
+ FieldMessage = "message"
+ // FieldCheckedAt holds the string denoting the checked_at field in the database.
+ FieldCheckedAt = "checked_at"
+ // EdgeMonitor holds the string denoting the monitor edge name in mutations.
+ EdgeMonitor = "monitor"
+ // Table holds the table name of the channelmonitorhistory in the database.
+ Table = "channel_monitor_histories"
+ // MonitorTable is the table that holds the monitor relation/edge.
+ MonitorTable = "channel_monitor_histories"
+ // MonitorInverseTable is the table name for the ChannelMonitor entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitor" package.
+ MonitorInverseTable = "channel_monitors"
+ // MonitorColumn is the table column denoting the monitor relation/edge.
+ MonitorColumn = "monitor_id"
+)
+
+// Columns holds all SQL columns for channelmonitorhistory fields.
+var Columns = []string{
+ FieldID,
+ FieldMonitorID,
+ FieldModel,
+ FieldStatus,
+ FieldLatencyMs,
+ FieldPingLatencyMs,
+ FieldMessage,
+ FieldCheckedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // ModelValidator is a validator for the "model" field. It is called by the builders before save.
+ ModelValidator func(string) error
+ // DefaultMessage holds the default value on creation for the "message" field.
+ DefaultMessage string
+ // MessageValidator is a validator for the "message" field. It is called by the builders before save.
+ MessageValidator func(string) error
+ // DefaultCheckedAt holds the default value on creation for the "checked_at" field.
+ DefaultCheckedAt func() time.Time
+)
+
+// Status defines the type for the "status" enum field.
+type Status string
+
+// Status values.
+const (
+ StatusOperational Status = "operational"
+ StatusDegraded Status = "degraded"
+ StatusFailed Status = "failed"
+ StatusError Status = "error"
+)
+
+func (s Status) String() string {
+ return string(s)
+}
+
+// StatusValidator is a validator for the "status" field enum values. It is called by the builders before save.
+func StatusValidator(s Status) error {
+ switch s {
+ case StatusOperational, StatusDegraded, StatusFailed, StatusError:
+ return nil
+ default:
+ return fmt.Errorf("channelmonitorhistory: invalid enum value for status field: %q", s)
+ }
+}
+
+// OrderOption defines the ordering options for the ChannelMonitorHistory queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByMonitorID orders the results by the monitor_id field.
+func ByMonitorID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldMonitorID, opts...).ToFunc()
+}
+
+// ByModel orders the results by the model field.
+func ByModel(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldModel, opts...).ToFunc()
+}
+
+// ByStatus orders the results by the status field.
+func ByStatus(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldStatus, opts...).ToFunc()
+}
+
+// ByLatencyMs orders the results by the latency_ms field.
+func ByLatencyMs(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLatencyMs, opts...).ToFunc()
+}
+
+// ByPingLatencyMs orders the results by the ping_latency_ms field.
+func ByPingLatencyMs(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPingLatencyMs, opts...).ToFunc()
+}
+
+// ByMessage orders the results by the message field.
+func ByMessage(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldMessage, opts...).ToFunc()
+}
+
+// ByCheckedAt orders the results by the checked_at field.
+func ByCheckedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCheckedAt, opts...).ToFunc()
+}
+
+// ByMonitorField orders the results by monitor field.
+func ByMonitorField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newMonitorStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newMonitorStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(MonitorInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, MonitorTable, MonitorColumn),
+ )
+}
diff --git a/backend/ent/channelmonitorhistory/where.go b/backend/ent/channelmonitorhistory/where.go
new file mode 100644
index 00000000..afa73f35
--- /dev/null
+++ b/backend/ent/channelmonitorhistory/where.go
@@ -0,0 +1,444 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitorhistory
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldID, id))
+}
+
+// MonitorID applies equality check predicate on the "monitor_id" field. It's identical to MonitorIDEQ.
+func MonitorID(v int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMonitorID, v))
+}
+
+// Model applies equality check predicate on the "model" field. It's identical to ModelEQ.
+func Model(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldModel, v))
+}
+
+// LatencyMs applies equality check predicate on the "latency_ms" field. It's identical to LatencyMsEQ.
+func LatencyMs(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldLatencyMs, v))
+}
+
+// PingLatencyMs applies equality check predicate on the "ping_latency_ms" field. It's identical to PingLatencyMsEQ.
+func PingLatencyMs(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldPingLatencyMs, v))
+}
+
+// Message applies equality check predicate on the "message" field. It's identical to MessageEQ.
+func Message(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMessage, v))
+}
+
+// CheckedAt applies equality check predicate on the "checked_at" field. It's identical to CheckedAtEQ.
+func CheckedAt(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldCheckedAt, v))
+}
+
+// MonitorIDEQ applies the EQ predicate on the "monitor_id" field.
+func MonitorIDEQ(v int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMonitorID, v))
+}
+
+// MonitorIDNEQ applies the NEQ predicate on the "monitor_id" field.
+func MonitorIDNEQ(v int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldMonitorID, v))
+}
+
+// MonitorIDIn applies the In predicate on the "monitor_id" field.
+func MonitorIDIn(vs ...int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldMonitorID, vs...))
+}
+
+// MonitorIDNotIn applies the NotIn predicate on the "monitor_id" field.
+func MonitorIDNotIn(vs ...int64) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldMonitorID, vs...))
+}
+
+// ModelEQ applies the EQ predicate on the "model" field.
+func ModelEQ(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldModel, v))
+}
+
+// ModelNEQ applies the NEQ predicate on the "model" field.
+func ModelNEQ(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldModel, v))
+}
+
+// ModelIn applies the In predicate on the "model" field.
+func ModelIn(vs ...string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldModel, vs...))
+}
+
+// ModelNotIn applies the NotIn predicate on the "model" field.
+func ModelNotIn(vs ...string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldModel, vs...))
+}
+
+// ModelGT applies the GT predicate on the "model" field.
+func ModelGT(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGT(FieldModel, v))
+}
+
+// ModelGTE applies the GTE predicate on the "model" field.
+func ModelGTE(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldModel, v))
+}
+
+// ModelLT applies the LT predicate on the "model" field.
+func ModelLT(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLT(FieldModel, v))
+}
+
+// ModelLTE applies the LTE predicate on the "model" field.
+func ModelLTE(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldModel, v))
+}
+
+// ModelContains applies the Contains predicate on the "model" field.
+func ModelContains(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldContains(FieldModel, v))
+}
+
+// ModelHasPrefix applies the HasPrefix predicate on the "model" field.
+func ModelHasPrefix(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldHasPrefix(FieldModel, v))
+}
+
+// ModelHasSuffix applies the HasSuffix predicate on the "model" field.
+func ModelHasSuffix(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldHasSuffix(FieldModel, v))
+}
+
+// ModelEqualFold applies the EqualFold predicate on the "model" field.
+func ModelEqualFold(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEqualFold(FieldModel, v))
+}
+
+// ModelContainsFold applies the ContainsFold predicate on the "model" field.
+func ModelContainsFold(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldContainsFold(FieldModel, v))
+}
+
+// StatusEQ applies the EQ predicate on the "status" field.
+func StatusEQ(v Status) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldStatus, v))
+}
+
+// StatusNEQ applies the NEQ predicate on the "status" field.
+func StatusNEQ(v Status) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldStatus, v))
+}
+
+// StatusIn applies the In predicate on the "status" field.
+func StatusIn(vs ...Status) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldStatus, vs...))
+}
+
+// StatusNotIn applies the NotIn predicate on the "status" field.
+func StatusNotIn(vs ...Status) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldStatus, vs...))
+}
+
+// LatencyMsEQ applies the EQ predicate on the "latency_ms" field.
+func LatencyMsEQ(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldLatencyMs, v))
+}
+
+// LatencyMsNEQ applies the NEQ predicate on the "latency_ms" field.
+func LatencyMsNEQ(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldLatencyMs, v))
+}
+
+// LatencyMsIn applies the In predicate on the "latency_ms" field.
+func LatencyMsIn(vs ...int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldLatencyMs, vs...))
+}
+
+// LatencyMsNotIn applies the NotIn predicate on the "latency_ms" field.
+func LatencyMsNotIn(vs ...int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldLatencyMs, vs...))
+}
+
+// LatencyMsGT applies the GT predicate on the "latency_ms" field.
+func LatencyMsGT(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGT(FieldLatencyMs, v))
+}
+
+// LatencyMsGTE applies the GTE predicate on the "latency_ms" field.
+func LatencyMsGTE(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldLatencyMs, v))
+}
+
+// LatencyMsLT applies the LT predicate on the "latency_ms" field.
+func LatencyMsLT(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLT(FieldLatencyMs, v))
+}
+
+// LatencyMsLTE applies the LTE predicate on the "latency_ms" field.
+func LatencyMsLTE(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldLatencyMs, v))
+}
+
+// LatencyMsIsNil applies the IsNil predicate on the "latency_ms" field.
+func LatencyMsIsNil() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIsNull(FieldLatencyMs))
+}
+
+// LatencyMsNotNil applies the NotNil predicate on the "latency_ms" field.
+func LatencyMsNotNil() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotNull(FieldLatencyMs))
+}
+
+// PingLatencyMsEQ applies the EQ predicate on the "ping_latency_ms" field.
+func PingLatencyMsEQ(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldPingLatencyMs, v))
+}
+
+// PingLatencyMsNEQ applies the NEQ predicate on the "ping_latency_ms" field.
+func PingLatencyMsNEQ(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldPingLatencyMs, v))
+}
+
+// PingLatencyMsIn applies the In predicate on the "ping_latency_ms" field.
+func PingLatencyMsIn(vs ...int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldPingLatencyMs, vs...))
+}
+
+// PingLatencyMsNotIn applies the NotIn predicate on the "ping_latency_ms" field.
+func PingLatencyMsNotIn(vs ...int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldPingLatencyMs, vs...))
+}
+
+// PingLatencyMsGT applies the GT predicate on the "ping_latency_ms" field.
+func PingLatencyMsGT(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGT(FieldPingLatencyMs, v))
+}
+
+// PingLatencyMsGTE applies the GTE predicate on the "ping_latency_ms" field.
+func PingLatencyMsGTE(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldPingLatencyMs, v))
+}
+
+// PingLatencyMsLT applies the LT predicate on the "ping_latency_ms" field.
+func PingLatencyMsLT(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLT(FieldPingLatencyMs, v))
+}
+
+// PingLatencyMsLTE applies the LTE predicate on the "ping_latency_ms" field.
+func PingLatencyMsLTE(v int) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldPingLatencyMs, v))
+}
+
+// PingLatencyMsIsNil applies the IsNil predicate on the "ping_latency_ms" field.
+func PingLatencyMsIsNil() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIsNull(FieldPingLatencyMs))
+}
+
+// PingLatencyMsNotNil applies the NotNil predicate on the "ping_latency_ms" field.
+func PingLatencyMsNotNil() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotNull(FieldPingLatencyMs))
+}
+
+// MessageEQ applies the EQ predicate on the "message" field.
+func MessageEQ(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldMessage, v))
+}
+
+// MessageNEQ applies the NEQ predicate on the "message" field.
+func MessageNEQ(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldMessage, v))
+}
+
+// MessageIn applies the In predicate on the "message" field.
+func MessageIn(vs ...string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldMessage, vs...))
+}
+
+// MessageNotIn applies the NotIn predicate on the "message" field.
+func MessageNotIn(vs ...string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldMessage, vs...))
+}
+
+// MessageGT applies the GT predicate on the "message" field.
+func MessageGT(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGT(FieldMessage, v))
+}
+
+// MessageGTE applies the GTE predicate on the "message" field.
+func MessageGTE(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldMessage, v))
+}
+
+// MessageLT applies the LT predicate on the "message" field.
+func MessageLT(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLT(FieldMessage, v))
+}
+
+// MessageLTE applies the LTE predicate on the "message" field.
+func MessageLTE(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldMessage, v))
+}
+
+// MessageContains applies the Contains predicate on the "message" field.
+func MessageContains(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldContains(FieldMessage, v))
+}
+
+// MessageHasPrefix applies the HasPrefix predicate on the "message" field.
+func MessageHasPrefix(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldHasPrefix(FieldMessage, v))
+}
+
+// MessageHasSuffix applies the HasSuffix predicate on the "message" field.
+func MessageHasSuffix(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldHasSuffix(FieldMessage, v))
+}
+
+// MessageIsNil applies the IsNil predicate on the "message" field.
+func MessageIsNil() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIsNull(FieldMessage))
+}
+
+// MessageNotNil applies the NotNil predicate on the "message" field.
+func MessageNotNil() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotNull(FieldMessage))
+}
+
+// MessageEqualFold applies the EqualFold predicate on the "message" field.
+func MessageEqualFold(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEqualFold(FieldMessage, v))
+}
+
+// MessageContainsFold applies the ContainsFold predicate on the "message" field.
+func MessageContainsFold(v string) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldContainsFold(FieldMessage, v))
+}
+
+// CheckedAtEQ applies the EQ predicate on the "checked_at" field.
+func CheckedAtEQ(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldEQ(FieldCheckedAt, v))
+}
+
+// CheckedAtNEQ applies the NEQ predicate on the "checked_at" field.
+func CheckedAtNEQ(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNEQ(FieldCheckedAt, v))
+}
+
+// CheckedAtIn applies the In predicate on the "checked_at" field.
+func CheckedAtIn(vs ...time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldIn(FieldCheckedAt, vs...))
+}
+
+// CheckedAtNotIn applies the NotIn predicate on the "checked_at" field.
+func CheckedAtNotIn(vs ...time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldNotIn(FieldCheckedAt, vs...))
+}
+
+// CheckedAtGT applies the GT predicate on the "checked_at" field.
+func CheckedAtGT(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGT(FieldCheckedAt, v))
+}
+
+// CheckedAtGTE applies the GTE predicate on the "checked_at" field.
+func CheckedAtGTE(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldGTE(FieldCheckedAt, v))
+}
+
+// CheckedAtLT applies the LT predicate on the "checked_at" field.
+func CheckedAtLT(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLT(FieldCheckedAt, v))
+}
+
+// CheckedAtLTE applies the LTE predicate on the "checked_at" field.
+func CheckedAtLTE(v time.Time) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.FieldLTE(FieldCheckedAt, v))
+}
+
+// HasMonitor applies the HasEdge predicate on the "monitor" edge.
+func HasMonitor() predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, MonitorTable, MonitorColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasMonitorWith applies the HasEdge predicate on the "monitor" edge with a given conditions (other predicates).
+func HasMonitorWith(preds ...predicate.ChannelMonitor) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(func(s *sql.Selector) {
+ step := newMonitorStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.ChannelMonitorHistory) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.ChannelMonitorHistory) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.ChannelMonitorHistory) predicate.ChannelMonitorHistory {
+ return predicate.ChannelMonitorHistory(sql.NotPredicates(p))
+}
diff --git a/backend/ent/channelmonitorhistory_create.go b/backend/ent/channelmonitorhistory_create.go
new file mode 100644
index 00000000..71034865
--- /dev/null
+++ b/backend/ent/channelmonitorhistory_create.go
@@ -0,0 +1,947 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+)
+
+// ChannelMonitorHistoryCreate is the builder for creating a ChannelMonitorHistory entity.
+type ChannelMonitorHistoryCreate struct {
+ config
+ mutation *ChannelMonitorHistoryMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (_c *ChannelMonitorHistoryCreate) SetMonitorID(v int64) *ChannelMonitorHistoryCreate {
+ _c.mutation.SetMonitorID(v)
+ return _c
+}
+
+// SetModel sets the "model" field.
+func (_c *ChannelMonitorHistoryCreate) SetModel(v string) *ChannelMonitorHistoryCreate {
+ _c.mutation.SetModel(v)
+ return _c
+}
+
+// SetStatus sets the "status" field.
+func (_c *ChannelMonitorHistoryCreate) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryCreate {
+ _c.mutation.SetStatus(v)
+ return _c
+}
+
+// SetLatencyMs sets the "latency_ms" field.
+func (_c *ChannelMonitorHistoryCreate) SetLatencyMs(v int) *ChannelMonitorHistoryCreate {
+ _c.mutation.SetLatencyMs(v)
+ return _c
+}
+
+// SetNillableLatencyMs sets the "latency_ms" field if the given value is not nil.
+func (_c *ChannelMonitorHistoryCreate) SetNillableLatencyMs(v *int) *ChannelMonitorHistoryCreate {
+ if v != nil {
+ _c.SetLatencyMs(*v)
+ }
+ return _c
+}
+
+// SetPingLatencyMs sets the "ping_latency_ms" field.
+func (_c *ChannelMonitorHistoryCreate) SetPingLatencyMs(v int) *ChannelMonitorHistoryCreate {
+ _c.mutation.SetPingLatencyMs(v)
+ return _c
+}
+
+// SetNillablePingLatencyMs sets the "ping_latency_ms" field if the given value is not nil.
+func (_c *ChannelMonitorHistoryCreate) SetNillablePingLatencyMs(v *int) *ChannelMonitorHistoryCreate {
+ if v != nil {
+ _c.SetPingLatencyMs(*v)
+ }
+ return _c
+}
+
+// SetMessage sets the "message" field.
+func (_c *ChannelMonitorHistoryCreate) SetMessage(v string) *ChannelMonitorHistoryCreate {
+ _c.mutation.SetMessage(v)
+ return _c
+}
+
+// SetNillableMessage sets the "message" field if the given value is not nil.
+func (_c *ChannelMonitorHistoryCreate) SetNillableMessage(v *string) *ChannelMonitorHistoryCreate {
+ if v != nil {
+ _c.SetMessage(*v)
+ }
+ return _c
+}
+
+// SetCheckedAt sets the "checked_at" field.
+func (_c *ChannelMonitorHistoryCreate) SetCheckedAt(v time.Time) *ChannelMonitorHistoryCreate {
+ _c.mutation.SetCheckedAt(v)
+ return _c
+}
+
+// SetNillableCheckedAt sets the "checked_at" field if the given value is not nil.
+func (_c *ChannelMonitorHistoryCreate) SetNillableCheckedAt(v *time.Time) *ChannelMonitorHistoryCreate {
+ if v != nil {
+ _c.SetCheckedAt(*v)
+ }
+ return _c
+}
+
+// SetMonitor sets the "monitor" edge to the ChannelMonitor entity.
+func (_c *ChannelMonitorHistoryCreate) SetMonitor(v *ChannelMonitor) *ChannelMonitorHistoryCreate {
+ return _c.SetMonitorID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorHistoryMutation object of the builder.
+func (_c *ChannelMonitorHistoryCreate) Mutation() *ChannelMonitorHistoryMutation {
+ return _c.mutation
+}
+
+// Save creates the ChannelMonitorHistory in the database.
+func (_c *ChannelMonitorHistoryCreate) Save(ctx context.Context) (*ChannelMonitorHistory, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *ChannelMonitorHistoryCreate) SaveX(ctx context.Context) *ChannelMonitorHistory {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorHistoryCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorHistoryCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *ChannelMonitorHistoryCreate) defaults() {
+ if _, ok := _c.mutation.Message(); !ok {
+ v := channelmonitorhistory.DefaultMessage
+ _c.mutation.SetMessage(v)
+ }
+ if _, ok := _c.mutation.CheckedAt(); !ok {
+ v := channelmonitorhistory.DefaultCheckedAt()
+ _c.mutation.SetCheckedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *ChannelMonitorHistoryCreate) check() error {
+ if _, ok := _c.mutation.MonitorID(); !ok {
+ return &ValidationError{Name: "monitor_id", err: errors.New(`ent: missing required field "ChannelMonitorHistory.monitor_id"`)}
+ }
+ if _, ok := _c.mutation.Model(); !ok {
+ return &ValidationError{Name: "model", err: errors.New(`ent: missing required field "ChannelMonitorHistory.model"`)}
+ }
+ if v, ok := _c.mutation.Model(); ok {
+ if err := channelmonitorhistory.ModelValidator(v); err != nil {
+ return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.model": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Status(); !ok {
+ return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "ChannelMonitorHistory.status"`)}
+ }
+ if v, ok := _c.mutation.Status(); ok {
+ if err := channelmonitorhistory.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.status": %w`, err)}
+ }
+ }
+ if v, ok := _c.mutation.Message(); ok {
+ if err := channelmonitorhistory.MessageValidator(v); err != nil {
+ return &ValidationError{Name: "message", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.message": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.CheckedAt(); !ok {
+ return &ValidationError{Name: "checked_at", err: errors.New(`ent: missing required field "ChannelMonitorHistory.checked_at"`)}
+ }
+ if len(_c.mutation.MonitorIDs()) == 0 {
+ return &ValidationError{Name: "monitor", err: errors.New(`ent: missing required edge "ChannelMonitorHistory.monitor"`)}
+ }
+ return nil
+}
+
+func (_c *ChannelMonitorHistoryCreate) sqlSave(ctx context.Context) (*ChannelMonitorHistory, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *ChannelMonitorHistoryCreate) createSpec() (*ChannelMonitorHistory, *sqlgraph.CreateSpec) {
+ var (
+ _node = &ChannelMonitorHistory{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(channelmonitorhistory.Table, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.Model(); ok {
+ _spec.SetField(channelmonitorhistory.FieldModel, field.TypeString, value)
+ _node.Model = value
+ }
+ if value, ok := _c.mutation.Status(); ok {
+ _spec.SetField(channelmonitorhistory.FieldStatus, field.TypeEnum, value)
+ _node.Status = value
+ }
+ if value, ok := _c.mutation.LatencyMs(); ok {
+ _spec.SetField(channelmonitorhistory.FieldLatencyMs, field.TypeInt, value)
+ _node.LatencyMs = &value
+ }
+ if value, ok := _c.mutation.PingLatencyMs(); ok {
+ _spec.SetField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt, value)
+ _node.PingLatencyMs = &value
+ }
+ if value, ok := _c.mutation.Message(); ok {
+ _spec.SetField(channelmonitorhistory.FieldMessage, field.TypeString, value)
+ _node.Message = value
+ }
+ if value, ok := _c.mutation.CheckedAt(); ok {
+ _spec.SetField(channelmonitorhistory.FieldCheckedAt, field.TypeTime, value)
+ _node.CheckedAt = value
+ }
+ if nodes := _c.mutation.MonitorIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitorhistory.MonitorTable,
+ Columns: []string{channelmonitorhistory.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.MonitorID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorHistory.Create().
+// SetMonitorID(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorHistoryUpsert) {
+// SetMonitorID(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorHistoryCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorHistoryUpsertOne {
+ _c.conflict = opts
+ return &ChannelMonitorHistoryUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorHistory.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorHistoryCreate) OnConflictColumns(columns ...string) *ChannelMonitorHistoryUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorHistoryUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // ChannelMonitorHistoryUpsertOne is the builder for "upsert"-ing
+ // one ChannelMonitorHistory node.
+ ChannelMonitorHistoryUpsertOne struct {
+ create *ChannelMonitorHistoryCreate
+ }
+
+ // ChannelMonitorHistoryUpsert is the "OnConflict" setter.
+ ChannelMonitorHistoryUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetMonitorID sets the "monitor_id" field.
+func (u *ChannelMonitorHistoryUpsert) SetMonitorID(v int64) *ChannelMonitorHistoryUpsert {
+ u.Set(channelmonitorhistory.FieldMonitorID, v)
+ return u
+}
+
+// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsert) UpdateMonitorID() *ChannelMonitorHistoryUpsert {
+ u.SetExcluded(channelmonitorhistory.FieldMonitorID)
+ return u
+}
+
+// SetModel sets the "model" field.
+func (u *ChannelMonitorHistoryUpsert) SetModel(v string) *ChannelMonitorHistoryUpsert {
+ u.Set(channelmonitorhistory.FieldModel, v)
+ return u
+}
+
+// UpdateModel sets the "model" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsert) UpdateModel() *ChannelMonitorHistoryUpsert {
+ u.SetExcluded(channelmonitorhistory.FieldModel)
+ return u
+}
+
+// SetStatus sets the "status" field.
+func (u *ChannelMonitorHistoryUpsert) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryUpsert {
+ u.Set(channelmonitorhistory.FieldStatus, v)
+ return u
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsert) UpdateStatus() *ChannelMonitorHistoryUpsert {
+ u.SetExcluded(channelmonitorhistory.FieldStatus)
+ return u
+}
+
+// SetLatencyMs sets the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsert) SetLatencyMs(v int) *ChannelMonitorHistoryUpsert {
+ u.Set(channelmonitorhistory.FieldLatencyMs, v)
+ return u
+}
+
+// UpdateLatencyMs sets the "latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsert) UpdateLatencyMs() *ChannelMonitorHistoryUpsert {
+ u.SetExcluded(channelmonitorhistory.FieldLatencyMs)
+ return u
+}
+
+// AddLatencyMs adds v to the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsert) AddLatencyMs(v int) *ChannelMonitorHistoryUpsert {
+ u.Add(channelmonitorhistory.FieldLatencyMs, v)
+ return u
+}
+
+// ClearLatencyMs clears the value of the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsert) ClearLatencyMs() *ChannelMonitorHistoryUpsert {
+ u.SetNull(channelmonitorhistory.FieldLatencyMs)
+ return u
+}
+
+// SetPingLatencyMs sets the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsert) SetPingLatencyMs(v int) *ChannelMonitorHistoryUpsert {
+ u.Set(channelmonitorhistory.FieldPingLatencyMs, v)
+ return u
+}
+
+// UpdatePingLatencyMs sets the "ping_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsert) UpdatePingLatencyMs() *ChannelMonitorHistoryUpsert {
+ u.SetExcluded(channelmonitorhistory.FieldPingLatencyMs)
+ return u
+}
+
+// AddPingLatencyMs adds v to the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsert) AddPingLatencyMs(v int) *ChannelMonitorHistoryUpsert {
+ u.Add(channelmonitorhistory.FieldPingLatencyMs, v)
+ return u
+}
+
+// ClearPingLatencyMs clears the value of the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsert) ClearPingLatencyMs() *ChannelMonitorHistoryUpsert {
+ u.SetNull(channelmonitorhistory.FieldPingLatencyMs)
+ return u
+}
+
+// SetMessage sets the "message" field.
+func (u *ChannelMonitorHistoryUpsert) SetMessage(v string) *ChannelMonitorHistoryUpsert {
+ u.Set(channelmonitorhistory.FieldMessage, v)
+ return u
+}
+
+// UpdateMessage sets the "message" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsert) UpdateMessage() *ChannelMonitorHistoryUpsert {
+ u.SetExcluded(channelmonitorhistory.FieldMessage)
+ return u
+}
+
+// ClearMessage clears the value of the "message" field.
+func (u *ChannelMonitorHistoryUpsert) ClearMessage() *ChannelMonitorHistoryUpsert {
+ u.SetNull(channelmonitorhistory.FieldMessage)
+ return u
+}
+
+// SetCheckedAt sets the "checked_at" field.
+func (u *ChannelMonitorHistoryUpsert) SetCheckedAt(v time.Time) *ChannelMonitorHistoryUpsert {
+ u.Set(channelmonitorhistory.FieldCheckedAt, v)
+ return u
+}
+
+// UpdateCheckedAt sets the "checked_at" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsert) UpdateCheckedAt() *ChannelMonitorHistoryUpsert {
+ u.SetExcluded(channelmonitorhistory.FieldCheckedAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorHistory.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorHistoryUpsertOne) UpdateNewValues() *ChannelMonitorHistoryUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorHistory.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorHistoryUpsertOne) Ignore() *ChannelMonitorHistoryUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorHistoryUpsertOne) DoNothing() *ChannelMonitorHistoryUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorHistoryCreate.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorHistoryUpsertOne) Update(set func(*ChannelMonitorHistoryUpsert)) *ChannelMonitorHistoryUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorHistoryUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (u *ChannelMonitorHistoryUpsertOne) SetMonitorID(v int64) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetMonitorID(v)
+ })
+}
+
+// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertOne) UpdateMonitorID() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateMonitorID()
+ })
+}
+
+// SetModel sets the "model" field.
+func (u *ChannelMonitorHistoryUpsertOne) SetModel(v string) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetModel(v)
+ })
+}
+
+// UpdateModel sets the "model" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertOne) UpdateModel() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateModel()
+ })
+}
+
+// SetStatus sets the "status" field.
+func (u *ChannelMonitorHistoryUpsertOne) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetStatus(v)
+ })
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertOne) UpdateStatus() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateStatus()
+ })
+}
+
+// SetLatencyMs sets the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertOne) SetLatencyMs(v int) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetLatencyMs(v)
+ })
+}
+
+// AddLatencyMs adds v to the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertOne) AddLatencyMs(v int) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.AddLatencyMs(v)
+ })
+}
+
+// UpdateLatencyMs sets the "latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertOne) UpdateLatencyMs() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateLatencyMs()
+ })
+}
+
+// ClearLatencyMs clears the value of the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertOne) ClearLatencyMs() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.ClearLatencyMs()
+ })
+}
+
+// SetPingLatencyMs sets the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertOne) SetPingLatencyMs(v int) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetPingLatencyMs(v)
+ })
+}
+
+// AddPingLatencyMs adds v to the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertOne) AddPingLatencyMs(v int) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.AddPingLatencyMs(v)
+ })
+}
+
+// UpdatePingLatencyMs sets the "ping_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertOne) UpdatePingLatencyMs() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdatePingLatencyMs()
+ })
+}
+
+// ClearPingLatencyMs clears the value of the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertOne) ClearPingLatencyMs() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.ClearPingLatencyMs()
+ })
+}
+
+// SetMessage sets the "message" field.
+func (u *ChannelMonitorHistoryUpsertOne) SetMessage(v string) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetMessage(v)
+ })
+}
+
+// UpdateMessage sets the "message" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertOne) UpdateMessage() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateMessage()
+ })
+}
+
+// ClearMessage clears the value of the "message" field.
+func (u *ChannelMonitorHistoryUpsertOne) ClearMessage() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.ClearMessage()
+ })
+}
+
+// SetCheckedAt sets the "checked_at" field.
+func (u *ChannelMonitorHistoryUpsertOne) SetCheckedAt(v time.Time) *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetCheckedAt(v)
+ })
+}
+
+// UpdateCheckedAt sets the "checked_at" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertOne) UpdateCheckedAt() *ChannelMonitorHistoryUpsertOne {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateCheckedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorHistoryUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorHistoryCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorHistoryUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *ChannelMonitorHistoryUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *ChannelMonitorHistoryUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// ChannelMonitorHistoryCreateBulk is the builder for creating many ChannelMonitorHistory entities in bulk.
+type ChannelMonitorHistoryCreateBulk struct {
+ config
+ err error
+ builders []*ChannelMonitorHistoryCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the ChannelMonitorHistory entities in the database.
+func (_c *ChannelMonitorHistoryCreateBulk) Save(ctx context.Context) ([]*ChannelMonitorHistory, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*ChannelMonitorHistory, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*ChannelMonitorHistoryMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *ChannelMonitorHistoryCreateBulk) SaveX(ctx context.Context) []*ChannelMonitorHistory {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorHistoryCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorHistoryCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorHistory.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorHistoryUpsert) {
+// SetMonitorID(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorHistoryCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorHistoryUpsertBulk {
+ _c.conflict = opts
+ return &ChannelMonitorHistoryUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorHistory.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorHistoryCreateBulk) OnConflictColumns(columns ...string) *ChannelMonitorHistoryUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorHistoryUpsertBulk{
+ create: _c,
+ }
+}
+
+// ChannelMonitorHistoryUpsertBulk is the builder for "upsert"-ing
+// a bulk of ChannelMonitorHistory nodes.
+type ChannelMonitorHistoryUpsertBulk struct {
+ create *ChannelMonitorHistoryCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.ChannelMonitorHistory.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorHistoryUpsertBulk) UpdateNewValues() *ChannelMonitorHistoryUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorHistory.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorHistoryUpsertBulk) Ignore() *ChannelMonitorHistoryUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorHistoryUpsertBulk) DoNothing() *ChannelMonitorHistoryUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorHistoryCreateBulk.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorHistoryUpsertBulk) Update(set func(*ChannelMonitorHistoryUpsert)) *ChannelMonitorHistoryUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorHistoryUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (u *ChannelMonitorHistoryUpsertBulk) SetMonitorID(v int64) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetMonitorID(v)
+ })
+}
+
+// UpdateMonitorID sets the "monitor_id" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertBulk) UpdateMonitorID() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateMonitorID()
+ })
+}
+
+// SetModel sets the "model" field.
+func (u *ChannelMonitorHistoryUpsertBulk) SetModel(v string) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetModel(v)
+ })
+}
+
+// UpdateModel sets the "model" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertBulk) UpdateModel() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateModel()
+ })
+}
+
+// SetStatus sets the "status" field.
+func (u *ChannelMonitorHistoryUpsertBulk) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetStatus(v)
+ })
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertBulk) UpdateStatus() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateStatus()
+ })
+}
+
+// SetLatencyMs sets the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertBulk) SetLatencyMs(v int) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetLatencyMs(v)
+ })
+}
+
+// AddLatencyMs adds v to the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertBulk) AddLatencyMs(v int) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.AddLatencyMs(v)
+ })
+}
+
+// UpdateLatencyMs sets the "latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertBulk) UpdateLatencyMs() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateLatencyMs()
+ })
+}
+
+// ClearLatencyMs clears the value of the "latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertBulk) ClearLatencyMs() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.ClearLatencyMs()
+ })
+}
+
+// SetPingLatencyMs sets the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertBulk) SetPingLatencyMs(v int) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetPingLatencyMs(v)
+ })
+}
+
+// AddPingLatencyMs adds v to the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertBulk) AddPingLatencyMs(v int) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.AddPingLatencyMs(v)
+ })
+}
+
+// UpdatePingLatencyMs sets the "ping_latency_ms" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertBulk) UpdatePingLatencyMs() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdatePingLatencyMs()
+ })
+}
+
+// ClearPingLatencyMs clears the value of the "ping_latency_ms" field.
+func (u *ChannelMonitorHistoryUpsertBulk) ClearPingLatencyMs() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.ClearPingLatencyMs()
+ })
+}
+
+// SetMessage sets the "message" field.
+func (u *ChannelMonitorHistoryUpsertBulk) SetMessage(v string) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetMessage(v)
+ })
+}
+
+// UpdateMessage sets the "message" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertBulk) UpdateMessage() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateMessage()
+ })
+}
+
+// ClearMessage clears the value of the "message" field.
+func (u *ChannelMonitorHistoryUpsertBulk) ClearMessage() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.ClearMessage()
+ })
+}
+
+// SetCheckedAt sets the "checked_at" field.
+func (u *ChannelMonitorHistoryUpsertBulk) SetCheckedAt(v time.Time) *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.SetCheckedAt(v)
+ })
+}
+
+// UpdateCheckedAt sets the "checked_at" field to the value that was provided on create.
+func (u *ChannelMonitorHistoryUpsertBulk) UpdateCheckedAt() *ChannelMonitorHistoryUpsertBulk {
+ return u.Update(func(s *ChannelMonitorHistoryUpsert) {
+ s.UpdateCheckedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorHistoryUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ChannelMonitorHistoryCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorHistoryCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorHistoryUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitorhistory_delete.go b/backend/ent/channelmonitorhistory_delete.go
new file mode 100644
index 00000000..97110e69
--- /dev/null
+++ b/backend/ent/channelmonitorhistory_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorHistoryDelete is the builder for deleting a ChannelMonitorHistory entity.
+type ChannelMonitorHistoryDelete struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorHistoryMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorHistoryDelete builder.
+func (_d *ChannelMonitorHistoryDelete) Where(ps ...predicate.ChannelMonitorHistory) *ChannelMonitorHistoryDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *ChannelMonitorHistoryDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorHistoryDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *ChannelMonitorHistoryDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(channelmonitorhistory.Table, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// ChannelMonitorHistoryDeleteOne is the builder for deleting a single ChannelMonitorHistory entity.
+type ChannelMonitorHistoryDeleteOne struct {
+ _d *ChannelMonitorHistoryDelete
+}
+
+// Where appends a list predicates to the ChannelMonitorHistoryDelete builder.
+func (_d *ChannelMonitorHistoryDeleteOne) Where(ps ...predicate.ChannelMonitorHistory) *ChannelMonitorHistoryDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *ChannelMonitorHistoryDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{channelmonitorhistory.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorHistoryDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitorhistory_query.go b/backend/ent/channelmonitorhistory_query.go
new file mode 100644
index 00000000..1fb872ad
--- /dev/null
+++ b/backend/ent/channelmonitorhistory_query.go
@@ -0,0 +1,643 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorHistoryQuery is the builder for querying ChannelMonitorHistory entities.
+type ChannelMonitorHistoryQuery struct {
+ config
+ ctx *QueryContext
+ order []channelmonitorhistory.OrderOption
+ inters []Interceptor
+ predicates []predicate.ChannelMonitorHistory
+ withMonitor *ChannelMonitorQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the ChannelMonitorHistoryQuery builder.
+func (_q *ChannelMonitorHistoryQuery) Where(ps ...predicate.ChannelMonitorHistory) *ChannelMonitorHistoryQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *ChannelMonitorHistoryQuery) Limit(limit int) *ChannelMonitorHistoryQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *ChannelMonitorHistoryQuery) Offset(offset int) *ChannelMonitorHistoryQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *ChannelMonitorHistoryQuery) Unique(unique bool) *ChannelMonitorHistoryQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *ChannelMonitorHistoryQuery) Order(o ...channelmonitorhistory.OrderOption) *ChannelMonitorHistoryQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryMonitor chains the current query on the "monitor" edge.
+func (_q *ChannelMonitorHistoryQuery) QueryMonitor() *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitorhistory.Table, channelmonitorhistory.FieldID, selector),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, channelmonitorhistory.MonitorTable, channelmonitorhistory.MonitorColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first ChannelMonitorHistory entity from the query.
+// Returns a *NotFoundError when no ChannelMonitorHistory was found.
+func (_q *ChannelMonitorHistoryQuery) First(ctx context.Context) (*ChannelMonitorHistory, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{channelmonitorhistory.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) FirstX(ctx context.Context) *ChannelMonitorHistory {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first ChannelMonitorHistory ID from the query.
+// Returns a *NotFoundError when no ChannelMonitorHistory ID was found.
+func (_q *ChannelMonitorHistoryQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{channelmonitorhistory.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single ChannelMonitorHistory entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one ChannelMonitorHistory entity is found.
+// Returns a *NotFoundError when no ChannelMonitorHistory entities are found.
+func (_q *ChannelMonitorHistoryQuery) Only(ctx context.Context) (*ChannelMonitorHistory, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{channelmonitorhistory.Label}
+ default:
+ return nil, &NotSingularError{channelmonitorhistory.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) OnlyX(ctx context.Context) *ChannelMonitorHistory {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only ChannelMonitorHistory ID in the query.
+// Returns a *NotSingularError when more than one ChannelMonitorHistory ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *ChannelMonitorHistoryQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{channelmonitorhistory.Label}
+ default:
+ err = &NotSingularError{channelmonitorhistory.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of ChannelMonitorHistories.
+func (_q *ChannelMonitorHistoryQuery) All(ctx context.Context) ([]*ChannelMonitorHistory, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*ChannelMonitorHistory, *ChannelMonitorHistoryQuery]()
+ return withInterceptors[[]*ChannelMonitorHistory](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) AllX(ctx context.Context) []*ChannelMonitorHistory {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of ChannelMonitorHistory IDs.
+func (_q *ChannelMonitorHistoryQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(channelmonitorhistory.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *ChannelMonitorHistoryQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*ChannelMonitorHistoryQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *ChannelMonitorHistoryQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *ChannelMonitorHistoryQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the ChannelMonitorHistoryQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *ChannelMonitorHistoryQuery) Clone() *ChannelMonitorHistoryQuery {
+ if _q == nil {
+ return nil
+ }
+ return &ChannelMonitorHistoryQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]channelmonitorhistory.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.ChannelMonitorHistory{}, _q.predicates...),
+ withMonitor: _q.withMonitor.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithMonitor tells the query-builder to eager-load the nodes that are connected to
+// the "monitor" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorHistoryQuery) WithMonitor(opts ...func(*ChannelMonitorQuery)) *ChannelMonitorHistoryQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withMonitor = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// MonitorID int64 `json:"monitor_id,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.ChannelMonitorHistory.Query().
+// GroupBy(channelmonitorhistory.FieldMonitorID).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorHistoryQuery) GroupBy(field string, fields ...string) *ChannelMonitorHistoryGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &ChannelMonitorHistoryGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = channelmonitorhistory.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// MonitorID int64 `json:"monitor_id,omitempty"`
+// }
+//
+// client.ChannelMonitorHistory.Query().
+// Select(channelmonitorhistory.FieldMonitorID).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorHistoryQuery) Select(fields ...string) *ChannelMonitorHistorySelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &ChannelMonitorHistorySelect{ChannelMonitorHistoryQuery: _q}
+ sbuild.label = channelmonitorhistory.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a ChannelMonitorHistorySelect configured with the given aggregations.
+func (_q *ChannelMonitorHistoryQuery) Aggregate(fns ...AggregateFunc) *ChannelMonitorHistorySelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *ChannelMonitorHistoryQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !channelmonitorhistory.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorHistoryQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ChannelMonitorHistory, error) {
+ var (
+ nodes = []*ChannelMonitorHistory{}
+ _spec = _q.querySpec()
+ loadedTypes = [1]bool{
+ _q.withMonitor != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*ChannelMonitorHistory).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &ChannelMonitorHistory{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withMonitor; query != nil {
+ if err := _q.loadMonitor(ctx, query, nodes, nil,
+ func(n *ChannelMonitorHistory, e *ChannelMonitor) { n.Edges.Monitor = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *ChannelMonitorHistoryQuery) loadMonitor(ctx context.Context, query *ChannelMonitorQuery, nodes []*ChannelMonitorHistory, init func(*ChannelMonitorHistory), assign func(*ChannelMonitorHistory, *ChannelMonitor)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*ChannelMonitorHistory)
+ for i := range nodes {
+ fk := nodes[i].MonitorID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(channelmonitor.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "monitor_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorHistoryQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *ChannelMonitorHistoryQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(channelmonitorhistory.Table, channelmonitorhistory.Columns, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitorhistory.FieldID)
+ for i := range fields {
+ if fields[i] != channelmonitorhistory.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withMonitor != nil {
+ _spec.Node.AddColumnOnce(channelmonitorhistory.FieldMonitorID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *ChannelMonitorHistoryQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(channelmonitorhistory.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = channelmonitorhistory.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *ChannelMonitorHistoryQuery) ForUpdate(opts ...sql.LockOption) *ChannelMonitorHistoryQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *ChannelMonitorHistoryQuery) ForShare(opts ...sql.LockOption) *ChannelMonitorHistoryQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// ChannelMonitorHistoryGroupBy is the group-by builder for ChannelMonitorHistory entities.
+type ChannelMonitorHistoryGroupBy struct {
+ selector
+ build *ChannelMonitorHistoryQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *ChannelMonitorHistoryGroupBy) Aggregate(fns ...AggregateFunc) *ChannelMonitorHistoryGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *ChannelMonitorHistoryGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorHistoryQuery, *ChannelMonitorHistoryGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *ChannelMonitorHistoryGroupBy) sqlScan(ctx context.Context, root *ChannelMonitorHistoryQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// ChannelMonitorHistorySelect is the builder for selecting fields of ChannelMonitorHistory entities.
+type ChannelMonitorHistorySelect struct {
+ *ChannelMonitorHistoryQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *ChannelMonitorHistorySelect) Aggregate(fns ...AggregateFunc) *ChannelMonitorHistorySelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *ChannelMonitorHistorySelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorHistoryQuery, *ChannelMonitorHistorySelect](ctx, _s.ChannelMonitorHistoryQuery, _s, _s.inters, v)
+}
+
+func (_s *ChannelMonitorHistorySelect) sqlScan(ctx context.Context, root *ChannelMonitorHistoryQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/channelmonitorhistory_update.go b/backend/ent/channelmonitorhistory_update.go
new file mode 100644
index 00000000..a85a8072
--- /dev/null
+++ b/backend/ent/channelmonitorhistory_update.go
@@ -0,0 +1,635 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorHistoryUpdate is the builder for updating ChannelMonitorHistory entities.
+type ChannelMonitorHistoryUpdate struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorHistoryMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorHistoryUpdate builder.
+func (_u *ChannelMonitorHistoryUpdate) Where(ps ...predicate.ChannelMonitorHistory) *ChannelMonitorHistoryUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (_u *ChannelMonitorHistoryUpdate) SetMonitorID(v int64) *ChannelMonitorHistoryUpdate {
+ _u.mutation.SetMonitorID(v)
+ return _u
+}
+
+// SetNillableMonitorID sets the "monitor_id" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdate) SetNillableMonitorID(v *int64) *ChannelMonitorHistoryUpdate {
+ if v != nil {
+ _u.SetMonitorID(*v)
+ }
+ return _u
+}
+
+// SetModel sets the "model" field.
+func (_u *ChannelMonitorHistoryUpdate) SetModel(v string) *ChannelMonitorHistoryUpdate {
+ _u.mutation.SetModel(v)
+ return _u
+}
+
+// SetNillableModel sets the "model" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdate) SetNillableModel(v *string) *ChannelMonitorHistoryUpdate {
+ if v != nil {
+ _u.SetModel(*v)
+ }
+ return _u
+}
+
+// SetStatus sets the "status" field.
+func (_u *ChannelMonitorHistoryUpdate) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryUpdate {
+ _u.mutation.SetStatus(v)
+ return _u
+}
+
+// SetNillableStatus sets the "status" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdate) SetNillableStatus(v *channelmonitorhistory.Status) *ChannelMonitorHistoryUpdate {
+ if v != nil {
+ _u.SetStatus(*v)
+ }
+ return _u
+}
+
+// SetLatencyMs sets the "latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdate) SetLatencyMs(v int) *ChannelMonitorHistoryUpdate {
+ _u.mutation.ResetLatencyMs()
+ _u.mutation.SetLatencyMs(v)
+ return _u
+}
+
+// SetNillableLatencyMs sets the "latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdate) SetNillableLatencyMs(v *int) *ChannelMonitorHistoryUpdate {
+ if v != nil {
+ _u.SetLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddLatencyMs adds value to the "latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdate) AddLatencyMs(v int) *ChannelMonitorHistoryUpdate {
+ _u.mutation.AddLatencyMs(v)
+ return _u
+}
+
+// ClearLatencyMs clears the value of the "latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdate) ClearLatencyMs() *ChannelMonitorHistoryUpdate {
+ _u.mutation.ClearLatencyMs()
+ return _u
+}
+
+// SetPingLatencyMs sets the "ping_latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdate) SetPingLatencyMs(v int) *ChannelMonitorHistoryUpdate {
+ _u.mutation.ResetPingLatencyMs()
+ _u.mutation.SetPingLatencyMs(v)
+ return _u
+}
+
+// SetNillablePingLatencyMs sets the "ping_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdate) SetNillablePingLatencyMs(v *int) *ChannelMonitorHistoryUpdate {
+ if v != nil {
+ _u.SetPingLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddPingLatencyMs adds value to the "ping_latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdate) AddPingLatencyMs(v int) *ChannelMonitorHistoryUpdate {
+ _u.mutation.AddPingLatencyMs(v)
+ return _u
+}
+
+// ClearPingLatencyMs clears the value of the "ping_latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdate) ClearPingLatencyMs() *ChannelMonitorHistoryUpdate {
+ _u.mutation.ClearPingLatencyMs()
+ return _u
+}
+
+// SetMessage sets the "message" field.
+func (_u *ChannelMonitorHistoryUpdate) SetMessage(v string) *ChannelMonitorHistoryUpdate {
+ _u.mutation.SetMessage(v)
+ return _u
+}
+
+// SetNillableMessage sets the "message" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdate) SetNillableMessage(v *string) *ChannelMonitorHistoryUpdate {
+ if v != nil {
+ _u.SetMessage(*v)
+ }
+ return _u
+}
+
+// ClearMessage clears the value of the "message" field.
+func (_u *ChannelMonitorHistoryUpdate) ClearMessage() *ChannelMonitorHistoryUpdate {
+ _u.mutation.ClearMessage()
+ return _u
+}
+
+// SetCheckedAt sets the "checked_at" field.
+func (_u *ChannelMonitorHistoryUpdate) SetCheckedAt(v time.Time) *ChannelMonitorHistoryUpdate {
+ _u.mutation.SetCheckedAt(v)
+ return _u
+}
+
+// SetNillableCheckedAt sets the "checked_at" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdate) SetNillableCheckedAt(v *time.Time) *ChannelMonitorHistoryUpdate {
+ if v != nil {
+ _u.SetCheckedAt(*v)
+ }
+ return _u
+}
+
+// SetMonitor sets the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorHistoryUpdate) SetMonitor(v *ChannelMonitor) *ChannelMonitorHistoryUpdate {
+ return _u.SetMonitorID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorHistoryMutation object of the builder.
+func (_u *ChannelMonitorHistoryUpdate) Mutation() *ChannelMonitorHistoryMutation {
+ return _u.mutation
+}
+
+// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorHistoryUpdate) ClearMonitor() *ChannelMonitorHistoryUpdate {
+ _u.mutation.ClearMonitor()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *ChannelMonitorHistoryUpdate) Save(ctx context.Context) (int, error) {
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorHistoryUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *ChannelMonitorHistoryUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorHistoryUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorHistoryUpdate) check() error {
+ if v, ok := _u.mutation.Model(); ok {
+ if err := channelmonitorhistory.ModelValidator(v); err != nil {
+ return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.model": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Status(); ok {
+ if err := channelmonitorhistory.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.status": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Message(); ok {
+ if err := channelmonitorhistory.MessageValidator(v); err != nil {
+ return &ValidationError{Name: "message", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.message": %w`, err)}
+ }
+ }
+ if _u.mutation.MonitorCleared() && len(_u.mutation.MonitorIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "ChannelMonitorHistory.monitor"`)
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorHistoryUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitorhistory.Table, channelmonitorhistory.Columns, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.Model(); ok {
+ _spec.SetField(channelmonitorhistory.FieldModel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Status(); ok {
+ _spec.SetField(channelmonitorhistory.FieldStatus, field.TypeEnum, value)
+ }
+ if value, ok := _u.mutation.LatencyMs(); ok {
+ _spec.SetField(channelmonitorhistory.FieldLatencyMs, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedLatencyMs(); ok {
+ _spec.AddField(channelmonitorhistory.FieldLatencyMs, field.TypeInt, value)
+ }
+ if _u.mutation.LatencyMsCleared() {
+ _spec.ClearField(channelmonitorhistory.FieldLatencyMs, field.TypeInt)
+ }
+ if value, ok := _u.mutation.PingLatencyMs(); ok {
+ _spec.SetField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedPingLatencyMs(); ok {
+ _spec.AddField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt, value)
+ }
+ if _u.mutation.PingLatencyMsCleared() {
+ _spec.ClearField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt)
+ }
+ if value, ok := _u.mutation.Message(); ok {
+ _spec.SetField(channelmonitorhistory.FieldMessage, field.TypeString, value)
+ }
+ if _u.mutation.MessageCleared() {
+ _spec.ClearField(channelmonitorhistory.FieldMessage, field.TypeString)
+ }
+ if value, ok := _u.mutation.CheckedAt(); ok {
+ _spec.SetField(channelmonitorhistory.FieldCheckedAt, field.TypeTime, value)
+ }
+ if _u.mutation.MonitorCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitorhistory.MonitorTable,
+ Columns: []string{channelmonitorhistory.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitorhistory.MonitorTable,
+ Columns: []string{channelmonitorhistory.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitorhistory.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// ChannelMonitorHistoryUpdateOne is the builder for updating a single ChannelMonitorHistory entity.
+type ChannelMonitorHistoryUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *ChannelMonitorHistoryMutation
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (_u *ChannelMonitorHistoryUpdateOne) SetMonitorID(v int64) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.SetMonitorID(v)
+ return _u
+}
+
+// SetNillableMonitorID sets the "monitor_id" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdateOne) SetNillableMonitorID(v *int64) *ChannelMonitorHistoryUpdateOne {
+ if v != nil {
+ _u.SetMonitorID(*v)
+ }
+ return _u
+}
+
+// SetModel sets the "model" field.
+func (_u *ChannelMonitorHistoryUpdateOne) SetModel(v string) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.SetModel(v)
+ return _u
+}
+
+// SetNillableModel sets the "model" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdateOne) SetNillableModel(v *string) *ChannelMonitorHistoryUpdateOne {
+ if v != nil {
+ _u.SetModel(*v)
+ }
+ return _u
+}
+
+// SetStatus sets the "status" field.
+func (_u *ChannelMonitorHistoryUpdateOne) SetStatus(v channelmonitorhistory.Status) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.SetStatus(v)
+ return _u
+}
+
+// SetNillableStatus sets the "status" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdateOne) SetNillableStatus(v *channelmonitorhistory.Status) *ChannelMonitorHistoryUpdateOne {
+ if v != nil {
+ _u.SetStatus(*v)
+ }
+ return _u
+}
+
+// SetLatencyMs sets the "latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdateOne) SetLatencyMs(v int) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.ResetLatencyMs()
+ _u.mutation.SetLatencyMs(v)
+ return _u
+}
+
+// SetNillableLatencyMs sets the "latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdateOne) SetNillableLatencyMs(v *int) *ChannelMonitorHistoryUpdateOne {
+ if v != nil {
+ _u.SetLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddLatencyMs adds value to the "latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdateOne) AddLatencyMs(v int) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.AddLatencyMs(v)
+ return _u
+}
+
+// ClearLatencyMs clears the value of the "latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdateOne) ClearLatencyMs() *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.ClearLatencyMs()
+ return _u
+}
+
+// SetPingLatencyMs sets the "ping_latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdateOne) SetPingLatencyMs(v int) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.ResetPingLatencyMs()
+ _u.mutation.SetPingLatencyMs(v)
+ return _u
+}
+
+// SetNillablePingLatencyMs sets the "ping_latency_ms" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdateOne) SetNillablePingLatencyMs(v *int) *ChannelMonitorHistoryUpdateOne {
+ if v != nil {
+ _u.SetPingLatencyMs(*v)
+ }
+ return _u
+}
+
+// AddPingLatencyMs adds value to the "ping_latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdateOne) AddPingLatencyMs(v int) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.AddPingLatencyMs(v)
+ return _u
+}
+
+// ClearPingLatencyMs clears the value of the "ping_latency_ms" field.
+func (_u *ChannelMonitorHistoryUpdateOne) ClearPingLatencyMs() *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.ClearPingLatencyMs()
+ return _u
+}
+
+// SetMessage sets the "message" field.
+func (_u *ChannelMonitorHistoryUpdateOne) SetMessage(v string) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.SetMessage(v)
+ return _u
+}
+
+// SetNillableMessage sets the "message" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdateOne) SetNillableMessage(v *string) *ChannelMonitorHistoryUpdateOne {
+ if v != nil {
+ _u.SetMessage(*v)
+ }
+ return _u
+}
+
+// ClearMessage clears the value of the "message" field.
+func (_u *ChannelMonitorHistoryUpdateOne) ClearMessage() *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.ClearMessage()
+ return _u
+}
+
+// SetCheckedAt sets the "checked_at" field.
+func (_u *ChannelMonitorHistoryUpdateOne) SetCheckedAt(v time.Time) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.SetCheckedAt(v)
+ return _u
+}
+
+// SetNillableCheckedAt sets the "checked_at" field if the given value is not nil.
+func (_u *ChannelMonitorHistoryUpdateOne) SetNillableCheckedAt(v *time.Time) *ChannelMonitorHistoryUpdateOne {
+ if v != nil {
+ _u.SetCheckedAt(*v)
+ }
+ return _u
+}
+
+// SetMonitor sets the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorHistoryUpdateOne) SetMonitor(v *ChannelMonitor) *ChannelMonitorHistoryUpdateOne {
+ return _u.SetMonitorID(v.ID)
+}
+
+// Mutation returns the ChannelMonitorHistoryMutation object of the builder.
+func (_u *ChannelMonitorHistoryUpdateOne) Mutation() *ChannelMonitorHistoryMutation {
+ return _u.mutation
+}
+
+// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity.
+func (_u *ChannelMonitorHistoryUpdateOne) ClearMonitor() *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.ClearMonitor()
+ return _u
+}
+
+// Where appends a list predicates to the ChannelMonitorHistoryUpdate builder.
+func (_u *ChannelMonitorHistoryUpdateOne) Where(ps ...predicate.ChannelMonitorHistory) *ChannelMonitorHistoryUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *ChannelMonitorHistoryUpdateOne) Select(field string, fields ...string) *ChannelMonitorHistoryUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated ChannelMonitorHistory entity.
+func (_u *ChannelMonitorHistoryUpdateOne) Save(ctx context.Context) (*ChannelMonitorHistory, error) {
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorHistoryUpdateOne) SaveX(ctx context.Context) *ChannelMonitorHistory {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *ChannelMonitorHistoryUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorHistoryUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorHistoryUpdateOne) check() error {
+ if v, ok := _u.mutation.Model(); ok {
+ if err := channelmonitorhistory.ModelValidator(v); err != nil {
+ return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.model": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Status(); ok {
+ if err := channelmonitorhistory.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.status": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Message(); ok {
+ if err := channelmonitorhistory.MessageValidator(v); err != nil {
+ return &ValidationError{Name: "message", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorHistory.message": %w`, err)}
+ }
+ }
+ if _u.mutation.MonitorCleared() && len(_u.mutation.MonitorIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "ChannelMonitorHistory.monitor"`)
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorHistoryUpdateOne) sqlSave(ctx context.Context) (_node *ChannelMonitorHistory, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitorhistory.Table, channelmonitorhistory.Columns, sqlgraph.NewFieldSpec(channelmonitorhistory.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ChannelMonitorHistory.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitorhistory.FieldID)
+ for _, f := range fields {
+ if !channelmonitorhistory.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != channelmonitorhistory.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.Model(); ok {
+ _spec.SetField(channelmonitorhistory.FieldModel, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Status(); ok {
+ _spec.SetField(channelmonitorhistory.FieldStatus, field.TypeEnum, value)
+ }
+ if value, ok := _u.mutation.LatencyMs(); ok {
+ _spec.SetField(channelmonitorhistory.FieldLatencyMs, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedLatencyMs(); ok {
+ _spec.AddField(channelmonitorhistory.FieldLatencyMs, field.TypeInt, value)
+ }
+ if _u.mutation.LatencyMsCleared() {
+ _spec.ClearField(channelmonitorhistory.FieldLatencyMs, field.TypeInt)
+ }
+ if value, ok := _u.mutation.PingLatencyMs(); ok {
+ _spec.SetField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedPingLatencyMs(); ok {
+ _spec.AddField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt, value)
+ }
+ if _u.mutation.PingLatencyMsCleared() {
+ _spec.ClearField(channelmonitorhistory.FieldPingLatencyMs, field.TypeInt)
+ }
+ if value, ok := _u.mutation.Message(); ok {
+ _spec.SetField(channelmonitorhistory.FieldMessage, field.TypeString, value)
+ }
+ if _u.mutation.MessageCleared() {
+ _spec.ClearField(channelmonitorhistory.FieldMessage, field.TypeString)
+ }
+ if value, ok := _u.mutation.CheckedAt(); ok {
+ _spec.SetField(channelmonitorhistory.FieldCheckedAt, field.TypeTime, value)
+ }
+ if _u.mutation.MonitorCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitorhistory.MonitorTable,
+ Columns: []string{channelmonitorhistory.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: channelmonitorhistory.MonitorTable,
+ Columns: []string{channelmonitorhistory.MonitorColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &ChannelMonitorHistory{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitorhistory.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/channelmonitorrequesttemplate.go b/backend/ent/channelmonitorrequesttemplate.go
new file mode 100644
index 00000000..b8429a4d
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate.go
@@ -0,0 +1,216 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+)
+
+// ChannelMonitorRequestTemplate is the model entity for the ChannelMonitorRequestTemplate schema.
+type ChannelMonitorRequestTemplate struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // Name holds the value of the "name" field.
+ Name string `json:"name,omitempty"`
+ // Provider holds the value of the "provider" field.
+ Provider channelmonitorrequesttemplate.Provider `json:"provider,omitempty"`
+ // Description holds the value of the "description" field.
+ Description string `json:"description,omitempty"`
+ // ExtraHeaders holds the value of the "extra_headers" field.
+ ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
+ // BodyOverrideMode holds the value of the "body_override_mode" field.
+ BodyOverrideMode string `json:"body_override_mode,omitempty"`
+ // BodyOverride holds the value of the "body_override" field.
+ BodyOverride map[string]interface{} `json:"body_override,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the ChannelMonitorRequestTemplateQuery when eager-loading is set.
+ Edges ChannelMonitorRequestTemplateEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// ChannelMonitorRequestTemplateEdges holds the relations/edges for other nodes in the graph.
+type ChannelMonitorRequestTemplateEdges struct {
+ // Monitors holds the value of the monitors edge.
+ Monitors []*ChannelMonitor `json:"monitors,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// MonitorsOrErr returns the Monitors value or an error if the edge
+// was not loaded in eager-loading.
+func (e ChannelMonitorRequestTemplateEdges) MonitorsOrErr() ([]*ChannelMonitor, error) {
+ if e.loadedTypes[0] {
+ return e.Monitors, nil
+ }
+ return nil, &NotLoadedError{edge: "monitors"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*ChannelMonitorRequestTemplate) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitorrequesttemplate.FieldExtraHeaders, channelmonitorrequesttemplate.FieldBodyOverride:
+ values[i] = new([]byte)
+ case channelmonitorrequesttemplate.FieldID:
+ values[i] = new(sql.NullInt64)
+ case channelmonitorrequesttemplate.FieldName, channelmonitorrequesttemplate.FieldProvider, channelmonitorrequesttemplate.FieldDescription, channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ values[i] = new(sql.NullString)
+ case channelmonitorrequesttemplate.FieldCreatedAt, channelmonitorrequesttemplate.FieldUpdatedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the ChannelMonitorRequestTemplate fields.
+func (_m *ChannelMonitorRequestTemplate) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case channelmonitorrequesttemplate.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case channelmonitorrequesttemplate.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case channelmonitorrequesttemplate.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case channelmonitorrequesttemplate.FieldName:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field name", values[i])
+ } else if value.Valid {
+ _m.Name = value.String
+ }
+ case channelmonitorrequesttemplate.FieldProvider:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider", values[i])
+ } else if value.Valid {
+ _m.Provider = channelmonitorrequesttemplate.Provider(value.String)
+ }
+ case channelmonitorrequesttemplate.FieldDescription:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field description", values[i])
+ } else if value.Valid {
+ _m.Description = value.String
+ }
+ case channelmonitorrequesttemplate.FieldExtraHeaders:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field extra_headers", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.ExtraHeaders); err != nil {
+ return fmt.Errorf("unmarshal field extra_headers: %w", err)
+ }
+ }
+ case channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field body_override_mode", values[i])
+ } else if value.Valid {
+ _m.BodyOverrideMode = value.String
+ }
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field body_override", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.BodyOverride); err != nil {
+ return fmt.Errorf("unmarshal field body_override: %w", err)
+ }
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the ChannelMonitorRequestTemplate.
+// This includes values selected through modifiers, order, etc.
+func (_m *ChannelMonitorRequestTemplate) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryMonitors queries the "monitors" edge of the ChannelMonitorRequestTemplate entity.
+func (_m *ChannelMonitorRequestTemplate) QueryMonitors() *ChannelMonitorQuery {
+ return NewChannelMonitorRequestTemplateClient(_m.config).QueryMonitors(_m)
+}
+
+// Update returns a builder for updating this ChannelMonitorRequestTemplate.
+// Note that you need to call ChannelMonitorRequestTemplate.Unwrap() before calling this method if this ChannelMonitorRequestTemplate
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *ChannelMonitorRequestTemplate) Update() *ChannelMonitorRequestTemplateUpdateOne {
+ return NewChannelMonitorRequestTemplateClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the ChannelMonitorRequestTemplate entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *ChannelMonitorRequestTemplate) Unwrap() *ChannelMonitorRequestTemplate {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: ChannelMonitorRequestTemplate is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *ChannelMonitorRequestTemplate) String() string {
+ var builder strings.Builder
+ builder.WriteString("ChannelMonitorRequestTemplate(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("name=")
+ builder.WriteString(_m.Name)
+ builder.WriteString(", ")
+ builder.WriteString("provider=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Provider))
+ builder.WriteString(", ")
+ builder.WriteString("description=")
+ builder.WriteString(_m.Description)
+ builder.WriteString(", ")
+ builder.WriteString("extra_headers=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ExtraHeaders))
+ builder.WriteString(", ")
+ builder.WriteString("body_override_mode=")
+ builder.WriteString(_m.BodyOverrideMode)
+ builder.WriteString(", ")
+ builder.WriteString("body_override=")
+ builder.WriteString(fmt.Sprintf("%v", _m.BodyOverride))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// ChannelMonitorRequestTemplates is a parsable slice of ChannelMonitorRequestTemplate.
+type ChannelMonitorRequestTemplates []*ChannelMonitorRequestTemplate
diff --git a/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go b/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go
new file mode 100644
index 00000000..65b8d641
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go
@@ -0,0 +1,172 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitorrequesttemplate
+
+import (
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the channelmonitorrequesttemplate type in the database.
+ Label = "channel_monitor_request_template"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldName holds the string denoting the name field in the database.
+ FieldName = "name"
+ // FieldProvider holds the string denoting the provider field in the database.
+ FieldProvider = "provider"
+ // FieldDescription holds the string denoting the description field in the database.
+ FieldDescription = "description"
+ // FieldExtraHeaders holds the string denoting the extra_headers field in the database.
+ FieldExtraHeaders = "extra_headers"
+ // FieldBodyOverrideMode holds the string denoting the body_override_mode field in the database.
+ FieldBodyOverrideMode = "body_override_mode"
+ // FieldBodyOverride holds the string denoting the body_override field in the database.
+ FieldBodyOverride = "body_override"
+ // EdgeMonitors holds the string denoting the monitors edge name in mutations.
+ EdgeMonitors = "monitors"
+ // Table holds the table name of the channelmonitorrequesttemplate in the database.
+ Table = "channel_monitor_request_templates"
+ // MonitorsTable is the table that holds the monitors relation/edge.
+ MonitorsTable = "channel_monitors"
+ // MonitorsInverseTable is the table name for the ChannelMonitor entity.
+ // It exists in this package in order to avoid circular dependency with the "channelmonitor" package.
+ MonitorsInverseTable = "channel_monitors"
+ // MonitorsColumn is the table column denoting the monitors relation/edge.
+ MonitorsColumn = "template_id"
+)
+
+// Columns holds all SQL columns for channelmonitorrequesttemplate fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldName,
+ FieldProvider,
+ FieldDescription,
+ FieldExtraHeaders,
+ FieldBodyOverrideMode,
+ FieldBodyOverride,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // NameValidator is a validator for the "name" field. It is called by the builders before save.
+ NameValidator func(string) error
+ // DefaultDescription holds the default value on creation for the "description" field.
+ DefaultDescription string
+ // DescriptionValidator is a validator for the "description" field. It is called by the builders before save.
+ DescriptionValidator func(string) error
+ // DefaultExtraHeaders holds the default value on creation for the "extra_headers" field.
+ DefaultExtraHeaders map[string]string
+ // DefaultBodyOverrideMode holds the default value on creation for the "body_override_mode" field.
+ DefaultBodyOverrideMode string
+ // BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
+ BodyOverrideModeValidator func(string) error
+)
+
+// Provider defines the type for the "provider" enum field.
+type Provider string
+
+// Provider values.
+const (
+ ProviderOpenai Provider = "openai"
+ ProviderAnthropic Provider = "anthropic"
+ ProviderGemini Provider = "gemini"
+)
+
+func (pr Provider) String() string {
+ return string(pr)
+}
+
+// ProviderValidator is a validator for the "provider" field enum values. It is called by the builders before save.
+func ProviderValidator(pr Provider) error {
+ switch pr {
+ case ProviderOpenai, ProviderAnthropic, ProviderGemini:
+ return nil
+ default:
+ return fmt.Errorf("channelmonitorrequesttemplate: invalid enum value for provider field: %q", pr)
+ }
+}
+
+// OrderOption defines the ordering options for the ChannelMonitorRequestTemplate queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByName orders the results by the name field.
+func ByName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldName, opts...).ToFunc()
+}
+
+// ByProvider orders the results by the provider field.
+func ByProvider(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProvider, opts...).ToFunc()
+}
+
+// ByDescription orders the results by the description field.
+func ByDescription(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDescription, opts...).ToFunc()
+}
+
+// ByBodyOverrideMode orders the results by the body_override_mode field.
+func ByBodyOverrideMode(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBodyOverrideMode, opts...).ToFunc()
+}
+
+// ByMonitorsCount orders the results by monitors count.
+func ByMonitorsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newMonitorsStep(), opts...)
+ }
+}
+
+// ByMonitors orders the results by monitors terms.
+func ByMonitors(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newMonitorsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+func newMonitorsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(MonitorsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, true, MonitorsTable, MonitorsColumn),
+ )
+}
diff --git a/backend/ent/channelmonitorrequesttemplate/where.go b/backend/ent/channelmonitorrequesttemplate/where.go
new file mode 100644
index 00000000..b95e5df0
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate/where.go
@@ -0,0 +1,434 @@
+// Code generated by ent, DO NOT EDIT.
+
+package channelmonitorrequesttemplate
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
+func Name(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldName, v))
+}
+
+// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
+func Description(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldDescription, v))
+}
+
+// BodyOverrideMode applies equality check predicate on the "body_override_mode" field. It's identical to BodyOverrideModeEQ.
+func BodyOverrideMode(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldBodyOverrideMode, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// NameEQ applies the EQ predicate on the "name" field.
+func NameEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldName, v))
+}
+
+// NameNEQ applies the NEQ predicate on the "name" field.
+func NameNEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldName, v))
+}
+
+// NameIn applies the In predicate on the "name" field.
+func NameIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldName, vs...))
+}
+
+// NameNotIn applies the NotIn predicate on the "name" field.
+func NameNotIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldName, vs...))
+}
+
+// NameGT applies the GT predicate on the "name" field.
+func NameGT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldName, v))
+}
+
+// NameGTE applies the GTE predicate on the "name" field.
+func NameGTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldName, v))
+}
+
+// NameLT applies the LT predicate on the "name" field.
+func NameLT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldName, v))
+}
+
+// NameLTE applies the LTE predicate on the "name" field.
+func NameLTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldName, v))
+}
+
+// NameContains applies the Contains predicate on the "name" field.
+func NameContains(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContains(FieldName, v))
+}
+
+// NameHasPrefix applies the HasPrefix predicate on the "name" field.
+func NameHasPrefix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasPrefix(FieldName, v))
+}
+
+// NameHasSuffix applies the HasSuffix predicate on the "name" field.
+func NameHasSuffix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasSuffix(FieldName, v))
+}
+
+// NameEqualFold applies the EqualFold predicate on the "name" field.
+func NameEqualFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEqualFold(FieldName, v))
+}
+
+// NameContainsFold applies the ContainsFold predicate on the "name" field.
+func NameContainsFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContainsFold(FieldName, v))
+}
+
+// ProviderEQ applies the EQ predicate on the "provider" field.
+func ProviderEQ(v Provider) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldProvider, v))
+}
+
+// ProviderNEQ applies the NEQ predicate on the "provider" field.
+func ProviderNEQ(v Provider) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldProvider, v))
+}
+
+// ProviderIn applies the In predicate on the "provider" field.
+func ProviderIn(vs ...Provider) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldProvider, vs...))
+}
+
+// ProviderNotIn applies the NotIn predicate on the "provider" field.
+func ProviderNotIn(vs ...Provider) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldProvider, vs...))
+}
+
+// DescriptionEQ applies the EQ predicate on the "description" field.
+func DescriptionEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldDescription, v))
+}
+
+// DescriptionNEQ applies the NEQ predicate on the "description" field.
+func DescriptionNEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldDescription, v))
+}
+
+// DescriptionIn applies the In predicate on the "description" field.
+func DescriptionIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldDescription, vs...))
+}
+
+// DescriptionNotIn applies the NotIn predicate on the "description" field.
+func DescriptionNotIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldDescription, vs...))
+}
+
+// DescriptionGT applies the GT predicate on the "description" field.
+func DescriptionGT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldDescription, v))
+}
+
+// DescriptionGTE applies the GTE predicate on the "description" field.
+func DescriptionGTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldDescription, v))
+}
+
+// DescriptionLT applies the LT predicate on the "description" field.
+func DescriptionLT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldDescription, v))
+}
+
+// DescriptionLTE applies the LTE predicate on the "description" field.
+func DescriptionLTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldDescription, v))
+}
+
+// DescriptionContains applies the Contains predicate on the "description" field.
+func DescriptionContains(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContains(FieldDescription, v))
+}
+
+// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field.
+func DescriptionHasPrefix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasPrefix(FieldDescription, v))
+}
+
+// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field.
+func DescriptionHasSuffix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasSuffix(FieldDescription, v))
+}
+
+// DescriptionIsNil applies the IsNil predicate on the "description" field.
+func DescriptionIsNil() predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIsNull(FieldDescription))
+}
+
+// DescriptionNotNil applies the NotNil predicate on the "description" field.
+func DescriptionNotNil() predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotNull(FieldDescription))
+}
+
+// DescriptionEqualFold applies the EqualFold predicate on the "description" field.
+func DescriptionEqualFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEqualFold(FieldDescription, v))
+}
+
+// DescriptionContainsFold applies the ContainsFold predicate on the "description" field.
+func DescriptionContainsFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContainsFold(FieldDescription, v))
+}
+
+// BodyOverrideModeEQ applies the EQ predicate on the "body_override_mode" field.
+func BodyOverrideModeEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeNEQ applies the NEQ predicate on the "body_override_mode" field.
+func BodyOverrideModeNEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeIn applies the In predicate on the "body_override_mode" field.
+func BodyOverrideModeIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldBodyOverrideMode, vs...))
+}
+
+// BodyOverrideModeNotIn applies the NotIn predicate on the "body_override_mode" field.
+func BodyOverrideModeNotIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldBodyOverrideMode, vs...))
+}
+
+// BodyOverrideModeGT applies the GT predicate on the "body_override_mode" field.
+func BodyOverrideModeGT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeGTE applies the GTE predicate on the "body_override_mode" field.
+func BodyOverrideModeGTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeLT applies the LT predicate on the "body_override_mode" field.
+func BodyOverrideModeLT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeLTE applies the LTE predicate on the "body_override_mode" field.
+func BodyOverrideModeLTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeContains applies the Contains predicate on the "body_override_mode" field.
+func BodyOverrideModeContains(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContains(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeHasPrefix applies the HasPrefix predicate on the "body_override_mode" field.
+func BodyOverrideModeHasPrefix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasPrefix(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeHasSuffix applies the HasSuffix predicate on the "body_override_mode" field.
+func BodyOverrideModeHasSuffix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasSuffix(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeEqualFold applies the EqualFold predicate on the "body_override_mode" field.
+func BodyOverrideModeEqualFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEqualFold(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideModeContainsFold applies the ContainsFold predicate on the "body_override_mode" field.
+func BodyOverrideModeContainsFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContainsFold(FieldBodyOverrideMode, v))
+}
+
+// BodyOverrideIsNil applies the IsNil predicate on the "body_override" field.
+func BodyOverrideIsNil() predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIsNull(FieldBodyOverride))
+}
+
+// BodyOverrideNotNil applies the NotNil predicate on the "body_override" field.
+func BodyOverrideNotNil() predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotNull(FieldBodyOverride))
+}
+
+// HasMonitors applies the HasEdge predicate on the "monitors" edge.
+func HasMonitors() predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, true, MonitorsTable, MonitorsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasMonitorsWith applies the HasEdge predicate on the "monitors" edge with a given conditions (other predicates).
+func HasMonitorsWith(preds ...predicate.ChannelMonitor) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(func(s *sql.Selector) {
+ step := newMonitorsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.ChannelMonitorRequestTemplate) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.NotPredicates(p))
+}
diff --git a/backend/ent/channelmonitorrequesttemplate_create.go b/backend/ent/channelmonitorrequesttemplate_create.go
new file mode 100644
index 00000000..1ba842cd
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate_create.go
@@ -0,0 +1,942 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+)
+
+// ChannelMonitorRequestTemplateCreate is the builder for creating a ChannelMonitorRequestTemplate entity.
+type ChannelMonitorRequestTemplateCreate struct {
+ config
+ mutation *ChannelMonitorRequestTemplateMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetCreatedAt(v time.Time) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *ChannelMonitorRequestTemplateCreate) SetNillableCreatedAt(v *time.Time) *ChannelMonitorRequestTemplateCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *ChannelMonitorRequestTemplateCreate) SetNillableUpdatedAt(v *time.Time) *ChannelMonitorRequestTemplateCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetName sets the "name" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetName(v string) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetName(v)
+ return _c
+}
+
+// SetProvider sets the "provider" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetProvider(v)
+ return _c
+}
+
+// SetDescription sets the "description" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetDescription(v string) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetDescription(v)
+ return _c
+}
+
+// SetNillableDescription sets the "description" field if the given value is not nil.
+func (_c *ChannelMonitorRequestTemplateCreate) SetNillableDescription(v *string) *ChannelMonitorRequestTemplateCreate {
+ if v != nil {
+ _c.SetDescription(*v)
+ }
+ return _c
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetExtraHeaders(v)
+ return _c
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetBodyOverrideMode(v)
+ return _c
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_c *ChannelMonitorRequestTemplateCreate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorRequestTemplateCreate {
+ if v != nil {
+ _c.SetBodyOverrideMode(*v)
+ }
+ return _c
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetBodyOverride(v)
+ return _c
+}
+
+// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by IDs.
+func (_c *ChannelMonitorRequestTemplateCreate) AddMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.AddMonitorIDs(ids...)
+ return _c
+}
+
+// AddMonitors adds the "monitors" edges to the ChannelMonitor entity.
+func (_c *ChannelMonitorRequestTemplateCreate) AddMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddMonitorIDs(ids...)
+}
+
+// Mutation returns the ChannelMonitorRequestTemplateMutation object of the builder.
+func (_c *ChannelMonitorRequestTemplateCreate) Mutation() *ChannelMonitorRequestTemplateMutation {
+ return _c.mutation
+}
+
+// Save creates the ChannelMonitorRequestTemplate in the database.
+func (_c *ChannelMonitorRequestTemplateCreate) Save(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *ChannelMonitorRequestTemplateCreate) SaveX(ctx context.Context) *ChannelMonitorRequestTemplate {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorRequestTemplateCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorRequestTemplateCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *ChannelMonitorRequestTemplateCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := channelmonitorrequesttemplate.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := channelmonitorrequesttemplate.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.Description(); !ok {
+ v := channelmonitorrequesttemplate.DefaultDescription
+ _c.mutation.SetDescription(v)
+ }
+ if _, ok := _c.mutation.ExtraHeaders(); !ok {
+ v := channelmonitorrequesttemplate.DefaultExtraHeaders
+ _c.mutation.SetExtraHeaders(v)
+ }
+ if _, ok := _c.mutation.BodyOverrideMode(); !ok {
+ v := channelmonitorrequesttemplate.DefaultBodyOverrideMode
+ _c.mutation.SetBodyOverrideMode(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *ChannelMonitorRequestTemplateCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.updated_at"`)}
+ }
+ if _, ok := _c.mutation.Name(); !ok {
+ return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.name"`)}
+ }
+ if v, ok := _c.mutation.Name(); ok {
+ if err := channelmonitorrequesttemplate.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.name": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Provider(); !ok {
+ return &ValidationError{Name: "provider", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.provider"`)}
+ }
+ if v, ok := _c.mutation.Provider(); ok {
+ if err := channelmonitorrequesttemplate.ProviderValidator(v); err != nil {
+ return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
+ }
+ }
+ if v, ok := _c.mutation.Description(); ok {
+ if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
+ return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ExtraHeaders(); !ok {
+ return &ValidationError{Name: "extra_headers", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.extra_headers"`)}
+ }
+ if _, ok := _c.mutation.BodyOverrideMode(); !ok {
+ return &ValidationError{Name: "body_override_mode", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.body_override_mode"`)}
+ }
+ if v, ok := _c.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitorrequesttemplate.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.body_override_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_c *ChannelMonitorRequestTemplateCreate) sqlSave(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *ChannelMonitorRequestTemplateCreate) createSpec() (*ChannelMonitorRequestTemplate, *sqlgraph.CreateSpec) {
+ var (
+ _node = &ChannelMonitorRequestTemplate{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(channelmonitorrequesttemplate.Table, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.Name(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldName, field.TypeString, value)
+ _node.Name = value
+ }
+ if value, ok := _c.mutation.Provider(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
+ _node.Provider = value
+ }
+ if value, ok := _c.mutation.Description(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
+ _node.Description = value
+ }
+ if value, ok := _c.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldExtraHeaders, field.TypeJSON, value)
+ _node.ExtraHeaders = value
+ }
+ if value, ok := _c.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverrideMode, field.TypeString, value)
+ _node.BodyOverrideMode = value
+ }
+ if value, ok := _c.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON, value)
+ _node.BodyOverride = value
+ }
+ if nodes := _c.mutation.MonitorsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorRequestTemplateUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorRequestTemplateCreate) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorRequestTemplateUpsertOne {
+ _c.conflict = opts
+ return &ChannelMonitorRequestTemplateUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorRequestTemplateCreate) OnConflictColumns(columns ...string) *ChannelMonitorRequestTemplateUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorRequestTemplateUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // ChannelMonitorRequestTemplateUpsertOne is the builder for "upsert"-ing
+ // one ChannelMonitorRequestTemplate node.
+ ChannelMonitorRequestTemplateUpsertOne struct {
+ create *ChannelMonitorRequestTemplateCreate
+ }
+
+ // ChannelMonitorRequestTemplateUpsert is the "OnConflict" setter.
+ ChannelMonitorRequestTemplateUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateUpdatedAt() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldUpdatedAt)
+ return u
+}
+
+// SetName sets the "name" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetName(v string) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldName, v)
+ return u
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateName() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldName)
+ return u
+}
+
+// SetProvider sets the "provider" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldProvider, v)
+ return u
+}
+
+// UpdateProvider sets the "provider" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateProvider() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldProvider)
+ return u
+}
+
+// SetDescription sets the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetDescription(v string) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldDescription, v)
+ return u
+}
+
+// UpdateDescription sets the "description" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateDescription() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldDescription)
+ return u
+}
+
+// ClearDescription clears the value of the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsert) ClearDescription() *ChannelMonitorRequestTemplateUpsert {
+ u.SetNull(channelmonitorrequesttemplate.FieldDescription)
+ return u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldExtraHeaders, v)
+ return u
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateExtraHeaders() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldExtraHeaders)
+ return u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldBodyOverrideMode, v)
+ return u
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateBodyOverrideMode() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldBodyOverrideMode)
+ return u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldBodyOverride, v)
+ return u
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateBodyOverride() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldBodyOverride)
+ return u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsert) ClearBodyOverride() *ChannelMonitorRequestTemplateUpsert {
+ u.SetNull(channelmonitorrequesttemplate.FieldBodyOverride)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateNewValues() *ChannelMonitorRequestTemplateUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(channelmonitorrequesttemplate.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorRequestTemplateUpsertOne) Ignore() *ChannelMonitorRequestTemplateUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorRequestTemplateUpsertOne) DoNothing() *ChannelMonitorRequestTemplateUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorRequestTemplateCreate.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorRequestTemplateUpsertOne) Update(set func(*ChannelMonitorRequestTemplateUpsert)) *ChannelMonitorRequestTemplateUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorRequestTemplateUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateUpdatedAt() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetName sets the "name" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetName(v string) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetName(v)
+ })
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateName() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateName()
+ })
+}
+
+// SetProvider sets the "provider" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetProvider(v)
+ })
+}
+
+// UpdateProvider sets the "provider" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateProvider() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateProvider()
+ })
+}
+
+// SetDescription sets the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetDescription(v string) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetDescription(v)
+ })
+}
+
+// UpdateDescription sets the "description" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateDescription() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateDescription()
+ })
+}
+
+// ClearDescription clears the value of the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) ClearDescription() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.ClearDescription()
+ })
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetExtraHeaders(v)
+ })
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateExtraHeaders() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateExtraHeaders()
+ })
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetBodyOverrideMode(v)
+ })
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateBodyOverrideMode() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateBodyOverrideMode()
+ })
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetBodyOverride(v)
+ })
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateBodyOverride() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateBodyOverride()
+ })
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) ClearBodyOverride() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.ClearBodyOverride()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorRequestTemplateUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorRequestTemplateCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorRequestTemplateUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *ChannelMonitorRequestTemplateUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *ChannelMonitorRequestTemplateUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// ChannelMonitorRequestTemplateCreateBulk is the builder for creating many ChannelMonitorRequestTemplate entities in bulk.
+type ChannelMonitorRequestTemplateCreateBulk struct {
+ config
+ err error
+ builders []*ChannelMonitorRequestTemplateCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the ChannelMonitorRequestTemplate entities in the database.
+func (_c *ChannelMonitorRequestTemplateCreateBulk) Save(ctx context.Context) ([]*ChannelMonitorRequestTemplate, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*ChannelMonitorRequestTemplate, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*ChannelMonitorRequestTemplateMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *ChannelMonitorRequestTemplateCreateBulk) SaveX(ctx context.Context) []*ChannelMonitorRequestTemplate {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *ChannelMonitorRequestTemplateCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *ChannelMonitorRequestTemplateCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.ChannelMonitorRequestTemplate.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.ChannelMonitorRequestTemplateUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *ChannelMonitorRequestTemplateCreateBulk) OnConflict(opts ...sql.ConflictOption) *ChannelMonitorRequestTemplateUpsertBulk {
+ _c.conflict = opts
+ return &ChannelMonitorRequestTemplateUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *ChannelMonitorRequestTemplateCreateBulk) OnConflictColumns(columns ...string) *ChannelMonitorRequestTemplateUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &ChannelMonitorRequestTemplateUpsertBulk{
+ create: _c,
+ }
+}
+
+// ChannelMonitorRequestTemplateUpsertBulk is the builder for "upsert"-ing
+// a bulk of ChannelMonitorRequestTemplate nodes.
+type ChannelMonitorRequestTemplateUpsertBulk struct {
+ create *ChannelMonitorRequestTemplateCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateNewValues() *ChannelMonitorRequestTemplateUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(channelmonitorrequesttemplate.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.ChannelMonitorRequestTemplate.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *ChannelMonitorRequestTemplateUpsertBulk) Ignore() *ChannelMonitorRequestTemplateUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) DoNothing() *ChannelMonitorRequestTemplateUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the ChannelMonitorRequestTemplateCreateBulk.OnConflict
+// documentation for more info.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) Update(set func(*ChannelMonitorRequestTemplateUpsert)) *ChannelMonitorRequestTemplateUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&ChannelMonitorRequestTemplateUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateUpdatedAt() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetName sets the "name" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetName(v string) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetName(v)
+ })
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateName() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateName()
+ })
+}
+
+// SetProvider sets the "provider" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetProvider(v)
+ })
+}
+
+// UpdateProvider sets the "provider" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateProvider() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateProvider()
+ })
+}
+
+// SetDescription sets the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetDescription(v string) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetDescription(v)
+ })
+}
+
+// UpdateDescription sets the "description" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateDescription() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateDescription()
+ })
+}
+
+// ClearDescription clears the value of the "description" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) ClearDescription() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.ClearDescription()
+ })
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetExtraHeaders(v)
+ })
+}
+
+// UpdateExtraHeaders sets the "extra_headers" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateExtraHeaders() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateExtraHeaders()
+ })
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetBodyOverrideMode(v)
+ })
+}
+
+// UpdateBodyOverrideMode sets the "body_override_mode" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateBodyOverrideMode() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateBodyOverrideMode()
+ })
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetBodyOverride(v)
+ })
+}
+
+// UpdateBodyOverride sets the "body_override" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateBodyOverride() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateBodyOverride()
+ })
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) ClearBodyOverride() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.ClearBodyOverride()
+ })
+}
+
+// Exec executes the query.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ChannelMonitorRequestTemplateCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for ChannelMonitorRequestTemplateCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitorrequesttemplate_delete.go b/backend/ent/channelmonitorrequesttemplate_delete.go
new file mode 100644
index 00000000..98d365c8
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorRequestTemplateDelete is the builder for deleting a ChannelMonitorRequestTemplate entity.
+type ChannelMonitorRequestTemplateDelete struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorRequestTemplateMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorRequestTemplateDelete builder.
+func (_d *ChannelMonitorRequestTemplateDelete) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *ChannelMonitorRequestTemplateDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorRequestTemplateDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *ChannelMonitorRequestTemplateDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(channelmonitorrequesttemplate.Table, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// ChannelMonitorRequestTemplateDeleteOne is the builder for deleting a single ChannelMonitorRequestTemplate entity.
+type ChannelMonitorRequestTemplateDeleteOne struct {
+ _d *ChannelMonitorRequestTemplateDelete
+}
+
+// Where appends a list predicates to the ChannelMonitorRequestTemplateDelete builder.
+func (_d *ChannelMonitorRequestTemplateDeleteOne) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *ChannelMonitorRequestTemplateDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{channelmonitorrequesttemplate.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *ChannelMonitorRequestTemplateDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/channelmonitorrequesttemplate_query.go b/backend/ent/channelmonitorrequesttemplate_query.go
new file mode 100644
index 00000000..6491ea60
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate_query.go
@@ -0,0 +1,648 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorRequestTemplateQuery is the builder for querying ChannelMonitorRequestTemplate entities.
+type ChannelMonitorRequestTemplateQuery struct {
+ config
+ ctx *QueryContext
+ order []channelmonitorrequesttemplate.OrderOption
+ inters []Interceptor
+ predicates []predicate.ChannelMonitorRequestTemplate
+ withMonitors *ChannelMonitorQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the ChannelMonitorRequestTemplateQuery builder.
+func (_q *ChannelMonitorRequestTemplateQuery) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *ChannelMonitorRequestTemplateQuery) Limit(limit int) *ChannelMonitorRequestTemplateQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *ChannelMonitorRequestTemplateQuery) Offset(offset int) *ChannelMonitorRequestTemplateQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *ChannelMonitorRequestTemplateQuery) Unique(unique bool) *ChannelMonitorRequestTemplateQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *ChannelMonitorRequestTemplateQuery) Order(o ...channelmonitorrequesttemplate.OrderOption) *ChannelMonitorRequestTemplateQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryMonitors chains the current query on the "monitors" edge.
+func (_q *ChannelMonitorRequestTemplateQuery) QueryMonitors() *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID, selector),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, true, channelmonitorrequesttemplate.MonitorsTable, channelmonitorrequesttemplate.MonitorsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first ChannelMonitorRequestTemplate entity from the query.
+// Returns a *NotFoundError when no ChannelMonitorRequestTemplate was found.
+func (_q *ChannelMonitorRequestTemplateQuery) First(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{channelmonitorrequesttemplate.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) FirstX(ctx context.Context) *ChannelMonitorRequestTemplate {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first ChannelMonitorRequestTemplate ID from the query.
+// Returns a *NotFoundError when no ChannelMonitorRequestTemplate ID was found.
+func (_q *ChannelMonitorRequestTemplateQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{channelmonitorrequesttemplate.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single ChannelMonitorRequestTemplate entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one ChannelMonitorRequestTemplate entity is found.
+// Returns a *NotFoundError when no ChannelMonitorRequestTemplate entities are found.
+func (_q *ChannelMonitorRequestTemplateQuery) Only(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{channelmonitorrequesttemplate.Label}
+ default:
+ return nil, &NotSingularError{channelmonitorrequesttemplate.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) OnlyX(ctx context.Context) *ChannelMonitorRequestTemplate {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only ChannelMonitorRequestTemplate ID in the query.
+// Returns a *NotSingularError when more than one ChannelMonitorRequestTemplate ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *ChannelMonitorRequestTemplateQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{channelmonitorrequesttemplate.Label}
+ default:
+ err = &NotSingularError{channelmonitorrequesttemplate.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of ChannelMonitorRequestTemplates.
+func (_q *ChannelMonitorRequestTemplateQuery) All(ctx context.Context) ([]*ChannelMonitorRequestTemplate, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*ChannelMonitorRequestTemplate, *ChannelMonitorRequestTemplateQuery]()
+ return withInterceptors[[]*ChannelMonitorRequestTemplate](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) AllX(ctx context.Context) []*ChannelMonitorRequestTemplate {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of ChannelMonitorRequestTemplate IDs.
+func (_q *ChannelMonitorRequestTemplateQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(channelmonitorrequesttemplate.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *ChannelMonitorRequestTemplateQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*ChannelMonitorRequestTemplateQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *ChannelMonitorRequestTemplateQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *ChannelMonitorRequestTemplateQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the ChannelMonitorRequestTemplateQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *ChannelMonitorRequestTemplateQuery) Clone() *ChannelMonitorRequestTemplateQuery {
+ if _q == nil {
+ return nil
+ }
+ return &ChannelMonitorRequestTemplateQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]channelmonitorrequesttemplate.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.ChannelMonitorRequestTemplate{}, _q.predicates...),
+ withMonitors: _q.withMonitors.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithMonitors tells the query-builder to eager-load the nodes that are connected to
+// the "monitors" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *ChannelMonitorRequestTemplateQuery) WithMonitors(opts ...func(*ChannelMonitorQuery)) *ChannelMonitorRequestTemplateQuery {
+ query := (&ChannelMonitorClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withMonitors = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.ChannelMonitorRequestTemplate.Query().
+// GroupBy(channelmonitorrequesttemplate.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorRequestTemplateQuery) GroupBy(field string, fields ...string) *ChannelMonitorRequestTemplateGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &ChannelMonitorRequestTemplateGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = channelmonitorrequesttemplate.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.ChannelMonitorRequestTemplate.Query().
+// Select(channelmonitorrequesttemplate.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *ChannelMonitorRequestTemplateQuery) Select(fields ...string) *ChannelMonitorRequestTemplateSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &ChannelMonitorRequestTemplateSelect{ChannelMonitorRequestTemplateQuery: _q}
+ sbuild.label = channelmonitorrequesttemplate.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a ChannelMonitorRequestTemplateSelect configured with the given aggregations.
+func (_q *ChannelMonitorRequestTemplateQuery) Aggregate(fns ...AggregateFunc) *ChannelMonitorRequestTemplateSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !channelmonitorrequesttemplate.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ChannelMonitorRequestTemplate, error) {
+ var (
+ nodes = []*ChannelMonitorRequestTemplate{}
+ _spec = _q.querySpec()
+ loadedTypes = [1]bool{
+ _q.withMonitors != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*ChannelMonitorRequestTemplate).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &ChannelMonitorRequestTemplate{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withMonitors; query != nil {
+ if err := _q.loadMonitors(ctx, query, nodes,
+ func(n *ChannelMonitorRequestTemplate) { n.Edges.Monitors = []*ChannelMonitor{} },
+ func(n *ChannelMonitorRequestTemplate, e *ChannelMonitor) {
+ n.Edges.Monitors = append(n.Edges.Monitors, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) loadMonitors(ctx context.Context, query *ChannelMonitorQuery, nodes []*ChannelMonitorRequestTemplate, init func(*ChannelMonitorRequestTemplate), assign func(*ChannelMonitorRequestTemplate, *ChannelMonitor)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*ChannelMonitorRequestTemplate)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(channelmonitor.FieldTemplateID)
+ }
+ query.Where(predicate.ChannelMonitor(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(channelmonitorrequesttemplate.MonitorsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.TemplateID
+ if fk == nil {
+ return fmt.Errorf(`foreign-key "template_id" is nil for node %v`, n.ID)
+ }
+ node, ok := nodeids[*fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "template_id" returned %v for node %v`, *fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.Columns, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitorrequesttemplate.FieldID)
+ for i := range fields {
+ if fields[i] != channelmonitorrequesttemplate.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *ChannelMonitorRequestTemplateQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(channelmonitorrequesttemplate.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = channelmonitorrequesttemplate.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *ChannelMonitorRequestTemplateQuery) ForUpdate(opts ...sql.LockOption) *ChannelMonitorRequestTemplateQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *ChannelMonitorRequestTemplateQuery) ForShare(opts ...sql.LockOption) *ChannelMonitorRequestTemplateQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// ChannelMonitorRequestTemplateGroupBy is the group-by builder for ChannelMonitorRequestTemplate entities.
+type ChannelMonitorRequestTemplateGroupBy struct {
+ selector
+ build *ChannelMonitorRequestTemplateQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *ChannelMonitorRequestTemplateGroupBy) Aggregate(fns ...AggregateFunc) *ChannelMonitorRequestTemplateGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *ChannelMonitorRequestTemplateGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorRequestTemplateQuery, *ChannelMonitorRequestTemplateGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *ChannelMonitorRequestTemplateGroupBy) sqlScan(ctx context.Context, root *ChannelMonitorRequestTemplateQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// ChannelMonitorRequestTemplateSelect is the builder for selecting fields of ChannelMonitorRequestTemplate entities.
+type ChannelMonitorRequestTemplateSelect struct {
+ *ChannelMonitorRequestTemplateQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *ChannelMonitorRequestTemplateSelect) Aggregate(fns ...AggregateFunc) *ChannelMonitorRequestTemplateSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *ChannelMonitorRequestTemplateSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*ChannelMonitorRequestTemplateQuery, *ChannelMonitorRequestTemplateSelect](ctx, _s.ChannelMonitorRequestTemplateQuery, _s, _s.inters, v)
+}
+
+func (_s *ChannelMonitorRequestTemplateSelect) sqlScan(ctx context.Context, root *ChannelMonitorRequestTemplateQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/channelmonitorrequesttemplate_update.go b/backend/ent/channelmonitorrequesttemplate_update.go
new file mode 100644
index 00000000..8f55ba04
--- /dev/null
+++ b/backend/ent/channelmonitorrequesttemplate_update.go
@@ -0,0 +1,639 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ChannelMonitorRequestTemplateUpdate is the builder for updating ChannelMonitorRequestTemplate entities.
+type ChannelMonitorRequestTemplateUpdate struct {
+ config
+ hooks []Hook
+ mutation *ChannelMonitorRequestTemplateMutation
+}
+
+// Where appends a list predicates to the ChannelMonitorRequestTemplateUpdate builder.
+func (_u *ChannelMonitorRequestTemplateUpdate) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetName sets the "name" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetName(v string) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetName(v)
+ return _u
+}
+
+// SetNillableName sets the "name" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableName(v *string) *ChannelMonitorRequestTemplateUpdate {
+ if v != nil {
+ _u.SetName(*v)
+ }
+ return _u
+}
+
+// SetProvider sets the "provider" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetProvider(v)
+ return _u
+}
+
+// SetNillableProvider sets the "provider" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableProvider(v *channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdate {
+ if v != nil {
+ _u.SetProvider(*v)
+ }
+ return _u
+}
+
+// SetDescription sets the "description" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetDescription(v string) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetDescription(v)
+ return _u
+}
+
+// SetNillableDescription sets the "description" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableDescription(v *string) *ChannelMonitorRequestTemplateUpdate {
+ if v != nil {
+ _u.SetDescription(*v)
+ }
+ return _u
+}
+
+// ClearDescription clears the value of the "description" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) ClearDescription() *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.ClearDescription()
+ return _u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetExtraHeaders(v)
+ return _u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetBodyOverrideMode(v)
+ return _u
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableBodyOverrideMode(v *string) *ChannelMonitorRequestTemplateUpdate {
+ if v != nil {
+ _u.SetBodyOverrideMode(*v)
+ }
+ return _u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetBodyOverride(v)
+ return _u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) ClearBodyOverride() *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.ClearBodyOverride()
+ return _u
+}
+
+// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by IDs.
+func (_u *ChannelMonitorRequestTemplateUpdate) AddMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.AddMonitorIDs(ids...)
+ return _u
+}
+
+// AddMonitors adds the "monitors" edges to the ChannelMonitor entity.
+func (_u *ChannelMonitorRequestTemplateUpdate) AddMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddMonitorIDs(ids...)
+}
+
+// Mutation returns the ChannelMonitorRequestTemplateMutation object of the builder.
+func (_u *ChannelMonitorRequestTemplateUpdate) Mutation() *ChannelMonitorRequestTemplateMutation {
+ return _u.mutation
+}
+
+// ClearMonitors clears all "monitors" edges to the ChannelMonitor entity.
+func (_u *ChannelMonitorRequestTemplateUpdate) ClearMonitors() *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.ClearMonitors()
+ return _u
+}
+
+// RemoveMonitorIDs removes the "monitors" edge to ChannelMonitor entities by IDs.
+func (_u *ChannelMonitorRequestTemplateUpdate) RemoveMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.RemoveMonitorIDs(ids...)
+ return _u
+}
+
+// RemoveMonitors removes "monitors" edges to ChannelMonitor entities.
+func (_u *ChannelMonitorRequestTemplateUpdate) RemoveMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveMonitorIDs(ids...)
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *ChannelMonitorRequestTemplateUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorRequestTemplateUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *ChannelMonitorRequestTemplateUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorRequestTemplateUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorRequestTemplateUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := channelmonitorrequesttemplate.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorRequestTemplateUpdate) check() error {
+ if v, ok := _u.mutation.Name(); ok {
+ if err := channelmonitorrequesttemplate.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Provider(); ok {
+ if err := channelmonitorrequesttemplate.ProviderValidator(v); err != nil {
+ return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Description(); ok {
+ if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
+ return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitorrequesttemplate.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.body_override_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorRequestTemplateUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.Columns, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.Name(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Provider(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
+ }
+ if value, ok := _u.mutation.Description(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
+ }
+ if _u.mutation.DescriptionCleared() {
+ _spec.ClearField(channelmonitorrequesttemplate.FieldDescription, field.TypeString)
+ }
+ if value, ok := _u.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldExtraHeaders, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverrideMode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON, value)
+ }
+ if _u.mutation.BodyOverrideCleared() {
+ _spec.ClearField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON)
+ }
+ if _u.mutation.MonitorsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedMonitorsIDs(); len(nodes) > 0 && !_u.mutation.MonitorsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitorrequesttemplate.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// ChannelMonitorRequestTemplateUpdateOne is the builder for updating a single ChannelMonitorRequestTemplate entity.
+type ChannelMonitorRequestTemplateUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *ChannelMonitorRequestTemplateMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetUpdatedAt(v time.Time) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetName sets the "name" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetName(v string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetName(v)
+ return _u
+}
+
+// SetNillableName sets the "name" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableName(v *string) *ChannelMonitorRequestTemplateUpdateOne {
+ if v != nil {
+ _u.SetName(*v)
+ }
+ return _u
+}
+
+// SetProvider sets the "provider" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetProvider(v channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetProvider(v)
+ return _u
+}
+
+// SetNillableProvider sets the "provider" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableProvider(v *channelmonitorrequesttemplate.Provider) *ChannelMonitorRequestTemplateUpdateOne {
+ if v != nil {
+ _u.SetProvider(*v)
+ }
+ return _u
+}
+
+// SetDescription sets the "description" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetDescription(v string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetDescription(v)
+ return _u
+}
+
+// SetNillableDescription sets the "description" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableDescription(v *string) *ChannelMonitorRequestTemplateUpdateOne {
+ if v != nil {
+ _u.SetDescription(*v)
+ }
+ return _u
+}
+
+// ClearDescription clears the value of the "description" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) ClearDescription() *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.ClearDescription()
+ return _u
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetExtraHeaders(v map[string]string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetExtraHeaders(v)
+ return _u
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetBodyOverrideMode(v string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetBodyOverrideMode(v)
+ return _u
+}
+
+// SetNillableBodyOverrideMode sets the "body_override_mode" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableBodyOverrideMode(v *string) *ChannelMonitorRequestTemplateUpdateOne {
+ if v != nil {
+ _u.SetBodyOverrideMode(*v)
+ }
+ return _u
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetBodyOverride(v map[string]interface{}) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetBodyOverride(v)
+ return _u
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) ClearBodyOverride() *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.ClearBodyOverride()
+ return _u
+}
+
+// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by IDs.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) AddMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.AddMonitorIDs(ids...)
+ return _u
+}
+
+// AddMonitors adds the "monitors" edges to the ChannelMonitor entity.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) AddMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddMonitorIDs(ids...)
+}
+
+// Mutation returns the ChannelMonitorRequestTemplateMutation object of the builder.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) Mutation() *ChannelMonitorRequestTemplateMutation {
+ return _u.mutation
+}
+
+// ClearMonitors clears all "monitors" edges to the ChannelMonitor entity.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) ClearMonitors() *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.ClearMonitors()
+ return _u
+}
+
+// RemoveMonitorIDs removes the "monitors" edge to ChannelMonitor entities by IDs.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) RemoveMonitorIDs(ids ...int64) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.RemoveMonitorIDs(ids...)
+ return _u
+}
+
+// RemoveMonitors removes "monitors" edges to ChannelMonitor entities.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) RemoveMonitors(v ...*ChannelMonitor) *ChannelMonitorRequestTemplateUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveMonitorIDs(ids...)
+}
+
+// Where appends a list predicates to the ChannelMonitorRequestTemplateUpdate builder.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) Where(ps ...predicate.ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) Select(field string, fields ...string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated ChannelMonitorRequestTemplate entity.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) Save(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SaveX(ctx context.Context) *ChannelMonitorRequestTemplate {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := channelmonitorrequesttemplate.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) check() error {
+ if v, ok := _u.mutation.Name(); ok {
+ if err := channelmonitorrequesttemplate.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Provider(); ok {
+ if err := channelmonitorrequesttemplate.ProviderValidator(v); err != nil {
+ return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Description(); ok {
+ if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
+ return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.BodyOverrideMode(); ok {
+ if err := channelmonitorrequesttemplate.BodyOverrideModeValidator(v); err != nil {
+ return &ValidationError{Name: "body_override_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.body_override_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *ChannelMonitorRequestTemplateUpdateOne) sqlSave(ctx context.Context) (_node *ChannelMonitorRequestTemplate, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.Columns, sqlgraph.NewFieldSpec(channelmonitorrequesttemplate.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ChannelMonitorRequestTemplate.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, channelmonitorrequesttemplate.FieldID)
+ for _, f := range fields {
+ if !channelmonitorrequesttemplate.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != channelmonitorrequesttemplate.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.Name(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Provider(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
+ }
+ if value, ok := _u.mutation.Description(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
+ }
+ if _u.mutation.DescriptionCleared() {
+ _spec.ClearField(channelmonitorrequesttemplate.FieldDescription, field.TypeString)
+ }
+ if value, ok := _u.mutation.ExtraHeaders(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldExtraHeaders, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BodyOverrideMode(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverrideMode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BodyOverride(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON, value)
+ }
+ if _u.mutation.BodyOverrideCleared() {
+ _spec.ClearField(channelmonitorrequesttemplate.FieldBodyOverride, field.TypeJSON)
+ }
+ if _u.mutation.MonitorsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedMonitorsIDs(); len(nodes) > 0 && !_u.mutation.MonitorsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.MonitorsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: true,
+ Table: channelmonitorrequesttemplate.MonitorsTable,
+ Columns: []string{channelmonitorrequesttemplate.MonitorsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(channelmonitor.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &ChannelMonitorRequestTemplate{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{channelmonitorrequesttemplate.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/client.go b/backend/ent/client.go
index 4129d6c5..df20ddfa 100644
--- a/backend/ent/client.go
+++ b/backend/ent/client.go
@@ -20,15 +20,27 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
+ "github.com/Wei-Shaw/sub2api/ent/subscriptionplan"
"github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -56,12 +68,34 @@ type Client struct {
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
+ // AuthIdentity is the client for interacting with the AuthIdentity builders.
+ AuthIdentity *AuthIdentityClient
+ // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders.
+ AuthIdentityChannel *AuthIdentityChannelClient
+ // ChannelMonitor is the client for interacting with the ChannelMonitor builders.
+ ChannelMonitor *ChannelMonitorClient
+ // ChannelMonitorDailyRollup is the client for interacting with the ChannelMonitorDailyRollup builders.
+ ChannelMonitorDailyRollup *ChannelMonitorDailyRollupClient
+ // ChannelMonitorHistory is the client for interacting with the ChannelMonitorHistory builders.
+ ChannelMonitorHistory *ChannelMonitorHistoryClient
+ // ChannelMonitorRequestTemplate is the client for interacting with the ChannelMonitorRequestTemplate builders.
+ ChannelMonitorRequestTemplate *ChannelMonitorRequestTemplateClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient
+ // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders.
+ IdentityAdoptionDecision *IdentityAdoptionDecisionClient
+ // PaymentAuditLog is the client for interacting with the PaymentAuditLog builders.
+ PaymentAuditLog *PaymentAuditLogClient
+ // PaymentOrder is the client for interacting with the PaymentOrder builders.
+ PaymentOrder *PaymentOrderClient
+ // PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders.
+ PaymentProviderInstance *PaymentProviderInstanceClient
+ // PendingAuthSession is the client for interacting with the PendingAuthSession builders.
+ PendingAuthSession *PendingAuthSessionClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
@@ -74,6 +108,8 @@ type Client struct {
SecuritySecret *SecuritySecretClient
// Setting is the client for interacting with the Setting builders.
Setting *SettingClient
+ // SubscriptionPlan is the client for interacting with the SubscriptionPlan builders.
+ SubscriptionPlan *SubscriptionPlanClient
// TLSFingerprintProfile is the client for interacting with the TLSFingerprintProfile builders.
TLSFingerprintProfile *TLSFingerprintProfileClient
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
@@ -106,15 +142,27 @@ func (c *Client) init() {
c.AccountGroup = NewAccountGroupClient(c.config)
c.Announcement = NewAnnouncementClient(c.config)
c.AnnouncementRead = NewAnnouncementReadClient(c.config)
+ c.AuthIdentity = NewAuthIdentityClient(c.config)
+ c.AuthIdentityChannel = NewAuthIdentityChannelClient(c.config)
+ c.ChannelMonitor = NewChannelMonitorClient(c.config)
+ c.ChannelMonitorDailyRollup = NewChannelMonitorDailyRollupClient(c.config)
+ c.ChannelMonitorHistory = NewChannelMonitorHistoryClient(c.config)
+ c.ChannelMonitorRequestTemplate = NewChannelMonitorRequestTemplateClient(c.config)
c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config)
c.Group = NewGroupClient(c.config)
c.IdempotencyRecord = NewIdempotencyRecordClient(c.config)
+ c.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(c.config)
+ c.PaymentAuditLog = NewPaymentAuditLogClient(c.config)
+ c.PaymentOrder = NewPaymentOrderClient(c.config)
+ c.PaymentProviderInstance = NewPaymentProviderInstanceClient(c.config)
+ c.PendingAuthSession = NewPendingAuthSessionClient(c.config)
c.PromoCode = NewPromoCodeClient(c.config)
c.PromoCodeUsage = NewPromoCodeUsageClient(c.config)
c.Proxy = NewProxyClient(c.config)
c.RedeemCode = NewRedeemCodeClient(c.config)
c.SecuritySecret = NewSecuritySecretClient(c.config)
c.Setting = NewSettingClient(c.config)
+ c.SubscriptionPlan = NewSubscriptionPlanClient(c.config)
c.TLSFingerprintProfile = NewTLSFingerprintProfileClient(c.config)
c.UsageCleanupTask = NewUsageCleanupTaskClient(c.config)
c.UsageLog = NewUsageLogClient(c.config)
@@ -213,30 +261,42 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
cfg := c.config
cfg.driver = tx
return &Tx{
- ctx: ctx,
- config: cfg,
- APIKey: NewAPIKeyClient(cfg),
- Account: NewAccountClient(cfg),
- AccountGroup: NewAccountGroupClient(cfg),
- Announcement: NewAnnouncementClient(cfg),
- AnnouncementRead: NewAnnouncementReadClient(cfg),
- ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
- Group: NewGroupClient(cfg),
- IdempotencyRecord: NewIdempotencyRecordClient(cfg),
- PromoCode: NewPromoCodeClient(cfg),
- PromoCodeUsage: NewPromoCodeUsageClient(cfg),
- Proxy: NewProxyClient(cfg),
- RedeemCode: NewRedeemCodeClient(cfg),
- SecuritySecret: NewSecuritySecretClient(cfg),
- Setting: NewSettingClient(cfg),
- TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
- UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
- UsageLog: NewUsageLogClient(cfg),
- User: NewUserClient(cfg),
- UserAllowedGroup: NewUserAllowedGroupClient(cfg),
- UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
- UserAttributeValue: NewUserAttributeValueClient(cfg),
- UserSubscription: NewUserSubscriptionClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ APIKey: NewAPIKeyClient(cfg),
+ Account: NewAccountClient(cfg),
+ AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
+ AuthIdentity: NewAuthIdentityClient(cfg),
+ AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
+ ChannelMonitor: NewChannelMonitorClient(cfg),
+ ChannelMonitorDailyRollup: NewChannelMonitorDailyRollupClient(cfg),
+ ChannelMonitorHistory: NewChannelMonitorHistoryClient(cfg),
+ ChannelMonitorRequestTemplate: NewChannelMonitorRequestTemplateClient(cfg),
+ ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
+ Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
+ IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
+ PaymentAuditLog: NewPaymentAuditLogClient(cfg),
+ PaymentOrder: NewPaymentOrderClient(cfg),
+ PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
+ PendingAuthSession: NewPendingAuthSessionClient(cfg),
+ PromoCode: NewPromoCodeClient(cfg),
+ PromoCodeUsage: NewPromoCodeUsageClient(cfg),
+ Proxy: NewProxyClient(cfg),
+ RedeemCode: NewRedeemCodeClient(cfg),
+ SecuritySecret: NewSecuritySecretClient(cfg),
+ Setting: NewSettingClient(cfg),
+ SubscriptionPlan: NewSubscriptionPlanClient(cfg),
+ TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
+ UsageLog: NewUsageLogClient(cfg),
+ User: NewUserClient(cfg),
+ UserAllowedGroup: NewUserAllowedGroupClient(cfg),
+ UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
+ UserAttributeValue: NewUserAttributeValueClient(cfg),
+ UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -254,30 +314,42 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
cfg := c.config
cfg.driver = &txDriver{tx: tx, drv: c.driver}
return &Tx{
- ctx: ctx,
- config: cfg,
- APIKey: NewAPIKeyClient(cfg),
- Account: NewAccountClient(cfg),
- AccountGroup: NewAccountGroupClient(cfg),
- Announcement: NewAnnouncementClient(cfg),
- AnnouncementRead: NewAnnouncementReadClient(cfg),
- ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
- Group: NewGroupClient(cfg),
- IdempotencyRecord: NewIdempotencyRecordClient(cfg),
- PromoCode: NewPromoCodeClient(cfg),
- PromoCodeUsage: NewPromoCodeUsageClient(cfg),
- Proxy: NewProxyClient(cfg),
- RedeemCode: NewRedeemCodeClient(cfg),
- SecuritySecret: NewSecuritySecretClient(cfg),
- Setting: NewSettingClient(cfg),
- TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
- UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
- UsageLog: NewUsageLogClient(cfg),
- User: NewUserClient(cfg),
- UserAllowedGroup: NewUserAllowedGroupClient(cfg),
- UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
- UserAttributeValue: NewUserAttributeValueClient(cfg),
- UserSubscription: NewUserSubscriptionClient(cfg),
+ ctx: ctx,
+ config: cfg,
+ APIKey: NewAPIKeyClient(cfg),
+ Account: NewAccountClient(cfg),
+ AccountGroup: NewAccountGroupClient(cfg),
+ Announcement: NewAnnouncementClient(cfg),
+ AnnouncementRead: NewAnnouncementReadClient(cfg),
+ AuthIdentity: NewAuthIdentityClient(cfg),
+ AuthIdentityChannel: NewAuthIdentityChannelClient(cfg),
+ ChannelMonitor: NewChannelMonitorClient(cfg),
+ ChannelMonitorDailyRollup: NewChannelMonitorDailyRollupClient(cfg),
+ ChannelMonitorHistory: NewChannelMonitorHistoryClient(cfg),
+ ChannelMonitorRequestTemplate: NewChannelMonitorRequestTemplateClient(cfg),
+ ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg),
+ Group: NewGroupClient(cfg),
+ IdempotencyRecord: NewIdempotencyRecordClient(cfg),
+ IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg),
+ PaymentAuditLog: NewPaymentAuditLogClient(cfg),
+ PaymentOrder: NewPaymentOrderClient(cfg),
+ PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg),
+ PendingAuthSession: NewPendingAuthSessionClient(cfg),
+ PromoCode: NewPromoCodeClient(cfg),
+ PromoCodeUsage: NewPromoCodeUsageClient(cfg),
+ Proxy: NewProxyClient(cfg),
+ RedeemCode: NewRedeemCodeClient(cfg),
+ SecuritySecret: NewSecuritySecretClient(cfg),
+ Setting: NewSettingClient(cfg),
+ SubscriptionPlan: NewSubscriptionPlanClient(cfg),
+ TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
+ UsageLog: NewUsageLogClient(cfg),
+ User: NewUserClient(cfg),
+ UserAllowedGroup: NewUserAllowedGroupClient(cfg),
+ UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg),
+ UserAttributeValue: NewUserAttributeValueClient(cfg),
+ UserSubscription: NewUserSubscriptionClient(cfg),
}, nil
}
@@ -308,10 +380,14 @@ func (c *Client) Close() error {
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
- c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode,
+ c.AuthIdentity, c.AuthIdentityChannel, c.ChannelMonitor,
+ c.ChannelMonitorDailyRollup, c.ChannelMonitorHistory,
+ c.ChannelMonitorRequestTemplate, c.ErrorPassthroughRule, c.Group,
+ c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
+ c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
- c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
+ c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Use(hooks...)
@@ -323,10 +399,14 @@ func (c *Client) Use(hooks ...Hook) {
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead,
- c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode,
+ c.AuthIdentity, c.AuthIdentityChannel, c.ChannelMonitor,
+ c.ChannelMonitorDailyRollup, c.ChannelMonitorHistory,
+ c.ChannelMonitorRequestTemplate, c.ErrorPassthroughRule, c.Group,
+ c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog,
+ c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode,
c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting,
- c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User,
- c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog,
+ c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
c.UserSubscription,
} {
n.Intercept(interceptors...)
@@ -346,12 +426,34 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.Announcement.mutate(ctx, m)
case *AnnouncementReadMutation:
return c.AnnouncementRead.mutate(ctx, m)
+ case *AuthIdentityMutation:
+ return c.AuthIdentity.mutate(ctx, m)
+ case *AuthIdentityChannelMutation:
+ return c.AuthIdentityChannel.mutate(ctx, m)
+ case *ChannelMonitorMutation:
+ return c.ChannelMonitor.mutate(ctx, m)
+ case *ChannelMonitorDailyRollupMutation:
+ return c.ChannelMonitorDailyRollup.mutate(ctx, m)
+ case *ChannelMonitorHistoryMutation:
+ return c.ChannelMonitorHistory.mutate(ctx, m)
+ case *ChannelMonitorRequestTemplateMutation:
+ return c.ChannelMonitorRequestTemplate.mutate(ctx, m)
case *ErrorPassthroughRuleMutation:
return c.ErrorPassthroughRule.mutate(ctx, m)
case *GroupMutation:
return c.Group.mutate(ctx, m)
case *IdempotencyRecordMutation:
return c.IdempotencyRecord.mutate(ctx, m)
+ case *IdentityAdoptionDecisionMutation:
+ return c.IdentityAdoptionDecision.mutate(ctx, m)
+ case *PaymentAuditLogMutation:
+ return c.PaymentAuditLog.mutate(ctx, m)
+ case *PaymentOrderMutation:
+ return c.PaymentOrder.mutate(ctx, m)
+ case *PaymentProviderInstanceMutation:
+ return c.PaymentProviderInstance.mutate(ctx, m)
+ case *PendingAuthSessionMutation:
+ return c.PendingAuthSession.mutate(ctx, m)
case *PromoCodeMutation:
return c.PromoCode.mutate(ctx, m)
case *PromoCodeUsageMutation:
@@ -364,6 +466,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.SecuritySecret.mutate(ctx, m)
case *SettingMutation:
return c.Setting.mutate(ctx, m)
+ case *SubscriptionPlanMutation:
+ return c.SubscriptionPlan.mutate(ctx, m)
case *TLSFingerprintProfileMutation:
return c.TLSFingerprintProfile.mutate(ctx, m)
case *UsageCleanupTaskMutation:
@@ -1197,6 +1301,964 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead
}
}
+// AuthIdentityClient is a client for the AuthIdentity schema.
+type AuthIdentityClient struct {
+ config
+}
+
+// NewAuthIdentityClient returns a client for the AuthIdentity from the given config.
+func NewAuthIdentityClient(c config) *AuthIdentityClient {
+ return &AuthIdentityClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `authidentity.Hooks(f(g(h())))`.
+func (c *AuthIdentityClient) Use(hooks ...Hook) {
+ c.hooks.AuthIdentity = append(c.hooks.AuthIdentity, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `authidentity.Intercept(f(g(h())))`.
+func (c *AuthIdentityClient) Intercept(interceptors ...Interceptor) {
+ c.inters.AuthIdentity = append(c.inters.AuthIdentity, interceptors...)
+}
+
+// Create returns a builder for creating a AuthIdentity entity.
+func (c *AuthIdentityClient) Create() *AuthIdentityCreate {
+ mutation := newAuthIdentityMutation(c.config, OpCreate)
+ return &AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of AuthIdentity entities.
+func (c *AuthIdentityClient) CreateBulk(builders ...*AuthIdentityCreate) *AuthIdentityCreateBulk {
+ return &AuthIdentityCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *AuthIdentityClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityCreate, int)) *AuthIdentityCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &AuthIdentityCreateBulk{err: fmt.Errorf("calling to AuthIdentityClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*AuthIdentityCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &AuthIdentityCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for AuthIdentity.
+func (c *AuthIdentityClient) Update() *AuthIdentityUpdate {
+ mutation := newAuthIdentityMutation(c.config, OpUpdate)
+ return &AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *AuthIdentityClient) UpdateOne(_m *AuthIdentity) *AuthIdentityUpdateOne {
+ mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentity(_m))
+ return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *AuthIdentityClient) UpdateOneID(id int64) *AuthIdentityUpdateOne {
+ mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentityID(id))
+ return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for AuthIdentity.
+func (c *AuthIdentityClient) Delete() *AuthIdentityDelete {
+ mutation := newAuthIdentityMutation(c.config, OpDelete)
+ return &AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *AuthIdentityClient) DeleteOne(_m *AuthIdentity) *AuthIdentityDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *AuthIdentityClient) DeleteOneID(id int64) *AuthIdentityDeleteOne {
+ builder := c.Delete().Where(authidentity.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &AuthIdentityDeleteOne{builder}
+}
+
+// Query returns a query builder for AuthIdentity.
+func (c *AuthIdentityClient) Query() *AuthIdentityQuery {
+ return &AuthIdentityQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeAuthIdentity},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a AuthIdentity entity by its id.
+func (c *AuthIdentityClient) Get(ctx context.Context, id int64) (*AuthIdentity, error) {
+ return c.Query().Where(authidentity.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *AuthIdentityClient) GetX(ctx context.Context, id int64) *AuthIdentity {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryUser queries the user edge of a AuthIdentity.
+func (c *AuthIdentityClient) QueryUser(_m *AuthIdentity) *UserQuery {
+ query := (&UserClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryChannels queries the channels edge of a AuthIdentity.
+func (c *AuthIdentityClient) QueryChannels(_m *AuthIdentity) *AuthIdentityChannelQuery {
+ query := (&AuthIdentityChannelClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
+ sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecisions queries the adoption_decisions edge of a AuthIdentity.
+func (c *AuthIdentityClient) QueryAdoptionDecisions(_m *AuthIdentity) *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentity.Table, authidentity.FieldID, id),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *AuthIdentityClient) Hooks() []Hook {
+ return c.hooks.AuthIdentity
+}
+
+// Interceptors returns the client interceptors.
+func (c *AuthIdentityClient) Interceptors() []Interceptor {
+ return c.inters.AuthIdentity
+}
+
+func (c *AuthIdentityClient) mutate(ctx context.Context, m *AuthIdentityMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown AuthIdentity mutation op: %q", m.Op())
+ }
+}
+
+// AuthIdentityChannelClient is a client for the AuthIdentityChannel schema.
+type AuthIdentityChannelClient struct {
+ config
+}
+
+// NewAuthIdentityChannelClient returns a client for the AuthIdentityChannel from the given config.
+func NewAuthIdentityChannelClient(c config) *AuthIdentityChannelClient {
+ return &AuthIdentityChannelClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `authidentitychannel.Hooks(f(g(h())))`.
+func (c *AuthIdentityChannelClient) Use(hooks ...Hook) {
+ c.hooks.AuthIdentityChannel = append(c.hooks.AuthIdentityChannel, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `authidentitychannel.Intercept(f(g(h())))`.
+func (c *AuthIdentityChannelClient) Intercept(interceptors ...Interceptor) {
+ c.inters.AuthIdentityChannel = append(c.inters.AuthIdentityChannel, interceptors...)
+}
+
+// Create returns a builder for creating a AuthIdentityChannel entity.
+func (c *AuthIdentityChannelClient) Create() *AuthIdentityChannelCreate {
+ mutation := newAuthIdentityChannelMutation(c.config, OpCreate)
+ return &AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of AuthIdentityChannel entities.
+func (c *AuthIdentityChannelClient) CreateBulk(builders ...*AuthIdentityChannelCreate) *AuthIdentityChannelCreateBulk {
+ return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *AuthIdentityChannelClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityChannelCreate, int)) *AuthIdentityChannelCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &AuthIdentityChannelCreateBulk{err: fmt.Errorf("calling to AuthIdentityChannelClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*AuthIdentityChannelCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) Update() *AuthIdentityChannelUpdate {
+ mutation := newAuthIdentityChannelMutation(c.config, OpUpdate)
+ return &AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *AuthIdentityChannelClient) UpdateOne(_m *AuthIdentityChannel) *AuthIdentityChannelUpdateOne {
+ mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannel(_m))
+ return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *AuthIdentityChannelClient) UpdateOneID(id int64) *AuthIdentityChannelUpdateOne {
+ mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannelID(id))
+ return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) Delete() *AuthIdentityChannelDelete {
+ mutation := newAuthIdentityChannelMutation(c.config, OpDelete)
+ return &AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *AuthIdentityChannelClient) DeleteOne(_m *AuthIdentityChannel) *AuthIdentityChannelDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *AuthIdentityChannelClient) DeleteOneID(id int64) *AuthIdentityChannelDeleteOne {
+ builder := c.Delete().Where(authidentitychannel.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &AuthIdentityChannelDeleteOne{builder}
+}
+
+// Query returns a query builder for AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) Query() *AuthIdentityChannelQuery {
+ return &AuthIdentityChannelQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeAuthIdentityChannel},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a AuthIdentityChannel entity by its id.
+func (c *AuthIdentityChannelClient) Get(ctx context.Context, id int64) (*AuthIdentityChannel, error) {
+ return c.Query().Where(authidentitychannel.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *AuthIdentityChannelClient) GetX(ctx context.Context, id int64) *AuthIdentityChannel {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryIdentity queries the identity edge of a AuthIdentityChannel.
+func (c *AuthIdentityChannelClient) QueryIdentity(_m *AuthIdentityChannel) *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, id),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *AuthIdentityChannelClient) Hooks() []Hook {
+ return c.hooks.AuthIdentityChannel
+}
+
+// Interceptors returns the client interceptors.
+func (c *AuthIdentityChannelClient) Interceptors() []Interceptor {
+ return c.inters.AuthIdentityChannel
+}
+
+func (c *AuthIdentityChannelClient) mutate(ctx context.Context, m *AuthIdentityChannelMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown AuthIdentityChannel mutation op: %q", m.Op())
+ }
+}
+
+// ChannelMonitorClient is a client for the ChannelMonitor schema.
+type ChannelMonitorClient struct {
+ config
+}
+
+// NewChannelMonitorClient returns a client for the ChannelMonitor from the given config.
+func NewChannelMonitorClient(c config) *ChannelMonitorClient {
+ return &ChannelMonitorClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `channelmonitor.Hooks(f(g(h())))`.
+func (c *ChannelMonitorClient) Use(hooks ...Hook) {
+ c.hooks.ChannelMonitor = append(c.hooks.ChannelMonitor, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `channelmonitor.Intercept(f(g(h())))`.
+func (c *ChannelMonitorClient) Intercept(interceptors ...Interceptor) {
+ c.inters.ChannelMonitor = append(c.inters.ChannelMonitor, interceptors...)
+}
+
+// Create returns a builder for creating a ChannelMonitor entity.
+func (c *ChannelMonitorClient) Create() *ChannelMonitorCreate {
+ mutation := newChannelMonitorMutation(c.config, OpCreate)
+ return &ChannelMonitorCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of ChannelMonitor entities.
+func (c *ChannelMonitorClient) CreateBulk(builders ...*ChannelMonitorCreate) *ChannelMonitorCreateBulk {
+ return &ChannelMonitorCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *ChannelMonitorClient) MapCreateBulk(slice any, setFunc func(*ChannelMonitorCreate, int)) *ChannelMonitorCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &ChannelMonitorCreateBulk{err: fmt.Errorf("calling to ChannelMonitorClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*ChannelMonitorCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &ChannelMonitorCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for ChannelMonitor.
+func (c *ChannelMonitorClient) Update() *ChannelMonitorUpdate {
+ mutation := newChannelMonitorMutation(c.config, OpUpdate)
+ return &ChannelMonitorUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *ChannelMonitorClient) UpdateOne(_m *ChannelMonitor) *ChannelMonitorUpdateOne {
+ mutation := newChannelMonitorMutation(c.config, OpUpdateOne, withChannelMonitor(_m))
+ return &ChannelMonitorUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *ChannelMonitorClient) UpdateOneID(id int64) *ChannelMonitorUpdateOne {
+ mutation := newChannelMonitorMutation(c.config, OpUpdateOne, withChannelMonitorID(id))
+ return &ChannelMonitorUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for ChannelMonitor.
+func (c *ChannelMonitorClient) Delete() *ChannelMonitorDelete {
+ mutation := newChannelMonitorMutation(c.config, OpDelete)
+ return &ChannelMonitorDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *ChannelMonitorClient) DeleteOne(_m *ChannelMonitor) *ChannelMonitorDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *ChannelMonitorClient) DeleteOneID(id int64) *ChannelMonitorDeleteOne {
+ builder := c.Delete().Where(channelmonitor.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &ChannelMonitorDeleteOne{builder}
+}
+
+// Query returns a query builder for ChannelMonitor.
+func (c *ChannelMonitorClient) Query() *ChannelMonitorQuery {
+ return &ChannelMonitorQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeChannelMonitor},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a ChannelMonitor entity by its id.
+func (c *ChannelMonitorClient) Get(ctx context.Context, id int64) (*ChannelMonitor, error) {
+ return c.Query().Where(channelmonitor.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *ChannelMonitorClient) GetX(ctx context.Context, id int64) *ChannelMonitor {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryHistory queries the history edge of a ChannelMonitor.
+func (c *ChannelMonitorClient) QueryHistory(_m *ChannelMonitor) *ChannelMonitorHistoryQuery {
+ query := (&ChannelMonitorHistoryClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, id),
+ sqlgraph.To(channelmonitorhistory.Table, channelmonitorhistory.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, channelmonitor.HistoryTable, channelmonitor.HistoryColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryDailyRollups queries the daily_rollups edge of a ChannelMonitor.
+func (c *ChannelMonitorClient) QueryDailyRollups(_m *ChannelMonitor) *ChannelMonitorDailyRollupQuery {
+ query := (&ChannelMonitorDailyRollupClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, id),
+ sqlgraph.To(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, channelmonitor.DailyRollupsTable, channelmonitor.DailyRollupsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryRequestTemplate queries the request_template edge of a ChannelMonitor.
+func (c *ChannelMonitorClient) QueryRequestTemplate(_m *ChannelMonitor) *ChannelMonitorRequestTemplateQuery {
+ query := (&ChannelMonitorRequestTemplateClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitor.Table, channelmonitor.FieldID, id),
+ sqlgraph.To(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, false, channelmonitor.RequestTemplateTable, channelmonitor.RequestTemplateColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *ChannelMonitorClient) Hooks() []Hook {
+ return c.hooks.ChannelMonitor
+}
+
+// Interceptors returns the client interceptors.
+func (c *ChannelMonitorClient) Interceptors() []Interceptor {
+ return c.inters.ChannelMonitor
+}
+
+func (c *ChannelMonitorClient) mutate(ctx context.Context, m *ChannelMonitorMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&ChannelMonitorCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&ChannelMonitorUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&ChannelMonitorUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&ChannelMonitorDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown ChannelMonitor mutation op: %q", m.Op())
+ }
+}
+
+// ChannelMonitorDailyRollupClient is a client for the ChannelMonitorDailyRollup schema.
+type ChannelMonitorDailyRollupClient struct {
+ config
+}
+
+// NewChannelMonitorDailyRollupClient returns a client for the ChannelMonitorDailyRollup from the given config.
+func NewChannelMonitorDailyRollupClient(c config) *ChannelMonitorDailyRollupClient {
+ return &ChannelMonitorDailyRollupClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `channelmonitordailyrollup.Hooks(f(g(h())))`.
+func (c *ChannelMonitorDailyRollupClient) Use(hooks ...Hook) {
+ c.hooks.ChannelMonitorDailyRollup = append(c.hooks.ChannelMonitorDailyRollup, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `channelmonitordailyrollup.Intercept(f(g(h())))`.
+func (c *ChannelMonitorDailyRollupClient) Intercept(interceptors ...Interceptor) {
+ c.inters.ChannelMonitorDailyRollup = append(c.inters.ChannelMonitorDailyRollup, interceptors...)
+}
+
+// Create returns a builder for creating a ChannelMonitorDailyRollup entity.
+func (c *ChannelMonitorDailyRollupClient) Create() *ChannelMonitorDailyRollupCreate {
+ mutation := newChannelMonitorDailyRollupMutation(c.config, OpCreate)
+ return &ChannelMonitorDailyRollupCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of ChannelMonitorDailyRollup entities.
+func (c *ChannelMonitorDailyRollupClient) CreateBulk(builders ...*ChannelMonitorDailyRollupCreate) *ChannelMonitorDailyRollupCreateBulk {
+ return &ChannelMonitorDailyRollupCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *ChannelMonitorDailyRollupClient) MapCreateBulk(slice any, setFunc func(*ChannelMonitorDailyRollupCreate, int)) *ChannelMonitorDailyRollupCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &ChannelMonitorDailyRollupCreateBulk{err: fmt.Errorf("calling to ChannelMonitorDailyRollupClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*ChannelMonitorDailyRollupCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &ChannelMonitorDailyRollupCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for ChannelMonitorDailyRollup.
+func (c *ChannelMonitorDailyRollupClient) Update() *ChannelMonitorDailyRollupUpdate {
+ mutation := newChannelMonitorDailyRollupMutation(c.config, OpUpdate)
+ return &ChannelMonitorDailyRollupUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *ChannelMonitorDailyRollupClient) UpdateOne(_m *ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupUpdateOne {
+ mutation := newChannelMonitorDailyRollupMutation(c.config, OpUpdateOne, withChannelMonitorDailyRollup(_m))
+ return &ChannelMonitorDailyRollupUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *ChannelMonitorDailyRollupClient) UpdateOneID(id int64) *ChannelMonitorDailyRollupUpdateOne {
+ mutation := newChannelMonitorDailyRollupMutation(c.config, OpUpdateOne, withChannelMonitorDailyRollupID(id))
+ return &ChannelMonitorDailyRollupUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for ChannelMonitorDailyRollup.
+func (c *ChannelMonitorDailyRollupClient) Delete() *ChannelMonitorDailyRollupDelete {
+ mutation := newChannelMonitorDailyRollupMutation(c.config, OpDelete)
+ return &ChannelMonitorDailyRollupDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *ChannelMonitorDailyRollupClient) DeleteOne(_m *ChannelMonitorDailyRollup) *ChannelMonitorDailyRollupDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *ChannelMonitorDailyRollupClient) DeleteOneID(id int64) *ChannelMonitorDailyRollupDeleteOne {
+ builder := c.Delete().Where(channelmonitordailyrollup.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &ChannelMonitorDailyRollupDeleteOne{builder}
+}
+
+// Query returns a query builder for ChannelMonitorDailyRollup.
+func (c *ChannelMonitorDailyRollupClient) Query() *ChannelMonitorDailyRollupQuery {
+ return &ChannelMonitorDailyRollupQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeChannelMonitorDailyRollup},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a ChannelMonitorDailyRollup entity by its id.
+func (c *ChannelMonitorDailyRollupClient) Get(ctx context.Context, id int64) (*ChannelMonitorDailyRollup, error) {
+ return c.Query().Where(channelmonitordailyrollup.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *ChannelMonitorDailyRollupClient) GetX(ctx context.Context, id int64) *ChannelMonitorDailyRollup {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryMonitor queries the monitor edge of a ChannelMonitorDailyRollup.
+func (c *ChannelMonitorDailyRollupClient) QueryMonitor(_m *ChannelMonitorDailyRollup) *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitordailyrollup.Table, channelmonitordailyrollup.FieldID, id),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, channelmonitordailyrollup.MonitorTable, channelmonitordailyrollup.MonitorColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *ChannelMonitorDailyRollupClient) Hooks() []Hook {
+ return c.hooks.ChannelMonitorDailyRollup
+}
+
+// Interceptors returns the client interceptors.
+func (c *ChannelMonitorDailyRollupClient) Interceptors() []Interceptor {
+ return c.inters.ChannelMonitorDailyRollup
+}
+
+func (c *ChannelMonitorDailyRollupClient) mutate(ctx context.Context, m *ChannelMonitorDailyRollupMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&ChannelMonitorDailyRollupCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&ChannelMonitorDailyRollupUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&ChannelMonitorDailyRollupUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&ChannelMonitorDailyRollupDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown ChannelMonitorDailyRollup mutation op: %q", m.Op())
+ }
+}
+
+// ChannelMonitorHistoryClient is a client for the ChannelMonitorHistory schema.
+type ChannelMonitorHistoryClient struct {
+ config
+}
+
+// NewChannelMonitorHistoryClient returns a client for the ChannelMonitorHistory from the given config.
+func NewChannelMonitorHistoryClient(c config) *ChannelMonitorHistoryClient {
+ return &ChannelMonitorHistoryClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `channelmonitorhistory.Hooks(f(g(h())))`.
+func (c *ChannelMonitorHistoryClient) Use(hooks ...Hook) {
+ c.hooks.ChannelMonitorHistory = append(c.hooks.ChannelMonitorHistory, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `channelmonitorhistory.Intercept(f(g(h())))`.
+func (c *ChannelMonitorHistoryClient) Intercept(interceptors ...Interceptor) {
+ c.inters.ChannelMonitorHistory = append(c.inters.ChannelMonitorHistory, interceptors...)
+}
+
+// Create returns a builder for creating a ChannelMonitorHistory entity.
+func (c *ChannelMonitorHistoryClient) Create() *ChannelMonitorHistoryCreate {
+ mutation := newChannelMonitorHistoryMutation(c.config, OpCreate)
+ return &ChannelMonitorHistoryCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of ChannelMonitorHistory entities.
+func (c *ChannelMonitorHistoryClient) CreateBulk(builders ...*ChannelMonitorHistoryCreate) *ChannelMonitorHistoryCreateBulk {
+ return &ChannelMonitorHistoryCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *ChannelMonitorHistoryClient) MapCreateBulk(slice any, setFunc func(*ChannelMonitorHistoryCreate, int)) *ChannelMonitorHistoryCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &ChannelMonitorHistoryCreateBulk{err: fmt.Errorf("calling to ChannelMonitorHistoryClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*ChannelMonitorHistoryCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &ChannelMonitorHistoryCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for ChannelMonitorHistory.
+func (c *ChannelMonitorHistoryClient) Update() *ChannelMonitorHistoryUpdate {
+ mutation := newChannelMonitorHistoryMutation(c.config, OpUpdate)
+ return &ChannelMonitorHistoryUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *ChannelMonitorHistoryClient) UpdateOne(_m *ChannelMonitorHistory) *ChannelMonitorHistoryUpdateOne {
+ mutation := newChannelMonitorHistoryMutation(c.config, OpUpdateOne, withChannelMonitorHistory(_m))
+ return &ChannelMonitorHistoryUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *ChannelMonitorHistoryClient) UpdateOneID(id int64) *ChannelMonitorHistoryUpdateOne {
+ mutation := newChannelMonitorHistoryMutation(c.config, OpUpdateOne, withChannelMonitorHistoryID(id))
+ return &ChannelMonitorHistoryUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for ChannelMonitorHistory.
+func (c *ChannelMonitorHistoryClient) Delete() *ChannelMonitorHistoryDelete {
+ mutation := newChannelMonitorHistoryMutation(c.config, OpDelete)
+ return &ChannelMonitorHistoryDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *ChannelMonitorHistoryClient) DeleteOne(_m *ChannelMonitorHistory) *ChannelMonitorHistoryDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *ChannelMonitorHistoryClient) DeleteOneID(id int64) *ChannelMonitorHistoryDeleteOne {
+ builder := c.Delete().Where(channelmonitorhistory.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &ChannelMonitorHistoryDeleteOne{builder}
+}
+
+// Query returns a query builder for ChannelMonitorHistory.
+func (c *ChannelMonitorHistoryClient) Query() *ChannelMonitorHistoryQuery {
+ return &ChannelMonitorHistoryQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeChannelMonitorHistory},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a ChannelMonitorHistory entity by its id.
+func (c *ChannelMonitorHistoryClient) Get(ctx context.Context, id int64) (*ChannelMonitorHistory, error) {
+ return c.Query().Where(channelmonitorhistory.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *ChannelMonitorHistoryClient) GetX(ctx context.Context, id int64) *ChannelMonitorHistory {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryMonitor queries the monitor edge of a ChannelMonitorHistory.
+func (c *ChannelMonitorHistoryClient) QueryMonitor(_m *ChannelMonitorHistory) *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitorhistory.Table, channelmonitorhistory.FieldID, id),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, channelmonitorhistory.MonitorTable, channelmonitorhistory.MonitorColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *ChannelMonitorHistoryClient) Hooks() []Hook {
+ return c.hooks.ChannelMonitorHistory
+}
+
+// Interceptors returns the client interceptors.
+func (c *ChannelMonitorHistoryClient) Interceptors() []Interceptor {
+ return c.inters.ChannelMonitorHistory
+}
+
+func (c *ChannelMonitorHistoryClient) mutate(ctx context.Context, m *ChannelMonitorHistoryMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&ChannelMonitorHistoryCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&ChannelMonitorHistoryUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&ChannelMonitorHistoryUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&ChannelMonitorHistoryDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown ChannelMonitorHistory mutation op: %q", m.Op())
+ }
+}
+
+// ChannelMonitorRequestTemplateClient is a client for the ChannelMonitorRequestTemplate schema.
+type ChannelMonitorRequestTemplateClient struct {
+ config
+}
+
+// NewChannelMonitorRequestTemplateClient returns a client for the ChannelMonitorRequestTemplate from the given config.
+func NewChannelMonitorRequestTemplateClient(c config) *ChannelMonitorRequestTemplateClient {
+ return &ChannelMonitorRequestTemplateClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `channelmonitorrequesttemplate.Hooks(f(g(h())))`.
+func (c *ChannelMonitorRequestTemplateClient) Use(hooks ...Hook) {
+ c.hooks.ChannelMonitorRequestTemplate = append(c.hooks.ChannelMonitorRequestTemplate, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `channelmonitorrequesttemplate.Intercept(f(g(h())))`.
+func (c *ChannelMonitorRequestTemplateClient) Intercept(interceptors ...Interceptor) {
+ c.inters.ChannelMonitorRequestTemplate = append(c.inters.ChannelMonitorRequestTemplate, interceptors...)
+}
+
+// Create returns a builder for creating a ChannelMonitorRequestTemplate entity.
+func (c *ChannelMonitorRequestTemplateClient) Create() *ChannelMonitorRequestTemplateCreate {
+ mutation := newChannelMonitorRequestTemplateMutation(c.config, OpCreate)
+ return &ChannelMonitorRequestTemplateCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of ChannelMonitorRequestTemplate entities.
+func (c *ChannelMonitorRequestTemplateClient) CreateBulk(builders ...*ChannelMonitorRequestTemplateCreate) *ChannelMonitorRequestTemplateCreateBulk {
+ return &ChannelMonitorRequestTemplateCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *ChannelMonitorRequestTemplateClient) MapCreateBulk(slice any, setFunc func(*ChannelMonitorRequestTemplateCreate, int)) *ChannelMonitorRequestTemplateCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &ChannelMonitorRequestTemplateCreateBulk{err: fmt.Errorf("calling to ChannelMonitorRequestTemplateClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*ChannelMonitorRequestTemplateCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &ChannelMonitorRequestTemplateCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for ChannelMonitorRequestTemplate.
+func (c *ChannelMonitorRequestTemplateClient) Update() *ChannelMonitorRequestTemplateUpdate {
+ mutation := newChannelMonitorRequestTemplateMutation(c.config, OpUpdate)
+ return &ChannelMonitorRequestTemplateUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *ChannelMonitorRequestTemplateClient) UpdateOne(_m *ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateUpdateOne {
+ mutation := newChannelMonitorRequestTemplateMutation(c.config, OpUpdateOne, withChannelMonitorRequestTemplate(_m))
+ return &ChannelMonitorRequestTemplateUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *ChannelMonitorRequestTemplateClient) UpdateOneID(id int64) *ChannelMonitorRequestTemplateUpdateOne {
+ mutation := newChannelMonitorRequestTemplateMutation(c.config, OpUpdateOne, withChannelMonitorRequestTemplateID(id))
+ return &ChannelMonitorRequestTemplateUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for ChannelMonitorRequestTemplate.
+func (c *ChannelMonitorRequestTemplateClient) Delete() *ChannelMonitorRequestTemplateDelete {
+ mutation := newChannelMonitorRequestTemplateMutation(c.config, OpDelete)
+ return &ChannelMonitorRequestTemplateDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *ChannelMonitorRequestTemplateClient) DeleteOne(_m *ChannelMonitorRequestTemplate) *ChannelMonitorRequestTemplateDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *ChannelMonitorRequestTemplateClient) DeleteOneID(id int64) *ChannelMonitorRequestTemplateDeleteOne {
+ builder := c.Delete().Where(channelmonitorrequesttemplate.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &ChannelMonitorRequestTemplateDeleteOne{builder}
+}
+
+// Query returns a query builder for ChannelMonitorRequestTemplate.
+func (c *ChannelMonitorRequestTemplateClient) Query() *ChannelMonitorRequestTemplateQuery {
+ return &ChannelMonitorRequestTemplateQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeChannelMonitorRequestTemplate},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a ChannelMonitorRequestTemplate entity by its id.
+func (c *ChannelMonitorRequestTemplateClient) Get(ctx context.Context, id int64) (*ChannelMonitorRequestTemplate, error) {
+ return c.Query().Where(channelmonitorrequesttemplate.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *ChannelMonitorRequestTemplateClient) GetX(ctx context.Context, id int64) *ChannelMonitorRequestTemplate {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryMonitors queries the monitors edge of a ChannelMonitorRequestTemplate.
+func (c *ChannelMonitorRequestTemplateClient) QueryMonitors(_m *ChannelMonitorRequestTemplate) *ChannelMonitorQuery {
+ query := (&ChannelMonitorClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(channelmonitorrequesttemplate.Table, channelmonitorrequesttemplate.FieldID, id),
+ sqlgraph.To(channelmonitor.Table, channelmonitor.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, true, channelmonitorrequesttemplate.MonitorsTable, channelmonitorrequesttemplate.MonitorsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *ChannelMonitorRequestTemplateClient) Hooks() []Hook {
+ return c.hooks.ChannelMonitorRequestTemplate
+}
+
+// Interceptors returns the client interceptors.
+func (c *ChannelMonitorRequestTemplateClient) Interceptors() []Interceptor {
+ return c.inters.ChannelMonitorRequestTemplate
+}
+
+func (c *ChannelMonitorRequestTemplateClient) mutate(ctx context.Context, m *ChannelMonitorRequestTemplateMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&ChannelMonitorRequestTemplateCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&ChannelMonitorRequestTemplateUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&ChannelMonitorRequestTemplateUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&ChannelMonitorRequestTemplateDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown ChannelMonitorRequestTemplate mutation op: %q", m.Op())
+ }
+}
+
// ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema.
type ErrorPassthroughRuleClient struct {
config
@@ -1726,6 +2788,751 @@ func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyReco
}
}
+// IdentityAdoptionDecisionClient is a client for the IdentityAdoptionDecision schema.
+type IdentityAdoptionDecisionClient struct {
+ config
+}
+
+// NewIdentityAdoptionDecisionClient returns a client for the IdentityAdoptionDecision from the given config.
+func NewIdentityAdoptionDecisionClient(c config) *IdentityAdoptionDecisionClient {
+ return &IdentityAdoptionDecisionClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `identityadoptiondecision.Hooks(f(g(h())))`.
+func (c *IdentityAdoptionDecisionClient) Use(hooks ...Hook) {
+ c.hooks.IdentityAdoptionDecision = append(c.hooks.IdentityAdoptionDecision, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `identityadoptiondecision.Intercept(f(g(h())))`.
+func (c *IdentityAdoptionDecisionClient) Intercept(interceptors ...Interceptor) {
+ c.inters.IdentityAdoptionDecision = append(c.inters.IdentityAdoptionDecision, interceptors...)
+}
+
+// Create returns a builder for creating a IdentityAdoptionDecision entity.
+func (c *IdentityAdoptionDecisionClient) Create() *IdentityAdoptionDecisionCreate {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpCreate)
+ return &IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of IdentityAdoptionDecision entities.
+func (c *IdentityAdoptionDecisionClient) CreateBulk(builders ...*IdentityAdoptionDecisionCreate) *IdentityAdoptionDecisionCreateBulk {
+ return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *IdentityAdoptionDecisionClient) MapCreateBulk(slice any, setFunc func(*IdentityAdoptionDecisionCreate, int)) *IdentityAdoptionDecisionCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &IdentityAdoptionDecisionCreateBulk{err: fmt.Errorf("calling to IdentityAdoptionDecisionClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*IdentityAdoptionDecisionCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) Update() *IdentityAdoptionDecisionUpdate {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdate)
+ return &IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *IdentityAdoptionDecisionClient) UpdateOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecision(_m))
+ return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *IdentityAdoptionDecisionClient) UpdateOneID(id int64) *IdentityAdoptionDecisionUpdateOne {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecisionID(id))
+ return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) Delete() *IdentityAdoptionDecisionDelete {
+ mutation := newIdentityAdoptionDecisionMutation(c.config, OpDelete)
+ return &IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *IdentityAdoptionDecisionClient) DeleteOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *IdentityAdoptionDecisionClient) DeleteOneID(id int64) *IdentityAdoptionDecisionDeleteOne {
+ builder := c.Delete().Where(identityadoptiondecision.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &IdentityAdoptionDecisionDeleteOne{builder}
+}
+
+// Query returns a query builder for IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) Query() *IdentityAdoptionDecisionQuery {
+ return &IdentityAdoptionDecisionQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeIdentityAdoptionDecision},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a IdentityAdoptionDecision entity by its id.
+func (c *IdentityAdoptionDecisionClient) Get(ctx context.Context, id int64) (*IdentityAdoptionDecision, error) {
+ return c.Query().Where(identityadoptiondecision.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *IdentityAdoptionDecisionClient) GetX(ctx context.Context, id int64) *IdentityAdoptionDecision {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryPendingAuthSession queries the pending_auth_session edge of a IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) QueryPendingAuthSession(_m *IdentityAdoptionDecision) *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryIdentity queries the identity edge of a IdentityAdoptionDecision.
+func (c *IdentityAdoptionDecisionClient) QueryIdentity(_m *IdentityAdoptionDecision) *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *IdentityAdoptionDecisionClient) Hooks() []Hook {
+ return c.hooks.IdentityAdoptionDecision
+}
+
+// Interceptors returns the client interceptors.
+func (c *IdentityAdoptionDecisionClient) Interceptors() []Interceptor {
+ return c.inters.IdentityAdoptionDecision
+}
+
+func (c *IdentityAdoptionDecisionClient) mutate(ctx context.Context, m *IdentityAdoptionDecisionMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown IdentityAdoptionDecision mutation op: %q", m.Op())
+ }
+}
+
+// PaymentAuditLogClient is a client for the PaymentAuditLog schema.
+type PaymentAuditLogClient struct {
+ config
+}
+
+// NewPaymentAuditLogClient returns a client for the PaymentAuditLog from the given config.
+func NewPaymentAuditLogClient(c config) *PaymentAuditLogClient {
+ return &PaymentAuditLogClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `paymentauditlog.Hooks(f(g(h())))`.
+func (c *PaymentAuditLogClient) Use(hooks ...Hook) {
+ c.hooks.PaymentAuditLog = append(c.hooks.PaymentAuditLog, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `paymentauditlog.Intercept(f(g(h())))`.
+func (c *PaymentAuditLogClient) Intercept(interceptors ...Interceptor) {
+ c.inters.PaymentAuditLog = append(c.inters.PaymentAuditLog, interceptors...)
+}
+
+// Create returns a builder for creating a PaymentAuditLog entity.
+func (c *PaymentAuditLogClient) Create() *PaymentAuditLogCreate {
+ mutation := newPaymentAuditLogMutation(c.config, OpCreate)
+ return &PaymentAuditLogCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of PaymentAuditLog entities.
+func (c *PaymentAuditLogClient) CreateBulk(builders ...*PaymentAuditLogCreate) *PaymentAuditLogCreateBulk {
+ return &PaymentAuditLogCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *PaymentAuditLogClient) MapCreateBulk(slice any, setFunc func(*PaymentAuditLogCreate, int)) *PaymentAuditLogCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &PaymentAuditLogCreateBulk{err: fmt.Errorf("calling to PaymentAuditLogClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*PaymentAuditLogCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &PaymentAuditLogCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for PaymentAuditLog.
+func (c *PaymentAuditLogClient) Update() *PaymentAuditLogUpdate {
+ mutation := newPaymentAuditLogMutation(c.config, OpUpdate)
+ return &PaymentAuditLogUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *PaymentAuditLogClient) UpdateOne(_m *PaymentAuditLog) *PaymentAuditLogUpdateOne {
+ mutation := newPaymentAuditLogMutation(c.config, OpUpdateOne, withPaymentAuditLog(_m))
+ return &PaymentAuditLogUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *PaymentAuditLogClient) UpdateOneID(id int64) *PaymentAuditLogUpdateOne {
+ mutation := newPaymentAuditLogMutation(c.config, OpUpdateOne, withPaymentAuditLogID(id))
+ return &PaymentAuditLogUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for PaymentAuditLog.
+func (c *PaymentAuditLogClient) Delete() *PaymentAuditLogDelete {
+ mutation := newPaymentAuditLogMutation(c.config, OpDelete)
+ return &PaymentAuditLogDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *PaymentAuditLogClient) DeleteOne(_m *PaymentAuditLog) *PaymentAuditLogDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *PaymentAuditLogClient) DeleteOneID(id int64) *PaymentAuditLogDeleteOne {
+ builder := c.Delete().Where(paymentauditlog.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &PaymentAuditLogDeleteOne{builder}
+}
+
+// Query returns a query builder for PaymentAuditLog.
+func (c *PaymentAuditLogClient) Query() *PaymentAuditLogQuery {
+ return &PaymentAuditLogQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypePaymentAuditLog},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a PaymentAuditLog entity by its id.
+func (c *PaymentAuditLogClient) Get(ctx context.Context, id int64) (*PaymentAuditLog, error) {
+ return c.Query().Where(paymentauditlog.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *PaymentAuditLogClient) GetX(ctx context.Context, id int64) *PaymentAuditLog {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// Hooks returns the client hooks.
+func (c *PaymentAuditLogClient) Hooks() []Hook {
+ return c.hooks.PaymentAuditLog
+}
+
+// Interceptors returns the client interceptors.
+func (c *PaymentAuditLogClient) Interceptors() []Interceptor {
+ return c.inters.PaymentAuditLog
+}
+
+func (c *PaymentAuditLogClient) mutate(ctx context.Context, m *PaymentAuditLogMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&PaymentAuditLogCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&PaymentAuditLogUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&PaymentAuditLogUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&PaymentAuditLogDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown PaymentAuditLog mutation op: %q", m.Op())
+ }
+}
+
+// PaymentOrderClient is a client for the PaymentOrder schema.
+type PaymentOrderClient struct {
+ config
+}
+
+// NewPaymentOrderClient returns a client for the PaymentOrder from the given config.
+func NewPaymentOrderClient(c config) *PaymentOrderClient {
+ return &PaymentOrderClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `paymentorder.Hooks(f(g(h())))`.
+func (c *PaymentOrderClient) Use(hooks ...Hook) {
+ c.hooks.PaymentOrder = append(c.hooks.PaymentOrder, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `paymentorder.Intercept(f(g(h())))`.
+func (c *PaymentOrderClient) Intercept(interceptors ...Interceptor) {
+ c.inters.PaymentOrder = append(c.inters.PaymentOrder, interceptors...)
+}
+
+// Create returns a builder for creating a PaymentOrder entity.
+func (c *PaymentOrderClient) Create() *PaymentOrderCreate {
+ mutation := newPaymentOrderMutation(c.config, OpCreate)
+ return &PaymentOrderCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of PaymentOrder entities.
+func (c *PaymentOrderClient) CreateBulk(builders ...*PaymentOrderCreate) *PaymentOrderCreateBulk {
+ return &PaymentOrderCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *PaymentOrderClient) MapCreateBulk(slice any, setFunc func(*PaymentOrderCreate, int)) *PaymentOrderCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &PaymentOrderCreateBulk{err: fmt.Errorf("calling to PaymentOrderClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*PaymentOrderCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &PaymentOrderCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for PaymentOrder.
+func (c *PaymentOrderClient) Update() *PaymentOrderUpdate {
+ mutation := newPaymentOrderMutation(c.config, OpUpdate)
+ return &PaymentOrderUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *PaymentOrderClient) UpdateOne(_m *PaymentOrder) *PaymentOrderUpdateOne {
+ mutation := newPaymentOrderMutation(c.config, OpUpdateOne, withPaymentOrder(_m))
+ return &PaymentOrderUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *PaymentOrderClient) UpdateOneID(id int64) *PaymentOrderUpdateOne {
+ mutation := newPaymentOrderMutation(c.config, OpUpdateOne, withPaymentOrderID(id))
+ return &PaymentOrderUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for PaymentOrder.
+func (c *PaymentOrderClient) Delete() *PaymentOrderDelete {
+ mutation := newPaymentOrderMutation(c.config, OpDelete)
+ return &PaymentOrderDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *PaymentOrderClient) DeleteOne(_m *PaymentOrder) *PaymentOrderDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *PaymentOrderClient) DeleteOneID(id int64) *PaymentOrderDeleteOne {
+ builder := c.Delete().Where(paymentorder.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &PaymentOrderDeleteOne{builder}
+}
+
+// Query returns a query builder for PaymentOrder.
+func (c *PaymentOrderClient) Query() *PaymentOrderQuery {
+ return &PaymentOrderQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypePaymentOrder},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a PaymentOrder entity by its id.
+func (c *PaymentOrderClient) Get(ctx context.Context, id int64) (*PaymentOrder, error) {
+ return c.Query().Where(paymentorder.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *PaymentOrderClient) GetX(ctx context.Context, id int64) *PaymentOrder {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryUser queries the user edge of a PaymentOrder.
+func (c *PaymentOrderClient) QueryUser(_m *PaymentOrder) *UserQuery {
+ query := (&UserClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(paymentorder.Table, paymentorder.FieldID, id),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, paymentorder.UserTable, paymentorder.UserColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *PaymentOrderClient) Hooks() []Hook {
+ return c.hooks.PaymentOrder
+}
+
+// Interceptors returns the client interceptors.
+func (c *PaymentOrderClient) Interceptors() []Interceptor {
+ return c.inters.PaymentOrder
+}
+
+func (c *PaymentOrderClient) mutate(ctx context.Context, m *PaymentOrderMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&PaymentOrderCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&PaymentOrderUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&PaymentOrderUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&PaymentOrderDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown PaymentOrder mutation op: %q", m.Op())
+ }
+}
+
+// PaymentProviderInstanceClient is a client for the PaymentProviderInstance schema.
+type PaymentProviderInstanceClient struct {
+ config
+}
+
+// NewPaymentProviderInstanceClient returns a client for the PaymentProviderInstance from the given config.
+func NewPaymentProviderInstanceClient(c config) *PaymentProviderInstanceClient {
+ return &PaymentProviderInstanceClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `paymentproviderinstance.Hooks(f(g(h())))`.
+func (c *PaymentProviderInstanceClient) Use(hooks ...Hook) {
+ c.hooks.PaymentProviderInstance = append(c.hooks.PaymentProviderInstance, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `paymentproviderinstance.Intercept(f(g(h())))`.
+func (c *PaymentProviderInstanceClient) Intercept(interceptors ...Interceptor) {
+ c.inters.PaymentProviderInstance = append(c.inters.PaymentProviderInstance, interceptors...)
+}
+
+// Create returns a builder for creating a PaymentProviderInstance entity.
+func (c *PaymentProviderInstanceClient) Create() *PaymentProviderInstanceCreate {
+ mutation := newPaymentProviderInstanceMutation(c.config, OpCreate)
+ return &PaymentProviderInstanceCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of PaymentProviderInstance entities.
+func (c *PaymentProviderInstanceClient) CreateBulk(builders ...*PaymentProviderInstanceCreate) *PaymentProviderInstanceCreateBulk {
+ return &PaymentProviderInstanceCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *PaymentProviderInstanceClient) MapCreateBulk(slice any, setFunc func(*PaymentProviderInstanceCreate, int)) *PaymentProviderInstanceCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &PaymentProviderInstanceCreateBulk{err: fmt.Errorf("calling to PaymentProviderInstanceClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*PaymentProviderInstanceCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &PaymentProviderInstanceCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for PaymentProviderInstance.
+func (c *PaymentProviderInstanceClient) Update() *PaymentProviderInstanceUpdate {
+ mutation := newPaymentProviderInstanceMutation(c.config, OpUpdate)
+ return &PaymentProviderInstanceUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *PaymentProviderInstanceClient) UpdateOne(_m *PaymentProviderInstance) *PaymentProviderInstanceUpdateOne {
+ mutation := newPaymentProviderInstanceMutation(c.config, OpUpdateOne, withPaymentProviderInstance(_m))
+ return &PaymentProviderInstanceUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *PaymentProviderInstanceClient) UpdateOneID(id int64) *PaymentProviderInstanceUpdateOne {
+ mutation := newPaymentProviderInstanceMutation(c.config, OpUpdateOne, withPaymentProviderInstanceID(id))
+ return &PaymentProviderInstanceUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for PaymentProviderInstance.
+func (c *PaymentProviderInstanceClient) Delete() *PaymentProviderInstanceDelete {
+ mutation := newPaymentProviderInstanceMutation(c.config, OpDelete)
+ return &PaymentProviderInstanceDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *PaymentProviderInstanceClient) DeleteOne(_m *PaymentProviderInstance) *PaymentProviderInstanceDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *PaymentProviderInstanceClient) DeleteOneID(id int64) *PaymentProviderInstanceDeleteOne {
+ builder := c.Delete().Where(paymentproviderinstance.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &PaymentProviderInstanceDeleteOne{builder}
+}
+
+// Query returns a query builder for PaymentProviderInstance.
+func (c *PaymentProviderInstanceClient) Query() *PaymentProviderInstanceQuery {
+ return &PaymentProviderInstanceQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypePaymentProviderInstance},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a PaymentProviderInstance entity by its id.
+func (c *PaymentProviderInstanceClient) Get(ctx context.Context, id int64) (*PaymentProviderInstance, error) {
+ return c.Query().Where(paymentproviderinstance.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *PaymentProviderInstanceClient) GetX(ctx context.Context, id int64) *PaymentProviderInstance {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// Hooks returns the client hooks.
+func (c *PaymentProviderInstanceClient) Hooks() []Hook {
+ return c.hooks.PaymentProviderInstance
+}
+
+// Interceptors returns the client interceptors.
+func (c *PaymentProviderInstanceClient) Interceptors() []Interceptor {
+ return c.inters.PaymentProviderInstance
+}
+
+func (c *PaymentProviderInstanceClient) mutate(ctx context.Context, m *PaymentProviderInstanceMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&PaymentProviderInstanceCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&PaymentProviderInstanceUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&PaymentProviderInstanceUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&PaymentProviderInstanceDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown PaymentProviderInstance mutation op: %q", m.Op())
+ }
+}
+
+// PendingAuthSessionClient is a client for the PendingAuthSession schema.
+type PendingAuthSessionClient struct {
+ config
+}
+
+// NewPendingAuthSessionClient returns a client for the PendingAuthSession from the given config.
+func NewPendingAuthSessionClient(c config) *PendingAuthSessionClient {
+ return &PendingAuthSessionClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `pendingauthsession.Hooks(f(g(h())))`.
+func (c *PendingAuthSessionClient) Use(hooks ...Hook) {
+ c.hooks.PendingAuthSession = append(c.hooks.PendingAuthSession, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `pendingauthsession.Intercept(f(g(h())))`.
+func (c *PendingAuthSessionClient) Intercept(interceptors ...Interceptor) {
+ c.inters.PendingAuthSession = append(c.inters.PendingAuthSession, interceptors...)
+}
+
+// Create returns a builder for creating a PendingAuthSession entity.
+func (c *PendingAuthSessionClient) Create() *PendingAuthSessionCreate {
+ mutation := newPendingAuthSessionMutation(c.config, OpCreate)
+ return &PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of PendingAuthSession entities.
+func (c *PendingAuthSessionClient) CreateBulk(builders ...*PendingAuthSessionCreate) *PendingAuthSessionCreateBulk {
+ return &PendingAuthSessionCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *PendingAuthSessionClient) MapCreateBulk(slice any, setFunc func(*PendingAuthSessionCreate, int)) *PendingAuthSessionCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &PendingAuthSessionCreateBulk{err: fmt.Errorf("calling to PendingAuthSessionClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*PendingAuthSessionCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &PendingAuthSessionCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for PendingAuthSession.
+func (c *PendingAuthSessionClient) Update() *PendingAuthSessionUpdate {
+ mutation := newPendingAuthSessionMutation(c.config, OpUpdate)
+ return &PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *PendingAuthSessionClient) UpdateOne(_m *PendingAuthSession) *PendingAuthSessionUpdateOne {
+ mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSession(_m))
+ return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *PendingAuthSessionClient) UpdateOneID(id int64) *PendingAuthSessionUpdateOne {
+ mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSessionID(id))
+ return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for PendingAuthSession.
+func (c *PendingAuthSessionClient) Delete() *PendingAuthSessionDelete {
+ mutation := newPendingAuthSessionMutation(c.config, OpDelete)
+ return &PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *PendingAuthSessionClient) DeleteOne(_m *PendingAuthSession) *PendingAuthSessionDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *PendingAuthSessionClient) DeleteOneID(id int64) *PendingAuthSessionDeleteOne {
+ builder := c.Delete().Where(pendingauthsession.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &PendingAuthSessionDeleteOne{builder}
+}
+
+// Query returns a query builder for PendingAuthSession.
+func (c *PendingAuthSessionClient) Query() *PendingAuthSessionQuery {
+ return &PendingAuthSessionQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypePendingAuthSession},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a PendingAuthSession entity by its id.
+func (c *PendingAuthSessionClient) Get(ctx context.Context, id int64) (*PendingAuthSession, error) {
+ return c.Query().Where(pendingauthsession.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *PendingAuthSessionClient) GetX(ctx context.Context, id int64) *PendingAuthSession {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// QueryTargetUser queries the target_user edge of a PendingAuthSession.
+func (c *PendingAuthSessionClient) QueryTargetUser(_m *PendingAuthSession) *UserQuery {
+ query := (&UserClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecision queries the adoption_decision edge of a PendingAuthSession.
+func (c *PendingAuthSessionClient) QueryAdoptionDecision(_m *PendingAuthSession) *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// Hooks returns the client hooks.
+func (c *PendingAuthSessionClient) Hooks() []Hook {
+ return c.hooks.PendingAuthSession
+}
+
+// Interceptors returns the client interceptors.
+func (c *PendingAuthSessionClient) Interceptors() []Interceptor {
+ return c.inters.PendingAuthSession
+}
+
+func (c *PendingAuthSessionClient) mutate(ctx context.Context, m *PendingAuthSessionMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown PendingAuthSession mutation op: %q", m.Op())
+ }
+}
+
// PromoCodeClient is a client for the PromoCode schema.
type PromoCodeClient struct {
config
@@ -2622,6 +4429,139 @@ func (c *SettingClient) mutate(ctx context.Context, m *SettingMutation) (Value,
}
}
+// SubscriptionPlanClient is a client for the SubscriptionPlan schema.
+type SubscriptionPlanClient struct {
+ config
+}
+
+// NewSubscriptionPlanClient returns a client for the SubscriptionPlan from the given config.
+func NewSubscriptionPlanClient(c config) *SubscriptionPlanClient {
+ return &SubscriptionPlanClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `subscriptionplan.Hooks(f(g(h())))`.
+func (c *SubscriptionPlanClient) Use(hooks ...Hook) {
+ c.hooks.SubscriptionPlan = append(c.hooks.SubscriptionPlan, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `subscriptionplan.Intercept(f(g(h())))`.
+func (c *SubscriptionPlanClient) Intercept(interceptors ...Interceptor) {
+ c.inters.SubscriptionPlan = append(c.inters.SubscriptionPlan, interceptors...)
+}
+
+// Create returns a builder for creating a SubscriptionPlan entity.
+func (c *SubscriptionPlanClient) Create() *SubscriptionPlanCreate {
+ mutation := newSubscriptionPlanMutation(c.config, OpCreate)
+ return &SubscriptionPlanCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of SubscriptionPlan entities.
+func (c *SubscriptionPlanClient) CreateBulk(builders ...*SubscriptionPlanCreate) *SubscriptionPlanCreateBulk {
+ return &SubscriptionPlanCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *SubscriptionPlanClient) MapCreateBulk(slice any, setFunc func(*SubscriptionPlanCreate, int)) *SubscriptionPlanCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &SubscriptionPlanCreateBulk{err: fmt.Errorf("calling to SubscriptionPlanClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*SubscriptionPlanCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &SubscriptionPlanCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for SubscriptionPlan.
+func (c *SubscriptionPlanClient) Update() *SubscriptionPlanUpdate {
+ mutation := newSubscriptionPlanMutation(c.config, OpUpdate)
+ return &SubscriptionPlanUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *SubscriptionPlanClient) UpdateOne(_m *SubscriptionPlan) *SubscriptionPlanUpdateOne {
+ mutation := newSubscriptionPlanMutation(c.config, OpUpdateOne, withSubscriptionPlan(_m))
+ return &SubscriptionPlanUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *SubscriptionPlanClient) UpdateOneID(id int64) *SubscriptionPlanUpdateOne {
+ mutation := newSubscriptionPlanMutation(c.config, OpUpdateOne, withSubscriptionPlanID(id))
+ return &SubscriptionPlanUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for SubscriptionPlan.
+func (c *SubscriptionPlanClient) Delete() *SubscriptionPlanDelete {
+ mutation := newSubscriptionPlanMutation(c.config, OpDelete)
+ return &SubscriptionPlanDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *SubscriptionPlanClient) DeleteOne(_m *SubscriptionPlan) *SubscriptionPlanDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *SubscriptionPlanClient) DeleteOneID(id int64) *SubscriptionPlanDeleteOne {
+ builder := c.Delete().Where(subscriptionplan.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &SubscriptionPlanDeleteOne{builder}
+}
+
+// Query returns a query builder for SubscriptionPlan.
+func (c *SubscriptionPlanClient) Query() *SubscriptionPlanQuery {
+ return &SubscriptionPlanQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeSubscriptionPlan},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a SubscriptionPlan entity by its id.
+func (c *SubscriptionPlanClient) Get(ctx context.Context, id int64) (*SubscriptionPlan, error) {
+ return c.Query().Where(subscriptionplan.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *SubscriptionPlanClient) GetX(ctx context.Context, id int64) *SubscriptionPlan {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// Hooks returns the client hooks.
+func (c *SubscriptionPlanClient) Hooks() []Hook {
+ return c.hooks.SubscriptionPlan
+}
+
+// Interceptors returns the client interceptors.
+func (c *SubscriptionPlanClient) Interceptors() []Interceptor {
+ return c.inters.SubscriptionPlan
+}
+
+func (c *SubscriptionPlanClient) mutate(ctx context.Context, m *SubscriptionPlanMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&SubscriptionPlanCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&SubscriptionPlanUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&SubscriptionPlanUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&SubscriptionPlanDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown SubscriptionPlan mutation op: %q", m.Op())
+ }
+}
+
// TLSFingerprintProfileClient is a client for the TLSFingerprintProfile schema.
type TLSFingerprintProfileClient struct {
config
@@ -3353,6 +5293,54 @@ func (c *UserClient) QueryPromoCodeUsages(_m *User) *PromoCodeUsageQuery {
return query
}
+// QueryPaymentOrders queries the payment_orders edge of a User.
+func (c *UserClient) QueryPaymentOrders(_m *User) *PaymentOrderQuery {
+ query := (&PaymentOrderClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, id),
+ sqlgraph.To(paymentorder.Table, paymentorder.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.PaymentOrdersTable, user.PaymentOrdersColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryAuthIdentities queries the auth_identities edge of a User.
+func (c *UserClient) QueryAuthIdentities(_m *User) *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, id),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
+// QueryPendingAuthSessions queries the pending_auth_sessions edge of a User.
+func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: c.config}).Query()
+ query.path = func(context.Context) (fromV *sql.Selector, _ error) {
+ id := _m.ID
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, id),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn),
+ )
+ fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step)
+ return fromV, nil
+ }
+ return query
+}
+
// QueryUserAllowedGroups queries the user_allowed_groups edge of a User.
func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: c.config}).Query()
@@ -4030,18 +6018,24 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
// hooks and interceptors per client, for fast access.
type (
hooks struct {
- APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
- ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage,
- Proxy, RedeemCode, SecuritySecret, Setting, TLSFingerprintProfile,
- UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
- UserAttributeValue, UserSubscription []ent.Hook
+ APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
+ AuthIdentityChannel, ChannelMonitor, ChannelMonitorDailyRollup,
+ ChannelMonitorHistory, ChannelMonitorRequestTemplate, ErrorPassthroughRule,
+ Group, IdempotencyRecord, IdentityAdoptionDecision, PaymentAuditLog,
+ PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode,
+ PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
+ TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
+ UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
}
inters struct {
- APIKey, Account, AccountGroup, Announcement, AnnouncementRead,
- ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage,
- Proxy, RedeemCode, SecuritySecret, Setting, TLSFingerprintProfile,
- UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
- UserAttributeValue, UserSubscription []ent.Interceptor
+ APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity,
+ AuthIdentityChannel, ChannelMonitor, ChannelMonitorDailyRollup,
+ ChannelMonitorHistory, ChannelMonitorRequestTemplate, ErrorPassthroughRule,
+ Group, IdempotencyRecord, IdentityAdoptionDecision, PaymentAuditLog,
+ PaymentOrder, PaymentProviderInstance, PendingAuthSession, PromoCode,
+ PromoCodeUsage, Proxy, RedeemCode, SecuritySecret, Setting, SubscriptionPlan,
+ TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
+ UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
}
)
diff --git a/backend/ent/ent.go b/backend/ent/ent.go
index bdeaed8a..c9fcc314 100644
--- a/backend/ent/ent.go
+++ b/backend/ent/ent.go
@@ -17,15 +17,27 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
+ "github.com/Wei-Shaw/sub2api/ent/subscriptionplan"
"github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -94,28 +106,40 @@ var (
func checkColumn(t, c string) error {
initCheck.Do(func() {
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
- apikey.Table: apikey.ValidColumn,
- account.Table: account.ValidColumn,
- accountgroup.Table: accountgroup.ValidColumn,
- announcement.Table: announcement.ValidColumn,
- announcementread.Table: announcementread.ValidColumn,
- errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
- group.Table: group.ValidColumn,
- idempotencyrecord.Table: idempotencyrecord.ValidColumn,
- promocode.Table: promocode.ValidColumn,
- promocodeusage.Table: promocodeusage.ValidColumn,
- proxy.Table: proxy.ValidColumn,
- redeemcode.Table: redeemcode.ValidColumn,
- securitysecret.Table: securitysecret.ValidColumn,
- setting.Table: setting.ValidColumn,
- tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
- usagecleanuptask.Table: usagecleanuptask.ValidColumn,
- usagelog.Table: usagelog.ValidColumn,
- user.Table: user.ValidColumn,
- userallowedgroup.Table: userallowedgroup.ValidColumn,
- userattributedefinition.Table: userattributedefinition.ValidColumn,
- userattributevalue.Table: userattributevalue.ValidColumn,
- usersubscription.Table: usersubscription.ValidColumn,
+ apikey.Table: apikey.ValidColumn,
+ account.Table: account.ValidColumn,
+ accountgroup.Table: accountgroup.ValidColumn,
+ announcement.Table: announcement.ValidColumn,
+ announcementread.Table: announcementread.ValidColumn,
+ authidentity.Table: authidentity.ValidColumn,
+ authidentitychannel.Table: authidentitychannel.ValidColumn,
+ channelmonitor.Table: channelmonitor.ValidColumn,
+ channelmonitordailyrollup.Table: channelmonitordailyrollup.ValidColumn,
+ channelmonitorhistory.Table: channelmonitorhistory.ValidColumn,
+ channelmonitorrequesttemplate.Table: channelmonitorrequesttemplate.ValidColumn,
+ errorpassthroughrule.Table: errorpassthroughrule.ValidColumn,
+ group.Table: group.ValidColumn,
+ idempotencyrecord.Table: idempotencyrecord.ValidColumn,
+ identityadoptiondecision.Table: identityadoptiondecision.ValidColumn,
+ paymentauditlog.Table: paymentauditlog.ValidColumn,
+ paymentorder.Table: paymentorder.ValidColumn,
+ paymentproviderinstance.Table: paymentproviderinstance.ValidColumn,
+ pendingauthsession.Table: pendingauthsession.ValidColumn,
+ promocode.Table: promocode.ValidColumn,
+ promocodeusage.Table: promocodeusage.ValidColumn,
+ proxy.Table: proxy.ValidColumn,
+ redeemcode.Table: redeemcode.ValidColumn,
+ securitysecret.Table: securitysecret.ValidColumn,
+ setting.Table: setting.ValidColumn,
+ subscriptionplan.Table: subscriptionplan.ValidColumn,
+ tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn,
+ usagecleanuptask.Table: usagecleanuptask.ValidColumn,
+ usagelog.Table: usagelog.ValidColumn,
+ user.Table: user.ValidColumn,
+ userallowedgroup.Table: userallowedgroup.ValidColumn,
+ userattributedefinition.Table: userattributedefinition.ValidColumn,
+ userattributevalue.Table: userattributevalue.ValidColumn,
+ usersubscription.Table: usersubscription.ValidColumn,
})
})
return columnCheck(t, c)
diff --git a/backend/ent/group.go b/backend/ent/group.go
index fc691a9b..5d9ae2ed 100644
--- a/backend/ent/group.go
+++ b/backend/ent/group.go
@@ -11,6 +11,7 @@ import (
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent/group"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
)
// Group is the model entity for the Group schema.
@@ -52,16 +53,6 @@ type Group struct {
ImagePrice2k *float64 `json:"image_price_2k,omitempty"`
// ImagePrice4k holds the value of the "image_price_4k" field.
ImagePrice4k *float64 `json:"image_price_4k,omitempty"`
- // SoraImagePrice360 holds the value of the "sora_image_price_360" field.
- SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
- // SoraImagePrice540 holds the value of the "sora_image_price_540" field.
- SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
- // SoraVideoPricePerRequest holds the value of the "sora_video_price_per_request" field.
- SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
- // SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field.
- SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"`
- // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
- SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
// 是否仅允许 Claude Code 客户端
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// 非 Claude Code 请求降级使用的分组 ID
@@ -86,6 +77,10 @@ type Group struct {
RequirePrivacySet bool `json:"require_privacy_set,omitempty"`
// 默认映射模型 ID,当账号级映射找不到时使用此值
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
+ // OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型
+ MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
+ // 分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流
+ RpmLimit int `json:"rpm_limit,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"`
@@ -192,13 +187,13 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
- case group.FieldModelRouting, group.FieldSupportedModelScopes:
+ case group.FieldModelRouting, group.FieldSupportedModelScopes, group.FieldMessagesDispatchModelConfig:
values[i] = new([]byte)
case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject, group.FieldAllowMessagesDispatch, group.FieldRequireOauthOnly, group.FieldRequirePrivacySet:
values[i] = new(sql.NullBool)
- case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd:
+ case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64)
- case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder:
+ case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldRpmLimit:
values[i] = new(sql.NullInt64)
case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel:
values[i] = new(sql.NullString)
@@ -335,40 +330,6 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.ImagePrice4k = new(float64)
*_m.ImagePrice4k = value.Float64
}
- case group.FieldSoraImagePrice360:
- if value, ok := values[i].(*sql.NullFloat64); !ok {
- return fmt.Errorf("unexpected type %T for field sora_image_price_360", values[i])
- } else if value.Valid {
- _m.SoraImagePrice360 = new(float64)
- *_m.SoraImagePrice360 = value.Float64
- }
- case group.FieldSoraImagePrice540:
- if value, ok := values[i].(*sql.NullFloat64); !ok {
- return fmt.Errorf("unexpected type %T for field sora_image_price_540", values[i])
- } else if value.Valid {
- _m.SoraImagePrice540 = new(float64)
- *_m.SoraImagePrice540 = value.Float64
- }
- case group.FieldSoraVideoPricePerRequest:
- if value, ok := values[i].(*sql.NullFloat64); !ok {
- return fmt.Errorf("unexpected type %T for field sora_video_price_per_request", values[i])
- } else if value.Valid {
- _m.SoraVideoPricePerRequest = new(float64)
- *_m.SoraVideoPricePerRequest = value.Float64
- }
- case group.FieldSoraVideoPricePerRequestHd:
- if value, ok := values[i].(*sql.NullFloat64); !ok {
- return fmt.Errorf("unexpected type %T for field sora_video_price_per_request_hd", values[i])
- } else if value.Valid {
- _m.SoraVideoPricePerRequestHd = new(float64)
- *_m.SoraVideoPricePerRequestHd = value.Float64
- }
- case group.FieldSoraStorageQuotaBytes:
- if value, ok := values[i].(*sql.NullInt64); !ok {
- return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
- } else if value.Valid {
- _m.SoraStorageQuotaBytes = value.Int64
- }
case group.FieldClaudeCodeOnly:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field claude_code_only", values[i])
@@ -447,6 +408,20 @@ func (_m *Group) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.DefaultMappedModel = value.String
}
+ case group.FieldMessagesDispatchModelConfig:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field messages_dispatch_model_config", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.MessagesDispatchModelConfig); err != nil {
+ return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err)
+ }
+ }
+ case group.FieldRpmLimit:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field rpm_limit", values[i])
+ } else if value.Valid {
+ _m.RpmLimit = int(value.Int64)
+ }
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -590,29 +565,6 @@ func (_m *Group) String() string {
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
- if v := _m.SoraImagePrice360; v != nil {
- builder.WriteString("sora_image_price_360=")
- builder.WriteString(fmt.Sprintf("%v", *v))
- }
- builder.WriteString(", ")
- if v := _m.SoraImagePrice540; v != nil {
- builder.WriteString("sora_image_price_540=")
- builder.WriteString(fmt.Sprintf("%v", *v))
- }
- builder.WriteString(", ")
- if v := _m.SoraVideoPricePerRequest; v != nil {
- builder.WriteString("sora_video_price_per_request=")
- builder.WriteString(fmt.Sprintf("%v", *v))
- }
- builder.WriteString(", ")
- if v := _m.SoraVideoPricePerRequestHd; v != nil {
- builder.WriteString("sora_video_price_per_request_hd=")
- builder.WriteString(fmt.Sprintf("%v", *v))
- }
- builder.WriteString(", ")
- builder.WriteString("sora_storage_quota_bytes=")
- builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
- builder.WriteString(", ")
builder.WriteString("claude_code_only=")
builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly))
builder.WriteString(", ")
@@ -652,6 +604,12 @@ func (_m *Group) String() string {
builder.WriteString(", ")
builder.WriteString("default_mapped_model=")
builder.WriteString(_m.DefaultMappedModel)
+ builder.WriteString(", ")
+ builder.WriteString("messages_dispatch_model_config=")
+ builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig))
+ builder.WriteString(", ")
+ builder.WriteString("rpm_limit=")
+ builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit))
builder.WriteByte(')')
return builder.String()
}
diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go
index 35222127..24bd9c13 100644
--- a/backend/ent/group/group.go
+++ b/backend/ent/group/group.go
@@ -8,6 +8,7 @@ import (
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
)
const (
@@ -49,16 +50,6 @@ const (
FieldImagePrice2k = "image_price_2k"
// FieldImagePrice4k holds the string denoting the image_price_4k field in the database.
FieldImagePrice4k = "image_price_4k"
- // FieldSoraImagePrice360 holds the string denoting the sora_image_price_360 field in the database.
- FieldSoraImagePrice360 = "sora_image_price_360"
- // FieldSoraImagePrice540 holds the string denoting the sora_image_price_540 field in the database.
- FieldSoraImagePrice540 = "sora_image_price_540"
- // FieldSoraVideoPricePerRequest holds the string denoting the sora_video_price_per_request field in the database.
- FieldSoraVideoPricePerRequest = "sora_video_price_per_request"
- // FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database.
- FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd"
- // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
- FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
// FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database.
FieldClaudeCodeOnly = "claude_code_only"
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
@@ -83,6 +74,10 @@ const (
FieldRequirePrivacySet = "require_privacy_set"
// FieldDefaultMappedModel holds the string denoting the default_mapped_model field in the database.
FieldDefaultMappedModel = "default_mapped_model"
+ // FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database.
+ FieldMessagesDispatchModelConfig = "messages_dispatch_model_config"
+ // FieldRpmLimit holds the string denoting the rpm_limit field in the database.
+ FieldRpmLimit = "rpm_limit"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -175,11 +170,6 @@ var Columns = []string{
FieldImagePrice1k,
FieldImagePrice2k,
FieldImagePrice4k,
- FieldSoraImagePrice360,
- FieldSoraImagePrice540,
- FieldSoraVideoPricePerRequest,
- FieldSoraVideoPricePerRequestHd,
- FieldSoraStorageQuotaBytes,
FieldClaudeCodeOnly,
FieldFallbackGroupID,
FieldFallbackGroupIDOnInvalidRequest,
@@ -192,6 +182,8 @@ var Columns = []string{
FieldRequireOauthOnly,
FieldRequirePrivacySet,
FieldDefaultMappedModel,
+ FieldMessagesDispatchModelConfig,
+ FieldRpmLimit,
}
var (
@@ -247,8 +239,6 @@ var (
SubscriptionTypeValidator func(string) error
// DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field.
DefaultDefaultValidityDays int
- // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
- DefaultSoraStorageQuotaBytes int64
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
DefaultClaudeCodeOnly bool
// DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
@@ -269,6 +259,10 @@ var (
DefaultDefaultMappedModel string
// DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
DefaultMappedModelValidator func(string) error
+ // DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field.
+ DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig
+ // DefaultRpmLimit holds the default value on creation for the "rpm_limit" field.
+ DefaultRpmLimit int
)
// OrderOption defines the ordering options for the Group queries.
@@ -364,31 +358,6 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc()
}
-// BySoraImagePrice360 orders the results by the sora_image_price_360 field.
-func BySoraImagePrice360(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldSoraImagePrice360, opts...).ToFunc()
-}
-
-// BySoraImagePrice540 orders the results by the sora_image_price_540 field.
-func BySoraImagePrice540(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldSoraImagePrice540, opts...).ToFunc()
-}
-
-// BySoraVideoPricePerRequest orders the results by the sora_video_price_per_request field.
-func BySoraVideoPricePerRequest(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldSoraVideoPricePerRequest, opts...).ToFunc()
-}
-
-// BySoraVideoPricePerRequestHd orders the results by the sora_video_price_per_request_hd field.
-func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc()
-}
-
-// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
-func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
-}
-
// ByClaudeCodeOnly orders the results by the claude_code_only field.
func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc()
@@ -439,6 +408,11 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc()
}
+// ByRpmLimit orders the results by the rpm_limit field.
+func ByRpmLimit(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRpmLimit, opts...).ToFunc()
+}
+
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go
index 41bd575a..2814d130 100644
--- a/backend/ent/group/where.go
+++ b/backend/ent/group/where.go
@@ -140,31 +140,6 @@ func ImagePrice4k(v float64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v))
}
-// SoraImagePrice360 applies equality check predicate on the "sora_image_price_360" field. It's identical to SoraImagePrice360EQ.
-func SoraImagePrice360(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v))
-}
-
-// SoraImagePrice540 applies equality check predicate on the "sora_image_price_540" field. It's identical to SoraImagePrice540EQ.
-func SoraImagePrice540(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v))
-}
-
-// SoraVideoPricePerRequest applies equality check predicate on the "sora_video_price_per_request" field. It's identical to SoraVideoPricePerRequestEQ.
-func SoraVideoPricePerRequest(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v))
-}
-
-// SoraVideoPricePerRequestHd applies equality check predicate on the "sora_video_price_per_request_hd" field. It's identical to SoraVideoPricePerRequestHdEQ.
-func SoraVideoPricePerRequestHd(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
-}
-
-// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
-func SoraStorageQuotaBytes(v int64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
-}
-
// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ.
func ClaudeCodeOnly(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
@@ -215,6 +190,11 @@ func DefaultMappedModel(v string) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v))
}
+// RpmLimit applies equality check predicate on the "rpm_limit" field. It's identical to RpmLimitEQ.
+func RpmLimit(v int) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldRpmLimit, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
@@ -1070,246 +1050,6 @@ func ImagePrice4kNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldImagePrice4k))
}
-// SoraImagePrice360EQ applies the EQ predicate on the "sora_image_price_360" field.
-func SoraImagePrice360EQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v))
-}
-
-// SoraImagePrice360NEQ applies the NEQ predicate on the "sora_image_price_360" field.
-func SoraImagePrice360NEQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice360, v))
-}
-
-// SoraImagePrice360In applies the In predicate on the "sora_image_price_360" field.
-func SoraImagePrice360In(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldIn(FieldSoraImagePrice360, vs...))
-}
-
-// SoraImagePrice360NotIn applies the NotIn predicate on the "sora_image_price_360" field.
-func SoraImagePrice360NotIn(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice360, vs...))
-}
-
-// SoraImagePrice360GT applies the GT predicate on the "sora_image_price_360" field.
-func SoraImagePrice360GT(v float64) predicate.Group {
- return predicate.Group(sql.FieldGT(FieldSoraImagePrice360, v))
-}
-
-// SoraImagePrice360GTE applies the GTE predicate on the "sora_image_price_360" field.
-func SoraImagePrice360GTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldGTE(FieldSoraImagePrice360, v))
-}
-
-// SoraImagePrice360LT applies the LT predicate on the "sora_image_price_360" field.
-func SoraImagePrice360LT(v float64) predicate.Group {
- return predicate.Group(sql.FieldLT(FieldSoraImagePrice360, v))
-}
-
-// SoraImagePrice360LTE applies the LTE predicate on the "sora_image_price_360" field.
-func SoraImagePrice360LTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldLTE(FieldSoraImagePrice360, v))
-}
-
-// SoraImagePrice360IsNil applies the IsNil predicate on the "sora_image_price_360" field.
-func SoraImagePrice360IsNil() predicate.Group {
- return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice360))
-}
-
-// SoraImagePrice360NotNil applies the NotNil predicate on the "sora_image_price_360" field.
-func SoraImagePrice360NotNil() predicate.Group {
- return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice360))
-}
-
-// SoraImagePrice540EQ applies the EQ predicate on the "sora_image_price_540" field.
-func SoraImagePrice540EQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v))
-}
-
-// SoraImagePrice540NEQ applies the NEQ predicate on the "sora_image_price_540" field.
-func SoraImagePrice540NEQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice540, v))
-}
-
-// SoraImagePrice540In applies the In predicate on the "sora_image_price_540" field.
-func SoraImagePrice540In(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldIn(FieldSoraImagePrice540, vs...))
-}
-
-// SoraImagePrice540NotIn applies the NotIn predicate on the "sora_image_price_540" field.
-func SoraImagePrice540NotIn(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice540, vs...))
-}
-
-// SoraImagePrice540GT applies the GT predicate on the "sora_image_price_540" field.
-func SoraImagePrice540GT(v float64) predicate.Group {
- return predicate.Group(sql.FieldGT(FieldSoraImagePrice540, v))
-}
-
-// SoraImagePrice540GTE applies the GTE predicate on the "sora_image_price_540" field.
-func SoraImagePrice540GTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldGTE(FieldSoraImagePrice540, v))
-}
-
-// SoraImagePrice540LT applies the LT predicate on the "sora_image_price_540" field.
-func SoraImagePrice540LT(v float64) predicate.Group {
- return predicate.Group(sql.FieldLT(FieldSoraImagePrice540, v))
-}
-
-// SoraImagePrice540LTE applies the LTE predicate on the "sora_image_price_540" field.
-func SoraImagePrice540LTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldLTE(FieldSoraImagePrice540, v))
-}
-
-// SoraImagePrice540IsNil applies the IsNil predicate on the "sora_image_price_540" field.
-func SoraImagePrice540IsNil() predicate.Group {
- return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice540))
-}
-
-// SoraImagePrice540NotNil applies the NotNil predicate on the "sora_image_price_540" field.
-func SoraImagePrice540NotNil() predicate.Group {
- return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice540))
-}
-
-// SoraVideoPricePerRequestEQ applies the EQ predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestEQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v))
-}
-
-// SoraVideoPricePerRequestNEQ applies the NEQ predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestNEQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequest, v))
-}
-
-// SoraVideoPricePerRequestIn applies the In predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestIn(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequest, vs...))
-}
-
-// SoraVideoPricePerRequestNotIn applies the NotIn predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestNotIn(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequest, vs...))
-}
-
-// SoraVideoPricePerRequestGT applies the GT predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestGT(v float64) predicate.Group {
- return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequest, v))
-}
-
-// SoraVideoPricePerRequestGTE applies the GTE predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestGTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequest, v))
-}
-
-// SoraVideoPricePerRequestLT applies the LT predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestLT(v float64) predicate.Group {
- return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequest, v))
-}
-
-// SoraVideoPricePerRequestLTE applies the LTE predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestLTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequest, v))
-}
-
-// SoraVideoPricePerRequestIsNil applies the IsNil predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestIsNil() predicate.Group {
- return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequest))
-}
-
-// SoraVideoPricePerRequestNotNil applies the NotNil predicate on the "sora_video_price_per_request" field.
-func SoraVideoPricePerRequestNotNil() predicate.Group {
- return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequest))
-}
-
-// SoraVideoPricePerRequestHdEQ applies the EQ predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdEQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v))
-}
-
-// SoraVideoPricePerRequestHdNEQ applies the NEQ predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdNEQ(v float64) predicate.Group {
- return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequestHd, v))
-}
-
-// SoraVideoPricePerRequestHdIn applies the In predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdIn(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequestHd, vs...))
-}
-
-// SoraVideoPricePerRequestHdNotIn applies the NotIn predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdNotIn(vs ...float64) predicate.Group {
- return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequestHd, vs...))
-}
-
-// SoraVideoPricePerRequestHdGT applies the GT predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdGT(v float64) predicate.Group {
- return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequestHd, v))
-}
-
-// SoraVideoPricePerRequestHdGTE applies the GTE predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdGTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequestHd, v))
-}
-
-// SoraVideoPricePerRequestHdLT applies the LT predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdLT(v float64) predicate.Group {
- return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequestHd, v))
-}
-
-// SoraVideoPricePerRequestHdLTE applies the LTE predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdLTE(v float64) predicate.Group {
- return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequestHd, v))
-}
-
-// SoraVideoPricePerRequestHdIsNil applies the IsNil predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdIsNil() predicate.Group {
- return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequestHd))
-}
-
-// SoraVideoPricePerRequestHdNotNil applies the NotNil predicate on the "sora_video_price_per_request_hd" field.
-func SoraVideoPricePerRequestHdNotNil() predicate.Group {
- return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd))
-}
-
-// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesEQ(v int64) predicate.Group {
- return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesNEQ(v int64) predicate.Group {
- return predicate.Group(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesIn(vs ...int64) predicate.Group {
- return predicate.Group(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
-}
-
-// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.Group {
- return predicate.Group(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
-}
-
-// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesGT(v int64) predicate.Group {
- return predicate.Group(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesGTE(v int64) predicate.Group {
- return predicate.Group(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesLT(v int64) predicate.Group {
- return predicate.Group(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
-}
-
-// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesLTE(v int64) predicate.Group {
- return predicate.Group(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
-}
-
// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field.
func ClaudeCodeOnlyEQ(v bool) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v))
@@ -1585,6 +1325,46 @@ func DefaultMappedModelContainsFold(v string) predicate.Group {
return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v))
}
+// RpmLimitEQ applies the EQ predicate on the "rpm_limit" field.
+func RpmLimitEQ(v int) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldRpmLimit, v))
+}
+
+// RpmLimitNEQ applies the NEQ predicate on the "rpm_limit" field.
+func RpmLimitNEQ(v int) predicate.Group {
+ return predicate.Group(sql.FieldNEQ(FieldRpmLimit, v))
+}
+
+// RpmLimitIn applies the In predicate on the "rpm_limit" field.
+func RpmLimitIn(vs ...int) predicate.Group {
+ return predicate.Group(sql.FieldIn(FieldRpmLimit, vs...))
+}
+
+// RpmLimitNotIn applies the NotIn predicate on the "rpm_limit" field.
+func RpmLimitNotIn(vs ...int) predicate.Group {
+ return predicate.Group(sql.FieldNotIn(FieldRpmLimit, vs...))
+}
+
+// RpmLimitGT applies the GT predicate on the "rpm_limit" field.
+func RpmLimitGT(v int) predicate.Group {
+ return predicate.Group(sql.FieldGT(FieldRpmLimit, v))
+}
+
+// RpmLimitGTE applies the GTE predicate on the "rpm_limit" field.
+func RpmLimitGTE(v int) predicate.Group {
+ return predicate.Group(sql.FieldGTE(FieldRpmLimit, v))
+}
+
+// RpmLimitLT applies the LT predicate on the "rpm_limit" field.
+func RpmLimitLT(v int) predicate.Group {
+ return predicate.Group(sql.FieldLT(FieldRpmLimit, v))
+}
+
+// RpmLimitLTE applies the LTE predicate on the "rpm_limit" field.
+func RpmLimitLTE(v int) predicate.Group {
+ return predicate.Group(sql.FieldLTE(FieldRpmLimit, v))
+}
+
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) {
diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go
index a635dfd9..20ea0a0f 100644
--- a/backend/ent/group_create.go
+++ b/backend/ent/group_create.go
@@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
)
// GroupCreate is the builder for creating a Group entity.
@@ -258,76 +259,6 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate {
return _c
}
-// SetSoraImagePrice360 sets the "sora_image_price_360" field.
-func (_c *GroupCreate) SetSoraImagePrice360(v float64) *GroupCreate {
- _c.mutation.SetSoraImagePrice360(v)
- return _c
-}
-
-// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
-func (_c *GroupCreate) SetNillableSoraImagePrice360(v *float64) *GroupCreate {
- if v != nil {
- _c.SetSoraImagePrice360(*v)
- }
- return _c
-}
-
-// SetSoraImagePrice540 sets the "sora_image_price_540" field.
-func (_c *GroupCreate) SetSoraImagePrice540(v float64) *GroupCreate {
- _c.mutation.SetSoraImagePrice540(v)
- return _c
-}
-
-// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
-func (_c *GroupCreate) SetNillableSoraImagePrice540(v *float64) *GroupCreate {
- if v != nil {
- _c.SetSoraImagePrice540(*v)
- }
- return _c
-}
-
-// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
-func (_c *GroupCreate) SetSoraVideoPricePerRequest(v float64) *GroupCreate {
- _c.mutation.SetSoraVideoPricePerRequest(v)
- return _c
-}
-
-// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
-func (_c *GroupCreate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupCreate {
- if v != nil {
- _c.SetSoraVideoPricePerRequest(*v)
- }
- return _c
-}
-
-// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
-func (_c *GroupCreate) SetSoraVideoPricePerRequestHd(v float64) *GroupCreate {
- _c.mutation.SetSoraVideoPricePerRequestHd(v)
- return _c
-}
-
-// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
-func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupCreate {
- if v != nil {
- _c.SetSoraVideoPricePerRequestHd(*v)
- }
- return _c
-}
-
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (_c *GroupCreate) SetSoraStorageQuotaBytes(v int64) *GroupCreate {
- _c.mutation.SetSoraStorageQuotaBytes(v)
- return _c
-}
-
-// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
-func (_c *GroupCreate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupCreate {
- if v != nil {
- _c.SetSoraStorageQuotaBytes(*v)
- }
- return _c
-}
-
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate {
_c.mutation.SetClaudeCodeOnly(v)
@@ -480,6 +411,34 @@ func (_c *GroupCreate) SetNillableDefaultMappedModel(v *string) *GroupCreate {
return _c
}
+// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
+func (_c *GroupCreate) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupCreate {
+ _c.mutation.SetMessagesDispatchModelConfig(v)
+ return _c
+}
+
+// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil.
+func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMessagesDispatchModelConfig) *GroupCreate {
+ if v != nil {
+ _c.SetMessagesDispatchModelConfig(*v)
+ }
+ return _c
+}
+
+// SetRpmLimit sets the "rpm_limit" field.
+func (_c *GroupCreate) SetRpmLimit(v int) *GroupCreate {
+ _c.mutation.SetRpmLimit(v)
+ return _c
+}
+
+// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
+func (_c *GroupCreate) SetNillableRpmLimit(v *int) *GroupCreate {
+ if v != nil {
+ _c.SetRpmLimit(*v)
+ }
+ return _c
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -645,10 +604,6 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultDefaultValidityDays
_c.mutation.SetDefaultValidityDays(v)
}
- if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
- v := group.DefaultSoraStorageQuotaBytes
- _c.mutation.SetSoraStorageQuotaBytes(v)
- }
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
v := group.DefaultClaudeCodeOnly
_c.mutation.SetClaudeCodeOnly(v)
@@ -685,6 +640,14 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultDefaultMappedModel
_c.mutation.SetDefaultMappedModel(v)
}
+ if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok {
+ v := group.DefaultMessagesDispatchModelConfig
+ _c.mutation.SetMessagesDispatchModelConfig(v)
+ }
+ if _, ok := _c.mutation.RpmLimit(); !ok {
+ v := group.DefaultRpmLimit
+ _c.mutation.SetRpmLimit(v)
+ }
return nil
}
@@ -737,9 +700,6 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.DefaultValidityDays(); !ok {
return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)}
}
- if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
- return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "Group.sora_storage_quota_bytes"`)}
- }
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
}
@@ -772,6 +732,12 @@ func (_c *GroupCreate) check() error {
return &ValidationError{Name: "default_mapped_model", err: fmt.Errorf(`ent: validator failed for field "Group.default_mapped_model": %w`, err)}
}
}
+ if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok {
+ return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)}
+ }
+ if _, ok := _c.mutation.RpmLimit(); !ok {
+ return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "Group.rpm_limit"`)}
+ }
return nil
}
@@ -867,26 +833,6 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value)
_node.ImagePrice4k = &value
}
- if value, ok := _c.mutation.SoraImagePrice360(); ok {
- _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
- _node.SoraImagePrice360 = &value
- }
- if value, ok := _c.mutation.SoraImagePrice540(); ok {
- _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
- _node.SoraImagePrice540 = &value
- }
- if value, ok := _c.mutation.SoraVideoPricePerRequest(); ok {
- _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
- _node.SoraVideoPricePerRequest = &value
- }
- if value, ok := _c.mutation.SoraVideoPricePerRequestHd(); ok {
- _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
- _node.SoraVideoPricePerRequestHd = &value
- }
- if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
- _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- _node.SoraStorageQuotaBytes = value
- }
if value, ok := _c.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
_node.ClaudeCodeOnly = value
@@ -935,6 +881,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
_node.DefaultMappedModel = value
}
+ if value, ok := _c.mutation.MessagesDispatchModelConfig(); ok {
+ _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
+ _node.MessagesDispatchModelConfig = value
+ }
+ if value, ok := _c.mutation.RpmLimit(); ok {
+ _spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
+ _node.RpmLimit = value
+ }
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1379,120 +1333,6 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert {
return u
}
-// SetSoraImagePrice360 sets the "sora_image_price_360" field.
-func (u *GroupUpsert) SetSoraImagePrice360(v float64) *GroupUpsert {
- u.Set(group.FieldSoraImagePrice360, v)
- return u
-}
-
-// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
-func (u *GroupUpsert) UpdateSoraImagePrice360() *GroupUpsert {
- u.SetExcluded(group.FieldSoraImagePrice360)
- return u
-}
-
-// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
-func (u *GroupUpsert) AddSoraImagePrice360(v float64) *GroupUpsert {
- u.Add(group.FieldSoraImagePrice360, v)
- return u
-}
-
-// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
-func (u *GroupUpsert) ClearSoraImagePrice360() *GroupUpsert {
- u.SetNull(group.FieldSoraImagePrice360)
- return u
-}
-
-// SetSoraImagePrice540 sets the "sora_image_price_540" field.
-func (u *GroupUpsert) SetSoraImagePrice540(v float64) *GroupUpsert {
- u.Set(group.FieldSoraImagePrice540, v)
- return u
-}
-
-// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
-func (u *GroupUpsert) UpdateSoraImagePrice540() *GroupUpsert {
- u.SetExcluded(group.FieldSoraImagePrice540)
- return u
-}
-
-// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
-func (u *GroupUpsert) AddSoraImagePrice540(v float64) *GroupUpsert {
- u.Add(group.FieldSoraImagePrice540, v)
- return u
-}
-
-// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
-func (u *GroupUpsert) ClearSoraImagePrice540() *GroupUpsert {
- u.SetNull(group.FieldSoraImagePrice540)
- return u
-}
-
-// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
-func (u *GroupUpsert) SetSoraVideoPricePerRequest(v float64) *GroupUpsert {
- u.Set(group.FieldSoraVideoPricePerRequest, v)
- return u
-}
-
-// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
-func (u *GroupUpsert) UpdateSoraVideoPricePerRequest() *GroupUpsert {
- u.SetExcluded(group.FieldSoraVideoPricePerRequest)
- return u
-}
-
-// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
-func (u *GroupUpsert) AddSoraVideoPricePerRequest(v float64) *GroupUpsert {
- u.Add(group.FieldSoraVideoPricePerRequest, v)
- return u
-}
-
-// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
-func (u *GroupUpsert) ClearSoraVideoPricePerRequest() *GroupUpsert {
- u.SetNull(group.FieldSoraVideoPricePerRequest)
- return u
-}
-
-// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsert) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsert {
- u.Set(group.FieldSoraVideoPricePerRequestHd, v)
- return u
-}
-
-// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
-func (u *GroupUpsert) UpdateSoraVideoPricePerRequestHd() *GroupUpsert {
- u.SetExcluded(group.FieldSoraVideoPricePerRequestHd)
- return u
-}
-
-// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsert) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsert {
- u.Add(group.FieldSoraVideoPricePerRequestHd, v)
- return u
-}
-
-// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert {
- u.SetNull(group.FieldSoraVideoPricePerRequestHd)
- return u
-}
-
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (u *GroupUpsert) SetSoraStorageQuotaBytes(v int64) *GroupUpsert {
- u.Set(group.FieldSoraStorageQuotaBytes, v)
- return u
-}
-
-// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
-func (u *GroupUpsert) UpdateSoraStorageQuotaBytes() *GroupUpsert {
- u.SetExcluded(group.FieldSoraStorageQuotaBytes)
- return u
-}
-
-// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
-func (u *GroupUpsert) AddSoraStorageQuotaBytes(v int64) *GroupUpsert {
- u.Add(group.FieldSoraStorageQuotaBytes, v)
- return u
-}
-
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert {
u.Set(group.FieldClaudeCodeOnly, v)
@@ -1673,6 +1513,36 @@ func (u *GroupUpsert) UpdateDefaultMappedModel() *GroupUpsert {
return u
}
+// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
+func (u *GroupUpsert) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpsert {
+ u.Set(group.FieldMessagesDispatchModelConfig, v)
+ return u
+}
+
+// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create.
+func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert {
+ u.SetExcluded(group.FieldMessagesDispatchModelConfig)
+ return u
+}
+
+// SetRpmLimit sets the "rpm_limit" field.
+func (u *GroupUpsert) SetRpmLimit(v int) *GroupUpsert {
+ u.Set(group.FieldRpmLimit, v)
+ return u
+}
+
+// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
+func (u *GroupUpsert) UpdateRpmLimit() *GroupUpsert {
+ u.SetExcluded(group.FieldRpmLimit)
+ return u
+}
+
+// AddRpmLimit adds v to the "rpm_limit" field.
+func (u *GroupUpsert) AddRpmLimit(v int) *GroupUpsert {
+ u.Add(group.FieldRpmLimit, v)
+ return u
+}
+
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -2054,139 +1924,6 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne {
})
}
-// SetSoraImagePrice360 sets the "sora_image_price_360" field.
-func (u *GroupUpsertOne) SetSoraImagePrice360(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraImagePrice360(v)
- })
-}
-
-// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
-func (u *GroupUpsertOne) AddSoraImagePrice360(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraImagePrice360(v)
- })
-}
-
-// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
-func (u *GroupUpsertOne) UpdateSoraImagePrice360() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraImagePrice360()
- })
-}
-
-// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
-func (u *GroupUpsertOne) ClearSoraImagePrice360() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraImagePrice360()
- })
-}
-
-// SetSoraImagePrice540 sets the "sora_image_price_540" field.
-func (u *GroupUpsertOne) SetSoraImagePrice540(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraImagePrice540(v)
- })
-}
-
-// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
-func (u *GroupUpsertOne) AddSoraImagePrice540(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraImagePrice540(v)
- })
-}
-
-// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
-func (u *GroupUpsertOne) UpdateSoraImagePrice540() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraImagePrice540()
- })
-}
-
-// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
-func (u *GroupUpsertOne) ClearSoraImagePrice540() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraImagePrice540()
- })
-}
-
-// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
-func (u *GroupUpsertOne) SetSoraVideoPricePerRequest(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraVideoPricePerRequest(v)
- })
-}
-
-// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
-func (u *GroupUpsertOne) AddSoraVideoPricePerRequest(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraVideoPricePerRequest(v)
- })
-}
-
-// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
-func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequest() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraVideoPricePerRequest()
- })
-}
-
-// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
-func (u *GroupUpsertOne) ClearSoraVideoPricePerRequest() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraVideoPricePerRequest()
- })
-}
-
-// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsertOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraVideoPricePerRequestHd(v)
- })
-}
-
-// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsertOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraVideoPricePerRequestHd(v)
- })
-}
-
-// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
-func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequestHd() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraVideoPricePerRequestHd()
- })
-}
-
-// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraVideoPricePerRequestHd()
- })
-}
-
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (u *GroupUpsertOne) SetSoraStorageQuotaBytes(v int64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraStorageQuotaBytes(v)
- })
-}
-
-// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
-func (u *GroupUpsertOne) AddSoraStorageQuotaBytes(v int64) *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraStorageQuotaBytes(v)
- })
-}
-
-// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
-func (u *GroupUpsertOne) UpdateSoraStorageQuotaBytes() *GroupUpsertOne {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraStorageQuotaBytes()
- })
-}
-
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne {
return u.Update(func(s *GroupUpsert) {
@@ -2397,6 +2134,41 @@ func (u *GroupUpsertOne) UpdateDefaultMappedModel() *GroupUpsertOne {
})
}
+// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
+func (u *GroupUpsertOne) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetMessagesDispatchModelConfig(v)
+ })
+}
+
+// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create.
+func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateMessagesDispatchModelConfig()
+ })
+}
+
+// SetRpmLimit sets the "rpm_limit" field.
+func (u *GroupUpsertOne) SetRpmLimit(v int) *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetRpmLimit(v)
+ })
+}
+
+// AddRpmLimit adds v to the "rpm_limit" field.
+func (u *GroupUpsertOne) AddRpmLimit(v int) *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.AddRpmLimit(v)
+ })
+}
+
+// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
+func (u *GroupUpsertOne) UpdateRpmLimit() *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateRpmLimit()
+ })
+}
+
// Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -2944,139 +2716,6 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk {
})
}
-// SetSoraImagePrice360 sets the "sora_image_price_360" field.
-func (u *GroupUpsertBulk) SetSoraImagePrice360(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraImagePrice360(v)
- })
-}
-
-// AddSoraImagePrice360 adds v to the "sora_image_price_360" field.
-func (u *GroupUpsertBulk) AddSoraImagePrice360(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraImagePrice360(v)
- })
-}
-
-// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create.
-func (u *GroupUpsertBulk) UpdateSoraImagePrice360() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraImagePrice360()
- })
-}
-
-// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
-func (u *GroupUpsertBulk) ClearSoraImagePrice360() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraImagePrice360()
- })
-}
-
-// SetSoraImagePrice540 sets the "sora_image_price_540" field.
-func (u *GroupUpsertBulk) SetSoraImagePrice540(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraImagePrice540(v)
- })
-}
-
-// AddSoraImagePrice540 adds v to the "sora_image_price_540" field.
-func (u *GroupUpsertBulk) AddSoraImagePrice540(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraImagePrice540(v)
- })
-}
-
-// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create.
-func (u *GroupUpsertBulk) UpdateSoraImagePrice540() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraImagePrice540()
- })
-}
-
-// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
-func (u *GroupUpsertBulk) ClearSoraImagePrice540() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraImagePrice540()
- })
-}
-
-// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
-func (u *GroupUpsertBulk) SetSoraVideoPricePerRequest(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraVideoPricePerRequest(v)
- })
-}
-
-// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field.
-func (u *GroupUpsertBulk) AddSoraVideoPricePerRequest(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraVideoPricePerRequest(v)
- })
-}
-
-// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create.
-func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequest() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraVideoPricePerRequest()
- })
-}
-
-// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
-func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequest() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraVideoPricePerRequest()
- })
-}
-
-// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsertBulk) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraVideoPricePerRequestHd(v)
- })
-}
-
-// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsertBulk) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraVideoPricePerRequestHd(v)
- })
-}
-
-// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create.
-func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequestHd() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraVideoPricePerRequestHd()
- })
-}
-
-// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
-func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.ClearSoraVideoPricePerRequestHd()
- })
-}
-
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (u *GroupUpsertBulk) SetSoraStorageQuotaBytes(v int64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.SetSoraStorageQuotaBytes(v)
- })
-}
-
-// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
-func (u *GroupUpsertBulk) AddSoraStorageQuotaBytes(v int64) *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.AddSoraStorageQuotaBytes(v)
- })
-}
-
-// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
-func (u *GroupUpsertBulk) UpdateSoraStorageQuotaBytes() *GroupUpsertBulk {
- return u.Update(func(s *GroupUpsert) {
- s.UpdateSoraStorageQuotaBytes()
- })
-}
-
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk {
return u.Update(func(s *GroupUpsert) {
@@ -3287,6 +2926,41 @@ func (u *GroupUpsertBulk) UpdateDefaultMappedModel() *GroupUpsertBulk {
})
}
+// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
+func (u *GroupUpsertBulk) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetMessagesDispatchModelConfig(v)
+ })
+}
+
+// UpdateMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field to the value that was provided on create.
+func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateMessagesDispatchModelConfig()
+ })
+}
+
+// SetRpmLimit sets the "rpm_limit" field.
+func (u *GroupUpsertBulk) SetRpmLimit(v int) *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetRpmLimit(v)
+ })
+}
+
+// AddRpmLimit adds v to the "rpm_limit" field.
+func (u *GroupUpsertBulk) AddRpmLimit(v int) *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.AddRpmLimit(v)
+ })
+}
+
+// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
+func (u *GroupUpsertBulk) UpdateRpmLimit() *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateRpmLimit()
+ })
+}
+
// Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go
index a9a4b9da..cc14f897 100644
--- a/backend/ent/group_update.go
+++ b/backend/ent/group_update.go
@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
)
// GroupUpdate is the builder for updating Group entities.
@@ -355,135 +356,6 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate {
return _u
}
-// SetSoraImagePrice360 sets the "sora_image_price_360" field.
-func (_u *GroupUpdate) SetSoraImagePrice360(v float64) *GroupUpdate {
- _u.mutation.ResetSoraImagePrice360()
- _u.mutation.SetSoraImagePrice360(v)
- return _u
-}
-
-// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
-func (_u *GroupUpdate) SetNillableSoraImagePrice360(v *float64) *GroupUpdate {
- if v != nil {
- _u.SetSoraImagePrice360(*v)
- }
- return _u
-}
-
-// AddSoraImagePrice360 adds value to the "sora_image_price_360" field.
-func (_u *GroupUpdate) AddSoraImagePrice360(v float64) *GroupUpdate {
- _u.mutation.AddSoraImagePrice360(v)
- return _u
-}
-
-// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
-func (_u *GroupUpdate) ClearSoraImagePrice360() *GroupUpdate {
- _u.mutation.ClearSoraImagePrice360()
- return _u
-}
-
-// SetSoraImagePrice540 sets the "sora_image_price_540" field.
-func (_u *GroupUpdate) SetSoraImagePrice540(v float64) *GroupUpdate {
- _u.mutation.ResetSoraImagePrice540()
- _u.mutation.SetSoraImagePrice540(v)
- return _u
-}
-
-// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
-func (_u *GroupUpdate) SetNillableSoraImagePrice540(v *float64) *GroupUpdate {
- if v != nil {
- _u.SetSoraImagePrice540(*v)
- }
- return _u
-}
-
-// AddSoraImagePrice540 adds value to the "sora_image_price_540" field.
-func (_u *GroupUpdate) AddSoraImagePrice540(v float64) *GroupUpdate {
- _u.mutation.AddSoraImagePrice540(v)
- return _u
-}
-
-// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
-func (_u *GroupUpdate) ClearSoraImagePrice540() *GroupUpdate {
- _u.mutation.ClearSoraImagePrice540()
- return _u
-}
-
-// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
-func (_u *GroupUpdate) SetSoraVideoPricePerRequest(v float64) *GroupUpdate {
- _u.mutation.ResetSoraVideoPricePerRequest()
- _u.mutation.SetSoraVideoPricePerRequest(v)
- return _u
-}
-
-// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
-func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdate {
- if v != nil {
- _u.SetSoraVideoPricePerRequest(*v)
- }
- return _u
-}
-
-// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field.
-func (_u *GroupUpdate) AddSoraVideoPricePerRequest(v float64) *GroupUpdate {
- _u.mutation.AddSoraVideoPricePerRequest(v)
- return _u
-}
-
-// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
-func (_u *GroupUpdate) ClearSoraVideoPricePerRequest() *GroupUpdate {
- _u.mutation.ClearSoraVideoPricePerRequest()
- return _u
-}
-
-// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
-func (_u *GroupUpdate) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdate {
- _u.mutation.ResetSoraVideoPricePerRequestHd()
- _u.mutation.SetSoraVideoPricePerRequestHd(v)
- return _u
-}
-
-// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
-func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdate {
- if v != nil {
- _u.SetSoraVideoPricePerRequestHd(*v)
- }
- return _u
-}
-
-// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field.
-func (_u *GroupUpdate) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdate {
- _u.mutation.AddSoraVideoPricePerRequestHd(v)
- return _u
-}
-
-// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
-func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate {
- _u.mutation.ClearSoraVideoPricePerRequestHd()
- return _u
-}
-
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (_u *GroupUpdate) SetSoraStorageQuotaBytes(v int64) *GroupUpdate {
- _u.mutation.ResetSoraStorageQuotaBytes()
- _u.mutation.SetSoraStorageQuotaBytes(v)
- return _u
-}
-
-// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
-func (_u *GroupUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdate {
- if v != nil {
- _u.SetSoraStorageQuotaBytes(*v)
- }
- return _u
-}
-
-// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
-func (_u *GroupUpdate) AddSoraStorageQuotaBytes(v int64) *GroupUpdate {
- _u.mutation.AddSoraStorageQuotaBytes(v)
- return _u
-}
-
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate {
_u.mutation.SetClaudeCodeOnly(v)
@@ -681,6 +553,41 @@ func (_u *GroupUpdate) SetNillableDefaultMappedModel(v *string) *GroupUpdate {
return _u
}
+// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
+func (_u *GroupUpdate) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpdate {
+ _u.mutation.SetMessagesDispatchModelConfig(v)
+ return _u
+}
+
+// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil.
+func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMessagesDispatchModelConfig) *GroupUpdate {
+ if v != nil {
+ _u.SetMessagesDispatchModelConfig(*v)
+ }
+ return _u
+}
+
+// SetRpmLimit sets the "rpm_limit" field.
+func (_u *GroupUpdate) SetRpmLimit(v int) *GroupUpdate {
+ _u.mutation.ResetRpmLimit()
+ _u.mutation.SetRpmLimit(v)
+ return _u
+}
+
+// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
+func (_u *GroupUpdate) SetNillableRpmLimit(v *int) *GroupUpdate {
+ if v != nil {
+ _u.SetRpmLimit(*v)
+ }
+ return _u
+}
+
+// AddRpmLimit adds value to the "rpm_limit" field.
+func (_u *GroupUpdate) AddRpmLimit(v int) *GroupUpdate {
+ _u.mutation.AddRpmLimit(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1082,48 +989,6 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
}
- if value, ok := _u.mutation.SoraImagePrice360(); ok {
- _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraImagePrice360(); ok {
- _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
- }
- if _u.mutation.SoraImagePrice360Cleared() {
- _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraImagePrice540(); ok {
- _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraImagePrice540(); ok {
- _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
- }
- if _u.mutation.SoraImagePrice540Cleared() {
- _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok {
- _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok {
- _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
- }
- if _u.mutation.SoraVideoPricePerRequestCleared() {
- _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok {
- _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok {
- _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
- }
- if _u.mutation.SoraVideoPricePerRequestHdCleared() {
- _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
- _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- }
- if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
- _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- }
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}
@@ -1183,6 +1048,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.DefaultMappedModel(); ok {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
}
+ if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
+ _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.RpmLimit(); ok {
+ _spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedRpmLimit(); ok {
+ _spec.AddField(group.FieldRpmLimit, field.TypeInt, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1817,135 +1691,6 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne {
return _u
}
-// SetSoraImagePrice360 sets the "sora_image_price_360" field.
-func (_u *GroupUpdateOne) SetSoraImagePrice360(v float64) *GroupUpdateOne {
- _u.mutation.ResetSoraImagePrice360()
- _u.mutation.SetSoraImagePrice360(v)
- return _u
-}
-
-// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil.
-func (_u *GroupUpdateOne) SetNillableSoraImagePrice360(v *float64) *GroupUpdateOne {
- if v != nil {
- _u.SetSoraImagePrice360(*v)
- }
- return _u
-}
-
-// AddSoraImagePrice360 adds value to the "sora_image_price_360" field.
-func (_u *GroupUpdateOne) AddSoraImagePrice360(v float64) *GroupUpdateOne {
- _u.mutation.AddSoraImagePrice360(v)
- return _u
-}
-
-// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
-func (_u *GroupUpdateOne) ClearSoraImagePrice360() *GroupUpdateOne {
- _u.mutation.ClearSoraImagePrice360()
- return _u
-}
-
-// SetSoraImagePrice540 sets the "sora_image_price_540" field.
-func (_u *GroupUpdateOne) SetSoraImagePrice540(v float64) *GroupUpdateOne {
- _u.mutation.ResetSoraImagePrice540()
- _u.mutation.SetSoraImagePrice540(v)
- return _u
-}
-
-// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil.
-func (_u *GroupUpdateOne) SetNillableSoraImagePrice540(v *float64) *GroupUpdateOne {
- if v != nil {
- _u.SetSoraImagePrice540(*v)
- }
- return _u
-}
-
-// AddSoraImagePrice540 adds value to the "sora_image_price_540" field.
-func (_u *GroupUpdateOne) AddSoraImagePrice540(v float64) *GroupUpdateOne {
- _u.mutation.AddSoraImagePrice540(v)
- return _u
-}
-
-// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
-func (_u *GroupUpdateOne) ClearSoraImagePrice540() *GroupUpdateOne {
- _u.mutation.ClearSoraImagePrice540()
- return _u
-}
-
-// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
-func (_u *GroupUpdateOne) SetSoraVideoPricePerRequest(v float64) *GroupUpdateOne {
- _u.mutation.ResetSoraVideoPricePerRequest()
- _u.mutation.SetSoraVideoPricePerRequest(v)
- return _u
-}
-
-// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil.
-func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdateOne {
- if v != nil {
- _u.SetSoraVideoPricePerRequest(*v)
- }
- return _u
-}
-
-// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field.
-func (_u *GroupUpdateOne) AddSoraVideoPricePerRequest(v float64) *GroupUpdateOne {
- _u.mutation.AddSoraVideoPricePerRequest(v)
- return _u
-}
-
-// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
-func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequest() *GroupUpdateOne {
- _u.mutation.ClearSoraVideoPricePerRequest()
- return _u
-}
-
-// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
-func (_u *GroupUpdateOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne {
- _u.mutation.ResetSoraVideoPricePerRequestHd()
- _u.mutation.SetSoraVideoPricePerRequestHd(v)
- return _u
-}
-
-// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil.
-func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdateOne {
- if v != nil {
- _u.SetSoraVideoPricePerRequestHd(*v)
- }
- return _u
-}
-
-// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field.
-func (_u *GroupUpdateOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne {
- _u.mutation.AddSoraVideoPricePerRequestHd(v)
- return _u
-}
-
-// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
-func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne {
- _u.mutation.ClearSoraVideoPricePerRequestHd()
- return _u
-}
-
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (_u *GroupUpdateOne) SetSoraStorageQuotaBytes(v int64) *GroupUpdateOne {
- _u.mutation.ResetSoraStorageQuotaBytes()
- _u.mutation.SetSoraStorageQuotaBytes(v)
- return _u
-}
-
-// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
-func (_u *GroupUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdateOne {
- if v != nil {
- _u.SetSoraStorageQuotaBytes(*v)
- }
- return _u
-}
-
-// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
-func (_u *GroupUpdateOne) AddSoraStorageQuotaBytes(v int64) *GroupUpdateOne {
- _u.mutation.AddSoraStorageQuotaBytes(v)
- return _u
-}
-
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne {
_u.mutation.SetClaudeCodeOnly(v)
@@ -2143,6 +1888,41 @@ func (_u *GroupUpdateOne) SetNillableDefaultMappedModel(v *string) *GroupUpdateO
return _u
}
+// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
+func (_u *GroupUpdateOne) SetMessagesDispatchModelConfig(v domain.OpenAIMessagesDispatchModelConfig) *GroupUpdateOne {
+ _u.mutation.SetMessagesDispatchModelConfig(v)
+ return _u
+}
+
+// SetNillableMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field if the given value is not nil.
+func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMessagesDispatchModelConfig) *GroupUpdateOne {
+ if v != nil {
+ _u.SetMessagesDispatchModelConfig(*v)
+ }
+ return _u
+}
+
+// SetRpmLimit sets the "rpm_limit" field.
+func (_u *GroupUpdateOne) SetRpmLimit(v int) *GroupUpdateOne {
+ _u.mutation.ResetRpmLimit()
+ _u.mutation.SetRpmLimit(v)
+ return _u
+}
+
+// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
+func (_u *GroupUpdateOne) SetNillableRpmLimit(v *int) *GroupUpdateOne {
+ if v != nil {
+ _u.SetRpmLimit(*v)
+ }
+ return _u
+}
+
+// AddRpmLimit adds value to the "rpm_limit" field.
+func (_u *GroupUpdateOne) AddRpmLimit(v int) *GroupUpdateOne {
+ _u.mutation.AddRpmLimit(v)
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -2574,48 +2354,6 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.ImagePrice4kCleared() {
_spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64)
}
- if value, ok := _u.mutation.SoraImagePrice360(); ok {
- _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraImagePrice360(); ok {
- _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value)
- }
- if _u.mutation.SoraImagePrice360Cleared() {
- _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraImagePrice540(); ok {
- _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraImagePrice540(); ok {
- _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value)
- }
- if _u.mutation.SoraImagePrice540Cleared() {
- _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok {
- _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok {
- _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value)
- }
- if _u.mutation.SoraVideoPricePerRequestCleared() {
- _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok {
- _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
- }
- if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok {
- _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value)
- }
- if _u.mutation.SoraVideoPricePerRequestHdCleared() {
- _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64)
- }
- if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
- _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- }
- if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
- _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- }
if value, ok := _u.mutation.ClaudeCodeOnly(); ok {
_spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value)
}
@@ -2675,6 +2413,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if value, ok := _u.mutation.DefaultMappedModel(); ok {
_spec.SetField(group.FieldDefaultMappedModel, field.TypeString, value)
}
+ if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok {
+ _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.RpmLimit(); ok {
+ _spec.SetField(group.FieldRpmLimit, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedRpmLimit(); ok {
+ _spec.AddField(group.FieldRpmLimit, field.TypeInt, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go
index f6f7b4e9..414eba24 100644
--- a/backend/ent/hook/hook.go
+++ b/backend/ent/hook/hook.go
@@ -69,6 +69,78 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m)
}
+// The AuthIdentityFunc type is an adapter to allow the use of ordinary
+// function as AuthIdentity mutator.
+type AuthIdentityFunc func(context.Context, *ent.AuthIdentityMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f AuthIdentityFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.AuthIdentityMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityMutation", m)
+}
+
+// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary
+// function as AuthIdentityChannel mutator.
+type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f AuthIdentityChannelFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.AuthIdentityChannelMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityChannelMutation", m)
+}
+
+// The ChannelMonitorFunc type is an adapter to allow the use of ordinary
+// function as ChannelMonitor mutator.
+type ChannelMonitorFunc func(context.Context, *ent.ChannelMonitorMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f ChannelMonitorFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.ChannelMonitorMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorMutation", m)
+}
+
+// The ChannelMonitorDailyRollupFunc type is an adapter to allow the use of ordinary
+// function as ChannelMonitorDailyRollup mutator.
+type ChannelMonitorDailyRollupFunc func(context.Context, *ent.ChannelMonitorDailyRollupMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f ChannelMonitorDailyRollupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.ChannelMonitorDailyRollupMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorDailyRollupMutation", m)
+}
+
+// The ChannelMonitorHistoryFunc type is an adapter to allow the use of ordinary
+// function as ChannelMonitorHistory mutator.
+type ChannelMonitorHistoryFunc func(context.Context, *ent.ChannelMonitorHistoryMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f ChannelMonitorHistoryFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.ChannelMonitorHistoryMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorHistoryMutation", m)
+}
+
+// The ChannelMonitorRequestTemplateFunc type is an adapter to allow the use of ordinary
+// function as ChannelMonitorRequestTemplate mutator.
+type ChannelMonitorRequestTemplateFunc func(context.Context, *ent.ChannelMonitorRequestTemplateMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f ChannelMonitorRequestTemplateFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.ChannelMonitorRequestTemplateMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ChannelMonitorRequestTemplateMutation", m)
+}
+
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary
// function as ErrorPassthroughRule mutator.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error)
@@ -105,6 +177,66 @@ func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m)
}
+// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary
+// function as IdentityAdoptionDecision mutator.
+type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f IdentityAdoptionDecisionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.IdentityAdoptionDecisionMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdentityAdoptionDecisionMutation", m)
+}
+
+// The PaymentAuditLogFunc type is an adapter to allow the use of ordinary
+// function as PaymentAuditLog mutator.
+type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f PaymentAuditLogFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.PaymentAuditLogMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PaymentAuditLogMutation", m)
+}
+
+// The PaymentOrderFunc type is an adapter to allow the use of ordinary
+// function as PaymentOrder mutator.
+type PaymentOrderFunc func(context.Context, *ent.PaymentOrderMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f PaymentOrderFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.PaymentOrderMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PaymentOrderMutation", m)
+}
+
+// The PaymentProviderInstanceFunc type is an adapter to allow the use of ordinary
+// function as PaymentProviderInstance mutator.
+type PaymentProviderInstanceFunc func(context.Context, *ent.PaymentProviderInstanceMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f PaymentProviderInstanceFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.PaymentProviderInstanceMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PaymentProviderInstanceMutation", m)
+}
+
+// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary
+// function as PendingAuthSession mutator.
+type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f PendingAuthSessionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.PendingAuthSessionMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PendingAuthSessionMutation", m)
+}
+
// The PromoCodeFunc type is an adapter to allow the use of ordinary
// function as PromoCode mutator.
type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error)
@@ -177,6 +309,18 @@ func (f SettingFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, err
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SettingMutation", m)
}
+// The SubscriptionPlanFunc type is an adapter to allow the use of ordinary
+// function as SubscriptionPlan mutator.
+type SubscriptionPlanFunc func(context.Context, *ent.SubscriptionPlanMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f SubscriptionPlanFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.SubscriptionPlanMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SubscriptionPlanMutation", m)
+}
+
// The TLSFingerprintProfileFunc type is an adapter to allow the use of ordinary
// function as TLSFingerprintProfile mutator.
type TLSFingerprintProfileFunc func(context.Context, *ent.TLSFingerprintProfileMutation) (ent.Value, error)
diff --git a/backend/ent/identityadoptiondecision.go b/backend/ent/identityadoptiondecision.go
new file mode 100644
index 00000000..ecaee65c
--- /dev/null
+++ b/backend/ent/identityadoptiondecision.go
@@ -0,0 +1,223 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+)
+
+// IdentityAdoptionDecision is the model entity for the IdentityAdoptionDecision schema.
+type IdentityAdoptionDecision struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // PendingAuthSessionID holds the value of the "pending_auth_session_id" field.
+ PendingAuthSessionID int64 `json:"pending_auth_session_id,omitempty"`
+ // IdentityID holds the value of the "identity_id" field.
+ IdentityID *int64 `json:"identity_id,omitempty"`
+ // AdoptDisplayName holds the value of the "adopt_display_name" field.
+ AdoptDisplayName bool `json:"adopt_display_name,omitempty"`
+ // AdoptAvatar holds the value of the "adopt_avatar" field.
+ AdoptAvatar bool `json:"adopt_avatar,omitempty"`
+ // DecidedAt holds the value of the "decided_at" field.
+ DecidedAt time.Time `json:"decided_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the IdentityAdoptionDecisionQuery when eager-loading is set.
+ Edges IdentityAdoptionDecisionEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// IdentityAdoptionDecisionEdges holds the relations/edges for other nodes in the graph.
+type IdentityAdoptionDecisionEdges struct {
+ // PendingAuthSession holds the value of the pending_auth_session edge.
+ PendingAuthSession *PendingAuthSession `json:"pending_auth_session,omitempty"`
+ // Identity holds the value of the identity edge.
+ Identity *AuthIdentity `json:"identity,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [2]bool
+}
+
+// PendingAuthSessionOrErr returns the PendingAuthSession value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e IdentityAdoptionDecisionEdges) PendingAuthSessionOrErr() (*PendingAuthSession, error) {
+ if e.PendingAuthSession != nil {
+ return e.PendingAuthSession, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: pendingauthsession.Label}
+ }
+ return nil, &NotLoadedError{edge: "pending_auth_session"}
+}
+
+// IdentityOrErr returns the Identity value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e IdentityAdoptionDecisionEdges) IdentityOrErr() (*AuthIdentity, error) {
+ if e.Identity != nil {
+ return e.Identity, nil
+ } else if e.loadedTypes[1] {
+ return nil, &NotFoundError{label: authidentity.Label}
+ }
+ return nil, &NotLoadedError{edge: "identity"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*IdentityAdoptionDecision) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case identityadoptiondecision.FieldAdoptDisplayName, identityadoptiondecision.FieldAdoptAvatar:
+ values[i] = new(sql.NullBool)
+ case identityadoptiondecision.FieldID, identityadoptiondecision.FieldPendingAuthSessionID, identityadoptiondecision.FieldIdentityID:
+ values[i] = new(sql.NullInt64)
+ case identityadoptiondecision.FieldCreatedAt, identityadoptiondecision.FieldUpdatedAt, identityadoptiondecision.FieldDecidedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the IdentityAdoptionDecision fields.
+func (_m *IdentityAdoptionDecision) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case identityadoptiondecision.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case identityadoptiondecision.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case identityadoptiondecision.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field pending_auth_session_id", values[i])
+ } else if value.Valid {
+ _m.PendingAuthSessionID = value.Int64
+ }
+ case identityadoptiondecision.FieldIdentityID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field identity_id", values[i])
+ } else if value.Valid {
+ _m.IdentityID = new(int64)
+ *_m.IdentityID = value.Int64
+ }
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field adopt_display_name", values[i])
+ } else if value.Valid {
+ _m.AdoptDisplayName = value.Bool
+ }
+ case identityadoptiondecision.FieldAdoptAvatar:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field adopt_avatar", values[i])
+ } else if value.Valid {
+ _m.AdoptAvatar = value.Bool
+ }
+ case identityadoptiondecision.FieldDecidedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field decided_at", values[i])
+ } else if value.Valid {
+ _m.DecidedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the IdentityAdoptionDecision.
+// This includes values selected through modifiers, order, etc.
+func (_m *IdentityAdoptionDecision) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryPendingAuthSession queries the "pending_auth_session" edge of the IdentityAdoptionDecision entity.
+func (_m *IdentityAdoptionDecision) QueryPendingAuthSession() *PendingAuthSessionQuery {
+ return NewIdentityAdoptionDecisionClient(_m.config).QueryPendingAuthSession(_m)
+}
+
+// QueryIdentity queries the "identity" edge of the IdentityAdoptionDecision entity.
+func (_m *IdentityAdoptionDecision) QueryIdentity() *AuthIdentityQuery {
+ return NewIdentityAdoptionDecisionClient(_m.config).QueryIdentity(_m)
+}
+
+// Update returns a builder for updating this IdentityAdoptionDecision.
+// Note that you need to call IdentityAdoptionDecision.Unwrap() before calling this method if this IdentityAdoptionDecision
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *IdentityAdoptionDecision) Update() *IdentityAdoptionDecisionUpdateOne {
+ return NewIdentityAdoptionDecisionClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the IdentityAdoptionDecision entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *IdentityAdoptionDecision) Unwrap() *IdentityAdoptionDecision {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: IdentityAdoptionDecision is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *IdentityAdoptionDecision) String() string {
+ var builder strings.Builder
+ builder.WriteString("IdentityAdoptionDecision(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("pending_auth_session_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.PendingAuthSessionID))
+ builder.WriteString(", ")
+ if v := _m.IdentityID; v != nil {
+ builder.WriteString("identity_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("adopt_display_name=")
+ builder.WriteString(fmt.Sprintf("%v", _m.AdoptDisplayName))
+ builder.WriteString(", ")
+ builder.WriteString("adopt_avatar=")
+ builder.WriteString(fmt.Sprintf("%v", _m.AdoptAvatar))
+ builder.WriteString(", ")
+ builder.WriteString("decided_at=")
+ builder.WriteString(_m.DecidedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// IdentityAdoptionDecisions is a parsable slice of IdentityAdoptionDecision.
+type IdentityAdoptionDecisions []*IdentityAdoptionDecision
diff --git a/backend/ent/identityadoptiondecision/identityadoptiondecision.go b/backend/ent/identityadoptiondecision/identityadoptiondecision.go
new file mode 100644
index 00000000..93adaf73
--- /dev/null
+++ b/backend/ent/identityadoptiondecision/identityadoptiondecision.go
@@ -0,0 +1,159 @@
+// Code generated by ent, DO NOT EDIT.
+
+package identityadoptiondecision
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the identityadoptiondecision type in the database.
+ Label = "identity_adoption_decision"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldPendingAuthSessionID holds the string denoting the pending_auth_session_id field in the database.
+ FieldPendingAuthSessionID = "pending_auth_session_id"
+ // FieldIdentityID holds the string denoting the identity_id field in the database.
+ FieldIdentityID = "identity_id"
+ // FieldAdoptDisplayName holds the string denoting the adopt_display_name field in the database.
+ FieldAdoptDisplayName = "adopt_display_name"
+ // FieldAdoptAvatar holds the string denoting the adopt_avatar field in the database.
+ FieldAdoptAvatar = "adopt_avatar"
+ // FieldDecidedAt holds the string denoting the decided_at field in the database.
+ FieldDecidedAt = "decided_at"
+ // EdgePendingAuthSession holds the string denoting the pending_auth_session edge name in mutations.
+ EdgePendingAuthSession = "pending_auth_session"
+ // EdgeIdentity holds the string denoting the identity edge name in mutations.
+ EdgeIdentity = "identity"
+ // Table holds the table name of the identityadoptiondecision in the database.
+ Table = "identity_adoption_decisions"
+ // PendingAuthSessionTable is the table that holds the pending_auth_session relation/edge.
+ PendingAuthSessionTable = "identity_adoption_decisions"
+ // PendingAuthSessionInverseTable is the table name for the PendingAuthSession entity.
+ // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package.
+ PendingAuthSessionInverseTable = "pending_auth_sessions"
+ // PendingAuthSessionColumn is the table column denoting the pending_auth_session relation/edge.
+ PendingAuthSessionColumn = "pending_auth_session_id"
+ // IdentityTable is the table that holds the identity relation/edge.
+ IdentityTable = "identity_adoption_decisions"
+ // IdentityInverseTable is the table name for the AuthIdentity entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentity" package.
+ IdentityInverseTable = "auth_identities"
+ // IdentityColumn is the table column denoting the identity relation/edge.
+ IdentityColumn = "identity_id"
+)
+
+// Columns holds all SQL columns for identityadoptiondecision fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldPendingAuthSessionID,
+ FieldIdentityID,
+ FieldAdoptDisplayName,
+ FieldAdoptAvatar,
+ FieldDecidedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // DefaultAdoptDisplayName holds the default value on creation for the "adopt_display_name" field.
+ DefaultAdoptDisplayName bool
+ // DefaultAdoptAvatar holds the default value on creation for the "adopt_avatar" field.
+ DefaultAdoptAvatar bool
+ // DefaultDecidedAt holds the default value on creation for the "decided_at" field.
+ DefaultDecidedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the IdentityAdoptionDecision queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByPendingAuthSessionID orders the results by the pending_auth_session_id field.
+func ByPendingAuthSessionID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPendingAuthSessionID, opts...).ToFunc()
+}
+
+// ByIdentityID orders the results by the identity_id field.
+func ByIdentityID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIdentityID, opts...).ToFunc()
+}
+
+// ByAdoptDisplayName orders the results by the adopt_display_name field.
+func ByAdoptDisplayName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAdoptDisplayName, opts...).ToFunc()
+}
+
+// ByAdoptAvatar orders the results by the adopt_avatar field.
+func ByAdoptAvatar(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAdoptAvatar, opts...).ToFunc()
+}
+
+// ByDecidedAt orders the results by the decided_at field.
+func ByDecidedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDecidedAt, opts...).ToFunc()
+}
+
+// ByPendingAuthSessionField orders the results by pending_auth_session field.
+func ByPendingAuthSessionField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByIdentityField orders the results by identity field.
+func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newPendingAuthSessionStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(PendingAuthSessionInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn),
+ )
+}
+func newIdentityStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(IdentityInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+}
diff --git a/backend/ent/identityadoptiondecision/where.go b/backend/ent/identityadoptiondecision/where.go
new file mode 100644
index 00000000..1968f175
--- /dev/null
+++ b/backend/ent/identityadoptiondecision/where.go
@@ -0,0 +1,342 @@
+// Code generated by ent, DO NOT EDIT.
+
+package identityadoptiondecision
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// PendingAuthSessionID applies equality check predicate on the "pending_auth_session_id" field. It's identical to PendingAuthSessionIDEQ.
+func PendingAuthSessionID(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v))
+}
+
+// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ.
+func IdentityID(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// AdoptDisplayName applies equality check predicate on the "adopt_display_name" field. It's identical to AdoptDisplayNameEQ.
+func AdoptDisplayName(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v))
+}
+
+// AdoptAvatar applies equality check predicate on the "adopt_avatar" field. It's identical to AdoptAvatarEQ.
+func AdoptAvatar(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v))
+}
+
+// DecidedAt applies equality check predicate on the "decided_at" field. It's identical to DecidedAtEQ.
+func DecidedAt(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// PendingAuthSessionIDEQ applies the EQ predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v))
+}
+
+// PendingAuthSessionIDNEQ applies the NEQ predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDNEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldPendingAuthSessionID, v))
+}
+
+// PendingAuthSessionIDIn applies the In predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldPendingAuthSessionID, vs...))
+}
+
+// PendingAuthSessionIDNotIn applies the NotIn predicate on the "pending_auth_session_id" field.
+func PendingAuthSessionIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldPendingAuthSessionID, vs...))
+}
+
+// IdentityIDEQ applies the EQ predicate on the "identity_id" field.
+func IdentityIDEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v))
+}
+
+// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field.
+func IdentityIDNEQ(v int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldIdentityID, v))
+}
+
+// IdentityIDIn applies the In predicate on the "identity_id" field.
+func IdentityIDIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldIdentityID, vs...))
+}
+
+// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field.
+func IdentityIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldIdentityID, vs...))
+}
+
+// IdentityIDIsNil applies the IsNil predicate on the "identity_id" field.
+func IdentityIDIsNil() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIsNull(FieldIdentityID))
+}
+
+// IdentityIDNotNil applies the NotNil predicate on the "identity_id" field.
+func IdentityIDNotNil() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotNull(FieldIdentityID))
+}
+
+// AdoptDisplayNameEQ applies the EQ predicate on the "adopt_display_name" field.
+func AdoptDisplayNameEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v))
+}
+
+// AdoptDisplayNameNEQ applies the NEQ predicate on the "adopt_display_name" field.
+func AdoptDisplayNameNEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptDisplayName, v))
+}
+
+// AdoptAvatarEQ applies the EQ predicate on the "adopt_avatar" field.
+func AdoptAvatarEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v))
+}
+
+// AdoptAvatarNEQ applies the NEQ predicate on the "adopt_avatar" field.
+func AdoptAvatarNEQ(v bool) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptAvatar, v))
+}
+
+// DecidedAtEQ applies the EQ predicate on the "decided_at" field.
+func DecidedAtEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v))
+}
+
+// DecidedAtNEQ applies the NEQ predicate on the "decided_at" field.
+func DecidedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldDecidedAt, v))
+}
+
+// DecidedAtIn applies the In predicate on the "decided_at" field.
+func DecidedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldDecidedAt, vs...))
+}
+
+// DecidedAtNotIn applies the NotIn predicate on the "decided_at" field.
+func DecidedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldDecidedAt, vs...))
+}
+
+// DecidedAtGT applies the GT predicate on the "decided_at" field.
+func DecidedAtGT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldDecidedAt, v))
+}
+
+// DecidedAtGTE applies the GTE predicate on the "decided_at" field.
+func DecidedAtGTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldDecidedAt, v))
+}
+
+// DecidedAtLT applies the LT predicate on the "decided_at" field.
+func DecidedAtLT(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldDecidedAt, v))
+}
+
+// DecidedAtLTE applies the LTE predicate on the "decided_at" field.
+func DecidedAtLTE(v time.Time) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldDecidedAt, v))
+}
+
+// HasPendingAuthSession applies the HasEdge predicate on the "pending_auth_session" edge.
+func HasPendingAuthSession() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasPendingAuthSessionWith applies the HasEdge predicate on the "pending_auth_session" edge with a given conditions (other predicates).
+func HasPendingAuthSessionWith(preds ...predicate.PendingAuthSession) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := newPendingAuthSessionStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasIdentity applies the HasEdge predicate on the "identity" edge.
+func HasIdentity() predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates).
+func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ step := newIdentityStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision {
+ return predicate.IdentityAdoptionDecision(sql.NotPredicates(p))
+}
diff --git a/backend/ent/identityadoptiondecision_create.go b/backend/ent/identityadoptiondecision_create.go
new file mode 100644
index 00000000..491ba9f9
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_create.go
@@ -0,0 +1,843 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+)
+
+// IdentityAdoptionDecisionCreate is the builder for creating a IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionCreate struct {
+ config
+ mutation *IdentityAdoptionDecisionMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *IdentityAdoptionDecisionCreate) SetCreatedAt(v time.Time) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableCreatedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *IdentityAdoptionDecisionCreate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableUpdatedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetPendingAuthSessionID(v)
+ return _c
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_c *IdentityAdoptionDecisionCreate) SetIdentityID(v int64) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetIdentityID(v)
+ return _c
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetIdentityID(*v)
+ }
+ return _c
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (_c *IdentityAdoptionDecisionCreate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetAdoptDisplayName(v)
+ return _c
+}
+
+// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetAdoptDisplayName(*v)
+ }
+ return _c
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (_c *IdentityAdoptionDecisionCreate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetAdoptAvatar(v)
+ return _c
+}
+
+// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetAdoptAvatar(*v)
+ }
+ return _c
+}
+
+// SetDecidedAt sets the "decided_at" field.
+func (_c *IdentityAdoptionDecisionCreate) SetDecidedAt(v time.Time) *IdentityAdoptionDecisionCreate {
+ _c.mutation.SetDecidedAt(v)
+ return _c
+}
+
+// SetNillableDecidedAt sets the "decided_at" field if the given value is not nil.
+func (_c *IdentityAdoptionDecisionCreate) SetNillableDecidedAt(v *time.Time) *IdentityAdoptionDecisionCreate {
+ if v != nil {
+ _c.SetDecidedAt(*v)
+ }
+ return _c
+}
+
+// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionCreate {
+ return _c.SetPendingAuthSessionID(v.ID)
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_c *IdentityAdoptionDecisionCreate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionCreate {
+ return _c.SetIdentityID(v.ID)
+}
+
+// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
+func (_c *IdentityAdoptionDecisionCreate) Mutation() *IdentityAdoptionDecisionMutation {
+ return _c.mutation
+}
+
+// Save creates the IdentityAdoptionDecision in the database.
+func (_c *IdentityAdoptionDecisionCreate) Save(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *IdentityAdoptionDecisionCreate) SaveX(ctx context.Context) *IdentityAdoptionDecision {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *IdentityAdoptionDecisionCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *IdentityAdoptionDecisionCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *IdentityAdoptionDecisionCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := identityadoptiondecision.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := identityadoptiondecision.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.AdoptDisplayName(); !ok {
+ v := identityadoptiondecision.DefaultAdoptDisplayName
+ _c.mutation.SetAdoptDisplayName(v)
+ }
+ if _, ok := _c.mutation.AdoptAvatar(); !ok {
+ v := identityadoptiondecision.DefaultAdoptAvatar
+ _c.mutation.SetAdoptAvatar(v)
+ }
+ if _, ok := _c.mutation.DecidedAt(); !ok {
+ v := identityadoptiondecision.DefaultDecidedAt()
+ _c.mutation.SetDecidedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *IdentityAdoptionDecisionCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.updated_at"`)}
+ }
+ if _, ok := _c.mutation.PendingAuthSessionID(); !ok {
+ return &ValidationError{Name: "pending_auth_session_id", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.pending_auth_session_id"`)}
+ }
+ if _, ok := _c.mutation.AdoptDisplayName(); !ok {
+ return &ValidationError{Name: "adopt_display_name", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_display_name"`)}
+ }
+ if _, ok := _c.mutation.AdoptAvatar(); !ok {
+ return &ValidationError{Name: "adopt_avatar", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_avatar"`)}
+ }
+ if _, ok := _c.mutation.DecidedAt(); !ok {
+ return &ValidationError{Name: "decided_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.decided_at"`)}
+ }
+ if len(_c.mutation.PendingAuthSessionIDs()) == 0 {
+ return &ValidationError{Name: "pending_auth_session", err: errors.New(`ent: missing required edge "IdentityAdoptionDecision.pending_auth_session"`)}
+ }
+ return nil
+}
+
+func (_c *IdentityAdoptionDecisionCreate) sqlSave(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *IdentityAdoptionDecisionCreate) createSpec() (*IdentityAdoptionDecision, *sqlgraph.CreateSpec) {
+ var (
+ _node = &IdentityAdoptionDecision{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.AdoptDisplayName(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
+ _node.AdoptDisplayName = value
+ }
+ if value, ok := _c.mutation.AdoptAvatar(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
+ _node.AdoptAvatar = value
+ }
+ if value, ok := _c.mutation.DecidedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldDecidedAt, field.TypeTime, value)
+ _node.DecidedAt = value
+ }
+ if nodes := _c.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.PendingAuthSessionID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.IdentityID = &nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.IdentityAdoptionDecision.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.IdentityAdoptionDecisionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreate) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertOne {
+ _c.conflict = opts
+ return &IdentityAdoptionDecisionUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreate) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &IdentityAdoptionDecisionUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // IdentityAdoptionDecisionUpsertOne is the builder for "upsert"-ing
+ // one IdentityAdoptionDecision node.
+ IdentityAdoptionDecisionUpsertOne struct {
+ create *IdentityAdoptionDecisionCreate
+ }
+
+ // IdentityAdoptionDecisionUpsert is the "OnConflict" setter.
+ IdentityAdoptionDecisionUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdentityAdoptionDecisionUpsert) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldUpdatedAt)
+ return u
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (u *IdentityAdoptionDecisionUpsert) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldPendingAuthSessionID, v)
+ return u
+}
+
+// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldPendingAuthSessionID)
+ return u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsert) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldIdentityID, v)
+ return u
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateIdentityID() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldIdentityID)
+ return u
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsert) ClearIdentityID() *IdentityAdoptionDecisionUpsert {
+ u.SetNull(identityadoptiondecision.FieldIdentityID)
+ return u
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (u *IdentityAdoptionDecisionUpsert) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldAdoptDisplayName, v)
+ return u
+}
+
+// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldAdoptDisplayName)
+ return u
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (u *IdentityAdoptionDecisionUpsert) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsert {
+ u.Set(identityadoptiondecision.FieldAdoptAvatar, v)
+ return u
+}
+
+// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsert {
+ u.SetExcluded(identityadoptiondecision.FieldAdoptAvatar)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateNewValues() *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldCreatedAt)
+ }
+ if _, exists := u.create.mutation.DecidedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldDecidedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertOne) Ignore() *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *IdentityAdoptionDecisionUpsertOne) DoNothing() *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreate.OnConflict
+// documentation for more info.
+func (u *IdentityAdoptionDecisionUpsertOne) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&IdentityAdoptionDecisionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetPendingAuthSessionID(v)
+ })
+}
+
+// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdatePendingAuthSessionID()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateIdentityID() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertOne) ClearIdentityID() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.ClearIdentityID()
+ })
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptDisplayName(v)
+ })
+}
+
+// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptDisplayName()
+ })
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptAvatar(v)
+ })
+}
+
+// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertOne {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptAvatar()
+ })
+}
+
+// Exec executes the query.
+func (u *IdentityAdoptionDecisionUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for IdentityAdoptionDecisionCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *IdentityAdoptionDecisionUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *IdentityAdoptionDecisionUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *IdentityAdoptionDecisionUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// IdentityAdoptionDecisionCreateBulk is the builder for creating many IdentityAdoptionDecision entities in bulk.
+type IdentityAdoptionDecisionCreateBulk struct {
+ config
+ err error
+ builders []*IdentityAdoptionDecisionCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the IdentityAdoptionDecision entities in the database.
+func (_c *IdentityAdoptionDecisionCreateBulk) Save(ctx context.Context) ([]*IdentityAdoptionDecision, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*IdentityAdoptionDecision, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*IdentityAdoptionDecisionMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *IdentityAdoptionDecisionCreateBulk) SaveX(ctx context.Context) []*IdentityAdoptionDecision {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *IdentityAdoptionDecisionCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *IdentityAdoptionDecisionCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.IdentityAdoptionDecision.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.IdentityAdoptionDecisionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertBulk {
+ _c.conflict = opts
+ return &IdentityAdoptionDecisionUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *IdentityAdoptionDecisionCreateBulk) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &IdentityAdoptionDecisionUpsertBulk{
+ create: _c,
+ }
+}
+
+// IdentityAdoptionDecisionUpsertBulk is the builder for "upsert"-ing
+// a bulk of IdentityAdoptionDecision nodes.
+type IdentityAdoptionDecisionUpsertBulk struct {
+ create *IdentityAdoptionDecisionCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateNewValues() *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldCreatedAt)
+ }
+ if _, exists := b.mutation.DecidedAt(); exists {
+ s.SetIgnore(identityadoptiondecision.FieldDecidedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.IdentityAdoptionDecision.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *IdentityAdoptionDecisionUpsertBulk) Ignore() *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *IdentityAdoptionDecisionUpsertBulk) DoNothing() *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreateBulk.OnConflict
+// documentation for more info.
+func (u *IdentityAdoptionDecisionUpsertBulk) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&IdentityAdoptionDecisionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetPendingAuthSessionID(v)
+ })
+}
+
+// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdatePendingAuthSessionID()
+ })
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetIdentityID(v)
+ })
+}
+
+// UpdateIdentityID sets the "identity_id" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateIdentityID() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateIdentityID()
+ })
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) ClearIdentityID() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.ClearIdentityID()
+ })
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptDisplayName(v)
+ })
+}
+
+// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptDisplayName()
+ })
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.SetAdoptAvatar(v)
+ })
+}
+
+// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create.
+func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertBulk {
+ return u.Update(func(s *IdentityAdoptionDecisionUpsert) {
+ s.UpdateAdoptAvatar()
+ })
+}
+
+// Exec executes the query.
+func (u *IdentityAdoptionDecisionUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdentityAdoptionDecisionCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for IdentityAdoptionDecisionCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *IdentityAdoptionDecisionUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/identityadoptiondecision_delete.go b/backend/ent/identityadoptiondecision_delete.go
new file mode 100644
index 00000000..ef3d328d
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdentityAdoptionDecisionDelete is the builder for deleting a IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionDelete struct {
+ config
+ hooks []Hook
+ mutation *IdentityAdoptionDecisionMutation
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder.
+func (_d *IdentityAdoptionDecisionDelete) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *IdentityAdoptionDecisionDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *IdentityAdoptionDecisionDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *IdentityAdoptionDecisionDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// IdentityAdoptionDecisionDeleteOne is the builder for deleting a single IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionDeleteOne struct {
+ _d *IdentityAdoptionDecisionDelete
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder.
+func (_d *IdentityAdoptionDecisionDeleteOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *IdentityAdoptionDecisionDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{identityadoptiondecision.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *IdentityAdoptionDecisionDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/identityadoptiondecision_query.go b/backend/ent/identityadoptiondecision_query.go
new file mode 100644
index 00000000..4082d8ee
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_query.go
@@ -0,0 +1,721 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdentityAdoptionDecisionQuery is the builder for querying IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionQuery struct {
+ config
+ ctx *QueryContext
+ order []identityadoptiondecision.OrderOption
+ inters []Interceptor
+ predicates []predicate.IdentityAdoptionDecision
+ withPendingAuthSession *PendingAuthSessionQuery
+ withIdentity *AuthIdentityQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the IdentityAdoptionDecisionQuery builder.
+func (_q *IdentityAdoptionDecisionQuery) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *IdentityAdoptionDecisionQuery) Limit(limit int) *IdentityAdoptionDecisionQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *IdentityAdoptionDecisionQuery) Offset(offset int) *IdentityAdoptionDecisionQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *IdentityAdoptionDecisionQuery) Unique(unique bool) *IdentityAdoptionDecisionQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *IdentityAdoptionDecisionQuery) Order(o ...identityadoptiondecision.OrderOption) *IdentityAdoptionDecisionQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryPendingAuthSession chains the current query on the "pending_auth_session" edge.
+func (_q *IdentityAdoptionDecisionQuery) QueryPendingAuthSession() *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryIdentity chains the current query on the "identity" edge.
+func (_q *IdentityAdoptionDecisionQuery) QueryIdentity() *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first IdentityAdoptionDecision entity from the query.
+// Returns a *NotFoundError when no IdentityAdoptionDecision was found.
+func (_q *IdentityAdoptionDecisionQuery) First(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{identityadoptiondecision.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) FirstX(ctx context.Context) *IdentityAdoptionDecision {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first IdentityAdoptionDecision ID from the query.
+// Returns a *NotFoundError when no IdentityAdoptionDecision ID was found.
+func (_q *IdentityAdoptionDecisionQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{identityadoptiondecision.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single IdentityAdoptionDecision entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one IdentityAdoptionDecision entity is found.
+// Returns a *NotFoundError when no IdentityAdoptionDecision entities are found.
+func (_q *IdentityAdoptionDecisionQuery) Only(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{identityadoptiondecision.Label}
+ default:
+ return nil, &NotSingularError{identityadoptiondecision.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) OnlyX(ctx context.Context) *IdentityAdoptionDecision {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only IdentityAdoptionDecision ID in the query.
+// Returns a *NotSingularError when more than one IdentityAdoptionDecision ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *IdentityAdoptionDecisionQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{identityadoptiondecision.Label}
+ default:
+ err = &NotSingularError{identityadoptiondecision.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of IdentityAdoptionDecisions.
+func (_q *IdentityAdoptionDecisionQuery) All(ctx context.Context) ([]*IdentityAdoptionDecision, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*IdentityAdoptionDecision, *IdentityAdoptionDecisionQuery]()
+ return withInterceptors[[]*IdentityAdoptionDecision](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) AllX(ctx context.Context) []*IdentityAdoptionDecision {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of IdentityAdoptionDecision IDs.
+func (_q *IdentityAdoptionDecisionQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(identityadoptiondecision.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *IdentityAdoptionDecisionQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*IdentityAdoptionDecisionQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *IdentityAdoptionDecisionQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *IdentityAdoptionDecisionQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the IdentityAdoptionDecisionQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *IdentityAdoptionDecisionQuery) Clone() *IdentityAdoptionDecisionQuery {
+ if _q == nil {
+ return nil
+ }
+ return &IdentityAdoptionDecisionQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]identityadoptiondecision.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.IdentityAdoptionDecision{}, _q.predicates...),
+ withPendingAuthSession: _q.withPendingAuthSession.Clone(),
+ withIdentity: _q.withIdentity.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithPendingAuthSession tells the query-builder to eager-load the nodes that are connected to
+// the "pending_auth_session" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *IdentityAdoptionDecisionQuery) WithPendingAuthSession(opts ...func(*PendingAuthSessionQuery)) *IdentityAdoptionDecisionQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withPendingAuthSession = query
+ return _q
+}
+
+// WithIdentity tells the query-builder to eager-load the nodes that are connected to
+// the "identity" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *IdentityAdoptionDecisionQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *IdentityAdoptionDecisionQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withIdentity = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.IdentityAdoptionDecision.Query().
+// GroupBy(identityadoptiondecision.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *IdentityAdoptionDecisionQuery) GroupBy(field string, fields ...string) *IdentityAdoptionDecisionGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &IdentityAdoptionDecisionGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = identityadoptiondecision.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.IdentityAdoptionDecision.Query().
+// Select(identityadoptiondecision.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *IdentityAdoptionDecisionQuery) Select(fields ...string) *IdentityAdoptionDecisionSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &IdentityAdoptionDecisionSelect{IdentityAdoptionDecisionQuery: _q}
+ sbuild.label = identityadoptiondecision.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a IdentityAdoptionDecisionSelect configured with the given aggregations.
+func (_q *IdentityAdoptionDecisionQuery) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *IdentityAdoptionDecisionQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !identityadoptiondecision.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *IdentityAdoptionDecisionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdentityAdoptionDecision, error) {
+ var (
+ nodes = []*IdentityAdoptionDecision{}
+ _spec = _q.querySpec()
+ loadedTypes = [2]bool{
+ _q.withPendingAuthSession != nil,
+ _q.withIdentity != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*IdentityAdoptionDecision).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &IdentityAdoptionDecision{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withPendingAuthSession; query != nil {
+ if err := _q.loadPendingAuthSession(ctx, query, nodes, nil,
+ func(n *IdentityAdoptionDecision, e *PendingAuthSession) { n.Edges.PendingAuthSession = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withIdentity; query != nil {
+ if err := _q.loadIdentity(ctx, query, nodes, nil,
+ func(n *IdentityAdoptionDecision, e *AuthIdentity) { n.Edges.Identity = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *IdentityAdoptionDecisionQuery) loadPendingAuthSession(ctx context.Context, query *PendingAuthSessionQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *PendingAuthSession)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*IdentityAdoptionDecision)
+ for i := range nodes {
+ fk := nodes[i].PendingAuthSessionID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(pendingauthsession.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "pending_auth_session_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *IdentityAdoptionDecisionQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *AuthIdentity)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*IdentityAdoptionDecision)
+ for i := range nodes {
+ if nodes[i].IdentityID == nil {
+ continue
+ }
+ fk := *nodes[i].IdentityID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(authidentity.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *IdentityAdoptionDecisionQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *IdentityAdoptionDecisionQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID)
+ for i := range fields {
+ if fields[i] != identityadoptiondecision.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withPendingAuthSession != nil {
+ _spec.Node.AddColumnOnce(identityadoptiondecision.FieldPendingAuthSessionID)
+ }
+ if _q.withIdentity != nil {
+ _spec.Node.AddColumnOnce(identityadoptiondecision.FieldIdentityID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *IdentityAdoptionDecisionQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(identityadoptiondecision.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = identityadoptiondecision.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *IdentityAdoptionDecisionQuery) ForUpdate(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *IdentityAdoptionDecisionQuery) ForShare(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// IdentityAdoptionDecisionGroupBy is the group-by builder for IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionGroupBy struct {
+ selector
+ build *IdentityAdoptionDecisionQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *IdentityAdoptionDecisionGroupBy) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *IdentityAdoptionDecisionGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *IdentityAdoptionDecisionGroupBy) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// IdentityAdoptionDecisionSelect is the builder for selecting fields of IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionSelect struct {
+ *IdentityAdoptionDecisionQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *IdentityAdoptionDecisionSelect) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *IdentityAdoptionDecisionSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionSelect](ctx, _s.IdentityAdoptionDecisionQuery, _s, _s.inters, v)
+}
+
+func (_s *IdentityAdoptionDecisionSelect) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/identityadoptiondecision_update.go b/backend/ent/identityadoptiondecision_update.go
new file mode 100644
index 00000000..0ca21d27
--- /dev/null
+++ b/backend/ent/identityadoptiondecision_update.go
@@ -0,0 +1,532 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// IdentityAdoptionDecisionUpdate is the builder for updating IdentityAdoptionDecision entities.
+type IdentityAdoptionDecisionUpdate struct {
+ config
+ hooks []Hook
+ mutation *IdentityAdoptionDecisionMutation
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder.
+func (_u *IdentityAdoptionDecisionUpdate) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetPendingAuthSessionID(v)
+ return _u
+}
+
+// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetPendingAuthSessionID(*v)
+ }
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdate) ClearIdentityID() *IdentityAdoptionDecisionUpdate {
+ _u.mutation.ClearIdentityID()
+ return _u
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetAdoptDisplayName(v)
+ return _u
+}
+
+// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetAdoptDisplayName(*v)
+ }
+ return _u
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (_u *IdentityAdoptionDecisionUpdate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdate {
+ _u.mutation.SetAdoptAvatar(v)
+ return _u
+}
+
+// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdate {
+ if v != nil {
+ _u.SetAdoptAvatar(*v)
+ }
+ return _u
+}
+
+// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdate {
+ return _u.SetPendingAuthSessionID(v.ID)
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdate {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
+func (_u *IdentityAdoptionDecisionUpdate) Mutation() *IdentityAdoptionDecisionMutation {
+ return _u.mutation
+}
+
+// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdate) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdate {
+ _u.mutation.ClearPendingAuthSession()
+ return _u
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdate) ClearIdentity() *IdentityAdoptionDecisionUpdate {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *IdentityAdoptionDecisionUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *IdentityAdoptionDecisionUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *IdentityAdoptionDecisionUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := identityadoptiondecision.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *IdentityAdoptionDecisionUpdate) check() error {
+ if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`)
+ }
+ return nil
+}
+
+func (_u *IdentityAdoptionDecisionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.AdoptDisplayName(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.AdoptAvatar(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
+ }
+ if _u.mutation.PendingAuthSessionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{identityadoptiondecision.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// IdentityAdoptionDecisionUpdateOne is the builder for updating a single IdentityAdoptionDecision entity.
+type IdentityAdoptionDecisionUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *IdentityAdoptionDecisionMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetPendingAuthSessionID(v)
+ return _u
+}
+
+// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetPendingAuthSessionID(*v)
+ }
+ return _u
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetIdentityID(v)
+ return _u
+}
+
+// SetNillableIdentityID sets the "identity_id" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetIdentityID(*v)
+ }
+ return _u
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentityID() *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.ClearIdentityID()
+ return _u
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetAdoptDisplayName(v)
+ return _u
+}
+
+// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetAdoptDisplayName(*v)
+ }
+ return _u
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.SetAdoptAvatar(v)
+ return _u
+}
+
+// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdateOne {
+ if v != nil {
+ _u.SetAdoptAvatar(*v)
+ }
+ return _u
+}
+
+// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdateOne {
+ return _u.SetPendingAuthSessionID(v.ID)
+}
+
+// SetIdentity sets the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdateOne {
+ return _u.SetIdentityID(v.ID)
+}
+
+// Mutation returns the IdentityAdoptionDecisionMutation object of the builder.
+func (_u *IdentityAdoptionDecisionUpdateOne) Mutation() *IdentityAdoptionDecisionMutation {
+ return _u.mutation
+}
+
+// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.ClearPendingAuthSession()
+ return _u
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentity() *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.ClearIdentity()
+ return _u
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder.
+func (_u *IdentityAdoptionDecisionUpdateOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *IdentityAdoptionDecisionUpdateOne) Select(field string, fields ...string) *IdentityAdoptionDecisionUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated IdentityAdoptionDecision entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) Save(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdateOne) SaveX(ctx context.Context) *IdentityAdoptionDecision {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *IdentityAdoptionDecisionUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *IdentityAdoptionDecisionUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *IdentityAdoptionDecisionUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := identityadoptiondecision.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *IdentityAdoptionDecisionUpdateOne) check() error {
+ if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`)
+ }
+ return nil
+}
+
+func (_u *IdentityAdoptionDecisionUpdateOne) sqlSave(ctx context.Context) (_node *IdentityAdoptionDecision, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdentityAdoptionDecision.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID)
+ for _, f := range fields {
+ if !identityadoptiondecision.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != identityadoptiondecision.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.AdoptDisplayName(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.AdoptAvatar(); ok {
+ _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value)
+ }
+ if _u.mutation.PendingAuthSessionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: true,
+ Table: identityadoptiondecision.PendingAuthSessionTable,
+ Columns: []string{identityadoptiondecision.PendingAuthSessionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.IdentityCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: identityadoptiondecision.IdentityTable,
+ Columns: []string{identityadoptiondecision.IdentityColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &IdentityAdoptionDecision{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{identityadoptiondecision.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go
index 13169ca7..95b68e09 100644
--- a/backend/ent/intercept/intercept.go
+++ b/backend/ent/intercept/intercept.go
@@ -13,9 +13,20 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -23,6 +34,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
+ "github.com/Wei-Shaw/sub2api/ent/subscriptionplan"
"github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -224,6 +236,168 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err
return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q)
}
+// The AuthIdentityFunc type is an adapter to allow the use of ordinary function as a Querier.
+type AuthIdentityFunc func(context.Context, *ent.AuthIdentityQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f AuthIdentityFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.AuthIdentityQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q)
+}
+
+// The TraverseAuthIdentity type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseAuthIdentity func(context.Context, *ent.AuthIdentityQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseAuthIdentity) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseAuthIdentity) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.AuthIdentityQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q)
+}
+
+// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary function as a Querier.
+type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f AuthIdentityChannelFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.AuthIdentityChannelQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q)
+}
+
+// The TraverseAuthIdentityChannel type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseAuthIdentityChannel func(context.Context, *ent.AuthIdentityChannelQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseAuthIdentityChannel) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseAuthIdentityChannel) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.AuthIdentityChannelQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q)
+}
+
+// The ChannelMonitorFunc type is an adapter to allow the use of ordinary function as a Querier.
+type ChannelMonitorFunc func(context.Context, *ent.ChannelMonitorQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f ChannelMonitorFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.ChannelMonitorQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorQuery", q)
+}
+
+// The TraverseChannelMonitor type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseChannelMonitor func(context.Context, *ent.ChannelMonitorQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseChannelMonitor) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseChannelMonitor) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.ChannelMonitorQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorQuery", q)
+}
+
+// The ChannelMonitorDailyRollupFunc type is an adapter to allow the use of ordinary function as a Querier.
+type ChannelMonitorDailyRollupFunc func(context.Context, *ent.ChannelMonitorDailyRollupQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f ChannelMonitorDailyRollupFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.ChannelMonitorDailyRollupQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorDailyRollupQuery", q)
+}
+
+// The TraverseChannelMonitorDailyRollup type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseChannelMonitorDailyRollup func(context.Context, *ent.ChannelMonitorDailyRollupQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseChannelMonitorDailyRollup) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseChannelMonitorDailyRollup) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.ChannelMonitorDailyRollupQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorDailyRollupQuery", q)
+}
+
+// The ChannelMonitorHistoryFunc type is an adapter to allow the use of ordinary function as a Querier.
+type ChannelMonitorHistoryFunc func(context.Context, *ent.ChannelMonitorHistoryQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f ChannelMonitorHistoryFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.ChannelMonitorHistoryQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorHistoryQuery", q)
+}
+
+// The TraverseChannelMonitorHistory type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseChannelMonitorHistory func(context.Context, *ent.ChannelMonitorHistoryQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseChannelMonitorHistory) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseChannelMonitorHistory) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.ChannelMonitorHistoryQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorHistoryQuery", q)
+}
+
+// The ChannelMonitorRequestTemplateFunc type is an adapter to allow the use of ordinary function as a Querier.
+type ChannelMonitorRequestTemplateFunc func(context.Context, *ent.ChannelMonitorRequestTemplateQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f ChannelMonitorRequestTemplateFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.ChannelMonitorRequestTemplateQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorRequestTemplateQuery", q)
+}
+
+// The TraverseChannelMonitorRequestTemplate type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseChannelMonitorRequestTemplate func(context.Context, *ent.ChannelMonitorRequestTemplateQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseChannelMonitorRequestTemplate) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseChannelMonitorRequestTemplate) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.ChannelMonitorRequestTemplateQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.ChannelMonitorRequestTemplateQuery", q)
+}
+
// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier.
type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error)
@@ -305,6 +479,141 @@ func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) er
return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q)
}
+// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary function as a Querier.
+type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f IdentityAdoptionDecisionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q)
+}
+
+// The TraverseIdentityAdoptionDecision type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseIdentityAdoptionDecision func(context.Context, *ent.IdentityAdoptionDecisionQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseIdentityAdoptionDecision) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseIdentityAdoptionDecision) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q)
+}
+
+// The PaymentAuditLogFunc type is an adapter to allow the use of ordinary function as a Querier.
+type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f PaymentAuditLogFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.PaymentAuditLogQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.PaymentAuditLogQuery", q)
+}
+
+// The TraversePaymentAuditLog type is an adapter to allow the use of ordinary function as Traverser.
+type TraversePaymentAuditLog func(context.Context, *ent.PaymentAuditLogQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraversePaymentAuditLog) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraversePaymentAuditLog) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.PaymentAuditLogQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.PaymentAuditLogQuery", q)
+}
+
+// The PaymentOrderFunc type is an adapter to allow the use of ordinary function as a Querier.
+type PaymentOrderFunc func(context.Context, *ent.PaymentOrderQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f PaymentOrderFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.PaymentOrderQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.PaymentOrderQuery", q)
+}
+
+// The TraversePaymentOrder type is an adapter to allow the use of ordinary function as Traverser.
+type TraversePaymentOrder func(context.Context, *ent.PaymentOrderQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraversePaymentOrder) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraversePaymentOrder) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.PaymentOrderQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.PaymentOrderQuery", q)
+}
+
+// The PaymentProviderInstanceFunc type is an adapter to allow the use of ordinary function as a Querier.
+type PaymentProviderInstanceFunc func(context.Context, *ent.PaymentProviderInstanceQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f PaymentProviderInstanceFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.PaymentProviderInstanceQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.PaymentProviderInstanceQuery", q)
+}
+
+// The TraversePaymentProviderInstance type is an adapter to allow the use of ordinary function as Traverser.
+type TraversePaymentProviderInstance func(context.Context, *ent.PaymentProviderInstanceQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraversePaymentProviderInstance) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraversePaymentProviderInstance) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.PaymentProviderInstanceQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.PaymentProviderInstanceQuery", q)
+}
+
+// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary function as a Querier.
+type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f PendingAuthSessionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.PendingAuthSessionQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q)
+}
+
+// The TraversePendingAuthSession type is an adapter to allow the use of ordinary function as Traverser.
+type TraversePendingAuthSession func(context.Context, *ent.PendingAuthSessionQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraversePendingAuthSession) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraversePendingAuthSession) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.PendingAuthSessionQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q)
+}
+
// The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier.
type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error)
@@ -467,6 +776,33 @@ func (f TraverseSetting) Traverse(ctx context.Context, q ent.Query) error {
return fmt.Errorf("unexpected query type %T. expect *ent.SettingQuery", q)
}
+// The SubscriptionPlanFunc type is an adapter to allow the use of ordinary function as a Querier.
+type SubscriptionPlanFunc func(context.Context, *ent.SubscriptionPlanQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f SubscriptionPlanFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.SubscriptionPlanQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.SubscriptionPlanQuery", q)
+}
+
+// The TraverseSubscriptionPlan type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseSubscriptionPlan func(context.Context, *ent.SubscriptionPlanQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseSubscriptionPlan) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseSubscriptionPlan) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.SubscriptionPlanQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.SubscriptionPlanQuery", q)
+}
+
// The TLSFingerprintProfileFunc type is an adapter to allow the use of ordinary function as a Querier.
type TLSFingerprintProfileFunc func(context.Context, *ent.TLSFingerprintProfileQuery) (ent.Value, error)
@@ -696,12 +1032,34 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil
case *ent.AnnouncementReadQuery:
return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil
+ case *ent.AuthIdentityQuery:
+ return &query[*ent.AuthIdentityQuery, predicate.AuthIdentity, authidentity.OrderOption]{typ: ent.TypeAuthIdentity, tq: q}, nil
+ case *ent.AuthIdentityChannelQuery:
+ return &query[*ent.AuthIdentityChannelQuery, predicate.AuthIdentityChannel, authidentitychannel.OrderOption]{typ: ent.TypeAuthIdentityChannel, tq: q}, nil
+ case *ent.ChannelMonitorQuery:
+ return &query[*ent.ChannelMonitorQuery, predicate.ChannelMonitor, channelmonitor.OrderOption]{typ: ent.TypeChannelMonitor, tq: q}, nil
+ case *ent.ChannelMonitorDailyRollupQuery:
+ return &query[*ent.ChannelMonitorDailyRollupQuery, predicate.ChannelMonitorDailyRollup, channelmonitordailyrollup.OrderOption]{typ: ent.TypeChannelMonitorDailyRollup, tq: q}, nil
+ case *ent.ChannelMonitorHistoryQuery:
+ return &query[*ent.ChannelMonitorHistoryQuery, predicate.ChannelMonitorHistory, channelmonitorhistory.OrderOption]{typ: ent.TypeChannelMonitorHistory, tq: q}, nil
+ case *ent.ChannelMonitorRequestTemplateQuery:
+ return &query[*ent.ChannelMonitorRequestTemplateQuery, predicate.ChannelMonitorRequestTemplate, channelmonitorrequesttemplate.OrderOption]{typ: ent.TypeChannelMonitorRequestTemplate, tq: q}, nil
case *ent.ErrorPassthroughRuleQuery:
return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil
case *ent.GroupQuery:
return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil
case *ent.IdempotencyRecordQuery:
return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil
+ case *ent.IdentityAdoptionDecisionQuery:
+ return &query[*ent.IdentityAdoptionDecisionQuery, predicate.IdentityAdoptionDecision, identityadoptiondecision.OrderOption]{typ: ent.TypeIdentityAdoptionDecision, tq: q}, nil
+ case *ent.PaymentAuditLogQuery:
+ return &query[*ent.PaymentAuditLogQuery, predicate.PaymentAuditLog, paymentauditlog.OrderOption]{typ: ent.TypePaymentAuditLog, tq: q}, nil
+ case *ent.PaymentOrderQuery:
+ return &query[*ent.PaymentOrderQuery, predicate.PaymentOrder, paymentorder.OrderOption]{typ: ent.TypePaymentOrder, tq: q}, nil
+ case *ent.PaymentProviderInstanceQuery:
+ return &query[*ent.PaymentProviderInstanceQuery, predicate.PaymentProviderInstance, paymentproviderinstance.OrderOption]{typ: ent.TypePaymentProviderInstance, tq: q}, nil
+ case *ent.PendingAuthSessionQuery:
+ return &query[*ent.PendingAuthSessionQuery, predicate.PendingAuthSession, pendingauthsession.OrderOption]{typ: ent.TypePendingAuthSession, tq: q}, nil
case *ent.PromoCodeQuery:
return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil
case *ent.PromoCodeUsageQuery:
@@ -714,6 +1072,8 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.SecuritySecretQuery, predicate.SecuritySecret, securitysecret.OrderOption]{typ: ent.TypeSecuritySecret, tq: q}, nil
case *ent.SettingQuery:
return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil
+ case *ent.SubscriptionPlanQuery:
+ return &query[*ent.SubscriptionPlanQuery, predicate.SubscriptionPlan, subscriptionplan.OrderOption]{typ: ent.TypeSubscriptionPlan, tq: q}, nil
case *ent.TLSFingerprintProfileQuery:
return &query[*ent.TLSFingerprintProfileQuery, predicate.TLSFingerprintProfile, tlsfingerprintprofile.OrderOption]{typ: ent.TypeTLSFingerprintProfile, tq: q}, nil
case *ent.UsageCleanupTaskQuery:
diff --git a/backend/ent/migrate/auth_identity_fk_ondelete_test.go b/backend/ent/migrate/auth_identity_fk_ondelete_test.go
new file mode 100644
index 00000000..0e37025a
--- /dev/null
+++ b/backend/ent/migrate/auth_identity_fk_ondelete_test.go
@@ -0,0 +1,73 @@
+package migrate
+
+import (
+ "testing"
+
+ "entgo.io/ent/dialect/entsql"
+ entschema "entgo.io/ent/dialect/sql/schema"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityFoundationForeignKeyOnDeleteActions(t *testing.T) {
+ require.Equal(
+ t,
+ entschema.Cascade,
+ findForeignKeyBySymbol(t, AuthIdentitiesTable, "auth_identities_users_auth_identities").OnDelete,
+ )
+ require.Equal(
+ t,
+ entschema.Cascade,
+ findForeignKeyBySymbol(t, AuthIdentityChannelsTable, "auth_identity_channels_auth_identities_channels").OnDelete,
+ )
+ require.Equal(
+ t,
+ entschema.Cascade,
+ findForeignKeyBySymbol(t, IdentityAdoptionDecisionsTable, "identity_adoption_decisions_pending_auth_sessions_adoption_decision").OnDelete,
+ )
+
+ require.Equal(
+ t,
+ entschema.SetNull,
+ findForeignKeyBySymbol(t, PendingAuthSessionsTable, "pending_auth_sessions_users_pending_auth_sessions").OnDelete,
+ )
+ require.Equal(
+ t,
+ entschema.SetNull,
+ findForeignKeyBySymbol(t, IdentityAdoptionDecisionsTable, "identity_adoption_decisions_auth_identities_adoption_decisions").OnDelete,
+ )
+}
+
+func TestPaymentOrdersOutTradeNoPartialUniqueIndex(t *testing.T) {
+ idx := findIndexByName(t, PaymentOrdersTable, "paymentorder_out_trade_no")
+ require.True(t, idx.Unique)
+ require.Len(t, idx.Columns, 1)
+ require.Equal(t, "out_trade_no", idx.Columns[0].Name)
+ require.NotNil(t, idx.Annotation)
+ require.Equal(t, (&entsql.IndexAnnotation{Where: "out_trade_no <> ''"}).Where, idx.Annotation.Where)
+}
+
+func findForeignKeyBySymbol(t *testing.T, table *entschema.Table, symbol string) *entschema.ForeignKey {
+ t.Helper()
+
+ for _, fk := range table.ForeignKeys {
+ if fk.Symbol == symbol {
+ return fk
+ }
+ }
+
+ require.Failf(t, "missing foreign key", "table %s should include foreign key %s", table.Name, symbol)
+ return nil
+}
+
+func findIndexByName(t *testing.T, table *entschema.Table, name string) *entschema.Index {
+ t.Helper()
+
+ for _, idx := range table.Indexes {
+ if idx.Name == name {
+ return idx
+ }
+ }
+
+ require.Failf(t, "missing index", "table %s should include index %s", table.Name, name)
+ return nil
+}
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index 6c56f2d0..178ae170 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -338,6 +338,252 @@ var (
},
},
}
+ // AuthIdentitiesColumns holds the columns for the "auth_identities" table.
+ AuthIdentitiesColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "provider_type", Type: field.TypeString, Size: 20},
+ {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "issuer", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "user_id", Type: field.TypeInt64},
+ }
+ // AuthIdentitiesTable holds the schema information for the "auth_identities" table.
+ AuthIdentitiesTable = &schema.Table{
+ Name: "auth_identities",
+ Columns: AuthIdentitiesColumns,
+ PrimaryKey: []*schema.Column{AuthIdentitiesColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "auth_identities_users_auth_identities",
+ Columns: []*schema.Column{AuthIdentitiesColumns[9]},
+ RefColumns: []*schema.Column{UsersColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "authidentity_provider_type_provider_key_provider_subject",
+ Unique: true,
+ Columns: []*schema.Column{AuthIdentitiesColumns[3], AuthIdentitiesColumns[4], AuthIdentitiesColumns[5]},
+ },
+ {
+ Name: "authidentity_user_id",
+ Unique: false,
+ Columns: []*schema.Column{AuthIdentitiesColumns[9]},
+ },
+ {
+ Name: "authidentity_user_id_provider_type",
+ Unique: false,
+ Columns: []*schema.Column{AuthIdentitiesColumns[9], AuthIdentitiesColumns[3]},
+ },
+ },
+ }
+ // AuthIdentityChannelsColumns holds the columns for the "auth_identity_channels" table.
+ AuthIdentityChannelsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "provider_type", Type: field.TypeString, Size: 20},
+ {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "channel", Type: field.TypeString, Size: 20},
+ {Name: "channel_app_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "channel_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "identity_id", Type: field.TypeInt64},
+ }
+ // AuthIdentityChannelsTable holds the schema information for the "auth_identity_channels" table.
+ AuthIdentityChannelsTable = &schema.Table{
+ Name: "auth_identity_channels",
+ Columns: AuthIdentityChannelsColumns,
+ PrimaryKey: []*schema.Column{AuthIdentityChannelsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "auth_identity_channels_auth_identities_channels",
+ Columns: []*schema.Column{AuthIdentityChannelsColumns[9]},
+ RefColumns: []*schema.Column{AuthIdentitiesColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "authidentitychannel_provider_type_provider_key_channel_channel_app_id_channel_subject",
+ Unique: true,
+ Columns: []*schema.Column{AuthIdentityChannelsColumns[3], AuthIdentityChannelsColumns[4], AuthIdentityChannelsColumns[5], AuthIdentityChannelsColumns[6], AuthIdentityChannelsColumns[7]},
+ },
+ {
+ Name: "authidentitychannel_identity_id",
+ Unique: false,
+ Columns: []*schema.Column{AuthIdentityChannelsColumns[9]},
+ },
+ },
+ }
+ // ChannelMonitorsColumns holds the columns for the "channel_monitors" table.
+ ChannelMonitorsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "name", Type: field.TypeString, Size: 100},
+ {Name: "provider", Type: field.TypeEnum, Enums: []string{"openai", "anthropic", "gemini"}},
+ {Name: "endpoint", Type: field.TypeString, Size: 500},
+ {Name: "api_key_encrypted", Type: field.TypeString},
+ {Name: "primary_model", Type: field.TypeString, Size: 200},
+ {Name: "extra_models", Type: field.TypeJSON},
+ {Name: "group_name", Type: field.TypeString, Nullable: true, Size: 100, Default: ""},
+ {Name: "enabled", Type: field.TypeBool, Default: true},
+ {Name: "interval_seconds", Type: field.TypeInt},
+ {Name: "last_checked_at", Type: field.TypeTime, Nullable: true},
+ {Name: "created_by", Type: field.TypeInt64},
+ {Name: "extra_headers", Type: field.TypeJSON},
+ {Name: "body_override_mode", Type: field.TypeString, Size: 10, Default: "off"},
+ {Name: "body_override", Type: field.TypeJSON, Nullable: true},
+ {Name: "template_id", Type: field.TypeInt64, Nullable: true},
+ }
+ // ChannelMonitorsTable holds the schema information for the "channel_monitors" table.
+ ChannelMonitorsTable = &schema.Table{
+ Name: "channel_monitors",
+ Columns: ChannelMonitorsColumns,
+ PrimaryKey: []*schema.Column{ChannelMonitorsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "channel_monitors_channel_monitor_request_templates_request_template",
+ Columns: []*schema.Column{ChannelMonitorsColumns[17]},
+ RefColumns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[0]},
+ OnDelete: schema.SetNull,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "channelmonitor_enabled_last_checked_at",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorsColumns[10], ChannelMonitorsColumns[12]},
+ },
+ {
+ Name: "channelmonitor_provider",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorsColumns[4]},
+ },
+ {
+ Name: "channelmonitor_group_name",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorsColumns[9]},
+ },
+ {
+ Name: "channelmonitor_template_id",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorsColumns[17]},
+ },
+ },
+ }
+ // ChannelMonitorDailyRollupsColumns holds the columns for the "channel_monitor_daily_rollups" table.
+ ChannelMonitorDailyRollupsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "model", Type: field.TypeString, Size: 200},
+ {Name: "bucket_date", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "date"}},
+ {Name: "total_checks", Type: field.TypeInt, Default: 0},
+ {Name: "ok_count", Type: field.TypeInt, Default: 0},
+ {Name: "operational_count", Type: field.TypeInt, Default: 0},
+ {Name: "degraded_count", Type: field.TypeInt, Default: 0},
+ {Name: "failed_count", Type: field.TypeInt, Default: 0},
+ {Name: "error_count", Type: field.TypeInt, Default: 0},
+ {Name: "sum_latency_ms", Type: field.TypeInt64, Default: 0},
+ {Name: "count_latency", Type: field.TypeInt, Default: 0},
+ {Name: "sum_ping_latency_ms", Type: field.TypeInt64, Default: 0},
+ {Name: "count_ping_latency", Type: field.TypeInt, Default: 0},
+ {Name: "computed_at", Type: field.TypeTime},
+ {Name: "monitor_id", Type: field.TypeInt64},
+ }
+ // ChannelMonitorDailyRollupsTable holds the schema information for the "channel_monitor_daily_rollups" table.
+ ChannelMonitorDailyRollupsTable = &schema.Table{
+ Name: "channel_monitor_daily_rollups",
+ Columns: ChannelMonitorDailyRollupsColumns,
+ PrimaryKey: []*schema.Column{ChannelMonitorDailyRollupsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "channel_monitor_daily_rollups_channel_monitors_daily_rollups",
+ Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[14]},
+ RefColumns: []*schema.Column{ChannelMonitorsColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "channelmonitordailyrollup_monitor_id_model_bucket_date",
+ Unique: true,
+ Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[14], ChannelMonitorDailyRollupsColumns[1], ChannelMonitorDailyRollupsColumns[2]},
+ },
+ {
+ Name: "channelmonitordailyrollup_bucket_date",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorDailyRollupsColumns[2]},
+ },
+ },
+ }
+ // ChannelMonitorHistoriesColumns holds the columns for the "channel_monitor_histories" table.
+ ChannelMonitorHistoriesColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "model", Type: field.TypeString, Size: 200},
+ {Name: "status", Type: field.TypeEnum, Enums: []string{"operational", "degraded", "failed", "error"}},
+ {Name: "latency_ms", Type: field.TypeInt, Nullable: true},
+ {Name: "ping_latency_ms", Type: field.TypeInt, Nullable: true},
+ {Name: "message", Type: field.TypeString, Nullable: true, Size: 500, Default: ""},
+ {Name: "checked_at", Type: field.TypeTime},
+ {Name: "monitor_id", Type: field.TypeInt64},
+ }
+ // ChannelMonitorHistoriesTable holds the schema information for the "channel_monitor_histories" table.
+ ChannelMonitorHistoriesTable = &schema.Table{
+ Name: "channel_monitor_histories",
+ Columns: ChannelMonitorHistoriesColumns,
+ PrimaryKey: []*schema.Column{ChannelMonitorHistoriesColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "channel_monitor_histories_channel_monitors_history",
+ Columns: []*schema.Column{ChannelMonitorHistoriesColumns[7]},
+ RefColumns: []*schema.Column{ChannelMonitorsColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "channelmonitorhistory_monitor_id_model_checked_at",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorHistoriesColumns[7], ChannelMonitorHistoriesColumns[1], ChannelMonitorHistoriesColumns[6]},
+ },
+ {
+ Name: "channelmonitorhistory_checked_at",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorHistoriesColumns[6]},
+ },
+ },
+ }
+ // ChannelMonitorRequestTemplatesColumns holds the columns for the "channel_monitor_request_templates" table.
+ ChannelMonitorRequestTemplatesColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "name", Type: field.TypeString, Size: 100},
+ {Name: "provider", Type: field.TypeEnum, Enums: []string{"openai", "anthropic", "gemini"}},
+ {Name: "description", Type: field.TypeString, Nullable: true, Size: 500, Default: ""},
+ {Name: "extra_headers", Type: field.TypeJSON},
+ {Name: "body_override_mode", Type: field.TypeString, Size: 10, Default: "off"},
+ {Name: "body_override", Type: field.TypeJSON, Nullable: true},
+ }
+ // ChannelMonitorRequestTemplatesTable holds the schema information for the "channel_monitor_request_templates" table.
+ ChannelMonitorRequestTemplatesTable = &schema.Table{
+ Name: "channel_monitor_request_templates",
+ Columns: ChannelMonitorRequestTemplatesColumns,
+ PrimaryKey: []*schema.Column{ChannelMonitorRequestTemplatesColumns[0]},
+ Indexes: []*schema.Index{
+ {
+ Name: "channelmonitorrequesttemplate_provider_name",
+ Unique: true,
+ Columns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[4], ChannelMonitorRequestTemplatesColumns[3]},
+ },
+ },
+ }
// ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table.
ErrorPassthroughRulesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -395,11 +641,6 @@ var (
{Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
- {Name: "sora_image_price_360", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
- {Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
- {Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
- {Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
- {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0},
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
{Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true},
@@ -412,6 +653,8 @@ var (
{Name: "require_oauth_only", Type: field.TypeBool, Default: false},
{Name: "require_privacy_set", Type: field.TypeBool, Default: false},
{Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""},
+ {Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "rpm_limit", Type: field.TypeInt, Default: 0},
}
// GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{
@@ -447,7 +690,7 @@ var (
{
Name: "group_sort_order",
Unique: false,
- Columns: []*schema.Column{GroupsColumns[30]},
+ Columns: []*schema.Column{GroupsColumns[25]},
},
},
}
@@ -489,6 +732,273 @@ var (
},
},
}
+ // IdentityAdoptionDecisionsColumns holds the columns for the "identity_adoption_decisions" table.
+ IdentityAdoptionDecisionsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "adopt_display_name", Type: field.TypeBool, Default: false},
+ {Name: "adopt_avatar", Type: field.TypeBool, Default: false},
+ {Name: "decided_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "identity_id", Type: field.TypeInt64, Nullable: true},
+ {Name: "pending_auth_session_id", Type: field.TypeInt64, Unique: true},
+ }
+ // IdentityAdoptionDecisionsTable holds the schema information for the "identity_adoption_decisions" table.
+ IdentityAdoptionDecisionsTable = &schema.Table{
+ Name: "identity_adoption_decisions",
+ Columns: IdentityAdoptionDecisionsColumns,
+ PrimaryKey: []*schema.Column{IdentityAdoptionDecisionsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "identity_adoption_decisions_auth_identities_adoption_decisions",
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]},
+ RefColumns: []*schema.Column{AuthIdentitiesColumns[0]},
+ OnDelete: schema.SetNull,
+ },
+ {
+ Symbol: "identity_adoption_decisions_pending_auth_sessions_adoption_decision",
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]},
+ RefColumns: []*schema.Column{PendingAuthSessionsColumns[0]},
+ OnDelete: schema.Cascade,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "identityadoptiondecision_pending_auth_session_id",
+ Unique: true,
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]},
+ },
+ {
+ Name: "identityadoptiondecision_identity_id",
+ Unique: false,
+ Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]},
+ },
+ },
+ }
+ // PaymentAuditLogsColumns holds the columns for the "payment_audit_logs" table.
+ PaymentAuditLogsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "order_id", Type: field.TypeString, Size: 64},
+ {Name: "action", Type: field.TypeString, Size: 50},
+ {Name: "detail", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "operator", Type: field.TypeString, Size: 100, Default: "system"},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ }
+ // PaymentAuditLogsTable holds the schema information for the "payment_audit_logs" table.
+ PaymentAuditLogsTable = &schema.Table{
+ Name: "payment_audit_logs",
+ Columns: PaymentAuditLogsColumns,
+ PrimaryKey: []*schema.Column{PaymentAuditLogsColumns[0]},
+ Indexes: []*schema.Index{
+ {
+ Name: "paymentauditlog_order_id",
+ Unique: false,
+ Columns: []*schema.Column{PaymentAuditLogsColumns[1]},
+ },
+ },
+ }
+ // PaymentOrdersColumns holds the columns for the "payment_orders" table.
+ PaymentOrdersColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "user_email", Type: field.TypeString, Size: 255},
+ {Name: "user_name", Type: field.TypeString, Size: 100},
+ {Name: "user_notes", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "amount", Type: field.TypeFloat64, SchemaType: map[string]string{"postgres": "decimal(20,2)"}},
+ {Name: "pay_amount", Type: field.TypeFloat64, SchemaType: map[string]string{"postgres": "decimal(20,2)"}},
+ {Name: "fee_rate", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
+ {Name: "recharge_code", Type: field.TypeString, Size: 64},
+ {Name: "out_trade_no", Type: field.TypeString, Size: 64, Default: ""},
+ {Name: "payment_type", Type: field.TypeString, Size: 30},
+ {Name: "payment_trade_no", Type: field.TypeString, Size: 128},
+ {Name: "pay_url", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "qr_code", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "qr_code_img", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "order_type", Type: field.TypeString, Size: 20, Default: "balance"},
+ {Name: "plan_id", Type: field.TypeInt64, Nullable: true},
+ {Name: "subscription_group_id", Type: field.TypeInt64, Nullable: true},
+ {Name: "subscription_days", Type: field.TypeInt, Nullable: true},
+ {Name: "provider_instance_id", Type: field.TypeString, Nullable: true, Size: 64},
+ {Name: "provider_key", Type: field.TypeString, Nullable: true, Size: 30},
+ {Name: "provider_snapshot", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "status", Type: field.TypeString, Size: 30, Default: "PENDING"},
+ {Name: "refund_amount", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,2)"}},
+ {Name: "refund_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "refund_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "force_refund", Type: field.TypeBool, Default: false},
+ {Name: "refund_requested_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "refund_request_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "refund_requested_by", Type: field.TypeString, Nullable: true, Size: 20},
+ {Name: "expires_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "paid_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "completed_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "failed_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "failed_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "client_ip", Type: field.TypeString, Size: 50},
+ {Name: "src_host", Type: field.TypeString, Size: 255},
+ {Name: "src_url", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "user_id", Type: field.TypeInt64},
+ }
+ // PaymentOrdersTable holds the schema information for the "payment_orders" table.
+ PaymentOrdersTable = &schema.Table{
+ Name: "payment_orders",
+ Columns: PaymentOrdersColumns,
+ PrimaryKey: []*schema.Column{PaymentOrdersColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "payment_orders_users_payment_orders",
+ Columns: []*schema.Column{PaymentOrdersColumns[39]},
+ RefColumns: []*schema.Column{UsersColumns[0]},
+ OnDelete: schema.NoAction,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "paymentorder_out_trade_no",
+ Unique: true,
+ Columns: []*schema.Column{PaymentOrdersColumns[8]},
+ Annotation: &entsql.IndexAnnotation{
+ Where: "out_trade_no <> ''",
+ },
+ },
+ {
+ Name: "paymentorder_user_id",
+ Unique: false,
+ Columns: []*schema.Column{PaymentOrdersColumns[39]},
+ },
+ {
+ Name: "paymentorder_status",
+ Unique: false,
+ Columns: []*schema.Column{PaymentOrdersColumns[21]},
+ },
+ {
+ Name: "paymentorder_expires_at",
+ Unique: false,
+ Columns: []*schema.Column{PaymentOrdersColumns[29]},
+ },
+ {
+ Name: "paymentorder_created_at",
+ Unique: false,
+ Columns: []*schema.Column{PaymentOrdersColumns[37]},
+ },
+ {
+ Name: "paymentorder_paid_at",
+ Unique: false,
+ Columns: []*schema.Column{PaymentOrdersColumns[30]},
+ },
+ {
+ Name: "paymentorder_payment_type_paid_at",
+ Unique: false,
+ Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[30]},
+ },
+ {
+ Name: "paymentorder_order_type",
+ Unique: false,
+ Columns: []*schema.Column{PaymentOrdersColumns[14]},
+ },
+ },
+ }
+ // PaymentProviderInstancesColumns holds the columns for the "payment_provider_instances" table.
+ PaymentProviderInstancesColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "provider_key", Type: field.TypeString, Size: 30},
+ {Name: "name", Type: field.TypeString, Size: 100, Default: ""},
+ {Name: "config", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "supported_types", Type: field.TypeString, Size: 200, Default: ""},
+ {Name: "enabled", Type: field.TypeBool, Default: true},
+ {Name: "payment_mode", Type: field.TypeString, Size: 20, Default: ""},
+ {Name: "sort_order", Type: field.TypeInt, Default: 0},
+ {Name: "limits", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "refund_enabled", Type: field.TypeBool, Default: false},
+ {Name: "allow_user_refund", Type: field.TypeBool, Default: false},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ }
+ // PaymentProviderInstancesTable holds the schema information for the "payment_provider_instances" table.
+ PaymentProviderInstancesTable = &schema.Table{
+ Name: "payment_provider_instances",
+ Columns: PaymentProviderInstancesColumns,
+ PrimaryKey: []*schema.Column{PaymentProviderInstancesColumns[0]},
+ Indexes: []*schema.Index{
+ {
+ Name: "paymentproviderinstance_provider_key",
+ Unique: false,
+ Columns: []*schema.Column{PaymentProviderInstancesColumns[1]},
+ },
+ {
+ Name: "paymentproviderinstance_enabled",
+ Unique: false,
+ Columns: []*schema.Column{PaymentProviderInstancesColumns[5]},
+ },
+ },
+ }
+ // PendingAuthSessionsColumns holds the columns for the "pending_auth_sessions" table.
+ PendingAuthSessionsColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "session_token", Type: field.TypeString, Size: 255},
+ {Name: "intent", Type: field.TypeString, Size: 40},
+ {Name: "provider_type", Type: field.TypeString, Size: 20},
+ {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "redirect_to", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "resolved_email", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "registration_password_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "upstream_identity_claims", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "local_flow_state", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "browser_session_key", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "completion_code_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "completion_code_expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "email_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "password_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "totp_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "expires_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "consumed_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "target_user_id", Type: field.TypeInt64, Nullable: true},
+ }
+ // PendingAuthSessionsTable holds the schema information for the "pending_auth_sessions" table.
+ PendingAuthSessionsTable = &schema.Table{
+ Name: "pending_auth_sessions",
+ Columns: PendingAuthSessionsColumns,
+ PrimaryKey: []*schema.Column{PendingAuthSessionsColumns[0]},
+ ForeignKeys: []*schema.ForeignKey{
+ {
+ Symbol: "pending_auth_sessions_users_pending_auth_sessions",
+ Columns: []*schema.Column{PendingAuthSessionsColumns[21]},
+ RefColumns: []*schema.Column{UsersColumns[0]},
+ OnDelete: schema.SetNull,
+ },
+ },
+ Indexes: []*schema.Index{
+ {
+ Name: "pendingauthsession_session_token",
+ Unique: true,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[3]},
+ },
+ {
+ Name: "pendingauthsession_target_user_id",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[21]},
+ },
+ {
+ Name: "pendingauthsession_expires_at",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[19]},
+ },
+ {
+ Name: "pendingauthsession_provider_type_provider_key_provider_subject",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[5], PendingAuthSessionsColumns[6], PendingAuthSessionsColumns[7]},
+ },
+ {
+ Name: "pendingauthsession_completion_code_hash",
+ Unique: false,
+ Columns: []*schema.Column{PendingAuthSessionsColumns[14]},
+ },
+ },
+ }
// PromoCodesColumns holds the columns for the "promo_codes" table.
PromoCodesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -675,6 +1185,41 @@ var (
Columns: SettingsColumns,
PrimaryKey: []*schema.Column{SettingsColumns[0]},
}
+ // SubscriptionPlansColumns holds the columns for the "subscription_plans" table.
+ SubscriptionPlansColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "group_id", Type: field.TypeInt64},
+ {Name: "name", Type: field.TypeString, Size: 100},
+ {Name: "description", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "price", Type: field.TypeFloat64, SchemaType: map[string]string{"postgres": "decimal(20,2)"}},
+ {Name: "original_price", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,2)"}},
+ {Name: "validity_days", Type: field.TypeInt, Default: 30},
+ {Name: "validity_unit", Type: field.TypeString, Size: 10, Default: "day"},
+ {Name: "features", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "product_name", Type: field.TypeString, Size: 100, Default: ""},
+ {Name: "for_sale", Type: field.TypeBool, Default: true},
+ {Name: "sort_order", Type: field.TypeInt, Default: 0},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ }
+ // SubscriptionPlansTable holds the schema information for the "subscription_plans" table.
+ SubscriptionPlansTable = &schema.Table{
+ Name: "subscription_plans",
+ Columns: SubscriptionPlansColumns,
+ PrimaryKey: []*schema.Column{SubscriptionPlansColumns[0]},
+ Indexes: []*schema.Index{
+ {
+ Name: "subscriptionplan_group_id",
+ Unique: false,
+ Columns: []*schema.Column{SubscriptionPlansColumns[1]},
+ },
+ {
+ Name: "subscriptionplan_for_sale",
+ Unique: false,
+ Columns: []*schema.Column{SubscriptionPlansColumns[10]},
+ },
+ },
+ }
// TLSFingerprintProfilesColumns holds the columns for the "tls_fingerprint_profiles" table.
TLSFingerprintProfilesColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -744,6 +1289,10 @@ var (
{Name: "model", Type: field.TypeString, Size: 100},
{Name: "requested_model", Type: field.TypeString, Nullable: true, Size: 100},
{Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100},
+ {Name: "channel_id", Type: field.TypeInt64, Nullable: true},
+ {Name: "model_mapping_chain", Type: field.TypeString, Nullable: true, Size: 500},
+ {Name: "billing_tier", Type: field.TypeString, Nullable: true, Size: 50},
+ {Name: "billing_mode", Type: field.TypeString, Nullable: true, Size: 20},
{Name: "input_tokens", Type: field.TypeInt, Default: 0},
{Name: "output_tokens", Type: field.TypeInt, Default: 0},
{Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0},
@@ -766,7 +1315,6 @@ var (
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
{Name: "image_count", Type: field.TypeInt, Default: 0},
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
- {Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16},
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "api_key_id", Type: field.TypeInt64},
@@ -783,31 +1331,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[30]},
+ Columns: []*schema.Column{UsageLogsColumns[33]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[31]},
+ Columns: []*schema.Column{UsageLogsColumns[34]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[32]},
+ Columns: []*schema.Column{UsageLogsColumns[35]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[33]},
+ Columns: []*schema.Column{UsageLogsColumns[36]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[34]},
+ Columns: []*schema.Column{UsageLogsColumns[37]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -816,32 +1364,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[33]},
+ Columns: []*schema.Column{UsageLogsColumns[36]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[30]},
+ Columns: []*schema.Column{UsageLogsColumns[33]},
},
{
Name: "usagelog_account_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[31]},
+ Columns: []*schema.Column{UsageLogsColumns[34]},
},
{
Name: "usagelog_group_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[32]},
+ Columns: []*schema.Column{UsageLogsColumns[35]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[34]},
+ Columns: []*schema.Column{UsageLogsColumns[37]},
},
{
Name: "usagelog_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[29]},
+ Columns: []*schema.Column{UsageLogsColumns[32]},
},
{
Name: "usagelog_model",
@@ -861,17 +1409,17 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[29]},
+ Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[32]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[29]},
+ Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[32]},
},
{
Name: "usagelog_group_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[29]},
+ Columns: []*schema.Column{UsageLogsColumns[35], UsageLogsColumns[32]},
},
},
}
@@ -892,8 +1440,15 @@ var (
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
- {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0},
- {Name: "sora_storage_used_bytes", Type: field.TypeInt64, Default: 0},
+ {Name: "signup_source", Type: field.TypeString, Default: "email"},
+ {Name: "last_login_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "last_active_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "balance_notify_enabled", Type: field.TypeBool, Default: true},
+ {Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"},
+ {Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
+ {Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "total_recharged", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
+ {Name: "rpm_limit", Type: field.TypeInt, Default: 0},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{
@@ -1128,15 +1683,27 @@ var (
AccountGroupsTable,
AnnouncementsTable,
AnnouncementReadsTable,
+ AuthIdentitiesTable,
+ AuthIdentityChannelsTable,
+ ChannelMonitorsTable,
+ ChannelMonitorDailyRollupsTable,
+ ChannelMonitorHistoriesTable,
+ ChannelMonitorRequestTemplatesTable,
ErrorPassthroughRulesTable,
GroupsTable,
IdempotencyRecordsTable,
+ IdentityAdoptionDecisionsTable,
+ PaymentAuditLogsTable,
+ PaymentOrdersTable,
+ PaymentProviderInstancesTable,
+ PendingAuthSessionsTable,
PromoCodesTable,
PromoCodeUsagesTable,
ProxiesTable,
RedeemCodesTable,
SecuritySecretsTable,
SettingsTable,
+ SubscriptionPlansTable,
TLSFingerprintProfilesTable,
UsageCleanupTasksTable,
UsageLogsTable,
@@ -1171,6 +1738,29 @@ func init() {
AnnouncementReadsTable.Annotation = &entsql.Annotation{
Table: "announcement_reads",
}
+ AuthIdentitiesTable.ForeignKeys[0].RefTable = UsersTable
+ AuthIdentitiesTable.Annotation = &entsql.Annotation{
+ Table: "auth_identities",
+ }
+ AuthIdentityChannelsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable
+ AuthIdentityChannelsTable.Annotation = &entsql.Annotation{
+ Table: "auth_identity_channels",
+ }
+ ChannelMonitorsTable.ForeignKeys[0].RefTable = ChannelMonitorRequestTemplatesTable
+ ChannelMonitorsTable.Annotation = &entsql.Annotation{
+ Table: "channel_monitors",
+ }
+ ChannelMonitorDailyRollupsTable.ForeignKeys[0].RefTable = ChannelMonitorsTable
+ ChannelMonitorDailyRollupsTable.Annotation = &entsql.Annotation{
+ Table: "channel_monitor_daily_rollups",
+ }
+ ChannelMonitorHistoriesTable.ForeignKeys[0].RefTable = ChannelMonitorsTable
+ ChannelMonitorHistoriesTable.Annotation = &entsql.Annotation{
+ Table: "channel_monitor_histories",
+ }
+ ChannelMonitorRequestTemplatesTable.Annotation = &entsql.Annotation{
+ Table: "channel_monitor_request_templates",
+ }
ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{
Table: "error_passthrough_rules",
}
@@ -1180,6 +1770,25 @@ func init() {
IdempotencyRecordsTable.Annotation = &entsql.Annotation{
Table: "idempotency_records",
}
+ IdentityAdoptionDecisionsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable
+ IdentityAdoptionDecisionsTable.ForeignKeys[1].RefTable = PendingAuthSessionsTable
+ IdentityAdoptionDecisionsTable.Annotation = &entsql.Annotation{
+ Table: "identity_adoption_decisions",
+ }
+ PaymentAuditLogsTable.Annotation = &entsql.Annotation{
+ Table: "payment_audit_logs",
+ }
+ PaymentOrdersTable.ForeignKeys[0].RefTable = UsersTable
+ PaymentOrdersTable.Annotation = &entsql.Annotation{
+ Table: "payment_orders",
+ }
+ PaymentProviderInstancesTable.Annotation = &entsql.Annotation{
+ Table: "payment_provider_instances",
+ }
+ PendingAuthSessionsTable.ForeignKeys[0].RefTable = UsersTable
+ PendingAuthSessionsTable.Annotation = &entsql.Annotation{
+ Table: "pending_auth_sessions",
+ }
PromoCodesTable.Annotation = &entsql.Annotation{
Table: "promo_codes",
}
@@ -1202,6 +1811,9 @@ func init() {
SettingsTable.Annotation = &entsql.Annotation{
Table: "settings",
}
+ SubscriptionPlansTable.Annotation = &entsql.Annotation{
+ Table: "subscription_plans",
+ }
TLSFingerprintProfilesTable.Annotation = &entsql.Annotation{
Table: "tls_fingerprint_profiles",
}
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index a862209d..d616e4ae 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -17,9 +17,20 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
@@ -27,6 +38,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
+ "github.com/Wei-Shaw/sub2api/ent/subscriptionplan"
"github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -47,28 +59,40 @@ const (
OpUpdateOne = ent.OpUpdateOne
// Node types.
- TypeAPIKey = "APIKey"
- TypeAccount = "Account"
- TypeAccountGroup = "AccountGroup"
- TypeAnnouncement = "Announcement"
- TypeAnnouncementRead = "AnnouncementRead"
- TypeErrorPassthroughRule = "ErrorPassthroughRule"
- TypeGroup = "Group"
- TypeIdempotencyRecord = "IdempotencyRecord"
- TypePromoCode = "PromoCode"
- TypePromoCodeUsage = "PromoCodeUsage"
- TypeProxy = "Proxy"
- TypeRedeemCode = "RedeemCode"
- TypeSecuritySecret = "SecuritySecret"
- TypeSetting = "Setting"
- TypeTLSFingerprintProfile = "TLSFingerprintProfile"
- TypeUsageCleanupTask = "UsageCleanupTask"
- TypeUsageLog = "UsageLog"
- TypeUser = "User"
- TypeUserAllowedGroup = "UserAllowedGroup"
- TypeUserAttributeDefinition = "UserAttributeDefinition"
- TypeUserAttributeValue = "UserAttributeValue"
- TypeUserSubscription = "UserSubscription"
+ TypeAPIKey = "APIKey"
+ TypeAccount = "Account"
+ TypeAccountGroup = "AccountGroup"
+ TypeAnnouncement = "Announcement"
+ TypeAnnouncementRead = "AnnouncementRead"
+ TypeAuthIdentity = "AuthIdentity"
+ TypeAuthIdentityChannel = "AuthIdentityChannel"
+ TypeChannelMonitor = "ChannelMonitor"
+ TypeChannelMonitorDailyRollup = "ChannelMonitorDailyRollup"
+ TypeChannelMonitorHistory = "ChannelMonitorHistory"
+ TypeChannelMonitorRequestTemplate = "ChannelMonitorRequestTemplate"
+ TypeErrorPassthroughRule = "ErrorPassthroughRule"
+ TypeGroup = "Group"
+ TypeIdempotencyRecord = "IdempotencyRecord"
+ TypeIdentityAdoptionDecision = "IdentityAdoptionDecision"
+ TypePaymentAuditLog = "PaymentAuditLog"
+ TypePaymentOrder = "PaymentOrder"
+ TypePaymentProviderInstance = "PaymentProviderInstance"
+ TypePendingAuthSession = "PendingAuthSession"
+ TypePromoCode = "PromoCode"
+ TypePromoCodeUsage = "PromoCodeUsage"
+ TypeProxy = "Proxy"
+ TypeRedeemCode = "RedeemCode"
+ TypeSecuritySecret = "SecuritySecret"
+ TypeSetting = "Setting"
+ TypeSubscriptionPlan = "SubscriptionPlan"
+ TypeTLSFingerprintProfile = "TLSFingerprintProfile"
+ TypeUsageCleanupTask = "UsageCleanupTask"
+ TypeUsageLog = "UsageLog"
+ TypeUser = "User"
+ TypeUserAllowedGroup = "UserAllowedGroup"
+ TypeUserAttributeDefinition = "UserAttributeDefinition"
+ TypeUserAttributeValue = "UserAttributeValue"
+ TypeUserSubscription = "UserSubscription"
)
// APIKeyMutation represents an operation that mutates the APIKey nodes in the graph.
@@ -6879,6 +6903,6522 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown AnnouncementRead edge %s", name)
}
+// AuthIdentityMutation represents an operation that mutates the AuthIdentity nodes in the graph.
+type AuthIdentityMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ provider_type *string
+ provider_key *string
+ provider_subject *string
+ verified_at *time.Time
+ issuer *string
+ metadata *map[string]interface{}
+ clearedFields map[string]struct{}
+ user *int64
+ cleareduser bool
+ channels map[int64]struct{}
+ removedchannels map[int64]struct{}
+ clearedchannels bool
+ adoption_decisions map[int64]struct{}
+ removedadoption_decisions map[int64]struct{}
+ clearedadoption_decisions bool
+ done bool
+ oldValue func(context.Context) (*AuthIdentity, error)
+ predicates []predicate.AuthIdentity
+}
+
+var _ ent.Mutation = (*AuthIdentityMutation)(nil)
+
+// authidentityOption allows management of the mutation configuration using functional options.
+type authidentityOption func(*AuthIdentityMutation)
+
+// newAuthIdentityMutation creates new mutation for the AuthIdentity entity.
+func newAuthIdentityMutation(c config, op Op, opts ...authidentityOption) *AuthIdentityMutation {
+ m := &AuthIdentityMutation{
+ config: c,
+ op: op,
+ typ: TypeAuthIdentity,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withAuthIdentityID sets the ID field of the mutation.
+func withAuthIdentityID(id int64) authidentityOption {
+ return func(m *AuthIdentityMutation) {
+ var (
+ err error
+ once sync.Once
+ value *AuthIdentity
+ )
+ m.oldValue = func(ctx context.Context) (*AuthIdentity, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().AuthIdentity.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withAuthIdentity sets the old AuthIdentity of the mutation.
+func withAuthIdentity(node *AuthIdentity) authidentityOption {
+ return func(m *AuthIdentityMutation) {
+ m.oldValue = func(context.Context) (*AuthIdentity, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m AuthIdentityMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m AuthIdentityMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *AuthIdentityMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *AuthIdentityMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().AuthIdentity.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *AuthIdentityMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *AuthIdentityMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *AuthIdentityMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *AuthIdentityMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *AuthIdentityMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *AuthIdentityMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetUserID sets the "user_id" field.
+func (m *AuthIdentityMutation) SetUserID(i int64) {
+ m.user = &i
+}
+
+// UserID returns the value of the "user_id" field in the mutation.
+func (m *AuthIdentityMutation) UserID() (r int64, exists bool) {
+ v := m.user
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUserID returns the old "user_id" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldUserID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUserID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUserID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUserID: %w", err)
+ }
+ return oldValue.UserID, nil
+}
+
+// ResetUserID resets all changes to the "user_id" field.
+func (m *AuthIdentityMutation) ResetUserID() {
+ m.user = nil
+}
+
+// SetProviderType sets the "provider_type" field.
+func (m *AuthIdentityMutation) SetProviderType(s string) {
+ m.provider_type = &s
+}
+
+// ProviderType returns the value of the "provider_type" field in the mutation.
+func (m *AuthIdentityMutation) ProviderType() (r string, exists bool) {
+ v := m.provider_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderType returns the old "provider_type" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldProviderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderType: %w", err)
+ }
+ return oldValue.ProviderType, nil
+}
+
+// ResetProviderType resets all changes to the "provider_type" field.
+func (m *AuthIdentityMutation) ResetProviderType() {
+ m.provider_type = nil
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *AuthIdentityMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *AuthIdentityMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *AuthIdentityMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (m *AuthIdentityMutation) SetProviderSubject(s string) {
+ m.provider_subject = &s
+}
+
+// ProviderSubject returns the value of the "provider_subject" field in the mutation.
+func (m *AuthIdentityMutation) ProviderSubject() (r string, exists bool) {
+ v := m.provider_subject
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderSubject returns the old "provider_subject" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldProviderSubject(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderSubject requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err)
+ }
+ return oldValue.ProviderSubject, nil
+}
+
+// ResetProviderSubject resets all changes to the "provider_subject" field.
+func (m *AuthIdentityMutation) ResetProviderSubject() {
+ m.provider_subject = nil
+}
+
+// SetVerifiedAt sets the "verified_at" field.
+func (m *AuthIdentityMutation) SetVerifiedAt(t time.Time) {
+ m.verified_at = &t
+}
+
+// VerifiedAt returns the value of the "verified_at" field in the mutation.
+func (m *AuthIdentityMutation) VerifiedAt() (r time.Time, exists bool) {
+ v := m.verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldVerifiedAt returns the old "verified_at" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldVerifiedAt: %w", err)
+ }
+ return oldValue.VerifiedAt, nil
+}
+
+// ClearVerifiedAt clears the value of the "verified_at" field.
+func (m *AuthIdentityMutation) ClearVerifiedAt() {
+ m.verified_at = nil
+ m.clearedFields[authidentity.FieldVerifiedAt] = struct{}{}
+}
+
+// VerifiedAtCleared returns if the "verified_at" field was cleared in this mutation.
+func (m *AuthIdentityMutation) VerifiedAtCleared() bool {
+ _, ok := m.clearedFields[authidentity.FieldVerifiedAt]
+ return ok
+}
+
+// ResetVerifiedAt resets all changes to the "verified_at" field.
+func (m *AuthIdentityMutation) ResetVerifiedAt() {
+ m.verified_at = nil
+ delete(m.clearedFields, authidentity.FieldVerifiedAt)
+}
+
+// SetIssuer sets the "issuer" field.
+func (m *AuthIdentityMutation) SetIssuer(s string) {
+ m.issuer = &s
+}
+
+// Issuer returns the value of the "issuer" field in the mutation.
+func (m *AuthIdentityMutation) Issuer() (r string, exists bool) {
+ v := m.issuer
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIssuer returns the old "issuer" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldIssuer(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIssuer is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIssuer requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIssuer: %w", err)
+ }
+ return oldValue.Issuer, nil
+}
+
+// ClearIssuer clears the value of the "issuer" field.
+func (m *AuthIdentityMutation) ClearIssuer() {
+ m.issuer = nil
+ m.clearedFields[authidentity.FieldIssuer] = struct{}{}
+}
+
+// IssuerCleared returns if the "issuer" field was cleared in this mutation.
+func (m *AuthIdentityMutation) IssuerCleared() bool {
+ _, ok := m.clearedFields[authidentity.FieldIssuer]
+ return ok
+}
+
+// ResetIssuer resets all changes to the "issuer" field.
+func (m *AuthIdentityMutation) ResetIssuer() {
+ m.issuer = nil
+ delete(m.clearedFields, authidentity.FieldIssuer)
+}
+
+// SetMetadata sets the "metadata" field.
+func (m *AuthIdentityMutation) SetMetadata(value map[string]interface{}) {
+ m.metadata = &value
+}
+
+// Metadata returns the value of the "metadata" field in the mutation.
+func (m *AuthIdentityMutation) Metadata() (r map[string]interface{}, exists bool) {
+ v := m.metadata
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMetadata returns the old "metadata" field's value of the AuthIdentity entity.
+// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMetadata is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMetadata requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMetadata: %w", err)
+ }
+ return oldValue.Metadata, nil
+}
+
+// ResetMetadata resets all changes to the "metadata" field.
+func (m *AuthIdentityMutation) ResetMetadata() {
+ m.metadata = nil
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (m *AuthIdentityMutation) ClearUser() {
+ m.cleareduser = true
+ m.clearedFields[authidentity.FieldUserID] = struct{}{}
+}
+
+// UserCleared reports if the "user" edge to the User entity was cleared.
+func (m *AuthIdentityMutation) UserCleared() bool {
+ return m.cleareduser
+}
+
+// UserIDs returns the "user" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// UserID instead. It exists only for internal usage by the builders.
+func (m *AuthIdentityMutation) UserIDs() (ids []int64) {
+ if id := m.user; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetUser resets all changes to the "user" edge.
+func (m *AuthIdentityMutation) ResetUser() {
+ m.user = nil
+ m.cleareduser = false
+}
+
+// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by ids.
+func (m *AuthIdentityMutation) AddChannelIDs(ids ...int64) {
+ if m.channels == nil {
+ m.channels = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.channels[ids[i]] = struct{}{}
+ }
+}
+
+// ClearChannels clears the "channels" edge to the AuthIdentityChannel entity.
+func (m *AuthIdentityMutation) ClearChannels() {
+ m.clearedchannels = true
+}
+
+// ChannelsCleared reports if the "channels" edge to the AuthIdentityChannel entity was cleared.
+func (m *AuthIdentityMutation) ChannelsCleared() bool {
+ return m.clearedchannels
+}
+
+// RemoveChannelIDs removes the "channels" edge to the AuthIdentityChannel entity by IDs.
+func (m *AuthIdentityMutation) RemoveChannelIDs(ids ...int64) {
+ if m.removedchannels == nil {
+ m.removedchannels = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.channels, ids[i])
+ m.removedchannels[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedChannels returns the removed IDs of the "channels" edge to the AuthIdentityChannel entity.
+func (m *AuthIdentityMutation) RemovedChannelsIDs() (ids []int64) {
+ for id := range m.removedchannels {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ChannelsIDs returns the "channels" edge IDs in the mutation.
+func (m *AuthIdentityMutation) ChannelsIDs() (ids []int64) {
+ for id := range m.channels {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetChannels resets all changes to the "channels" edge.
+func (m *AuthIdentityMutation) ResetChannels() {
+ m.channels = nil
+ m.clearedchannels = false
+ m.removedchannels = nil
+}
+
+// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by ids.
+func (m *AuthIdentityMutation) AddAdoptionDecisionIDs(ids ...int64) {
+ if m.adoption_decisions == nil {
+ m.adoption_decisions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.adoption_decisions[ids[i]] = struct{}{}
+ }
+}
+
+// ClearAdoptionDecisions clears the "adoption_decisions" edge to the IdentityAdoptionDecision entity.
+func (m *AuthIdentityMutation) ClearAdoptionDecisions() {
+ m.clearedadoption_decisions = true
+}
+
+// AdoptionDecisionsCleared reports if the "adoption_decisions" edge to the IdentityAdoptionDecision entity was cleared.
+func (m *AuthIdentityMutation) AdoptionDecisionsCleared() bool {
+ return m.clearedadoption_decisions
+}
+
+// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs.
+func (m *AuthIdentityMutation) RemoveAdoptionDecisionIDs(ids ...int64) {
+ if m.removedadoption_decisions == nil {
+ m.removedadoption_decisions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.adoption_decisions, ids[i])
+ m.removedadoption_decisions[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedAdoptionDecisions returns the removed IDs of the "adoption_decisions" edge to the IdentityAdoptionDecision entity.
+func (m *AuthIdentityMutation) RemovedAdoptionDecisionsIDs() (ids []int64) {
+ for id := range m.removedadoption_decisions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// AdoptionDecisionsIDs returns the "adoption_decisions" edge IDs in the mutation.
+func (m *AuthIdentityMutation) AdoptionDecisionsIDs() (ids []int64) {
+ for id := range m.adoption_decisions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetAdoptionDecisions resets all changes to the "adoption_decisions" edge.
+func (m *AuthIdentityMutation) ResetAdoptionDecisions() {
+ m.adoption_decisions = nil
+ m.clearedadoption_decisions = false
+ m.removedadoption_decisions = nil
+}
+
+// Where appends a list predicates to the AuthIdentityMutation builder.
+func (m *AuthIdentityMutation) Where(ps ...predicate.AuthIdentity) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the AuthIdentityMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *AuthIdentityMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.AuthIdentity, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *AuthIdentityMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *AuthIdentityMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (AuthIdentity).
+func (m *AuthIdentityMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *AuthIdentityMutation) Fields() []string {
+ fields := make([]string, 0, 9)
+ if m.created_at != nil {
+ fields = append(fields, authidentity.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, authidentity.FieldUpdatedAt)
+ }
+ if m.user != nil {
+ fields = append(fields, authidentity.FieldUserID)
+ }
+ if m.provider_type != nil {
+ fields = append(fields, authidentity.FieldProviderType)
+ }
+ if m.provider_key != nil {
+ fields = append(fields, authidentity.FieldProviderKey)
+ }
+ if m.provider_subject != nil {
+ fields = append(fields, authidentity.FieldProviderSubject)
+ }
+ if m.verified_at != nil {
+ fields = append(fields, authidentity.FieldVerifiedAt)
+ }
+ if m.issuer != nil {
+ fields = append(fields, authidentity.FieldIssuer)
+ }
+ if m.metadata != nil {
+ fields = append(fields, authidentity.FieldMetadata)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *AuthIdentityMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ return m.CreatedAt()
+ case authidentity.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case authidentity.FieldUserID:
+ return m.UserID()
+ case authidentity.FieldProviderType:
+ return m.ProviderType()
+ case authidentity.FieldProviderKey:
+ return m.ProviderKey()
+ case authidentity.FieldProviderSubject:
+ return m.ProviderSubject()
+ case authidentity.FieldVerifiedAt:
+ return m.VerifiedAt()
+ case authidentity.FieldIssuer:
+ return m.Issuer()
+ case authidentity.FieldMetadata:
+ return m.Metadata()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *AuthIdentityMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case authidentity.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case authidentity.FieldUserID:
+ return m.OldUserID(ctx)
+ case authidentity.FieldProviderType:
+ return m.OldProviderType(ctx)
+ case authidentity.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case authidentity.FieldProviderSubject:
+ return m.OldProviderSubject(ctx)
+ case authidentity.FieldVerifiedAt:
+ return m.OldVerifiedAt(ctx)
+ case authidentity.FieldIssuer:
+ return m.OldIssuer(ctx)
+ case authidentity.FieldMetadata:
+ return m.OldMetadata(ctx)
+ }
+ return nil, fmt.Errorf("unknown AuthIdentity field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case authidentity.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case authidentity.FieldUserID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUserID(v)
+ return nil
+ case authidentity.FieldProviderType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderType(v)
+ return nil
+ case authidentity.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case authidentity.FieldProviderSubject:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderSubject(v)
+ return nil
+ case authidentity.FieldVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetVerifiedAt(v)
+ return nil
+ case authidentity.FieldIssuer:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIssuer(v)
+ return nil
+ case authidentity.FieldMetadata:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMetadata(v)
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *AuthIdentityMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *AuthIdentityMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown AuthIdentity numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *AuthIdentityMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(authidentity.FieldVerifiedAt) {
+ fields = append(fields, authidentity.FieldVerifiedAt)
+ }
+ if m.FieldCleared(authidentity.FieldIssuer) {
+ fields = append(fields, authidentity.FieldIssuer)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *AuthIdentityMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *AuthIdentityMutation) ClearField(name string) error {
+ switch name {
+ case authidentity.FieldVerifiedAt:
+ m.ClearVerifiedAt()
+ return nil
+ case authidentity.FieldIssuer:
+ m.ClearIssuer()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *AuthIdentityMutation) ResetField(name string) error {
+ switch name {
+ case authidentity.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case authidentity.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case authidentity.FieldUserID:
+ m.ResetUserID()
+ return nil
+ case authidentity.FieldProviderType:
+ m.ResetProviderType()
+ return nil
+ case authidentity.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case authidentity.FieldProviderSubject:
+ m.ResetProviderSubject()
+ return nil
+ case authidentity.FieldVerifiedAt:
+ m.ResetVerifiedAt()
+ return nil
+ case authidentity.FieldIssuer:
+ m.ResetIssuer()
+ return nil
+ case authidentity.FieldMetadata:
+ m.ResetMetadata()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *AuthIdentityMutation) AddedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.user != nil {
+ edges = append(edges, authidentity.EdgeUser)
+ }
+ if m.channels != nil {
+ edges = append(edges, authidentity.EdgeChannels)
+ }
+ if m.adoption_decisions != nil {
+ edges = append(edges, authidentity.EdgeAdoptionDecisions)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *AuthIdentityMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case authidentity.EdgeUser:
+ if id := m.user; id != nil {
+ return []ent.Value{*id}
+ }
+ case authidentity.EdgeChannels:
+ ids := make([]ent.Value, 0, len(m.channels))
+ for id := range m.channels {
+ ids = append(ids, id)
+ }
+ return ids
+ case authidentity.EdgeAdoptionDecisions:
+ ids := make([]ent.Value, 0, len(m.adoption_decisions))
+ for id := range m.adoption_decisions {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *AuthIdentityMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.removedchannels != nil {
+ edges = append(edges, authidentity.EdgeChannels)
+ }
+ if m.removedadoption_decisions != nil {
+ edges = append(edges, authidentity.EdgeAdoptionDecisions)
+ }
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *AuthIdentityMutation) RemovedIDs(name string) []ent.Value {
+ switch name {
+ case authidentity.EdgeChannels:
+ ids := make([]ent.Value, 0, len(m.removedchannels))
+ for id := range m.removedchannels {
+ ids = append(ids, id)
+ }
+ return ids
+ case authidentity.EdgeAdoptionDecisions:
+ ids := make([]ent.Value, 0, len(m.removedadoption_decisions))
+ for id := range m.removedadoption_decisions {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *AuthIdentityMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.cleareduser {
+ edges = append(edges, authidentity.EdgeUser)
+ }
+ if m.clearedchannels {
+ edges = append(edges, authidentity.EdgeChannels)
+ }
+ if m.clearedadoption_decisions {
+ edges = append(edges, authidentity.EdgeAdoptionDecisions)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *AuthIdentityMutation) EdgeCleared(name string) bool {
+ switch name {
+ case authidentity.EdgeUser:
+ return m.cleareduser
+ case authidentity.EdgeChannels:
+ return m.clearedchannels
+ case authidentity.EdgeAdoptionDecisions:
+ return m.clearedadoption_decisions
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *AuthIdentityMutation) ClearEdge(name string) error {
+ switch name {
+ case authidentity.EdgeUser:
+ m.ClearUser()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *AuthIdentityMutation) ResetEdge(name string) error {
+ switch name {
+ case authidentity.EdgeUser:
+ m.ResetUser()
+ return nil
+ case authidentity.EdgeChannels:
+ m.ResetChannels()
+ return nil
+ case authidentity.EdgeAdoptionDecisions:
+ m.ResetAdoptionDecisions()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentity edge %s", name)
+}
+
+// AuthIdentityChannelMutation represents an operation that mutates the AuthIdentityChannel nodes in the graph.
+type AuthIdentityChannelMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ provider_type *string
+ provider_key *string
+ channel *string
+ channel_app_id *string
+ channel_subject *string
+ metadata *map[string]interface{}
+ clearedFields map[string]struct{}
+ identity *int64
+ clearedidentity bool
+ done bool
+ oldValue func(context.Context) (*AuthIdentityChannel, error)
+ predicates []predicate.AuthIdentityChannel
+}
+
+var _ ent.Mutation = (*AuthIdentityChannelMutation)(nil)
+
+// authidentitychannelOption allows management of the mutation configuration using functional options.
+type authidentitychannelOption func(*AuthIdentityChannelMutation)
+
+// newAuthIdentityChannelMutation creates new mutation for the AuthIdentityChannel entity.
+func newAuthIdentityChannelMutation(c config, op Op, opts ...authidentitychannelOption) *AuthIdentityChannelMutation {
+ m := &AuthIdentityChannelMutation{
+ config: c,
+ op: op,
+ typ: TypeAuthIdentityChannel,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withAuthIdentityChannelID sets the ID field of the mutation.
+func withAuthIdentityChannelID(id int64) authidentitychannelOption {
+ return func(m *AuthIdentityChannelMutation) {
+ var (
+ err error
+ once sync.Once
+ value *AuthIdentityChannel
+ )
+ m.oldValue = func(ctx context.Context) (*AuthIdentityChannel, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().AuthIdentityChannel.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withAuthIdentityChannel sets the old AuthIdentityChannel of the mutation.
+func withAuthIdentityChannel(node *AuthIdentityChannel) authidentitychannelOption {
+ return func(m *AuthIdentityChannelMutation) {
+ m.oldValue = func(context.Context) (*AuthIdentityChannel, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m AuthIdentityChannelMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m AuthIdentityChannelMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *AuthIdentityChannelMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *AuthIdentityChannelMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().AuthIdentityChannel.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *AuthIdentityChannelMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *AuthIdentityChannelMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *AuthIdentityChannelMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *AuthIdentityChannelMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *AuthIdentityChannelMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *AuthIdentityChannelMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (m *AuthIdentityChannelMutation) SetIdentityID(i int64) {
+ m.identity = &i
+}
+
+// IdentityID returns the value of the "identity_id" field in the mutation.
+func (m *AuthIdentityChannelMutation) IdentityID() (r int64, exists bool) {
+ v := m.identity
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIdentityID returns the old "identity_id" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldIdentityID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIdentityID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIdentityID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIdentityID: %w", err)
+ }
+ return oldValue.IdentityID, nil
+}
+
+// ResetIdentityID resets all changes to the "identity_id" field.
+func (m *AuthIdentityChannelMutation) ResetIdentityID() {
+ m.identity = nil
+}
+
+// SetProviderType sets the "provider_type" field.
+func (m *AuthIdentityChannelMutation) SetProviderType(s string) {
+ m.provider_type = &s
+}
+
+// ProviderType returns the value of the "provider_type" field in the mutation.
+func (m *AuthIdentityChannelMutation) ProviderType() (r string, exists bool) {
+ v := m.provider_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderType returns the old "provider_type" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldProviderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderType: %w", err)
+ }
+ return oldValue.ProviderType, nil
+}
+
+// ResetProviderType resets all changes to the "provider_type" field.
+func (m *AuthIdentityChannelMutation) ResetProviderType() {
+ m.provider_type = nil
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *AuthIdentityChannelMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *AuthIdentityChannelMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *AuthIdentityChannelMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetChannel sets the "channel" field.
+func (m *AuthIdentityChannelMutation) SetChannel(s string) {
+ m.channel = &s
+}
+
+// Channel returns the value of the "channel" field in the mutation.
+func (m *AuthIdentityChannelMutation) Channel() (r string, exists bool) {
+ v := m.channel
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannel returns the old "channel" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldChannel(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannel is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannel requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannel: %w", err)
+ }
+ return oldValue.Channel, nil
+}
+
+// ResetChannel resets all changes to the "channel" field.
+func (m *AuthIdentityChannelMutation) ResetChannel() {
+ m.channel = nil
+}
+
+// SetChannelAppID sets the "channel_app_id" field.
+func (m *AuthIdentityChannelMutation) SetChannelAppID(s string) {
+ m.channel_app_id = &s
+}
+
+// ChannelAppID returns the value of the "channel_app_id" field in the mutation.
+func (m *AuthIdentityChannelMutation) ChannelAppID() (r string, exists bool) {
+ v := m.channel_app_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannelAppID returns the old "channel_app_id" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldChannelAppID(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannelAppID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannelAppID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannelAppID: %w", err)
+ }
+ return oldValue.ChannelAppID, nil
+}
+
+// ResetChannelAppID resets all changes to the "channel_app_id" field.
+func (m *AuthIdentityChannelMutation) ResetChannelAppID() {
+ m.channel_app_id = nil
+}
+
+// SetChannelSubject sets the "channel_subject" field.
+func (m *AuthIdentityChannelMutation) SetChannelSubject(s string) {
+ m.channel_subject = &s
+}
+
+// ChannelSubject returns the value of the "channel_subject" field in the mutation.
+func (m *AuthIdentityChannelMutation) ChannelSubject() (r string, exists bool) {
+ v := m.channel_subject
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannelSubject returns the old "channel_subject" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldChannelSubject(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannelSubject is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannelSubject requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannelSubject: %w", err)
+ }
+ return oldValue.ChannelSubject, nil
+}
+
+// ResetChannelSubject resets all changes to the "channel_subject" field.
+func (m *AuthIdentityChannelMutation) ResetChannelSubject() {
+ m.channel_subject = nil
+}
+
+// SetMetadata sets the "metadata" field.
+func (m *AuthIdentityChannelMutation) SetMetadata(value map[string]interface{}) {
+ m.metadata = &value
+}
+
+// Metadata returns the value of the "metadata" field in the mutation.
+func (m *AuthIdentityChannelMutation) Metadata() (r map[string]interface{}, exists bool) {
+ v := m.metadata
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMetadata returns the old "metadata" field's value of the AuthIdentityChannel entity.
+// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AuthIdentityChannelMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMetadata is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMetadata requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMetadata: %w", err)
+ }
+ return oldValue.Metadata, nil
+}
+
+// ResetMetadata resets all changes to the "metadata" field.
+func (m *AuthIdentityChannelMutation) ResetMetadata() {
+ m.metadata = nil
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (m *AuthIdentityChannelMutation) ClearIdentity() {
+ m.clearedidentity = true
+ m.clearedFields[authidentitychannel.FieldIdentityID] = struct{}{}
+}
+
+// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared.
+func (m *AuthIdentityChannelMutation) IdentityCleared() bool {
+ return m.clearedidentity
+}
+
+// IdentityIDs returns the "identity" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// IdentityID instead. It exists only for internal usage by the builders.
+func (m *AuthIdentityChannelMutation) IdentityIDs() (ids []int64) {
+ if id := m.identity; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetIdentity resets all changes to the "identity" edge.
+func (m *AuthIdentityChannelMutation) ResetIdentity() {
+ m.identity = nil
+ m.clearedidentity = false
+}
+
+// Where appends a list predicates to the AuthIdentityChannelMutation builder.
+func (m *AuthIdentityChannelMutation) Where(ps ...predicate.AuthIdentityChannel) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the AuthIdentityChannelMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *AuthIdentityChannelMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.AuthIdentityChannel, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *AuthIdentityChannelMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *AuthIdentityChannelMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (AuthIdentityChannel).
+func (m *AuthIdentityChannelMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *AuthIdentityChannelMutation) Fields() []string {
+ fields := make([]string, 0, 9)
+ if m.created_at != nil {
+ fields = append(fields, authidentitychannel.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, authidentitychannel.FieldUpdatedAt)
+ }
+ if m.identity != nil {
+ fields = append(fields, authidentitychannel.FieldIdentityID)
+ }
+ if m.provider_type != nil {
+ fields = append(fields, authidentitychannel.FieldProviderType)
+ }
+ if m.provider_key != nil {
+ fields = append(fields, authidentitychannel.FieldProviderKey)
+ }
+ if m.channel != nil {
+ fields = append(fields, authidentitychannel.FieldChannel)
+ }
+ if m.channel_app_id != nil {
+ fields = append(fields, authidentitychannel.FieldChannelAppID)
+ }
+ if m.channel_subject != nil {
+ fields = append(fields, authidentitychannel.FieldChannelSubject)
+ }
+ if m.metadata != nil {
+ fields = append(fields, authidentitychannel.FieldMetadata)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *AuthIdentityChannelMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ return m.CreatedAt()
+ case authidentitychannel.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case authidentitychannel.FieldIdentityID:
+ return m.IdentityID()
+ case authidentitychannel.FieldProviderType:
+ return m.ProviderType()
+ case authidentitychannel.FieldProviderKey:
+ return m.ProviderKey()
+ case authidentitychannel.FieldChannel:
+ return m.Channel()
+ case authidentitychannel.FieldChannelAppID:
+ return m.ChannelAppID()
+ case authidentitychannel.FieldChannelSubject:
+ return m.ChannelSubject()
+ case authidentitychannel.FieldMetadata:
+ return m.Metadata()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *AuthIdentityChannelMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case authidentitychannel.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case authidentitychannel.FieldIdentityID:
+ return m.OldIdentityID(ctx)
+ case authidentitychannel.FieldProviderType:
+ return m.OldProviderType(ctx)
+ case authidentitychannel.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case authidentitychannel.FieldChannel:
+ return m.OldChannel(ctx)
+ case authidentitychannel.FieldChannelAppID:
+ return m.OldChannelAppID(ctx)
+ case authidentitychannel.FieldChannelSubject:
+ return m.OldChannelSubject(ctx)
+ case authidentitychannel.FieldMetadata:
+ return m.OldMetadata(ctx)
+ }
+ return nil, fmt.Errorf("unknown AuthIdentityChannel field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityChannelMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case authidentitychannel.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case authidentitychannel.FieldIdentityID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIdentityID(v)
+ return nil
+ case authidentitychannel.FieldProviderType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderType(v)
+ return nil
+ case authidentitychannel.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case authidentitychannel.FieldChannel:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetChannel(v)
+ return nil
+ case authidentitychannel.FieldChannelAppID:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetChannelAppID(v)
+ return nil
+ case authidentitychannel.FieldChannelSubject:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetChannelSubject(v)
+ return nil
+ case authidentitychannel.FieldMetadata:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMetadata(v)
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *AuthIdentityChannelMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *AuthIdentityChannelMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *AuthIdentityChannelMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *AuthIdentityChannelMutation) ClearedFields() []string {
+ return nil
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *AuthIdentityChannelMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ClearField(name string) error {
+ return fmt.Errorf("unknown AuthIdentityChannel nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ResetField(name string) error {
+ switch name {
+ case authidentitychannel.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case authidentitychannel.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case authidentitychannel.FieldIdentityID:
+ m.ResetIdentityID()
+ return nil
+ case authidentitychannel.FieldProviderType:
+ m.ResetProviderType()
+ return nil
+ case authidentitychannel.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case authidentitychannel.FieldChannel:
+ m.ResetChannel()
+ return nil
+ case authidentitychannel.FieldChannelAppID:
+ m.ResetChannelAppID()
+ return nil
+ case authidentitychannel.FieldChannelSubject:
+ m.ResetChannelSubject()
+ return nil
+ case authidentitychannel.FieldMetadata:
+ m.ResetMetadata()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *AuthIdentityChannelMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.identity != nil {
+ edges = append(edges, authidentitychannel.EdgeIdentity)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *AuthIdentityChannelMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ if id := m.identity; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *AuthIdentityChannelMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *AuthIdentityChannelMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *AuthIdentityChannelMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedidentity {
+ edges = append(edges, authidentitychannel.EdgeIdentity)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *AuthIdentityChannelMutation) EdgeCleared(name string) bool {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ return m.clearedidentity
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ClearEdge(name string) error {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ m.ClearIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *AuthIdentityChannelMutation) ResetEdge(name string) error {
+ switch name {
+ case authidentitychannel.EdgeIdentity:
+ m.ResetIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown AuthIdentityChannel edge %s", name)
+}
+
+// ChannelMonitorMutation represents an operation that mutates the ChannelMonitor nodes in the graph.
+type ChannelMonitorMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ name *string
+ provider *channelmonitor.Provider
+ endpoint *string
+ api_key_encrypted *string
+ primary_model *string
+ extra_models *[]string
+ appendextra_models []string
+ group_name *string
+ enabled *bool
+ interval_seconds *int
+ addinterval_seconds *int
+ last_checked_at *time.Time
+ created_by *int64
+ addcreated_by *int64
+ extra_headers *map[string]string
+ body_override_mode *string
+ body_override *map[string]interface{}
+ clearedFields map[string]struct{}
+ history map[int64]struct{}
+ removedhistory map[int64]struct{}
+ clearedhistory bool
+ daily_rollups map[int64]struct{}
+ removeddaily_rollups map[int64]struct{}
+ cleareddaily_rollups bool
+ request_template *int64
+ clearedrequest_template bool
+ done bool
+ oldValue func(context.Context) (*ChannelMonitor, error)
+ predicates []predicate.ChannelMonitor
+}
+
+var _ ent.Mutation = (*ChannelMonitorMutation)(nil)
+
+// channelmonitorOption allows management of the mutation configuration using functional options.
+type channelmonitorOption func(*ChannelMonitorMutation)
+
+// newChannelMonitorMutation creates new mutation for the ChannelMonitor entity.
+func newChannelMonitorMutation(c config, op Op, opts ...channelmonitorOption) *ChannelMonitorMutation {
+ m := &ChannelMonitorMutation{
+ config: c,
+ op: op,
+ typ: TypeChannelMonitor,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withChannelMonitorID sets the ID field of the mutation.
+func withChannelMonitorID(id int64) channelmonitorOption {
+ return func(m *ChannelMonitorMutation) {
+ var (
+ err error
+ once sync.Once
+ value *ChannelMonitor
+ )
+ m.oldValue = func(ctx context.Context) (*ChannelMonitor, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().ChannelMonitor.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withChannelMonitor sets the old ChannelMonitor of the mutation.
+func withChannelMonitor(node *ChannelMonitor) channelmonitorOption {
+ return func(m *ChannelMonitorMutation) {
+ m.oldValue = func(context.Context) (*ChannelMonitor, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m ChannelMonitorMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m ChannelMonitorMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *ChannelMonitorMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *ChannelMonitorMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().ChannelMonitor.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *ChannelMonitorMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *ChannelMonitorMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *ChannelMonitorMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *ChannelMonitorMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *ChannelMonitorMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *ChannelMonitorMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetName sets the "name" field.
+func (m *ChannelMonitorMutation) SetName(s string) {
+ m.name = &s
+}
+
+// Name returns the value of the "name" field in the mutation.
+func (m *ChannelMonitorMutation) Name() (r string, exists bool) {
+ v := m.name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldName returns the old "name" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldName(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldName: %w", err)
+ }
+ return oldValue.Name, nil
+}
+
+// ResetName resets all changes to the "name" field.
+func (m *ChannelMonitorMutation) ResetName() {
+ m.name = nil
+}
+
+// SetProvider sets the "provider" field.
+func (m *ChannelMonitorMutation) SetProvider(c channelmonitor.Provider) {
+ m.provider = &c
+}
+
+// Provider returns the value of the "provider" field in the mutation.
+func (m *ChannelMonitorMutation) Provider() (r channelmonitor.Provider, exists bool) {
+ v := m.provider
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProvider returns the old "provider" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldProvider(ctx context.Context) (v channelmonitor.Provider, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProvider is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProvider requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProvider: %w", err)
+ }
+ return oldValue.Provider, nil
+}
+
+// ResetProvider resets all changes to the "provider" field.
+func (m *ChannelMonitorMutation) ResetProvider() {
+ m.provider = nil
+}
+
+// SetEndpoint sets the "endpoint" field.
+func (m *ChannelMonitorMutation) SetEndpoint(s string) {
+ m.endpoint = &s
+}
+
+// Endpoint returns the value of the "endpoint" field in the mutation.
+func (m *ChannelMonitorMutation) Endpoint() (r string, exists bool) {
+ v := m.endpoint
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldEndpoint returns the old "endpoint" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldEndpoint(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldEndpoint is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldEndpoint requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldEndpoint: %w", err)
+ }
+ return oldValue.Endpoint, nil
+}
+
+// ResetEndpoint resets all changes to the "endpoint" field.
+func (m *ChannelMonitorMutation) ResetEndpoint() {
+ m.endpoint = nil
+}
+
+// SetAPIKeyEncrypted sets the "api_key_encrypted" field.
+func (m *ChannelMonitorMutation) SetAPIKeyEncrypted(s string) {
+ m.api_key_encrypted = &s
+}
+
+// APIKeyEncrypted returns the value of the "api_key_encrypted" field in the mutation.
+func (m *ChannelMonitorMutation) APIKeyEncrypted() (r string, exists bool) {
+ v := m.api_key_encrypted
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAPIKeyEncrypted returns the old "api_key_encrypted" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldAPIKeyEncrypted(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAPIKeyEncrypted is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAPIKeyEncrypted requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAPIKeyEncrypted: %w", err)
+ }
+ return oldValue.APIKeyEncrypted, nil
+}
+
+// ResetAPIKeyEncrypted resets all changes to the "api_key_encrypted" field.
+func (m *ChannelMonitorMutation) ResetAPIKeyEncrypted() {
+ m.api_key_encrypted = nil
+}
+
+// SetPrimaryModel sets the "primary_model" field.
+func (m *ChannelMonitorMutation) SetPrimaryModel(s string) {
+ m.primary_model = &s
+}
+
+// PrimaryModel returns the value of the "primary_model" field in the mutation.
+func (m *ChannelMonitorMutation) PrimaryModel() (r string, exists bool) {
+ v := m.primary_model
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPrimaryModel returns the old "primary_model" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldPrimaryModel(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPrimaryModel is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPrimaryModel requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPrimaryModel: %w", err)
+ }
+ return oldValue.PrimaryModel, nil
+}
+
+// ResetPrimaryModel resets all changes to the "primary_model" field.
+func (m *ChannelMonitorMutation) ResetPrimaryModel() {
+ m.primary_model = nil
+}
+
+// SetExtraModels sets the "extra_models" field.
+func (m *ChannelMonitorMutation) SetExtraModels(s []string) {
+ m.extra_models = &s
+ m.appendextra_models = nil
+}
+
+// ExtraModels returns the value of the "extra_models" field in the mutation.
+func (m *ChannelMonitorMutation) ExtraModels() (r []string, exists bool) {
+ v := m.extra_models
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExtraModels returns the old "extra_models" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldExtraModels(ctx context.Context) (v []string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExtraModels is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExtraModels requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExtraModels: %w", err)
+ }
+ return oldValue.ExtraModels, nil
+}
+
+// AppendExtraModels adds s to the "extra_models" field.
+func (m *ChannelMonitorMutation) AppendExtraModels(s []string) {
+ m.appendextra_models = append(m.appendextra_models, s...)
+}
+
+// AppendedExtraModels returns the list of values that were appended to the "extra_models" field in this mutation.
+func (m *ChannelMonitorMutation) AppendedExtraModels() ([]string, bool) {
+ if len(m.appendextra_models) == 0 {
+ return nil, false
+ }
+ return m.appendextra_models, true
+}
+
+// ResetExtraModels resets all changes to the "extra_models" field.
+func (m *ChannelMonitorMutation) ResetExtraModels() {
+ m.extra_models = nil
+ m.appendextra_models = nil
+}
+
+// SetGroupName sets the "group_name" field.
+func (m *ChannelMonitorMutation) SetGroupName(s string) {
+ m.group_name = &s
+}
+
+// GroupName returns the value of the "group_name" field in the mutation.
+func (m *ChannelMonitorMutation) GroupName() (r string, exists bool) {
+ v := m.group_name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldGroupName returns the old "group_name" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldGroupName(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldGroupName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldGroupName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldGroupName: %w", err)
+ }
+ return oldValue.GroupName, nil
+}
+
+// ClearGroupName clears the value of the "group_name" field.
+func (m *ChannelMonitorMutation) ClearGroupName() {
+ m.group_name = nil
+ m.clearedFields[channelmonitor.FieldGroupName] = struct{}{}
+}
+
+// GroupNameCleared returns if the "group_name" field was cleared in this mutation.
+func (m *ChannelMonitorMutation) GroupNameCleared() bool {
+ _, ok := m.clearedFields[channelmonitor.FieldGroupName]
+ return ok
+}
+
+// ResetGroupName resets all changes to the "group_name" field.
+func (m *ChannelMonitorMutation) ResetGroupName() {
+ m.group_name = nil
+ delete(m.clearedFields, channelmonitor.FieldGroupName)
+}
+
+// SetEnabled sets the "enabled" field.
+func (m *ChannelMonitorMutation) SetEnabled(b bool) {
+ m.enabled = &b
+}
+
+// Enabled returns the value of the "enabled" field in the mutation.
+func (m *ChannelMonitorMutation) Enabled() (r bool, exists bool) {
+ v := m.enabled
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldEnabled returns the old "enabled" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldEnabled(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldEnabled is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldEnabled requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldEnabled: %w", err)
+ }
+ return oldValue.Enabled, nil
+}
+
+// ResetEnabled resets all changes to the "enabled" field.
+func (m *ChannelMonitorMutation) ResetEnabled() {
+ m.enabled = nil
+}
+
+// SetIntervalSeconds sets the "interval_seconds" field.
+func (m *ChannelMonitorMutation) SetIntervalSeconds(i int) {
+ m.interval_seconds = &i
+ m.addinterval_seconds = nil
+}
+
+// IntervalSeconds returns the value of the "interval_seconds" field in the mutation.
+func (m *ChannelMonitorMutation) IntervalSeconds() (r int, exists bool) {
+ v := m.interval_seconds
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIntervalSeconds returns the old "interval_seconds" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldIntervalSeconds(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIntervalSeconds is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIntervalSeconds requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIntervalSeconds: %w", err)
+ }
+ return oldValue.IntervalSeconds, nil
+}
+
+// AddIntervalSeconds adds i to the "interval_seconds" field.
+func (m *ChannelMonitorMutation) AddIntervalSeconds(i int) {
+ if m.addinterval_seconds != nil {
+ *m.addinterval_seconds += i
+ } else {
+ m.addinterval_seconds = &i
+ }
+}
+
+// AddedIntervalSeconds returns the value that was added to the "interval_seconds" field in this mutation.
+func (m *ChannelMonitorMutation) AddedIntervalSeconds() (r int, exists bool) {
+ v := m.addinterval_seconds
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetIntervalSeconds resets all changes to the "interval_seconds" field.
+func (m *ChannelMonitorMutation) ResetIntervalSeconds() {
+ m.interval_seconds = nil
+ m.addinterval_seconds = nil
+}
+
+// SetLastCheckedAt sets the "last_checked_at" field.
+func (m *ChannelMonitorMutation) SetLastCheckedAt(t time.Time) {
+ m.last_checked_at = &t
+}
+
+// LastCheckedAt returns the value of the "last_checked_at" field in the mutation.
+func (m *ChannelMonitorMutation) LastCheckedAt() (r time.Time, exists bool) {
+ v := m.last_checked_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLastCheckedAt returns the old "last_checked_at" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldLastCheckedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLastCheckedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLastCheckedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLastCheckedAt: %w", err)
+ }
+ return oldValue.LastCheckedAt, nil
+}
+
+// ClearLastCheckedAt clears the value of the "last_checked_at" field.
+func (m *ChannelMonitorMutation) ClearLastCheckedAt() {
+ m.last_checked_at = nil
+ m.clearedFields[channelmonitor.FieldLastCheckedAt] = struct{}{}
+}
+
+// LastCheckedAtCleared returns if the "last_checked_at" field was cleared in this mutation.
+func (m *ChannelMonitorMutation) LastCheckedAtCleared() bool {
+ _, ok := m.clearedFields[channelmonitor.FieldLastCheckedAt]
+ return ok
+}
+
+// ResetLastCheckedAt resets all changes to the "last_checked_at" field.
+func (m *ChannelMonitorMutation) ResetLastCheckedAt() {
+ m.last_checked_at = nil
+ delete(m.clearedFields, channelmonitor.FieldLastCheckedAt)
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (m *ChannelMonitorMutation) SetCreatedBy(i int64) {
+ m.created_by = &i
+ m.addcreated_by = nil
+}
+
+// CreatedBy returns the value of the "created_by" field in the mutation.
+func (m *ChannelMonitorMutation) CreatedBy() (r int64, exists bool) {
+ v := m.created_by
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedBy returns the old "created_by" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldCreatedBy(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedBy is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedBy requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedBy: %w", err)
+ }
+ return oldValue.CreatedBy, nil
+}
+
+// AddCreatedBy adds i to the "created_by" field.
+func (m *ChannelMonitorMutation) AddCreatedBy(i int64) {
+ if m.addcreated_by != nil {
+ *m.addcreated_by += i
+ } else {
+ m.addcreated_by = &i
+ }
+}
+
+// AddedCreatedBy returns the value that was added to the "created_by" field in this mutation.
+func (m *ChannelMonitorMutation) AddedCreatedBy() (r int64, exists bool) {
+ v := m.addcreated_by
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetCreatedBy resets all changes to the "created_by" field.
+func (m *ChannelMonitorMutation) ResetCreatedBy() {
+ m.created_by = nil
+ m.addcreated_by = nil
+}
+
+// SetTemplateID sets the "template_id" field.
+func (m *ChannelMonitorMutation) SetTemplateID(i int64) {
+ m.request_template = &i
+}
+
+// TemplateID returns the value of the "template_id" field in the mutation.
+func (m *ChannelMonitorMutation) TemplateID() (r int64, exists bool) {
+ v := m.request_template
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTemplateID returns the old "template_id" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldTemplateID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTemplateID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTemplateID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTemplateID: %w", err)
+ }
+ return oldValue.TemplateID, nil
+}
+
+// ClearTemplateID clears the value of the "template_id" field.
+func (m *ChannelMonitorMutation) ClearTemplateID() {
+ m.request_template = nil
+ m.clearedFields[channelmonitor.FieldTemplateID] = struct{}{}
+}
+
+// TemplateIDCleared returns if the "template_id" field was cleared in this mutation.
+func (m *ChannelMonitorMutation) TemplateIDCleared() bool {
+ _, ok := m.clearedFields[channelmonitor.FieldTemplateID]
+ return ok
+}
+
+// ResetTemplateID resets all changes to the "template_id" field.
+func (m *ChannelMonitorMutation) ResetTemplateID() {
+ m.request_template = nil
+ delete(m.clearedFields, channelmonitor.FieldTemplateID)
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (m *ChannelMonitorMutation) SetExtraHeaders(value map[string]string) {
+ m.extra_headers = &value
+}
+
+// ExtraHeaders returns the value of the "extra_headers" field in the mutation.
+func (m *ChannelMonitorMutation) ExtraHeaders() (r map[string]string, exists bool) {
+ v := m.extra_headers
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExtraHeaders returns the old "extra_headers" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldExtraHeaders(ctx context.Context) (v map[string]string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExtraHeaders is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExtraHeaders requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExtraHeaders: %w", err)
+ }
+ return oldValue.ExtraHeaders, nil
+}
+
+// ResetExtraHeaders resets all changes to the "extra_headers" field.
+func (m *ChannelMonitorMutation) ResetExtraHeaders() {
+ m.extra_headers = nil
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (m *ChannelMonitorMutation) SetBodyOverrideMode(s string) {
+ m.body_override_mode = &s
+}
+
+// BodyOverrideMode returns the value of the "body_override_mode" field in the mutation.
+func (m *ChannelMonitorMutation) BodyOverrideMode() (r string, exists bool) {
+ v := m.body_override_mode
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBodyOverrideMode returns the old "body_override_mode" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldBodyOverrideMode(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBodyOverrideMode is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBodyOverrideMode requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBodyOverrideMode: %w", err)
+ }
+ return oldValue.BodyOverrideMode, nil
+}
+
+// ResetBodyOverrideMode resets all changes to the "body_override_mode" field.
+func (m *ChannelMonitorMutation) ResetBodyOverrideMode() {
+ m.body_override_mode = nil
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (m *ChannelMonitorMutation) SetBodyOverride(value map[string]interface{}) {
+ m.body_override = &value
+}
+
+// BodyOverride returns the value of the "body_override" field in the mutation.
+func (m *ChannelMonitorMutation) BodyOverride() (r map[string]interface{}, exists bool) {
+ v := m.body_override
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBodyOverride returns the old "body_override" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldBodyOverride(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBodyOverride is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBodyOverride requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBodyOverride: %w", err)
+ }
+ return oldValue.BodyOverride, nil
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (m *ChannelMonitorMutation) ClearBodyOverride() {
+ m.body_override = nil
+ m.clearedFields[channelmonitor.FieldBodyOverride] = struct{}{}
+}
+
+// BodyOverrideCleared returns if the "body_override" field was cleared in this mutation.
+func (m *ChannelMonitorMutation) BodyOverrideCleared() bool {
+ _, ok := m.clearedFields[channelmonitor.FieldBodyOverride]
+ return ok
+}
+
+// ResetBodyOverride resets all changes to the "body_override" field.
+func (m *ChannelMonitorMutation) ResetBodyOverride() {
+ m.body_override = nil
+ delete(m.clearedFields, channelmonitor.FieldBodyOverride)
+}
+
+// AddHistoryIDs adds the "history" edge to the ChannelMonitorHistory entity by ids.
+func (m *ChannelMonitorMutation) AddHistoryIDs(ids ...int64) {
+ if m.history == nil {
+ m.history = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.history[ids[i]] = struct{}{}
+ }
+}
+
+// ClearHistory clears the "history" edge to the ChannelMonitorHistory entity.
+func (m *ChannelMonitorMutation) ClearHistory() {
+ m.clearedhistory = true
+}
+
+// HistoryCleared reports if the "history" edge to the ChannelMonitorHistory entity was cleared.
+func (m *ChannelMonitorMutation) HistoryCleared() bool {
+ return m.clearedhistory
+}
+
+// RemoveHistoryIDs removes the "history" edge to the ChannelMonitorHistory entity by IDs.
+func (m *ChannelMonitorMutation) RemoveHistoryIDs(ids ...int64) {
+ if m.removedhistory == nil {
+ m.removedhistory = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.history, ids[i])
+ m.removedhistory[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedHistory returns the removed IDs of the "history" edge to the ChannelMonitorHistory entity.
+func (m *ChannelMonitorMutation) RemovedHistoryIDs() (ids []int64) {
+ for id := range m.removedhistory {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// HistoryIDs returns the "history" edge IDs in the mutation.
+func (m *ChannelMonitorMutation) HistoryIDs() (ids []int64) {
+ for id := range m.history {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetHistory resets all changes to the "history" edge.
+func (m *ChannelMonitorMutation) ResetHistory() {
+ m.history = nil
+ m.clearedhistory = false
+ m.removedhistory = nil
+}
+
+// AddDailyRollupIDs adds the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by ids.
+func (m *ChannelMonitorMutation) AddDailyRollupIDs(ids ...int64) {
+ if m.daily_rollups == nil {
+ m.daily_rollups = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.daily_rollups[ids[i]] = struct{}{}
+ }
+}
+
+// ClearDailyRollups clears the "daily_rollups" edge to the ChannelMonitorDailyRollup entity.
+func (m *ChannelMonitorMutation) ClearDailyRollups() {
+ m.cleareddaily_rollups = true
+}
+
+// DailyRollupsCleared reports if the "daily_rollups" edge to the ChannelMonitorDailyRollup entity was cleared.
+func (m *ChannelMonitorMutation) DailyRollupsCleared() bool {
+ return m.cleareddaily_rollups
+}
+
+// RemoveDailyRollupIDs removes the "daily_rollups" edge to the ChannelMonitorDailyRollup entity by IDs.
+func (m *ChannelMonitorMutation) RemoveDailyRollupIDs(ids ...int64) {
+ if m.removeddaily_rollups == nil {
+ m.removeddaily_rollups = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.daily_rollups, ids[i])
+ m.removeddaily_rollups[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedDailyRollups returns the removed IDs of the "daily_rollups" edge to the ChannelMonitorDailyRollup entity.
+func (m *ChannelMonitorMutation) RemovedDailyRollupsIDs() (ids []int64) {
+ for id := range m.removeddaily_rollups {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// DailyRollupsIDs returns the "daily_rollups" edge IDs in the mutation.
+func (m *ChannelMonitorMutation) DailyRollupsIDs() (ids []int64) {
+ for id := range m.daily_rollups {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetDailyRollups resets all changes to the "daily_rollups" edge.
+func (m *ChannelMonitorMutation) ResetDailyRollups() {
+ m.daily_rollups = nil
+ m.cleareddaily_rollups = false
+ m.removeddaily_rollups = nil
+}
+
+// SetRequestTemplateID sets the "request_template" edge to the ChannelMonitorRequestTemplate entity by id.
+func (m *ChannelMonitorMutation) SetRequestTemplateID(id int64) {
+ m.request_template = &id
+}
+
+// ClearRequestTemplate clears the "request_template" edge to the ChannelMonitorRequestTemplate entity.
+func (m *ChannelMonitorMutation) ClearRequestTemplate() {
+ m.clearedrequest_template = true
+ m.clearedFields[channelmonitor.FieldTemplateID] = struct{}{}
+}
+
+// RequestTemplateCleared reports if the "request_template" edge to the ChannelMonitorRequestTemplate entity was cleared.
+func (m *ChannelMonitorMutation) RequestTemplateCleared() bool {
+ return m.TemplateIDCleared() || m.clearedrequest_template
+}
+
+// RequestTemplateID returns the "request_template" edge ID in the mutation.
+func (m *ChannelMonitorMutation) RequestTemplateID() (id int64, exists bool) {
+ if m.request_template != nil {
+ return *m.request_template, true
+ }
+ return
+}
+
+// RequestTemplateIDs returns the "request_template" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// RequestTemplateID instead. It exists only for internal usage by the builders.
+func (m *ChannelMonitorMutation) RequestTemplateIDs() (ids []int64) {
+ if id := m.request_template; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetRequestTemplate resets all changes to the "request_template" edge.
+func (m *ChannelMonitorMutation) ResetRequestTemplate() {
+ m.request_template = nil
+ m.clearedrequest_template = false
+}
+
+// Where appends a list predicates to the ChannelMonitorMutation builder.
+func (m *ChannelMonitorMutation) Where(ps ...predicate.ChannelMonitor) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the ChannelMonitorMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *ChannelMonitorMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.ChannelMonitor, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *ChannelMonitorMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *ChannelMonitorMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (ChannelMonitor).
+func (m *ChannelMonitorMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *ChannelMonitorMutation) Fields() []string {
+ fields := make([]string, 0, 17)
+ if m.created_at != nil {
+ fields = append(fields, channelmonitor.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, channelmonitor.FieldUpdatedAt)
+ }
+ if m.name != nil {
+ fields = append(fields, channelmonitor.FieldName)
+ }
+ if m.provider != nil {
+ fields = append(fields, channelmonitor.FieldProvider)
+ }
+ if m.endpoint != nil {
+ fields = append(fields, channelmonitor.FieldEndpoint)
+ }
+ if m.api_key_encrypted != nil {
+ fields = append(fields, channelmonitor.FieldAPIKeyEncrypted)
+ }
+ if m.primary_model != nil {
+ fields = append(fields, channelmonitor.FieldPrimaryModel)
+ }
+ if m.extra_models != nil {
+ fields = append(fields, channelmonitor.FieldExtraModels)
+ }
+ if m.group_name != nil {
+ fields = append(fields, channelmonitor.FieldGroupName)
+ }
+ if m.enabled != nil {
+ fields = append(fields, channelmonitor.FieldEnabled)
+ }
+ if m.interval_seconds != nil {
+ fields = append(fields, channelmonitor.FieldIntervalSeconds)
+ }
+ if m.last_checked_at != nil {
+ fields = append(fields, channelmonitor.FieldLastCheckedAt)
+ }
+ if m.created_by != nil {
+ fields = append(fields, channelmonitor.FieldCreatedBy)
+ }
+ if m.request_template != nil {
+ fields = append(fields, channelmonitor.FieldTemplateID)
+ }
+ if m.extra_headers != nil {
+ fields = append(fields, channelmonitor.FieldExtraHeaders)
+ }
+ if m.body_override_mode != nil {
+ fields = append(fields, channelmonitor.FieldBodyOverrideMode)
+ }
+ if m.body_override != nil {
+ fields = append(fields, channelmonitor.FieldBodyOverride)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *ChannelMonitorMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitor.FieldCreatedAt:
+ return m.CreatedAt()
+ case channelmonitor.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case channelmonitor.FieldName:
+ return m.Name()
+ case channelmonitor.FieldProvider:
+ return m.Provider()
+ case channelmonitor.FieldEndpoint:
+ return m.Endpoint()
+ case channelmonitor.FieldAPIKeyEncrypted:
+ return m.APIKeyEncrypted()
+ case channelmonitor.FieldPrimaryModel:
+ return m.PrimaryModel()
+ case channelmonitor.FieldExtraModels:
+ return m.ExtraModels()
+ case channelmonitor.FieldGroupName:
+ return m.GroupName()
+ case channelmonitor.FieldEnabled:
+ return m.Enabled()
+ case channelmonitor.FieldIntervalSeconds:
+ return m.IntervalSeconds()
+ case channelmonitor.FieldLastCheckedAt:
+ return m.LastCheckedAt()
+ case channelmonitor.FieldCreatedBy:
+ return m.CreatedBy()
+ case channelmonitor.FieldTemplateID:
+ return m.TemplateID()
+ case channelmonitor.FieldExtraHeaders:
+ return m.ExtraHeaders()
+ case channelmonitor.FieldBodyOverrideMode:
+ return m.BodyOverrideMode()
+ case channelmonitor.FieldBodyOverride:
+ return m.BodyOverride()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *ChannelMonitorMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case channelmonitor.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case channelmonitor.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case channelmonitor.FieldName:
+ return m.OldName(ctx)
+ case channelmonitor.FieldProvider:
+ return m.OldProvider(ctx)
+ case channelmonitor.FieldEndpoint:
+ return m.OldEndpoint(ctx)
+ case channelmonitor.FieldAPIKeyEncrypted:
+ return m.OldAPIKeyEncrypted(ctx)
+ case channelmonitor.FieldPrimaryModel:
+ return m.OldPrimaryModel(ctx)
+ case channelmonitor.FieldExtraModels:
+ return m.OldExtraModels(ctx)
+ case channelmonitor.FieldGroupName:
+ return m.OldGroupName(ctx)
+ case channelmonitor.FieldEnabled:
+ return m.OldEnabled(ctx)
+ case channelmonitor.FieldIntervalSeconds:
+ return m.OldIntervalSeconds(ctx)
+ case channelmonitor.FieldLastCheckedAt:
+ return m.OldLastCheckedAt(ctx)
+ case channelmonitor.FieldCreatedBy:
+ return m.OldCreatedBy(ctx)
+ case channelmonitor.FieldTemplateID:
+ return m.OldTemplateID(ctx)
+ case channelmonitor.FieldExtraHeaders:
+ return m.OldExtraHeaders(ctx)
+ case channelmonitor.FieldBodyOverrideMode:
+ return m.OldBodyOverrideMode(ctx)
+ case channelmonitor.FieldBodyOverride:
+ return m.OldBodyOverride(ctx)
+ }
+ return nil, fmt.Errorf("unknown ChannelMonitor field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitor.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case channelmonitor.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case channelmonitor.FieldName:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetName(v)
+ return nil
+ case channelmonitor.FieldProvider:
+ v, ok := value.(channelmonitor.Provider)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProvider(v)
+ return nil
+ case channelmonitor.FieldEndpoint:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetEndpoint(v)
+ return nil
+ case channelmonitor.FieldAPIKeyEncrypted:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAPIKeyEncrypted(v)
+ return nil
+ case channelmonitor.FieldPrimaryModel:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPrimaryModel(v)
+ return nil
+ case channelmonitor.FieldExtraModels:
+ v, ok := value.([]string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExtraModels(v)
+ return nil
+ case channelmonitor.FieldGroupName:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetGroupName(v)
+ return nil
+ case channelmonitor.FieldEnabled:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetEnabled(v)
+ return nil
+ case channelmonitor.FieldIntervalSeconds:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIntervalSeconds(v)
+ return nil
+ case channelmonitor.FieldLastCheckedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLastCheckedAt(v)
+ return nil
+ case channelmonitor.FieldCreatedBy:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedBy(v)
+ return nil
+ case channelmonitor.FieldTemplateID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTemplateID(v)
+ return nil
+ case channelmonitor.FieldExtraHeaders:
+ v, ok := value.(map[string]string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExtraHeaders(v)
+ return nil
+ case channelmonitor.FieldBodyOverrideMode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBodyOverrideMode(v)
+ return nil
+ case channelmonitor.FieldBodyOverride:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBodyOverride(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitor field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *ChannelMonitorMutation) AddedFields() []string {
+ var fields []string
+ if m.addinterval_seconds != nil {
+ fields = append(fields, channelmonitor.FieldIntervalSeconds)
+ }
+ if m.addcreated_by != nil {
+ fields = append(fields, channelmonitor.FieldCreatedBy)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *ChannelMonitorMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitor.FieldIntervalSeconds:
+ return m.AddedIntervalSeconds()
+ case channelmonitor.FieldCreatedBy:
+ return m.AddedCreatedBy()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitor.FieldIntervalSeconds:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddIntervalSeconds(v)
+ return nil
+ case channelmonitor.FieldCreatedBy:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddCreatedBy(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitor numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *ChannelMonitorMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(channelmonitor.FieldGroupName) {
+ fields = append(fields, channelmonitor.FieldGroupName)
+ }
+ if m.FieldCleared(channelmonitor.FieldLastCheckedAt) {
+ fields = append(fields, channelmonitor.FieldLastCheckedAt)
+ }
+ if m.FieldCleared(channelmonitor.FieldTemplateID) {
+ fields = append(fields, channelmonitor.FieldTemplateID)
+ }
+ if m.FieldCleared(channelmonitor.FieldBodyOverride) {
+ fields = append(fields, channelmonitor.FieldBodyOverride)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *ChannelMonitorMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *ChannelMonitorMutation) ClearField(name string) error {
+ switch name {
+ case channelmonitor.FieldGroupName:
+ m.ClearGroupName()
+ return nil
+ case channelmonitor.FieldLastCheckedAt:
+ m.ClearLastCheckedAt()
+ return nil
+ case channelmonitor.FieldTemplateID:
+ m.ClearTemplateID()
+ return nil
+ case channelmonitor.FieldBodyOverride:
+ m.ClearBodyOverride()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitor nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *ChannelMonitorMutation) ResetField(name string) error {
+ switch name {
+ case channelmonitor.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case channelmonitor.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case channelmonitor.FieldName:
+ m.ResetName()
+ return nil
+ case channelmonitor.FieldProvider:
+ m.ResetProvider()
+ return nil
+ case channelmonitor.FieldEndpoint:
+ m.ResetEndpoint()
+ return nil
+ case channelmonitor.FieldAPIKeyEncrypted:
+ m.ResetAPIKeyEncrypted()
+ return nil
+ case channelmonitor.FieldPrimaryModel:
+ m.ResetPrimaryModel()
+ return nil
+ case channelmonitor.FieldExtraModels:
+ m.ResetExtraModels()
+ return nil
+ case channelmonitor.FieldGroupName:
+ m.ResetGroupName()
+ return nil
+ case channelmonitor.FieldEnabled:
+ m.ResetEnabled()
+ return nil
+ case channelmonitor.FieldIntervalSeconds:
+ m.ResetIntervalSeconds()
+ return nil
+ case channelmonitor.FieldLastCheckedAt:
+ m.ResetLastCheckedAt()
+ return nil
+ case channelmonitor.FieldCreatedBy:
+ m.ResetCreatedBy()
+ return nil
+ case channelmonitor.FieldTemplateID:
+ m.ResetTemplateID()
+ return nil
+ case channelmonitor.FieldExtraHeaders:
+ m.ResetExtraHeaders()
+ return nil
+ case channelmonitor.FieldBodyOverrideMode:
+ m.ResetBodyOverrideMode()
+ return nil
+ case channelmonitor.FieldBodyOverride:
+ m.ResetBodyOverride()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitor field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *ChannelMonitorMutation) AddedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.history != nil {
+ edges = append(edges, channelmonitor.EdgeHistory)
+ }
+ if m.daily_rollups != nil {
+ edges = append(edges, channelmonitor.EdgeDailyRollups)
+ }
+ if m.request_template != nil {
+ edges = append(edges, channelmonitor.EdgeRequestTemplate)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *ChannelMonitorMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case channelmonitor.EdgeHistory:
+ ids := make([]ent.Value, 0, len(m.history))
+ for id := range m.history {
+ ids = append(ids, id)
+ }
+ return ids
+ case channelmonitor.EdgeDailyRollups:
+ ids := make([]ent.Value, 0, len(m.daily_rollups))
+ for id := range m.daily_rollups {
+ ids = append(ids, id)
+ }
+ return ids
+ case channelmonitor.EdgeRequestTemplate:
+ if id := m.request_template; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *ChannelMonitorMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.removedhistory != nil {
+ edges = append(edges, channelmonitor.EdgeHistory)
+ }
+ if m.removeddaily_rollups != nil {
+ edges = append(edges, channelmonitor.EdgeDailyRollups)
+ }
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *ChannelMonitorMutation) RemovedIDs(name string) []ent.Value {
+ switch name {
+ case channelmonitor.EdgeHistory:
+ ids := make([]ent.Value, 0, len(m.removedhistory))
+ for id := range m.removedhistory {
+ ids = append(ids, id)
+ }
+ return ids
+ case channelmonitor.EdgeDailyRollups:
+ ids := make([]ent.Value, 0, len(m.removeddaily_rollups))
+ for id := range m.removeddaily_rollups {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *ChannelMonitorMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 3)
+ if m.clearedhistory {
+ edges = append(edges, channelmonitor.EdgeHistory)
+ }
+ if m.cleareddaily_rollups {
+ edges = append(edges, channelmonitor.EdgeDailyRollups)
+ }
+ if m.clearedrequest_template {
+ edges = append(edges, channelmonitor.EdgeRequestTemplate)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *ChannelMonitorMutation) EdgeCleared(name string) bool {
+ switch name {
+ case channelmonitor.EdgeHistory:
+ return m.clearedhistory
+ case channelmonitor.EdgeDailyRollups:
+ return m.cleareddaily_rollups
+ case channelmonitor.EdgeRequestTemplate:
+ return m.clearedrequest_template
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *ChannelMonitorMutation) ClearEdge(name string) error {
+ switch name {
+ case channelmonitor.EdgeRequestTemplate:
+ m.ClearRequestTemplate()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitor unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *ChannelMonitorMutation) ResetEdge(name string) error {
+ switch name {
+ case channelmonitor.EdgeHistory:
+ m.ResetHistory()
+ return nil
+ case channelmonitor.EdgeDailyRollups:
+ m.ResetDailyRollups()
+ return nil
+ case channelmonitor.EdgeRequestTemplate:
+ m.ResetRequestTemplate()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitor edge %s", name)
+}
+
+// ChannelMonitorDailyRollupMutation represents an operation that mutates the ChannelMonitorDailyRollup nodes in the graph.
+type ChannelMonitorDailyRollupMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ model *string
+ bucket_date *time.Time
+ total_checks *int
+ addtotal_checks *int
+ ok_count *int
+ addok_count *int
+ operational_count *int
+ addoperational_count *int
+ degraded_count *int
+ adddegraded_count *int
+ failed_count *int
+ addfailed_count *int
+ error_count *int
+ adderror_count *int
+ sum_latency_ms *int64
+ addsum_latency_ms *int64
+ count_latency *int
+ addcount_latency *int
+ sum_ping_latency_ms *int64
+ addsum_ping_latency_ms *int64
+ count_ping_latency *int
+ addcount_ping_latency *int
+ computed_at *time.Time
+ clearedFields map[string]struct{}
+ monitor *int64
+ clearedmonitor bool
+ done bool
+ oldValue func(context.Context) (*ChannelMonitorDailyRollup, error)
+ predicates []predicate.ChannelMonitorDailyRollup
+}
+
+var _ ent.Mutation = (*ChannelMonitorDailyRollupMutation)(nil)
+
+// channelmonitordailyrollupOption allows management of the mutation configuration using functional options.
+type channelmonitordailyrollupOption func(*ChannelMonitorDailyRollupMutation)
+
+// newChannelMonitorDailyRollupMutation creates new mutation for the ChannelMonitorDailyRollup entity.
+func newChannelMonitorDailyRollupMutation(c config, op Op, opts ...channelmonitordailyrollupOption) *ChannelMonitorDailyRollupMutation {
+ m := &ChannelMonitorDailyRollupMutation{
+ config: c,
+ op: op,
+ typ: TypeChannelMonitorDailyRollup,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withChannelMonitorDailyRollupID sets the ID field of the mutation.
+func withChannelMonitorDailyRollupID(id int64) channelmonitordailyrollupOption {
+ return func(m *ChannelMonitorDailyRollupMutation) {
+ var (
+ err error
+ once sync.Once
+ value *ChannelMonitorDailyRollup
+ )
+ m.oldValue = func(ctx context.Context) (*ChannelMonitorDailyRollup, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().ChannelMonitorDailyRollup.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withChannelMonitorDailyRollup sets the old ChannelMonitorDailyRollup of the mutation.
+func withChannelMonitorDailyRollup(node *ChannelMonitorDailyRollup) channelmonitordailyrollupOption {
+ return func(m *ChannelMonitorDailyRollupMutation) {
+ m.oldValue = func(context.Context) (*ChannelMonitorDailyRollup, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m ChannelMonitorDailyRollupMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m ChannelMonitorDailyRollupMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *ChannelMonitorDailyRollupMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *ChannelMonitorDailyRollupMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().ChannelMonitorDailyRollup.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (m *ChannelMonitorDailyRollupMutation) SetMonitorID(i int64) {
+ m.monitor = &i
+}
+
+// MonitorID returns the value of the "monitor_id" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) MonitorID() (r int64, exists bool) {
+ v := m.monitor
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMonitorID returns the old "monitor_id" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldMonitorID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMonitorID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMonitorID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMonitorID: %w", err)
+ }
+ return oldValue.MonitorID, nil
+}
+
+// ResetMonitorID resets all changes to the "monitor_id" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetMonitorID() {
+ m.monitor = nil
+}
+
+// SetModel sets the "model" field.
+func (m *ChannelMonitorDailyRollupMutation) SetModel(s string) {
+ m.model = &s
+}
+
+// Model returns the value of the "model" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) Model() (r string, exists bool) {
+ v := m.model
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldModel returns the old "model" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldModel(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldModel is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldModel requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldModel: %w", err)
+ }
+ return oldValue.Model, nil
+}
+
+// ResetModel resets all changes to the "model" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetModel() {
+ m.model = nil
+}
+
+// SetBucketDate sets the "bucket_date" field.
+func (m *ChannelMonitorDailyRollupMutation) SetBucketDate(t time.Time) {
+ m.bucket_date = &t
+}
+
+// BucketDate returns the value of the "bucket_date" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) BucketDate() (r time.Time, exists bool) {
+ v := m.bucket_date
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBucketDate returns the old "bucket_date" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldBucketDate(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBucketDate is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBucketDate requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBucketDate: %w", err)
+ }
+ return oldValue.BucketDate, nil
+}
+
+// ResetBucketDate resets all changes to the "bucket_date" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetBucketDate() {
+ m.bucket_date = nil
+}
+
+// SetTotalChecks sets the "total_checks" field.
+func (m *ChannelMonitorDailyRollupMutation) SetTotalChecks(i int) {
+ m.total_checks = &i
+ m.addtotal_checks = nil
+}
+
+// TotalChecks returns the value of the "total_checks" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) TotalChecks() (r int, exists bool) {
+ v := m.total_checks
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTotalChecks returns the old "total_checks" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldTotalChecks(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTotalChecks is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTotalChecks requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTotalChecks: %w", err)
+ }
+ return oldValue.TotalChecks, nil
+}
+
+// AddTotalChecks adds i to the "total_checks" field.
+func (m *ChannelMonitorDailyRollupMutation) AddTotalChecks(i int) {
+ if m.addtotal_checks != nil {
+ *m.addtotal_checks += i
+ } else {
+ m.addtotal_checks = &i
+ }
+}
+
+// AddedTotalChecks returns the value that was added to the "total_checks" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedTotalChecks() (r int, exists bool) {
+ v := m.addtotal_checks
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetTotalChecks resets all changes to the "total_checks" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetTotalChecks() {
+ m.total_checks = nil
+ m.addtotal_checks = nil
+}
+
+// SetOkCount sets the "ok_count" field.
+func (m *ChannelMonitorDailyRollupMutation) SetOkCount(i int) {
+ m.ok_count = &i
+ m.addok_count = nil
+}
+
+// OkCount returns the value of the "ok_count" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) OkCount() (r int, exists bool) {
+ v := m.ok_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldOkCount returns the old "ok_count" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldOkCount(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldOkCount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldOkCount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldOkCount: %w", err)
+ }
+ return oldValue.OkCount, nil
+}
+
+// AddOkCount adds i to the "ok_count" field.
+func (m *ChannelMonitorDailyRollupMutation) AddOkCount(i int) {
+ if m.addok_count != nil {
+ *m.addok_count += i
+ } else {
+ m.addok_count = &i
+ }
+}
+
+// AddedOkCount returns the value that was added to the "ok_count" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedOkCount() (r int, exists bool) {
+ v := m.addok_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetOkCount resets all changes to the "ok_count" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetOkCount() {
+ m.ok_count = nil
+ m.addok_count = nil
+}
+
+// SetOperationalCount sets the "operational_count" field.
+func (m *ChannelMonitorDailyRollupMutation) SetOperationalCount(i int) {
+ m.operational_count = &i
+ m.addoperational_count = nil
+}
+
+// OperationalCount returns the value of the "operational_count" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) OperationalCount() (r int, exists bool) {
+ v := m.operational_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldOperationalCount returns the old "operational_count" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldOperationalCount(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldOperationalCount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldOperationalCount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldOperationalCount: %w", err)
+ }
+ return oldValue.OperationalCount, nil
+}
+
+// AddOperationalCount adds i to the "operational_count" field.
+func (m *ChannelMonitorDailyRollupMutation) AddOperationalCount(i int) {
+ if m.addoperational_count != nil {
+ *m.addoperational_count += i
+ } else {
+ m.addoperational_count = &i
+ }
+}
+
+// AddedOperationalCount returns the value that was added to the "operational_count" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedOperationalCount() (r int, exists bool) {
+ v := m.addoperational_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetOperationalCount resets all changes to the "operational_count" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetOperationalCount() {
+ m.operational_count = nil
+ m.addoperational_count = nil
+}
+
+// SetDegradedCount sets the "degraded_count" field.
+func (m *ChannelMonitorDailyRollupMutation) SetDegradedCount(i int) {
+ m.degraded_count = &i
+ m.adddegraded_count = nil
+}
+
+// DegradedCount returns the value of the "degraded_count" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) DegradedCount() (r int, exists bool) {
+ v := m.degraded_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDegradedCount returns the old "degraded_count" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldDegradedCount(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDegradedCount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDegradedCount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDegradedCount: %w", err)
+ }
+ return oldValue.DegradedCount, nil
+}
+
+// AddDegradedCount adds i to the "degraded_count" field.
+func (m *ChannelMonitorDailyRollupMutation) AddDegradedCount(i int) {
+ if m.adddegraded_count != nil {
+ *m.adddegraded_count += i
+ } else {
+ m.adddegraded_count = &i
+ }
+}
+
+// AddedDegradedCount returns the value that was added to the "degraded_count" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedDegradedCount() (r int, exists bool) {
+ v := m.adddegraded_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetDegradedCount resets all changes to the "degraded_count" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetDegradedCount() {
+ m.degraded_count = nil
+ m.adddegraded_count = nil
+}
+
+// SetFailedCount sets the "failed_count" field.
+func (m *ChannelMonitorDailyRollupMutation) SetFailedCount(i int) {
+ m.failed_count = &i
+ m.addfailed_count = nil
+}
+
+// FailedCount returns the value of the "failed_count" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) FailedCount() (r int, exists bool) {
+ v := m.failed_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldFailedCount returns the old "failed_count" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldFailedCount(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldFailedCount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldFailedCount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldFailedCount: %w", err)
+ }
+ return oldValue.FailedCount, nil
+}
+
+// AddFailedCount adds i to the "failed_count" field.
+func (m *ChannelMonitorDailyRollupMutation) AddFailedCount(i int) {
+ if m.addfailed_count != nil {
+ *m.addfailed_count += i
+ } else {
+ m.addfailed_count = &i
+ }
+}
+
+// AddedFailedCount returns the value that was added to the "failed_count" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedFailedCount() (r int, exists bool) {
+ v := m.addfailed_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetFailedCount resets all changes to the "failed_count" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetFailedCount() {
+ m.failed_count = nil
+ m.addfailed_count = nil
+}
+
+// SetErrorCount sets the "error_count" field.
+func (m *ChannelMonitorDailyRollupMutation) SetErrorCount(i int) {
+ m.error_count = &i
+ m.adderror_count = nil
+}
+
+// ErrorCount returns the value of the "error_count" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) ErrorCount() (r int, exists bool) {
+ v := m.error_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldErrorCount returns the old "error_count" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldErrorCount(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldErrorCount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldErrorCount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldErrorCount: %w", err)
+ }
+ return oldValue.ErrorCount, nil
+}
+
+// AddErrorCount adds i to the "error_count" field.
+func (m *ChannelMonitorDailyRollupMutation) AddErrorCount(i int) {
+ if m.adderror_count != nil {
+ *m.adderror_count += i
+ } else {
+ m.adderror_count = &i
+ }
+}
+
+// AddedErrorCount returns the value that was added to the "error_count" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedErrorCount() (r int, exists bool) {
+ v := m.adderror_count
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetErrorCount resets all changes to the "error_count" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetErrorCount() {
+ m.error_count = nil
+ m.adderror_count = nil
+}
+
+// SetSumLatencyMs sets the "sum_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) SetSumLatencyMs(i int64) {
+ m.sum_latency_ms = &i
+ m.addsum_latency_ms = nil
+}
+
+// SumLatencyMs returns the value of the "sum_latency_ms" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) SumLatencyMs() (r int64, exists bool) {
+ v := m.sum_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSumLatencyMs returns the old "sum_latency_ms" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldSumLatencyMs(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSumLatencyMs is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSumLatencyMs requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSumLatencyMs: %w", err)
+ }
+ return oldValue.SumLatencyMs, nil
+}
+
+// AddSumLatencyMs adds i to the "sum_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) AddSumLatencyMs(i int64) {
+ if m.addsum_latency_ms != nil {
+ *m.addsum_latency_ms += i
+ } else {
+ m.addsum_latency_ms = &i
+ }
+}
+
+// AddedSumLatencyMs returns the value that was added to the "sum_latency_ms" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedSumLatencyMs() (r int64, exists bool) {
+ v := m.addsum_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetSumLatencyMs resets all changes to the "sum_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetSumLatencyMs() {
+ m.sum_latency_ms = nil
+ m.addsum_latency_ms = nil
+}
+
+// SetCountLatency sets the "count_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) SetCountLatency(i int) {
+ m.count_latency = &i
+ m.addcount_latency = nil
+}
+
+// CountLatency returns the value of the "count_latency" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) CountLatency() (r int, exists bool) {
+ v := m.count_latency
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCountLatency returns the old "count_latency" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldCountLatency(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCountLatency is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCountLatency requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCountLatency: %w", err)
+ }
+ return oldValue.CountLatency, nil
+}
+
+// AddCountLatency adds i to the "count_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) AddCountLatency(i int) {
+ if m.addcount_latency != nil {
+ *m.addcount_latency += i
+ } else {
+ m.addcount_latency = &i
+ }
+}
+
+// AddedCountLatency returns the value that was added to the "count_latency" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedCountLatency() (r int, exists bool) {
+ v := m.addcount_latency
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetCountLatency resets all changes to the "count_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetCountLatency() {
+ m.count_latency = nil
+ m.addcount_latency = nil
+}
+
+// SetSumPingLatencyMs sets the "sum_ping_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) SetSumPingLatencyMs(i int64) {
+ m.sum_ping_latency_ms = &i
+ m.addsum_ping_latency_ms = nil
+}
+
+// SumPingLatencyMs returns the value of the "sum_ping_latency_ms" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) SumPingLatencyMs() (r int64, exists bool) {
+ v := m.sum_ping_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSumPingLatencyMs returns the old "sum_ping_latency_ms" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldSumPingLatencyMs(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSumPingLatencyMs is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSumPingLatencyMs requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSumPingLatencyMs: %w", err)
+ }
+ return oldValue.SumPingLatencyMs, nil
+}
+
+// AddSumPingLatencyMs adds i to the "sum_ping_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) AddSumPingLatencyMs(i int64) {
+ if m.addsum_ping_latency_ms != nil {
+ *m.addsum_ping_latency_ms += i
+ } else {
+ m.addsum_ping_latency_ms = &i
+ }
+}
+
+// AddedSumPingLatencyMs returns the value that was added to the "sum_ping_latency_ms" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedSumPingLatencyMs() (r int64, exists bool) {
+ v := m.addsum_ping_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetSumPingLatencyMs resets all changes to the "sum_ping_latency_ms" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetSumPingLatencyMs() {
+ m.sum_ping_latency_ms = nil
+ m.addsum_ping_latency_ms = nil
+}
+
+// SetCountPingLatency sets the "count_ping_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) SetCountPingLatency(i int) {
+ m.count_ping_latency = &i
+ m.addcount_ping_latency = nil
+}
+
+// CountPingLatency returns the value of the "count_ping_latency" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) CountPingLatency() (r int, exists bool) {
+ v := m.count_ping_latency
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCountPingLatency returns the old "count_ping_latency" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldCountPingLatency(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCountPingLatency is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCountPingLatency requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCountPingLatency: %w", err)
+ }
+ return oldValue.CountPingLatency, nil
+}
+
+// AddCountPingLatency adds i to the "count_ping_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) AddCountPingLatency(i int) {
+ if m.addcount_ping_latency != nil {
+ *m.addcount_ping_latency += i
+ } else {
+ m.addcount_ping_latency = &i
+ }
+}
+
+// AddedCountPingLatency returns the value that was added to the "count_ping_latency" field in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedCountPingLatency() (r int, exists bool) {
+ v := m.addcount_ping_latency
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetCountPingLatency resets all changes to the "count_ping_latency" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetCountPingLatency() {
+ m.count_ping_latency = nil
+ m.addcount_ping_latency = nil
+}
+
+// SetComputedAt sets the "computed_at" field.
+func (m *ChannelMonitorDailyRollupMutation) SetComputedAt(t time.Time) {
+ m.computed_at = &t
+}
+
+// ComputedAt returns the value of the "computed_at" field in the mutation.
+func (m *ChannelMonitorDailyRollupMutation) ComputedAt() (r time.Time, exists bool) {
+ v := m.computed_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldComputedAt returns the old "computed_at" field's value of the ChannelMonitorDailyRollup entity.
+// If the ChannelMonitorDailyRollup object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorDailyRollupMutation) OldComputedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldComputedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldComputedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldComputedAt: %w", err)
+ }
+ return oldValue.ComputedAt, nil
+}
+
+// ResetComputedAt resets all changes to the "computed_at" field.
+func (m *ChannelMonitorDailyRollupMutation) ResetComputedAt() {
+ m.computed_at = nil
+}
+
+// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity.
+func (m *ChannelMonitorDailyRollupMutation) ClearMonitor() {
+ m.clearedmonitor = true
+ m.clearedFields[channelmonitordailyrollup.FieldMonitorID] = struct{}{}
+}
+
+// MonitorCleared reports if the "monitor" edge to the ChannelMonitor entity was cleared.
+func (m *ChannelMonitorDailyRollupMutation) MonitorCleared() bool {
+ return m.clearedmonitor
+}
+
+// MonitorIDs returns the "monitor" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// MonitorID instead. It exists only for internal usage by the builders.
+func (m *ChannelMonitorDailyRollupMutation) MonitorIDs() (ids []int64) {
+ if id := m.monitor; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetMonitor resets all changes to the "monitor" edge.
+func (m *ChannelMonitorDailyRollupMutation) ResetMonitor() {
+ m.monitor = nil
+ m.clearedmonitor = false
+}
+
+// Where appends a list predicates to the ChannelMonitorDailyRollupMutation builder.
+func (m *ChannelMonitorDailyRollupMutation) Where(ps ...predicate.ChannelMonitorDailyRollup) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the ChannelMonitorDailyRollupMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *ChannelMonitorDailyRollupMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.ChannelMonitorDailyRollup, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *ChannelMonitorDailyRollupMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *ChannelMonitorDailyRollupMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (ChannelMonitorDailyRollup).
+func (m *ChannelMonitorDailyRollupMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *ChannelMonitorDailyRollupMutation) Fields() []string {
+ fields := make([]string, 0, 14)
+ if m.monitor != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldMonitorID)
+ }
+ if m.model != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldModel)
+ }
+ if m.bucket_date != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldBucketDate)
+ }
+ if m.total_checks != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldTotalChecks)
+ }
+ if m.ok_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldOkCount)
+ }
+ if m.operational_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldOperationalCount)
+ }
+ if m.degraded_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldDegradedCount)
+ }
+ if m.failed_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldFailedCount)
+ }
+ if m.error_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldErrorCount)
+ }
+ if m.sum_latency_ms != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldSumLatencyMs)
+ }
+ if m.count_latency != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldCountLatency)
+ }
+ if m.sum_ping_latency_ms != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldSumPingLatencyMs)
+ }
+ if m.count_ping_latency != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldCountPingLatency)
+ }
+ if m.computed_at != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldComputedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *ChannelMonitorDailyRollupMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitordailyrollup.FieldMonitorID:
+ return m.MonitorID()
+ case channelmonitordailyrollup.FieldModel:
+ return m.Model()
+ case channelmonitordailyrollup.FieldBucketDate:
+ return m.BucketDate()
+ case channelmonitordailyrollup.FieldTotalChecks:
+ return m.TotalChecks()
+ case channelmonitordailyrollup.FieldOkCount:
+ return m.OkCount()
+ case channelmonitordailyrollup.FieldOperationalCount:
+ return m.OperationalCount()
+ case channelmonitordailyrollup.FieldDegradedCount:
+ return m.DegradedCount()
+ case channelmonitordailyrollup.FieldFailedCount:
+ return m.FailedCount()
+ case channelmonitordailyrollup.FieldErrorCount:
+ return m.ErrorCount()
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ return m.SumLatencyMs()
+ case channelmonitordailyrollup.FieldCountLatency:
+ return m.CountLatency()
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ return m.SumPingLatencyMs()
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ return m.CountPingLatency()
+ case channelmonitordailyrollup.FieldComputedAt:
+ return m.ComputedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *ChannelMonitorDailyRollupMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case channelmonitordailyrollup.FieldMonitorID:
+ return m.OldMonitorID(ctx)
+ case channelmonitordailyrollup.FieldModel:
+ return m.OldModel(ctx)
+ case channelmonitordailyrollup.FieldBucketDate:
+ return m.OldBucketDate(ctx)
+ case channelmonitordailyrollup.FieldTotalChecks:
+ return m.OldTotalChecks(ctx)
+ case channelmonitordailyrollup.FieldOkCount:
+ return m.OldOkCount(ctx)
+ case channelmonitordailyrollup.FieldOperationalCount:
+ return m.OldOperationalCount(ctx)
+ case channelmonitordailyrollup.FieldDegradedCount:
+ return m.OldDegradedCount(ctx)
+ case channelmonitordailyrollup.FieldFailedCount:
+ return m.OldFailedCount(ctx)
+ case channelmonitordailyrollup.FieldErrorCount:
+ return m.OldErrorCount(ctx)
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ return m.OldSumLatencyMs(ctx)
+ case channelmonitordailyrollup.FieldCountLatency:
+ return m.OldCountLatency(ctx)
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ return m.OldSumPingLatencyMs(ctx)
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ return m.OldCountPingLatency(ctx)
+ case channelmonitordailyrollup.FieldComputedAt:
+ return m.OldComputedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown ChannelMonitorDailyRollup field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorDailyRollupMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitordailyrollup.FieldMonitorID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMonitorID(v)
+ return nil
+ case channelmonitordailyrollup.FieldModel:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetModel(v)
+ return nil
+ case channelmonitordailyrollup.FieldBucketDate:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBucketDate(v)
+ return nil
+ case channelmonitordailyrollup.FieldTotalChecks:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotalChecks(v)
+ return nil
+ case channelmonitordailyrollup.FieldOkCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetOkCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldOperationalCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetOperationalCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldDegradedCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDegradedCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldFailedCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetFailedCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldErrorCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetErrorCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSumLatencyMs(v)
+ return nil
+ case channelmonitordailyrollup.FieldCountLatency:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCountLatency(v)
+ return nil
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSumPingLatencyMs(v)
+ return nil
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCountPingLatency(v)
+ return nil
+ case channelmonitordailyrollup.FieldComputedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetComputedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedFields() []string {
+ var fields []string
+ if m.addtotal_checks != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldTotalChecks)
+ }
+ if m.addok_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldOkCount)
+ }
+ if m.addoperational_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldOperationalCount)
+ }
+ if m.adddegraded_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldDegradedCount)
+ }
+ if m.addfailed_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldFailedCount)
+ }
+ if m.adderror_count != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldErrorCount)
+ }
+ if m.addsum_latency_ms != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldSumLatencyMs)
+ }
+ if m.addcount_latency != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldCountLatency)
+ }
+ if m.addsum_ping_latency_ms != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldSumPingLatencyMs)
+ }
+ if m.addcount_ping_latency != nil {
+ fields = append(fields, channelmonitordailyrollup.FieldCountPingLatency)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *ChannelMonitorDailyRollupMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitordailyrollup.FieldTotalChecks:
+ return m.AddedTotalChecks()
+ case channelmonitordailyrollup.FieldOkCount:
+ return m.AddedOkCount()
+ case channelmonitordailyrollup.FieldOperationalCount:
+ return m.AddedOperationalCount()
+ case channelmonitordailyrollup.FieldDegradedCount:
+ return m.AddedDegradedCount()
+ case channelmonitordailyrollup.FieldFailedCount:
+ return m.AddedFailedCount()
+ case channelmonitordailyrollup.FieldErrorCount:
+ return m.AddedErrorCount()
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ return m.AddedSumLatencyMs()
+ case channelmonitordailyrollup.FieldCountLatency:
+ return m.AddedCountLatency()
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ return m.AddedSumPingLatencyMs()
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ return m.AddedCountPingLatency()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorDailyRollupMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitordailyrollup.FieldTotalChecks:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddTotalChecks(v)
+ return nil
+ case channelmonitordailyrollup.FieldOkCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddOkCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldOperationalCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddOperationalCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldDegradedCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddDegradedCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldFailedCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddFailedCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldErrorCount:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddErrorCount(v)
+ return nil
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSumLatencyMs(v)
+ return nil
+ case channelmonitordailyrollup.FieldCountLatency:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddCountLatency(v)
+ return nil
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSumPingLatencyMs(v)
+ return nil
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddCountPingLatency(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *ChannelMonitorDailyRollupMutation) ClearedFields() []string {
+ return nil
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *ChannelMonitorDailyRollupMutation) ClearField(name string) error {
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *ChannelMonitorDailyRollupMutation) ResetField(name string) error {
+ switch name {
+ case channelmonitordailyrollup.FieldMonitorID:
+ m.ResetMonitorID()
+ return nil
+ case channelmonitordailyrollup.FieldModel:
+ m.ResetModel()
+ return nil
+ case channelmonitordailyrollup.FieldBucketDate:
+ m.ResetBucketDate()
+ return nil
+ case channelmonitordailyrollup.FieldTotalChecks:
+ m.ResetTotalChecks()
+ return nil
+ case channelmonitordailyrollup.FieldOkCount:
+ m.ResetOkCount()
+ return nil
+ case channelmonitordailyrollup.FieldOperationalCount:
+ m.ResetOperationalCount()
+ return nil
+ case channelmonitordailyrollup.FieldDegradedCount:
+ m.ResetDegradedCount()
+ return nil
+ case channelmonitordailyrollup.FieldFailedCount:
+ m.ResetFailedCount()
+ return nil
+ case channelmonitordailyrollup.FieldErrorCount:
+ m.ResetErrorCount()
+ return nil
+ case channelmonitordailyrollup.FieldSumLatencyMs:
+ m.ResetSumLatencyMs()
+ return nil
+ case channelmonitordailyrollup.FieldCountLatency:
+ m.ResetCountLatency()
+ return nil
+ case channelmonitordailyrollup.FieldSumPingLatencyMs:
+ m.ResetSumPingLatencyMs()
+ return nil
+ case channelmonitordailyrollup.FieldCountPingLatency:
+ m.ResetCountPingLatency()
+ return nil
+ case channelmonitordailyrollup.FieldComputedAt:
+ m.ResetComputedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.monitor != nil {
+ edges = append(edges, channelmonitordailyrollup.EdgeMonitor)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case channelmonitordailyrollup.EdgeMonitor:
+ if id := m.monitor; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedmonitor {
+ edges = append(edges, channelmonitordailyrollup.EdgeMonitor)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *ChannelMonitorDailyRollupMutation) EdgeCleared(name string) bool {
+ switch name {
+ case channelmonitordailyrollup.EdgeMonitor:
+ return m.clearedmonitor
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *ChannelMonitorDailyRollupMutation) ClearEdge(name string) error {
+ switch name {
+ case channelmonitordailyrollup.EdgeMonitor:
+ m.ClearMonitor()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *ChannelMonitorDailyRollupMutation) ResetEdge(name string) error {
+ switch name {
+ case channelmonitordailyrollup.EdgeMonitor:
+ m.ResetMonitor()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorDailyRollup edge %s", name)
+}
+
+// ChannelMonitorHistoryMutation represents an operation that mutates the ChannelMonitorHistory nodes in the graph.
+type ChannelMonitorHistoryMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ model *string
+ status *channelmonitorhistory.Status
+ latency_ms *int
+ addlatency_ms *int
+ ping_latency_ms *int
+ addping_latency_ms *int
+ message *string
+ checked_at *time.Time
+ clearedFields map[string]struct{}
+ monitor *int64
+ clearedmonitor bool
+ done bool
+ oldValue func(context.Context) (*ChannelMonitorHistory, error)
+ predicates []predicate.ChannelMonitorHistory
+}
+
+var _ ent.Mutation = (*ChannelMonitorHistoryMutation)(nil)
+
+// channelmonitorhistoryOption allows management of the mutation configuration using functional options.
+type channelmonitorhistoryOption func(*ChannelMonitorHistoryMutation)
+
+// newChannelMonitorHistoryMutation creates new mutation for the ChannelMonitorHistory entity.
+func newChannelMonitorHistoryMutation(c config, op Op, opts ...channelmonitorhistoryOption) *ChannelMonitorHistoryMutation {
+ m := &ChannelMonitorHistoryMutation{
+ config: c,
+ op: op,
+ typ: TypeChannelMonitorHistory,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withChannelMonitorHistoryID sets the ID field of the mutation.
+func withChannelMonitorHistoryID(id int64) channelmonitorhistoryOption {
+ return func(m *ChannelMonitorHistoryMutation) {
+ var (
+ err error
+ once sync.Once
+ value *ChannelMonitorHistory
+ )
+ m.oldValue = func(ctx context.Context) (*ChannelMonitorHistory, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().ChannelMonitorHistory.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withChannelMonitorHistory sets the old ChannelMonitorHistory of the mutation.
+func withChannelMonitorHistory(node *ChannelMonitorHistory) channelmonitorhistoryOption {
+ return func(m *ChannelMonitorHistoryMutation) {
+ m.oldValue = func(context.Context) (*ChannelMonitorHistory, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m ChannelMonitorHistoryMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m ChannelMonitorHistoryMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *ChannelMonitorHistoryMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *ChannelMonitorHistoryMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().ChannelMonitorHistory.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetMonitorID sets the "monitor_id" field.
+func (m *ChannelMonitorHistoryMutation) SetMonitorID(i int64) {
+ m.monitor = &i
+}
+
+// MonitorID returns the value of the "monitor_id" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) MonitorID() (r int64, exists bool) {
+ v := m.monitor
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMonitorID returns the old "monitor_id" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldMonitorID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMonitorID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMonitorID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMonitorID: %w", err)
+ }
+ return oldValue.MonitorID, nil
+}
+
+// ResetMonitorID resets all changes to the "monitor_id" field.
+func (m *ChannelMonitorHistoryMutation) ResetMonitorID() {
+ m.monitor = nil
+}
+
+// SetModel sets the "model" field.
+func (m *ChannelMonitorHistoryMutation) SetModel(s string) {
+ m.model = &s
+}
+
+// Model returns the value of the "model" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) Model() (r string, exists bool) {
+ v := m.model
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldModel returns the old "model" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldModel(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldModel is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldModel requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldModel: %w", err)
+ }
+ return oldValue.Model, nil
+}
+
+// ResetModel resets all changes to the "model" field.
+func (m *ChannelMonitorHistoryMutation) ResetModel() {
+ m.model = nil
+}
+
+// SetStatus sets the "status" field.
+func (m *ChannelMonitorHistoryMutation) SetStatus(c channelmonitorhistory.Status) {
+ m.status = &c
+}
+
+// Status returns the value of the "status" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) Status() (r channelmonitorhistory.Status, exists bool) {
+ v := m.status
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldStatus returns the old "status" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldStatus(ctx context.Context) (v channelmonitorhistory.Status, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldStatus is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldStatus requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldStatus: %w", err)
+ }
+ return oldValue.Status, nil
+}
+
+// ResetStatus resets all changes to the "status" field.
+func (m *ChannelMonitorHistoryMutation) ResetStatus() {
+ m.status = nil
+}
+
+// SetLatencyMs sets the "latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) SetLatencyMs(i int) {
+ m.latency_ms = &i
+ m.addlatency_ms = nil
+}
+
+// LatencyMs returns the value of the "latency_ms" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) LatencyMs() (r int, exists bool) {
+ v := m.latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLatencyMs returns the old "latency_ms" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldLatencyMs(ctx context.Context) (v *int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLatencyMs is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLatencyMs requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLatencyMs: %w", err)
+ }
+ return oldValue.LatencyMs, nil
+}
+
+// AddLatencyMs adds i to the "latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) AddLatencyMs(i int) {
+ if m.addlatency_ms != nil {
+ *m.addlatency_ms += i
+ } else {
+ m.addlatency_ms = &i
+ }
+}
+
+// AddedLatencyMs returns the value that was added to the "latency_ms" field in this mutation.
+func (m *ChannelMonitorHistoryMutation) AddedLatencyMs() (r int, exists bool) {
+ v := m.addlatency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearLatencyMs clears the value of the "latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) ClearLatencyMs() {
+ m.latency_ms = nil
+ m.addlatency_ms = nil
+ m.clearedFields[channelmonitorhistory.FieldLatencyMs] = struct{}{}
+}
+
+// LatencyMsCleared returns if the "latency_ms" field was cleared in this mutation.
+func (m *ChannelMonitorHistoryMutation) LatencyMsCleared() bool {
+ _, ok := m.clearedFields[channelmonitorhistory.FieldLatencyMs]
+ return ok
+}
+
+// ResetLatencyMs resets all changes to the "latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) ResetLatencyMs() {
+ m.latency_ms = nil
+ m.addlatency_ms = nil
+ delete(m.clearedFields, channelmonitorhistory.FieldLatencyMs)
+}
+
+// SetPingLatencyMs sets the "ping_latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) SetPingLatencyMs(i int) {
+ m.ping_latency_ms = &i
+ m.addping_latency_ms = nil
+}
+
+// PingLatencyMs returns the value of the "ping_latency_ms" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) PingLatencyMs() (r int, exists bool) {
+ v := m.ping_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPingLatencyMs returns the old "ping_latency_ms" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldPingLatencyMs(ctx context.Context) (v *int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPingLatencyMs is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPingLatencyMs requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPingLatencyMs: %w", err)
+ }
+ return oldValue.PingLatencyMs, nil
+}
+
+// AddPingLatencyMs adds i to the "ping_latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) AddPingLatencyMs(i int) {
+ if m.addping_latency_ms != nil {
+ *m.addping_latency_ms += i
+ } else {
+ m.addping_latency_ms = &i
+ }
+}
+
+// AddedPingLatencyMs returns the value that was added to the "ping_latency_ms" field in this mutation.
+func (m *ChannelMonitorHistoryMutation) AddedPingLatencyMs() (r int, exists bool) {
+ v := m.addping_latency_ms
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearPingLatencyMs clears the value of the "ping_latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) ClearPingLatencyMs() {
+ m.ping_latency_ms = nil
+ m.addping_latency_ms = nil
+ m.clearedFields[channelmonitorhistory.FieldPingLatencyMs] = struct{}{}
+}
+
+// PingLatencyMsCleared returns if the "ping_latency_ms" field was cleared in this mutation.
+func (m *ChannelMonitorHistoryMutation) PingLatencyMsCleared() bool {
+ _, ok := m.clearedFields[channelmonitorhistory.FieldPingLatencyMs]
+ return ok
+}
+
+// ResetPingLatencyMs resets all changes to the "ping_latency_ms" field.
+func (m *ChannelMonitorHistoryMutation) ResetPingLatencyMs() {
+ m.ping_latency_ms = nil
+ m.addping_latency_ms = nil
+ delete(m.clearedFields, channelmonitorhistory.FieldPingLatencyMs)
+}
+
+// SetMessage sets the "message" field.
+func (m *ChannelMonitorHistoryMutation) SetMessage(s string) {
+ m.message = &s
+}
+
+// Message returns the value of the "message" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) Message() (r string, exists bool) {
+ v := m.message
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMessage returns the old "message" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldMessage(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMessage is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMessage requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMessage: %w", err)
+ }
+ return oldValue.Message, nil
+}
+
+// ClearMessage clears the value of the "message" field.
+func (m *ChannelMonitorHistoryMutation) ClearMessage() {
+ m.message = nil
+ m.clearedFields[channelmonitorhistory.FieldMessage] = struct{}{}
+}
+
+// MessageCleared returns if the "message" field was cleared in this mutation.
+func (m *ChannelMonitorHistoryMutation) MessageCleared() bool {
+ _, ok := m.clearedFields[channelmonitorhistory.FieldMessage]
+ return ok
+}
+
+// ResetMessage resets all changes to the "message" field.
+func (m *ChannelMonitorHistoryMutation) ResetMessage() {
+ m.message = nil
+ delete(m.clearedFields, channelmonitorhistory.FieldMessage)
+}
+
+// SetCheckedAt sets the "checked_at" field.
+func (m *ChannelMonitorHistoryMutation) SetCheckedAt(t time.Time) {
+ m.checked_at = &t
+}
+
+// CheckedAt returns the value of the "checked_at" field in the mutation.
+func (m *ChannelMonitorHistoryMutation) CheckedAt() (r time.Time, exists bool) {
+ v := m.checked_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCheckedAt returns the old "checked_at" field's value of the ChannelMonitorHistory entity.
+// If the ChannelMonitorHistory object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorHistoryMutation) OldCheckedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCheckedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCheckedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCheckedAt: %w", err)
+ }
+ return oldValue.CheckedAt, nil
+}
+
+// ResetCheckedAt resets all changes to the "checked_at" field.
+func (m *ChannelMonitorHistoryMutation) ResetCheckedAt() {
+ m.checked_at = nil
+}
+
+// ClearMonitor clears the "monitor" edge to the ChannelMonitor entity.
+func (m *ChannelMonitorHistoryMutation) ClearMonitor() {
+ m.clearedmonitor = true
+ m.clearedFields[channelmonitorhistory.FieldMonitorID] = struct{}{}
+}
+
+// MonitorCleared reports if the "monitor" edge to the ChannelMonitor entity was cleared.
+func (m *ChannelMonitorHistoryMutation) MonitorCleared() bool {
+ return m.clearedmonitor
+}
+
+// MonitorIDs returns the "monitor" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// MonitorID instead. It exists only for internal usage by the builders.
+func (m *ChannelMonitorHistoryMutation) MonitorIDs() (ids []int64) {
+ if id := m.monitor; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetMonitor resets all changes to the "monitor" edge.
+func (m *ChannelMonitorHistoryMutation) ResetMonitor() {
+ m.monitor = nil
+ m.clearedmonitor = false
+}
+
+// Where appends a list predicates to the ChannelMonitorHistoryMutation builder.
+func (m *ChannelMonitorHistoryMutation) Where(ps ...predicate.ChannelMonitorHistory) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the ChannelMonitorHistoryMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *ChannelMonitorHistoryMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.ChannelMonitorHistory, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *ChannelMonitorHistoryMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *ChannelMonitorHistoryMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (ChannelMonitorHistory).
+func (m *ChannelMonitorHistoryMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *ChannelMonitorHistoryMutation) Fields() []string {
+ fields := make([]string, 0, 7)
+ if m.monitor != nil {
+ fields = append(fields, channelmonitorhistory.FieldMonitorID)
+ }
+ if m.model != nil {
+ fields = append(fields, channelmonitorhistory.FieldModel)
+ }
+ if m.status != nil {
+ fields = append(fields, channelmonitorhistory.FieldStatus)
+ }
+ if m.latency_ms != nil {
+ fields = append(fields, channelmonitorhistory.FieldLatencyMs)
+ }
+ if m.ping_latency_ms != nil {
+ fields = append(fields, channelmonitorhistory.FieldPingLatencyMs)
+ }
+ if m.message != nil {
+ fields = append(fields, channelmonitorhistory.FieldMessage)
+ }
+ if m.checked_at != nil {
+ fields = append(fields, channelmonitorhistory.FieldCheckedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *ChannelMonitorHistoryMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitorhistory.FieldMonitorID:
+ return m.MonitorID()
+ case channelmonitorhistory.FieldModel:
+ return m.Model()
+ case channelmonitorhistory.FieldStatus:
+ return m.Status()
+ case channelmonitorhistory.FieldLatencyMs:
+ return m.LatencyMs()
+ case channelmonitorhistory.FieldPingLatencyMs:
+ return m.PingLatencyMs()
+ case channelmonitorhistory.FieldMessage:
+ return m.Message()
+ case channelmonitorhistory.FieldCheckedAt:
+ return m.CheckedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *ChannelMonitorHistoryMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case channelmonitorhistory.FieldMonitorID:
+ return m.OldMonitorID(ctx)
+ case channelmonitorhistory.FieldModel:
+ return m.OldModel(ctx)
+ case channelmonitorhistory.FieldStatus:
+ return m.OldStatus(ctx)
+ case channelmonitorhistory.FieldLatencyMs:
+ return m.OldLatencyMs(ctx)
+ case channelmonitorhistory.FieldPingLatencyMs:
+ return m.OldPingLatencyMs(ctx)
+ case channelmonitorhistory.FieldMessage:
+ return m.OldMessage(ctx)
+ case channelmonitorhistory.FieldCheckedAt:
+ return m.OldCheckedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown ChannelMonitorHistory field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorHistoryMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitorhistory.FieldMonitorID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMonitorID(v)
+ return nil
+ case channelmonitorhistory.FieldModel:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetModel(v)
+ return nil
+ case channelmonitorhistory.FieldStatus:
+ v, ok := value.(channelmonitorhistory.Status)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetStatus(v)
+ return nil
+ case channelmonitorhistory.FieldLatencyMs:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLatencyMs(v)
+ return nil
+ case channelmonitorhistory.FieldPingLatencyMs:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPingLatencyMs(v)
+ return nil
+ case channelmonitorhistory.FieldMessage:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMessage(v)
+ return nil
+ case channelmonitorhistory.FieldCheckedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCheckedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorHistory field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *ChannelMonitorHistoryMutation) AddedFields() []string {
+ var fields []string
+ if m.addlatency_ms != nil {
+ fields = append(fields, channelmonitorhistory.FieldLatencyMs)
+ }
+ if m.addping_latency_ms != nil {
+ fields = append(fields, channelmonitorhistory.FieldPingLatencyMs)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *ChannelMonitorHistoryMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitorhistory.FieldLatencyMs:
+ return m.AddedLatencyMs()
+ case channelmonitorhistory.FieldPingLatencyMs:
+ return m.AddedPingLatencyMs()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorHistoryMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitorhistory.FieldLatencyMs:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddLatencyMs(v)
+ return nil
+ case channelmonitorhistory.FieldPingLatencyMs:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddPingLatencyMs(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorHistory numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *ChannelMonitorHistoryMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(channelmonitorhistory.FieldLatencyMs) {
+ fields = append(fields, channelmonitorhistory.FieldLatencyMs)
+ }
+ if m.FieldCleared(channelmonitorhistory.FieldPingLatencyMs) {
+ fields = append(fields, channelmonitorhistory.FieldPingLatencyMs)
+ }
+ if m.FieldCleared(channelmonitorhistory.FieldMessage) {
+ fields = append(fields, channelmonitorhistory.FieldMessage)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *ChannelMonitorHistoryMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *ChannelMonitorHistoryMutation) ClearField(name string) error {
+ switch name {
+ case channelmonitorhistory.FieldLatencyMs:
+ m.ClearLatencyMs()
+ return nil
+ case channelmonitorhistory.FieldPingLatencyMs:
+ m.ClearPingLatencyMs()
+ return nil
+ case channelmonitorhistory.FieldMessage:
+ m.ClearMessage()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorHistory nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *ChannelMonitorHistoryMutation) ResetField(name string) error {
+ switch name {
+ case channelmonitorhistory.FieldMonitorID:
+ m.ResetMonitorID()
+ return nil
+ case channelmonitorhistory.FieldModel:
+ m.ResetModel()
+ return nil
+ case channelmonitorhistory.FieldStatus:
+ m.ResetStatus()
+ return nil
+ case channelmonitorhistory.FieldLatencyMs:
+ m.ResetLatencyMs()
+ return nil
+ case channelmonitorhistory.FieldPingLatencyMs:
+ m.ResetPingLatencyMs()
+ return nil
+ case channelmonitorhistory.FieldMessage:
+ m.ResetMessage()
+ return nil
+ case channelmonitorhistory.FieldCheckedAt:
+ m.ResetCheckedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorHistory field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *ChannelMonitorHistoryMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.monitor != nil {
+ edges = append(edges, channelmonitorhistory.EdgeMonitor)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *ChannelMonitorHistoryMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case channelmonitorhistory.EdgeMonitor:
+ if id := m.monitor; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *ChannelMonitorHistoryMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *ChannelMonitorHistoryMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *ChannelMonitorHistoryMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedmonitor {
+ edges = append(edges, channelmonitorhistory.EdgeMonitor)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *ChannelMonitorHistoryMutation) EdgeCleared(name string) bool {
+ switch name {
+ case channelmonitorhistory.EdgeMonitor:
+ return m.clearedmonitor
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *ChannelMonitorHistoryMutation) ClearEdge(name string) error {
+ switch name {
+ case channelmonitorhistory.EdgeMonitor:
+ m.ClearMonitor()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorHistory unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *ChannelMonitorHistoryMutation) ResetEdge(name string) error {
+ switch name {
+ case channelmonitorhistory.EdgeMonitor:
+ m.ResetMonitor()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorHistory edge %s", name)
+}
+
+// ChannelMonitorRequestTemplateMutation represents an operation that mutates the ChannelMonitorRequestTemplate nodes in the graph.
+type ChannelMonitorRequestTemplateMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ name *string
+ provider *channelmonitorrequesttemplate.Provider
+ description *string
+ extra_headers *map[string]string
+ body_override_mode *string
+ body_override *map[string]interface{}
+ clearedFields map[string]struct{}
+ monitors map[int64]struct{}
+ removedmonitors map[int64]struct{}
+ clearedmonitors bool
+ done bool
+ oldValue func(context.Context) (*ChannelMonitorRequestTemplate, error)
+ predicates []predicate.ChannelMonitorRequestTemplate
+}
+
+var _ ent.Mutation = (*ChannelMonitorRequestTemplateMutation)(nil)
+
+// channelmonitorrequesttemplateOption allows management of the mutation configuration using functional options.
+type channelmonitorrequesttemplateOption func(*ChannelMonitorRequestTemplateMutation)
+
+// newChannelMonitorRequestTemplateMutation creates new mutation for the ChannelMonitorRequestTemplate entity.
+func newChannelMonitorRequestTemplateMutation(c config, op Op, opts ...channelmonitorrequesttemplateOption) *ChannelMonitorRequestTemplateMutation {
+ m := &ChannelMonitorRequestTemplateMutation{
+ config: c,
+ op: op,
+ typ: TypeChannelMonitorRequestTemplate,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withChannelMonitorRequestTemplateID sets the ID field of the mutation.
+func withChannelMonitorRequestTemplateID(id int64) channelmonitorrequesttemplateOption {
+ return func(m *ChannelMonitorRequestTemplateMutation) {
+ var (
+ err error
+ once sync.Once
+ value *ChannelMonitorRequestTemplate
+ )
+ m.oldValue = func(ctx context.Context) (*ChannelMonitorRequestTemplate, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().ChannelMonitorRequestTemplate.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withChannelMonitorRequestTemplate sets the old ChannelMonitorRequestTemplate of the mutation.
+func withChannelMonitorRequestTemplate(node *ChannelMonitorRequestTemplate) channelmonitorrequesttemplateOption {
+ return func(m *ChannelMonitorRequestTemplateMutation) {
+ m.oldValue = func(context.Context) (*ChannelMonitorRequestTemplate, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m ChannelMonitorRequestTemplateMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m ChannelMonitorRequestTemplateMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *ChannelMonitorRequestTemplateMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().ChannelMonitorRequestTemplate.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetName sets the "name" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetName(s string) {
+ m.name = &s
+}
+
+// Name returns the value of the "name" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) Name() (r string, exists bool) {
+ v := m.name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldName returns the old "name" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldName(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldName: %w", err)
+ }
+ return oldValue.Name, nil
+}
+
+// ResetName resets all changes to the "name" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetName() {
+ m.name = nil
+}
+
+// SetProvider sets the "provider" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetProvider(c channelmonitorrequesttemplate.Provider) {
+ m.provider = &c
+}
+
+// Provider returns the value of the "provider" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) Provider() (r channelmonitorrequesttemplate.Provider, exists bool) {
+ v := m.provider
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProvider returns the old "provider" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldProvider(ctx context.Context) (v channelmonitorrequesttemplate.Provider, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProvider is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProvider requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProvider: %w", err)
+ }
+ return oldValue.Provider, nil
+}
+
+// ResetProvider resets all changes to the "provider" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetProvider() {
+ m.provider = nil
+}
+
+// SetDescription sets the "description" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetDescription(s string) {
+ m.description = &s
+}
+
+// Description returns the value of the "description" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) Description() (r string, exists bool) {
+ v := m.description
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDescription returns the old "description" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldDescription(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDescription is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDescription requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDescription: %w", err)
+ }
+ return oldValue.Description, nil
+}
+
+// ClearDescription clears the value of the "description" field.
+func (m *ChannelMonitorRequestTemplateMutation) ClearDescription() {
+ m.description = nil
+ m.clearedFields[channelmonitorrequesttemplate.FieldDescription] = struct{}{}
+}
+
+// DescriptionCleared returns if the "description" field was cleared in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) DescriptionCleared() bool {
+ _, ok := m.clearedFields[channelmonitorrequesttemplate.FieldDescription]
+ return ok
+}
+
+// ResetDescription resets all changes to the "description" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetDescription() {
+ m.description = nil
+ delete(m.clearedFields, channelmonitorrequesttemplate.FieldDescription)
+}
+
+// SetExtraHeaders sets the "extra_headers" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetExtraHeaders(value map[string]string) {
+ m.extra_headers = &value
+}
+
+// ExtraHeaders returns the value of the "extra_headers" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) ExtraHeaders() (r map[string]string, exists bool) {
+ v := m.extra_headers
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExtraHeaders returns the old "extra_headers" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldExtraHeaders(ctx context.Context) (v map[string]string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExtraHeaders is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExtraHeaders requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExtraHeaders: %w", err)
+ }
+ return oldValue.ExtraHeaders, nil
+}
+
+// ResetExtraHeaders resets all changes to the "extra_headers" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetExtraHeaders() {
+ m.extra_headers = nil
+}
+
+// SetBodyOverrideMode sets the "body_override_mode" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetBodyOverrideMode(s string) {
+ m.body_override_mode = &s
+}
+
+// BodyOverrideMode returns the value of the "body_override_mode" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) BodyOverrideMode() (r string, exists bool) {
+ v := m.body_override_mode
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBodyOverrideMode returns the old "body_override_mode" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldBodyOverrideMode(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBodyOverrideMode is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBodyOverrideMode requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBodyOverrideMode: %w", err)
+ }
+ return oldValue.BodyOverrideMode, nil
+}
+
+// ResetBodyOverrideMode resets all changes to the "body_override_mode" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetBodyOverrideMode() {
+ m.body_override_mode = nil
+}
+
+// SetBodyOverride sets the "body_override" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetBodyOverride(value map[string]interface{}) {
+ m.body_override = &value
+}
+
+// BodyOverride returns the value of the "body_override" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) BodyOverride() (r map[string]interface{}, exists bool) {
+ v := m.body_override
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBodyOverride returns the old "body_override" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldBodyOverride(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBodyOverride is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBodyOverride requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBodyOverride: %w", err)
+ }
+ return oldValue.BodyOverride, nil
+}
+
+// ClearBodyOverride clears the value of the "body_override" field.
+func (m *ChannelMonitorRequestTemplateMutation) ClearBodyOverride() {
+ m.body_override = nil
+ m.clearedFields[channelmonitorrequesttemplate.FieldBodyOverride] = struct{}{}
+}
+
+// BodyOverrideCleared returns if the "body_override" field was cleared in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) BodyOverrideCleared() bool {
+ _, ok := m.clearedFields[channelmonitorrequesttemplate.FieldBodyOverride]
+ return ok
+}
+
+// ResetBodyOverride resets all changes to the "body_override" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetBodyOverride() {
+ m.body_override = nil
+ delete(m.clearedFields, channelmonitorrequesttemplate.FieldBodyOverride)
+}
+
+// AddMonitorIDs adds the "monitors" edge to the ChannelMonitor entity by ids.
+func (m *ChannelMonitorRequestTemplateMutation) AddMonitorIDs(ids ...int64) {
+ if m.monitors == nil {
+ m.monitors = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.monitors[ids[i]] = struct{}{}
+ }
+}
+
+// ClearMonitors clears the "monitors" edge to the ChannelMonitor entity.
+func (m *ChannelMonitorRequestTemplateMutation) ClearMonitors() {
+ m.clearedmonitors = true
+}
+
+// MonitorsCleared reports if the "monitors" edge to the ChannelMonitor entity was cleared.
+func (m *ChannelMonitorRequestTemplateMutation) MonitorsCleared() bool {
+ return m.clearedmonitors
+}
+
+// RemoveMonitorIDs removes the "monitors" edge to the ChannelMonitor entity by IDs.
+func (m *ChannelMonitorRequestTemplateMutation) RemoveMonitorIDs(ids ...int64) {
+ if m.removedmonitors == nil {
+ m.removedmonitors = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.monitors, ids[i])
+ m.removedmonitors[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedMonitors returns the removed IDs of the "monitors" edge to the ChannelMonitor entity.
+func (m *ChannelMonitorRequestTemplateMutation) RemovedMonitorsIDs() (ids []int64) {
+ for id := range m.removedmonitors {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// MonitorsIDs returns the "monitors" edge IDs in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) MonitorsIDs() (ids []int64) {
+ for id := range m.monitors {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetMonitors resets all changes to the "monitors" edge.
+func (m *ChannelMonitorRequestTemplateMutation) ResetMonitors() {
+ m.monitors = nil
+ m.clearedmonitors = false
+ m.removedmonitors = nil
+}
+
+// Where appends a list predicates to the ChannelMonitorRequestTemplateMutation builder.
+func (m *ChannelMonitorRequestTemplateMutation) Where(ps ...predicate.ChannelMonitorRequestTemplate) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the ChannelMonitorRequestTemplateMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *ChannelMonitorRequestTemplateMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.ChannelMonitorRequestTemplate, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *ChannelMonitorRequestTemplateMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *ChannelMonitorRequestTemplateMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (ChannelMonitorRequestTemplate).
+func (m *ChannelMonitorRequestTemplateMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *ChannelMonitorRequestTemplateMutation) Fields() []string {
+ fields := make([]string, 0, 8)
+ if m.created_at != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldUpdatedAt)
+ }
+ if m.name != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldName)
+ }
+ if m.provider != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldProvider)
+ }
+ if m.description != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldDescription)
+ }
+ if m.extra_headers != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldExtraHeaders)
+ }
+ if m.body_override_mode != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldBodyOverrideMode)
+ }
+ if m.body_override != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldBodyOverride)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *ChannelMonitorRequestTemplateMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case channelmonitorrequesttemplate.FieldCreatedAt:
+ return m.CreatedAt()
+ case channelmonitorrequesttemplate.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case channelmonitorrequesttemplate.FieldName:
+ return m.Name()
+ case channelmonitorrequesttemplate.FieldProvider:
+ return m.Provider()
+ case channelmonitorrequesttemplate.FieldDescription:
+ return m.Description()
+ case channelmonitorrequesttemplate.FieldExtraHeaders:
+ return m.ExtraHeaders()
+ case channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ return m.BodyOverrideMode()
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ return m.BodyOverride()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *ChannelMonitorRequestTemplateMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case channelmonitorrequesttemplate.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case channelmonitorrequesttemplate.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case channelmonitorrequesttemplate.FieldName:
+ return m.OldName(ctx)
+ case channelmonitorrequesttemplate.FieldProvider:
+ return m.OldProvider(ctx)
+ case channelmonitorrequesttemplate.FieldDescription:
+ return m.OldDescription(ctx)
+ case channelmonitorrequesttemplate.FieldExtraHeaders:
+ return m.OldExtraHeaders(ctx)
+ case channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ return m.OldBodyOverrideMode(ctx)
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ return m.OldBodyOverride(ctx)
+ }
+ return nil, fmt.Errorf("unknown ChannelMonitorRequestTemplate field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorRequestTemplateMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case channelmonitorrequesttemplate.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldName:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetName(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldProvider:
+ v, ok := value.(channelmonitorrequesttemplate.Provider)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProvider(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldDescription:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDescription(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldExtraHeaders:
+ v, ok := value.(map[string]string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExtraHeaders(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBodyOverrideMode(v)
+ return nil
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBodyOverride(v)
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) AddedFields() []string {
+ return nil
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *ChannelMonitorRequestTemplateMutation) AddedField(name string) (ent.Value, bool) {
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *ChannelMonitorRequestTemplateMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *ChannelMonitorRequestTemplateMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(channelmonitorrequesttemplate.FieldDescription) {
+ fields = append(fields, channelmonitorrequesttemplate.FieldDescription)
+ }
+ if m.FieldCleared(channelmonitorrequesttemplate.FieldBodyOverride) {
+ fields = append(fields, channelmonitorrequesttemplate.FieldBodyOverride)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *ChannelMonitorRequestTemplateMutation) ClearField(name string) error {
+ switch name {
+ case channelmonitorrequesttemplate.FieldDescription:
+ m.ClearDescription()
+ return nil
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ m.ClearBodyOverride()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *ChannelMonitorRequestTemplateMutation) ResetField(name string) error {
+ switch name {
+ case channelmonitorrequesttemplate.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case channelmonitorrequesttemplate.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case channelmonitorrequesttemplate.FieldName:
+ m.ResetName()
+ return nil
+ case channelmonitorrequesttemplate.FieldProvider:
+ m.ResetProvider()
+ return nil
+ case channelmonitorrequesttemplate.FieldDescription:
+ m.ResetDescription()
+ return nil
+ case channelmonitorrequesttemplate.FieldExtraHeaders:
+ m.ResetExtraHeaders()
+ return nil
+ case channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ m.ResetBodyOverrideMode()
+ return nil
+ case channelmonitorrequesttemplate.FieldBodyOverride:
+ m.ResetBodyOverride()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.monitors != nil {
+ edges = append(edges, channelmonitorrequesttemplate.EdgeMonitors)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case channelmonitorrequesttemplate.EdgeMonitors:
+ ids := make([]ent.Value, 0, len(m.monitors))
+ for id := range m.monitors {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.removedmonitors != nil {
+ edges = append(edges, channelmonitorrequesttemplate.EdgeMonitors)
+ }
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) RemovedIDs(name string) []ent.Value {
+ switch name {
+ case channelmonitorrequesttemplate.EdgeMonitors:
+ ids := make([]ent.Value, 0, len(m.removedmonitors))
+ for id := range m.removedmonitors {
+ ids = append(ids, id)
+ }
+ return ids
+ }
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.clearedmonitors {
+ edges = append(edges, channelmonitorrequesttemplate.EdgeMonitors)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *ChannelMonitorRequestTemplateMutation) EdgeCleared(name string) bool {
+ switch name {
+ case channelmonitorrequesttemplate.EdgeMonitors:
+ return m.clearedmonitors
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *ChannelMonitorRequestTemplateMutation) ClearEdge(name string) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *ChannelMonitorRequestTemplateMutation) ResetEdge(name string) error {
+ switch name {
+ case channelmonitorrequesttemplate.EdgeMonitors:
+ m.ResetMonitors()
+ return nil
+ }
+ return fmt.Errorf("unknown ChannelMonitorRequestTemplate edge %s", name)
+}
+
// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph.
type ErrorPassthroughRuleMutation struct {
config
@@ -8230,16 +14770,6 @@ type GroupMutation struct {
addimage_price_2k *float64
image_price_4k *float64
addimage_price_4k *float64
- sora_image_price_360 *float64
- addsora_image_price_360 *float64
- sora_image_price_540 *float64
- addsora_image_price_540 *float64
- sora_video_price_per_request *float64
- addsora_video_price_per_request *float64
- sora_video_price_per_request_hd *float64
- addsora_video_price_per_request_hd *float64
- sora_storage_quota_bytes *int64
- addsora_storage_quota_bytes *int64
claude_code_only *bool
fallback_group_id *int64
addfallback_group_id *int64
@@ -8256,6 +14786,9 @@ type GroupMutation struct {
require_oauth_only *bool
require_privacy_set *bool
default_mapped_model *string
+ messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig
+ rpm_limit *int
+ addrpm_limit *int
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -9260,342 +15793,6 @@ func (m *GroupMutation) ResetImagePrice4k() {
delete(m.clearedFields, group.FieldImagePrice4k)
}
-// SetSoraImagePrice360 sets the "sora_image_price_360" field.
-func (m *GroupMutation) SetSoraImagePrice360(f float64) {
- m.sora_image_price_360 = &f
- m.addsora_image_price_360 = nil
-}
-
-// SoraImagePrice360 returns the value of the "sora_image_price_360" field in the mutation.
-func (m *GroupMutation) SoraImagePrice360() (r float64, exists bool) {
- v := m.sora_image_price_360
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldSoraImagePrice360 returns the old "sora_image_price_360" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldSoraImagePrice360(ctx context.Context) (v *float64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSoraImagePrice360 is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSoraImagePrice360 requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSoraImagePrice360: %w", err)
- }
- return oldValue.SoraImagePrice360, nil
-}
-
-// AddSoraImagePrice360 adds f to the "sora_image_price_360" field.
-func (m *GroupMutation) AddSoraImagePrice360(f float64) {
- if m.addsora_image_price_360 != nil {
- *m.addsora_image_price_360 += f
- } else {
- m.addsora_image_price_360 = &f
- }
-}
-
-// AddedSoraImagePrice360 returns the value that was added to the "sora_image_price_360" field in this mutation.
-func (m *GroupMutation) AddedSoraImagePrice360() (r float64, exists bool) {
- v := m.addsora_image_price_360
- if v == nil {
- return
- }
- return *v, true
-}
-
-// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field.
-func (m *GroupMutation) ClearSoraImagePrice360() {
- m.sora_image_price_360 = nil
- m.addsora_image_price_360 = nil
- m.clearedFields[group.FieldSoraImagePrice360] = struct{}{}
-}
-
-// SoraImagePrice360Cleared returns if the "sora_image_price_360" field was cleared in this mutation.
-func (m *GroupMutation) SoraImagePrice360Cleared() bool {
- _, ok := m.clearedFields[group.FieldSoraImagePrice360]
- return ok
-}
-
-// ResetSoraImagePrice360 resets all changes to the "sora_image_price_360" field.
-func (m *GroupMutation) ResetSoraImagePrice360() {
- m.sora_image_price_360 = nil
- m.addsora_image_price_360 = nil
- delete(m.clearedFields, group.FieldSoraImagePrice360)
-}
-
-// SetSoraImagePrice540 sets the "sora_image_price_540" field.
-func (m *GroupMutation) SetSoraImagePrice540(f float64) {
- m.sora_image_price_540 = &f
- m.addsora_image_price_540 = nil
-}
-
-// SoraImagePrice540 returns the value of the "sora_image_price_540" field in the mutation.
-func (m *GroupMutation) SoraImagePrice540() (r float64, exists bool) {
- v := m.sora_image_price_540
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldSoraImagePrice540 returns the old "sora_image_price_540" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldSoraImagePrice540(ctx context.Context) (v *float64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSoraImagePrice540 is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSoraImagePrice540 requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSoraImagePrice540: %w", err)
- }
- return oldValue.SoraImagePrice540, nil
-}
-
-// AddSoraImagePrice540 adds f to the "sora_image_price_540" field.
-func (m *GroupMutation) AddSoraImagePrice540(f float64) {
- if m.addsora_image_price_540 != nil {
- *m.addsora_image_price_540 += f
- } else {
- m.addsora_image_price_540 = &f
- }
-}
-
-// AddedSoraImagePrice540 returns the value that was added to the "sora_image_price_540" field in this mutation.
-func (m *GroupMutation) AddedSoraImagePrice540() (r float64, exists bool) {
- v := m.addsora_image_price_540
- if v == nil {
- return
- }
- return *v, true
-}
-
-// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field.
-func (m *GroupMutation) ClearSoraImagePrice540() {
- m.sora_image_price_540 = nil
- m.addsora_image_price_540 = nil
- m.clearedFields[group.FieldSoraImagePrice540] = struct{}{}
-}
-
-// SoraImagePrice540Cleared returns if the "sora_image_price_540" field was cleared in this mutation.
-func (m *GroupMutation) SoraImagePrice540Cleared() bool {
- _, ok := m.clearedFields[group.FieldSoraImagePrice540]
- return ok
-}
-
-// ResetSoraImagePrice540 resets all changes to the "sora_image_price_540" field.
-func (m *GroupMutation) ResetSoraImagePrice540() {
- m.sora_image_price_540 = nil
- m.addsora_image_price_540 = nil
- delete(m.clearedFields, group.FieldSoraImagePrice540)
-}
-
-// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field.
-func (m *GroupMutation) SetSoraVideoPricePerRequest(f float64) {
- m.sora_video_price_per_request = &f
- m.addsora_video_price_per_request = nil
-}
-
-// SoraVideoPricePerRequest returns the value of the "sora_video_price_per_request" field in the mutation.
-func (m *GroupMutation) SoraVideoPricePerRequest() (r float64, exists bool) {
- v := m.sora_video_price_per_request
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldSoraVideoPricePerRequest returns the old "sora_video_price_per_request" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldSoraVideoPricePerRequest(ctx context.Context) (v *float64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSoraVideoPricePerRequest is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSoraVideoPricePerRequest requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequest: %w", err)
- }
- return oldValue.SoraVideoPricePerRequest, nil
-}
-
-// AddSoraVideoPricePerRequest adds f to the "sora_video_price_per_request" field.
-func (m *GroupMutation) AddSoraVideoPricePerRequest(f float64) {
- if m.addsora_video_price_per_request != nil {
- *m.addsora_video_price_per_request += f
- } else {
- m.addsora_video_price_per_request = &f
- }
-}
-
-// AddedSoraVideoPricePerRequest returns the value that was added to the "sora_video_price_per_request" field in this mutation.
-func (m *GroupMutation) AddedSoraVideoPricePerRequest() (r float64, exists bool) {
- v := m.addsora_video_price_per_request
- if v == nil {
- return
- }
- return *v, true
-}
-
-// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field.
-func (m *GroupMutation) ClearSoraVideoPricePerRequest() {
- m.sora_video_price_per_request = nil
- m.addsora_video_price_per_request = nil
- m.clearedFields[group.FieldSoraVideoPricePerRequest] = struct{}{}
-}
-
-// SoraVideoPricePerRequestCleared returns if the "sora_video_price_per_request" field was cleared in this mutation.
-func (m *GroupMutation) SoraVideoPricePerRequestCleared() bool {
- _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequest]
- return ok
-}
-
-// ResetSoraVideoPricePerRequest resets all changes to the "sora_video_price_per_request" field.
-func (m *GroupMutation) ResetSoraVideoPricePerRequest() {
- m.sora_video_price_per_request = nil
- m.addsora_video_price_per_request = nil
- delete(m.clearedFields, group.FieldSoraVideoPricePerRequest)
-}
-
-// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field.
-func (m *GroupMutation) SetSoraVideoPricePerRequestHd(f float64) {
- m.sora_video_price_per_request_hd = &f
- m.addsora_video_price_per_request_hd = nil
-}
-
-// SoraVideoPricePerRequestHd returns the value of the "sora_video_price_per_request_hd" field in the mutation.
-func (m *GroupMutation) SoraVideoPricePerRequestHd() (r float64, exists bool) {
- v := m.sora_video_price_per_request_hd
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldSoraVideoPricePerRequestHd returns the old "sora_video_price_per_request_hd" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldSoraVideoPricePerRequestHd(ctx context.Context) (v *float64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSoraVideoPricePerRequestHd is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSoraVideoPricePerRequestHd requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequestHd: %w", err)
- }
- return oldValue.SoraVideoPricePerRequestHd, nil
-}
-
-// AddSoraVideoPricePerRequestHd adds f to the "sora_video_price_per_request_hd" field.
-func (m *GroupMutation) AddSoraVideoPricePerRequestHd(f float64) {
- if m.addsora_video_price_per_request_hd != nil {
- *m.addsora_video_price_per_request_hd += f
- } else {
- m.addsora_video_price_per_request_hd = &f
- }
-}
-
-// AddedSoraVideoPricePerRequestHd returns the value that was added to the "sora_video_price_per_request_hd" field in this mutation.
-func (m *GroupMutation) AddedSoraVideoPricePerRequestHd() (r float64, exists bool) {
- v := m.addsora_video_price_per_request_hd
- if v == nil {
- return
- }
- return *v, true
-}
-
-// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field.
-func (m *GroupMutation) ClearSoraVideoPricePerRequestHd() {
- m.sora_video_price_per_request_hd = nil
- m.addsora_video_price_per_request_hd = nil
- m.clearedFields[group.FieldSoraVideoPricePerRequestHd] = struct{}{}
-}
-
-// SoraVideoPricePerRequestHdCleared returns if the "sora_video_price_per_request_hd" field was cleared in this mutation.
-func (m *GroupMutation) SoraVideoPricePerRequestHdCleared() bool {
- _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequestHd]
- return ok
-}
-
-// ResetSoraVideoPricePerRequestHd resets all changes to the "sora_video_price_per_request_hd" field.
-func (m *GroupMutation) ResetSoraVideoPricePerRequestHd() {
- m.sora_video_price_per_request_hd = nil
- m.addsora_video_price_per_request_hd = nil
- delete(m.clearedFields, group.FieldSoraVideoPricePerRequestHd)
-}
-
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (m *GroupMutation) SetSoraStorageQuotaBytes(i int64) {
- m.sora_storage_quota_bytes = &i
- m.addsora_storage_quota_bytes = nil
-}
-
-// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation.
-func (m *GroupMutation) SoraStorageQuotaBytes() (r int64, exists bool) {
- v := m.sora_storage_quota_bytes
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the Group entity.
-// If the Group object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *GroupMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err)
- }
- return oldValue.SoraStorageQuotaBytes, nil
-}
-
-// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field.
-func (m *GroupMutation) AddSoraStorageQuotaBytes(i int64) {
- if m.addsora_storage_quota_bytes != nil {
- *m.addsora_storage_quota_bytes += i
- } else {
- m.addsora_storage_quota_bytes = &i
- }
-}
-
-// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation.
-func (m *GroupMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) {
- v := m.addsora_storage_quota_bytes
- if v == nil {
- return
- }
- return *v, true
-}
-
-// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field.
-func (m *GroupMutation) ResetSoraStorageQuotaBytes() {
- m.sora_storage_quota_bytes = nil
- m.addsora_storage_quota_bytes = nil
-}
-
// SetClaudeCodeOnly sets the "claude_code_only" field.
func (m *GroupMutation) SetClaudeCodeOnly(b bool) {
m.claude_code_only = &b
@@ -10144,6 +16341,98 @@ func (m *GroupMutation) ResetDefaultMappedModel() {
m.default_mapped_model = nil
}
+// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field.
+func (m *GroupMutation) SetMessagesDispatchModelConfig(damdmc domain.OpenAIMessagesDispatchModelConfig) {
+ m.messages_dispatch_model_config = &damdmc
+}
+
+// MessagesDispatchModelConfig returns the value of the "messages_dispatch_model_config" field in the mutation.
+func (m *GroupMutation) MessagesDispatchModelConfig() (r domain.OpenAIMessagesDispatchModelConfig, exists bool) {
+ v := m.messages_dispatch_model_config
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldMessagesDispatchModelConfig returns the old "messages_dispatch_model_config" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldMessagesDispatchModelConfig(ctx context.Context) (v domain.OpenAIMessagesDispatchModelConfig, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldMessagesDispatchModelConfig is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldMessagesDispatchModelConfig requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldMessagesDispatchModelConfig: %w", err)
+ }
+ return oldValue.MessagesDispatchModelConfig, nil
+}
+
+// ResetMessagesDispatchModelConfig resets all changes to the "messages_dispatch_model_config" field.
+func (m *GroupMutation) ResetMessagesDispatchModelConfig() {
+ m.messages_dispatch_model_config = nil
+}
+
+// SetRpmLimit sets the "rpm_limit" field.
+func (m *GroupMutation) SetRpmLimit(i int) {
+ m.rpm_limit = &i
+ m.addrpm_limit = nil
+}
+
+// RpmLimit returns the value of the "rpm_limit" field in the mutation.
+func (m *GroupMutation) RpmLimit() (r int, exists bool) {
+ v := m.rpm_limit
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRpmLimit returns the old "rpm_limit" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldRpmLimit(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRpmLimit is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRpmLimit requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRpmLimit: %w", err)
+ }
+ return oldValue.RpmLimit, nil
+}
+
+// AddRpmLimit adds i to the "rpm_limit" field.
+func (m *GroupMutation) AddRpmLimit(i int) {
+ if m.addrpm_limit != nil {
+ *m.addrpm_limit += i
+ } else {
+ m.addrpm_limit = &i
+ }
+}
+
+// AddedRpmLimit returns the value that was added to the "rpm_limit" field in this mutation.
+func (m *GroupMutation) AddedRpmLimit() (r int, exists bool) {
+ v := m.addrpm_limit
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetRpmLimit resets all changes to the "rpm_limit" field.
+func (m *GroupMutation) ResetRpmLimit() {
+ m.rpm_limit = nil
+ m.addrpm_limit = nil
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -10502,7 +16791,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
- fields := make([]string, 0, 34)
+ fields := make([]string, 0, 31)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}
@@ -10554,21 +16843,6 @@ func (m *GroupMutation) Fields() []string {
if m.image_price_4k != nil {
fields = append(fields, group.FieldImagePrice4k)
}
- if m.sora_image_price_360 != nil {
- fields = append(fields, group.FieldSoraImagePrice360)
- }
- if m.sora_image_price_540 != nil {
- fields = append(fields, group.FieldSoraImagePrice540)
- }
- if m.sora_video_price_per_request != nil {
- fields = append(fields, group.FieldSoraVideoPricePerRequest)
- }
- if m.sora_video_price_per_request_hd != nil {
- fields = append(fields, group.FieldSoraVideoPricePerRequestHd)
- }
- if m.sora_storage_quota_bytes != nil {
- fields = append(fields, group.FieldSoraStorageQuotaBytes)
- }
if m.claude_code_only != nil {
fields = append(fields, group.FieldClaudeCodeOnly)
}
@@ -10605,6 +16879,12 @@ func (m *GroupMutation) Fields() []string {
if m.default_mapped_model != nil {
fields = append(fields, group.FieldDefaultMappedModel)
}
+ if m.messages_dispatch_model_config != nil {
+ fields = append(fields, group.FieldMessagesDispatchModelConfig)
+ }
+ if m.rpm_limit != nil {
+ fields = append(fields, group.FieldRpmLimit)
+ }
return fields
}
@@ -10647,16 +16927,6 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.ImagePrice2k()
case group.FieldImagePrice4k:
return m.ImagePrice4k()
- case group.FieldSoraImagePrice360:
- return m.SoraImagePrice360()
- case group.FieldSoraImagePrice540:
- return m.SoraImagePrice540()
- case group.FieldSoraVideoPricePerRequest:
- return m.SoraVideoPricePerRequest()
- case group.FieldSoraVideoPricePerRequestHd:
- return m.SoraVideoPricePerRequestHd()
- case group.FieldSoraStorageQuotaBytes:
- return m.SoraStorageQuotaBytes()
case group.FieldClaudeCodeOnly:
return m.ClaudeCodeOnly()
case group.FieldFallbackGroupID:
@@ -10681,6 +16951,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.RequirePrivacySet()
case group.FieldDefaultMappedModel:
return m.DefaultMappedModel()
+ case group.FieldMessagesDispatchModelConfig:
+ return m.MessagesDispatchModelConfig()
+ case group.FieldRpmLimit:
+ return m.RpmLimit()
}
return nil, false
}
@@ -10724,16 +16998,6 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldImagePrice2k(ctx)
case group.FieldImagePrice4k:
return m.OldImagePrice4k(ctx)
- case group.FieldSoraImagePrice360:
- return m.OldSoraImagePrice360(ctx)
- case group.FieldSoraImagePrice540:
- return m.OldSoraImagePrice540(ctx)
- case group.FieldSoraVideoPricePerRequest:
- return m.OldSoraVideoPricePerRequest(ctx)
- case group.FieldSoraVideoPricePerRequestHd:
- return m.OldSoraVideoPricePerRequestHd(ctx)
- case group.FieldSoraStorageQuotaBytes:
- return m.OldSoraStorageQuotaBytes(ctx)
case group.FieldClaudeCodeOnly:
return m.OldClaudeCodeOnly(ctx)
case group.FieldFallbackGroupID:
@@ -10758,6 +17022,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldRequirePrivacySet(ctx)
case group.FieldDefaultMappedModel:
return m.OldDefaultMappedModel(ctx)
+ case group.FieldMessagesDispatchModelConfig:
+ return m.OldMessagesDispatchModelConfig(ctx)
+ case group.FieldRpmLimit:
+ return m.OldRpmLimit(ctx)
}
return nil, fmt.Errorf("unknown Group field %s", name)
}
@@ -10886,41 +17154,6 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetImagePrice4k(v)
return nil
- case group.FieldSoraImagePrice360:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSoraImagePrice360(v)
- return nil
- case group.FieldSoraImagePrice540:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSoraImagePrice540(v)
- return nil
- case group.FieldSoraVideoPricePerRequest:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSoraVideoPricePerRequest(v)
- return nil
- case group.FieldSoraVideoPricePerRequestHd:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSoraVideoPricePerRequestHd(v)
- return nil
- case group.FieldSoraStorageQuotaBytes:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetSoraStorageQuotaBytes(v)
- return nil
case group.FieldClaudeCodeOnly:
v, ok := value.(bool)
if !ok {
@@ -11005,6 +17238,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetDefaultMappedModel(v)
return nil
+ case group.FieldMessagesDispatchModelConfig:
+ v, ok := value.(domain.OpenAIMessagesDispatchModelConfig)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetMessagesDispatchModelConfig(v)
+ return nil
+ case group.FieldRpmLimit:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRpmLimit(v)
+ return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
@@ -11037,21 +17284,6 @@ func (m *GroupMutation) AddedFields() []string {
if m.addimage_price_4k != nil {
fields = append(fields, group.FieldImagePrice4k)
}
- if m.addsora_image_price_360 != nil {
- fields = append(fields, group.FieldSoraImagePrice360)
- }
- if m.addsora_image_price_540 != nil {
- fields = append(fields, group.FieldSoraImagePrice540)
- }
- if m.addsora_video_price_per_request != nil {
- fields = append(fields, group.FieldSoraVideoPricePerRequest)
- }
- if m.addsora_video_price_per_request_hd != nil {
- fields = append(fields, group.FieldSoraVideoPricePerRequestHd)
- }
- if m.addsora_storage_quota_bytes != nil {
- fields = append(fields, group.FieldSoraStorageQuotaBytes)
- }
if m.addfallback_group_id != nil {
fields = append(fields, group.FieldFallbackGroupID)
}
@@ -11061,6 +17293,9 @@ func (m *GroupMutation) AddedFields() []string {
if m.addsort_order != nil {
fields = append(fields, group.FieldSortOrder)
}
+ if m.addrpm_limit != nil {
+ fields = append(fields, group.FieldRpmLimit)
+ }
return fields
}
@@ -11085,22 +17320,14 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedImagePrice2k()
case group.FieldImagePrice4k:
return m.AddedImagePrice4k()
- case group.FieldSoraImagePrice360:
- return m.AddedSoraImagePrice360()
- case group.FieldSoraImagePrice540:
- return m.AddedSoraImagePrice540()
- case group.FieldSoraVideoPricePerRequest:
- return m.AddedSoraVideoPricePerRequest()
- case group.FieldSoraVideoPricePerRequestHd:
- return m.AddedSoraVideoPricePerRequestHd()
- case group.FieldSoraStorageQuotaBytes:
- return m.AddedSoraStorageQuotaBytes()
case group.FieldFallbackGroupID:
return m.AddedFallbackGroupID()
case group.FieldFallbackGroupIDOnInvalidRequest:
return m.AddedFallbackGroupIDOnInvalidRequest()
case group.FieldSortOrder:
return m.AddedSortOrder()
+ case group.FieldRpmLimit:
+ return m.AddedRpmLimit()
}
return nil, false
}
@@ -11166,41 +17393,6 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
}
m.AddImagePrice4k(v)
return nil
- case group.FieldSoraImagePrice360:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSoraImagePrice360(v)
- return nil
- case group.FieldSoraImagePrice540:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSoraImagePrice540(v)
- return nil
- case group.FieldSoraVideoPricePerRequest:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSoraVideoPricePerRequest(v)
- return nil
- case group.FieldSoraVideoPricePerRequestHd:
- v, ok := value.(float64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSoraVideoPricePerRequestHd(v)
- return nil
- case group.FieldSoraStorageQuotaBytes:
- v, ok := value.(int64)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.AddSoraStorageQuotaBytes(v)
- return nil
case group.FieldFallbackGroupID:
v, ok := value.(int64)
if !ok {
@@ -11222,6 +17414,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error {
}
m.AddSortOrder(v)
return nil
+ case group.FieldRpmLimit:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddRpmLimit(v)
+ return nil
}
return fmt.Errorf("unknown Group numeric field %s", name)
}
@@ -11254,18 +17453,6 @@ func (m *GroupMutation) ClearedFields() []string {
if m.FieldCleared(group.FieldImagePrice4k) {
fields = append(fields, group.FieldImagePrice4k)
}
- if m.FieldCleared(group.FieldSoraImagePrice360) {
- fields = append(fields, group.FieldSoraImagePrice360)
- }
- if m.FieldCleared(group.FieldSoraImagePrice540) {
- fields = append(fields, group.FieldSoraImagePrice540)
- }
- if m.FieldCleared(group.FieldSoraVideoPricePerRequest) {
- fields = append(fields, group.FieldSoraVideoPricePerRequest)
- }
- if m.FieldCleared(group.FieldSoraVideoPricePerRequestHd) {
- fields = append(fields, group.FieldSoraVideoPricePerRequestHd)
- }
if m.FieldCleared(group.FieldFallbackGroupID) {
fields = append(fields, group.FieldFallbackGroupID)
}
@@ -11313,18 +17500,6 @@ func (m *GroupMutation) ClearField(name string) error {
case group.FieldImagePrice4k:
m.ClearImagePrice4k()
return nil
- case group.FieldSoraImagePrice360:
- m.ClearSoraImagePrice360()
- return nil
- case group.FieldSoraImagePrice540:
- m.ClearSoraImagePrice540()
- return nil
- case group.FieldSoraVideoPricePerRequest:
- m.ClearSoraVideoPricePerRequest()
- return nil
- case group.FieldSoraVideoPricePerRequestHd:
- m.ClearSoraVideoPricePerRequestHd()
- return nil
case group.FieldFallbackGroupID:
m.ClearFallbackGroupID()
return nil
@@ -11393,21 +17568,6 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldImagePrice4k:
m.ResetImagePrice4k()
return nil
- case group.FieldSoraImagePrice360:
- m.ResetSoraImagePrice360()
- return nil
- case group.FieldSoraImagePrice540:
- m.ResetSoraImagePrice540()
- return nil
- case group.FieldSoraVideoPricePerRequest:
- m.ResetSoraVideoPricePerRequest()
- return nil
- case group.FieldSoraVideoPricePerRequestHd:
- m.ResetSoraVideoPricePerRequestHd()
- return nil
- case group.FieldSoraStorageQuotaBytes:
- m.ResetSoraStorageQuotaBytes()
- return nil
case group.FieldClaudeCodeOnly:
m.ResetClaudeCodeOnly()
return nil
@@ -11444,6 +17604,12 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldDefaultMappedModel:
m.ResetDefaultMappedModel()
return nil
+ case group.FieldMessagesDispatchModelConfig:
+ m.ResetMessagesDispatchModelConfig()
+ return nil
+ case group.FieldRpmLimit:
+ m.ResetRpmLimit()
+ return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
@@ -12644,6 +18810,6970 @@ func (m *IdempotencyRecordMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown IdempotencyRecord edge %s", name)
}
+// IdentityAdoptionDecisionMutation represents an operation that mutates the IdentityAdoptionDecision nodes in the graph.
+type IdentityAdoptionDecisionMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ adopt_display_name *bool
+ adopt_avatar *bool
+ decided_at *time.Time
+ clearedFields map[string]struct{}
+ pending_auth_session *int64
+ clearedpending_auth_session bool
+ identity *int64
+ clearedidentity bool
+ done bool
+ oldValue func(context.Context) (*IdentityAdoptionDecision, error)
+ predicates []predicate.IdentityAdoptionDecision
+}
+
+var _ ent.Mutation = (*IdentityAdoptionDecisionMutation)(nil)
+
+// identityadoptiondecisionOption allows management of the mutation configuration using functional options.
+type identityadoptiondecisionOption func(*IdentityAdoptionDecisionMutation)
+
+// newIdentityAdoptionDecisionMutation creates new mutation for the IdentityAdoptionDecision entity.
+func newIdentityAdoptionDecisionMutation(c config, op Op, opts ...identityadoptiondecisionOption) *IdentityAdoptionDecisionMutation {
+ m := &IdentityAdoptionDecisionMutation{
+ config: c,
+ op: op,
+ typ: TypeIdentityAdoptionDecision,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withIdentityAdoptionDecisionID sets the ID field of the mutation.
+func withIdentityAdoptionDecisionID(id int64) identityadoptiondecisionOption {
+ return func(m *IdentityAdoptionDecisionMutation) {
+ var (
+ err error
+ once sync.Once
+ value *IdentityAdoptionDecision
+ )
+ m.oldValue = func(ctx context.Context) (*IdentityAdoptionDecision, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().IdentityAdoptionDecision.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withIdentityAdoptionDecision sets the old IdentityAdoptionDecision of the mutation.
+func withIdentityAdoptionDecision(node *IdentityAdoptionDecision) identityadoptiondecisionOption {
+ return func(m *IdentityAdoptionDecisionMutation) {
+ m.oldValue = func(context.Context) (*IdentityAdoptionDecision, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m IdentityAdoptionDecisionMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m IdentityAdoptionDecisionMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *IdentityAdoptionDecisionMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *IdentityAdoptionDecisionMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().IdentityAdoptionDecision.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *IdentityAdoptionDecisionMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *IdentityAdoptionDecisionMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *IdentityAdoptionDecisionMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *IdentityAdoptionDecisionMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetPendingAuthSessionID sets the "pending_auth_session_id" field.
+func (m *IdentityAdoptionDecisionMutation) SetPendingAuthSessionID(i int64) {
+ m.pending_auth_session = &i
+}
+
+// PendingAuthSessionID returns the value of the "pending_auth_session_id" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionID() (r int64, exists bool) {
+ v := m.pending_auth_session
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPendingAuthSessionID returns the old "pending_auth_session_id" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldPendingAuthSessionID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPendingAuthSessionID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPendingAuthSessionID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPendingAuthSessionID: %w", err)
+ }
+ return oldValue.PendingAuthSessionID, nil
+}
+
+// ResetPendingAuthSessionID resets all changes to the "pending_auth_session_id" field.
+func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSessionID() {
+ m.pending_auth_session = nil
+}
+
+// SetIdentityID sets the "identity_id" field.
+func (m *IdentityAdoptionDecisionMutation) SetIdentityID(i int64) {
+ m.identity = &i
+}
+
+// IdentityID returns the value of the "identity_id" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) IdentityID() (r int64, exists bool) {
+ v := m.identity
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIdentityID returns the old "identity_id" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldIdentityID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIdentityID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIdentityID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIdentityID: %w", err)
+ }
+ return oldValue.IdentityID, nil
+}
+
+// ClearIdentityID clears the value of the "identity_id" field.
+func (m *IdentityAdoptionDecisionMutation) ClearIdentityID() {
+ m.identity = nil
+ m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{}
+}
+
+// IdentityIDCleared returns if the "identity_id" field was cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) IdentityIDCleared() bool {
+ _, ok := m.clearedFields[identityadoptiondecision.FieldIdentityID]
+ return ok
+}
+
+// ResetIdentityID resets all changes to the "identity_id" field.
+func (m *IdentityAdoptionDecisionMutation) ResetIdentityID() {
+ m.identity = nil
+ delete(m.clearedFields, identityadoptiondecision.FieldIdentityID)
+}
+
+// SetAdoptDisplayName sets the "adopt_display_name" field.
+func (m *IdentityAdoptionDecisionMutation) SetAdoptDisplayName(b bool) {
+ m.adopt_display_name = &b
+}
+
+// AdoptDisplayName returns the value of the "adopt_display_name" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) AdoptDisplayName() (r bool, exists bool) {
+ v := m.adopt_display_name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAdoptDisplayName returns the old "adopt_display_name" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldAdoptDisplayName(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAdoptDisplayName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAdoptDisplayName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAdoptDisplayName: %w", err)
+ }
+ return oldValue.AdoptDisplayName, nil
+}
+
+// ResetAdoptDisplayName resets all changes to the "adopt_display_name" field.
+func (m *IdentityAdoptionDecisionMutation) ResetAdoptDisplayName() {
+ m.adopt_display_name = nil
+}
+
+// SetAdoptAvatar sets the "adopt_avatar" field.
+func (m *IdentityAdoptionDecisionMutation) SetAdoptAvatar(b bool) {
+ m.adopt_avatar = &b
+}
+
+// AdoptAvatar returns the value of the "adopt_avatar" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) AdoptAvatar() (r bool, exists bool) {
+ v := m.adopt_avatar
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAdoptAvatar returns the old "adopt_avatar" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldAdoptAvatar(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAdoptAvatar is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAdoptAvatar requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAdoptAvatar: %w", err)
+ }
+ return oldValue.AdoptAvatar, nil
+}
+
+// ResetAdoptAvatar resets all changes to the "adopt_avatar" field.
+func (m *IdentityAdoptionDecisionMutation) ResetAdoptAvatar() {
+ m.adopt_avatar = nil
+}
+
+// SetDecidedAt sets the "decided_at" field.
+func (m *IdentityAdoptionDecisionMutation) SetDecidedAt(t time.Time) {
+ m.decided_at = &t
+}
+
+// DecidedAt returns the value of the "decided_at" field in the mutation.
+func (m *IdentityAdoptionDecisionMutation) DecidedAt() (r time.Time, exists bool) {
+ v := m.decided_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDecidedAt returns the old "decided_at" field's value of the IdentityAdoptionDecision entity.
+// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *IdentityAdoptionDecisionMutation) OldDecidedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDecidedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDecidedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDecidedAt: %w", err)
+ }
+ return oldValue.DecidedAt, nil
+}
+
+// ResetDecidedAt resets all changes to the "decided_at" field.
+func (m *IdentityAdoptionDecisionMutation) ResetDecidedAt() {
+ m.decided_at = nil
+}
+
+// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity.
+func (m *IdentityAdoptionDecisionMutation) ClearPendingAuthSession() {
+ m.clearedpending_auth_session = true
+ m.clearedFields[identityadoptiondecision.FieldPendingAuthSessionID] = struct{}{}
+}
+
+// PendingAuthSessionCleared reports if the "pending_auth_session" edge to the PendingAuthSession entity was cleared.
+func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionCleared() bool {
+ return m.clearedpending_auth_session
+}
+
+// PendingAuthSessionIDs returns the "pending_auth_session" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// PendingAuthSessionID instead. It exists only for internal usage by the builders.
+func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionIDs() (ids []int64) {
+ if id := m.pending_auth_session; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetPendingAuthSession resets all changes to the "pending_auth_session" edge.
+func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSession() {
+ m.pending_auth_session = nil
+ m.clearedpending_auth_session = false
+}
+
+// ClearIdentity clears the "identity" edge to the AuthIdentity entity.
+func (m *IdentityAdoptionDecisionMutation) ClearIdentity() {
+ m.clearedidentity = true
+ m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{}
+}
+
+// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared.
+func (m *IdentityAdoptionDecisionMutation) IdentityCleared() bool {
+ return m.IdentityIDCleared() || m.clearedidentity
+}
+
+// IdentityIDs returns the "identity" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// IdentityID instead. It exists only for internal usage by the builders.
+func (m *IdentityAdoptionDecisionMutation) IdentityIDs() (ids []int64) {
+ if id := m.identity; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetIdentity resets all changes to the "identity" edge.
+func (m *IdentityAdoptionDecisionMutation) ResetIdentity() {
+ m.identity = nil
+ m.clearedidentity = false
+}
+
+// Where appends a list predicates to the IdentityAdoptionDecisionMutation builder.
+func (m *IdentityAdoptionDecisionMutation) Where(ps ...predicate.IdentityAdoptionDecision) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the IdentityAdoptionDecisionMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *IdentityAdoptionDecisionMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.IdentityAdoptionDecision, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *IdentityAdoptionDecisionMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *IdentityAdoptionDecisionMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (IdentityAdoptionDecision).
+func (m *IdentityAdoptionDecisionMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *IdentityAdoptionDecisionMutation) Fields() []string {
+ fields := make([]string, 0, 7)
+ if m.created_at != nil {
+ fields = append(fields, identityadoptiondecision.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, identityadoptiondecision.FieldUpdatedAt)
+ }
+ if m.pending_auth_session != nil {
+ fields = append(fields, identityadoptiondecision.FieldPendingAuthSessionID)
+ }
+ if m.identity != nil {
+ fields = append(fields, identityadoptiondecision.FieldIdentityID)
+ }
+ if m.adopt_display_name != nil {
+ fields = append(fields, identityadoptiondecision.FieldAdoptDisplayName)
+ }
+ if m.adopt_avatar != nil {
+ fields = append(fields, identityadoptiondecision.FieldAdoptAvatar)
+ }
+ if m.decided_at != nil {
+ fields = append(fields, identityadoptiondecision.FieldDecidedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *IdentityAdoptionDecisionMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ return m.CreatedAt()
+ case identityadoptiondecision.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ return m.PendingAuthSessionID()
+ case identityadoptiondecision.FieldIdentityID:
+ return m.IdentityID()
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ return m.AdoptDisplayName()
+ case identityadoptiondecision.FieldAdoptAvatar:
+ return m.AdoptAvatar()
+ case identityadoptiondecision.FieldDecidedAt:
+ return m.DecidedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *IdentityAdoptionDecisionMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case identityadoptiondecision.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ return m.OldPendingAuthSessionID(ctx)
+ case identityadoptiondecision.FieldIdentityID:
+ return m.OldIdentityID(ctx)
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ return m.OldAdoptDisplayName(ctx)
+ case identityadoptiondecision.FieldAdoptAvatar:
+ return m.OldAdoptAvatar(ctx)
+ case identityadoptiondecision.FieldDecidedAt:
+ return m.OldDecidedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown IdentityAdoptionDecision field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *IdentityAdoptionDecisionMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case identityadoptiondecision.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPendingAuthSessionID(v)
+ return nil
+ case identityadoptiondecision.FieldIdentityID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIdentityID(v)
+ return nil
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAdoptDisplayName(v)
+ return nil
+ case identityadoptiondecision.FieldAdoptAvatar:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAdoptAvatar(v)
+ return nil
+ case identityadoptiondecision.FieldDecidedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDecidedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *IdentityAdoptionDecisionMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *IdentityAdoptionDecisionMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *IdentityAdoptionDecisionMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(identityadoptiondecision.FieldIdentityID) {
+ fields = append(fields, identityadoptiondecision.FieldIdentityID)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ClearField(name string) error {
+ switch name {
+ case identityadoptiondecision.FieldIdentityID:
+ m.ClearIdentityID()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ResetField(name string) error {
+ switch name {
+ case identityadoptiondecision.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case identityadoptiondecision.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case identityadoptiondecision.FieldPendingAuthSessionID:
+ m.ResetPendingAuthSessionID()
+ return nil
+ case identityadoptiondecision.FieldIdentityID:
+ m.ResetIdentityID()
+ return nil
+ case identityadoptiondecision.FieldAdoptDisplayName:
+ m.ResetAdoptDisplayName()
+ return nil
+ case identityadoptiondecision.FieldAdoptAvatar:
+ m.ResetAdoptAvatar()
+ return nil
+ case identityadoptiondecision.FieldDecidedAt:
+ m.ResetDecidedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *IdentityAdoptionDecisionMutation) AddedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.pending_auth_session != nil {
+ edges = append(edges, identityadoptiondecision.EdgePendingAuthSession)
+ }
+ if m.identity != nil {
+ edges = append(edges, identityadoptiondecision.EdgeIdentity)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *IdentityAdoptionDecisionMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ if id := m.pending_auth_session; id != nil {
+ return []ent.Value{*id}
+ }
+ case identityadoptiondecision.EdgeIdentity:
+ if id := m.identity; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *IdentityAdoptionDecisionMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 2)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *IdentityAdoptionDecisionMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.clearedpending_auth_session {
+ edges = append(edges, identityadoptiondecision.EdgePendingAuthSession)
+ }
+ if m.clearedidentity {
+ edges = append(edges, identityadoptiondecision.EdgeIdentity)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *IdentityAdoptionDecisionMutation) EdgeCleared(name string) bool {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ return m.clearedpending_auth_session
+ case identityadoptiondecision.EdgeIdentity:
+ return m.clearedidentity
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ClearEdge(name string) error {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ m.ClearPendingAuthSession()
+ return nil
+ case identityadoptiondecision.EdgeIdentity:
+ m.ClearIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *IdentityAdoptionDecisionMutation) ResetEdge(name string) error {
+ switch name {
+ case identityadoptiondecision.EdgePendingAuthSession:
+ m.ResetPendingAuthSession()
+ return nil
+ case identityadoptiondecision.EdgeIdentity:
+ m.ResetIdentity()
+ return nil
+ }
+ return fmt.Errorf("unknown IdentityAdoptionDecision edge %s", name)
+}
+
+// PaymentAuditLogMutation represents an operation that mutates the PaymentAuditLog nodes in the graph.
+type PaymentAuditLogMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ order_id *string
+ action *string
+ detail *string
+ operator *string
+ created_at *time.Time
+ clearedFields map[string]struct{}
+ done bool
+ oldValue func(context.Context) (*PaymentAuditLog, error)
+ predicates []predicate.PaymentAuditLog
+}
+
+var _ ent.Mutation = (*PaymentAuditLogMutation)(nil)
+
+// paymentauditlogOption allows management of the mutation configuration using functional options.
+type paymentauditlogOption func(*PaymentAuditLogMutation)
+
+// newPaymentAuditLogMutation creates new mutation for the PaymentAuditLog entity.
+func newPaymentAuditLogMutation(c config, op Op, opts ...paymentauditlogOption) *PaymentAuditLogMutation {
+ m := &PaymentAuditLogMutation{
+ config: c,
+ op: op,
+ typ: TypePaymentAuditLog,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withPaymentAuditLogID sets the ID field of the mutation.
+func withPaymentAuditLogID(id int64) paymentauditlogOption {
+ return func(m *PaymentAuditLogMutation) {
+ var (
+ err error
+ once sync.Once
+ value *PaymentAuditLog
+ )
+ m.oldValue = func(ctx context.Context) (*PaymentAuditLog, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().PaymentAuditLog.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withPaymentAuditLog sets the old PaymentAuditLog of the mutation.
+func withPaymentAuditLog(node *PaymentAuditLog) paymentauditlogOption {
+ return func(m *PaymentAuditLogMutation) {
+ m.oldValue = func(context.Context) (*PaymentAuditLog, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m PaymentAuditLogMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m PaymentAuditLogMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *PaymentAuditLogMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *PaymentAuditLogMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().PaymentAuditLog.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetOrderID sets the "order_id" field.
+func (m *PaymentAuditLogMutation) SetOrderID(s string) {
+ m.order_id = &s
+}
+
+// OrderID returns the value of the "order_id" field in the mutation.
+func (m *PaymentAuditLogMutation) OrderID() (r string, exists bool) {
+ v := m.order_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldOrderID returns the old "order_id" field's value of the PaymentAuditLog entity.
+// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentAuditLogMutation) OldOrderID(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldOrderID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldOrderID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldOrderID: %w", err)
+ }
+ return oldValue.OrderID, nil
+}
+
+// ResetOrderID resets all changes to the "order_id" field.
+func (m *PaymentAuditLogMutation) ResetOrderID() {
+ m.order_id = nil
+}
+
+// SetAction sets the "action" field.
+func (m *PaymentAuditLogMutation) SetAction(s string) {
+ m.action = &s
+}
+
+// Action returns the value of the "action" field in the mutation.
+func (m *PaymentAuditLogMutation) Action() (r string, exists bool) {
+ v := m.action
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAction returns the old "action" field's value of the PaymentAuditLog entity.
+// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentAuditLogMutation) OldAction(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAction is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAction requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAction: %w", err)
+ }
+ return oldValue.Action, nil
+}
+
+// ResetAction resets all changes to the "action" field.
+func (m *PaymentAuditLogMutation) ResetAction() {
+ m.action = nil
+}
+
+// SetDetail sets the "detail" field.
+func (m *PaymentAuditLogMutation) SetDetail(s string) {
+ m.detail = &s
+}
+
+// Detail returns the value of the "detail" field in the mutation.
+func (m *PaymentAuditLogMutation) Detail() (r string, exists bool) {
+ v := m.detail
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDetail returns the old "detail" field's value of the PaymentAuditLog entity.
+// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentAuditLogMutation) OldDetail(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDetail is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDetail requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDetail: %w", err)
+ }
+ return oldValue.Detail, nil
+}
+
+// ResetDetail resets all changes to the "detail" field.
+func (m *PaymentAuditLogMutation) ResetDetail() {
+ m.detail = nil
+}
+
+// SetOperator sets the "operator" field.
+func (m *PaymentAuditLogMutation) SetOperator(s string) {
+ m.operator = &s
+}
+
+// Operator returns the value of the "operator" field in the mutation.
+func (m *PaymentAuditLogMutation) Operator() (r string, exists bool) {
+ v := m.operator
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldOperator returns the old "operator" field's value of the PaymentAuditLog entity.
+// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentAuditLogMutation) OldOperator(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldOperator is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldOperator requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldOperator: %w", err)
+ }
+ return oldValue.Operator, nil
+}
+
+// ResetOperator resets all changes to the "operator" field.
+func (m *PaymentAuditLogMutation) ResetOperator() {
+ m.operator = nil
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *PaymentAuditLogMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *PaymentAuditLogMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the PaymentAuditLog entity.
+// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentAuditLogMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *PaymentAuditLogMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// Where appends a list predicates to the PaymentAuditLogMutation builder.
+func (m *PaymentAuditLogMutation) Where(ps ...predicate.PaymentAuditLog) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the PaymentAuditLogMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *PaymentAuditLogMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.PaymentAuditLog, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *PaymentAuditLogMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *PaymentAuditLogMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (PaymentAuditLog).
+func (m *PaymentAuditLogMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *PaymentAuditLogMutation) Fields() []string {
+ fields := make([]string, 0, 5)
+ if m.order_id != nil {
+ fields = append(fields, paymentauditlog.FieldOrderID)
+ }
+ if m.action != nil {
+ fields = append(fields, paymentauditlog.FieldAction)
+ }
+ if m.detail != nil {
+ fields = append(fields, paymentauditlog.FieldDetail)
+ }
+ if m.operator != nil {
+ fields = append(fields, paymentauditlog.FieldOperator)
+ }
+ if m.created_at != nil {
+ fields = append(fields, paymentauditlog.FieldCreatedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *PaymentAuditLogMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case paymentauditlog.FieldOrderID:
+ return m.OrderID()
+ case paymentauditlog.FieldAction:
+ return m.Action()
+ case paymentauditlog.FieldDetail:
+ return m.Detail()
+ case paymentauditlog.FieldOperator:
+ return m.Operator()
+ case paymentauditlog.FieldCreatedAt:
+ return m.CreatedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *PaymentAuditLogMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case paymentauditlog.FieldOrderID:
+ return m.OldOrderID(ctx)
+ case paymentauditlog.FieldAction:
+ return m.OldAction(ctx)
+ case paymentauditlog.FieldDetail:
+ return m.OldDetail(ctx)
+ case paymentauditlog.FieldOperator:
+ return m.OldOperator(ctx)
+ case paymentauditlog.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown PaymentAuditLog field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PaymentAuditLogMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case paymentauditlog.FieldOrderID:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetOrderID(v)
+ return nil
+ case paymentauditlog.FieldAction:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAction(v)
+ return nil
+ case paymentauditlog.FieldDetail:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDetail(v)
+ return nil
+ case paymentauditlog.FieldOperator:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetOperator(v)
+ return nil
+ case paymentauditlog.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown PaymentAuditLog field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *PaymentAuditLogMutation) AddedFields() []string {
+ return nil
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *PaymentAuditLogMutation) AddedField(name string) (ent.Value, bool) {
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PaymentAuditLogMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown PaymentAuditLog numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *PaymentAuditLogMutation) ClearedFields() []string {
+ return nil
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *PaymentAuditLogMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *PaymentAuditLogMutation) ClearField(name string) error {
+ return fmt.Errorf("unknown PaymentAuditLog nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *PaymentAuditLogMutation) ResetField(name string) error {
+ switch name {
+ case paymentauditlog.FieldOrderID:
+ m.ResetOrderID()
+ return nil
+ case paymentauditlog.FieldAction:
+ m.ResetAction()
+ return nil
+ case paymentauditlog.FieldDetail:
+ m.ResetDetail()
+ return nil
+ case paymentauditlog.FieldOperator:
+ m.ResetOperator()
+ return nil
+ case paymentauditlog.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown PaymentAuditLog field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *PaymentAuditLogMutation) AddedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *PaymentAuditLogMutation) AddedIDs(name string) []ent.Value {
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *PaymentAuditLogMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *PaymentAuditLogMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *PaymentAuditLogMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *PaymentAuditLogMutation) EdgeCleared(name string) bool {
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *PaymentAuditLogMutation) ClearEdge(name string) error {
+ return fmt.Errorf("unknown PaymentAuditLog unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *PaymentAuditLogMutation) ResetEdge(name string) error {
+ return fmt.Errorf("unknown PaymentAuditLog edge %s", name)
+}
+
+// PaymentOrderMutation represents an operation that mutates the PaymentOrder nodes in the graph.
+type PaymentOrderMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ user_email *string
+ user_name *string
+ user_notes *string
+ amount *float64
+ addamount *float64
+ pay_amount *float64
+ addpay_amount *float64
+ fee_rate *float64
+ addfee_rate *float64
+ recharge_code *string
+ out_trade_no *string
+ payment_type *string
+ payment_trade_no *string
+ pay_url *string
+ qr_code *string
+ qr_code_img *string
+ order_type *string
+ plan_id *int64
+ addplan_id *int64
+ subscription_group_id *int64
+ addsubscription_group_id *int64
+ subscription_days *int
+ addsubscription_days *int
+ provider_instance_id *string
+ provider_key *string
+ provider_snapshot *map[string]interface{}
+ status *string
+ refund_amount *float64
+ addrefund_amount *float64
+ refund_reason *string
+ refund_at *time.Time
+ force_refund *bool
+ refund_requested_at *time.Time
+ refund_request_reason *string
+ refund_requested_by *string
+ expires_at *time.Time
+ paid_at *time.Time
+ completed_at *time.Time
+ failed_at *time.Time
+ failed_reason *string
+ client_ip *string
+ src_host *string
+ src_url *string
+ created_at *time.Time
+ updated_at *time.Time
+ clearedFields map[string]struct{}
+ user *int64
+ cleareduser bool
+ done bool
+ oldValue func(context.Context) (*PaymentOrder, error)
+ predicates []predicate.PaymentOrder
+}
+
+var _ ent.Mutation = (*PaymentOrderMutation)(nil)
+
+// paymentorderOption allows management of the mutation configuration using functional options.
+type paymentorderOption func(*PaymentOrderMutation)
+
+// newPaymentOrderMutation creates new mutation for the PaymentOrder entity.
+func newPaymentOrderMutation(c config, op Op, opts ...paymentorderOption) *PaymentOrderMutation {
+ m := &PaymentOrderMutation{
+ config: c,
+ op: op,
+ typ: TypePaymentOrder,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withPaymentOrderID sets the ID field of the mutation.
+func withPaymentOrderID(id int64) paymentorderOption {
+ return func(m *PaymentOrderMutation) {
+ var (
+ err error
+ once sync.Once
+ value *PaymentOrder
+ )
+ m.oldValue = func(ctx context.Context) (*PaymentOrder, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().PaymentOrder.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withPaymentOrder sets the old PaymentOrder of the mutation.
+func withPaymentOrder(node *PaymentOrder) paymentorderOption {
+ return func(m *PaymentOrderMutation) {
+ m.oldValue = func(context.Context) (*PaymentOrder, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m PaymentOrderMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m PaymentOrderMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *PaymentOrderMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *PaymentOrderMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().PaymentOrder.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetUserID sets the "user_id" field.
+func (m *PaymentOrderMutation) SetUserID(i int64) {
+ m.user = &i
+}
+
+// UserID returns the value of the "user_id" field in the mutation.
+func (m *PaymentOrderMutation) UserID() (r int64, exists bool) {
+ v := m.user
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUserID returns the old "user_id" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldUserID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUserID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUserID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUserID: %w", err)
+ }
+ return oldValue.UserID, nil
+}
+
+// ResetUserID resets all changes to the "user_id" field.
+func (m *PaymentOrderMutation) ResetUserID() {
+ m.user = nil
+}
+
+// SetUserEmail sets the "user_email" field.
+func (m *PaymentOrderMutation) SetUserEmail(s string) {
+ m.user_email = &s
+}
+
+// UserEmail returns the value of the "user_email" field in the mutation.
+func (m *PaymentOrderMutation) UserEmail() (r string, exists bool) {
+ v := m.user_email
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUserEmail returns the old "user_email" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldUserEmail(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUserEmail is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUserEmail requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUserEmail: %w", err)
+ }
+ return oldValue.UserEmail, nil
+}
+
+// ResetUserEmail resets all changes to the "user_email" field.
+func (m *PaymentOrderMutation) ResetUserEmail() {
+ m.user_email = nil
+}
+
+// SetUserName sets the "user_name" field.
+func (m *PaymentOrderMutation) SetUserName(s string) {
+ m.user_name = &s
+}
+
+// UserName returns the value of the "user_name" field in the mutation.
+func (m *PaymentOrderMutation) UserName() (r string, exists bool) {
+ v := m.user_name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUserName returns the old "user_name" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldUserName(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUserName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUserName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUserName: %w", err)
+ }
+ return oldValue.UserName, nil
+}
+
+// ResetUserName resets all changes to the "user_name" field.
+func (m *PaymentOrderMutation) ResetUserName() {
+ m.user_name = nil
+}
+
+// SetUserNotes sets the "user_notes" field.
+func (m *PaymentOrderMutation) SetUserNotes(s string) {
+ m.user_notes = &s
+}
+
+// UserNotes returns the value of the "user_notes" field in the mutation.
+func (m *PaymentOrderMutation) UserNotes() (r string, exists bool) {
+ v := m.user_notes
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUserNotes returns the old "user_notes" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldUserNotes(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUserNotes is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUserNotes requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUserNotes: %w", err)
+ }
+ return oldValue.UserNotes, nil
+}
+
+// ClearUserNotes clears the value of the "user_notes" field.
+func (m *PaymentOrderMutation) ClearUserNotes() {
+ m.user_notes = nil
+ m.clearedFields[paymentorder.FieldUserNotes] = struct{}{}
+}
+
+// UserNotesCleared returns if the "user_notes" field was cleared in this mutation.
+func (m *PaymentOrderMutation) UserNotesCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldUserNotes]
+ return ok
+}
+
+// ResetUserNotes resets all changes to the "user_notes" field.
+func (m *PaymentOrderMutation) ResetUserNotes() {
+ m.user_notes = nil
+ delete(m.clearedFields, paymentorder.FieldUserNotes)
+}
+
+// SetAmount sets the "amount" field.
+func (m *PaymentOrderMutation) SetAmount(f float64) {
+ m.amount = &f
+ m.addamount = nil
+}
+
+// Amount returns the value of the "amount" field in the mutation.
+func (m *PaymentOrderMutation) Amount() (r float64, exists bool) {
+ v := m.amount
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAmount returns the old "amount" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldAmount(ctx context.Context) (v float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAmount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAmount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAmount: %w", err)
+ }
+ return oldValue.Amount, nil
+}
+
+// AddAmount adds f to the "amount" field.
+func (m *PaymentOrderMutation) AddAmount(f float64) {
+ if m.addamount != nil {
+ *m.addamount += f
+ } else {
+ m.addamount = &f
+ }
+}
+
+// AddedAmount returns the value that was added to the "amount" field in this mutation.
+func (m *PaymentOrderMutation) AddedAmount() (r float64, exists bool) {
+ v := m.addamount
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetAmount resets all changes to the "amount" field.
+func (m *PaymentOrderMutation) ResetAmount() {
+ m.amount = nil
+ m.addamount = nil
+}
+
+// SetPayAmount sets the "pay_amount" field.
+func (m *PaymentOrderMutation) SetPayAmount(f float64) {
+ m.pay_amount = &f
+ m.addpay_amount = nil
+}
+
+// PayAmount returns the value of the "pay_amount" field in the mutation.
+func (m *PaymentOrderMutation) PayAmount() (r float64, exists bool) {
+ v := m.pay_amount
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPayAmount returns the old "pay_amount" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldPayAmount(ctx context.Context) (v float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPayAmount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPayAmount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPayAmount: %w", err)
+ }
+ return oldValue.PayAmount, nil
+}
+
+// AddPayAmount adds f to the "pay_amount" field.
+func (m *PaymentOrderMutation) AddPayAmount(f float64) {
+ if m.addpay_amount != nil {
+ *m.addpay_amount += f
+ } else {
+ m.addpay_amount = &f
+ }
+}
+
+// AddedPayAmount returns the value that was added to the "pay_amount" field in this mutation.
+func (m *PaymentOrderMutation) AddedPayAmount() (r float64, exists bool) {
+ v := m.addpay_amount
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetPayAmount resets all changes to the "pay_amount" field.
+func (m *PaymentOrderMutation) ResetPayAmount() {
+ m.pay_amount = nil
+ m.addpay_amount = nil
+}
+
+// SetFeeRate sets the "fee_rate" field.
+func (m *PaymentOrderMutation) SetFeeRate(f float64) {
+ m.fee_rate = &f
+ m.addfee_rate = nil
+}
+
+// FeeRate returns the value of the "fee_rate" field in the mutation.
+func (m *PaymentOrderMutation) FeeRate() (r float64, exists bool) {
+ v := m.fee_rate
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldFeeRate returns the old "fee_rate" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldFeeRate(ctx context.Context) (v float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldFeeRate is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldFeeRate requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldFeeRate: %w", err)
+ }
+ return oldValue.FeeRate, nil
+}
+
+// AddFeeRate adds f to the "fee_rate" field.
+func (m *PaymentOrderMutation) AddFeeRate(f float64) {
+ if m.addfee_rate != nil {
+ *m.addfee_rate += f
+ } else {
+ m.addfee_rate = &f
+ }
+}
+
+// AddedFeeRate returns the value that was added to the "fee_rate" field in this mutation.
+func (m *PaymentOrderMutation) AddedFeeRate() (r float64, exists bool) {
+ v := m.addfee_rate
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetFeeRate resets all changes to the "fee_rate" field.
+func (m *PaymentOrderMutation) ResetFeeRate() {
+ m.fee_rate = nil
+ m.addfee_rate = nil
+}
+
+// SetRechargeCode sets the "recharge_code" field.
+func (m *PaymentOrderMutation) SetRechargeCode(s string) {
+ m.recharge_code = &s
+}
+
+// RechargeCode returns the value of the "recharge_code" field in the mutation.
+func (m *PaymentOrderMutation) RechargeCode() (r string, exists bool) {
+ v := m.recharge_code
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRechargeCode returns the old "recharge_code" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldRechargeCode(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRechargeCode is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRechargeCode requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRechargeCode: %w", err)
+ }
+ return oldValue.RechargeCode, nil
+}
+
+// ResetRechargeCode resets all changes to the "recharge_code" field.
+func (m *PaymentOrderMutation) ResetRechargeCode() {
+ m.recharge_code = nil
+}
+
+// SetOutTradeNo sets the "out_trade_no" field.
+func (m *PaymentOrderMutation) SetOutTradeNo(s string) {
+ m.out_trade_no = &s
+}
+
+// OutTradeNo returns the value of the "out_trade_no" field in the mutation.
+func (m *PaymentOrderMutation) OutTradeNo() (r string, exists bool) {
+ v := m.out_trade_no
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldOutTradeNo returns the old "out_trade_no" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldOutTradeNo(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldOutTradeNo is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldOutTradeNo requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldOutTradeNo: %w", err)
+ }
+ return oldValue.OutTradeNo, nil
+}
+
+// ResetOutTradeNo resets all changes to the "out_trade_no" field.
+func (m *PaymentOrderMutation) ResetOutTradeNo() {
+ m.out_trade_no = nil
+}
+
+// SetPaymentType sets the "payment_type" field.
+func (m *PaymentOrderMutation) SetPaymentType(s string) {
+ m.payment_type = &s
+}
+
+// PaymentType returns the value of the "payment_type" field in the mutation.
+func (m *PaymentOrderMutation) PaymentType() (r string, exists bool) {
+ v := m.payment_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPaymentType returns the old "payment_type" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldPaymentType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPaymentType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPaymentType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPaymentType: %w", err)
+ }
+ return oldValue.PaymentType, nil
+}
+
+// ResetPaymentType resets all changes to the "payment_type" field.
+func (m *PaymentOrderMutation) ResetPaymentType() {
+ m.payment_type = nil
+}
+
+// SetPaymentTradeNo sets the "payment_trade_no" field.
+func (m *PaymentOrderMutation) SetPaymentTradeNo(s string) {
+ m.payment_trade_no = &s
+}
+
+// PaymentTradeNo returns the value of the "payment_trade_no" field in the mutation.
+func (m *PaymentOrderMutation) PaymentTradeNo() (r string, exists bool) {
+ v := m.payment_trade_no
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPaymentTradeNo returns the old "payment_trade_no" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldPaymentTradeNo(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPaymentTradeNo is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPaymentTradeNo requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPaymentTradeNo: %w", err)
+ }
+ return oldValue.PaymentTradeNo, nil
+}
+
+// ResetPaymentTradeNo resets all changes to the "payment_trade_no" field.
+func (m *PaymentOrderMutation) ResetPaymentTradeNo() {
+ m.payment_trade_no = nil
+}
+
+// SetPayURL sets the "pay_url" field.
+func (m *PaymentOrderMutation) SetPayURL(s string) {
+ m.pay_url = &s
+}
+
+// PayURL returns the value of the "pay_url" field in the mutation.
+func (m *PaymentOrderMutation) PayURL() (r string, exists bool) {
+ v := m.pay_url
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPayURL returns the old "pay_url" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldPayURL(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPayURL is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPayURL requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPayURL: %w", err)
+ }
+ return oldValue.PayURL, nil
+}
+
+// ClearPayURL clears the value of the "pay_url" field.
+func (m *PaymentOrderMutation) ClearPayURL() {
+ m.pay_url = nil
+ m.clearedFields[paymentorder.FieldPayURL] = struct{}{}
+}
+
+// PayURLCleared returns if the "pay_url" field was cleared in this mutation.
+func (m *PaymentOrderMutation) PayURLCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldPayURL]
+ return ok
+}
+
+// ResetPayURL resets all changes to the "pay_url" field.
+func (m *PaymentOrderMutation) ResetPayURL() {
+ m.pay_url = nil
+ delete(m.clearedFields, paymentorder.FieldPayURL)
+}
+
+// SetQrCode sets the "qr_code" field.
+func (m *PaymentOrderMutation) SetQrCode(s string) {
+ m.qr_code = &s
+}
+
+// QrCode returns the value of the "qr_code" field in the mutation.
+func (m *PaymentOrderMutation) QrCode() (r string, exists bool) {
+ v := m.qr_code
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldQrCode returns the old "qr_code" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldQrCode(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldQrCode is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldQrCode requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldQrCode: %w", err)
+ }
+ return oldValue.QrCode, nil
+}
+
+// ClearQrCode clears the value of the "qr_code" field.
+func (m *PaymentOrderMutation) ClearQrCode() {
+ m.qr_code = nil
+ m.clearedFields[paymentorder.FieldQrCode] = struct{}{}
+}
+
+// QrCodeCleared returns if the "qr_code" field was cleared in this mutation.
+func (m *PaymentOrderMutation) QrCodeCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldQrCode]
+ return ok
+}
+
+// ResetQrCode resets all changes to the "qr_code" field.
+func (m *PaymentOrderMutation) ResetQrCode() {
+ m.qr_code = nil
+ delete(m.clearedFields, paymentorder.FieldQrCode)
+}
+
+// SetQrCodeImg sets the "qr_code_img" field.
+func (m *PaymentOrderMutation) SetQrCodeImg(s string) {
+ m.qr_code_img = &s
+}
+
+// QrCodeImg returns the value of the "qr_code_img" field in the mutation.
+func (m *PaymentOrderMutation) QrCodeImg() (r string, exists bool) {
+ v := m.qr_code_img
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldQrCodeImg returns the old "qr_code_img" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldQrCodeImg(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldQrCodeImg is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldQrCodeImg requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldQrCodeImg: %w", err)
+ }
+ return oldValue.QrCodeImg, nil
+}
+
+// ClearQrCodeImg clears the value of the "qr_code_img" field.
+func (m *PaymentOrderMutation) ClearQrCodeImg() {
+ m.qr_code_img = nil
+ m.clearedFields[paymentorder.FieldQrCodeImg] = struct{}{}
+}
+
+// QrCodeImgCleared returns if the "qr_code_img" field was cleared in this mutation.
+func (m *PaymentOrderMutation) QrCodeImgCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldQrCodeImg]
+ return ok
+}
+
+// ResetQrCodeImg resets all changes to the "qr_code_img" field.
+func (m *PaymentOrderMutation) ResetQrCodeImg() {
+ m.qr_code_img = nil
+ delete(m.clearedFields, paymentorder.FieldQrCodeImg)
+}
+
+// SetOrderType sets the "order_type" field.
+func (m *PaymentOrderMutation) SetOrderType(s string) {
+ m.order_type = &s
+}
+
+// OrderType returns the value of the "order_type" field in the mutation.
+func (m *PaymentOrderMutation) OrderType() (r string, exists bool) {
+ v := m.order_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldOrderType returns the old "order_type" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldOrderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldOrderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldOrderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldOrderType: %w", err)
+ }
+ return oldValue.OrderType, nil
+}
+
+// ResetOrderType resets all changes to the "order_type" field.
+func (m *PaymentOrderMutation) ResetOrderType() {
+ m.order_type = nil
+}
+
+// SetPlanID sets the "plan_id" field.
+func (m *PaymentOrderMutation) SetPlanID(i int64) {
+ m.plan_id = &i
+ m.addplan_id = nil
+}
+
+// PlanID returns the value of the "plan_id" field in the mutation.
+func (m *PaymentOrderMutation) PlanID() (r int64, exists bool) {
+ v := m.plan_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPlanID returns the old "plan_id" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldPlanID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPlanID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPlanID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPlanID: %w", err)
+ }
+ return oldValue.PlanID, nil
+}
+
+// AddPlanID adds i to the "plan_id" field.
+func (m *PaymentOrderMutation) AddPlanID(i int64) {
+ if m.addplan_id != nil {
+ *m.addplan_id += i
+ } else {
+ m.addplan_id = &i
+ }
+}
+
+// AddedPlanID returns the value that was added to the "plan_id" field in this mutation.
+func (m *PaymentOrderMutation) AddedPlanID() (r int64, exists bool) {
+ v := m.addplan_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearPlanID clears the value of the "plan_id" field.
+func (m *PaymentOrderMutation) ClearPlanID() {
+ m.plan_id = nil
+ m.addplan_id = nil
+ m.clearedFields[paymentorder.FieldPlanID] = struct{}{}
+}
+
+// PlanIDCleared returns if the "plan_id" field was cleared in this mutation.
+func (m *PaymentOrderMutation) PlanIDCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldPlanID]
+ return ok
+}
+
+// ResetPlanID resets all changes to the "plan_id" field.
+func (m *PaymentOrderMutation) ResetPlanID() {
+ m.plan_id = nil
+ m.addplan_id = nil
+ delete(m.clearedFields, paymentorder.FieldPlanID)
+}
+
+// SetSubscriptionGroupID sets the "subscription_group_id" field.
+func (m *PaymentOrderMutation) SetSubscriptionGroupID(i int64) {
+ m.subscription_group_id = &i
+ m.addsubscription_group_id = nil
+}
+
+// SubscriptionGroupID returns the value of the "subscription_group_id" field in the mutation.
+func (m *PaymentOrderMutation) SubscriptionGroupID() (r int64, exists bool) {
+ v := m.subscription_group_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSubscriptionGroupID returns the old "subscription_group_id" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldSubscriptionGroupID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSubscriptionGroupID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSubscriptionGroupID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSubscriptionGroupID: %w", err)
+ }
+ return oldValue.SubscriptionGroupID, nil
+}
+
+// AddSubscriptionGroupID adds i to the "subscription_group_id" field.
+func (m *PaymentOrderMutation) AddSubscriptionGroupID(i int64) {
+ if m.addsubscription_group_id != nil {
+ *m.addsubscription_group_id += i
+ } else {
+ m.addsubscription_group_id = &i
+ }
+}
+
+// AddedSubscriptionGroupID returns the value that was added to the "subscription_group_id" field in this mutation.
+func (m *PaymentOrderMutation) AddedSubscriptionGroupID() (r int64, exists bool) {
+ v := m.addsubscription_group_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearSubscriptionGroupID clears the value of the "subscription_group_id" field.
+func (m *PaymentOrderMutation) ClearSubscriptionGroupID() {
+ m.subscription_group_id = nil
+ m.addsubscription_group_id = nil
+ m.clearedFields[paymentorder.FieldSubscriptionGroupID] = struct{}{}
+}
+
+// SubscriptionGroupIDCleared returns if the "subscription_group_id" field was cleared in this mutation.
+func (m *PaymentOrderMutation) SubscriptionGroupIDCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldSubscriptionGroupID]
+ return ok
+}
+
+// ResetSubscriptionGroupID resets all changes to the "subscription_group_id" field.
+func (m *PaymentOrderMutation) ResetSubscriptionGroupID() {
+ m.subscription_group_id = nil
+ m.addsubscription_group_id = nil
+ delete(m.clearedFields, paymentorder.FieldSubscriptionGroupID)
+}
+
+// SetSubscriptionDays sets the "subscription_days" field.
+func (m *PaymentOrderMutation) SetSubscriptionDays(i int) {
+ m.subscription_days = &i
+ m.addsubscription_days = nil
+}
+
+// SubscriptionDays returns the value of the "subscription_days" field in the mutation.
+func (m *PaymentOrderMutation) SubscriptionDays() (r int, exists bool) {
+ v := m.subscription_days
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSubscriptionDays returns the old "subscription_days" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldSubscriptionDays(ctx context.Context) (v *int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSubscriptionDays is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSubscriptionDays requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSubscriptionDays: %w", err)
+ }
+ return oldValue.SubscriptionDays, nil
+}
+
+// AddSubscriptionDays adds i to the "subscription_days" field.
+func (m *PaymentOrderMutation) AddSubscriptionDays(i int) {
+ if m.addsubscription_days != nil {
+ *m.addsubscription_days += i
+ } else {
+ m.addsubscription_days = &i
+ }
+}
+
+// AddedSubscriptionDays returns the value that was added to the "subscription_days" field in this mutation.
+func (m *PaymentOrderMutation) AddedSubscriptionDays() (r int, exists bool) {
+ v := m.addsubscription_days
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearSubscriptionDays clears the value of the "subscription_days" field.
+func (m *PaymentOrderMutation) ClearSubscriptionDays() {
+ m.subscription_days = nil
+ m.addsubscription_days = nil
+ m.clearedFields[paymentorder.FieldSubscriptionDays] = struct{}{}
+}
+
+// SubscriptionDaysCleared returns if the "subscription_days" field was cleared in this mutation.
+func (m *PaymentOrderMutation) SubscriptionDaysCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldSubscriptionDays]
+ return ok
+}
+
+// ResetSubscriptionDays resets all changes to the "subscription_days" field.
+func (m *PaymentOrderMutation) ResetSubscriptionDays() {
+ m.subscription_days = nil
+ m.addsubscription_days = nil
+ delete(m.clearedFields, paymentorder.FieldSubscriptionDays)
+}
+
+// SetProviderInstanceID sets the "provider_instance_id" field.
+func (m *PaymentOrderMutation) SetProviderInstanceID(s string) {
+ m.provider_instance_id = &s
+}
+
+// ProviderInstanceID returns the value of the "provider_instance_id" field in the mutation.
+func (m *PaymentOrderMutation) ProviderInstanceID() (r string, exists bool) {
+ v := m.provider_instance_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderInstanceID returns the old "provider_instance_id" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldProviderInstanceID(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderInstanceID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderInstanceID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderInstanceID: %w", err)
+ }
+ return oldValue.ProviderInstanceID, nil
+}
+
+// ClearProviderInstanceID clears the value of the "provider_instance_id" field.
+func (m *PaymentOrderMutation) ClearProviderInstanceID() {
+ m.provider_instance_id = nil
+ m.clearedFields[paymentorder.FieldProviderInstanceID] = struct{}{}
+}
+
+// ProviderInstanceIDCleared returns if the "provider_instance_id" field was cleared in this mutation.
+func (m *PaymentOrderMutation) ProviderInstanceIDCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldProviderInstanceID]
+ return ok
+}
+
+// ResetProviderInstanceID resets all changes to the "provider_instance_id" field.
+func (m *PaymentOrderMutation) ResetProviderInstanceID() {
+ m.provider_instance_id = nil
+ delete(m.clearedFields, paymentorder.FieldProviderInstanceID)
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *PaymentOrderMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *PaymentOrderMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldProviderKey(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (m *PaymentOrderMutation) ClearProviderKey() {
+ m.provider_key = nil
+ m.clearedFields[paymentorder.FieldProviderKey] = struct{}{}
+}
+
+// ProviderKeyCleared returns if the "provider_key" field was cleared in this mutation.
+func (m *PaymentOrderMutation) ProviderKeyCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldProviderKey]
+ return ok
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *PaymentOrderMutation) ResetProviderKey() {
+ m.provider_key = nil
+ delete(m.clearedFields, paymentorder.FieldProviderKey)
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (m *PaymentOrderMutation) SetProviderSnapshot(value map[string]interface{}) {
+ m.provider_snapshot = &value
+}
+
+// ProviderSnapshot returns the value of the "provider_snapshot" field in the mutation.
+func (m *PaymentOrderMutation) ProviderSnapshot() (r map[string]interface{}, exists bool) {
+ v := m.provider_snapshot
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderSnapshot returns the old "provider_snapshot" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldProviderSnapshot(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderSnapshot is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderSnapshot requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderSnapshot: %w", err)
+ }
+ return oldValue.ProviderSnapshot, nil
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (m *PaymentOrderMutation) ClearProviderSnapshot() {
+ m.provider_snapshot = nil
+ m.clearedFields[paymentorder.FieldProviderSnapshot] = struct{}{}
+}
+
+// ProviderSnapshotCleared returns if the "provider_snapshot" field was cleared in this mutation.
+func (m *PaymentOrderMutation) ProviderSnapshotCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldProviderSnapshot]
+ return ok
+}
+
+// ResetProviderSnapshot resets all changes to the "provider_snapshot" field.
+func (m *PaymentOrderMutation) ResetProviderSnapshot() {
+ m.provider_snapshot = nil
+ delete(m.clearedFields, paymentorder.FieldProviderSnapshot)
+}
+
+// SetStatus sets the "status" field.
+func (m *PaymentOrderMutation) SetStatus(s string) {
+ m.status = &s
+}
+
+// Status returns the value of the "status" field in the mutation.
+func (m *PaymentOrderMutation) Status() (r string, exists bool) {
+ v := m.status
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldStatus returns the old "status" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldStatus(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldStatus is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldStatus requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldStatus: %w", err)
+ }
+ return oldValue.Status, nil
+}
+
+// ResetStatus resets all changes to the "status" field.
+func (m *PaymentOrderMutation) ResetStatus() {
+ m.status = nil
+}
+
+// SetRefundAmount sets the "refund_amount" field.
+func (m *PaymentOrderMutation) SetRefundAmount(f float64) {
+ m.refund_amount = &f
+ m.addrefund_amount = nil
+}
+
+// RefundAmount returns the value of the "refund_amount" field in the mutation.
+func (m *PaymentOrderMutation) RefundAmount() (r float64, exists bool) {
+ v := m.refund_amount
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRefundAmount returns the old "refund_amount" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldRefundAmount(ctx context.Context) (v float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRefundAmount is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRefundAmount requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRefundAmount: %w", err)
+ }
+ return oldValue.RefundAmount, nil
+}
+
+// AddRefundAmount adds f to the "refund_amount" field.
+func (m *PaymentOrderMutation) AddRefundAmount(f float64) {
+ if m.addrefund_amount != nil {
+ *m.addrefund_amount += f
+ } else {
+ m.addrefund_amount = &f
+ }
+}
+
+// AddedRefundAmount returns the value that was added to the "refund_amount" field in this mutation.
+func (m *PaymentOrderMutation) AddedRefundAmount() (r float64, exists bool) {
+ v := m.addrefund_amount
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetRefundAmount resets all changes to the "refund_amount" field.
+func (m *PaymentOrderMutation) ResetRefundAmount() {
+ m.refund_amount = nil
+ m.addrefund_amount = nil
+}
+
+// SetRefundReason sets the "refund_reason" field.
+func (m *PaymentOrderMutation) SetRefundReason(s string) {
+ m.refund_reason = &s
+}
+
+// RefundReason returns the value of the "refund_reason" field in the mutation.
+func (m *PaymentOrderMutation) RefundReason() (r string, exists bool) {
+ v := m.refund_reason
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRefundReason returns the old "refund_reason" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldRefundReason(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRefundReason is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRefundReason requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRefundReason: %w", err)
+ }
+ return oldValue.RefundReason, nil
+}
+
+// ClearRefundReason clears the value of the "refund_reason" field.
+func (m *PaymentOrderMutation) ClearRefundReason() {
+ m.refund_reason = nil
+ m.clearedFields[paymentorder.FieldRefundReason] = struct{}{}
+}
+
+// RefundReasonCleared returns if the "refund_reason" field was cleared in this mutation.
+func (m *PaymentOrderMutation) RefundReasonCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldRefundReason]
+ return ok
+}
+
+// ResetRefundReason resets all changes to the "refund_reason" field.
+func (m *PaymentOrderMutation) ResetRefundReason() {
+ m.refund_reason = nil
+ delete(m.clearedFields, paymentorder.FieldRefundReason)
+}
+
+// SetRefundAt sets the "refund_at" field.
+func (m *PaymentOrderMutation) SetRefundAt(t time.Time) {
+ m.refund_at = &t
+}
+
+// RefundAt returns the value of the "refund_at" field in the mutation.
+func (m *PaymentOrderMutation) RefundAt() (r time.Time, exists bool) {
+ v := m.refund_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRefundAt returns the old "refund_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldRefundAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRefundAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRefundAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRefundAt: %w", err)
+ }
+ return oldValue.RefundAt, nil
+}
+
+// ClearRefundAt clears the value of the "refund_at" field.
+func (m *PaymentOrderMutation) ClearRefundAt() {
+ m.refund_at = nil
+ m.clearedFields[paymentorder.FieldRefundAt] = struct{}{}
+}
+
+// RefundAtCleared returns if the "refund_at" field was cleared in this mutation.
+func (m *PaymentOrderMutation) RefundAtCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldRefundAt]
+ return ok
+}
+
+// ResetRefundAt resets all changes to the "refund_at" field.
+func (m *PaymentOrderMutation) ResetRefundAt() {
+ m.refund_at = nil
+ delete(m.clearedFields, paymentorder.FieldRefundAt)
+}
+
+// SetForceRefund sets the "force_refund" field.
+func (m *PaymentOrderMutation) SetForceRefund(b bool) {
+ m.force_refund = &b
+}
+
+// ForceRefund returns the value of the "force_refund" field in the mutation.
+func (m *PaymentOrderMutation) ForceRefund() (r bool, exists bool) {
+ v := m.force_refund
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldForceRefund returns the old "force_refund" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldForceRefund(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldForceRefund is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldForceRefund requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldForceRefund: %w", err)
+ }
+ return oldValue.ForceRefund, nil
+}
+
+// ResetForceRefund resets all changes to the "force_refund" field.
+func (m *PaymentOrderMutation) ResetForceRefund() {
+ m.force_refund = nil
+}
+
+// SetRefundRequestedAt sets the "refund_requested_at" field.
+func (m *PaymentOrderMutation) SetRefundRequestedAt(t time.Time) {
+ m.refund_requested_at = &t
+}
+
+// RefundRequestedAt returns the value of the "refund_requested_at" field in the mutation.
+func (m *PaymentOrderMutation) RefundRequestedAt() (r time.Time, exists bool) {
+ v := m.refund_requested_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRefundRequestedAt returns the old "refund_requested_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldRefundRequestedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRefundRequestedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRefundRequestedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRefundRequestedAt: %w", err)
+ }
+ return oldValue.RefundRequestedAt, nil
+}
+
+// ClearRefundRequestedAt clears the value of the "refund_requested_at" field.
+func (m *PaymentOrderMutation) ClearRefundRequestedAt() {
+ m.refund_requested_at = nil
+ m.clearedFields[paymentorder.FieldRefundRequestedAt] = struct{}{}
+}
+
+// RefundRequestedAtCleared returns if the "refund_requested_at" field was cleared in this mutation.
+func (m *PaymentOrderMutation) RefundRequestedAtCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldRefundRequestedAt]
+ return ok
+}
+
+// ResetRefundRequestedAt resets all changes to the "refund_requested_at" field.
+func (m *PaymentOrderMutation) ResetRefundRequestedAt() {
+ m.refund_requested_at = nil
+ delete(m.clearedFields, paymentorder.FieldRefundRequestedAt)
+}
+
+// SetRefundRequestReason sets the "refund_request_reason" field.
+func (m *PaymentOrderMutation) SetRefundRequestReason(s string) {
+ m.refund_request_reason = &s
+}
+
+// RefundRequestReason returns the value of the "refund_request_reason" field in the mutation.
+func (m *PaymentOrderMutation) RefundRequestReason() (r string, exists bool) {
+ v := m.refund_request_reason
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRefundRequestReason returns the old "refund_request_reason" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldRefundRequestReason(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRefundRequestReason is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRefundRequestReason requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRefundRequestReason: %w", err)
+ }
+ return oldValue.RefundRequestReason, nil
+}
+
+// ClearRefundRequestReason clears the value of the "refund_request_reason" field.
+func (m *PaymentOrderMutation) ClearRefundRequestReason() {
+ m.refund_request_reason = nil
+ m.clearedFields[paymentorder.FieldRefundRequestReason] = struct{}{}
+}
+
+// RefundRequestReasonCleared returns if the "refund_request_reason" field was cleared in this mutation.
+func (m *PaymentOrderMutation) RefundRequestReasonCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldRefundRequestReason]
+ return ok
+}
+
+// ResetRefundRequestReason resets all changes to the "refund_request_reason" field.
+func (m *PaymentOrderMutation) ResetRefundRequestReason() {
+ m.refund_request_reason = nil
+ delete(m.clearedFields, paymentorder.FieldRefundRequestReason)
+}
+
+// SetRefundRequestedBy sets the "refund_requested_by" field.
+func (m *PaymentOrderMutation) SetRefundRequestedBy(s string) {
+ m.refund_requested_by = &s
+}
+
+// RefundRequestedBy returns the value of the "refund_requested_by" field in the mutation.
+func (m *PaymentOrderMutation) RefundRequestedBy() (r string, exists bool) {
+ v := m.refund_requested_by
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRefundRequestedBy returns the old "refund_requested_by" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldRefundRequestedBy(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRefundRequestedBy is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRefundRequestedBy requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRefundRequestedBy: %w", err)
+ }
+ return oldValue.RefundRequestedBy, nil
+}
+
+// ClearRefundRequestedBy clears the value of the "refund_requested_by" field.
+func (m *PaymentOrderMutation) ClearRefundRequestedBy() {
+ m.refund_requested_by = nil
+ m.clearedFields[paymentorder.FieldRefundRequestedBy] = struct{}{}
+}
+
+// RefundRequestedByCleared returns if the "refund_requested_by" field was cleared in this mutation.
+func (m *PaymentOrderMutation) RefundRequestedByCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldRefundRequestedBy]
+ return ok
+}
+
+// ResetRefundRequestedBy resets all changes to the "refund_requested_by" field.
+func (m *PaymentOrderMutation) ResetRefundRequestedBy() {
+ m.refund_requested_by = nil
+ delete(m.clearedFields, paymentorder.FieldRefundRequestedBy)
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (m *PaymentOrderMutation) SetExpiresAt(t time.Time) {
+ m.expires_at = &t
+}
+
+// ExpiresAt returns the value of the "expires_at" field in the mutation.
+func (m *PaymentOrderMutation) ExpiresAt() (r time.Time, exists bool) {
+ v := m.expires_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExpiresAt returns the old "expires_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExpiresAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
+ }
+ return oldValue.ExpiresAt, nil
+}
+
+// ResetExpiresAt resets all changes to the "expires_at" field.
+func (m *PaymentOrderMutation) ResetExpiresAt() {
+ m.expires_at = nil
+}
+
+// SetPaidAt sets the "paid_at" field.
+func (m *PaymentOrderMutation) SetPaidAt(t time.Time) {
+ m.paid_at = &t
+}
+
+// PaidAt returns the value of the "paid_at" field in the mutation.
+func (m *PaymentOrderMutation) PaidAt() (r time.Time, exists bool) {
+ v := m.paid_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPaidAt returns the old "paid_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldPaidAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPaidAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPaidAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPaidAt: %w", err)
+ }
+ return oldValue.PaidAt, nil
+}
+
+// ClearPaidAt clears the value of the "paid_at" field.
+func (m *PaymentOrderMutation) ClearPaidAt() {
+ m.paid_at = nil
+ m.clearedFields[paymentorder.FieldPaidAt] = struct{}{}
+}
+
+// PaidAtCleared returns if the "paid_at" field was cleared in this mutation.
+func (m *PaymentOrderMutation) PaidAtCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldPaidAt]
+ return ok
+}
+
+// ResetPaidAt resets all changes to the "paid_at" field.
+func (m *PaymentOrderMutation) ResetPaidAt() {
+ m.paid_at = nil
+ delete(m.clearedFields, paymentorder.FieldPaidAt)
+}
+
+// SetCompletedAt sets the "completed_at" field.
+func (m *PaymentOrderMutation) SetCompletedAt(t time.Time) {
+ m.completed_at = &t
+}
+
+// CompletedAt returns the value of the "completed_at" field in the mutation.
+func (m *PaymentOrderMutation) CompletedAt() (r time.Time, exists bool) {
+ v := m.completed_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCompletedAt returns the old "completed_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldCompletedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCompletedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCompletedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCompletedAt: %w", err)
+ }
+ return oldValue.CompletedAt, nil
+}
+
+// ClearCompletedAt clears the value of the "completed_at" field.
+func (m *PaymentOrderMutation) ClearCompletedAt() {
+ m.completed_at = nil
+ m.clearedFields[paymentorder.FieldCompletedAt] = struct{}{}
+}
+
+// CompletedAtCleared returns if the "completed_at" field was cleared in this mutation.
+func (m *PaymentOrderMutation) CompletedAtCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldCompletedAt]
+ return ok
+}
+
+// ResetCompletedAt resets all changes to the "completed_at" field.
+func (m *PaymentOrderMutation) ResetCompletedAt() {
+ m.completed_at = nil
+ delete(m.clearedFields, paymentorder.FieldCompletedAt)
+}
+
+// SetFailedAt sets the "failed_at" field.
+func (m *PaymentOrderMutation) SetFailedAt(t time.Time) {
+ m.failed_at = &t
+}
+
+// FailedAt returns the value of the "failed_at" field in the mutation.
+func (m *PaymentOrderMutation) FailedAt() (r time.Time, exists bool) {
+ v := m.failed_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldFailedAt returns the old "failed_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldFailedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldFailedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldFailedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldFailedAt: %w", err)
+ }
+ return oldValue.FailedAt, nil
+}
+
+// ClearFailedAt clears the value of the "failed_at" field.
+func (m *PaymentOrderMutation) ClearFailedAt() {
+ m.failed_at = nil
+ m.clearedFields[paymentorder.FieldFailedAt] = struct{}{}
+}
+
+// FailedAtCleared returns if the "failed_at" field was cleared in this mutation.
+func (m *PaymentOrderMutation) FailedAtCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldFailedAt]
+ return ok
+}
+
+// ResetFailedAt resets all changes to the "failed_at" field.
+func (m *PaymentOrderMutation) ResetFailedAt() {
+ m.failed_at = nil
+ delete(m.clearedFields, paymentorder.FieldFailedAt)
+}
+
+// SetFailedReason sets the "failed_reason" field.
+func (m *PaymentOrderMutation) SetFailedReason(s string) {
+ m.failed_reason = &s
+}
+
+// FailedReason returns the value of the "failed_reason" field in the mutation.
+func (m *PaymentOrderMutation) FailedReason() (r string, exists bool) {
+ v := m.failed_reason
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldFailedReason returns the old "failed_reason" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldFailedReason(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldFailedReason is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldFailedReason requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldFailedReason: %w", err)
+ }
+ return oldValue.FailedReason, nil
+}
+
+// ClearFailedReason clears the value of the "failed_reason" field.
+func (m *PaymentOrderMutation) ClearFailedReason() {
+ m.failed_reason = nil
+ m.clearedFields[paymentorder.FieldFailedReason] = struct{}{}
+}
+
+// FailedReasonCleared returns if the "failed_reason" field was cleared in this mutation.
+func (m *PaymentOrderMutation) FailedReasonCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldFailedReason]
+ return ok
+}
+
+// ResetFailedReason resets all changes to the "failed_reason" field.
+func (m *PaymentOrderMutation) ResetFailedReason() {
+ m.failed_reason = nil
+ delete(m.clearedFields, paymentorder.FieldFailedReason)
+}
+
+// SetClientIP sets the "client_ip" field.
+func (m *PaymentOrderMutation) SetClientIP(s string) {
+ m.client_ip = &s
+}
+
+// ClientIP returns the value of the "client_ip" field in the mutation.
+func (m *PaymentOrderMutation) ClientIP() (r string, exists bool) {
+ v := m.client_ip
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldClientIP returns the old "client_ip" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldClientIP(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldClientIP is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldClientIP requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldClientIP: %w", err)
+ }
+ return oldValue.ClientIP, nil
+}
+
+// ResetClientIP resets all changes to the "client_ip" field.
+func (m *PaymentOrderMutation) ResetClientIP() {
+ m.client_ip = nil
+}
+
+// SetSrcHost sets the "src_host" field.
+func (m *PaymentOrderMutation) SetSrcHost(s string) {
+ m.src_host = &s
+}
+
+// SrcHost returns the value of the "src_host" field in the mutation.
+func (m *PaymentOrderMutation) SrcHost() (r string, exists bool) {
+ v := m.src_host
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSrcHost returns the old "src_host" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldSrcHost(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSrcHost is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSrcHost requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSrcHost: %w", err)
+ }
+ return oldValue.SrcHost, nil
+}
+
+// ResetSrcHost resets all changes to the "src_host" field.
+func (m *PaymentOrderMutation) ResetSrcHost() {
+ m.src_host = nil
+}
+
+// SetSrcURL sets the "src_url" field.
+func (m *PaymentOrderMutation) SetSrcURL(s string) {
+ m.src_url = &s
+}
+
+// SrcURL returns the value of the "src_url" field in the mutation.
+func (m *PaymentOrderMutation) SrcURL() (r string, exists bool) {
+ v := m.src_url
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSrcURL returns the old "src_url" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldSrcURL(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSrcURL is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSrcURL requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSrcURL: %w", err)
+ }
+ return oldValue.SrcURL, nil
+}
+
+// ClearSrcURL clears the value of the "src_url" field.
+func (m *PaymentOrderMutation) ClearSrcURL() {
+ m.src_url = nil
+ m.clearedFields[paymentorder.FieldSrcURL] = struct{}{}
+}
+
+// SrcURLCleared returns if the "src_url" field was cleared in this mutation.
+func (m *PaymentOrderMutation) SrcURLCleared() bool {
+ _, ok := m.clearedFields[paymentorder.FieldSrcURL]
+ return ok
+}
+
+// ResetSrcURL resets all changes to the "src_url" field.
+func (m *PaymentOrderMutation) ResetSrcURL() {
+ m.src_url = nil
+ delete(m.clearedFields, paymentorder.FieldSrcURL)
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *PaymentOrderMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *PaymentOrderMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *PaymentOrderMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *PaymentOrderMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *PaymentOrderMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the PaymentOrder entity.
+// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentOrderMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *PaymentOrderMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (m *PaymentOrderMutation) ClearUser() {
+ m.cleareduser = true
+ m.clearedFields[paymentorder.FieldUserID] = struct{}{}
+}
+
+// UserCleared reports if the "user" edge to the User entity was cleared.
+func (m *PaymentOrderMutation) UserCleared() bool {
+ return m.cleareduser
+}
+
+// UserIDs returns the "user" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// UserID instead. It exists only for internal usage by the builders.
+func (m *PaymentOrderMutation) UserIDs() (ids []int64) {
+ if id := m.user; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetUser resets all changes to the "user" edge.
+func (m *PaymentOrderMutation) ResetUser() {
+ m.user = nil
+ m.cleareduser = false
+}
+
+// Where appends a list predicates to the PaymentOrderMutation builder.
+func (m *PaymentOrderMutation) Where(ps ...predicate.PaymentOrder) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the PaymentOrderMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *PaymentOrderMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.PaymentOrder, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *PaymentOrderMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *PaymentOrderMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (PaymentOrder).
+func (m *PaymentOrderMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *PaymentOrderMutation) Fields() []string {
+ fields := make([]string, 0, 39)
+ if m.user != nil {
+ fields = append(fields, paymentorder.FieldUserID)
+ }
+ if m.user_email != nil {
+ fields = append(fields, paymentorder.FieldUserEmail)
+ }
+ if m.user_name != nil {
+ fields = append(fields, paymentorder.FieldUserName)
+ }
+ if m.user_notes != nil {
+ fields = append(fields, paymentorder.FieldUserNotes)
+ }
+ if m.amount != nil {
+ fields = append(fields, paymentorder.FieldAmount)
+ }
+ if m.pay_amount != nil {
+ fields = append(fields, paymentorder.FieldPayAmount)
+ }
+ if m.fee_rate != nil {
+ fields = append(fields, paymentorder.FieldFeeRate)
+ }
+ if m.recharge_code != nil {
+ fields = append(fields, paymentorder.FieldRechargeCode)
+ }
+ if m.out_trade_no != nil {
+ fields = append(fields, paymentorder.FieldOutTradeNo)
+ }
+ if m.payment_type != nil {
+ fields = append(fields, paymentorder.FieldPaymentType)
+ }
+ if m.payment_trade_no != nil {
+ fields = append(fields, paymentorder.FieldPaymentTradeNo)
+ }
+ if m.pay_url != nil {
+ fields = append(fields, paymentorder.FieldPayURL)
+ }
+ if m.qr_code != nil {
+ fields = append(fields, paymentorder.FieldQrCode)
+ }
+ if m.qr_code_img != nil {
+ fields = append(fields, paymentorder.FieldQrCodeImg)
+ }
+ if m.order_type != nil {
+ fields = append(fields, paymentorder.FieldOrderType)
+ }
+ if m.plan_id != nil {
+ fields = append(fields, paymentorder.FieldPlanID)
+ }
+ if m.subscription_group_id != nil {
+ fields = append(fields, paymentorder.FieldSubscriptionGroupID)
+ }
+ if m.subscription_days != nil {
+ fields = append(fields, paymentorder.FieldSubscriptionDays)
+ }
+ if m.provider_instance_id != nil {
+ fields = append(fields, paymentorder.FieldProviderInstanceID)
+ }
+ if m.provider_key != nil {
+ fields = append(fields, paymentorder.FieldProviderKey)
+ }
+ if m.provider_snapshot != nil {
+ fields = append(fields, paymentorder.FieldProviderSnapshot)
+ }
+ if m.status != nil {
+ fields = append(fields, paymentorder.FieldStatus)
+ }
+ if m.refund_amount != nil {
+ fields = append(fields, paymentorder.FieldRefundAmount)
+ }
+ if m.refund_reason != nil {
+ fields = append(fields, paymentorder.FieldRefundReason)
+ }
+ if m.refund_at != nil {
+ fields = append(fields, paymentorder.FieldRefundAt)
+ }
+ if m.force_refund != nil {
+ fields = append(fields, paymentorder.FieldForceRefund)
+ }
+ if m.refund_requested_at != nil {
+ fields = append(fields, paymentorder.FieldRefundRequestedAt)
+ }
+ if m.refund_request_reason != nil {
+ fields = append(fields, paymentorder.FieldRefundRequestReason)
+ }
+ if m.refund_requested_by != nil {
+ fields = append(fields, paymentorder.FieldRefundRequestedBy)
+ }
+ if m.expires_at != nil {
+ fields = append(fields, paymentorder.FieldExpiresAt)
+ }
+ if m.paid_at != nil {
+ fields = append(fields, paymentorder.FieldPaidAt)
+ }
+ if m.completed_at != nil {
+ fields = append(fields, paymentorder.FieldCompletedAt)
+ }
+ if m.failed_at != nil {
+ fields = append(fields, paymentorder.FieldFailedAt)
+ }
+ if m.failed_reason != nil {
+ fields = append(fields, paymentorder.FieldFailedReason)
+ }
+ if m.client_ip != nil {
+ fields = append(fields, paymentorder.FieldClientIP)
+ }
+ if m.src_host != nil {
+ fields = append(fields, paymentorder.FieldSrcHost)
+ }
+ if m.src_url != nil {
+ fields = append(fields, paymentorder.FieldSrcURL)
+ }
+ if m.created_at != nil {
+ fields = append(fields, paymentorder.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, paymentorder.FieldUpdatedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case paymentorder.FieldUserID:
+ return m.UserID()
+ case paymentorder.FieldUserEmail:
+ return m.UserEmail()
+ case paymentorder.FieldUserName:
+ return m.UserName()
+ case paymentorder.FieldUserNotes:
+ return m.UserNotes()
+ case paymentorder.FieldAmount:
+ return m.Amount()
+ case paymentorder.FieldPayAmount:
+ return m.PayAmount()
+ case paymentorder.FieldFeeRate:
+ return m.FeeRate()
+ case paymentorder.FieldRechargeCode:
+ return m.RechargeCode()
+ case paymentorder.FieldOutTradeNo:
+ return m.OutTradeNo()
+ case paymentorder.FieldPaymentType:
+ return m.PaymentType()
+ case paymentorder.FieldPaymentTradeNo:
+ return m.PaymentTradeNo()
+ case paymentorder.FieldPayURL:
+ return m.PayURL()
+ case paymentorder.FieldQrCode:
+ return m.QrCode()
+ case paymentorder.FieldQrCodeImg:
+ return m.QrCodeImg()
+ case paymentorder.FieldOrderType:
+ return m.OrderType()
+ case paymentorder.FieldPlanID:
+ return m.PlanID()
+ case paymentorder.FieldSubscriptionGroupID:
+ return m.SubscriptionGroupID()
+ case paymentorder.FieldSubscriptionDays:
+ return m.SubscriptionDays()
+ case paymentorder.FieldProviderInstanceID:
+ return m.ProviderInstanceID()
+ case paymentorder.FieldProviderKey:
+ return m.ProviderKey()
+ case paymentorder.FieldProviderSnapshot:
+ return m.ProviderSnapshot()
+ case paymentorder.FieldStatus:
+ return m.Status()
+ case paymentorder.FieldRefundAmount:
+ return m.RefundAmount()
+ case paymentorder.FieldRefundReason:
+ return m.RefundReason()
+ case paymentorder.FieldRefundAt:
+ return m.RefundAt()
+ case paymentorder.FieldForceRefund:
+ return m.ForceRefund()
+ case paymentorder.FieldRefundRequestedAt:
+ return m.RefundRequestedAt()
+ case paymentorder.FieldRefundRequestReason:
+ return m.RefundRequestReason()
+ case paymentorder.FieldRefundRequestedBy:
+ return m.RefundRequestedBy()
+ case paymentorder.FieldExpiresAt:
+ return m.ExpiresAt()
+ case paymentorder.FieldPaidAt:
+ return m.PaidAt()
+ case paymentorder.FieldCompletedAt:
+ return m.CompletedAt()
+ case paymentorder.FieldFailedAt:
+ return m.FailedAt()
+ case paymentorder.FieldFailedReason:
+ return m.FailedReason()
+ case paymentorder.FieldClientIP:
+ return m.ClientIP()
+ case paymentorder.FieldSrcHost:
+ return m.SrcHost()
+ case paymentorder.FieldSrcURL:
+ return m.SrcURL()
+ case paymentorder.FieldCreatedAt:
+ return m.CreatedAt()
+ case paymentorder.FieldUpdatedAt:
+ return m.UpdatedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case paymentorder.FieldUserID:
+ return m.OldUserID(ctx)
+ case paymentorder.FieldUserEmail:
+ return m.OldUserEmail(ctx)
+ case paymentorder.FieldUserName:
+ return m.OldUserName(ctx)
+ case paymentorder.FieldUserNotes:
+ return m.OldUserNotes(ctx)
+ case paymentorder.FieldAmount:
+ return m.OldAmount(ctx)
+ case paymentorder.FieldPayAmount:
+ return m.OldPayAmount(ctx)
+ case paymentorder.FieldFeeRate:
+ return m.OldFeeRate(ctx)
+ case paymentorder.FieldRechargeCode:
+ return m.OldRechargeCode(ctx)
+ case paymentorder.FieldOutTradeNo:
+ return m.OldOutTradeNo(ctx)
+ case paymentorder.FieldPaymentType:
+ return m.OldPaymentType(ctx)
+ case paymentorder.FieldPaymentTradeNo:
+ return m.OldPaymentTradeNo(ctx)
+ case paymentorder.FieldPayURL:
+ return m.OldPayURL(ctx)
+ case paymentorder.FieldQrCode:
+ return m.OldQrCode(ctx)
+ case paymentorder.FieldQrCodeImg:
+ return m.OldQrCodeImg(ctx)
+ case paymentorder.FieldOrderType:
+ return m.OldOrderType(ctx)
+ case paymentorder.FieldPlanID:
+ return m.OldPlanID(ctx)
+ case paymentorder.FieldSubscriptionGroupID:
+ return m.OldSubscriptionGroupID(ctx)
+ case paymentorder.FieldSubscriptionDays:
+ return m.OldSubscriptionDays(ctx)
+ case paymentorder.FieldProviderInstanceID:
+ return m.OldProviderInstanceID(ctx)
+ case paymentorder.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case paymentorder.FieldProviderSnapshot:
+ return m.OldProviderSnapshot(ctx)
+ case paymentorder.FieldStatus:
+ return m.OldStatus(ctx)
+ case paymentorder.FieldRefundAmount:
+ return m.OldRefundAmount(ctx)
+ case paymentorder.FieldRefundReason:
+ return m.OldRefundReason(ctx)
+ case paymentorder.FieldRefundAt:
+ return m.OldRefundAt(ctx)
+ case paymentorder.FieldForceRefund:
+ return m.OldForceRefund(ctx)
+ case paymentorder.FieldRefundRequestedAt:
+ return m.OldRefundRequestedAt(ctx)
+ case paymentorder.FieldRefundRequestReason:
+ return m.OldRefundRequestReason(ctx)
+ case paymentorder.FieldRefundRequestedBy:
+ return m.OldRefundRequestedBy(ctx)
+ case paymentorder.FieldExpiresAt:
+ return m.OldExpiresAt(ctx)
+ case paymentorder.FieldPaidAt:
+ return m.OldPaidAt(ctx)
+ case paymentorder.FieldCompletedAt:
+ return m.OldCompletedAt(ctx)
+ case paymentorder.FieldFailedAt:
+ return m.OldFailedAt(ctx)
+ case paymentorder.FieldFailedReason:
+ return m.OldFailedReason(ctx)
+ case paymentorder.FieldClientIP:
+ return m.OldClientIP(ctx)
+ case paymentorder.FieldSrcHost:
+ return m.OldSrcHost(ctx)
+ case paymentorder.FieldSrcURL:
+ return m.OldSrcURL(ctx)
+ case paymentorder.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case paymentorder.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown PaymentOrder field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case paymentorder.FieldUserID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUserID(v)
+ return nil
+ case paymentorder.FieldUserEmail:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUserEmail(v)
+ return nil
+ case paymentorder.FieldUserName:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUserName(v)
+ return nil
+ case paymentorder.FieldUserNotes:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUserNotes(v)
+ return nil
+ case paymentorder.FieldAmount:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAmount(v)
+ return nil
+ case paymentorder.FieldPayAmount:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPayAmount(v)
+ return nil
+ case paymentorder.FieldFeeRate:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetFeeRate(v)
+ return nil
+ case paymentorder.FieldRechargeCode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRechargeCode(v)
+ return nil
+ case paymentorder.FieldOutTradeNo:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetOutTradeNo(v)
+ return nil
+ case paymentorder.FieldPaymentType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPaymentType(v)
+ return nil
+ case paymentorder.FieldPaymentTradeNo:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPaymentTradeNo(v)
+ return nil
+ case paymentorder.FieldPayURL:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPayURL(v)
+ return nil
+ case paymentorder.FieldQrCode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetQrCode(v)
+ return nil
+ case paymentorder.FieldQrCodeImg:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetQrCodeImg(v)
+ return nil
+ case paymentorder.FieldOrderType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetOrderType(v)
+ return nil
+ case paymentorder.FieldPlanID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPlanID(v)
+ return nil
+ case paymentorder.FieldSubscriptionGroupID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSubscriptionGroupID(v)
+ return nil
+ case paymentorder.FieldSubscriptionDays:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSubscriptionDays(v)
+ return nil
+ case paymentorder.FieldProviderInstanceID:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderInstanceID(v)
+ return nil
+ case paymentorder.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case paymentorder.FieldProviderSnapshot:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderSnapshot(v)
+ return nil
+ case paymentorder.FieldStatus:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetStatus(v)
+ return nil
+ case paymentorder.FieldRefundAmount:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRefundAmount(v)
+ return nil
+ case paymentorder.FieldRefundReason:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRefundReason(v)
+ return nil
+ case paymentorder.FieldRefundAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRefundAt(v)
+ return nil
+ case paymentorder.FieldForceRefund:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetForceRefund(v)
+ return nil
+ case paymentorder.FieldRefundRequestedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRefundRequestedAt(v)
+ return nil
+ case paymentorder.FieldRefundRequestReason:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRefundRequestReason(v)
+ return nil
+ case paymentorder.FieldRefundRequestedBy:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRefundRequestedBy(v)
+ return nil
+ case paymentorder.FieldExpiresAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExpiresAt(v)
+ return nil
+ case paymentorder.FieldPaidAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPaidAt(v)
+ return nil
+ case paymentorder.FieldCompletedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCompletedAt(v)
+ return nil
+ case paymentorder.FieldFailedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetFailedAt(v)
+ return nil
+ case paymentorder.FieldFailedReason:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetFailedReason(v)
+ return nil
+ case paymentorder.FieldClientIP:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetClientIP(v)
+ return nil
+ case paymentorder.FieldSrcHost:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSrcHost(v)
+ return nil
+ case paymentorder.FieldSrcURL:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSrcURL(v)
+ return nil
+ case paymentorder.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case paymentorder.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown PaymentOrder field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *PaymentOrderMutation) AddedFields() []string {
+ var fields []string
+ if m.addamount != nil {
+ fields = append(fields, paymentorder.FieldAmount)
+ }
+ if m.addpay_amount != nil {
+ fields = append(fields, paymentorder.FieldPayAmount)
+ }
+ if m.addfee_rate != nil {
+ fields = append(fields, paymentorder.FieldFeeRate)
+ }
+ if m.addplan_id != nil {
+ fields = append(fields, paymentorder.FieldPlanID)
+ }
+ if m.addsubscription_group_id != nil {
+ fields = append(fields, paymentorder.FieldSubscriptionGroupID)
+ }
+ if m.addsubscription_days != nil {
+ fields = append(fields, paymentorder.FieldSubscriptionDays)
+ }
+ if m.addrefund_amount != nil {
+ fields = append(fields, paymentorder.FieldRefundAmount)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *PaymentOrderMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case paymentorder.FieldAmount:
+ return m.AddedAmount()
+ case paymentorder.FieldPayAmount:
+ return m.AddedPayAmount()
+ case paymentorder.FieldFeeRate:
+ return m.AddedFeeRate()
+ case paymentorder.FieldPlanID:
+ return m.AddedPlanID()
+ case paymentorder.FieldSubscriptionGroupID:
+ return m.AddedSubscriptionGroupID()
+ case paymentorder.FieldSubscriptionDays:
+ return m.AddedSubscriptionDays()
+ case paymentorder.FieldRefundAmount:
+ return m.AddedRefundAmount()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PaymentOrderMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case paymentorder.FieldAmount:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddAmount(v)
+ return nil
+ case paymentorder.FieldPayAmount:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddPayAmount(v)
+ return nil
+ case paymentorder.FieldFeeRate:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddFeeRate(v)
+ return nil
+ case paymentorder.FieldPlanID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddPlanID(v)
+ return nil
+ case paymentorder.FieldSubscriptionGroupID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSubscriptionGroupID(v)
+ return nil
+ case paymentorder.FieldSubscriptionDays:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSubscriptionDays(v)
+ return nil
+ case paymentorder.FieldRefundAmount:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddRefundAmount(v)
+ return nil
+ }
+ return fmt.Errorf("unknown PaymentOrder numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *PaymentOrderMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(paymentorder.FieldUserNotes) {
+ fields = append(fields, paymentorder.FieldUserNotes)
+ }
+ if m.FieldCleared(paymentorder.FieldPayURL) {
+ fields = append(fields, paymentorder.FieldPayURL)
+ }
+ if m.FieldCleared(paymentorder.FieldQrCode) {
+ fields = append(fields, paymentorder.FieldQrCode)
+ }
+ if m.FieldCleared(paymentorder.FieldQrCodeImg) {
+ fields = append(fields, paymentorder.FieldQrCodeImg)
+ }
+ if m.FieldCleared(paymentorder.FieldPlanID) {
+ fields = append(fields, paymentorder.FieldPlanID)
+ }
+ if m.FieldCleared(paymentorder.FieldSubscriptionGroupID) {
+ fields = append(fields, paymentorder.FieldSubscriptionGroupID)
+ }
+ if m.FieldCleared(paymentorder.FieldSubscriptionDays) {
+ fields = append(fields, paymentorder.FieldSubscriptionDays)
+ }
+ if m.FieldCleared(paymentorder.FieldProviderInstanceID) {
+ fields = append(fields, paymentorder.FieldProviderInstanceID)
+ }
+ if m.FieldCleared(paymentorder.FieldProviderKey) {
+ fields = append(fields, paymentorder.FieldProviderKey)
+ }
+ if m.FieldCleared(paymentorder.FieldProviderSnapshot) {
+ fields = append(fields, paymentorder.FieldProviderSnapshot)
+ }
+ if m.FieldCleared(paymentorder.FieldRefundReason) {
+ fields = append(fields, paymentorder.FieldRefundReason)
+ }
+ if m.FieldCleared(paymentorder.FieldRefundAt) {
+ fields = append(fields, paymentorder.FieldRefundAt)
+ }
+ if m.FieldCleared(paymentorder.FieldRefundRequestedAt) {
+ fields = append(fields, paymentorder.FieldRefundRequestedAt)
+ }
+ if m.FieldCleared(paymentorder.FieldRefundRequestReason) {
+ fields = append(fields, paymentorder.FieldRefundRequestReason)
+ }
+ if m.FieldCleared(paymentorder.FieldRefundRequestedBy) {
+ fields = append(fields, paymentorder.FieldRefundRequestedBy)
+ }
+ if m.FieldCleared(paymentorder.FieldPaidAt) {
+ fields = append(fields, paymentorder.FieldPaidAt)
+ }
+ if m.FieldCleared(paymentorder.FieldCompletedAt) {
+ fields = append(fields, paymentorder.FieldCompletedAt)
+ }
+ if m.FieldCleared(paymentorder.FieldFailedAt) {
+ fields = append(fields, paymentorder.FieldFailedAt)
+ }
+ if m.FieldCleared(paymentorder.FieldFailedReason) {
+ fields = append(fields, paymentorder.FieldFailedReason)
+ }
+ if m.FieldCleared(paymentorder.FieldSrcURL) {
+ fields = append(fields, paymentorder.FieldSrcURL)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *PaymentOrderMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *PaymentOrderMutation) ClearField(name string) error {
+ switch name {
+ case paymentorder.FieldUserNotes:
+ m.ClearUserNotes()
+ return nil
+ case paymentorder.FieldPayURL:
+ m.ClearPayURL()
+ return nil
+ case paymentorder.FieldQrCode:
+ m.ClearQrCode()
+ return nil
+ case paymentorder.FieldQrCodeImg:
+ m.ClearQrCodeImg()
+ return nil
+ case paymentorder.FieldPlanID:
+ m.ClearPlanID()
+ return nil
+ case paymentorder.FieldSubscriptionGroupID:
+ m.ClearSubscriptionGroupID()
+ return nil
+ case paymentorder.FieldSubscriptionDays:
+ m.ClearSubscriptionDays()
+ return nil
+ case paymentorder.FieldProviderInstanceID:
+ m.ClearProviderInstanceID()
+ return nil
+ case paymentorder.FieldProviderKey:
+ m.ClearProviderKey()
+ return nil
+ case paymentorder.FieldProviderSnapshot:
+ m.ClearProviderSnapshot()
+ return nil
+ case paymentorder.FieldRefundReason:
+ m.ClearRefundReason()
+ return nil
+ case paymentorder.FieldRefundAt:
+ m.ClearRefundAt()
+ return nil
+ case paymentorder.FieldRefundRequestedAt:
+ m.ClearRefundRequestedAt()
+ return nil
+ case paymentorder.FieldRefundRequestReason:
+ m.ClearRefundRequestReason()
+ return nil
+ case paymentorder.FieldRefundRequestedBy:
+ m.ClearRefundRequestedBy()
+ return nil
+ case paymentorder.FieldPaidAt:
+ m.ClearPaidAt()
+ return nil
+ case paymentorder.FieldCompletedAt:
+ m.ClearCompletedAt()
+ return nil
+ case paymentorder.FieldFailedAt:
+ m.ClearFailedAt()
+ return nil
+ case paymentorder.FieldFailedReason:
+ m.ClearFailedReason()
+ return nil
+ case paymentorder.FieldSrcURL:
+ m.ClearSrcURL()
+ return nil
+ }
+ return fmt.Errorf("unknown PaymentOrder nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *PaymentOrderMutation) ResetField(name string) error {
+ switch name {
+ case paymentorder.FieldUserID:
+ m.ResetUserID()
+ return nil
+ case paymentorder.FieldUserEmail:
+ m.ResetUserEmail()
+ return nil
+ case paymentorder.FieldUserName:
+ m.ResetUserName()
+ return nil
+ case paymentorder.FieldUserNotes:
+ m.ResetUserNotes()
+ return nil
+ case paymentorder.FieldAmount:
+ m.ResetAmount()
+ return nil
+ case paymentorder.FieldPayAmount:
+ m.ResetPayAmount()
+ return nil
+ case paymentorder.FieldFeeRate:
+ m.ResetFeeRate()
+ return nil
+ case paymentorder.FieldRechargeCode:
+ m.ResetRechargeCode()
+ return nil
+ case paymentorder.FieldOutTradeNo:
+ m.ResetOutTradeNo()
+ return nil
+ case paymentorder.FieldPaymentType:
+ m.ResetPaymentType()
+ return nil
+ case paymentorder.FieldPaymentTradeNo:
+ m.ResetPaymentTradeNo()
+ return nil
+ case paymentorder.FieldPayURL:
+ m.ResetPayURL()
+ return nil
+ case paymentorder.FieldQrCode:
+ m.ResetQrCode()
+ return nil
+ case paymentorder.FieldQrCodeImg:
+ m.ResetQrCodeImg()
+ return nil
+ case paymentorder.FieldOrderType:
+ m.ResetOrderType()
+ return nil
+ case paymentorder.FieldPlanID:
+ m.ResetPlanID()
+ return nil
+ case paymentorder.FieldSubscriptionGroupID:
+ m.ResetSubscriptionGroupID()
+ return nil
+ case paymentorder.FieldSubscriptionDays:
+ m.ResetSubscriptionDays()
+ return nil
+ case paymentorder.FieldProviderInstanceID:
+ m.ResetProviderInstanceID()
+ return nil
+ case paymentorder.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case paymentorder.FieldProviderSnapshot:
+ m.ResetProviderSnapshot()
+ return nil
+ case paymentorder.FieldStatus:
+ m.ResetStatus()
+ return nil
+ case paymentorder.FieldRefundAmount:
+ m.ResetRefundAmount()
+ return nil
+ case paymentorder.FieldRefundReason:
+ m.ResetRefundReason()
+ return nil
+ case paymentorder.FieldRefundAt:
+ m.ResetRefundAt()
+ return nil
+ case paymentorder.FieldForceRefund:
+ m.ResetForceRefund()
+ return nil
+ case paymentorder.FieldRefundRequestedAt:
+ m.ResetRefundRequestedAt()
+ return nil
+ case paymentorder.FieldRefundRequestReason:
+ m.ResetRefundRequestReason()
+ return nil
+ case paymentorder.FieldRefundRequestedBy:
+ m.ResetRefundRequestedBy()
+ return nil
+ case paymentorder.FieldExpiresAt:
+ m.ResetExpiresAt()
+ return nil
+ case paymentorder.FieldPaidAt:
+ m.ResetPaidAt()
+ return nil
+ case paymentorder.FieldCompletedAt:
+ m.ResetCompletedAt()
+ return nil
+ case paymentorder.FieldFailedAt:
+ m.ResetFailedAt()
+ return nil
+ case paymentorder.FieldFailedReason:
+ m.ResetFailedReason()
+ return nil
+ case paymentorder.FieldClientIP:
+ m.ResetClientIP()
+ return nil
+ case paymentorder.FieldSrcHost:
+ m.ResetSrcHost()
+ return nil
+ case paymentorder.FieldSrcURL:
+ m.ResetSrcURL()
+ return nil
+ case paymentorder.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case paymentorder.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown PaymentOrder field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *PaymentOrderMutation) AddedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.user != nil {
+ edges = append(edges, paymentorder.EdgeUser)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *PaymentOrderMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case paymentorder.EdgeUser:
+ if id := m.user; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *PaymentOrderMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 1)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *PaymentOrderMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *PaymentOrderMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 1)
+ if m.cleareduser {
+ edges = append(edges, paymentorder.EdgeUser)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *PaymentOrderMutation) EdgeCleared(name string) bool {
+ switch name {
+ case paymentorder.EdgeUser:
+ return m.cleareduser
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *PaymentOrderMutation) ClearEdge(name string) error {
+ switch name {
+ case paymentorder.EdgeUser:
+ m.ClearUser()
+ return nil
+ }
+ return fmt.Errorf("unknown PaymentOrder unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *PaymentOrderMutation) ResetEdge(name string) error {
+ switch name {
+ case paymentorder.EdgeUser:
+ m.ResetUser()
+ return nil
+ }
+ return fmt.Errorf("unknown PaymentOrder edge %s", name)
+}
+
+// PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph.
+type PaymentProviderInstanceMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ provider_key *string
+ name *string
+ _config *string
+ supported_types *string
+ enabled *bool
+ payment_mode *string
+ sort_order *int
+ addsort_order *int
+ limits *string
+ refund_enabled *bool
+ allow_user_refund *bool
+ created_at *time.Time
+ updated_at *time.Time
+ clearedFields map[string]struct{}
+ done bool
+ oldValue func(context.Context) (*PaymentProviderInstance, error)
+ predicates []predicate.PaymentProviderInstance
+}
+
+var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil)
+
+// paymentproviderinstanceOption allows management of the mutation configuration using functional options.
+type paymentproviderinstanceOption func(*PaymentProviderInstanceMutation)
+
+// newPaymentProviderInstanceMutation creates new mutation for the PaymentProviderInstance entity.
+func newPaymentProviderInstanceMutation(c config, op Op, opts ...paymentproviderinstanceOption) *PaymentProviderInstanceMutation {
+ m := &PaymentProviderInstanceMutation{
+ config: c,
+ op: op,
+ typ: TypePaymentProviderInstance,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withPaymentProviderInstanceID sets the ID field of the mutation.
+func withPaymentProviderInstanceID(id int64) paymentproviderinstanceOption {
+ return func(m *PaymentProviderInstanceMutation) {
+ var (
+ err error
+ once sync.Once
+ value *PaymentProviderInstance
+ )
+ m.oldValue = func(ctx context.Context) (*PaymentProviderInstance, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().PaymentProviderInstance.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withPaymentProviderInstance sets the old PaymentProviderInstance of the mutation.
+func withPaymentProviderInstance(node *PaymentProviderInstance) paymentproviderinstanceOption {
+ return func(m *PaymentProviderInstanceMutation) {
+ m.oldValue = func(context.Context) (*PaymentProviderInstance, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m PaymentProviderInstanceMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m PaymentProviderInstanceMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *PaymentProviderInstanceMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *PaymentProviderInstanceMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().PaymentProviderInstance.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *PaymentProviderInstanceMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *PaymentProviderInstanceMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *PaymentProviderInstanceMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetName sets the "name" field.
+func (m *PaymentProviderInstanceMutation) SetName(s string) {
+ m.name = &s
+}
+
+// Name returns the value of the "name" field in the mutation.
+func (m *PaymentProviderInstanceMutation) Name() (r string, exists bool) {
+ v := m.name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldName returns the old "name" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldName(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldName: %w", err)
+ }
+ return oldValue.Name, nil
+}
+
+// ResetName resets all changes to the "name" field.
+func (m *PaymentProviderInstanceMutation) ResetName() {
+ m.name = nil
+}
+
+// SetConfig sets the "config" field.
+func (m *PaymentProviderInstanceMutation) SetConfig(s string) {
+ m._config = &s
+}
+
+// Config returns the value of the "config" field in the mutation.
+func (m *PaymentProviderInstanceMutation) Config() (r string, exists bool) {
+ v := m._config
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldConfig returns the old "config" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldConfig(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldConfig is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldConfig requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldConfig: %w", err)
+ }
+ return oldValue.Config, nil
+}
+
+// ResetConfig resets all changes to the "config" field.
+func (m *PaymentProviderInstanceMutation) ResetConfig() {
+ m._config = nil
+}
+
+// SetSupportedTypes sets the "supported_types" field.
+func (m *PaymentProviderInstanceMutation) SetSupportedTypes(s string) {
+ m.supported_types = &s
+}
+
+// SupportedTypes returns the value of the "supported_types" field in the mutation.
+func (m *PaymentProviderInstanceMutation) SupportedTypes() (r string, exists bool) {
+ v := m.supported_types
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSupportedTypes returns the old "supported_types" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldSupportedTypes(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSupportedTypes is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSupportedTypes requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSupportedTypes: %w", err)
+ }
+ return oldValue.SupportedTypes, nil
+}
+
+// ResetSupportedTypes resets all changes to the "supported_types" field.
+func (m *PaymentProviderInstanceMutation) ResetSupportedTypes() {
+ m.supported_types = nil
+}
+
+// SetEnabled sets the "enabled" field.
+func (m *PaymentProviderInstanceMutation) SetEnabled(b bool) {
+ m.enabled = &b
+}
+
+// Enabled returns the value of the "enabled" field in the mutation.
+func (m *PaymentProviderInstanceMutation) Enabled() (r bool, exists bool) {
+ v := m.enabled
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldEnabled returns the old "enabled" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldEnabled(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldEnabled is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldEnabled requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldEnabled: %w", err)
+ }
+ return oldValue.Enabled, nil
+}
+
+// ResetEnabled resets all changes to the "enabled" field.
+func (m *PaymentProviderInstanceMutation) ResetEnabled() {
+ m.enabled = nil
+}
+
+// SetPaymentMode sets the "payment_mode" field.
+func (m *PaymentProviderInstanceMutation) SetPaymentMode(s string) {
+ m.payment_mode = &s
+}
+
+// PaymentMode returns the value of the "payment_mode" field in the mutation.
+func (m *PaymentProviderInstanceMutation) PaymentMode() (r string, exists bool) {
+ v := m.payment_mode
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPaymentMode returns the old "payment_mode" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldPaymentMode(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPaymentMode is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPaymentMode requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPaymentMode: %w", err)
+ }
+ return oldValue.PaymentMode, nil
+}
+
+// ResetPaymentMode resets all changes to the "payment_mode" field.
+func (m *PaymentProviderInstanceMutation) ResetPaymentMode() {
+ m.payment_mode = nil
+}
+
+// SetSortOrder sets the "sort_order" field.
+func (m *PaymentProviderInstanceMutation) SetSortOrder(i int) {
+ m.sort_order = &i
+ m.addsort_order = nil
+}
+
+// SortOrder returns the value of the "sort_order" field in the mutation.
+func (m *PaymentProviderInstanceMutation) SortOrder() (r int, exists bool) {
+ v := m.sort_order
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSortOrder returns the old "sort_order" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldSortOrder(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSortOrder is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSortOrder requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSortOrder: %w", err)
+ }
+ return oldValue.SortOrder, nil
+}
+
+// AddSortOrder adds i to the "sort_order" field.
+func (m *PaymentProviderInstanceMutation) AddSortOrder(i int) {
+ if m.addsort_order != nil {
+ *m.addsort_order += i
+ } else {
+ m.addsort_order = &i
+ }
+}
+
+// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation.
+func (m *PaymentProviderInstanceMutation) AddedSortOrder() (r int, exists bool) {
+ v := m.addsort_order
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetSortOrder resets all changes to the "sort_order" field.
+func (m *PaymentProviderInstanceMutation) ResetSortOrder() {
+ m.sort_order = nil
+ m.addsort_order = nil
+}
+
+// SetLimits sets the "limits" field.
+func (m *PaymentProviderInstanceMutation) SetLimits(s string) {
+ m.limits = &s
+}
+
+// Limits returns the value of the "limits" field in the mutation.
+func (m *PaymentProviderInstanceMutation) Limits() (r string, exists bool) {
+ v := m.limits
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLimits returns the old "limits" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldLimits(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLimits is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLimits requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLimits: %w", err)
+ }
+ return oldValue.Limits, nil
+}
+
+// ResetLimits resets all changes to the "limits" field.
+func (m *PaymentProviderInstanceMutation) ResetLimits() {
+ m.limits = nil
+}
+
+// SetRefundEnabled sets the "refund_enabled" field.
+func (m *PaymentProviderInstanceMutation) SetRefundEnabled(b bool) {
+ m.refund_enabled = &b
+}
+
+// RefundEnabled returns the value of the "refund_enabled" field in the mutation.
+func (m *PaymentProviderInstanceMutation) RefundEnabled() (r bool, exists bool) {
+ v := m.refund_enabled
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRefundEnabled returns the old "refund_enabled" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldRefundEnabled(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRefundEnabled is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRefundEnabled requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRefundEnabled: %w", err)
+ }
+ return oldValue.RefundEnabled, nil
+}
+
+// ResetRefundEnabled resets all changes to the "refund_enabled" field.
+func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() {
+ m.refund_enabled = nil
+}
+
+// SetAllowUserRefund sets the "allow_user_refund" field.
+func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) {
+ m.allow_user_refund = &b
+}
+
+// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation.
+func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) {
+ v := m.allow_user_refund
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAllowUserRefund requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err)
+ }
+ return oldValue.AllowUserRefund, nil
+}
+
+// ResetAllowUserRefund resets all changes to the "allow_user_refund" field.
+func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() {
+ m.allow_user_refund = nil
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *PaymentProviderInstanceMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *PaymentProviderInstanceMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *PaymentProviderInstanceMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *PaymentProviderInstanceMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the PaymentProviderInstance entity.
+// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PaymentProviderInstanceMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *PaymentProviderInstanceMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// Where appends a list predicates to the PaymentProviderInstanceMutation builder.
+func (m *PaymentProviderInstanceMutation) Where(ps ...predicate.PaymentProviderInstance) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the PaymentProviderInstanceMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *PaymentProviderInstanceMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.PaymentProviderInstance, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *PaymentProviderInstanceMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *PaymentProviderInstanceMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (PaymentProviderInstance).
+func (m *PaymentProviderInstanceMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *PaymentProviderInstanceMutation) Fields() []string {
+ fields := make([]string, 0, 12)
+ if m.provider_key != nil {
+ fields = append(fields, paymentproviderinstance.FieldProviderKey)
+ }
+ if m.name != nil {
+ fields = append(fields, paymentproviderinstance.FieldName)
+ }
+ if m._config != nil {
+ fields = append(fields, paymentproviderinstance.FieldConfig)
+ }
+ if m.supported_types != nil {
+ fields = append(fields, paymentproviderinstance.FieldSupportedTypes)
+ }
+ if m.enabled != nil {
+ fields = append(fields, paymentproviderinstance.FieldEnabled)
+ }
+ if m.payment_mode != nil {
+ fields = append(fields, paymentproviderinstance.FieldPaymentMode)
+ }
+ if m.sort_order != nil {
+ fields = append(fields, paymentproviderinstance.FieldSortOrder)
+ }
+ if m.limits != nil {
+ fields = append(fields, paymentproviderinstance.FieldLimits)
+ }
+ if m.refund_enabled != nil {
+ fields = append(fields, paymentproviderinstance.FieldRefundEnabled)
+ }
+ if m.allow_user_refund != nil {
+ fields = append(fields, paymentproviderinstance.FieldAllowUserRefund)
+ }
+ if m.created_at != nil {
+ fields = append(fields, paymentproviderinstance.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, paymentproviderinstance.FieldUpdatedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case paymentproviderinstance.FieldProviderKey:
+ return m.ProviderKey()
+ case paymentproviderinstance.FieldName:
+ return m.Name()
+ case paymentproviderinstance.FieldConfig:
+ return m.Config()
+ case paymentproviderinstance.FieldSupportedTypes:
+ return m.SupportedTypes()
+ case paymentproviderinstance.FieldEnabled:
+ return m.Enabled()
+ case paymentproviderinstance.FieldPaymentMode:
+ return m.PaymentMode()
+ case paymentproviderinstance.FieldSortOrder:
+ return m.SortOrder()
+ case paymentproviderinstance.FieldLimits:
+ return m.Limits()
+ case paymentproviderinstance.FieldRefundEnabled:
+ return m.RefundEnabled()
+ case paymentproviderinstance.FieldAllowUserRefund:
+ return m.AllowUserRefund()
+ case paymentproviderinstance.FieldCreatedAt:
+ return m.CreatedAt()
+ case paymentproviderinstance.FieldUpdatedAt:
+ return m.UpdatedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case paymentproviderinstance.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case paymentproviderinstance.FieldName:
+ return m.OldName(ctx)
+ case paymentproviderinstance.FieldConfig:
+ return m.OldConfig(ctx)
+ case paymentproviderinstance.FieldSupportedTypes:
+ return m.OldSupportedTypes(ctx)
+ case paymentproviderinstance.FieldEnabled:
+ return m.OldEnabled(ctx)
+ case paymentproviderinstance.FieldPaymentMode:
+ return m.OldPaymentMode(ctx)
+ case paymentproviderinstance.FieldSortOrder:
+ return m.OldSortOrder(ctx)
+ case paymentproviderinstance.FieldLimits:
+ return m.OldLimits(ctx)
+ case paymentproviderinstance.FieldRefundEnabled:
+ return m.OldRefundEnabled(ctx)
+ case paymentproviderinstance.FieldAllowUserRefund:
+ return m.OldAllowUserRefund(ctx)
+ case paymentproviderinstance.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case paymentproviderinstance.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown PaymentProviderInstance field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case paymentproviderinstance.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case paymentproviderinstance.FieldName:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetName(v)
+ return nil
+ case paymentproviderinstance.FieldConfig:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetConfig(v)
+ return nil
+ case paymentproviderinstance.FieldSupportedTypes:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSupportedTypes(v)
+ return nil
+ case paymentproviderinstance.FieldEnabled:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetEnabled(v)
+ return nil
+ case paymentproviderinstance.FieldPaymentMode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPaymentMode(v)
+ return nil
+ case paymentproviderinstance.FieldSortOrder:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSortOrder(v)
+ return nil
+ case paymentproviderinstance.FieldLimits:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLimits(v)
+ return nil
+ case paymentproviderinstance.FieldRefundEnabled:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRefundEnabled(v)
+ return nil
+ case paymentproviderinstance.FieldAllowUserRefund:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAllowUserRefund(v)
+ return nil
+ case paymentproviderinstance.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case paymentproviderinstance.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown PaymentProviderInstance field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *PaymentProviderInstanceMutation) AddedFields() []string {
+ var fields []string
+ if m.addsort_order != nil {
+ fields = append(fields, paymentproviderinstance.FieldSortOrder)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *PaymentProviderInstanceMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case paymentproviderinstance.FieldSortOrder:
+ return m.AddedSortOrder()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PaymentProviderInstanceMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case paymentproviderinstance.FieldSortOrder:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSortOrder(v)
+ return nil
+ }
+ return fmt.Errorf("unknown PaymentProviderInstance numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *PaymentProviderInstanceMutation) ClearedFields() []string {
+ return nil
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *PaymentProviderInstanceMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *PaymentProviderInstanceMutation) ClearField(name string) error {
+ return fmt.Errorf("unknown PaymentProviderInstance nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *PaymentProviderInstanceMutation) ResetField(name string) error {
+ switch name {
+ case paymentproviderinstance.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case paymentproviderinstance.FieldName:
+ m.ResetName()
+ return nil
+ case paymentproviderinstance.FieldConfig:
+ m.ResetConfig()
+ return nil
+ case paymentproviderinstance.FieldSupportedTypes:
+ m.ResetSupportedTypes()
+ return nil
+ case paymentproviderinstance.FieldEnabled:
+ m.ResetEnabled()
+ return nil
+ case paymentproviderinstance.FieldPaymentMode:
+ m.ResetPaymentMode()
+ return nil
+ case paymentproviderinstance.FieldSortOrder:
+ m.ResetSortOrder()
+ return nil
+ case paymentproviderinstance.FieldLimits:
+ m.ResetLimits()
+ return nil
+ case paymentproviderinstance.FieldRefundEnabled:
+ m.ResetRefundEnabled()
+ return nil
+ case paymentproviderinstance.FieldAllowUserRefund:
+ m.ResetAllowUserRefund()
+ return nil
+ case paymentproviderinstance.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case paymentproviderinstance.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown PaymentProviderInstance field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *PaymentProviderInstanceMutation) AddedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *PaymentProviderInstanceMutation) AddedIDs(name string) []ent.Value {
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *PaymentProviderInstanceMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *PaymentProviderInstanceMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *PaymentProviderInstanceMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *PaymentProviderInstanceMutation) EdgeCleared(name string) bool {
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *PaymentProviderInstanceMutation) ClearEdge(name string) error {
+ return fmt.Errorf("unknown PaymentProviderInstance unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *PaymentProviderInstanceMutation) ResetEdge(name string) error {
+ return fmt.Errorf("unknown PaymentProviderInstance edge %s", name)
+}
+
+// PendingAuthSessionMutation represents an operation that mutates the PendingAuthSession nodes in the graph.
+type PendingAuthSessionMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ session_token *string
+ intent *string
+ provider_type *string
+ provider_key *string
+ provider_subject *string
+ redirect_to *string
+ resolved_email *string
+ registration_password_hash *string
+ upstream_identity_claims *map[string]interface{}
+ local_flow_state *map[string]interface{}
+ browser_session_key *string
+ completion_code_hash *string
+ completion_code_expires_at *time.Time
+ email_verified_at *time.Time
+ password_verified_at *time.Time
+ totp_verified_at *time.Time
+ expires_at *time.Time
+ consumed_at *time.Time
+ clearedFields map[string]struct{}
+ target_user *int64
+ clearedtarget_user bool
+ adoption_decision *int64
+ clearedadoption_decision bool
+ done bool
+ oldValue func(context.Context) (*PendingAuthSession, error)
+ predicates []predicate.PendingAuthSession
+}
+
+var _ ent.Mutation = (*PendingAuthSessionMutation)(nil)
+
+// pendingauthsessionOption allows management of the mutation configuration using functional options.
+type pendingauthsessionOption func(*PendingAuthSessionMutation)
+
+// newPendingAuthSessionMutation creates new mutation for the PendingAuthSession entity.
+func newPendingAuthSessionMutation(c config, op Op, opts ...pendingauthsessionOption) *PendingAuthSessionMutation {
+ m := &PendingAuthSessionMutation{
+ config: c,
+ op: op,
+ typ: TypePendingAuthSession,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withPendingAuthSessionID sets the ID field of the mutation.
+func withPendingAuthSessionID(id int64) pendingauthsessionOption {
+ return func(m *PendingAuthSessionMutation) {
+ var (
+ err error
+ once sync.Once
+ value *PendingAuthSession
+ )
+ m.oldValue = func(ctx context.Context) (*PendingAuthSession, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().PendingAuthSession.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withPendingAuthSession sets the old PendingAuthSession of the mutation.
+func withPendingAuthSession(node *PendingAuthSession) pendingauthsessionOption {
+ return func(m *PendingAuthSessionMutation) {
+ m.oldValue = func(context.Context) (*PendingAuthSession, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m PendingAuthSessionMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m PendingAuthSessionMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *PendingAuthSessionMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *PendingAuthSessionMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().PendingAuthSession.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *PendingAuthSessionMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *PendingAuthSessionMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *PendingAuthSessionMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *PendingAuthSessionMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *PendingAuthSessionMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *PendingAuthSessionMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetSessionToken sets the "session_token" field.
+func (m *PendingAuthSessionMutation) SetSessionToken(s string) {
+ m.session_token = &s
+}
+
+// SessionToken returns the value of the "session_token" field in the mutation.
+func (m *PendingAuthSessionMutation) SessionToken() (r string, exists bool) {
+ v := m.session_token
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSessionToken returns the old "session_token" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldSessionToken(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSessionToken is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSessionToken requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSessionToken: %w", err)
+ }
+ return oldValue.SessionToken, nil
+}
+
+// ResetSessionToken resets all changes to the "session_token" field.
+func (m *PendingAuthSessionMutation) ResetSessionToken() {
+ m.session_token = nil
+}
+
+// SetIntent sets the "intent" field.
+func (m *PendingAuthSessionMutation) SetIntent(s string) {
+ m.intent = &s
+}
+
+// Intent returns the value of the "intent" field in the mutation.
+func (m *PendingAuthSessionMutation) Intent() (r string, exists bool) {
+ v := m.intent
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldIntent returns the old "intent" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldIntent(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldIntent is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldIntent requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldIntent: %w", err)
+ }
+ return oldValue.Intent, nil
+}
+
+// ResetIntent resets all changes to the "intent" field.
+func (m *PendingAuthSessionMutation) ResetIntent() {
+ m.intent = nil
+}
+
+// SetProviderType sets the "provider_type" field.
+func (m *PendingAuthSessionMutation) SetProviderType(s string) {
+ m.provider_type = &s
+}
+
+// ProviderType returns the value of the "provider_type" field in the mutation.
+func (m *PendingAuthSessionMutation) ProviderType() (r string, exists bool) {
+ v := m.provider_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderType returns the old "provider_type" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldProviderType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderType: %w", err)
+ }
+ return oldValue.ProviderType, nil
+}
+
+// ResetProviderType resets all changes to the "provider_type" field.
+func (m *PendingAuthSessionMutation) ResetProviderType() {
+ m.provider_type = nil
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (m *PendingAuthSessionMutation) SetProviderKey(s string) {
+ m.provider_key = &s
+}
+
+// ProviderKey returns the value of the "provider_key" field in the mutation.
+func (m *PendingAuthSessionMutation) ProviderKey() (r string, exists bool) {
+ v := m.provider_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderKey returns the old "provider_key" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldProviderKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderKey: %w", err)
+ }
+ return oldValue.ProviderKey, nil
+}
+
+// ResetProviderKey resets all changes to the "provider_key" field.
+func (m *PendingAuthSessionMutation) ResetProviderKey() {
+ m.provider_key = nil
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (m *PendingAuthSessionMutation) SetProviderSubject(s string) {
+ m.provider_subject = &s
+}
+
+// ProviderSubject returns the value of the "provider_subject" field in the mutation.
+func (m *PendingAuthSessionMutation) ProviderSubject() (r string, exists bool) {
+ v := m.provider_subject
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProviderSubject returns the old "provider_subject" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldProviderSubject(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProviderSubject requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err)
+ }
+ return oldValue.ProviderSubject, nil
+}
+
+// ResetProviderSubject resets all changes to the "provider_subject" field.
+func (m *PendingAuthSessionMutation) ResetProviderSubject() {
+ m.provider_subject = nil
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (m *PendingAuthSessionMutation) SetTargetUserID(i int64) {
+ m.target_user = &i
+}
+
+// TargetUserID returns the value of the "target_user_id" field in the mutation.
+func (m *PendingAuthSessionMutation) TargetUserID() (r int64, exists bool) {
+ v := m.target_user
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTargetUserID returns the old "target_user_id" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldTargetUserID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTargetUserID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTargetUserID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTargetUserID: %w", err)
+ }
+ return oldValue.TargetUserID, nil
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (m *PendingAuthSessionMutation) ClearTargetUserID() {
+ m.target_user = nil
+ m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{}
+}
+
+// TargetUserIDCleared returns if the "target_user_id" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) TargetUserIDCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldTargetUserID]
+ return ok
+}
+
+// ResetTargetUserID resets all changes to the "target_user_id" field.
+func (m *PendingAuthSessionMutation) ResetTargetUserID() {
+ m.target_user = nil
+ delete(m.clearedFields, pendingauthsession.FieldTargetUserID)
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (m *PendingAuthSessionMutation) SetRedirectTo(s string) {
+ m.redirect_to = &s
+}
+
+// RedirectTo returns the value of the "redirect_to" field in the mutation.
+func (m *PendingAuthSessionMutation) RedirectTo() (r string, exists bool) {
+ v := m.redirect_to
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRedirectTo returns the old "redirect_to" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldRedirectTo(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRedirectTo is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRedirectTo requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRedirectTo: %w", err)
+ }
+ return oldValue.RedirectTo, nil
+}
+
+// ResetRedirectTo resets all changes to the "redirect_to" field.
+func (m *PendingAuthSessionMutation) ResetRedirectTo() {
+ m.redirect_to = nil
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (m *PendingAuthSessionMutation) SetResolvedEmail(s string) {
+ m.resolved_email = &s
+}
+
+// ResolvedEmail returns the value of the "resolved_email" field in the mutation.
+func (m *PendingAuthSessionMutation) ResolvedEmail() (r string, exists bool) {
+ v := m.resolved_email
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldResolvedEmail returns the old "resolved_email" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldResolvedEmail(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldResolvedEmail is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldResolvedEmail requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldResolvedEmail: %w", err)
+ }
+ return oldValue.ResolvedEmail, nil
+}
+
+// ResetResolvedEmail resets all changes to the "resolved_email" field.
+func (m *PendingAuthSessionMutation) ResetResolvedEmail() {
+ m.resolved_email = nil
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (m *PendingAuthSessionMutation) SetRegistrationPasswordHash(s string) {
+ m.registration_password_hash = &s
+}
+
+// RegistrationPasswordHash returns the value of the "registration_password_hash" field in the mutation.
+func (m *PendingAuthSessionMutation) RegistrationPasswordHash() (r string, exists bool) {
+ v := m.registration_password_hash
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRegistrationPasswordHash returns the old "registration_password_hash" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldRegistrationPasswordHash(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRegistrationPasswordHash is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRegistrationPasswordHash requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRegistrationPasswordHash: %w", err)
+ }
+ return oldValue.RegistrationPasswordHash, nil
+}
+
+// ResetRegistrationPasswordHash resets all changes to the "registration_password_hash" field.
+func (m *PendingAuthSessionMutation) ResetRegistrationPasswordHash() {
+ m.registration_password_hash = nil
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (m *PendingAuthSessionMutation) SetUpstreamIdentityClaims(value map[string]interface{}) {
+ m.upstream_identity_claims = &value
+}
+
+// UpstreamIdentityClaims returns the value of the "upstream_identity_claims" field in the mutation.
+func (m *PendingAuthSessionMutation) UpstreamIdentityClaims() (r map[string]interface{}, exists bool) {
+ v := m.upstream_identity_claims
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpstreamIdentityClaims returns the old "upstream_identity_claims" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldUpstreamIdentityClaims(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpstreamIdentityClaims is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpstreamIdentityClaims requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpstreamIdentityClaims: %w", err)
+ }
+ return oldValue.UpstreamIdentityClaims, nil
+}
+
+// ResetUpstreamIdentityClaims resets all changes to the "upstream_identity_claims" field.
+func (m *PendingAuthSessionMutation) ResetUpstreamIdentityClaims() {
+ m.upstream_identity_claims = nil
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (m *PendingAuthSessionMutation) SetLocalFlowState(value map[string]interface{}) {
+ m.local_flow_state = &value
+}
+
+// LocalFlowState returns the value of the "local_flow_state" field in the mutation.
+func (m *PendingAuthSessionMutation) LocalFlowState() (r map[string]interface{}, exists bool) {
+ v := m.local_flow_state
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldLocalFlowState returns the old "local_flow_state" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldLocalFlowState(ctx context.Context) (v map[string]interface{}, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLocalFlowState is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLocalFlowState requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLocalFlowState: %w", err)
+ }
+ return oldValue.LocalFlowState, nil
+}
+
+// ResetLocalFlowState resets all changes to the "local_flow_state" field.
+func (m *PendingAuthSessionMutation) ResetLocalFlowState() {
+ m.local_flow_state = nil
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (m *PendingAuthSessionMutation) SetBrowserSessionKey(s string) {
+ m.browser_session_key = &s
+}
+
+// BrowserSessionKey returns the value of the "browser_session_key" field in the mutation.
+func (m *PendingAuthSessionMutation) BrowserSessionKey() (r string, exists bool) {
+ v := m.browser_session_key
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBrowserSessionKey returns the old "browser_session_key" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldBrowserSessionKey(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBrowserSessionKey is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBrowserSessionKey requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBrowserSessionKey: %w", err)
+ }
+ return oldValue.BrowserSessionKey, nil
+}
+
+// ResetBrowserSessionKey resets all changes to the "browser_session_key" field.
+func (m *PendingAuthSessionMutation) ResetBrowserSessionKey() {
+ m.browser_session_key = nil
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (m *PendingAuthSessionMutation) SetCompletionCodeHash(s string) {
+ m.completion_code_hash = &s
+}
+
+// CompletionCodeHash returns the value of the "completion_code_hash" field in the mutation.
+func (m *PendingAuthSessionMutation) CompletionCodeHash() (r string, exists bool) {
+ v := m.completion_code_hash
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCompletionCodeHash returns the old "completion_code_hash" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldCompletionCodeHash(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCompletionCodeHash is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCompletionCodeHash requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCompletionCodeHash: %w", err)
+ }
+ return oldValue.CompletionCodeHash, nil
+}
+
+// ResetCompletionCodeHash resets all changes to the "completion_code_hash" field.
+func (m *PendingAuthSessionMutation) ResetCompletionCodeHash() {
+ m.completion_code_hash = nil
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (m *PendingAuthSessionMutation) SetCompletionCodeExpiresAt(t time.Time) {
+ m.completion_code_expires_at = &t
+}
+
+// CompletionCodeExpiresAt returns the value of the "completion_code_expires_at" field in the mutation.
+func (m *PendingAuthSessionMutation) CompletionCodeExpiresAt() (r time.Time, exists bool) {
+ v := m.completion_code_expires_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCompletionCodeExpiresAt returns the old "completion_code_expires_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldCompletionCodeExpiresAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCompletionCodeExpiresAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCompletionCodeExpiresAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCompletionCodeExpiresAt: %w", err)
+ }
+ return oldValue.CompletionCodeExpiresAt, nil
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (m *PendingAuthSessionMutation) ClearCompletionCodeExpiresAt() {
+ m.completion_code_expires_at = nil
+ m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt] = struct{}{}
+}
+
+// CompletionCodeExpiresAtCleared returns if the "completion_code_expires_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) CompletionCodeExpiresAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt]
+ return ok
+}
+
+// ResetCompletionCodeExpiresAt resets all changes to the "completion_code_expires_at" field.
+func (m *PendingAuthSessionMutation) ResetCompletionCodeExpiresAt() {
+ m.completion_code_expires_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldCompletionCodeExpiresAt)
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (m *PendingAuthSessionMutation) SetEmailVerifiedAt(t time.Time) {
+ m.email_verified_at = &t
+}
+
+// EmailVerifiedAt returns the value of the "email_verified_at" field in the mutation.
+func (m *PendingAuthSessionMutation) EmailVerifiedAt() (r time.Time, exists bool) {
+ v := m.email_verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldEmailVerifiedAt returns the old "email_verified_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldEmailVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldEmailVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldEmailVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldEmailVerifiedAt: %w", err)
+ }
+ return oldValue.EmailVerifiedAt, nil
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (m *PendingAuthSessionMutation) ClearEmailVerifiedAt() {
+ m.email_verified_at = nil
+ m.clearedFields[pendingauthsession.FieldEmailVerifiedAt] = struct{}{}
+}
+
+// EmailVerifiedAtCleared returns if the "email_verified_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) EmailVerifiedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldEmailVerifiedAt]
+ return ok
+}
+
+// ResetEmailVerifiedAt resets all changes to the "email_verified_at" field.
+func (m *PendingAuthSessionMutation) ResetEmailVerifiedAt() {
+ m.email_verified_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldEmailVerifiedAt)
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (m *PendingAuthSessionMutation) SetPasswordVerifiedAt(t time.Time) {
+ m.password_verified_at = &t
+}
+
+// PasswordVerifiedAt returns the value of the "password_verified_at" field in the mutation.
+func (m *PendingAuthSessionMutation) PasswordVerifiedAt() (r time.Time, exists bool) {
+ v := m.password_verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPasswordVerifiedAt returns the old "password_verified_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldPasswordVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPasswordVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPasswordVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPasswordVerifiedAt: %w", err)
+ }
+ return oldValue.PasswordVerifiedAt, nil
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (m *PendingAuthSessionMutation) ClearPasswordVerifiedAt() {
+ m.password_verified_at = nil
+ m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt] = struct{}{}
+}
+
+// PasswordVerifiedAtCleared returns if the "password_verified_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) PasswordVerifiedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt]
+ return ok
+}
+
+// ResetPasswordVerifiedAt resets all changes to the "password_verified_at" field.
+func (m *PendingAuthSessionMutation) ResetPasswordVerifiedAt() {
+ m.password_verified_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldPasswordVerifiedAt)
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (m *PendingAuthSessionMutation) SetTotpVerifiedAt(t time.Time) {
+ m.totp_verified_at = &t
+}
+
+// TotpVerifiedAt returns the value of the "totp_verified_at" field in the mutation.
+func (m *PendingAuthSessionMutation) TotpVerifiedAt() (r time.Time, exists bool) {
+ v := m.totp_verified_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTotpVerifiedAt returns the old "totp_verified_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldTotpVerifiedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTotpVerifiedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTotpVerifiedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTotpVerifiedAt: %w", err)
+ }
+ return oldValue.TotpVerifiedAt, nil
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (m *PendingAuthSessionMutation) ClearTotpVerifiedAt() {
+ m.totp_verified_at = nil
+ m.clearedFields[pendingauthsession.FieldTotpVerifiedAt] = struct{}{}
+}
+
+// TotpVerifiedAtCleared returns if the "totp_verified_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) TotpVerifiedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldTotpVerifiedAt]
+ return ok
+}
+
+// ResetTotpVerifiedAt resets all changes to the "totp_verified_at" field.
+func (m *PendingAuthSessionMutation) ResetTotpVerifiedAt() {
+ m.totp_verified_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldTotpVerifiedAt)
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (m *PendingAuthSessionMutation) SetExpiresAt(t time.Time) {
+ m.expires_at = &t
+}
+
+// ExpiresAt returns the value of the "expires_at" field in the mutation.
+func (m *PendingAuthSessionMutation) ExpiresAt() (r time.Time, exists bool) {
+ v := m.expires_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExpiresAt returns the old "expires_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExpiresAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
+ }
+ return oldValue.ExpiresAt, nil
+}
+
+// ResetExpiresAt resets all changes to the "expires_at" field.
+func (m *PendingAuthSessionMutation) ResetExpiresAt() {
+ m.expires_at = nil
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (m *PendingAuthSessionMutation) SetConsumedAt(t time.Time) {
+ m.consumed_at = &t
+}
+
+// ConsumedAt returns the value of the "consumed_at" field in the mutation.
+func (m *PendingAuthSessionMutation) ConsumedAt() (r time.Time, exists bool) {
+ v := m.consumed_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldConsumedAt returns the old "consumed_at" field's value of the PendingAuthSession entity.
+// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *PendingAuthSessionMutation) OldConsumedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldConsumedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldConsumedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldConsumedAt: %w", err)
+ }
+ return oldValue.ConsumedAt, nil
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (m *PendingAuthSessionMutation) ClearConsumedAt() {
+ m.consumed_at = nil
+ m.clearedFields[pendingauthsession.FieldConsumedAt] = struct{}{}
+}
+
+// ConsumedAtCleared returns if the "consumed_at" field was cleared in this mutation.
+func (m *PendingAuthSessionMutation) ConsumedAtCleared() bool {
+ _, ok := m.clearedFields[pendingauthsession.FieldConsumedAt]
+ return ok
+}
+
+// ResetConsumedAt resets all changes to the "consumed_at" field.
+func (m *PendingAuthSessionMutation) ResetConsumedAt() {
+ m.consumed_at = nil
+ delete(m.clearedFields, pendingauthsession.FieldConsumedAt)
+}
+
+// ClearTargetUser clears the "target_user" edge to the User entity.
+func (m *PendingAuthSessionMutation) ClearTargetUser() {
+ m.clearedtarget_user = true
+ m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{}
+}
+
+// TargetUserCleared reports if the "target_user" edge to the User entity was cleared.
+func (m *PendingAuthSessionMutation) TargetUserCleared() bool {
+ return m.TargetUserIDCleared() || m.clearedtarget_user
+}
+
+// TargetUserIDs returns the "target_user" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// TargetUserID instead. It exists only for internal usage by the builders.
+func (m *PendingAuthSessionMutation) TargetUserIDs() (ids []int64) {
+ if id := m.target_user; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetTargetUser resets all changes to the "target_user" edge.
+func (m *PendingAuthSessionMutation) ResetTargetUser() {
+ m.target_user = nil
+ m.clearedtarget_user = false
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by id.
+func (m *PendingAuthSessionMutation) SetAdoptionDecisionID(id int64) {
+ m.adoption_decision = &id
+}
+
+// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (m *PendingAuthSessionMutation) ClearAdoptionDecision() {
+ m.clearedadoption_decision = true
+}
+
+// AdoptionDecisionCleared reports if the "adoption_decision" edge to the IdentityAdoptionDecision entity was cleared.
+func (m *PendingAuthSessionMutation) AdoptionDecisionCleared() bool {
+ return m.clearedadoption_decision
+}
+
+// AdoptionDecisionID returns the "adoption_decision" edge ID in the mutation.
+func (m *PendingAuthSessionMutation) AdoptionDecisionID() (id int64, exists bool) {
+ if m.adoption_decision != nil {
+ return *m.adoption_decision, true
+ }
+ return
+}
+
+// AdoptionDecisionIDs returns the "adoption_decision" edge IDs in the mutation.
+// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use
+// AdoptionDecisionID instead. It exists only for internal usage by the builders.
+func (m *PendingAuthSessionMutation) AdoptionDecisionIDs() (ids []int64) {
+ if id := m.adoption_decision; id != nil {
+ ids = append(ids, *id)
+ }
+ return
+}
+
+// ResetAdoptionDecision resets all changes to the "adoption_decision" edge.
+func (m *PendingAuthSessionMutation) ResetAdoptionDecision() {
+ m.adoption_decision = nil
+ m.clearedadoption_decision = false
+}
+
+// Where appends a list predicates to the PendingAuthSessionMutation builder.
+func (m *PendingAuthSessionMutation) Where(ps ...predicate.PendingAuthSession) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the PendingAuthSessionMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *PendingAuthSessionMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.PendingAuthSession, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *PendingAuthSessionMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *PendingAuthSessionMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (PendingAuthSession).
+func (m *PendingAuthSessionMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *PendingAuthSessionMutation) Fields() []string {
+ fields := make([]string, 0, 21)
+ if m.created_at != nil {
+ fields = append(fields, pendingauthsession.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, pendingauthsession.FieldUpdatedAt)
+ }
+ if m.session_token != nil {
+ fields = append(fields, pendingauthsession.FieldSessionToken)
+ }
+ if m.intent != nil {
+ fields = append(fields, pendingauthsession.FieldIntent)
+ }
+ if m.provider_type != nil {
+ fields = append(fields, pendingauthsession.FieldProviderType)
+ }
+ if m.provider_key != nil {
+ fields = append(fields, pendingauthsession.FieldProviderKey)
+ }
+ if m.provider_subject != nil {
+ fields = append(fields, pendingauthsession.FieldProviderSubject)
+ }
+ if m.target_user != nil {
+ fields = append(fields, pendingauthsession.FieldTargetUserID)
+ }
+ if m.redirect_to != nil {
+ fields = append(fields, pendingauthsession.FieldRedirectTo)
+ }
+ if m.resolved_email != nil {
+ fields = append(fields, pendingauthsession.FieldResolvedEmail)
+ }
+ if m.registration_password_hash != nil {
+ fields = append(fields, pendingauthsession.FieldRegistrationPasswordHash)
+ }
+ if m.upstream_identity_claims != nil {
+ fields = append(fields, pendingauthsession.FieldUpstreamIdentityClaims)
+ }
+ if m.local_flow_state != nil {
+ fields = append(fields, pendingauthsession.FieldLocalFlowState)
+ }
+ if m.browser_session_key != nil {
+ fields = append(fields, pendingauthsession.FieldBrowserSessionKey)
+ }
+ if m.completion_code_hash != nil {
+ fields = append(fields, pendingauthsession.FieldCompletionCodeHash)
+ }
+ if m.completion_code_expires_at != nil {
+ fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt)
+ }
+ if m.email_verified_at != nil {
+ fields = append(fields, pendingauthsession.FieldEmailVerifiedAt)
+ }
+ if m.password_verified_at != nil {
+ fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt)
+ }
+ if m.totp_verified_at != nil {
+ fields = append(fields, pendingauthsession.FieldTotpVerifiedAt)
+ }
+ if m.expires_at != nil {
+ fields = append(fields, pendingauthsession.FieldExpiresAt)
+ }
+ if m.consumed_at != nil {
+ fields = append(fields, pendingauthsession.FieldConsumedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *PendingAuthSessionMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ return m.CreatedAt()
+ case pendingauthsession.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case pendingauthsession.FieldSessionToken:
+ return m.SessionToken()
+ case pendingauthsession.FieldIntent:
+ return m.Intent()
+ case pendingauthsession.FieldProviderType:
+ return m.ProviderType()
+ case pendingauthsession.FieldProviderKey:
+ return m.ProviderKey()
+ case pendingauthsession.FieldProviderSubject:
+ return m.ProviderSubject()
+ case pendingauthsession.FieldTargetUserID:
+ return m.TargetUserID()
+ case pendingauthsession.FieldRedirectTo:
+ return m.RedirectTo()
+ case pendingauthsession.FieldResolvedEmail:
+ return m.ResolvedEmail()
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ return m.RegistrationPasswordHash()
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ return m.UpstreamIdentityClaims()
+ case pendingauthsession.FieldLocalFlowState:
+ return m.LocalFlowState()
+ case pendingauthsession.FieldBrowserSessionKey:
+ return m.BrowserSessionKey()
+ case pendingauthsession.FieldCompletionCodeHash:
+ return m.CompletionCodeHash()
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ return m.CompletionCodeExpiresAt()
+ case pendingauthsession.FieldEmailVerifiedAt:
+ return m.EmailVerifiedAt()
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ return m.PasswordVerifiedAt()
+ case pendingauthsession.FieldTotpVerifiedAt:
+ return m.TotpVerifiedAt()
+ case pendingauthsession.FieldExpiresAt:
+ return m.ExpiresAt()
+ case pendingauthsession.FieldConsumedAt:
+ return m.ConsumedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *PendingAuthSessionMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case pendingauthsession.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case pendingauthsession.FieldSessionToken:
+ return m.OldSessionToken(ctx)
+ case pendingauthsession.FieldIntent:
+ return m.OldIntent(ctx)
+ case pendingauthsession.FieldProviderType:
+ return m.OldProviderType(ctx)
+ case pendingauthsession.FieldProviderKey:
+ return m.OldProviderKey(ctx)
+ case pendingauthsession.FieldProviderSubject:
+ return m.OldProviderSubject(ctx)
+ case pendingauthsession.FieldTargetUserID:
+ return m.OldTargetUserID(ctx)
+ case pendingauthsession.FieldRedirectTo:
+ return m.OldRedirectTo(ctx)
+ case pendingauthsession.FieldResolvedEmail:
+ return m.OldResolvedEmail(ctx)
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ return m.OldRegistrationPasswordHash(ctx)
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ return m.OldUpstreamIdentityClaims(ctx)
+ case pendingauthsession.FieldLocalFlowState:
+ return m.OldLocalFlowState(ctx)
+ case pendingauthsession.FieldBrowserSessionKey:
+ return m.OldBrowserSessionKey(ctx)
+ case pendingauthsession.FieldCompletionCodeHash:
+ return m.OldCompletionCodeHash(ctx)
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ return m.OldCompletionCodeExpiresAt(ctx)
+ case pendingauthsession.FieldEmailVerifiedAt:
+ return m.OldEmailVerifiedAt(ctx)
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ return m.OldPasswordVerifiedAt(ctx)
+ case pendingauthsession.FieldTotpVerifiedAt:
+ return m.OldTotpVerifiedAt(ctx)
+ case pendingauthsession.FieldExpiresAt:
+ return m.OldExpiresAt(ctx)
+ case pendingauthsession.FieldConsumedAt:
+ return m.OldConsumedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown PendingAuthSession field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PendingAuthSessionMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case pendingauthsession.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case pendingauthsession.FieldSessionToken:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSessionToken(v)
+ return nil
+ case pendingauthsession.FieldIntent:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetIntent(v)
+ return nil
+ case pendingauthsession.FieldProviderType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderType(v)
+ return nil
+ case pendingauthsession.FieldProviderKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderKey(v)
+ return nil
+ case pendingauthsession.FieldProviderSubject:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProviderSubject(v)
+ return nil
+ case pendingauthsession.FieldTargetUserID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTargetUserID(v)
+ return nil
+ case pendingauthsession.FieldRedirectTo:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRedirectTo(v)
+ return nil
+ case pendingauthsession.FieldResolvedEmail:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetResolvedEmail(v)
+ return nil
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRegistrationPasswordHash(v)
+ return nil
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpstreamIdentityClaims(v)
+ return nil
+ case pendingauthsession.FieldLocalFlowState:
+ v, ok := value.(map[string]interface{})
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLocalFlowState(v)
+ return nil
+ case pendingauthsession.FieldBrowserSessionKey:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBrowserSessionKey(v)
+ return nil
+ case pendingauthsession.FieldCompletionCodeHash:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCompletionCodeHash(v)
+ return nil
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCompletionCodeExpiresAt(v)
+ return nil
+ case pendingauthsession.FieldEmailVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetEmailVerifiedAt(v)
+ return nil
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPasswordVerifiedAt(v)
+ return nil
+ case pendingauthsession.FieldTotpVerifiedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotpVerifiedAt(v)
+ return nil
+ case pendingauthsession.FieldExpiresAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExpiresAt(v)
+ return nil
+ case pendingauthsession.FieldConsumedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetConsumedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *PendingAuthSessionMutation) AddedFields() []string {
+ var fields []string
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *PendingAuthSessionMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *PendingAuthSessionMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ }
+ return fmt.Errorf("unknown PendingAuthSession numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *PendingAuthSessionMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(pendingauthsession.FieldTargetUserID) {
+ fields = append(fields, pendingauthsession.FieldTargetUserID)
+ }
+ if m.FieldCleared(pendingauthsession.FieldCompletionCodeExpiresAt) {
+ fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldEmailVerifiedAt) {
+ fields = append(fields, pendingauthsession.FieldEmailVerifiedAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldPasswordVerifiedAt) {
+ fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldTotpVerifiedAt) {
+ fields = append(fields, pendingauthsession.FieldTotpVerifiedAt)
+ }
+ if m.FieldCleared(pendingauthsession.FieldConsumedAt) {
+ fields = append(fields, pendingauthsession.FieldConsumedAt)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *PendingAuthSessionMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *PendingAuthSessionMutation) ClearField(name string) error {
+ switch name {
+ case pendingauthsession.FieldTargetUserID:
+ m.ClearTargetUserID()
+ return nil
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ m.ClearCompletionCodeExpiresAt()
+ return nil
+ case pendingauthsession.FieldEmailVerifiedAt:
+ m.ClearEmailVerifiedAt()
+ return nil
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ m.ClearPasswordVerifiedAt()
+ return nil
+ case pendingauthsession.FieldTotpVerifiedAt:
+ m.ClearTotpVerifiedAt()
+ return nil
+ case pendingauthsession.FieldConsumedAt:
+ m.ClearConsumedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *PendingAuthSessionMutation) ResetField(name string) error {
+ switch name {
+ case pendingauthsession.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case pendingauthsession.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case pendingauthsession.FieldSessionToken:
+ m.ResetSessionToken()
+ return nil
+ case pendingauthsession.FieldIntent:
+ m.ResetIntent()
+ return nil
+ case pendingauthsession.FieldProviderType:
+ m.ResetProviderType()
+ return nil
+ case pendingauthsession.FieldProviderKey:
+ m.ResetProviderKey()
+ return nil
+ case pendingauthsession.FieldProviderSubject:
+ m.ResetProviderSubject()
+ return nil
+ case pendingauthsession.FieldTargetUserID:
+ m.ResetTargetUserID()
+ return nil
+ case pendingauthsession.FieldRedirectTo:
+ m.ResetRedirectTo()
+ return nil
+ case pendingauthsession.FieldResolvedEmail:
+ m.ResetResolvedEmail()
+ return nil
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ m.ResetRegistrationPasswordHash()
+ return nil
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ m.ResetUpstreamIdentityClaims()
+ return nil
+ case pendingauthsession.FieldLocalFlowState:
+ m.ResetLocalFlowState()
+ return nil
+ case pendingauthsession.FieldBrowserSessionKey:
+ m.ResetBrowserSessionKey()
+ return nil
+ case pendingauthsession.FieldCompletionCodeHash:
+ m.ResetCompletionCodeHash()
+ return nil
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ m.ResetCompletionCodeExpiresAt()
+ return nil
+ case pendingauthsession.FieldEmailVerifiedAt:
+ m.ResetEmailVerifiedAt()
+ return nil
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ m.ResetPasswordVerifiedAt()
+ return nil
+ case pendingauthsession.FieldTotpVerifiedAt:
+ m.ResetTotpVerifiedAt()
+ return nil
+ case pendingauthsession.FieldExpiresAt:
+ m.ResetExpiresAt()
+ return nil
+ case pendingauthsession.FieldConsumedAt:
+ m.ResetConsumedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *PendingAuthSessionMutation) AddedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.target_user != nil {
+ edges = append(edges, pendingauthsession.EdgeTargetUser)
+ }
+ if m.adoption_decision != nil {
+ edges = append(edges, pendingauthsession.EdgeAdoptionDecision)
+ }
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *PendingAuthSessionMutation) AddedIDs(name string) []ent.Value {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ if id := m.target_user; id != nil {
+ return []ent.Value{*id}
+ }
+ case pendingauthsession.EdgeAdoptionDecision:
+ if id := m.adoption_decision; id != nil {
+ return []ent.Value{*id}
+ }
+ }
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *PendingAuthSessionMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 2)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *PendingAuthSessionMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *PendingAuthSessionMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 2)
+ if m.clearedtarget_user {
+ edges = append(edges, pendingauthsession.EdgeTargetUser)
+ }
+ if m.clearedadoption_decision {
+ edges = append(edges, pendingauthsession.EdgeAdoptionDecision)
+ }
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *PendingAuthSessionMutation) EdgeCleared(name string) bool {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ return m.clearedtarget_user
+ case pendingauthsession.EdgeAdoptionDecision:
+ return m.clearedadoption_decision
+ }
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *PendingAuthSessionMutation) ClearEdge(name string) error {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ m.ClearTargetUser()
+ return nil
+ case pendingauthsession.EdgeAdoptionDecision:
+ m.ClearAdoptionDecision()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *PendingAuthSessionMutation) ResetEdge(name string) error {
+ switch name {
+ case pendingauthsession.EdgeTargetUser:
+ m.ResetTargetUser()
+ return nil
+ case pendingauthsession.EdgeAdoptionDecision:
+ m.ResetAdoptionDecision()
+ return nil
+ }
+ return fmt.Errorf("unknown PendingAuthSession edge %s", name)
+}
+
// PromoCodeMutation represents an operation that mutates the PromoCode nodes in the graph.
type PromoCodeMutation struct {
config
@@ -17258,6 +30388,1171 @@ func (m *SettingMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown Setting edge %s", name)
}
+// SubscriptionPlanMutation represents an operation that mutates the SubscriptionPlan nodes in the graph.
+type SubscriptionPlanMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ group_id *int64
+ addgroup_id *int64
+ name *string
+ description *string
+ price *float64
+ addprice *float64
+ original_price *float64
+ addoriginal_price *float64
+ validity_days *int
+ addvalidity_days *int
+ validity_unit *string
+ features *string
+ product_name *string
+ for_sale *bool
+ sort_order *int
+ addsort_order *int
+ created_at *time.Time
+ updated_at *time.Time
+ clearedFields map[string]struct{}
+ done bool
+ oldValue func(context.Context) (*SubscriptionPlan, error)
+ predicates []predicate.SubscriptionPlan
+}
+
+var _ ent.Mutation = (*SubscriptionPlanMutation)(nil)
+
+// subscriptionplanOption allows management of the mutation configuration using functional options.
+type subscriptionplanOption func(*SubscriptionPlanMutation)
+
+// newSubscriptionPlanMutation creates new mutation for the SubscriptionPlan entity.
+func newSubscriptionPlanMutation(c config, op Op, opts ...subscriptionplanOption) *SubscriptionPlanMutation {
+ m := &SubscriptionPlanMutation{
+ config: c,
+ op: op,
+ typ: TypeSubscriptionPlan,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withSubscriptionPlanID sets the ID field of the mutation.
+func withSubscriptionPlanID(id int64) subscriptionplanOption {
+ return func(m *SubscriptionPlanMutation) {
+ var (
+ err error
+ once sync.Once
+ value *SubscriptionPlan
+ )
+ m.oldValue = func(ctx context.Context) (*SubscriptionPlan, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().SubscriptionPlan.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withSubscriptionPlan sets the old SubscriptionPlan of the mutation.
+func withSubscriptionPlan(node *SubscriptionPlan) subscriptionplanOption {
+ return func(m *SubscriptionPlanMutation) {
+ m.oldValue = func(context.Context) (*SubscriptionPlan, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m SubscriptionPlanMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m SubscriptionPlanMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *SubscriptionPlanMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *SubscriptionPlanMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().SubscriptionPlan.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetGroupID sets the "group_id" field.
+func (m *SubscriptionPlanMutation) SetGroupID(i int64) {
+ m.group_id = &i
+ m.addgroup_id = nil
+}
+
+// GroupID returns the value of the "group_id" field in the mutation.
+func (m *SubscriptionPlanMutation) GroupID() (r int64, exists bool) {
+ v := m.group_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldGroupID returns the old "group_id" field's value of the SubscriptionPlan entity.
+// If the SubscriptionPlan object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *SubscriptionPlanMutation) OldGroupID(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldGroupID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldGroupID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldGroupID: %w", err)
+ }
+ return oldValue.GroupID, nil
+}
+
+// AddGroupID adds i to the "group_id" field.
+func (m *SubscriptionPlanMutation) AddGroupID(i int64) {
+ if m.addgroup_id != nil {
+ *m.addgroup_id += i
+ } else {
+ m.addgroup_id = &i
+ }
+}
+
+// AddedGroupID returns the value that was added to the "group_id" field in this mutation.
+func (m *SubscriptionPlanMutation) AddedGroupID() (r int64, exists bool) {
+ v := m.addgroup_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetGroupID resets all changes to the "group_id" field.
+func (m *SubscriptionPlanMutation) ResetGroupID() {
+ m.group_id = nil
+ m.addgroup_id = nil
+}
+
+// SetName sets the "name" field.
+func (m *SubscriptionPlanMutation) SetName(s string) {
+ m.name = &s
+}
+
+// Name returns the value of the "name" field in the mutation.
+func (m *SubscriptionPlanMutation) Name() (r string, exists bool) {
+ v := m.name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldName returns the old "name" field's value of the SubscriptionPlan entity.
+// If the SubscriptionPlan object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *SubscriptionPlanMutation) OldName(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldName: %w", err)
+ }
+ return oldValue.Name, nil
+}
+
+// ResetName resets all changes to the "name" field.
+func (m *SubscriptionPlanMutation) ResetName() {
+ m.name = nil
+}
+
+// SetDescription sets the "description" field.
+func (m *SubscriptionPlanMutation) SetDescription(s string) {
+ m.description = &s
+}
+
+// Description returns the value of the "description" field in the mutation.
+func (m *SubscriptionPlanMutation) Description() (r string, exists bool) {
+ v := m.description
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDescription returns the old "description" field's value of the SubscriptionPlan entity.
+// If the SubscriptionPlan object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *SubscriptionPlanMutation) OldDescription(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDescription is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDescription requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDescription: %w", err)
+ }
+ return oldValue.Description, nil
+}
+
+// ResetDescription resets all changes to the "description" field.
+func (m *SubscriptionPlanMutation) ResetDescription() {
+ m.description = nil
+}
+
+// SetPrice sets the "price" field.
+func (m *SubscriptionPlanMutation) SetPrice(f float64) {
+ m.price = &f
+ m.addprice = nil
+}
+
+// Price returns the value of the "price" field in the mutation.
+func (m *SubscriptionPlanMutation) Price() (r float64, exists bool) {
+ v := m.price
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldPrice returns the old "price" field's value of the SubscriptionPlan entity.
+// If the SubscriptionPlan object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *SubscriptionPlanMutation) OldPrice(ctx context.Context) (v float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldPrice is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldPrice requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldPrice: %w", err)
+ }
+ return oldValue.Price, nil
+}
+
+// AddPrice adds f to the "price" field.
+func (m *SubscriptionPlanMutation) AddPrice(f float64) {
+ if m.addprice != nil {
+ *m.addprice += f
+ } else {
+ m.addprice = &f
+ }
+}
+
+// AddedPrice returns the value that was added to the "price" field in this mutation.
+func (m *SubscriptionPlanMutation) AddedPrice() (r float64, exists bool) {
+ v := m.addprice
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetPrice resets all changes to the "price" field.
+func (m *SubscriptionPlanMutation) ResetPrice() {
+ m.price = nil
+ m.addprice = nil
+}
+
+// SetOriginalPrice sets the "original_price" field.
+func (m *SubscriptionPlanMutation) SetOriginalPrice(f float64) {
+ m.original_price = &f
+ m.addoriginal_price = nil
+}
+
+// OriginalPrice returns the value of the "original_price" field in the mutation.
+func (m *SubscriptionPlanMutation) OriginalPrice() (r float64, exists bool) {
+ v := m.original_price
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldOriginalPrice returns the old "original_price" field's value of the SubscriptionPlan entity.
+// If the SubscriptionPlan object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *SubscriptionPlanMutation) OldOriginalPrice(ctx context.Context) (v *float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldOriginalPrice is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldOriginalPrice requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldOriginalPrice: %w", err)
+ }
+ return oldValue.OriginalPrice, nil
+}
+
+// AddOriginalPrice adds f to the "original_price" field.
+func (m *SubscriptionPlanMutation) AddOriginalPrice(f float64) {
+ if m.addoriginal_price != nil {
+ *m.addoriginal_price += f
+ } else {
+ m.addoriginal_price = &f
+ }
+}
+
+// AddedOriginalPrice returns the value that was added to the "original_price" field in this mutation.
+func (m *SubscriptionPlanMutation) AddedOriginalPrice() (r float64, exists bool) {
+ v := m.addoriginal_price
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearOriginalPrice clears the value of the "original_price" field.
+func (m *SubscriptionPlanMutation) ClearOriginalPrice() {
+ m.original_price = nil
+ m.addoriginal_price = nil
+ m.clearedFields[subscriptionplan.FieldOriginalPrice] = struct{}{}
+}
+
+// OriginalPriceCleared returns if the "original_price" field was cleared in this mutation.
+func (m *SubscriptionPlanMutation) OriginalPriceCleared() bool {
+ _, ok := m.clearedFields[subscriptionplan.FieldOriginalPrice]
+ return ok
+}
+
+// ResetOriginalPrice resets all changes to the "original_price" field.
+func (m *SubscriptionPlanMutation) ResetOriginalPrice() {
+ m.original_price = nil
+ m.addoriginal_price = nil
+ delete(m.clearedFields, subscriptionplan.FieldOriginalPrice)
+}
+
+// SetValidityDays sets the "validity_days" field.
+func (m *SubscriptionPlanMutation) SetValidityDays(i int) {
+ m.validity_days = &i
+ m.addvalidity_days = nil
+}
+
+// ValidityDays returns the value of the "validity_days" field in the mutation.
+func (m *SubscriptionPlanMutation) ValidityDays() (r int, exists bool) {
+ v := m.validity_days
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldValidityDays returns the old "validity_days" field's value of the SubscriptionPlan entity.
+// If the SubscriptionPlan object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *SubscriptionPlanMutation) OldValidityDays(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldValidityDays is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldValidityDays requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldValidityDays: %w", err)
+ }
+ return oldValue.ValidityDays, nil
+}
+
+// AddValidityDays adds i to the "validity_days" field.
+func (m *SubscriptionPlanMutation) AddValidityDays(i int) {
+ if m.addvalidity_days != nil {
+ *m.addvalidity_days += i
+ } else {
+ m.addvalidity_days = &i
+ }
+}
+
+// AddedValidityDays returns the value that was added to the "validity_days" field in this mutation.
+func (m *SubscriptionPlanMutation) AddedValidityDays() (r int, exists bool) {
+ v := m.addvalidity_days
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetValidityDays resets all changes to the "validity_days" field.
+func (m *SubscriptionPlanMutation) ResetValidityDays() {
+ m.validity_days = nil
+ m.addvalidity_days = nil
+}
+
+// SetValidityUnit sets the "validity_unit" field.
+func (m *SubscriptionPlanMutation) SetValidityUnit(s string) {
+ m.validity_unit = &s
+}
+
+// ValidityUnit returns the value of the "validity_unit" field in the mutation.
+func (m *SubscriptionPlanMutation) ValidityUnit() (r string, exists bool) {
+ v := m.validity_unit
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldValidityUnit returns the old "validity_unit" field's value of the SubscriptionPlan entity.
+// If the SubscriptionPlan object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *SubscriptionPlanMutation) OldValidityUnit(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldValidityUnit is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldValidityUnit requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldValidityUnit: %w", err)
+ }
+ return oldValue.ValidityUnit, nil
+}
+
+// ResetValidityUnit resets all changes to the "validity_unit" field.
+func (m *SubscriptionPlanMutation) ResetValidityUnit() {
+ m.validity_unit = nil
+}
+
+// SetFeatures sets the "features" field.
+func (m *SubscriptionPlanMutation) SetFeatures(s string) {
+ m.features = &s
+}
+
+// Features returns the value of the "features" field in the mutation.
+func (m *SubscriptionPlanMutation) Features() (r string, exists bool) {
+ v := m.features
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldFeatures returns the old "features" field's value of the SubscriptionPlan entity.
+// If the SubscriptionPlan object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *SubscriptionPlanMutation) OldFeatures(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldFeatures is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldFeatures requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldFeatures: %w", err)
+ }
+ return oldValue.Features, nil
+}
+
+// ResetFeatures resets all changes to the "features" field.
+func (m *SubscriptionPlanMutation) ResetFeatures() {
+ m.features = nil
+}
+
+// SetProductName sets the "product_name" field.
+func (m *SubscriptionPlanMutation) SetProductName(s string) {
+ m.product_name = &s
+}
+
+// ProductName returns the value of the "product_name" field in the mutation.
+func (m *SubscriptionPlanMutation) ProductName() (r string, exists bool) {
+ v := m.product_name
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldProductName returns the old "product_name" field's value of the SubscriptionPlan entity.
+// If the SubscriptionPlan object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *SubscriptionPlanMutation) OldProductName(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldProductName is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldProductName requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldProductName: %w", err)
+ }
+ return oldValue.ProductName, nil
+}
+
+// ResetProductName resets all changes to the "product_name" field.
+func (m *SubscriptionPlanMutation) ResetProductName() {
+ m.product_name = nil
+}
+
+// SetForSale sets the "for_sale" field.
+func (m *SubscriptionPlanMutation) SetForSale(b bool) {
+ m.for_sale = &b
+}
+
+// ForSale returns the value of the "for_sale" field in the mutation.
+func (m *SubscriptionPlanMutation) ForSale() (r bool, exists bool) {
+ v := m.for_sale
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldForSale returns the old "for_sale" field's value of the SubscriptionPlan entity.
+// If the SubscriptionPlan object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *SubscriptionPlanMutation) OldForSale(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldForSale is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldForSale requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldForSale: %w", err)
+ }
+ return oldValue.ForSale, nil
+}
+
+// ResetForSale resets all changes to the "for_sale" field.
+func (m *SubscriptionPlanMutation) ResetForSale() {
+ m.for_sale = nil
+}
+
+// SetSortOrder sets the "sort_order" field.
+func (m *SubscriptionPlanMutation) SetSortOrder(i int) {
+ m.sort_order = &i
+ m.addsort_order = nil
+}
+
+// SortOrder returns the value of the "sort_order" field in the mutation.
+func (m *SubscriptionPlanMutation) SortOrder() (r int, exists bool) {
+ v := m.sort_order
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldSortOrder returns the old "sort_order" field's value of the SubscriptionPlan entity.
+// If the SubscriptionPlan object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *SubscriptionPlanMutation) OldSortOrder(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldSortOrder is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldSortOrder requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldSortOrder: %w", err)
+ }
+ return oldValue.SortOrder, nil
+}
+
+// AddSortOrder adds i to the "sort_order" field.
+func (m *SubscriptionPlanMutation) AddSortOrder(i int) {
+ if m.addsort_order != nil {
+ *m.addsort_order += i
+ } else {
+ m.addsort_order = &i
+ }
+}
+
+// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation.
+func (m *SubscriptionPlanMutation) AddedSortOrder() (r int, exists bool) {
+ v := m.addsort_order
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetSortOrder resets all changes to the "sort_order" field.
+func (m *SubscriptionPlanMutation) ResetSortOrder() {
+ m.sort_order = nil
+ m.addsort_order = nil
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *SubscriptionPlanMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *SubscriptionPlanMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the SubscriptionPlan entity.
+// If the SubscriptionPlan object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *SubscriptionPlanMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *SubscriptionPlanMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *SubscriptionPlanMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *SubscriptionPlanMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the SubscriptionPlan entity.
+// If the SubscriptionPlan object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *SubscriptionPlanMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *SubscriptionPlanMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// Where appends a list predicates to the SubscriptionPlanMutation builder.
+func (m *SubscriptionPlanMutation) Where(ps ...predicate.SubscriptionPlan) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the SubscriptionPlanMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *SubscriptionPlanMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.SubscriptionPlan, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *SubscriptionPlanMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *SubscriptionPlanMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (SubscriptionPlan).
+func (m *SubscriptionPlanMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *SubscriptionPlanMutation) Fields() []string {
+ fields := make([]string, 0, 13)
+ if m.group_id != nil {
+ fields = append(fields, subscriptionplan.FieldGroupID)
+ }
+ if m.name != nil {
+ fields = append(fields, subscriptionplan.FieldName)
+ }
+ if m.description != nil {
+ fields = append(fields, subscriptionplan.FieldDescription)
+ }
+ if m.price != nil {
+ fields = append(fields, subscriptionplan.FieldPrice)
+ }
+ if m.original_price != nil {
+ fields = append(fields, subscriptionplan.FieldOriginalPrice)
+ }
+ if m.validity_days != nil {
+ fields = append(fields, subscriptionplan.FieldValidityDays)
+ }
+ if m.validity_unit != nil {
+ fields = append(fields, subscriptionplan.FieldValidityUnit)
+ }
+ if m.features != nil {
+ fields = append(fields, subscriptionplan.FieldFeatures)
+ }
+ if m.product_name != nil {
+ fields = append(fields, subscriptionplan.FieldProductName)
+ }
+ if m.for_sale != nil {
+ fields = append(fields, subscriptionplan.FieldForSale)
+ }
+ if m.sort_order != nil {
+ fields = append(fields, subscriptionplan.FieldSortOrder)
+ }
+ if m.created_at != nil {
+ fields = append(fields, subscriptionplan.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, subscriptionplan.FieldUpdatedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *SubscriptionPlanMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case subscriptionplan.FieldGroupID:
+ return m.GroupID()
+ case subscriptionplan.FieldName:
+ return m.Name()
+ case subscriptionplan.FieldDescription:
+ return m.Description()
+ case subscriptionplan.FieldPrice:
+ return m.Price()
+ case subscriptionplan.FieldOriginalPrice:
+ return m.OriginalPrice()
+ case subscriptionplan.FieldValidityDays:
+ return m.ValidityDays()
+ case subscriptionplan.FieldValidityUnit:
+ return m.ValidityUnit()
+ case subscriptionplan.FieldFeatures:
+ return m.Features()
+ case subscriptionplan.FieldProductName:
+ return m.ProductName()
+ case subscriptionplan.FieldForSale:
+ return m.ForSale()
+ case subscriptionplan.FieldSortOrder:
+ return m.SortOrder()
+ case subscriptionplan.FieldCreatedAt:
+ return m.CreatedAt()
+ case subscriptionplan.FieldUpdatedAt:
+ return m.UpdatedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *SubscriptionPlanMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case subscriptionplan.FieldGroupID:
+ return m.OldGroupID(ctx)
+ case subscriptionplan.FieldName:
+ return m.OldName(ctx)
+ case subscriptionplan.FieldDescription:
+ return m.OldDescription(ctx)
+ case subscriptionplan.FieldPrice:
+ return m.OldPrice(ctx)
+ case subscriptionplan.FieldOriginalPrice:
+ return m.OldOriginalPrice(ctx)
+ case subscriptionplan.FieldValidityDays:
+ return m.OldValidityDays(ctx)
+ case subscriptionplan.FieldValidityUnit:
+ return m.OldValidityUnit(ctx)
+ case subscriptionplan.FieldFeatures:
+ return m.OldFeatures(ctx)
+ case subscriptionplan.FieldProductName:
+ return m.OldProductName(ctx)
+ case subscriptionplan.FieldForSale:
+ return m.OldForSale(ctx)
+ case subscriptionplan.FieldSortOrder:
+ return m.OldSortOrder(ctx)
+ case subscriptionplan.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case subscriptionplan.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown SubscriptionPlan field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *SubscriptionPlanMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case subscriptionplan.FieldGroupID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetGroupID(v)
+ return nil
+ case subscriptionplan.FieldName:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetName(v)
+ return nil
+ case subscriptionplan.FieldDescription:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDescription(v)
+ return nil
+ case subscriptionplan.FieldPrice:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetPrice(v)
+ return nil
+ case subscriptionplan.FieldOriginalPrice:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetOriginalPrice(v)
+ return nil
+ case subscriptionplan.FieldValidityDays:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetValidityDays(v)
+ return nil
+ case subscriptionplan.FieldValidityUnit:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetValidityUnit(v)
+ return nil
+ case subscriptionplan.FieldFeatures:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetFeatures(v)
+ return nil
+ case subscriptionplan.FieldProductName:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetProductName(v)
+ return nil
+ case subscriptionplan.FieldForSale:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetForSale(v)
+ return nil
+ case subscriptionplan.FieldSortOrder:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetSortOrder(v)
+ return nil
+ case subscriptionplan.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case subscriptionplan.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown SubscriptionPlan field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *SubscriptionPlanMutation) AddedFields() []string {
+ var fields []string
+ if m.addgroup_id != nil {
+ fields = append(fields, subscriptionplan.FieldGroupID)
+ }
+ if m.addprice != nil {
+ fields = append(fields, subscriptionplan.FieldPrice)
+ }
+ if m.addoriginal_price != nil {
+ fields = append(fields, subscriptionplan.FieldOriginalPrice)
+ }
+ if m.addvalidity_days != nil {
+ fields = append(fields, subscriptionplan.FieldValidityDays)
+ }
+ if m.addsort_order != nil {
+ fields = append(fields, subscriptionplan.FieldSortOrder)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *SubscriptionPlanMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case subscriptionplan.FieldGroupID:
+ return m.AddedGroupID()
+ case subscriptionplan.FieldPrice:
+ return m.AddedPrice()
+ case subscriptionplan.FieldOriginalPrice:
+ return m.AddedOriginalPrice()
+ case subscriptionplan.FieldValidityDays:
+ return m.AddedValidityDays()
+ case subscriptionplan.FieldSortOrder:
+ return m.AddedSortOrder()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *SubscriptionPlanMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case subscriptionplan.FieldGroupID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddGroupID(v)
+ return nil
+ case subscriptionplan.FieldPrice:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddPrice(v)
+ return nil
+ case subscriptionplan.FieldOriginalPrice:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddOriginalPrice(v)
+ return nil
+ case subscriptionplan.FieldValidityDays:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddValidityDays(v)
+ return nil
+ case subscriptionplan.FieldSortOrder:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddSortOrder(v)
+ return nil
+ }
+ return fmt.Errorf("unknown SubscriptionPlan numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *SubscriptionPlanMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(subscriptionplan.FieldOriginalPrice) {
+ fields = append(fields, subscriptionplan.FieldOriginalPrice)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *SubscriptionPlanMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *SubscriptionPlanMutation) ClearField(name string) error {
+ switch name {
+ case subscriptionplan.FieldOriginalPrice:
+ m.ClearOriginalPrice()
+ return nil
+ }
+ return fmt.Errorf("unknown SubscriptionPlan nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *SubscriptionPlanMutation) ResetField(name string) error {
+ switch name {
+ case subscriptionplan.FieldGroupID:
+ m.ResetGroupID()
+ return nil
+ case subscriptionplan.FieldName:
+ m.ResetName()
+ return nil
+ case subscriptionplan.FieldDescription:
+ m.ResetDescription()
+ return nil
+ case subscriptionplan.FieldPrice:
+ m.ResetPrice()
+ return nil
+ case subscriptionplan.FieldOriginalPrice:
+ m.ResetOriginalPrice()
+ return nil
+ case subscriptionplan.FieldValidityDays:
+ m.ResetValidityDays()
+ return nil
+ case subscriptionplan.FieldValidityUnit:
+ m.ResetValidityUnit()
+ return nil
+ case subscriptionplan.FieldFeatures:
+ m.ResetFeatures()
+ return nil
+ case subscriptionplan.FieldProductName:
+ m.ResetProductName()
+ return nil
+ case subscriptionplan.FieldForSale:
+ m.ResetForSale()
+ return nil
+ case subscriptionplan.FieldSortOrder:
+ m.ResetSortOrder()
+ return nil
+ case subscriptionplan.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case subscriptionplan.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown SubscriptionPlan field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *SubscriptionPlanMutation) AddedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *SubscriptionPlanMutation) AddedIDs(name string) []ent.Value {
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *SubscriptionPlanMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *SubscriptionPlanMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *SubscriptionPlanMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *SubscriptionPlanMutation) EdgeCleared(name string) bool {
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *SubscriptionPlanMutation) ClearEdge(name string) error {
+ return fmt.Errorf("unknown SubscriptionPlan unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *SubscriptionPlanMutation) ResetEdge(name string) error {
+ return fmt.Errorf("unknown SubscriptionPlan edge %s", name)
+}
+
// TLSFingerprintProfileMutation represents an operation that mutates the TLSFingerprintProfile nodes in the graph.
type TLSFingerprintProfileMutation struct {
config
@@ -19725,6 +34020,11 @@ type UsageLogMutation struct {
model *string
requested_model *string
upstream_model *string
+ channel_id *int64
+ addchannel_id *int64
+ model_mapping_chain *string
+ billing_tier *string
+ billing_mode *string
input_tokens *int
addinput_tokens *int
output_tokens *int
@@ -19765,7 +34065,6 @@ type UsageLogMutation struct {
image_count *int
addimage_count *int
image_size *string
- media_type *string
cache_ttl_overridden *bool
created_at *time.Time
clearedFields map[string]struct{}
@@ -20160,6 +34459,223 @@ func (m *UsageLogMutation) ResetUpstreamModel() {
delete(m.clearedFields, usagelog.FieldUpstreamModel)
}
+// SetChannelID sets the "channel_id" field.
+func (m *UsageLogMutation) SetChannelID(i int64) {
+ m.channel_id = &i
+ m.addchannel_id = nil
+}
+
+// ChannelID returns the value of the "channel_id" field in the mutation.
+func (m *UsageLogMutation) ChannelID() (r int64, exists bool) {
+ v := m.channel_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldChannelID returns the old "channel_id" field's value of the UsageLog entity.
+// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageLogMutation) OldChannelID(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldChannelID is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldChannelID requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldChannelID: %w", err)
+ }
+ return oldValue.ChannelID, nil
+}
+
+// AddChannelID adds i to the "channel_id" field.
+func (m *UsageLogMutation) AddChannelID(i int64) {
+ if m.addchannel_id != nil {
+ *m.addchannel_id += i
+ } else {
+ m.addchannel_id = &i
+ }
+}
+
+// AddedChannelID returns the value that was added to the "channel_id" field in this mutation.
+func (m *UsageLogMutation) AddedChannelID() (r int64, exists bool) {
+ v := m.addchannel_id
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearChannelID clears the value of the "channel_id" field.
+func (m *UsageLogMutation) ClearChannelID() {
+ m.channel_id = nil
+ m.addchannel_id = nil
+ m.clearedFields[usagelog.FieldChannelID] = struct{}{}
+}
+
+// ChannelIDCleared returns if the "channel_id" field was cleared in this mutation.
+func (m *UsageLogMutation) ChannelIDCleared() bool {
+ _, ok := m.clearedFields[usagelog.FieldChannelID]
+ return ok
+}
+
+// ResetChannelID resets all changes to the "channel_id" field.
+func (m *UsageLogMutation) ResetChannelID() {
+ m.channel_id = nil
+ m.addchannel_id = nil
+ delete(m.clearedFields, usagelog.FieldChannelID)
+}
+
+// SetModelMappingChain sets the "model_mapping_chain" field.
+func (m *UsageLogMutation) SetModelMappingChain(s string) {
+ m.model_mapping_chain = &s
+}
+
+// ModelMappingChain returns the value of the "model_mapping_chain" field in the mutation.
+func (m *UsageLogMutation) ModelMappingChain() (r string, exists bool) {
+ v := m.model_mapping_chain
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldModelMappingChain returns the old "model_mapping_chain" field's value of the UsageLog entity.
+// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageLogMutation) OldModelMappingChain(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldModelMappingChain is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldModelMappingChain requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldModelMappingChain: %w", err)
+ }
+ return oldValue.ModelMappingChain, nil
+}
+
+// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
+func (m *UsageLogMutation) ClearModelMappingChain() {
+ m.model_mapping_chain = nil
+ m.clearedFields[usagelog.FieldModelMappingChain] = struct{}{}
+}
+
+// ModelMappingChainCleared returns if the "model_mapping_chain" field was cleared in this mutation.
+func (m *UsageLogMutation) ModelMappingChainCleared() bool {
+ _, ok := m.clearedFields[usagelog.FieldModelMappingChain]
+ return ok
+}
+
+// ResetModelMappingChain resets all changes to the "model_mapping_chain" field.
+func (m *UsageLogMutation) ResetModelMappingChain() {
+ m.model_mapping_chain = nil
+ delete(m.clearedFields, usagelog.FieldModelMappingChain)
+}
+
+// SetBillingTier sets the "billing_tier" field.
+func (m *UsageLogMutation) SetBillingTier(s string) {
+ m.billing_tier = &s
+}
+
+// BillingTier returns the value of the "billing_tier" field in the mutation.
+func (m *UsageLogMutation) BillingTier() (r string, exists bool) {
+ v := m.billing_tier
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBillingTier returns the old "billing_tier" field's value of the UsageLog entity.
+// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageLogMutation) OldBillingTier(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBillingTier is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBillingTier requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBillingTier: %w", err)
+ }
+ return oldValue.BillingTier, nil
+}
+
+// ClearBillingTier clears the value of the "billing_tier" field.
+func (m *UsageLogMutation) ClearBillingTier() {
+ m.billing_tier = nil
+ m.clearedFields[usagelog.FieldBillingTier] = struct{}{}
+}
+
+// BillingTierCleared returns if the "billing_tier" field was cleared in this mutation.
+func (m *UsageLogMutation) BillingTierCleared() bool {
+ _, ok := m.clearedFields[usagelog.FieldBillingTier]
+ return ok
+}
+
+// ResetBillingTier resets all changes to the "billing_tier" field.
+func (m *UsageLogMutation) ResetBillingTier() {
+ m.billing_tier = nil
+ delete(m.clearedFields, usagelog.FieldBillingTier)
+}
+
+// SetBillingMode sets the "billing_mode" field.
+func (m *UsageLogMutation) SetBillingMode(s string) {
+ m.billing_mode = &s
+}
+
+// BillingMode returns the value of the "billing_mode" field in the mutation.
+func (m *UsageLogMutation) BillingMode() (r string, exists bool) {
+ v := m.billing_mode
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBillingMode returns the old "billing_mode" field's value of the UsageLog entity.
+// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageLogMutation) OldBillingMode(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBillingMode is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBillingMode requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBillingMode: %w", err)
+ }
+ return oldValue.BillingMode, nil
+}
+
+// ClearBillingMode clears the value of the "billing_mode" field.
+func (m *UsageLogMutation) ClearBillingMode() {
+ m.billing_mode = nil
+ m.clearedFields[usagelog.FieldBillingMode] = struct{}{}
+}
+
+// BillingModeCleared returns if the "billing_mode" field was cleared in this mutation.
+func (m *UsageLogMutation) BillingModeCleared() bool {
+ _, ok := m.clearedFields[usagelog.FieldBillingMode]
+ return ok
+}
+
+// ResetBillingMode resets all changes to the "billing_mode" field.
+func (m *UsageLogMutation) ResetBillingMode() {
+ m.billing_mode = nil
+ delete(m.clearedFields, usagelog.FieldBillingMode)
+}
+
// SetGroupID sets the "group_id" field.
func (m *UsageLogMutation) SetGroupID(i int64) {
m.group = &i
@@ -21491,55 +36007,6 @@ func (m *UsageLogMutation) ResetImageSize() {
delete(m.clearedFields, usagelog.FieldImageSize)
}
-// SetMediaType sets the "media_type" field.
-func (m *UsageLogMutation) SetMediaType(s string) {
- m.media_type = &s
-}
-
-// MediaType returns the value of the "media_type" field in the mutation.
-func (m *UsageLogMutation) MediaType() (r string, exists bool) {
- v := m.media_type
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldMediaType returns the old "media_type" field's value of the UsageLog entity.
-// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
-// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *UsageLogMutation) OldMediaType(ctx context.Context) (v *string, err error) {
- if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldMediaType is only allowed on UpdateOne operations")
- }
- if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldMediaType requires an ID field in the mutation")
- }
- oldValue, err := m.oldValue(ctx)
- if err != nil {
- return v, fmt.Errorf("querying old value for OldMediaType: %w", err)
- }
- return oldValue.MediaType, nil
-}
-
-// ClearMediaType clears the value of the "media_type" field.
-func (m *UsageLogMutation) ClearMediaType() {
- m.media_type = nil
- m.clearedFields[usagelog.FieldMediaType] = struct{}{}
-}
-
-// MediaTypeCleared returns if the "media_type" field was cleared in this mutation.
-func (m *UsageLogMutation) MediaTypeCleared() bool {
- _, ok := m.clearedFields[usagelog.FieldMediaType]
- return ok
-}
-
-// ResetMediaType resets all changes to the "media_type" field.
-func (m *UsageLogMutation) ResetMediaType() {
- m.media_type = nil
- delete(m.clearedFields, usagelog.FieldMediaType)
-}
-
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) {
m.cache_ttl_overridden = &b
@@ -21781,7 +36248,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UsageLogMutation) Fields() []string {
- fields := make([]string, 0, 34)
+ fields := make([]string, 0, 37)
if m.user != nil {
fields = append(fields, usagelog.FieldUserID)
}
@@ -21803,6 +36270,18 @@ func (m *UsageLogMutation) Fields() []string {
if m.upstream_model != nil {
fields = append(fields, usagelog.FieldUpstreamModel)
}
+ if m.channel_id != nil {
+ fields = append(fields, usagelog.FieldChannelID)
+ }
+ if m.model_mapping_chain != nil {
+ fields = append(fields, usagelog.FieldModelMappingChain)
+ }
+ if m.billing_tier != nil {
+ fields = append(fields, usagelog.FieldBillingTier)
+ }
+ if m.billing_mode != nil {
+ fields = append(fields, usagelog.FieldBillingMode)
+ }
if m.group != nil {
fields = append(fields, usagelog.FieldGroupID)
}
@@ -21875,9 +36354,6 @@ func (m *UsageLogMutation) Fields() []string {
if m.image_size != nil {
fields = append(fields, usagelog.FieldImageSize)
}
- if m.media_type != nil {
- fields = append(fields, usagelog.FieldMediaType)
- }
if m.cache_ttl_overridden != nil {
fields = append(fields, usagelog.FieldCacheTTLOverridden)
}
@@ -21906,6 +36382,14 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.RequestedModel()
case usagelog.FieldUpstreamModel:
return m.UpstreamModel()
+ case usagelog.FieldChannelID:
+ return m.ChannelID()
+ case usagelog.FieldModelMappingChain:
+ return m.ModelMappingChain()
+ case usagelog.FieldBillingTier:
+ return m.BillingTier()
+ case usagelog.FieldBillingMode:
+ return m.BillingMode()
case usagelog.FieldGroupID:
return m.GroupID()
case usagelog.FieldSubscriptionID:
@@ -21954,8 +36438,6 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.ImageCount()
case usagelog.FieldImageSize:
return m.ImageSize()
- case usagelog.FieldMediaType:
- return m.MediaType()
case usagelog.FieldCacheTTLOverridden:
return m.CacheTTLOverridden()
case usagelog.FieldCreatedAt:
@@ -21983,6 +36465,14 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldRequestedModel(ctx)
case usagelog.FieldUpstreamModel:
return m.OldUpstreamModel(ctx)
+ case usagelog.FieldChannelID:
+ return m.OldChannelID(ctx)
+ case usagelog.FieldModelMappingChain:
+ return m.OldModelMappingChain(ctx)
+ case usagelog.FieldBillingTier:
+ return m.OldBillingTier(ctx)
+ case usagelog.FieldBillingMode:
+ return m.OldBillingMode(ctx)
case usagelog.FieldGroupID:
return m.OldGroupID(ctx)
case usagelog.FieldSubscriptionID:
@@ -22031,8 +36521,6 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldImageCount(ctx)
case usagelog.FieldImageSize:
return m.OldImageSize(ctx)
- case usagelog.FieldMediaType:
- return m.OldMediaType(ctx)
case usagelog.FieldCacheTTLOverridden:
return m.OldCacheTTLOverridden(ctx)
case usagelog.FieldCreatedAt:
@@ -22095,6 +36583,34 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetUpstreamModel(v)
return nil
+ case usagelog.FieldChannelID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetChannelID(v)
+ return nil
+ case usagelog.FieldModelMappingChain:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetModelMappingChain(v)
+ return nil
+ case usagelog.FieldBillingTier:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBillingTier(v)
+ return nil
+ case usagelog.FieldBillingMode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBillingMode(v)
+ return nil
case usagelog.FieldGroupID:
v, ok := value.(int64)
if !ok {
@@ -22263,13 +36779,6 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetImageSize(v)
return nil
- case usagelog.FieldMediaType:
- v, ok := value.(string)
- if !ok {
- return fmt.Errorf("unexpected type %T for field %s", value, name)
- }
- m.SetMediaType(v)
- return nil
case usagelog.FieldCacheTTLOverridden:
v, ok := value.(bool)
if !ok {
@@ -22292,6 +36801,9 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
// this mutation.
func (m *UsageLogMutation) AddedFields() []string {
var fields []string
+ if m.addchannel_id != nil {
+ fields = append(fields, usagelog.FieldChannelID)
+ }
if m.addinput_tokens != nil {
fields = append(fields, usagelog.FieldInputTokens)
}
@@ -22354,6 +36866,8 @@ func (m *UsageLogMutation) AddedFields() []string {
// was not set, or was not defined in the schema.
func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) {
switch name {
+ case usagelog.FieldChannelID:
+ return m.AddedChannelID()
case usagelog.FieldInputTokens:
return m.AddedInputTokens()
case usagelog.FieldOutputTokens:
@@ -22399,6 +36913,13 @@ func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) {
// type.
func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
switch name {
+ case usagelog.FieldChannelID:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddChannelID(v)
+ return nil
case usagelog.FieldInputTokens:
v, ok := value.(int)
if !ok {
@@ -22539,6 +37060,18 @@ func (m *UsageLogMutation) ClearedFields() []string {
if m.FieldCleared(usagelog.FieldUpstreamModel) {
fields = append(fields, usagelog.FieldUpstreamModel)
}
+ if m.FieldCleared(usagelog.FieldChannelID) {
+ fields = append(fields, usagelog.FieldChannelID)
+ }
+ if m.FieldCleared(usagelog.FieldModelMappingChain) {
+ fields = append(fields, usagelog.FieldModelMappingChain)
+ }
+ if m.FieldCleared(usagelog.FieldBillingTier) {
+ fields = append(fields, usagelog.FieldBillingTier)
+ }
+ if m.FieldCleared(usagelog.FieldBillingMode) {
+ fields = append(fields, usagelog.FieldBillingMode)
+ }
if m.FieldCleared(usagelog.FieldGroupID) {
fields = append(fields, usagelog.FieldGroupID)
}
@@ -22563,9 +37096,6 @@ func (m *UsageLogMutation) ClearedFields() []string {
if m.FieldCleared(usagelog.FieldImageSize) {
fields = append(fields, usagelog.FieldImageSize)
}
- if m.FieldCleared(usagelog.FieldMediaType) {
- fields = append(fields, usagelog.FieldMediaType)
- }
return fields
}
@@ -22586,6 +37116,18 @@ func (m *UsageLogMutation) ClearField(name string) error {
case usagelog.FieldUpstreamModel:
m.ClearUpstreamModel()
return nil
+ case usagelog.FieldChannelID:
+ m.ClearChannelID()
+ return nil
+ case usagelog.FieldModelMappingChain:
+ m.ClearModelMappingChain()
+ return nil
+ case usagelog.FieldBillingTier:
+ m.ClearBillingTier()
+ return nil
+ case usagelog.FieldBillingMode:
+ m.ClearBillingMode()
+ return nil
case usagelog.FieldGroupID:
m.ClearGroupID()
return nil
@@ -22610,9 +37152,6 @@ func (m *UsageLogMutation) ClearField(name string) error {
case usagelog.FieldImageSize:
m.ClearImageSize()
return nil
- case usagelog.FieldMediaType:
- m.ClearMediaType()
- return nil
}
return fmt.Errorf("unknown UsageLog nullable field %s", name)
}
@@ -22642,6 +37181,18 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldUpstreamModel:
m.ResetUpstreamModel()
return nil
+ case usagelog.FieldChannelID:
+ m.ResetChannelID()
+ return nil
+ case usagelog.FieldModelMappingChain:
+ m.ResetModelMappingChain()
+ return nil
+ case usagelog.FieldBillingTier:
+ m.ResetBillingTier()
+ return nil
+ case usagelog.FieldBillingMode:
+ m.ResetBillingMode()
+ return nil
case usagelog.FieldGroupID:
m.ResetGroupID()
return nil
@@ -22714,9 +37265,6 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldImageSize:
m.ResetImageSize()
return nil
- case usagelog.FieldMediaType:
- m.ResetMediaType()
- return nil
case usagelog.FieldCacheTTLOverridden:
m.ResetCacheTTLOverridden()
return nil
@@ -22895,10 +37443,18 @@ type UserMutation struct {
totp_secret_encrypted *string
totp_enabled *bool
totp_enabled_at *time.Time
- sora_storage_quota_bytes *int64
- addsora_storage_quota_bytes *int64
- sora_storage_used_bytes *int64
- addsora_storage_used_bytes *int64
+ signup_source *string
+ last_login_at *time.Time
+ last_active_at *time.Time
+ balance_notify_enabled *bool
+ balance_notify_threshold_type *string
+ balance_notify_threshold *float64
+ addbalance_notify_threshold *float64
+ balance_notify_extra_emails *string
+ total_recharged *float64
+ addtotal_recharged *float64
+ rpm_limit *int
+ addrpm_limit *int
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -22927,6 +37483,15 @@ type UserMutation struct {
promo_code_usages map[int64]struct{}
removedpromo_code_usages map[int64]struct{}
clearedpromo_code_usages bool
+ payment_orders map[int64]struct{}
+ removedpayment_orders map[int64]struct{}
+ clearedpayment_orders bool
+ auth_identities map[int64]struct{}
+ removedauth_identities map[int64]struct{}
+ clearedauth_identities bool
+ pending_auth_sessions map[int64]struct{}
+ removedpending_auth_sessions map[int64]struct{}
+ clearedpending_auth_sessions bool
done bool
oldValue func(context.Context) (*User, error)
predicates []predicate.User
@@ -23613,116 +38178,428 @@ func (m *UserMutation) ResetTotpEnabledAt() {
delete(m.clearedFields, user.FieldTotpEnabledAt)
}
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (m *UserMutation) SetSoraStorageQuotaBytes(i int64) {
- m.sora_storage_quota_bytes = &i
- m.addsora_storage_quota_bytes = nil
+// SetSignupSource sets the "signup_source" field.
+func (m *UserMutation) SetSignupSource(s string) {
+ m.signup_source = &s
}
-// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation.
-func (m *UserMutation) SoraStorageQuotaBytes() (r int64, exists bool) {
- v := m.sora_storage_quota_bytes
+// SignupSource returns the value of the "signup_source" field in the mutation.
+func (m *UserMutation) SignupSource() (r string, exists bool) {
+ v := m.signup_source
if v == nil {
return
}
return *v, true
}
-// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the User entity.
+// OldSignupSource returns the old "signup_source" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *UserMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) {
+func (m *UserMutation) OldSignupSource(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations")
+ return v, errors.New("OldSignupSource is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation")
+ return v, errors.New("OldSignupSource requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err)
+ return v, fmt.Errorf("querying old value for OldSignupSource: %w", err)
}
- return oldValue.SoraStorageQuotaBytes, nil
+ return oldValue.SignupSource, nil
}
-// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field.
-func (m *UserMutation) AddSoraStorageQuotaBytes(i int64) {
- if m.addsora_storage_quota_bytes != nil {
- *m.addsora_storage_quota_bytes += i
- } else {
- m.addsora_storage_quota_bytes = &i
- }
+// ResetSignupSource resets all changes to the "signup_source" field.
+func (m *UserMutation) ResetSignupSource() {
+ m.signup_source = nil
}
-// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation.
-func (m *UserMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) {
- v := m.addsora_storage_quota_bytes
+// SetLastLoginAt sets the "last_login_at" field.
+func (m *UserMutation) SetLastLoginAt(t time.Time) {
+ m.last_login_at = &t
+}
+
+// LastLoginAt returns the value of the "last_login_at" field in the mutation.
+func (m *UserMutation) LastLoginAt() (r time.Time, exists bool) {
+ v := m.last_login_at
if v == nil {
return
}
return *v, true
}
-// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field.
-func (m *UserMutation) ResetSoraStorageQuotaBytes() {
- m.sora_storage_quota_bytes = nil
- m.addsora_storage_quota_bytes = nil
-}
-
-// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
-func (m *UserMutation) SetSoraStorageUsedBytes(i int64) {
- m.sora_storage_used_bytes = &i
- m.addsora_storage_used_bytes = nil
-}
-
-// SoraStorageUsedBytes returns the value of the "sora_storage_used_bytes" field in the mutation.
-func (m *UserMutation) SoraStorageUsedBytes() (r int64, exists bool) {
- v := m.sora_storage_used_bytes
- if v == nil {
- return
- }
- return *v, true
-}
-
-// OldSoraStorageUsedBytes returns the old "sora_storage_used_bytes" field's value of the User entity.
+// OldLastLoginAt returns the old "last_login_at" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
-func (m *UserMutation) OldSoraStorageUsedBytes(ctx context.Context) (v int64, err error) {
+func (m *UserMutation) OldLastLoginAt(ctx context.Context) (v *time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
- return v, errors.New("OldSoraStorageUsedBytes is only allowed on UpdateOne operations")
+ return v, errors.New("OldLastLoginAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
- return v, errors.New("OldSoraStorageUsedBytes requires an ID field in the mutation")
+ return v, errors.New("OldLastLoginAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
- return v, fmt.Errorf("querying old value for OldSoraStorageUsedBytes: %w", err)
+ return v, fmt.Errorf("querying old value for OldLastLoginAt: %w", err)
}
- return oldValue.SoraStorageUsedBytes, nil
+ return oldValue.LastLoginAt, nil
}
-// AddSoraStorageUsedBytes adds i to the "sora_storage_used_bytes" field.
-func (m *UserMutation) AddSoraStorageUsedBytes(i int64) {
- if m.addsora_storage_used_bytes != nil {
- *m.addsora_storage_used_bytes += i
- } else {
- m.addsora_storage_used_bytes = &i
- }
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (m *UserMutation) ClearLastLoginAt() {
+ m.last_login_at = nil
+ m.clearedFields[user.FieldLastLoginAt] = struct{}{}
}
-// AddedSoraStorageUsedBytes returns the value that was added to the "sora_storage_used_bytes" field in this mutation.
-func (m *UserMutation) AddedSoraStorageUsedBytes() (r int64, exists bool) {
- v := m.addsora_storage_used_bytes
+// LastLoginAtCleared returns if the "last_login_at" field was cleared in this mutation.
+func (m *UserMutation) LastLoginAtCleared() bool {
+ _, ok := m.clearedFields[user.FieldLastLoginAt]
+ return ok
+}
+
+// ResetLastLoginAt resets all changes to the "last_login_at" field.
+func (m *UserMutation) ResetLastLoginAt() {
+ m.last_login_at = nil
+ delete(m.clearedFields, user.FieldLastLoginAt)
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (m *UserMutation) SetLastActiveAt(t time.Time) {
+ m.last_active_at = &t
+}
+
+// LastActiveAt returns the value of the "last_active_at" field in the mutation.
+func (m *UserMutation) LastActiveAt() (r time.Time, exists bool) {
+ v := m.last_active_at
if v == nil {
return
}
return *v, true
}
-// ResetSoraStorageUsedBytes resets all changes to the "sora_storage_used_bytes" field.
-func (m *UserMutation) ResetSoraStorageUsedBytes() {
- m.sora_storage_used_bytes = nil
- m.addsora_storage_used_bytes = nil
+// OldLastActiveAt returns the old "last_active_at" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldLastActiveAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldLastActiveAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldLastActiveAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldLastActiveAt: %w", err)
+ }
+ return oldValue.LastActiveAt, nil
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (m *UserMutation) ClearLastActiveAt() {
+ m.last_active_at = nil
+ m.clearedFields[user.FieldLastActiveAt] = struct{}{}
+}
+
+// LastActiveAtCleared returns if the "last_active_at" field was cleared in this mutation.
+func (m *UserMutation) LastActiveAtCleared() bool {
+ _, ok := m.clearedFields[user.FieldLastActiveAt]
+ return ok
+}
+
+// ResetLastActiveAt resets all changes to the "last_active_at" field.
+func (m *UserMutation) ResetLastActiveAt() {
+ m.last_active_at = nil
+ delete(m.clearedFields, user.FieldLastActiveAt)
+}
+
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (m *UserMutation) SetBalanceNotifyEnabled(b bool) {
+ m.balance_notify_enabled = &b
+}
+
+// BalanceNotifyEnabled returns the value of the "balance_notify_enabled" field in the mutation.
+func (m *UserMutation) BalanceNotifyEnabled() (r bool, exists bool) {
+ v := m.balance_notify_enabled
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBalanceNotifyEnabled returns the old "balance_notify_enabled" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldBalanceNotifyEnabled(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBalanceNotifyEnabled is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBalanceNotifyEnabled requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBalanceNotifyEnabled: %w", err)
+ }
+ return oldValue.BalanceNotifyEnabled, nil
+}
+
+// ResetBalanceNotifyEnabled resets all changes to the "balance_notify_enabled" field.
+func (m *UserMutation) ResetBalanceNotifyEnabled() {
+ m.balance_notify_enabled = nil
+}
+
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (m *UserMutation) SetBalanceNotifyThresholdType(s string) {
+ m.balance_notify_threshold_type = &s
+}
+
+// BalanceNotifyThresholdType returns the value of the "balance_notify_threshold_type" field in the mutation.
+func (m *UserMutation) BalanceNotifyThresholdType() (r string, exists bool) {
+ v := m.balance_notify_threshold_type
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBalanceNotifyThresholdType returns the old "balance_notify_threshold_type" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldBalanceNotifyThresholdType(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBalanceNotifyThresholdType is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBalanceNotifyThresholdType requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBalanceNotifyThresholdType: %w", err)
+ }
+ return oldValue.BalanceNotifyThresholdType, nil
+}
+
+// ResetBalanceNotifyThresholdType resets all changes to the "balance_notify_threshold_type" field.
+func (m *UserMutation) ResetBalanceNotifyThresholdType() {
+ m.balance_notify_threshold_type = nil
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (m *UserMutation) SetBalanceNotifyThreshold(f float64) {
+ m.balance_notify_threshold = &f
+ m.addbalance_notify_threshold = nil
+}
+
+// BalanceNotifyThreshold returns the value of the "balance_notify_threshold" field in the mutation.
+func (m *UserMutation) BalanceNotifyThreshold() (r float64, exists bool) {
+ v := m.balance_notify_threshold
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBalanceNotifyThreshold returns the old "balance_notify_threshold" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldBalanceNotifyThreshold(ctx context.Context) (v *float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBalanceNotifyThreshold is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBalanceNotifyThreshold requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBalanceNotifyThreshold: %w", err)
+ }
+ return oldValue.BalanceNotifyThreshold, nil
+}
+
+// AddBalanceNotifyThreshold adds f to the "balance_notify_threshold" field.
+func (m *UserMutation) AddBalanceNotifyThreshold(f float64) {
+ if m.addbalance_notify_threshold != nil {
+ *m.addbalance_notify_threshold += f
+ } else {
+ m.addbalance_notify_threshold = &f
+ }
+}
+
+// AddedBalanceNotifyThreshold returns the value that was added to the "balance_notify_threshold" field in this mutation.
+func (m *UserMutation) AddedBalanceNotifyThreshold() (r float64, exists bool) {
+ v := m.addbalance_notify_threshold
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (m *UserMutation) ClearBalanceNotifyThreshold() {
+ m.balance_notify_threshold = nil
+ m.addbalance_notify_threshold = nil
+ m.clearedFields[user.FieldBalanceNotifyThreshold] = struct{}{}
+}
+
+// BalanceNotifyThresholdCleared returns if the "balance_notify_threshold" field was cleared in this mutation.
+func (m *UserMutation) BalanceNotifyThresholdCleared() bool {
+ _, ok := m.clearedFields[user.FieldBalanceNotifyThreshold]
+ return ok
+}
+
+// ResetBalanceNotifyThreshold resets all changes to the "balance_notify_threshold" field.
+func (m *UserMutation) ResetBalanceNotifyThreshold() {
+ m.balance_notify_threshold = nil
+ m.addbalance_notify_threshold = nil
+ delete(m.clearedFields, user.FieldBalanceNotifyThreshold)
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (m *UserMutation) SetBalanceNotifyExtraEmails(s string) {
+ m.balance_notify_extra_emails = &s
+}
+
+// BalanceNotifyExtraEmails returns the value of the "balance_notify_extra_emails" field in the mutation.
+func (m *UserMutation) BalanceNotifyExtraEmails() (r string, exists bool) {
+ v := m.balance_notify_extra_emails
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldBalanceNotifyExtraEmails returns the old "balance_notify_extra_emails" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldBalanceNotifyExtraEmails(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldBalanceNotifyExtraEmails is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldBalanceNotifyExtraEmails requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldBalanceNotifyExtraEmails: %w", err)
+ }
+ return oldValue.BalanceNotifyExtraEmails, nil
+}
+
+// ResetBalanceNotifyExtraEmails resets all changes to the "balance_notify_extra_emails" field.
+func (m *UserMutation) ResetBalanceNotifyExtraEmails() {
+ m.balance_notify_extra_emails = nil
+}
+
+// SetTotalRecharged sets the "total_recharged" field.
+func (m *UserMutation) SetTotalRecharged(f float64) {
+ m.total_recharged = &f
+ m.addtotal_recharged = nil
+}
+
+// TotalRecharged returns the value of the "total_recharged" field in the mutation.
+func (m *UserMutation) TotalRecharged() (r float64, exists bool) {
+ v := m.total_recharged
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTotalRecharged returns the old "total_recharged" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldTotalRecharged(ctx context.Context) (v float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTotalRecharged is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTotalRecharged requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTotalRecharged: %w", err)
+ }
+ return oldValue.TotalRecharged, nil
+}
+
+// AddTotalRecharged adds f to the "total_recharged" field.
+func (m *UserMutation) AddTotalRecharged(f float64) {
+ if m.addtotal_recharged != nil {
+ *m.addtotal_recharged += f
+ } else {
+ m.addtotal_recharged = &f
+ }
+}
+
+// AddedTotalRecharged returns the value that was added to the "total_recharged" field in this mutation.
+func (m *UserMutation) AddedTotalRecharged() (r float64, exists bool) {
+ v := m.addtotal_recharged
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetTotalRecharged resets all changes to the "total_recharged" field.
+func (m *UserMutation) ResetTotalRecharged() {
+ m.total_recharged = nil
+ m.addtotal_recharged = nil
+}
+
+// SetRpmLimit sets the "rpm_limit" field.
+func (m *UserMutation) SetRpmLimit(i int) {
+ m.rpm_limit = &i
+ m.addrpm_limit = nil
+}
+
+// RpmLimit returns the value of the "rpm_limit" field in the mutation.
+func (m *UserMutation) RpmLimit() (r int, exists bool) {
+ v := m.rpm_limit
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRpmLimit returns the old "rpm_limit" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldRpmLimit(ctx context.Context) (v int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRpmLimit is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRpmLimit requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRpmLimit: %w", err)
+ }
+ return oldValue.RpmLimit, nil
+}
+
+// AddRpmLimit adds i to the "rpm_limit" field.
+func (m *UserMutation) AddRpmLimit(i int) {
+ if m.addrpm_limit != nil {
+ *m.addrpm_limit += i
+ } else {
+ m.addrpm_limit = &i
+ }
+}
+
+// AddedRpmLimit returns the value that was added to the "rpm_limit" field in this mutation.
+func (m *UserMutation) AddedRpmLimit() (r int, exists bool) {
+ v := m.addrpm_limit
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetRpmLimit resets all changes to the "rpm_limit" field.
+func (m *UserMutation) ResetRpmLimit() {
+ m.rpm_limit = nil
+ m.addrpm_limit = nil
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
@@ -24211,6 +39088,168 @@ func (m *UserMutation) ResetPromoCodeUsages() {
m.removedpromo_code_usages = nil
}
+// AddPaymentOrderIDs adds the "payment_orders" edge to the PaymentOrder entity by ids.
+func (m *UserMutation) AddPaymentOrderIDs(ids ...int64) {
+ if m.payment_orders == nil {
+ m.payment_orders = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.payment_orders[ids[i]] = struct{}{}
+ }
+}
+
+// ClearPaymentOrders clears the "payment_orders" edge to the PaymentOrder entity.
+func (m *UserMutation) ClearPaymentOrders() {
+ m.clearedpayment_orders = true
+}
+
+// PaymentOrdersCleared reports if the "payment_orders" edge to the PaymentOrder entity was cleared.
+func (m *UserMutation) PaymentOrdersCleared() bool {
+ return m.clearedpayment_orders
+}
+
+// RemovePaymentOrderIDs removes the "payment_orders" edge to the PaymentOrder entity by IDs.
+func (m *UserMutation) RemovePaymentOrderIDs(ids ...int64) {
+ if m.removedpayment_orders == nil {
+ m.removedpayment_orders = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.payment_orders, ids[i])
+ m.removedpayment_orders[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedPaymentOrders returns the removed IDs of the "payment_orders" edge to the PaymentOrder entity.
+func (m *UserMutation) RemovedPaymentOrdersIDs() (ids []int64) {
+ for id := range m.removedpayment_orders {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// PaymentOrdersIDs returns the "payment_orders" edge IDs in the mutation.
+func (m *UserMutation) PaymentOrdersIDs() (ids []int64) {
+ for id := range m.payment_orders {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetPaymentOrders resets all changes to the "payment_orders" edge.
+func (m *UserMutation) ResetPaymentOrders() {
+ m.payment_orders = nil
+ m.clearedpayment_orders = false
+ m.removedpayment_orders = nil
+}
+
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by ids.
+func (m *UserMutation) AddAuthIdentityIDs(ids ...int64) {
+ if m.auth_identities == nil {
+ m.auth_identities = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.auth_identities[ids[i]] = struct{}{}
+ }
+}
+
+// ClearAuthIdentities clears the "auth_identities" edge to the AuthIdentity entity.
+func (m *UserMutation) ClearAuthIdentities() {
+ m.clearedauth_identities = true
+}
+
+// AuthIdentitiesCleared reports if the "auth_identities" edge to the AuthIdentity entity was cleared.
+func (m *UserMutation) AuthIdentitiesCleared() bool {
+ return m.clearedauth_identities
+}
+
+// RemoveAuthIdentityIDs removes the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (m *UserMutation) RemoveAuthIdentityIDs(ids ...int64) {
+ if m.removedauth_identities == nil {
+ m.removedauth_identities = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.auth_identities, ids[i])
+ m.removedauth_identities[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedAuthIdentities returns the removed IDs of the "auth_identities" edge to the AuthIdentity entity.
+func (m *UserMutation) RemovedAuthIdentitiesIDs() (ids []int64) {
+ for id := range m.removedauth_identities {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// AuthIdentitiesIDs returns the "auth_identities" edge IDs in the mutation.
+func (m *UserMutation) AuthIdentitiesIDs() (ids []int64) {
+ for id := range m.auth_identities {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetAuthIdentities resets all changes to the "auth_identities" edge.
+func (m *UserMutation) ResetAuthIdentities() {
+ m.auth_identities = nil
+ m.clearedauth_identities = false
+ m.removedauth_identities = nil
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by ids.
+func (m *UserMutation) AddPendingAuthSessionIDs(ids ...int64) {
+ if m.pending_auth_sessions == nil {
+ m.pending_auth_sessions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ m.pending_auth_sessions[ids[i]] = struct{}{}
+ }
+}
+
+// ClearPendingAuthSessions clears the "pending_auth_sessions" edge to the PendingAuthSession entity.
+func (m *UserMutation) ClearPendingAuthSessions() {
+ m.clearedpending_auth_sessions = true
+}
+
+// PendingAuthSessionsCleared reports if the "pending_auth_sessions" edge to the PendingAuthSession entity was cleared.
+func (m *UserMutation) PendingAuthSessionsCleared() bool {
+ return m.clearedpending_auth_sessions
+}
+
+// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (m *UserMutation) RemovePendingAuthSessionIDs(ids ...int64) {
+ if m.removedpending_auth_sessions == nil {
+ m.removedpending_auth_sessions = make(map[int64]struct{})
+ }
+ for i := range ids {
+ delete(m.pending_auth_sessions, ids[i])
+ m.removedpending_auth_sessions[ids[i]] = struct{}{}
+ }
+}
+
+// RemovedPendingAuthSessions returns the removed IDs of the "pending_auth_sessions" edge to the PendingAuthSession entity.
+func (m *UserMutation) RemovedPendingAuthSessionsIDs() (ids []int64) {
+ for id := range m.removedpending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// PendingAuthSessionsIDs returns the "pending_auth_sessions" edge IDs in the mutation.
+func (m *UserMutation) PendingAuthSessionsIDs() (ids []int64) {
+ for id := range m.pending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return
+}
+
+// ResetPendingAuthSessions resets all changes to the "pending_auth_sessions" edge.
+func (m *UserMutation) ResetPendingAuthSessions() {
+ m.pending_auth_sessions = nil
+ m.clearedpending_auth_sessions = false
+ m.removedpending_auth_sessions = nil
+}
+
// Where appends a list predicates to the UserMutation builder.
func (m *UserMutation) Where(ps ...predicate.User) {
m.predicates = append(m.predicates, ps...)
@@ -24245,7 +39284,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
- fields := make([]string, 0, 16)
+ fields := make([]string, 0, 23)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -24288,11 +39327,32 @@ func (m *UserMutation) Fields() []string {
if m.totp_enabled_at != nil {
fields = append(fields, user.FieldTotpEnabledAt)
}
- if m.sora_storage_quota_bytes != nil {
- fields = append(fields, user.FieldSoraStorageQuotaBytes)
+ if m.signup_source != nil {
+ fields = append(fields, user.FieldSignupSource)
}
- if m.sora_storage_used_bytes != nil {
- fields = append(fields, user.FieldSoraStorageUsedBytes)
+ if m.last_login_at != nil {
+ fields = append(fields, user.FieldLastLoginAt)
+ }
+ if m.last_active_at != nil {
+ fields = append(fields, user.FieldLastActiveAt)
+ }
+ if m.balance_notify_enabled != nil {
+ fields = append(fields, user.FieldBalanceNotifyEnabled)
+ }
+ if m.balance_notify_threshold_type != nil {
+ fields = append(fields, user.FieldBalanceNotifyThresholdType)
+ }
+ if m.balance_notify_threshold != nil {
+ fields = append(fields, user.FieldBalanceNotifyThreshold)
+ }
+ if m.balance_notify_extra_emails != nil {
+ fields = append(fields, user.FieldBalanceNotifyExtraEmails)
+ }
+ if m.total_recharged != nil {
+ fields = append(fields, user.FieldTotalRecharged)
+ }
+ if m.rpm_limit != nil {
+ fields = append(fields, user.FieldRpmLimit)
}
return fields
}
@@ -24330,10 +39390,24 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.TotpEnabled()
case user.FieldTotpEnabledAt:
return m.TotpEnabledAt()
- case user.FieldSoraStorageQuotaBytes:
- return m.SoraStorageQuotaBytes()
- case user.FieldSoraStorageUsedBytes:
- return m.SoraStorageUsedBytes()
+ case user.FieldSignupSource:
+ return m.SignupSource()
+ case user.FieldLastLoginAt:
+ return m.LastLoginAt()
+ case user.FieldLastActiveAt:
+ return m.LastActiveAt()
+ case user.FieldBalanceNotifyEnabled:
+ return m.BalanceNotifyEnabled()
+ case user.FieldBalanceNotifyThresholdType:
+ return m.BalanceNotifyThresholdType()
+ case user.FieldBalanceNotifyThreshold:
+ return m.BalanceNotifyThreshold()
+ case user.FieldBalanceNotifyExtraEmails:
+ return m.BalanceNotifyExtraEmails()
+ case user.FieldTotalRecharged:
+ return m.TotalRecharged()
+ case user.FieldRpmLimit:
+ return m.RpmLimit()
}
return nil, false
}
@@ -24371,10 +39445,24 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldTotpEnabled(ctx)
case user.FieldTotpEnabledAt:
return m.OldTotpEnabledAt(ctx)
- case user.FieldSoraStorageQuotaBytes:
- return m.OldSoraStorageQuotaBytes(ctx)
- case user.FieldSoraStorageUsedBytes:
- return m.OldSoraStorageUsedBytes(ctx)
+ case user.FieldSignupSource:
+ return m.OldSignupSource(ctx)
+ case user.FieldLastLoginAt:
+ return m.OldLastLoginAt(ctx)
+ case user.FieldLastActiveAt:
+ return m.OldLastActiveAt(ctx)
+ case user.FieldBalanceNotifyEnabled:
+ return m.OldBalanceNotifyEnabled(ctx)
+ case user.FieldBalanceNotifyThresholdType:
+ return m.OldBalanceNotifyThresholdType(ctx)
+ case user.FieldBalanceNotifyThreshold:
+ return m.OldBalanceNotifyThreshold(ctx)
+ case user.FieldBalanceNotifyExtraEmails:
+ return m.OldBalanceNotifyExtraEmails(ctx)
+ case user.FieldTotalRecharged:
+ return m.OldTotalRecharged(ctx)
+ case user.FieldRpmLimit:
+ return m.OldRpmLimit(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
@@ -24482,19 +39570,68 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetTotpEnabledAt(v)
return nil
- case user.FieldSoraStorageQuotaBytes:
- v, ok := value.(int64)
+ case user.FieldSignupSource:
+ v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetSoraStorageQuotaBytes(v)
+ m.SetSignupSource(v)
return nil
- case user.FieldSoraStorageUsedBytes:
- v, ok := value.(int64)
+ case user.FieldLastLoginAt:
+ v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.SetSoraStorageUsedBytes(v)
+ m.SetLastLoginAt(v)
+ return nil
+ case user.FieldLastActiveAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetLastActiveAt(v)
+ return nil
+ case user.FieldBalanceNotifyEnabled:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBalanceNotifyEnabled(v)
+ return nil
+ case user.FieldBalanceNotifyThresholdType:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBalanceNotifyThresholdType(v)
+ return nil
+ case user.FieldBalanceNotifyThreshold:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBalanceNotifyThreshold(v)
+ return nil
+ case user.FieldBalanceNotifyExtraEmails:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetBalanceNotifyExtraEmails(v)
+ return nil
+ case user.FieldTotalRecharged:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotalRecharged(v)
+ return nil
+ case user.FieldRpmLimit:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRpmLimit(v)
return nil
}
return fmt.Errorf("unknown User field %s", name)
@@ -24510,11 +39647,14 @@ func (m *UserMutation) AddedFields() []string {
if m.addconcurrency != nil {
fields = append(fields, user.FieldConcurrency)
}
- if m.addsora_storage_quota_bytes != nil {
- fields = append(fields, user.FieldSoraStorageQuotaBytes)
+ if m.addbalance_notify_threshold != nil {
+ fields = append(fields, user.FieldBalanceNotifyThreshold)
}
- if m.addsora_storage_used_bytes != nil {
- fields = append(fields, user.FieldSoraStorageUsedBytes)
+ if m.addtotal_recharged != nil {
+ fields = append(fields, user.FieldTotalRecharged)
+ }
+ if m.addrpm_limit != nil {
+ fields = append(fields, user.FieldRpmLimit)
}
return fields
}
@@ -24528,10 +39668,12 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedBalance()
case user.FieldConcurrency:
return m.AddedConcurrency()
- case user.FieldSoraStorageQuotaBytes:
- return m.AddedSoraStorageQuotaBytes()
- case user.FieldSoraStorageUsedBytes:
- return m.AddedSoraStorageUsedBytes()
+ case user.FieldBalanceNotifyThreshold:
+ return m.AddedBalanceNotifyThreshold()
+ case user.FieldTotalRecharged:
+ return m.AddedTotalRecharged()
+ case user.FieldRpmLimit:
+ return m.AddedRpmLimit()
}
return nil, false
}
@@ -24555,19 +39697,26 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
}
m.AddConcurrency(v)
return nil
- case user.FieldSoraStorageQuotaBytes:
- v, ok := value.(int64)
+ case user.FieldBalanceNotifyThreshold:
+ v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.AddSoraStorageQuotaBytes(v)
+ m.AddBalanceNotifyThreshold(v)
return nil
- case user.FieldSoraStorageUsedBytes:
- v, ok := value.(int64)
+ case user.FieldTotalRecharged:
+ v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
- m.AddSoraStorageUsedBytes(v)
+ m.AddTotalRecharged(v)
+ return nil
+ case user.FieldRpmLimit:
+ v, ok := value.(int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddRpmLimit(v)
return nil
}
return fmt.Errorf("unknown User numeric field %s", name)
@@ -24586,6 +39735,15 @@ func (m *UserMutation) ClearedFields() []string {
if m.FieldCleared(user.FieldTotpEnabledAt) {
fields = append(fields, user.FieldTotpEnabledAt)
}
+ if m.FieldCleared(user.FieldLastLoginAt) {
+ fields = append(fields, user.FieldLastLoginAt)
+ }
+ if m.FieldCleared(user.FieldLastActiveAt) {
+ fields = append(fields, user.FieldLastActiveAt)
+ }
+ if m.FieldCleared(user.FieldBalanceNotifyThreshold) {
+ fields = append(fields, user.FieldBalanceNotifyThreshold)
+ }
return fields
}
@@ -24609,6 +39767,15 @@ func (m *UserMutation) ClearField(name string) error {
case user.FieldTotpEnabledAt:
m.ClearTotpEnabledAt()
return nil
+ case user.FieldLastLoginAt:
+ m.ClearLastLoginAt()
+ return nil
+ case user.FieldLastActiveAt:
+ m.ClearLastActiveAt()
+ return nil
+ case user.FieldBalanceNotifyThreshold:
+ m.ClearBalanceNotifyThreshold()
+ return nil
}
return fmt.Errorf("unknown User nullable field %s", name)
}
@@ -24659,11 +39826,32 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldTotpEnabledAt:
m.ResetTotpEnabledAt()
return nil
- case user.FieldSoraStorageQuotaBytes:
- m.ResetSoraStorageQuotaBytes()
+ case user.FieldSignupSource:
+ m.ResetSignupSource()
return nil
- case user.FieldSoraStorageUsedBytes:
- m.ResetSoraStorageUsedBytes()
+ case user.FieldLastLoginAt:
+ m.ResetLastLoginAt()
+ return nil
+ case user.FieldLastActiveAt:
+ m.ResetLastActiveAt()
+ return nil
+ case user.FieldBalanceNotifyEnabled:
+ m.ResetBalanceNotifyEnabled()
+ return nil
+ case user.FieldBalanceNotifyThresholdType:
+ m.ResetBalanceNotifyThresholdType()
+ return nil
+ case user.FieldBalanceNotifyThreshold:
+ m.ResetBalanceNotifyThreshold()
+ return nil
+ case user.FieldBalanceNotifyExtraEmails:
+ m.ResetBalanceNotifyExtraEmails()
+ return nil
+ case user.FieldTotalRecharged:
+ m.ResetTotalRecharged()
+ return nil
+ case user.FieldRpmLimit:
+ m.ResetRpmLimit()
return nil
}
return fmt.Errorf("unknown User field %s", name)
@@ -24671,7 +39859,7 @@ func (m *UserMutation) ResetField(name string) error {
// AddedEdges returns all edge names that were set/added in this mutation.
func (m *UserMutation) AddedEdges() []string {
- edges := make([]string, 0, 9)
+ edges := make([]string, 0, 12)
if m.api_keys != nil {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -24699,6 +39887,15 @@ func (m *UserMutation) AddedEdges() []string {
if m.promo_code_usages != nil {
edges = append(edges, user.EdgePromoCodeUsages)
}
+ if m.payment_orders != nil {
+ edges = append(edges, user.EdgePaymentOrders)
+ }
+ if m.auth_identities != nil {
+ edges = append(edges, user.EdgeAuthIdentities)
+ }
+ if m.pending_auth_sessions != nil {
+ edges = append(edges, user.EdgePendingAuthSessions)
+ }
return edges
}
@@ -24760,13 +39957,31 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case user.EdgePaymentOrders:
+ ids := make([]ent.Value, 0, len(m.payment_orders))
+ for id := range m.payment_orders {
+ ids = append(ids, id)
+ }
+ return ids
+ case user.EdgeAuthIdentities:
+ ids := make([]ent.Value, 0, len(m.auth_identities))
+ for id := range m.auth_identities {
+ ids = append(ids, id)
+ }
+ return ids
+ case user.EdgePendingAuthSessions:
+ ids := make([]ent.Value, 0, len(m.pending_auth_sessions))
+ for id := range m.pending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return ids
}
return nil
}
// RemovedEdges returns all edge names that were removed in this mutation.
func (m *UserMutation) RemovedEdges() []string {
- edges := make([]string, 0, 9)
+ edges := make([]string, 0, 12)
if m.removedapi_keys != nil {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -24794,6 +40009,15 @@ func (m *UserMutation) RemovedEdges() []string {
if m.removedpromo_code_usages != nil {
edges = append(edges, user.EdgePromoCodeUsages)
}
+ if m.removedpayment_orders != nil {
+ edges = append(edges, user.EdgePaymentOrders)
+ }
+ if m.removedauth_identities != nil {
+ edges = append(edges, user.EdgeAuthIdentities)
+ }
+ if m.removedpending_auth_sessions != nil {
+ edges = append(edges, user.EdgePendingAuthSessions)
+ }
return edges
}
@@ -24855,13 +40079,31 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value {
ids = append(ids, id)
}
return ids
+ case user.EdgePaymentOrders:
+ ids := make([]ent.Value, 0, len(m.removedpayment_orders))
+ for id := range m.removedpayment_orders {
+ ids = append(ids, id)
+ }
+ return ids
+ case user.EdgeAuthIdentities:
+ ids := make([]ent.Value, 0, len(m.removedauth_identities))
+ for id := range m.removedauth_identities {
+ ids = append(ids, id)
+ }
+ return ids
+ case user.EdgePendingAuthSessions:
+ ids := make([]ent.Value, 0, len(m.removedpending_auth_sessions))
+ for id := range m.removedpending_auth_sessions {
+ ids = append(ids, id)
+ }
+ return ids
}
return nil
}
// ClearedEdges returns all edge names that were cleared in this mutation.
func (m *UserMutation) ClearedEdges() []string {
- edges := make([]string, 0, 9)
+ edges := make([]string, 0, 12)
if m.clearedapi_keys {
edges = append(edges, user.EdgeAPIKeys)
}
@@ -24889,6 +40131,15 @@ func (m *UserMutation) ClearedEdges() []string {
if m.clearedpromo_code_usages {
edges = append(edges, user.EdgePromoCodeUsages)
}
+ if m.clearedpayment_orders {
+ edges = append(edges, user.EdgePaymentOrders)
+ }
+ if m.clearedauth_identities {
+ edges = append(edges, user.EdgeAuthIdentities)
+ }
+ if m.clearedpending_auth_sessions {
+ edges = append(edges, user.EdgePendingAuthSessions)
+ }
return edges
}
@@ -24914,6 +40165,12 @@ func (m *UserMutation) EdgeCleared(name string) bool {
return m.clearedattribute_values
case user.EdgePromoCodeUsages:
return m.clearedpromo_code_usages
+ case user.EdgePaymentOrders:
+ return m.clearedpayment_orders
+ case user.EdgeAuthIdentities:
+ return m.clearedauth_identities
+ case user.EdgePendingAuthSessions:
+ return m.clearedpending_auth_sessions
}
return false
}
@@ -24957,6 +40214,15 @@ func (m *UserMutation) ResetEdge(name string) error {
case user.EdgePromoCodeUsages:
m.ResetPromoCodeUsages()
return nil
+ case user.EdgePaymentOrders:
+ m.ResetPaymentOrders()
+ return nil
+ case user.EdgeAuthIdentities:
+ m.ResetAuthIdentities()
+ return nil
+ case user.EdgePendingAuthSessions:
+ m.ResetPendingAuthSessions()
+ return nil
}
return fmt.Errorf("unknown User edge %s", name)
}
diff --git a/backend/ent/paymentauditlog.go b/backend/ent/paymentauditlog.go
new file mode 100644
index 00000000..ffcdcfcd
--- /dev/null
+++ b/backend/ent/paymentauditlog.go
@@ -0,0 +1,150 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
+)
+
+// PaymentAuditLog is the model entity for the PaymentAuditLog schema.
+type PaymentAuditLog struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // OrderID holds the value of the "order_id" field.
+ OrderID string `json:"order_id,omitempty"`
+ // Action holds the value of the "action" field.
+ Action string `json:"action,omitempty"`
+ // Detail holds the value of the "detail" field.
+ Detail string `json:"detail,omitempty"`
+ // Operator holds the value of the "operator" field.
+ Operator string `json:"operator,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ selectValues sql.SelectValues
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*PaymentAuditLog) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case paymentauditlog.FieldID:
+ values[i] = new(sql.NullInt64)
+ case paymentauditlog.FieldOrderID, paymentauditlog.FieldAction, paymentauditlog.FieldDetail, paymentauditlog.FieldOperator:
+ values[i] = new(sql.NullString)
+ case paymentauditlog.FieldCreatedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the PaymentAuditLog fields.
+func (_m *PaymentAuditLog) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case paymentauditlog.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case paymentauditlog.FieldOrderID:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field order_id", values[i])
+ } else if value.Valid {
+ _m.OrderID = value.String
+ }
+ case paymentauditlog.FieldAction:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field action", values[i])
+ } else if value.Valid {
+ _m.Action = value.String
+ }
+ case paymentauditlog.FieldDetail:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field detail", values[i])
+ } else if value.Valid {
+ _m.Detail = value.String
+ }
+ case paymentauditlog.FieldOperator:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field operator", values[i])
+ } else if value.Valid {
+ _m.Operator = value.String
+ }
+ case paymentauditlog.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the PaymentAuditLog.
+// This includes values selected through modifiers, order, etc.
+func (_m *PaymentAuditLog) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// Update returns a builder for updating this PaymentAuditLog.
+// Note that you need to call PaymentAuditLog.Unwrap() before calling this method if this PaymentAuditLog
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *PaymentAuditLog) Update() *PaymentAuditLogUpdateOne {
+ return NewPaymentAuditLogClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the PaymentAuditLog entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *PaymentAuditLog) Unwrap() *PaymentAuditLog {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: PaymentAuditLog is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *PaymentAuditLog) String() string {
+ var builder strings.Builder
+ builder.WriteString("PaymentAuditLog(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("order_id=")
+ builder.WriteString(_m.OrderID)
+ builder.WriteString(", ")
+ builder.WriteString("action=")
+ builder.WriteString(_m.Action)
+ builder.WriteString(", ")
+ builder.WriteString("detail=")
+ builder.WriteString(_m.Detail)
+ builder.WriteString(", ")
+ builder.WriteString("operator=")
+ builder.WriteString(_m.Operator)
+ builder.WriteString(", ")
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// PaymentAuditLogs is a parsable slice of PaymentAuditLog.
+type PaymentAuditLogs []*PaymentAuditLog
diff --git a/backend/ent/paymentauditlog/paymentauditlog.go b/backend/ent/paymentauditlog/paymentauditlog.go
new file mode 100644
index 00000000..9d480eef
--- /dev/null
+++ b/backend/ent/paymentauditlog/paymentauditlog.go
@@ -0,0 +1,96 @@
+// Code generated by ent, DO NOT EDIT.
+
+package paymentauditlog
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+)
+
+const (
+ // Label holds the string label denoting the paymentauditlog type in the database.
+ Label = "payment_audit_log"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldOrderID holds the string denoting the order_id field in the database.
+ FieldOrderID = "order_id"
+ // FieldAction holds the string denoting the action field in the database.
+ FieldAction = "action"
+ // FieldDetail holds the string denoting the detail field in the database.
+ FieldDetail = "detail"
+ // FieldOperator holds the string denoting the operator field in the database.
+ FieldOperator = "operator"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // Table holds the table name of the paymentauditlog in the database.
+ Table = "payment_audit_logs"
+)
+
+// Columns holds all SQL columns for paymentauditlog fields.
+var Columns = []string{
+ FieldID,
+ FieldOrderID,
+ FieldAction,
+ FieldDetail,
+ FieldOperator,
+ FieldCreatedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // OrderIDValidator is a validator for the "order_id" field. It is called by the builders before save.
+ OrderIDValidator func(string) error
+ // ActionValidator is a validator for the "action" field. It is called by the builders before save.
+ ActionValidator func(string) error
+ // DefaultDetail holds the default value on creation for the "detail" field.
+ DefaultDetail string
+ // DefaultOperator holds the default value on creation for the "operator" field.
+ DefaultOperator string
+ // OperatorValidator is a validator for the "operator" field. It is called by the builders before save.
+ OperatorValidator func(string) error
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the PaymentAuditLog queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByOrderID orders the results by the order_id field.
+func ByOrderID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldOrderID, opts...).ToFunc()
+}
+
+// ByAction orders the results by the action field.
+func ByAction(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAction, opts...).ToFunc()
+}
+
+// ByDetail orders the results by the detail field.
+func ByDetail(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDetail, opts...).ToFunc()
+}
+
+// ByOperator orders the results by the operator field.
+func ByOperator(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldOperator, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
diff --git a/backend/ent/paymentauditlog/where.go b/backend/ent/paymentauditlog/where.go
new file mode 100644
index 00000000..2fd80a42
--- /dev/null
+++ b/backend/ent/paymentauditlog/where.go
@@ -0,0 +1,395 @@
+// Code generated by ent, DO NOT EDIT.
+
+package paymentauditlog
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldLTE(FieldID, id))
+}
+
+// OrderID applies equality check predicate on the "order_id" field. It's identical to OrderIDEQ.
+func OrderID(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEQ(FieldOrderID, v))
+}
+
+// Action applies equality check predicate on the "action" field. It's identical to ActionEQ.
+func Action(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEQ(FieldAction, v))
+}
+
+// Detail applies equality check predicate on the "detail" field. It's identical to DetailEQ.
+func Detail(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEQ(FieldDetail, v))
+}
+
+// Operator applies equality check predicate on the "operator" field. It's identical to OperatorEQ.
+func Operator(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEQ(FieldOperator, v))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// OrderIDEQ applies the EQ predicate on the "order_id" field.
+func OrderIDEQ(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEQ(FieldOrderID, v))
+}
+
+// OrderIDNEQ applies the NEQ predicate on the "order_id" field.
+func OrderIDNEQ(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldNEQ(FieldOrderID, v))
+}
+
+// OrderIDIn applies the In predicate on the "order_id" field.
+func OrderIDIn(vs ...string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldIn(FieldOrderID, vs...))
+}
+
+// OrderIDNotIn applies the NotIn predicate on the "order_id" field.
+func OrderIDNotIn(vs ...string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldNotIn(FieldOrderID, vs...))
+}
+
+// OrderIDGT applies the GT predicate on the "order_id" field.
+func OrderIDGT(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldGT(FieldOrderID, v))
+}
+
+// OrderIDGTE applies the GTE predicate on the "order_id" field.
+func OrderIDGTE(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldGTE(FieldOrderID, v))
+}
+
+// OrderIDLT applies the LT predicate on the "order_id" field.
+func OrderIDLT(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldLT(FieldOrderID, v))
+}
+
+// OrderIDLTE applies the LTE predicate on the "order_id" field.
+func OrderIDLTE(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldLTE(FieldOrderID, v))
+}
+
+// OrderIDContains applies the Contains predicate on the "order_id" field.
+func OrderIDContains(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldContains(FieldOrderID, v))
+}
+
+// OrderIDHasPrefix applies the HasPrefix predicate on the "order_id" field.
+func OrderIDHasPrefix(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldHasPrefix(FieldOrderID, v))
+}
+
+// OrderIDHasSuffix applies the HasSuffix predicate on the "order_id" field.
+func OrderIDHasSuffix(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldHasSuffix(FieldOrderID, v))
+}
+
+// OrderIDEqualFold applies the EqualFold predicate on the "order_id" field.
+func OrderIDEqualFold(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEqualFold(FieldOrderID, v))
+}
+
+// OrderIDContainsFold applies the ContainsFold predicate on the "order_id" field.
+func OrderIDContainsFold(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldContainsFold(FieldOrderID, v))
+}
+
+// ActionEQ applies the EQ predicate on the "action" field.
+func ActionEQ(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEQ(FieldAction, v))
+}
+
+// ActionNEQ applies the NEQ predicate on the "action" field.
+func ActionNEQ(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldNEQ(FieldAction, v))
+}
+
+// ActionIn applies the In predicate on the "action" field.
+func ActionIn(vs ...string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldIn(FieldAction, vs...))
+}
+
+// ActionNotIn applies the NotIn predicate on the "action" field.
+func ActionNotIn(vs ...string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldNotIn(FieldAction, vs...))
+}
+
+// ActionGT applies the GT predicate on the "action" field.
+func ActionGT(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldGT(FieldAction, v))
+}
+
+// ActionGTE applies the GTE predicate on the "action" field.
+func ActionGTE(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldGTE(FieldAction, v))
+}
+
+// ActionLT applies the LT predicate on the "action" field.
+func ActionLT(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldLT(FieldAction, v))
+}
+
+// ActionLTE applies the LTE predicate on the "action" field.
+func ActionLTE(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldLTE(FieldAction, v))
+}
+
+// ActionContains applies the Contains predicate on the "action" field.
+func ActionContains(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldContains(FieldAction, v))
+}
+
+// ActionHasPrefix applies the HasPrefix predicate on the "action" field.
+func ActionHasPrefix(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldHasPrefix(FieldAction, v))
+}
+
+// ActionHasSuffix applies the HasSuffix predicate on the "action" field.
+func ActionHasSuffix(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldHasSuffix(FieldAction, v))
+}
+
+// ActionEqualFold applies the EqualFold predicate on the "action" field.
+func ActionEqualFold(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEqualFold(FieldAction, v))
+}
+
+// ActionContainsFold applies the ContainsFold predicate on the "action" field.
+func ActionContainsFold(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldContainsFold(FieldAction, v))
+}
+
+// DetailEQ applies the EQ predicate on the "detail" field.
+func DetailEQ(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEQ(FieldDetail, v))
+}
+
+// DetailNEQ applies the NEQ predicate on the "detail" field.
+func DetailNEQ(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldNEQ(FieldDetail, v))
+}
+
+// DetailIn applies the In predicate on the "detail" field.
+func DetailIn(vs ...string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldIn(FieldDetail, vs...))
+}
+
+// DetailNotIn applies the NotIn predicate on the "detail" field.
+func DetailNotIn(vs ...string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldNotIn(FieldDetail, vs...))
+}
+
+// DetailGT applies the GT predicate on the "detail" field.
+func DetailGT(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldGT(FieldDetail, v))
+}
+
+// DetailGTE applies the GTE predicate on the "detail" field.
+func DetailGTE(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldGTE(FieldDetail, v))
+}
+
+// DetailLT applies the LT predicate on the "detail" field.
+func DetailLT(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldLT(FieldDetail, v))
+}
+
+// DetailLTE applies the LTE predicate on the "detail" field.
+func DetailLTE(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldLTE(FieldDetail, v))
+}
+
+// DetailContains applies the Contains predicate on the "detail" field.
+func DetailContains(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldContains(FieldDetail, v))
+}
+
+// DetailHasPrefix applies the HasPrefix predicate on the "detail" field.
+func DetailHasPrefix(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldHasPrefix(FieldDetail, v))
+}
+
+// DetailHasSuffix applies the HasSuffix predicate on the "detail" field.
+func DetailHasSuffix(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldHasSuffix(FieldDetail, v))
+}
+
+// DetailEqualFold applies the EqualFold predicate on the "detail" field.
+func DetailEqualFold(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEqualFold(FieldDetail, v))
+}
+
+// DetailContainsFold applies the ContainsFold predicate on the "detail" field.
+func DetailContainsFold(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldContainsFold(FieldDetail, v))
+}
+
+// OperatorEQ applies the EQ predicate on the "operator" field.
+func OperatorEQ(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEQ(FieldOperator, v))
+}
+
+// OperatorNEQ applies the NEQ predicate on the "operator" field.
+func OperatorNEQ(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldNEQ(FieldOperator, v))
+}
+
+// OperatorIn applies the In predicate on the "operator" field.
+func OperatorIn(vs ...string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldIn(FieldOperator, vs...))
+}
+
+// OperatorNotIn applies the NotIn predicate on the "operator" field.
+func OperatorNotIn(vs ...string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldNotIn(FieldOperator, vs...))
+}
+
+// OperatorGT applies the GT predicate on the "operator" field.
+func OperatorGT(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldGT(FieldOperator, v))
+}
+
+// OperatorGTE applies the GTE predicate on the "operator" field.
+func OperatorGTE(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldGTE(FieldOperator, v))
+}
+
+// OperatorLT applies the LT predicate on the "operator" field.
+func OperatorLT(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldLT(FieldOperator, v))
+}
+
+// OperatorLTE applies the LTE predicate on the "operator" field.
+func OperatorLTE(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldLTE(FieldOperator, v))
+}
+
+// OperatorContains applies the Contains predicate on the "operator" field.
+func OperatorContains(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldContains(FieldOperator, v))
+}
+
+// OperatorHasPrefix applies the HasPrefix predicate on the "operator" field.
+func OperatorHasPrefix(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldHasPrefix(FieldOperator, v))
+}
+
+// OperatorHasSuffix applies the HasSuffix predicate on the "operator" field.
+func OperatorHasSuffix(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldHasSuffix(FieldOperator, v))
+}
+
+// OperatorEqualFold applies the EqualFold predicate on the "operator" field.
+func OperatorEqualFold(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEqualFold(FieldOperator, v))
+}
+
+// OperatorContainsFold applies the ContainsFold predicate on the "operator" field.
+func OperatorContainsFold(v string) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldContainsFold(FieldOperator, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.PaymentAuditLog) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.PaymentAuditLog) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.PaymentAuditLog) predicate.PaymentAuditLog {
+ return predicate.PaymentAuditLog(sql.NotPredicates(p))
+}
diff --git a/backend/ent/paymentauditlog_create.go b/backend/ent/paymentauditlog_create.go
new file mode 100644
index 00000000..1906aba1
--- /dev/null
+++ b/backend/ent/paymentauditlog_create.go
@@ -0,0 +1,696 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
+)
+
+// PaymentAuditLogCreate is the builder for creating a PaymentAuditLog entity.
+type PaymentAuditLogCreate struct {
+ config
+ mutation *PaymentAuditLogMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetOrderID sets the "order_id" field.
+func (_c *PaymentAuditLogCreate) SetOrderID(v string) *PaymentAuditLogCreate {
+ _c.mutation.SetOrderID(v)
+ return _c
+}
+
+// SetAction sets the "action" field.
+func (_c *PaymentAuditLogCreate) SetAction(v string) *PaymentAuditLogCreate {
+ _c.mutation.SetAction(v)
+ return _c
+}
+
+// SetDetail sets the "detail" field.
+func (_c *PaymentAuditLogCreate) SetDetail(v string) *PaymentAuditLogCreate {
+ _c.mutation.SetDetail(v)
+ return _c
+}
+
+// SetNillableDetail sets the "detail" field if the given value is not nil.
+func (_c *PaymentAuditLogCreate) SetNillableDetail(v *string) *PaymentAuditLogCreate {
+ if v != nil {
+ _c.SetDetail(*v)
+ }
+ return _c
+}
+
+// SetOperator sets the "operator" field.
+func (_c *PaymentAuditLogCreate) SetOperator(v string) *PaymentAuditLogCreate {
+ _c.mutation.SetOperator(v)
+ return _c
+}
+
+// SetNillableOperator sets the "operator" field if the given value is not nil.
+func (_c *PaymentAuditLogCreate) SetNillableOperator(v *string) *PaymentAuditLogCreate {
+ if v != nil {
+ _c.SetOperator(*v)
+ }
+ return _c
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *PaymentAuditLogCreate) SetCreatedAt(v time.Time) *PaymentAuditLogCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *PaymentAuditLogCreate) SetNillableCreatedAt(v *time.Time) *PaymentAuditLogCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// Mutation returns the PaymentAuditLogMutation object of the builder.
+func (_c *PaymentAuditLogCreate) Mutation() *PaymentAuditLogMutation {
+ return _c.mutation
+}
+
+// Save creates the PaymentAuditLog in the database.
+func (_c *PaymentAuditLogCreate) Save(ctx context.Context) (*PaymentAuditLog, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *PaymentAuditLogCreate) SaveX(ctx context.Context) *PaymentAuditLog {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PaymentAuditLogCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PaymentAuditLogCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *PaymentAuditLogCreate) defaults() {
+ if _, ok := _c.mutation.Detail(); !ok {
+ v := paymentauditlog.DefaultDetail
+ _c.mutation.SetDetail(v)
+ }
+ if _, ok := _c.mutation.Operator(); !ok {
+ v := paymentauditlog.DefaultOperator
+ _c.mutation.SetOperator(v)
+ }
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := paymentauditlog.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *PaymentAuditLogCreate) check() error {
+ if _, ok := _c.mutation.OrderID(); !ok {
+ return &ValidationError{Name: "order_id", err: errors.New(`ent: missing required field "PaymentAuditLog.order_id"`)}
+ }
+ if v, ok := _c.mutation.OrderID(); ok {
+ if err := paymentauditlog.OrderIDValidator(v); err != nil {
+ return &ValidationError{Name: "order_id", err: fmt.Errorf(`ent: validator failed for field "PaymentAuditLog.order_id": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Action(); !ok {
+ return &ValidationError{Name: "action", err: errors.New(`ent: missing required field "PaymentAuditLog.action"`)}
+ }
+ if v, ok := _c.mutation.Action(); ok {
+ if err := paymentauditlog.ActionValidator(v); err != nil {
+ return &ValidationError{Name: "action", err: fmt.Errorf(`ent: validator failed for field "PaymentAuditLog.action": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Detail(); !ok {
+ return &ValidationError{Name: "detail", err: errors.New(`ent: missing required field "PaymentAuditLog.detail"`)}
+ }
+ if _, ok := _c.mutation.Operator(); !ok {
+ return &ValidationError{Name: "operator", err: errors.New(`ent: missing required field "PaymentAuditLog.operator"`)}
+ }
+ if v, ok := _c.mutation.Operator(); ok {
+ if err := paymentauditlog.OperatorValidator(v); err != nil {
+ return &ValidationError{Name: "operator", err: fmt.Errorf(`ent: validator failed for field "PaymentAuditLog.operator": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PaymentAuditLog.created_at"`)}
+ }
+ return nil
+}
+
+func (_c *PaymentAuditLogCreate) sqlSave(ctx context.Context) (*PaymentAuditLog, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *PaymentAuditLogCreate) createSpec() (*PaymentAuditLog, *sqlgraph.CreateSpec) {
+ var (
+ _node = &PaymentAuditLog{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(paymentauditlog.Table, sqlgraph.NewFieldSpec(paymentauditlog.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.OrderID(); ok {
+ _spec.SetField(paymentauditlog.FieldOrderID, field.TypeString, value)
+ _node.OrderID = value
+ }
+ if value, ok := _c.mutation.Action(); ok {
+ _spec.SetField(paymentauditlog.FieldAction, field.TypeString, value)
+ _node.Action = value
+ }
+ if value, ok := _c.mutation.Detail(); ok {
+ _spec.SetField(paymentauditlog.FieldDetail, field.TypeString, value)
+ _node.Detail = value
+ }
+ if value, ok := _c.mutation.Operator(); ok {
+ _spec.SetField(paymentauditlog.FieldOperator, field.TypeString, value)
+ _node.Operator = value
+ }
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(paymentauditlog.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PaymentAuditLog.Create().
+// SetOrderID(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PaymentAuditLogUpsert) {
+// SetOrderID(v+v).
+// }).
+// Exec(ctx)
+func (_c *PaymentAuditLogCreate) OnConflict(opts ...sql.ConflictOption) *PaymentAuditLogUpsertOne {
+ _c.conflict = opts
+ return &PaymentAuditLogUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PaymentAuditLog.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PaymentAuditLogCreate) OnConflictColumns(columns ...string) *PaymentAuditLogUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PaymentAuditLogUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // PaymentAuditLogUpsertOne is the builder for "upsert"-ing
+ // one PaymentAuditLog node.
+ PaymentAuditLogUpsertOne struct {
+ create *PaymentAuditLogCreate
+ }
+
+ // PaymentAuditLogUpsert is the "OnConflict" setter.
+ PaymentAuditLogUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetOrderID sets the "order_id" field.
+func (u *PaymentAuditLogUpsert) SetOrderID(v string) *PaymentAuditLogUpsert {
+ u.Set(paymentauditlog.FieldOrderID, v)
+ return u
+}
+
+// UpdateOrderID sets the "order_id" field to the value that was provided on create.
+func (u *PaymentAuditLogUpsert) UpdateOrderID() *PaymentAuditLogUpsert {
+ u.SetExcluded(paymentauditlog.FieldOrderID)
+ return u
+}
+
+// SetAction sets the "action" field.
+func (u *PaymentAuditLogUpsert) SetAction(v string) *PaymentAuditLogUpsert {
+ u.Set(paymentauditlog.FieldAction, v)
+ return u
+}
+
+// UpdateAction sets the "action" field to the value that was provided on create.
+func (u *PaymentAuditLogUpsert) UpdateAction() *PaymentAuditLogUpsert {
+ u.SetExcluded(paymentauditlog.FieldAction)
+ return u
+}
+
+// SetDetail sets the "detail" field.
+func (u *PaymentAuditLogUpsert) SetDetail(v string) *PaymentAuditLogUpsert {
+ u.Set(paymentauditlog.FieldDetail, v)
+ return u
+}
+
+// UpdateDetail sets the "detail" field to the value that was provided on create.
+func (u *PaymentAuditLogUpsert) UpdateDetail() *PaymentAuditLogUpsert {
+ u.SetExcluded(paymentauditlog.FieldDetail)
+ return u
+}
+
+// SetOperator sets the "operator" field.
+func (u *PaymentAuditLogUpsert) SetOperator(v string) *PaymentAuditLogUpsert {
+ u.Set(paymentauditlog.FieldOperator, v)
+ return u
+}
+
+// UpdateOperator sets the "operator" field to the value that was provided on create.
+func (u *PaymentAuditLogUpsert) UpdateOperator() *PaymentAuditLogUpsert {
+ u.SetExcluded(paymentauditlog.FieldOperator)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.PaymentAuditLog.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PaymentAuditLogUpsertOne) UpdateNewValues() *PaymentAuditLogUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(paymentauditlog.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PaymentAuditLog.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PaymentAuditLogUpsertOne) Ignore() *PaymentAuditLogUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PaymentAuditLogUpsertOne) DoNothing() *PaymentAuditLogUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PaymentAuditLogCreate.OnConflict
+// documentation for more info.
+func (u *PaymentAuditLogUpsertOne) Update(set func(*PaymentAuditLogUpsert)) *PaymentAuditLogUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PaymentAuditLogUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetOrderID sets the "order_id" field.
+func (u *PaymentAuditLogUpsertOne) SetOrderID(v string) *PaymentAuditLogUpsertOne {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.SetOrderID(v)
+ })
+}
+
+// UpdateOrderID sets the "order_id" field to the value that was provided on create.
+func (u *PaymentAuditLogUpsertOne) UpdateOrderID() *PaymentAuditLogUpsertOne {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.UpdateOrderID()
+ })
+}
+
+// SetAction sets the "action" field.
+func (u *PaymentAuditLogUpsertOne) SetAction(v string) *PaymentAuditLogUpsertOne {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.SetAction(v)
+ })
+}
+
+// UpdateAction sets the "action" field to the value that was provided on create.
+func (u *PaymentAuditLogUpsertOne) UpdateAction() *PaymentAuditLogUpsertOne {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.UpdateAction()
+ })
+}
+
+// SetDetail sets the "detail" field.
+func (u *PaymentAuditLogUpsertOne) SetDetail(v string) *PaymentAuditLogUpsertOne {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.SetDetail(v)
+ })
+}
+
+// UpdateDetail sets the "detail" field to the value that was provided on create.
+func (u *PaymentAuditLogUpsertOne) UpdateDetail() *PaymentAuditLogUpsertOne {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.UpdateDetail()
+ })
+}
+
+// SetOperator sets the "operator" field.
+func (u *PaymentAuditLogUpsertOne) SetOperator(v string) *PaymentAuditLogUpsertOne {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.SetOperator(v)
+ })
+}
+
+// UpdateOperator sets the "operator" field to the value that was provided on create.
+func (u *PaymentAuditLogUpsertOne) UpdateOperator() *PaymentAuditLogUpsertOne {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.UpdateOperator()
+ })
+}
+
+// Exec executes the query.
+func (u *PaymentAuditLogUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PaymentAuditLogCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PaymentAuditLogUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *PaymentAuditLogUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *PaymentAuditLogUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// PaymentAuditLogCreateBulk is the builder for creating many PaymentAuditLog entities in bulk.
+type PaymentAuditLogCreateBulk struct {
+ config
+ err error
+ builders []*PaymentAuditLogCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the PaymentAuditLog entities in the database.
+func (_c *PaymentAuditLogCreateBulk) Save(ctx context.Context) ([]*PaymentAuditLog, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*PaymentAuditLog, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*PaymentAuditLogMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *PaymentAuditLogCreateBulk) SaveX(ctx context.Context) []*PaymentAuditLog {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PaymentAuditLogCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PaymentAuditLogCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PaymentAuditLog.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PaymentAuditLogUpsert) {
+// SetOrderID(v+v).
+// }).
+// Exec(ctx)
+func (_c *PaymentAuditLogCreateBulk) OnConflict(opts ...sql.ConflictOption) *PaymentAuditLogUpsertBulk {
+ _c.conflict = opts
+ return &PaymentAuditLogUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PaymentAuditLog.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PaymentAuditLogCreateBulk) OnConflictColumns(columns ...string) *PaymentAuditLogUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PaymentAuditLogUpsertBulk{
+ create: _c,
+ }
+}
+
+// PaymentAuditLogUpsertBulk is the builder for "upsert"-ing
+// a bulk of PaymentAuditLog nodes.
+type PaymentAuditLogUpsertBulk struct {
+ create *PaymentAuditLogCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.PaymentAuditLog.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PaymentAuditLogUpsertBulk) UpdateNewValues() *PaymentAuditLogUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(paymentauditlog.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PaymentAuditLog.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PaymentAuditLogUpsertBulk) Ignore() *PaymentAuditLogUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PaymentAuditLogUpsertBulk) DoNothing() *PaymentAuditLogUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PaymentAuditLogCreateBulk.OnConflict
+// documentation for more info.
+func (u *PaymentAuditLogUpsertBulk) Update(set func(*PaymentAuditLogUpsert)) *PaymentAuditLogUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PaymentAuditLogUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetOrderID sets the "order_id" field.
+func (u *PaymentAuditLogUpsertBulk) SetOrderID(v string) *PaymentAuditLogUpsertBulk {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.SetOrderID(v)
+ })
+}
+
+// UpdateOrderID sets the "order_id" field to the value that was provided on create.
+func (u *PaymentAuditLogUpsertBulk) UpdateOrderID() *PaymentAuditLogUpsertBulk {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.UpdateOrderID()
+ })
+}
+
+// SetAction sets the "action" field.
+func (u *PaymentAuditLogUpsertBulk) SetAction(v string) *PaymentAuditLogUpsertBulk {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.SetAction(v)
+ })
+}
+
+// UpdateAction sets the "action" field to the value that was provided on create.
+func (u *PaymentAuditLogUpsertBulk) UpdateAction() *PaymentAuditLogUpsertBulk {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.UpdateAction()
+ })
+}
+
+// SetDetail sets the "detail" field.
+func (u *PaymentAuditLogUpsertBulk) SetDetail(v string) *PaymentAuditLogUpsertBulk {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.SetDetail(v)
+ })
+}
+
+// UpdateDetail sets the "detail" field to the value that was provided on create.
+func (u *PaymentAuditLogUpsertBulk) UpdateDetail() *PaymentAuditLogUpsertBulk {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.UpdateDetail()
+ })
+}
+
+// SetOperator sets the "operator" field.
+func (u *PaymentAuditLogUpsertBulk) SetOperator(v string) *PaymentAuditLogUpsertBulk {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.SetOperator(v)
+ })
+}
+
+// UpdateOperator sets the "operator" field to the value that was provided on create.
+func (u *PaymentAuditLogUpsertBulk) UpdateOperator() *PaymentAuditLogUpsertBulk {
+ return u.Update(func(s *PaymentAuditLogUpsert) {
+ s.UpdateOperator()
+ })
+}
+
+// Exec executes the query.
+func (u *PaymentAuditLogUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the PaymentAuditLogCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PaymentAuditLogCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PaymentAuditLogUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/paymentauditlog_delete.go b/backend/ent/paymentauditlog_delete.go
new file mode 100644
index 00000000..ca22d8db
--- /dev/null
+++ b/backend/ent/paymentauditlog_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// PaymentAuditLogDelete is the builder for deleting a PaymentAuditLog entity.
+type PaymentAuditLogDelete struct {
+ config
+ hooks []Hook
+ mutation *PaymentAuditLogMutation
+}
+
+// Where appends a list predicates to the PaymentAuditLogDelete builder.
+func (_d *PaymentAuditLogDelete) Where(ps ...predicate.PaymentAuditLog) *PaymentAuditLogDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *PaymentAuditLogDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PaymentAuditLogDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *PaymentAuditLogDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(paymentauditlog.Table, sqlgraph.NewFieldSpec(paymentauditlog.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// PaymentAuditLogDeleteOne is the builder for deleting a single PaymentAuditLog entity.
+type PaymentAuditLogDeleteOne struct {
+ _d *PaymentAuditLogDelete
+}
+
+// Where appends a list predicates to the PaymentAuditLogDelete builder.
+func (_d *PaymentAuditLogDeleteOne) Where(ps ...predicate.PaymentAuditLog) *PaymentAuditLogDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *PaymentAuditLogDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{paymentauditlog.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PaymentAuditLogDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/paymentauditlog_query.go b/backend/ent/paymentauditlog_query.go
new file mode 100644
index 00000000..7a4e9115
--- /dev/null
+++ b/backend/ent/paymentauditlog_query.go
@@ -0,0 +1,564 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// PaymentAuditLogQuery is the builder for querying PaymentAuditLog entities.
+type PaymentAuditLogQuery struct {
+ config
+ ctx *QueryContext
+ order []paymentauditlog.OrderOption
+ inters []Interceptor
+ predicates []predicate.PaymentAuditLog
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the PaymentAuditLogQuery builder.
+func (_q *PaymentAuditLogQuery) Where(ps ...predicate.PaymentAuditLog) *PaymentAuditLogQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *PaymentAuditLogQuery) Limit(limit int) *PaymentAuditLogQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *PaymentAuditLogQuery) Offset(offset int) *PaymentAuditLogQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *PaymentAuditLogQuery) Unique(unique bool) *PaymentAuditLogQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *PaymentAuditLogQuery) Order(o ...paymentauditlog.OrderOption) *PaymentAuditLogQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// First returns the first PaymentAuditLog entity from the query.
+// Returns a *NotFoundError when no PaymentAuditLog was found.
+func (_q *PaymentAuditLogQuery) First(ctx context.Context) (*PaymentAuditLog, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{paymentauditlog.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *PaymentAuditLogQuery) FirstX(ctx context.Context) *PaymentAuditLog {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first PaymentAuditLog ID from the query.
+// Returns a *NotFoundError when no PaymentAuditLog ID was found.
+func (_q *PaymentAuditLogQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{paymentauditlog.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *PaymentAuditLogQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single PaymentAuditLog entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one PaymentAuditLog entity is found.
+// Returns a *NotFoundError when no PaymentAuditLog entities are found.
+func (_q *PaymentAuditLogQuery) Only(ctx context.Context) (*PaymentAuditLog, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{paymentauditlog.Label}
+ default:
+ return nil, &NotSingularError{paymentauditlog.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *PaymentAuditLogQuery) OnlyX(ctx context.Context) *PaymentAuditLog {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only PaymentAuditLog ID in the query.
+// Returns a *NotSingularError when more than one PaymentAuditLog ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *PaymentAuditLogQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{paymentauditlog.Label}
+ default:
+ err = &NotSingularError{paymentauditlog.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *PaymentAuditLogQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of PaymentAuditLogs.
+func (_q *PaymentAuditLogQuery) All(ctx context.Context) ([]*PaymentAuditLog, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*PaymentAuditLog, *PaymentAuditLogQuery]()
+ return withInterceptors[[]*PaymentAuditLog](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *PaymentAuditLogQuery) AllX(ctx context.Context) []*PaymentAuditLog {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of PaymentAuditLog IDs.
+func (_q *PaymentAuditLogQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(paymentauditlog.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *PaymentAuditLogQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *PaymentAuditLogQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*PaymentAuditLogQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *PaymentAuditLogQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *PaymentAuditLogQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *PaymentAuditLogQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the PaymentAuditLogQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *PaymentAuditLogQuery) Clone() *PaymentAuditLogQuery {
+ if _q == nil {
+ return nil
+ }
+ return &PaymentAuditLogQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]paymentauditlog.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.PaymentAuditLog{}, _q.predicates...),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// OrderID string `json:"order_id,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.PaymentAuditLog.Query().
+// GroupBy(paymentauditlog.FieldOrderID).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *PaymentAuditLogQuery) GroupBy(field string, fields ...string) *PaymentAuditLogGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &PaymentAuditLogGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = paymentauditlog.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// OrderID string `json:"order_id,omitempty"`
+// }
+//
+// client.PaymentAuditLog.Query().
+// Select(paymentauditlog.FieldOrderID).
+// Scan(ctx, &v)
+func (_q *PaymentAuditLogQuery) Select(fields ...string) *PaymentAuditLogSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &PaymentAuditLogSelect{PaymentAuditLogQuery: _q}
+ sbuild.label = paymentauditlog.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a PaymentAuditLogSelect configured with the given aggregations.
+func (_q *PaymentAuditLogQuery) Aggregate(fns ...AggregateFunc) *PaymentAuditLogSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *PaymentAuditLogQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !paymentauditlog.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *PaymentAuditLogQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PaymentAuditLog, error) {
+ var (
+ nodes = []*PaymentAuditLog{}
+ _spec = _q.querySpec()
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*PaymentAuditLog).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &PaymentAuditLog{config: _q.config}
+ nodes = append(nodes, node)
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ return nodes, nil
+}
+
+func (_q *PaymentAuditLogQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *PaymentAuditLogQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(paymentauditlog.Table, paymentauditlog.Columns, sqlgraph.NewFieldSpec(paymentauditlog.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, paymentauditlog.FieldID)
+ for i := range fields {
+ if fields[i] != paymentauditlog.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *PaymentAuditLogQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(paymentauditlog.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = paymentauditlog.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *PaymentAuditLogQuery) ForUpdate(opts ...sql.LockOption) *PaymentAuditLogQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *PaymentAuditLogQuery) ForShare(opts ...sql.LockOption) *PaymentAuditLogQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// PaymentAuditLogGroupBy is the group-by builder for PaymentAuditLog entities.
+type PaymentAuditLogGroupBy struct {
+ selector
+ build *PaymentAuditLogQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *PaymentAuditLogGroupBy) Aggregate(fns ...AggregateFunc) *PaymentAuditLogGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *PaymentAuditLogGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PaymentAuditLogQuery, *PaymentAuditLogGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *PaymentAuditLogGroupBy) sqlScan(ctx context.Context, root *PaymentAuditLogQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// PaymentAuditLogSelect is the builder for selecting fields of PaymentAuditLog entities.
+type PaymentAuditLogSelect struct {
+ *PaymentAuditLogQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *PaymentAuditLogSelect) Aggregate(fns ...AggregateFunc) *PaymentAuditLogSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *PaymentAuditLogSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PaymentAuditLogQuery, *PaymentAuditLogSelect](ctx, _s.PaymentAuditLogQuery, _s, _s.inters, v)
+}
+
+func (_s *PaymentAuditLogSelect) sqlScan(ctx context.Context, root *PaymentAuditLogQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/paymentauditlog_update.go b/backend/ent/paymentauditlog_update.go
new file mode 100644
index 00000000..52b4afe7
--- /dev/null
+++ b/backend/ent/paymentauditlog_update.go
@@ -0,0 +1,357 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// PaymentAuditLogUpdate is the builder for updating PaymentAuditLog entities.
+type PaymentAuditLogUpdate struct {
+ config
+ hooks []Hook
+ mutation *PaymentAuditLogMutation
+}
+
+// Where appends a list predicates to the PaymentAuditLogUpdate builder.
+func (_u *PaymentAuditLogUpdate) Where(ps ...predicate.PaymentAuditLog) *PaymentAuditLogUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetOrderID sets the "order_id" field.
+func (_u *PaymentAuditLogUpdate) SetOrderID(v string) *PaymentAuditLogUpdate {
+ _u.mutation.SetOrderID(v)
+ return _u
+}
+
+// SetNillableOrderID sets the "order_id" field if the given value is not nil.
+func (_u *PaymentAuditLogUpdate) SetNillableOrderID(v *string) *PaymentAuditLogUpdate {
+ if v != nil {
+ _u.SetOrderID(*v)
+ }
+ return _u
+}
+
+// SetAction sets the "action" field.
+func (_u *PaymentAuditLogUpdate) SetAction(v string) *PaymentAuditLogUpdate {
+ _u.mutation.SetAction(v)
+ return _u
+}
+
+// SetNillableAction sets the "action" field if the given value is not nil.
+func (_u *PaymentAuditLogUpdate) SetNillableAction(v *string) *PaymentAuditLogUpdate {
+ if v != nil {
+ _u.SetAction(*v)
+ }
+ return _u
+}
+
+// SetDetail sets the "detail" field.
+func (_u *PaymentAuditLogUpdate) SetDetail(v string) *PaymentAuditLogUpdate {
+ _u.mutation.SetDetail(v)
+ return _u
+}
+
+// SetNillableDetail sets the "detail" field if the given value is not nil.
+func (_u *PaymentAuditLogUpdate) SetNillableDetail(v *string) *PaymentAuditLogUpdate {
+ if v != nil {
+ _u.SetDetail(*v)
+ }
+ return _u
+}
+
+// SetOperator sets the "operator" field.
+func (_u *PaymentAuditLogUpdate) SetOperator(v string) *PaymentAuditLogUpdate {
+ _u.mutation.SetOperator(v)
+ return _u
+}
+
+// SetNillableOperator sets the "operator" field if the given value is not nil.
+func (_u *PaymentAuditLogUpdate) SetNillableOperator(v *string) *PaymentAuditLogUpdate {
+ if v != nil {
+ _u.SetOperator(*v)
+ }
+ return _u
+}
+
+// Mutation returns the PaymentAuditLogMutation object of the builder.
+func (_u *PaymentAuditLogUpdate) Mutation() *PaymentAuditLogMutation {
+ return _u.mutation
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *PaymentAuditLogUpdate) Save(ctx context.Context) (int, error) {
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PaymentAuditLogUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *PaymentAuditLogUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PaymentAuditLogUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PaymentAuditLogUpdate) check() error {
+ if v, ok := _u.mutation.OrderID(); ok {
+ if err := paymentauditlog.OrderIDValidator(v); err != nil {
+ return &ValidationError{Name: "order_id", err: fmt.Errorf(`ent: validator failed for field "PaymentAuditLog.order_id": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Action(); ok {
+ if err := paymentauditlog.ActionValidator(v); err != nil {
+ return &ValidationError{Name: "action", err: fmt.Errorf(`ent: validator failed for field "PaymentAuditLog.action": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Operator(); ok {
+ if err := paymentauditlog.OperatorValidator(v); err != nil {
+ return &ValidationError{Name: "operator", err: fmt.Errorf(`ent: validator failed for field "PaymentAuditLog.operator": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *PaymentAuditLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(paymentauditlog.Table, paymentauditlog.Columns, sqlgraph.NewFieldSpec(paymentauditlog.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.OrderID(); ok {
+ _spec.SetField(paymentauditlog.FieldOrderID, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Action(); ok {
+ _spec.SetField(paymentauditlog.FieldAction, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Detail(); ok {
+ _spec.SetField(paymentauditlog.FieldDetail, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Operator(); ok {
+ _spec.SetField(paymentauditlog.FieldOperator, field.TypeString, value)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{paymentauditlog.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// PaymentAuditLogUpdateOne is the builder for updating a single PaymentAuditLog entity.
+type PaymentAuditLogUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *PaymentAuditLogMutation
+}
+
+// SetOrderID sets the "order_id" field.
+func (_u *PaymentAuditLogUpdateOne) SetOrderID(v string) *PaymentAuditLogUpdateOne {
+ _u.mutation.SetOrderID(v)
+ return _u
+}
+
+// SetNillableOrderID sets the "order_id" field if the given value is not nil.
+func (_u *PaymentAuditLogUpdateOne) SetNillableOrderID(v *string) *PaymentAuditLogUpdateOne {
+ if v != nil {
+ _u.SetOrderID(*v)
+ }
+ return _u
+}
+
+// SetAction sets the "action" field.
+func (_u *PaymentAuditLogUpdateOne) SetAction(v string) *PaymentAuditLogUpdateOne {
+ _u.mutation.SetAction(v)
+ return _u
+}
+
+// SetNillableAction sets the "action" field if the given value is not nil.
+func (_u *PaymentAuditLogUpdateOne) SetNillableAction(v *string) *PaymentAuditLogUpdateOne {
+ if v != nil {
+ _u.SetAction(*v)
+ }
+ return _u
+}
+
+// SetDetail sets the "detail" field.
+func (_u *PaymentAuditLogUpdateOne) SetDetail(v string) *PaymentAuditLogUpdateOne {
+ _u.mutation.SetDetail(v)
+ return _u
+}
+
+// SetNillableDetail sets the "detail" field if the given value is not nil.
+func (_u *PaymentAuditLogUpdateOne) SetNillableDetail(v *string) *PaymentAuditLogUpdateOne {
+ if v != nil {
+ _u.SetDetail(*v)
+ }
+ return _u
+}
+
+// SetOperator sets the "operator" field.
+func (_u *PaymentAuditLogUpdateOne) SetOperator(v string) *PaymentAuditLogUpdateOne {
+ _u.mutation.SetOperator(v)
+ return _u
+}
+
+// SetNillableOperator sets the "operator" field if the given value is not nil.
+func (_u *PaymentAuditLogUpdateOne) SetNillableOperator(v *string) *PaymentAuditLogUpdateOne {
+ if v != nil {
+ _u.SetOperator(*v)
+ }
+ return _u
+}
+
+// Mutation returns the PaymentAuditLogMutation object of the builder.
+func (_u *PaymentAuditLogUpdateOne) Mutation() *PaymentAuditLogMutation {
+ return _u.mutation
+}
+
+// Where appends a list predicates to the PaymentAuditLogUpdate builder.
+func (_u *PaymentAuditLogUpdateOne) Where(ps ...predicate.PaymentAuditLog) *PaymentAuditLogUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *PaymentAuditLogUpdateOne) Select(field string, fields ...string) *PaymentAuditLogUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated PaymentAuditLog entity.
+func (_u *PaymentAuditLogUpdateOne) Save(ctx context.Context) (*PaymentAuditLog, error) {
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PaymentAuditLogUpdateOne) SaveX(ctx context.Context) *PaymentAuditLog {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *PaymentAuditLogUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PaymentAuditLogUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PaymentAuditLogUpdateOne) check() error {
+ if v, ok := _u.mutation.OrderID(); ok {
+ if err := paymentauditlog.OrderIDValidator(v); err != nil {
+ return &ValidationError{Name: "order_id", err: fmt.Errorf(`ent: validator failed for field "PaymentAuditLog.order_id": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Action(); ok {
+ if err := paymentauditlog.ActionValidator(v); err != nil {
+ return &ValidationError{Name: "action", err: fmt.Errorf(`ent: validator failed for field "PaymentAuditLog.action": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Operator(); ok {
+ if err := paymentauditlog.OperatorValidator(v); err != nil {
+ return &ValidationError{Name: "operator", err: fmt.Errorf(`ent: validator failed for field "PaymentAuditLog.operator": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *PaymentAuditLogUpdateOne) sqlSave(ctx context.Context) (_node *PaymentAuditLog, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(paymentauditlog.Table, paymentauditlog.Columns, sqlgraph.NewFieldSpec(paymentauditlog.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PaymentAuditLog.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, paymentauditlog.FieldID)
+ for _, f := range fields {
+ if !paymentauditlog.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != paymentauditlog.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.OrderID(); ok {
+ _spec.SetField(paymentauditlog.FieldOrderID, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Action(); ok {
+ _spec.SetField(paymentauditlog.FieldAction, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Detail(); ok {
+ _spec.SetField(paymentauditlog.FieldDetail, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Operator(); ok {
+ _spec.SetField(paymentauditlog.FieldOperator, field.TypeString, value)
+ }
+ _node = &PaymentAuditLog{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{paymentauditlog.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/paymentorder.go b/backend/ent/paymentorder.go
new file mode 100644
index 00000000..b131b8c8
--- /dev/null
+++ b/backend/ent/paymentorder.go
@@ -0,0 +1,619 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PaymentOrder is the model entity for the PaymentOrder schema.
+type PaymentOrder struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // UserID holds the value of the "user_id" field.
+ UserID int64 `json:"user_id,omitempty"`
+ // UserEmail holds the value of the "user_email" field.
+ UserEmail string `json:"user_email,omitempty"`
+ // UserName holds the value of the "user_name" field.
+ UserName string `json:"user_name,omitempty"`
+ // UserNotes holds the value of the "user_notes" field.
+ UserNotes *string `json:"user_notes,omitempty"`
+ // Amount holds the value of the "amount" field.
+ Amount float64 `json:"amount,omitempty"`
+ // PayAmount holds the value of the "pay_amount" field.
+ PayAmount float64 `json:"pay_amount,omitempty"`
+ // FeeRate holds the value of the "fee_rate" field.
+ FeeRate float64 `json:"fee_rate,omitempty"`
+ // RechargeCode holds the value of the "recharge_code" field.
+ RechargeCode string `json:"recharge_code,omitempty"`
+ // OutTradeNo holds the value of the "out_trade_no" field.
+ OutTradeNo string `json:"out_trade_no,omitempty"`
+ // PaymentType holds the value of the "payment_type" field.
+ PaymentType string `json:"payment_type,omitempty"`
+ // PaymentTradeNo holds the value of the "payment_trade_no" field.
+ PaymentTradeNo string `json:"payment_trade_no,omitempty"`
+ // PayURL holds the value of the "pay_url" field.
+ PayURL *string `json:"pay_url,omitempty"`
+ // QrCode holds the value of the "qr_code" field.
+ QrCode *string `json:"qr_code,omitempty"`
+ // QrCodeImg holds the value of the "qr_code_img" field.
+ QrCodeImg *string `json:"qr_code_img,omitempty"`
+ // OrderType holds the value of the "order_type" field.
+ OrderType string `json:"order_type,omitempty"`
+ // PlanID holds the value of the "plan_id" field.
+ PlanID *int64 `json:"plan_id,omitempty"`
+ // SubscriptionGroupID holds the value of the "subscription_group_id" field.
+ SubscriptionGroupID *int64 `json:"subscription_group_id,omitempty"`
+ // SubscriptionDays holds the value of the "subscription_days" field.
+ SubscriptionDays *int `json:"subscription_days,omitempty"`
+ // ProviderInstanceID holds the value of the "provider_instance_id" field.
+ ProviderInstanceID *string `json:"provider_instance_id,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey *string `json:"provider_key,omitempty"`
+ // ProviderSnapshot holds the value of the "provider_snapshot" field.
+ ProviderSnapshot map[string]interface{} `json:"provider_snapshot,omitempty"`
+ // Status holds the value of the "status" field.
+ Status string `json:"status,omitempty"`
+ // RefundAmount holds the value of the "refund_amount" field.
+ RefundAmount float64 `json:"refund_amount,omitempty"`
+ // RefundReason holds the value of the "refund_reason" field.
+ RefundReason *string `json:"refund_reason,omitempty"`
+ // RefundAt holds the value of the "refund_at" field.
+ RefundAt *time.Time `json:"refund_at,omitempty"`
+ // ForceRefund holds the value of the "force_refund" field.
+ ForceRefund bool `json:"force_refund,omitempty"`
+ // RefundRequestedAt holds the value of the "refund_requested_at" field.
+ RefundRequestedAt *time.Time `json:"refund_requested_at,omitempty"`
+ // RefundRequestReason holds the value of the "refund_request_reason" field.
+ RefundRequestReason *string `json:"refund_request_reason,omitempty"`
+ // RefundRequestedBy holds the value of the "refund_requested_by" field.
+ RefundRequestedBy *string `json:"refund_requested_by,omitempty"`
+ // ExpiresAt holds the value of the "expires_at" field.
+ ExpiresAt time.Time `json:"expires_at,omitempty"`
+ // PaidAt holds the value of the "paid_at" field.
+ PaidAt *time.Time `json:"paid_at,omitempty"`
+ // CompletedAt holds the value of the "completed_at" field.
+ CompletedAt *time.Time `json:"completed_at,omitempty"`
+ // FailedAt holds the value of the "failed_at" field.
+ FailedAt *time.Time `json:"failed_at,omitempty"`
+ // FailedReason holds the value of the "failed_reason" field.
+ FailedReason *string `json:"failed_reason,omitempty"`
+ // ClientIP holds the value of the "client_ip" field.
+ ClientIP string `json:"client_ip,omitempty"`
+ // SrcHost holds the value of the "src_host" field.
+ SrcHost string `json:"src_host,omitempty"`
+ // SrcURL holds the value of the "src_url" field.
+ SrcURL *string `json:"src_url,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the PaymentOrderQuery when eager-loading is set.
+ Edges PaymentOrderEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// PaymentOrderEdges holds the relations/edges for other nodes in the graph.
+type PaymentOrderEdges struct {
+ // User holds the value of the user edge.
+ User *User `json:"user,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [1]bool
+}
+
+// UserOrErr returns the User value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e PaymentOrderEdges) UserOrErr() (*User, error) {
+ if e.User != nil {
+ return e.User, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: user.Label}
+ }
+ return nil, &NotLoadedError{edge: "user"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*PaymentOrder) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case paymentorder.FieldProviderSnapshot:
+ values[i] = new([]byte)
+ case paymentorder.FieldForceRefund:
+ values[i] = new(sql.NullBool)
+ case paymentorder.FieldAmount, paymentorder.FieldPayAmount, paymentorder.FieldFeeRate, paymentorder.FieldRefundAmount:
+ values[i] = new(sql.NullFloat64)
+ case paymentorder.FieldID, paymentorder.FieldUserID, paymentorder.FieldPlanID, paymentorder.FieldSubscriptionGroupID, paymentorder.FieldSubscriptionDays:
+ values[i] = new(sql.NullInt64)
+ case paymentorder.FieldUserEmail, paymentorder.FieldUserName, paymentorder.FieldUserNotes, paymentorder.FieldRechargeCode, paymentorder.FieldOutTradeNo, paymentorder.FieldPaymentType, paymentorder.FieldPaymentTradeNo, paymentorder.FieldPayURL, paymentorder.FieldQrCode, paymentorder.FieldQrCodeImg, paymentorder.FieldOrderType, paymentorder.FieldProviderInstanceID, paymentorder.FieldProviderKey, paymentorder.FieldStatus, paymentorder.FieldRefundReason, paymentorder.FieldRefundRequestReason, paymentorder.FieldRefundRequestedBy, paymentorder.FieldFailedReason, paymentorder.FieldClientIP, paymentorder.FieldSrcHost, paymentorder.FieldSrcURL:
+ values[i] = new(sql.NullString)
+ case paymentorder.FieldRefundAt, paymentorder.FieldRefundRequestedAt, paymentorder.FieldExpiresAt, paymentorder.FieldPaidAt, paymentorder.FieldCompletedAt, paymentorder.FieldFailedAt, paymentorder.FieldCreatedAt, paymentorder.FieldUpdatedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the PaymentOrder fields.
+func (_m *PaymentOrder) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case paymentorder.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case paymentorder.FieldUserID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field user_id", values[i])
+ } else if value.Valid {
+ _m.UserID = value.Int64
+ }
+ case paymentorder.FieldUserEmail:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field user_email", values[i])
+ } else if value.Valid {
+ _m.UserEmail = value.String
+ }
+ case paymentorder.FieldUserName:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field user_name", values[i])
+ } else if value.Valid {
+ _m.UserName = value.String
+ }
+ case paymentorder.FieldUserNotes:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field user_notes", values[i])
+ } else if value.Valid {
+ _m.UserNotes = new(string)
+ *_m.UserNotes = value.String
+ }
+ case paymentorder.FieldAmount:
+ if value, ok := values[i].(*sql.NullFloat64); !ok {
+ return fmt.Errorf("unexpected type %T for field amount", values[i])
+ } else if value.Valid {
+ _m.Amount = value.Float64
+ }
+ case paymentorder.FieldPayAmount:
+ if value, ok := values[i].(*sql.NullFloat64); !ok {
+ return fmt.Errorf("unexpected type %T for field pay_amount", values[i])
+ } else if value.Valid {
+ _m.PayAmount = value.Float64
+ }
+ case paymentorder.FieldFeeRate:
+ if value, ok := values[i].(*sql.NullFloat64); !ok {
+ return fmt.Errorf("unexpected type %T for field fee_rate", values[i])
+ } else if value.Valid {
+ _m.FeeRate = value.Float64
+ }
+ case paymentorder.FieldRechargeCode:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field recharge_code", values[i])
+ } else if value.Valid {
+ _m.RechargeCode = value.String
+ }
+ case paymentorder.FieldOutTradeNo:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field out_trade_no", values[i])
+ } else if value.Valid {
+ _m.OutTradeNo = value.String
+ }
+ case paymentorder.FieldPaymentType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field payment_type", values[i])
+ } else if value.Valid {
+ _m.PaymentType = value.String
+ }
+ case paymentorder.FieldPaymentTradeNo:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field payment_trade_no", values[i])
+ } else if value.Valid {
+ _m.PaymentTradeNo = value.String
+ }
+ case paymentorder.FieldPayURL:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field pay_url", values[i])
+ } else if value.Valid {
+ _m.PayURL = new(string)
+ *_m.PayURL = value.String
+ }
+ case paymentorder.FieldQrCode:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field qr_code", values[i])
+ } else if value.Valid {
+ _m.QrCode = new(string)
+ *_m.QrCode = value.String
+ }
+ case paymentorder.FieldQrCodeImg:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field qr_code_img", values[i])
+ } else if value.Valid {
+ _m.QrCodeImg = new(string)
+ *_m.QrCodeImg = value.String
+ }
+ case paymentorder.FieldOrderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field order_type", values[i])
+ } else if value.Valid {
+ _m.OrderType = value.String
+ }
+ case paymentorder.FieldPlanID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field plan_id", values[i])
+ } else if value.Valid {
+ _m.PlanID = new(int64)
+ *_m.PlanID = value.Int64
+ }
+ case paymentorder.FieldSubscriptionGroupID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field subscription_group_id", values[i])
+ } else if value.Valid {
+ _m.SubscriptionGroupID = new(int64)
+ *_m.SubscriptionGroupID = value.Int64
+ }
+ case paymentorder.FieldSubscriptionDays:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field subscription_days", values[i])
+ } else if value.Valid {
+ _m.SubscriptionDays = new(int)
+ *_m.SubscriptionDays = int(value.Int64)
+ }
+ case paymentorder.FieldProviderInstanceID:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_instance_id", values[i])
+ } else if value.Valid {
+ _m.ProviderInstanceID = new(string)
+ *_m.ProviderInstanceID = value.String
+ }
+ case paymentorder.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = new(string)
+ *_m.ProviderKey = value.String
+ }
+ case paymentorder.FieldProviderSnapshot:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_snapshot", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.ProviderSnapshot); err != nil {
+ return fmt.Errorf("unmarshal field provider_snapshot: %w", err)
+ }
+ }
+ case paymentorder.FieldStatus:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field status", values[i])
+ } else if value.Valid {
+ _m.Status = value.String
+ }
+ case paymentorder.FieldRefundAmount:
+ if value, ok := values[i].(*sql.NullFloat64); !ok {
+ return fmt.Errorf("unexpected type %T for field refund_amount", values[i])
+ } else if value.Valid {
+ _m.RefundAmount = value.Float64
+ }
+ case paymentorder.FieldRefundReason:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field refund_reason", values[i])
+ } else if value.Valid {
+ _m.RefundReason = new(string)
+ *_m.RefundReason = value.String
+ }
+ case paymentorder.FieldRefundAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field refund_at", values[i])
+ } else if value.Valid {
+ _m.RefundAt = new(time.Time)
+ *_m.RefundAt = value.Time
+ }
+ case paymentorder.FieldForceRefund:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field force_refund", values[i])
+ } else if value.Valid {
+ _m.ForceRefund = value.Bool
+ }
+ case paymentorder.FieldRefundRequestedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field refund_requested_at", values[i])
+ } else if value.Valid {
+ _m.RefundRequestedAt = new(time.Time)
+ *_m.RefundRequestedAt = value.Time
+ }
+ case paymentorder.FieldRefundRequestReason:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field refund_request_reason", values[i])
+ } else if value.Valid {
+ _m.RefundRequestReason = new(string)
+ *_m.RefundRequestReason = value.String
+ }
+ case paymentorder.FieldRefundRequestedBy:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field refund_requested_by", values[i])
+ } else if value.Valid {
+ _m.RefundRequestedBy = new(string)
+ *_m.RefundRequestedBy = value.String
+ }
+ case paymentorder.FieldExpiresAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field expires_at", values[i])
+ } else if value.Valid {
+ _m.ExpiresAt = value.Time
+ }
+ case paymentorder.FieldPaidAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field paid_at", values[i])
+ } else if value.Valid {
+ _m.PaidAt = new(time.Time)
+ *_m.PaidAt = value.Time
+ }
+ case paymentorder.FieldCompletedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field completed_at", values[i])
+ } else if value.Valid {
+ _m.CompletedAt = new(time.Time)
+ *_m.CompletedAt = value.Time
+ }
+ case paymentorder.FieldFailedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field failed_at", values[i])
+ } else if value.Valid {
+ _m.FailedAt = new(time.Time)
+ *_m.FailedAt = value.Time
+ }
+ case paymentorder.FieldFailedReason:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field failed_reason", values[i])
+ } else if value.Valid {
+ _m.FailedReason = new(string)
+ *_m.FailedReason = value.String
+ }
+ case paymentorder.FieldClientIP:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field client_ip", values[i])
+ } else if value.Valid {
+ _m.ClientIP = value.String
+ }
+ case paymentorder.FieldSrcHost:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field src_host", values[i])
+ } else if value.Valid {
+ _m.SrcHost = value.String
+ }
+ case paymentorder.FieldSrcURL:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field src_url", values[i])
+ } else if value.Valid {
+ _m.SrcURL = new(string)
+ *_m.SrcURL = value.String
+ }
+ case paymentorder.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case paymentorder.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the PaymentOrder.
+// This includes values selected through modifiers, order, etc.
+func (_m *PaymentOrder) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryUser queries the "user" edge of the PaymentOrder entity.
+func (_m *PaymentOrder) QueryUser() *UserQuery {
+ return NewPaymentOrderClient(_m.config).QueryUser(_m)
+}
+
+// Update returns a builder for updating this PaymentOrder.
+// Note that you need to call PaymentOrder.Unwrap() before calling this method if this PaymentOrder
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *PaymentOrder) Update() *PaymentOrderUpdateOne {
+ return NewPaymentOrderClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the PaymentOrder entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *PaymentOrder) Unwrap() *PaymentOrder {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: PaymentOrder is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *PaymentOrder) String() string {
+ var builder strings.Builder
+ builder.WriteString("PaymentOrder(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("user_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.UserID))
+ builder.WriteString(", ")
+ builder.WriteString("user_email=")
+ builder.WriteString(_m.UserEmail)
+ builder.WriteString(", ")
+ builder.WriteString("user_name=")
+ builder.WriteString(_m.UserName)
+ builder.WriteString(", ")
+ if v := _m.UserNotes; v != nil {
+ builder.WriteString("user_notes=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("amount=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Amount))
+ builder.WriteString(", ")
+ builder.WriteString("pay_amount=")
+ builder.WriteString(fmt.Sprintf("%v", _m.PayAmount))
+ builder.WriteString(", ")
+ builder.WriteString("fee_rate=")
+ builder.WriteString(fmt.Sprintf("%v", _m.FeeRate))
+ builder.WriteString(", ")
+ builder.WriteString("recharge_code=")
+ builder.WriteString(_m.RechargeCode)
+ builder.WriteString(", ")
+ builder.WriteString("out_trade_no=")
+ builder.WriteString(_m.OutTradeNo)
+ builder.WriteString(", ")
+ builder.WriteString("payment_type=")
+ builder.WriteString(_m.PaymentType)
+ builder.WriteString(", ")
+ builder.WriteString("payment_trade_no=")
+ builder.WriteString(_m.PaymentTradeNo)
+ builder.WriteString(", ")
+ if v := _m.PayURL; v != nil {
+ builder.WriteString("pay_url=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ if v := _m.QrCode; v != nil {
+ builder.WriteString("qr_code=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ if v := _m.QrCodeImg; v != nil {
+ builder.WriteString("qr_code_img=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("order_type=")
+ builder.WriteString(_m.OrderType)
+ builder.WriteString(", ")
+ if v := _m.PlanID; v != nil {
+ builder.WriteString("plan_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ if v := _m.SubscriptionGroupID; v != nil {
+ builder.WriteString("subscription_group_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ if v := _m.SubscriptionDays; v != nil {
+ builder.WriteString("subscription_days=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ if v := _m.ProviderInstanceID; v != nil {
+ builder.WriteString("provider_instance_id=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ if v := _m.ProviderKey; v != nil {
+ builder.WriteString("provider_key=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("provider_snapshot=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ProviderSnapshot))
+ builder.WriteString(", ")
+ builder.WriteString("status=")
+ builder.WriteString(_m.Status)
+ builder.WriteString(", ")
+ builder.WriteString("refund_amount=")
+ builder.WriteString(fmt.Sprintf("%v", _m.RefundAmount))
+ builder.WriteString(", ")
+ if v := _m.RefundReason; v != nil {
+ builder.WriteString("refund_reason=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ if v := _m.RefundAt; v != nil {
+ builder.WriteString("refund_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("force_refund=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ForceRefund))
+ builder.WriteString(", ")
+ if v := _m.RefundRequestedAt; v != nil {
+ builder.WriteString("refund_requested_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.RefundRequestReason; v != nil {
+ builder.WriteString("refund_request_reason=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ if v := _m.RefundRequestedBy; v != nil {
+ builder.WriteString("refund_requested_by=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("expires_at=")
+ builder.WriteString(_m.ExpiresAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ if v := _m.PaidAt; v != nil {
+ builder.WriteString("paid_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.CompletedAt; v != nil {
+ builder.WriteString("completed_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.FailedAt; v != nil {
+ builder.WriteString("failed_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.FailedReason; v != nil {
+ builder.WriteString("failed_reason=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("client_ip=")
+ builder.WriteString(_m.ClientIP)
+ builder.WriteString(", ")
+ builder.WriteString("src_host=")
+ builder.WriteString(_m.SrcHost)
+ builder.WriteString(", ")
+ if v := _m.SrcURL; v != nil {
+ builder.WriteString("src_url=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// PaymentOrders is a parsable slice of PaymentOrder.
+type PaymentOrders []*PaymentOrder
diff --git a/backend/ent/paymentorder/paymentorder.go b/backend/ent/paymentorder/paymentorder.go
new file mode 100644
index 00000000..62883794
--- /dev/null
+++ b/backend/ent/paymentorder/paymentorder.go
@@ -0,0 +1,419 @@
+// Code generated by ent, DO NOT EDIT.
+
+package paymentorder
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the paymentorder type in the database.
+ Label = "payment_order"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldUserID holds the string denoting the user_id field in the database.
+ FieldUserID = "user_id"
+ // FieldUserEmail holds the string denoting the user_email field in the database.
+ FieldUserEmail = "user_email"
+ // FieldUserName holds the string denoting the user_name field in the database.
+ FieldUserName = "user_name"
+ // FieldUserNotes holds the string denoting the user_notes field in the database.
+ FieldUserNotes = "user_notes"
+ // FieldAmount holds the string denoting the amount field in the database.
+ FieldAmount = "amount"
+ // FieldPayAmount holds the string denoting the pay_amount field in the database.
+ FieldPayAmount = "pay_amount"
+ // FieldFeeRate holds the string denoting the fee_rate field in the database.
+ FieldFeeRate = "fee_rate"
+ // FieldRechargeCode holds the string denoting the recharge_code field in the database.
+ FieldRechargeCode = "recharge_code"
+ // FieldOutTradeNo holds the string denoting the out_trade_no field in the database.
+ FieldOutTradeNo = "out_trade_no"
+ // FieldPaymentType holds the string denoting the payment_type field in the database.
+ FieldPaymentType = "payment_type"
+ // FieldPaymentTradeNo holds the string denoting the payment_trade_no field in the database.
+ FieldPaymentTradeNo = "payment_trade_no"
+ // FieldPayURL holds the string denoting the pay_url field in the database.
+ FieldPayURL = "pay_url"
+ // FieldQrCode holds the string denoting the qr_code field in the database.
+ FieldQrCode = "qr_code"
+ // FieldQrCodeImg holds the string denoting the qr_code_img field in the database.
+ FieldQrCodeImg = "qr_code_img"
+ // FieldOrderType holds the string denoting the order_type field in the database.
+ FieldOrderType = "order_type"
+ // FieldPlanID holds the string denoting the plan_id field in the database.
+ FieldPlanID = "plan_id"
+ // FieldSubscriptionGroupID holds the string denoting the subscription_group_id field in the database.
+ FieldSubscriptionGroupID = "subscription_group_id"
+ // FieldSubscriptionDays holds the string denoting the subscription_days field in the database.
+ FieldSubscriptionDays = "subscription_days"
+ // FieldProviderInstanceID holds the string denoting the provider_instance_id field in the database.
+ FieldProviderInstanceID = "provider_instance_id"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldProviderSnapshot holds the string denoting the provider_snapshot field in the database.
+ FieldProviderSnapshot = "provider_snapshot"
+ // FieldStatus holds the string denoting the status field in the database.
+ FieldStatus = "status"
+ // FieldRefundAmount holds the string denoting the refund_amount field in the database.
+ FieldRefundAmount = "refund_amount"
+ // FieldRefundReason holds the string denoting the refund_reason field in the database.
+ FieldRefundReason = "refund_reason"
+ // FieldRefundAt holds the string denoting the refund_at field in the database.
+ FieldRefundAt = "refund_at"
+ // FieldForceRefund holds the string denoting the force_refund field in the database.
+ FieldForceRefund = "force_refund"
+ // FieldRefundRequestedAt holds the string denoting the refund_requested_at field in the database.
+ FieldRefundRequestedAt = "refund_requested_at"
+ // FieldRefundRequestReason holds the string denoting the refund_request_reason field in the database.
+ FieldRefundRequestReason = "refund_request_reason"
+ // FieldRefundRequestedBy holds the string denoting the refund_requested_by field in the database.
+ FieldRefundRequestedBy = "refund_requested_by"
+ // FieldExpiresAt holds the string denoting the expires_at field in the database.
+ FieldExpiresAt = "expires_at"
+ // FieldPaidAt holds the string denoting the paid_at field in the database.
+ FieldPaidAt = "paid_at"
+ // FieldCompletedAt holds the string denoting the completed_at field in the database.
+ FieldCompletedAt = "completed_at"
+ // FieldFailedAt holds the string denoting the failed_at field in the database.
+ FieldFailedAt = "failed_at"
+ // FieldFailedReason holds the string denoting the failed_reason field in the database.
+ FieldFailedReason = "failed_reason"
+ // FieldClientIP holds the string denoting the client_ip field in the database.
+ FieldClientIP = "client_ip"
+ // FieldSrcHost holds the string denoting the src_host field in the database.
+ FieldSrcHost = "src_host"
+ // FieldSrcURL holds the string denoting the src_url field in the database.
+ FieldSrcURL = "src_url"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // EdgeUser holds the string denoting the user edge name in mutations.
+ EdgeUser = "user"
+ // Table holds the table name of the paymentorder in the database.
+ Table = "payment_orders"
+ // UserTable is the table that holds the user relation/edge.
+ UserTable = "payment_orders"
+ // UserInverseTable is the table name for the User entity.
+ // It exists in this package in order to avoid circular dependency with the "user" package.
+ UserInverseTable = "users"
+ // UserColumn is the table column denoting the user relation/edge.
+ UserColumn = "user_id"
+)
+
+// Columns holds all SQL columns for paymentorder fields.
+var Columns = []string{
+ FieldID,
+ FieldUserID,
+ FieldUserEmail,
+ FieldUserName,
+ FieldUserNotes,
+ FieldAmount,
+ FieldPayAmount,
+ FieldFeeRate,
+ FieldRechargeCode,
+ FieldOutTradeNo,
+ FieldPaymentType,
+ FieldPaymentTradeNo,
+ FieldPayURL,
+ FieldQrCode,
+ FieldQrCodeImg,
+ FieldOrderType,
+ FieldPlanID,
+ FieldSubscriptionGroupID,
+ FieldSubscriptionDays,
+ FieldProviderInstanceID,
+ FieldProviderKey,
+ FieldProviderSnapshot,
+ FieldStatus,
+ FieldRefundAmount,
+ FieldRefundReason,
+ FieldRefundAt,
+ FieldForceRefund,
+ FieldRefundRequestedAt,
+ FieldRefundRequestReason,
+ FieldRefundRequestedBy,
+ FieldExpiresAt,
+ FieldPaidAt,
+ FieldCompletedAt,
+ FieldFailedAt,
+ FieldFailedReason,
+ FieldClientIP,
+ FieldSrcHost,
+ FieldSrcURL,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // UserEmailValidator is a validator for the "user_email" field. It is called by the builders before save.
+ UserEmailValidator func(string) error
+ // UserNameValidator is a validator for the "user_name" field. It is called by the builders before save.
+ UserNameValidator func(string) error
+ // DefaultFeeRate holds the default value on creation for the "fee_rate" field.
+ DefaultFeeRate float64
+ // RechargeCodeValidator is a validator for the "recharge_code" field. It is called by the builders before save.
+ RechargeCodeValidator func(string) error
+ // DefaultOutTradeNo holds the default value on creation for the "out_trade_no" field.
+ DefaultOutTradeNo string
+ // OutTradeNoValidator is a validator for the "out_trade_no" field. It is called by the builders before save.
+ OutTradeNoValidator func(string) error
+ // PaymentTypeValidator is a validator for the "payment_type" field. It is called by the builders before save.
+ PaymentTypeValidator func(string) error
+ // PaymentTradeNoValidator is a validator for the "payment_trade_no" field. It is called by the builders before save.
+ PaymentTradeNoValidator func(string) error
+ // DefaultOrderType holds the default value on creation for the "order_type" field.
+ DefaultOrderType string
+ // OrderTypeValidator is a validator for the "order_type" field. It is called by the builders before save.
+ OrderTypeValidator func(string) error
+ // ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save.
+ ProviderInstanceIDValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // DefaultStatus holds the default value on creation for the "status" field.
+ DefaultStatus string
+ // StatusValidator is a validator for the "status" field. It is called by the builders before save.
+ StatusValidator func(string) error
+ // DefaultRefundAmount holds the default value on creation for the "refund_amount" field.
+ DefaultRefundAmount float64
+ // DefaultForceRefund holds the default value on creation for the "force_refund" field.
+ DefaultForceRefund bool
+ // RefundRequestedByValidator is a validator for the "refund_requested_by" field. It is called by the builders before save.
+ RefundRequestedByValidator func(string) error
+ // ClientIPValidator is a validator for the "client_ip" field. It is called by the builders before save.
+ ClientIPValidator func(string) error
+ // SrcHostValidator is a validator for the "src_host" field. It is called by the builders before save.
+ SrcHostValidator func(string) error
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the PaymentOrder queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByUserID orders the results by the user_id field.
+func ByUserID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUserID, opts...).ToFunc()
+}
+
+// ByUserEmail orders the results by the user_email field.
+func ByUserEmail(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUserEmail, opts...).ToFunc()
+}
+
+// ByUserName orders the results by the user_name field.
+func ByUserName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUserName, opts...).ToFunc()
+}
+
+// ByUserNotes orders the results by the user_notes field.
+func ByUserNotes(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUserNotes, opts...).ToFunc()
+}
+
+// ByAmount orders the results by the amount field.
+func ByAmount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAmount, opts...).ToFunc()
+}
+
+// ByPayAmount orders the results by the pay_amount field.
+func ByPayAmount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPayAmount, opts...).ToFunc()
+}
+
+// ByFeeRate orders the results by the fee_rate field.
+func ByFeeRate(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldFeeRate, opts...).ToFunc()
+}
+
+// ByRechargeCode orders the results by the recharge_code field.
+func ByRechargeCode(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRechargeCode, opts...).ToFunc()
+}
+
+// ByOutTradeNo orders the results by the out_trade_no field.
+func ByOutTradeNo(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldOutTradeNo, opts...).ToFunc()
+}
+
+// ByPaymentType orders the results by the payment_type field.
+func ByPaymentType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPaymentType, opts...).ToFunc()
+}
+
+// ByPaymentTradeNo orders the results by the payment_trade_no field.
+func ByPaymentTradeNo(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPaymentTradeNo, opts...).ToFunc()
+}
+
+// ByPayURL orders the results by the pay_url field.
+func ByPayURL(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPayURL, opts...).ToFunc()
+}
+
+// ByQrCode orders the results by the qr_code field.
+func ByQrCode(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldQrCode, opts...).ToFunc()
+}
+
+// ByQrCodeImg orders the results by the qr_code_img field.
+func ByQrCodeImg(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldQrCodeImg, opts...).ToFunc()
+}
+
+// ByOrderType orders the results by the order_type field.
+func ByOrderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldOrderType, opts...).ToFunc()
+}
+
+// ByPlanID orders the results by the plan_id field.
+func ByPlanID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPlanID, opts...).ToFunc()
+}
+
+// BySubscriptionGroupID orders the results by the subscription_group_id field.
+func BySubscriptionGroupID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSubscriptionGroupID, opts...).ToFunc()
+}
+
+// BySubscriptionDays orders the results by the subscription_days field.
+func BySubscriptionDays(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSubscriptionDays, opts...).ToFunc()
+}
+
+// ByProviderInstanceID orders the results by the provider_instance_id field.
+func ByProviderInstanceID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderInstanceID, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByStatus orders the results by the status field.
+func ByStatus(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldStatus, opts...).ToFunc()
+}
+
+// ByRefundAmount orders the results by the refund_amount field.
+func ByRefundAmount(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRefundAmount, opts...).ToFunc()
+}
+
+// ByRefundReason orders the results by the refund_reason field.
+func ByRefundReason(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRefundReason, opts...).ToFunc()
+}
+
+// ByRefundAt orders the results by the refund_at field.
+func ByRefundAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRefundAt, opts...).ToFunc()
+}
+
+// ByForceRefund orders the results by the force_refund field.
+func ByForceRefund(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldForceRefund, opts...).ToFunc()
+}
+
+// ByRefundRequestedAt orders the results by the refund_requested_at field.
+func ByRefundRequestedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRefundRequestedAt, opts...).ToFunc()
+}
+
+// ByRefundRequestReason orders the results by the refund_request_reason field.
+func ByRefundRequestReason(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRefundRequestReason, opts...).ToFunc()
+}
+
+// ByRefundRequestedBy orders the results by the refund_requested_by field.
+func ByRefundRequestedBy(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRefundRequestedBy, opts...).ToFunc()
+}
+
+// ByExpiresAt orders the results by the expires_at field.
+func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
+}
+
+// ByPaidAt orders the results by the paid_at field.
+func ByPaidAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPaidAt, opts...).ToFunc()
+}
+
+// ByCompletedAt orders the results by the completed_at field.
+func ByCompletedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCompletedAt, opts...).ToFunc()
+}
+
+// ByFailedAt orders the results by the failed_at field.
+func ByFailedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldFailedAt, opts...).ToFunc()
+}
+
+// ByFailedReason orders the results by the failed_reason field.
+func ByFailedReason(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldFailedReason, opts...).ToFunc()
+}
+
+// ByClientIP orders the results by the client_ip field.
+func ByClientIP(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldClientIP, opts...).ToFunc()
+}
+
+// BySrcHost orders the results by the src_host field.
+func BySrcHost(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSrcHost, opts...).ToFunc()
+}
+
+// BySrcURL orders the results by the src_url field.
+func BySrcURL(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSrcURL, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByUserField orders the results by user field.
+func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newUserStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(UserInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
+ )
+}
diff --git a/backend/ent/paymentorder/where.go b/backend/ent/paymentorder/where.go
new file mode 100644
index 00000000..e96bf51e
--- /dev/null
+++ b/backend/ent/paymentorder/where.go
@@ -0,0 +1,2479 @@
+// Code generated by ent, DO NOT EDIT.
+
+package paymentorder
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldID, id))
+}
+
+// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ.
+func UserID(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldUserID, v))
+}
+
+// UserEmail applies equality check predicate on the "user_email" field. It's identical to UserEmailEQ.
+func UserEmail(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldUserEmail, v))
+}
+
+// UserName applies equality check predicate on the "user_name" field. It's identical to UserNameEQ.
+func UserName(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldUserName, v))
+}
+
+// UserNotes applies equality check predicate on the "user_notes" field. It's identical to UserNotesEQ.
+func UserNotes(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldUserNotes, v))
+}
+
+// Amount applies equality check predicate on the "amount" field. It's identical to AmountEQ.
+func Amount(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldAmount, v))
+}
+
+// PayAmount applies equality check predicate on the "pay_amount" field. It's identical to PayAmountEQ.
+func PayAmount(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldPayAmount, v))
+}
+
+// FeeRate applies equality check predicate on the "fee_rate" field. It's identical to FeeRateEQ.
+func FeeRate(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldFeeRate, v))
+}
+
+// RechargeCode applies equality check predicate on the "recharge_code" field. It's identical to RechargeCodeEQ.
+func RechargeCode(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldRechargeCode, v))
+}
+
+// OutTradeNo applies equality check predicate on the "out_trade_no" field. It's identical to OutTradeNoEQ.
+func OutTradeNo(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldOutTradeNo, v))
+}
+
+// PaymentType applies equality check predicate on the "payment_type" field. It's identical to PaymentTypeEQ.
+func PaymentType(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldPaymentType, v))
+}
+
+// PaymentTradeNo applies equality check predicate on the "payment_trade_no" field. It's identical to PaymentTradeNoEQ.
+func PaymentTradeNo(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldPaymentTradeNo, v))
+}
+
+// PayURL applies equality check predicate on the "pay_url" field. It's identical to PayURLEQ.
+func PayURL(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldPayURL, v))
+}
+
+// QrCode applies equality check predicate on the "qr_code" field. It's identical to QrCodeEQ.
+func QrCode(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldQrCode, v))
+}
+
+// QrCodeImg applies equality check predicate on the "qr_code_img" field. It's identical to QrCodeImgEQ.
+func QrCodeImg(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldQrCodeImg, v))
+}
+
+// OrderType applies equality check predicate on the "order_type" field. It's identical to OrderTypeEQ.
+func OrderType(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldOrderType, v))
+}
+
+// PlanID applies equality check predicate on the "plan_id" field. It's identical to PlanIDEQ.
+func PlanID(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldPlanID, v))
+}
+
+// SubscriptionGroupID applies equality check predicate on the "subscription_group_id" field. It's identical to SubscriptionGroupIDEQ.
+func SubscriptionGroupID(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldSubscriptionGroupID, v))
+}
+
+// SubscriptionDays applies equality check predicate on the "subscription_days" field. It's identical to SubscriptionDaysEQ.
+func SubscriptionDays(v int) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldSubscriptionDays, v))
+}
+
+// ProviderInstanceID applies equality check predicate on the "provider_instance_id" field. It's identical to ProviderInstanceIDEQ.
+func ProviderInstanceID(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldProviderInstanceID, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
+func Status(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v))
+}
+
+// RefundAmount applies equality check predicate on the "refund_amount" field. It's identical to RefundAmountEQ.
+func RefundAmount(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldRefundAmount, v))
+}
+
+// RefundReason applies equality check predicate on the "refund_reason" field. It's identical to RefundReasonEQ.
+func RefundReason(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldRefundReason, v))
+}
+
+// RefundAt applies equality check predicate on the "refund_at" field. It's identical to RefundAtEQ.
+func RefundAt(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldRefundAt, v))
+}
+
+// ForceRefund applies equality check predicate on the "force_refund" field. It's identical to ForceRefundEQ.
+func ForceRefund(v bool) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldForceRefund, v))
+}
+
+// RefundRequestedAt applies equality check predicate on the "refund_requested_at" field. It's identical to RefundRequestedAtEQ.
+func RefundRequestedAt(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldRefundRequestedAt, v))
+}
+
+// RefundRequestReason applies equality check predicate on the "refund_request_reason" field. It's identical to RefundRequestReasonEQ.
+func RefundRequestReason(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldRefundRequestReason, v))
+}
+
+// RefundRequestedBy applies equality check predicate on the "refund_requested_by" field. It's identical to RefundRequestedByEQ.
+func RefundRequestedBy(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldRefundRequestedBy, v))
+}
+
+// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
+func ExpiresAt(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// PaidAt applies equality check predicate on the "paid_at" field. It's identical to PaidAtEQ.
+func PaidAt(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldPaidAt, v))
+}
+
+// CompletedAt applies equality check predicate on the "completed_at" field. It's identical to CompletedAtEQ.
+func CompletedAt(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldCompletedAt, v))
+}
+
+// FailedAt applies equality check predicate on the "failed_at" field. It's identical to FailedAtEQ.
+func FailedAt(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldFailedAt, v))
+}
+
+// FailedReason applies equality check predicate on the "failed_reason" field. It's identical to FailedReasonEQ.
+func FailedReason(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldFailedReason, v))
+}
+
+// ClientIP applies equality check predicate on the "client_ip" field. It's identical to ClientIPEQ.
+func ClientIP(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldClientIP, v))
+}
+
+// SrcHost applies equality check predicate on the "src_host" field. It's identical to SrcHostEQ.
+func SrcHost(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldSrcHost, v))
+}
+
+// SrcURL applies equality check predicate on the "src_url" field. It's identical to SrcURLEQ.
+func SrcURL(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldSrcURL, v))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UserIDEQ applies the EQ predicate on the "user_id" field.
+func UserIDEQ(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldUserID, v))
+}
+
+// UserIDNEQ applies the NEQ predicate on the "user_id" field.
+func UserIDNEQ(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldUserID, v))
+}
+
+// UserIDIn applies the In predicate on the "user_id" field.
+func UserIDIn(vs ...int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldUserID, vs...))
+}
+
+// UserIDNotIn applies the NotIn predicate on the "user_id" field.
+func UserIDNotIn(vs ...int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldUserID, vs...))
+}
+
+// UserEmailEQ applies the EQ predicate on the "user_email" field.
+func UserEmailEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldUserEmail, v))
+}
+
+// UserEmailNEQ applies the NEQ predicate on the "user_email" field.
+func UserEmailNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldUserEmail, v))
+}
+
+// UserEmailIn applies the In predicate on the "user_email" field.
+func UserEmailIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldUserEmail, vs...))
+}
+
+// UserEmailNotIn applies the NotIn predicate on the "user_email" field.
+func UserEmailNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldUserEmail, vs...))
+}
+
+// UserEmailGT applies the GT predicate on the "user_email" field.
+func UserEmailGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldUserEmail, v))
+}
+
+// UserEmailGTE applies the GTE predicate on the "user_email" field.
+func UserEmailGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldUserEmail, v))
+}
+
+// UserEmailLT applies the LT predicate on the "user_email" field.
+func UserEmailLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldUserEmail, v))
+}
+
+// UserEmailLTE applies the LTE predicate on the "user_email" field.
+func UserEmailLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldUserEmail, v))
+}
+
+// UserEmailContains applies the Contains predicate on the "user_email" field.
+func UserEmailContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldUserEmail, v))
+}
+
+// UserEmailHasPrefix applies the HasPrefix predicate on the "user_email" field.
+func UserEmailHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldUserEmail, v))
+}
+
+// UserEmailHasSuffix applies the HasSuffix predicate on the "user_email" field.
+func UserEmailHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldUserEmail, v))
+}
+
+// UserEmailEqualFold applies the EqualFold predicate on the "user_email" field.
+func UserEmailEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldUserEmail, v))
+}
+
+// UserEmailContainsFold applies the ContainsFold predicate on the "user_email" field.
+func UserEmailContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldUserEmail, v))
+}
+
+// UserNameEQ applies the EQ predicate on the "user_name" field.
+func UserNameEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldUserName, v))
+}
+
+// UserNameNEQ applies the NEQ predicate on the "user_name" field.
+func UserNameNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldUserName, v))
+}
+
+// UserNameIn applies the In predicate on the "user_name" field.
+func UserNameIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldUserName, vs...))
+}
+
+// UserNameNotIn applies the NotIn predicate on the "user_name" field.
+func UserNameNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldUserName, vs...))
+}
+
+// UserNameGT applies the GT predicate on the "user_name" field.
+func UserNameGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldUserName, v))
+}
+
+// UserNameGTE applies the GTE predicate on the "user_name" field.
+func UserNameGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldUserName, v))
+}
+
+// UserNameLT applies the LT predicate on the "user_name" field.
+func UserNameLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldUserName, v))
+}
+
+// UserNameLTE applies the LTE predicate on the "user_name" field.
+func UserNameLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldUserName, v))
+}
+
+// UserNameContains applies the Contains predicate on the "user_name" field.
+func UserNameContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldUserName, v))
+}
+
+// UserNameHasPrefix applies the HasPrefix predicate on the "user_name" field.
+func UserNameHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldUserName, v))
+}
+
+// UserNameHasSuffix applies the HasSuffix predicate on the "user_name" field.
+func UserNameHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldUserName, v))
+}
+
+// UserNameEqualFold applies the EqualFold predicate on the "user_name" field.
+func UserNameEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldUserName, v))
+}
+
+// UserNameContainsFold applies the ContainsFold predicate on the "user_name" field.
+func UserNameContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldUserName, v))
+}
+
+// UserNotesEQ applies the EQ predicate on the "user_notes" field.
+func UserNotesEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldUserNotes, v))
+}
+
+// UserNotesNEQ applies the NEQ predicate on the "user_notes" field.
+func UserNotesNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldUserNotes, v))
+}
+
+// UserNotesIn applies the In predicate on the "user_notes" field.
+func UserNotesIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldUserNotes, vs...))
+}
+
+// UserNotesNotIn applies the NotIn predicate on the "user_notes" field.
+func UserNotesNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldUserNotes, vs...))
+}
+
+// UserNotesGT applies the GT predicate on the "user_notes" field.
+func UserNotesGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldUserNotes, v))
+}
+
+// UserNotesGTE applies the GTE predicate on the "user_notes" field.
+func UserNotesGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldUserNotes, v))
+}
+
+// UserNotesLT applies the LT predicate on the "user_notes" field.
+func UserNotesLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldUserNotes, v))
+}
+
+// UserNotesLTE applies the LTE predicate on the "user_notes" field.
+func UserNotesLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldUserNotes, v))
+}
+
+// UserNotesContains applies the Contains predicate on the "user_notes" field.
+func UserNotesContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldUserNotes, v))
+}
+
+// UserNotesHasPrefix applies the HasPrefix predicate on the "user_notes" field.
+func UserNotesHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldUserNotes, v))
+}
+
+// UserNotesHasSuffix applies the HasSuffix predicate on the "user_notes" field.
+func UserNotesHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldUserNotes, v))
+}
+
+// UserNotesIsNil applies the IsNil predicate on the "user_notes" field.
+func UserNotesIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldUserNotes))
+}
+
+// UserNotesNotNil applies the NotNil predicate on the "user_notes" field.
+func UserNotesNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldUserNotes))
+}
+
+// UserNotesEqualFold applies the EqualFold predicate on the "user_notes" field.
+func UserNotesEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldUserNotes, v))
+}
+
+// UserNotesContainsFold applies the ContainsFold predicate on the "user_notes" field.
+func UserNotesContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldUserNotes, v))
+}
+
+// AmountEQ applies the EQ predicate on the "amount" field.
+func AmountEQ(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldAmount, v))
+}
+
+// AmountNEQ applies the NEQ predicate on the "amount" field.
+func AmountNEQ(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldAmount, v))
+}
+
+// AmountIn applies the In predicate on the "amount" field.
+func AmountIn(vs ...float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldAmount, vs...))
+}
+
+// AmountNotIn applies the NotIn predicate on the "amount" field.
+func AmountNotIn(vs ...float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldAmount, vs...))
+}
+
+// AmountGT applies the GT predicate on the "amount" field.
+func AmountGT(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldAmount, v))
+}
+
+// AmountGTE applies the GTE predicate on the "amount" field.
+func AmountGTE(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldAmount, v))
+}
+
+// AmountLT applies the LT predicate on the "amount" field.
+func AmountLT(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldAmount, v))
+}
+
+// AmountLTE applies the LTE predicate on the "amount" field.
+func AmountLTE(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldAmount, v))
+}
+
+// PayAmountEQ applies the EQ predicate on the "pay_amount" field.
+func PayAmountEQ(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldPayAmount, v))
+}
+
+// PayAmountNEQ applies the NEQ predicate on the "pay_amount" field.
+func PayAmountNEQ(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldPayAmount, v))
+}
+
+// PayAmountIn applies the In predicate on the "pay_amount" field.
+func PayAmountIn(vs ...float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldPayAmount, vs...))
+}
+
+// PayAmountNotIn applies the NotIn predicate on the "pay_amount" field.
+func PayAmountNotIn(vs ...float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldPayAmount, vs...))
+}
+
+// PayAmountGT applies the GT predicate on the "pay_amount" field.
+func PayAmountGT(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldPayAmount, v))
+}
+
+// PayAmountGTE applies the GTE predicate on the "pay_amount" field.
+func PayAmountGTE(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldPayAmount, v))
+}
+
+// PayAmountLT applies the LT predicate on the "pay_amount" field.
+func PayAmountLT(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldPayAmount, v))
+}
+
+// PayAmountLTE applies the LTE predicate on the "pay_amount" field.
+func PayAmountLTE(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldPayAmount, v))
+}
+
+// FeeRateEQ applies the EQ predicate on the "fee_rate" field.
+func FeeRateEQ(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldFeeRate, v))
+}
+
+// FeeRateNEQ applies the NEQ predicate on the "fee_rate" field.
+func FeeRateNEQ(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldFeeRate, v))
+}
+
+// FeeRateIn applies the In predicate on the "fee_rate" field.
+func FeeRateIn(vs ...float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldFeeRate, vs...))
+}
+
+// FeeRateNotIn applies the NotIn predicate on the "fee_rate" field.
+func FeeRateNotIn(vs ...float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldFeeRate, vs...))
+}
+
+// FeeRateGT applies the GT predicate on the "fee_rate" field.
+func FeeRateGT(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldFeeRate, v))
+}
+
+// FeeRateGTE applies the GTE predicate on the "fee_rate" field.
+func FeeRateGTE(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldFeeRate, v))
+}
+
+// FeeRateLT applies the LT predicate on the "fee_rate" field.
+func FeeRateLT(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldFeeRate, v))
+}
+
+// FeeRateLTE applies the LTE predicate on the "fee_rate" field.
+func FeeRateLTE(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldFeeRate, v))
+}
+
+// RechargeCodeEQ applies the EQ predicate on the "recharge_code" field.
+func RechargeCodeEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldRechargeCode, v))
+}
+
+// RechargeCodeNEQ applies the NEQ predicate on the "recharge_code" field.
+func RechargeCodeNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldRechargeCode, v))
+}
+
+// RechargeCodeIn applies the In predicate on the "recharge_code" field.
+func RechargeCodeIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldRechargeCode, vs...))
+}
+
+// RechargeCodeNotIn applies the NotIn predicate on the "recharge_code" field.
+func RechargeCodeNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldRechargeCode, vs...))
+}
+
+// RechargeCodeGT applies the GT predicate on the "recharge_code" field.
+func RechargeCodeGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldRechargeCode, v))
+}
+
+// RechargeCodeGTE applies the GTE predicate on the "recharge_code" field.
+func RechargeCodeGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldRechargeCode, v))
+}
+
+// RechargeCodeLT applies the LT predicate on the "recharge_code" field.
+func RechargeCodeLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldRechargeCode, v))
+}
+
+// RechargeCodeLTE applies the LTE predicate on the "recharge_code" field.
+func RechargeCodeLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldRechargeCode, v))
+}
+
+// RechargeCodeContains applies the Contains predicate on the "recharge_code" field.
+func RechargeCodeContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldRechargeCode, v))
+}
+
+// RechargeCodeHasPrefix applies the HasPrefix predicate on the "recharge_code" field.
+func RechargeCodeHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldRechargeCode, v))
+}
+
+// RechargeCodeHasSuffix applies the HasSuffix predicate on the "recharge_code" field.
+func RechargeCodeHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldRechargeCode, v))
+}
+
+// RechargeCodeEqualFold applies the EqualFold predicate on the "recharge_code" field.
+func RechargeCodeEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldRechargeCode, v))
+}
+
+// RechargeCodeContainsFold applies the ContainsFold predicate on the "recharge_code" field.
+func RechargeCodeContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldRechargeCode, v))
+}
+
+// OutTradeNoEQ applies the EQ predicate on the "out_trade_no" field.
+func OutTradeNoEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldOutTradeNo, v))
+}
+
+// OutTradeNoNEQ applies the NEQ predicate on the "out_trade_no" field.
+func OutTradeNoNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldOutTradeNo, v))
+}
+
+// OutTradeNoIn applies the In predicate on the "out_trade_no" field.
+func OutTradeNoIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldOutTradeNo, vs...))
+}
+
+// OutTradeNoNotIn applies the NotIn predicate on the "out_trade_no" field.
+func OutTradeNoNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldOutTradeNo, vs...))
+}
+
+// OutTradeNoGT applies the GT predicate on the "out_trade_no" field.
+func OutTradeNoGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldOutTradeNo, v))
+}
+
+// OutTradeNoGTE applies the GTE predicate on the "out_trade_no" field.
+func OutTradeNoGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldOutTradeNo, v))
+}
+
+// OutTradeNoLT applies the LT predicate on the "out_trade_no" field.
+func OutTradeNoLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldOutTradeNo, v))
+}
+
+// OutTradeNoLTE applies the LTE predicate on the "out_trade_no" field.
+func OutTradeNoLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldOutTradeNo, v))
+}
+
+// OutTradeNoContains applies the Contains predicate on the "out_trade_no" field.
+func OutTradeNoContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldOutTradeNo, v))
+}
+
+// OutTradeNoHasPrefix applies the HasPrefix predicate on the "out_trade_no" field.
+func OutTradeNoHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldOutTradeNo, v))
+}
+
+// OutTradeNoHasSuffix applies the HasSuffix predicate on the "out_trade_no" field.
+func OutTradeNoHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldOutTradeNo, v))
+}
+
+// OutTradeNoEqualFold applies the EqualFold predicate on the "out_trade_no" field.
+func OutTradeNoEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldOutTradeNo, v))
+}
+
+// OutTradeNoContainsFold applies the ContainsFold predicate on the "out_trade_no" field.
+func OutTradeNoContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldOutTradeNo, v))
+}
+
+// PaymentTypeEQ applies the EQ predicate on the "payment_type" field.
+func PaymentTypeEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldPaymentType, v))
+}
+
+// PaymentTypeNEQ applies the NEQ predicate on the "payment_type" field.
+func PaymentTypeNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldPaymentType, v))
+}
+
+// PaymentTypeIn applies the In predicate on the "payment_type" field.
+func PaymentTypeIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldPaymentType, vs...))
+}
+
+// PaymentTypeNotIn applies the NotIn predicate on the "payment_type" field.
+func PaymentTypeNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldPaymentType, vs...))
+}
+
+// PaymentTypeGT applies the GT predicate on the "payment_type" field.
+func PaymentTypeGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldPaymentType, v))
+}
+
+// PaymentTypeGTE applies the GTE predicate on the "payment_type" field.
+func PaymentTypeGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldPaymentType, v))
+}
+
+// PaymentTypeLT applies the LT predicate on the "payment_type" field.
+func PaymentTypeLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldPaymentType, v))
+}
+
+// PaymentTypeLTE applies the LTE predicate on the "payment_type" field.
+func PaymentTypeLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldPaymentType, v))
+}
+
+// PaymentTypeContains applies the Contains predicate on the "payment_type" field.
+func PaymentTypeContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldPaymentType, v))
+}
+
+// PaymentTypeHasPrefix applies the HasPrefix predicate on the "payment_type" field.
+func PaymentTypeHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldPaymentType, v))
+}
+
+// PaymentTypeHasSuffix applies the HasSuffix predicate on the "payment_type" field.
+func PaymentTypeHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldPaymentType, v))
+}
+
+// PaymentTypeEqualFold applies the EqualFold predicate on the "payment_type" field.
+func PaymentTypeEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldPaymentType, v))
+}
+
+// PaymentTypeContainsFold applies the ContainsFold predicate on the "payment_type" field.
+func PaymentTypeContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldPaymentType, v))
+}
+
+// PaymentTradeNoEQ applies the EQ predicate on the "payment_trade_no" field.
+func PaymentTradeNoEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldPaymentTradeNo, v))
+}
+
+// PaymentTradeNoNEQ applies the NEQ predicate on the "payment_trade_no" field.
+func PaymentTradeNoNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldPaymentTradeNo, v))
+}
+
+// PaymentTradeNoIn applies the In predicate on the "payment_trade_no" field.
+func PaymentTradeNoIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldPaymentTradeNo, vs...))
+}
+
+// PaymentTradeNoNotIn applies the NotIn predicate on the "payment_trade_no" field.
+func PaymentTradeNoNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldPaymentTradeNo, vs...))
+}
+
+// PaymentTradeNoGT applies the GT predicate on the "payment_trade_no" field.
+func PaymentTradeNoGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldPaymentTradeNo, v))
+}
+
+// PaymentTradeNoGTE applies the GTE predicate on the "payment_trade_no" field.
+func PaymentTradeNoGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldPaymentTradeNo, v))
+}
+
+// PaymentTradeNoLT applies the LT predicate on the "payment_trade_no" field.
+func PaymentTradeNoLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldPaymentTradeNo, v))
+}
+
+// PaymentTradeNoLTE applies the LTE predicate on the "payment_trade_no" field.
+func PaymentTradeNoLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldPaymentTradeNo, v))
+}
+
+// PaymentTradeNoContains applies the Contains predicate on the "payment_trade_no" field.
+func PaymentTradeNoContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldPaymentTradeNo, v))
+}
+
+// PaymentTradeNoHasPrefix applies the HasPrefix predicate on the "payment_trade_no" field.
+func PaymentTradeNoHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldPaymentTradeNo, v))
+}
+
+// PaymentTradeNoHasSuffix applies the HasSuffix predicate on the "payment_trade_no" field.
+func PaymentTradeNoHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldPaymentTradeNo, v))
+}
+
+// PaymentTradeNoEqualFold applies the EqualFold predicate on the "payment_trade_no" field.
+func PaymentTradeNoEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldPaymentTradeNo, v))
+}
+
+// PaymentTradeNoContainsFold applies the ContainsFold predicate on the "payment_trade_no" field.
+func PaymentTradeNoContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldPaymentTradeNo, v))
+}
+
+// PayURLEQ applies the EQ predicate on the "pay_url" field.
+func PayURLEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldPayURL, v))
+}
+
+// PayURLNEQ applies the NEQ predicate on the "pay_url" field.
+func PayURLNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldPayURL, v))
+}
+
+// PayURLIn applies the In predicate on the "pay_url" field.
+func PayURLIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldPayURL, vs...))
+}
+
+// PayURLNotIn applies the NotIn predicate on the "pay_url" field.
+func PayURLNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldPayURL, vs...))
+}
+
+// PayURLGT applies the GT predicate on the "pay_url" field.
+func PayURLGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldPayURL, v))
+}
+
+// PayURLGTE applies the GTE predicate on the "pay_url" field.
+func PayURLGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldPayURL, v))
+}
+
+// PayURLLT applies the LT predicate on the "pay_url" field.
+func PayURLLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldPayURL, v))
+}
+
+// PayURLLTE applies the LTE predicate on the "pay_url" field.
+func PayURLLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldPayURL, v))
+}
+
+// PayURLContains applies the Contains predicate on the "pay_url" field.
+func PayURLContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldPayURL, v))
+}
+
+// PayURLHasPrefix applies the HasPrefix predicate on the "pay_url" field.
+func PayURLHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldPayURL, v))
+}
+
+// PayURLHasSuffix applies the HasSuffix predicate on the "pay_url" field.
+func PayURLHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldPayURL, v))
+}
+
+// PayURLIsNil applies the IsNil predicate on the "pay_url" field.
+func PayURLIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldPayURL))
+}
+
+// PayURLNotNil applies the NotNil predicate on the "pay_url" field.
+func PayURLNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldPayURL))
+}
+
+// PayURLEqualFold applies the EqualFold predicate on the "pay_url" field.
+func PayURLEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldPayURL, v))
+}
+
+// PayURLContainsFold applies the ContainsFold predicate on the "pay_url" field.
+func PayURLContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldPayURL, v))
+}
+
+// QrCodeEQ applies the EQ predicate on the "qr_code" field.
+func QrCodeEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldQrCode, v))
+}
+
+// QrCodeNEQ applies the NEQ predicate on the "qr_code" field.
+func QrCodeNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldQrCode, v))
+}
+
+// QrCodeIn applies the In predicate on the "qr_code" field.
+func QrCodeIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldQrCode, vs...))
+}
+
+// QrCodeNotIn applies the NotIn predicate on the "qr_code" field.
+func QrCodeNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldQrCode, vs...))
+}
+
+// QrCodeGT applies the GT predicate on the "qr_code" field.
+func QrCodeGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldQrCode, v))
+}
+
+// QrCodeGTE applies the GTE predicate on the "qr_code" field.
+func QrCodeGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldQrCode, v))
+}
+
+// QrCodeLT applies the LT predicate on the "qr_code" field.
+func QrCodeLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldQrCode, v))
+}
+
+// QrCodeLTE applies the LTE predicate on the "qr_code" field.
+func QrCodeLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldQrCode, v))
+}
+
+// QrCodeContains applies the Contains predicate on the "qr_code" field.
+func QrCodeContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldQrCode, v))
+}
+
+// QrCodeHasPrefix applies the HasPrefix predicate on the "qr_code" field.
+func QrCodeHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldQrCode, v))
+}
+
+// QrCodeHasSuffix applies the HasSuffix predicate on the "qr_code" field.
+func QrCodeHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldQrCode, v))
+}
+
+// QrCodeIsNil applies the IsNil predicate on the "qr_code" field.
+func QrCodeIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldQrCode))
+}
+
+// QrCodeNotNil applies the NotNil predicate on the "qr_code" field.
+func QrCodeNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldQrCode))
+}
+
+// QrCodeEqualFold applies the EqualFold predicate on the "qr_code" field.
+func QrCodeEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldQrCode, v))
+}
+
+// QrCodeContainsFold applies the ContainsFold predicate on the "qr_code" field.
+func QrCodeContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldQrCode, v))
+}
+
+// QrCodeImgEQ applies the EQ predicate on the "qr_code_img" field.
+func QrCodeImgEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldQrCodeImg, v))
+}
+
+// QrCodeImgNEQ applies the NEQ predicate on the "qr_code_img" field.
+func QrCodeImgNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldQrCodeImg, v))
+}
+
+// QrCodeImgIn applies the In predicate on the "qr_code_img" field.
+func QrCodeImgIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldQrCodeImg, vs...))
+}
+
+// QrCodeImgNotIn applies the NotIn predicate on the "qr_code_img" field.
+func QrCodeImgNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldQrCodeImg, vs...))
+}
+
+// QrCodeImgGT applies the GT predicate on the "qr_code_img" field.
+func QrCodeImgGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldQrCodeImg, v))
+}
+
+// QrCodeImgGTE applies the GTE predicate on the "qr_code_img" field.
+func QrCodeImgGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldQrCodeImg, v))
+}
+
+// QrCodeImgLT applies the LT predicate on the "qr_code_img" field.
+func QrCodeImgLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldQrCodeImg, v))
+}
+
+// QrCodeImgLTE applies the LTE predicate on the "qr_code_img" field.
+func QrCodeImgLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldQrCodeImg, v))
+}
+
+// QrCodeImgContains applies the Contains predicate on the "qr_code_img" field.
+func QrCodeImgContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldQrCodeImg, v))
+}
+
+// QrCodeImgHasPrefix applies the HasPrefix predicate on the "qr_code_img" field.
+func QrCodeImgHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldQrCodeImg, v))
+}
+
+// QrCodeImgHasSuffix applies the HasSuffix predicate on the "qr_code_img" field.
+func QrCodeImgHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldQrCodeImg, v))
+}
+
+// QrCodeImgIsNil applies the IsNil predicate on the "qr_code_img" field.
+func QrCodeImgIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldQrCodeImg))
+}
+
+// QrCodeImgNotNil applies the NotNil predicate on the "qr_code_img" field.
+func QrCodeImgNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldQrCodeImg))
+}
+
+// QrCodeImgEqualFold applies the EqualFold predicate on the "qr_code_img" field.
+func QrCodeImgEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldQrCodeImg, v))
+}
+
+// QrCodeImgContainsFold applies the ContainsFold predicate on the "qr_code_img" field.
+func QrCodeImgContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldQrCodeImg, v))
+}
+
+// OrderTypeEQ applies the EQ predicate on the "order_type" field.
+func OrderTypeEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldOrderType, v))
+}
+
+// OrderTypeNEQ applies the NEQ predicate on the "order_type" field.
+func OrderTypeNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldOrderType, v))
+}
+
+// OrderTypeIn applies the In predicate on the "order_type" field.
+func OrderTypeIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldOrderType, vs...))
+}
+
+// OrderTypeNotIn applies the NotIn predicate on the "order_type" field.
+func OrderTypeNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldOrderType, vs...))
+}
+
+// OrderTypeGT applies the GT predicate on the "order_type" field.
+func OrderTypeGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldOrderType, v))
+}
+
+// OrderTypeGTE applies the GTE predicate on the "order_type" field.
+func OrderTypeGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldOrderType, v))
+}
+
+// OrderTypeLT applies the LT predicate on the "order_type" field.
+func OrderTypeLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldOrderType, v))
+}
+
+// OrderTypeLTE applies the LTE predicate on the "order_type" field.
+func OrderTypeLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldOrderType, v))
+}
+
+// OrderTypeContains applies the Contains predicate on the "order_type" field.
+func OrderTypeContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldOrderType, v))
+}
+
+// OrderTypeHasPrefix applies the HasPrefix predicate on the "order_type" field.
+func OrderTypeHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldOrderType, v))
+}
+
+// OrderTypeHasSuffix applies the HasSuffix predicate on the "order_type" field.
+func OrderTypeHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldOrderType, v))
+}
+
+// OrderTypeEqualFold applies the EqualFold predicate on the "order_type" field.
+func OrderTypeEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldOrderType, v))
+}
+
+// OrderTypeContainsFold applies the ContainsFold predicate on the "order_type" field.
+func OrderTypeContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldOrderType, v))
+}
+
+// PlanIDEQ applies the EQ predicate on the "plan_id" field.
+func PlanIDEQ(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldPlanID, v))
+}
+
+// PlanIDNEQ applies the NEQ predicate on the "plan_id" field.
+func PlanIDNEQ(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldPlanID, v))
+}
+
+// PlanIDIn applies the In predicate on the "plan_id" field.
+func PlanIDIn(vs ...int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldPlanID, vs...))
+}
+
+// PlanIDNotIn applies the NotIn predicate on the "plan_id" field.
+func PlanIDNotIn(vs ...int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldPlanID, vs...))
+}
+
+// PlanIDGT applies the GT predicate on the "plan_id" field.
+func PlanIDGT(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldPlanID, v))
+}
+
+// PlanIDGTE applies the GTE predicate on the "plan_id" field.
+func PlanIDGTE(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldPlanID, v))
+}
+
+// PlanIDLT applies the LT predicate on the "plan_id" field.
+func PlanIDLT(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldPlanID, v))
+}
+
+// PlanIDLTE applies the LTE predicate on the "plan_id" field.
+func PlanIDLTE(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldPlanID, v))
+}
+
+// PlanIDIsNil applies the IsNil predicate on the "plan_id" field.
+func PlanIDIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldPlanID))
+}
+
+// PlanIDNotNil applies the NotNil predicate on the "plan_id" field.
+func PlanIDNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldPlanID))
+}
+
+// SubscriptionGroupIDEQ applies the EQ predicate on the "subscription_group_id" field.
+func SubscriptionGroupIDEQ(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldSubscriptionGroupID, v))
+}
+
+// SubscriptionGroupIDNEQ applies the NEQ predicate on the "subscription_group_id" field.
+func SubscriptionGroupIDNEQ(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldSubscriptionGroupID, v))
+}
+
+// SubscriptionGroupIDIn applies the In predicate on the "subscription_group_id" field.
+func SubscriptionGroupIDIn(vs ...int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldSubscriptionGroupID, vs...))
+}
+
+// SubscriptionGroupIDNotIn applies the NotIn predicate on the "subscription_group_id" field.
+func SubscriptionGroupIDNotIn(vs ...int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldSubscriptionGroupID, vs...))
+}
+
+// SubscriptionGroupIDGT applies the GT predicate on the "subscription_group_id" field.
+func SubscriptionGroupIDGT(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldSubscriptionGroupID, v))
+}
+
+// SubscriptionGroupIDGTE applies the GTE predicate on the "subscription_group_id" field.
+func SubscriptionGroupIDGTE(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldSubscriptionGroupID, v))
+}
+
+// SubscriptionGroupIDLT applies the LT predicate on the "subscription_group_id" field.
+func SubscriptionGroupIDLT(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldSubscriptionGroupID, v))
+}
+
+// SubscriptionGroupIDLTE applies the LTE predicate on the "subscription_group_id" field.
+func SubscriptionGroupIDLTE(v int64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldSubscriptionGroupID, v))
+}
+
+// SubscriptionGroupIDIsNil applies the IsNil predicate on the "subscription_group_id" field.
+func SubscriptionGroupIDIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldSubscriptionGroupID))
+}
+
+// SubscriptionGroupIDNotNil applies the NotNil predicate on the "subscription_group_id" field.
+func SubscriptionGroupIDNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldSubscriptionGroupID))
+}
+
+// SubscriptionDaysEQ applies the EQ predicate on the "subscription_days" field.
+func SubscriptionDaysEQ(v int) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldSubscriptionDays, v))
+}
+
+// SubscriptionDaysNEQ applies the NEQ predicate on the "subscription_days" field.
+func SubscriptionDaysNEQ(v int) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldSubscriptionDays, v))
+}
+
+// SubscriptionDaysIn applies the In predicate on the "subscription_days" field.
+func SubscriptionDaysIn(vs ...int) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldSubscriptionDays, vs...))
+}
+
+// SubscriptionDaysNotIn applies the NotIn predicate on the "subscription_days" field.
+func SubscriptionDaysNotIn(vs ...int) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldSubscriptionDays, vs...))
+}
+
+// SubscriptionDaysGT applies the GT predicate on the "subscription_days" field.
+func SubscriptionDaysGT(v int) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldSubscriptionDays, v))
+}
+
+// SubscriptionDaysGTE applies the GTE predicate on the "subscription_days" field.
+func SubscriptionDaysGTE(v int) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldSubscriptionDays, v))
+}
+
+// SubscriptionDaysLT applies the LT predicate on the "subscription_days" field.
+func SubscriptionDaysLT(v int) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldSubscriptionDays, v))
+}
+
+// SubscriptionDaysLTE applies the LTE predicate on the "subscription_days" field.
+func SubscriptionDaysLTE(v int) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldSubscriptionDays, v))
+}
+
+// SubscriptionDaysIsNil applies the IsNil predicate on the "subscription_days" field.
+func SubscriptionDaysIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldSubscriptionDays))
+}
+
+// SubscriptionDaysNotNil applies the NotNil predicate on the "subscription_days" field.
+func SubscriptionDaysNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldSubscriptionDays))
+}
+
+// ProviderInstanceIDEQ applies the EQ predicate on the "provider_instance_id" field.
+func ProviderInstanceIDEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldProviderInstanceID, v))
+}
+
+// ProviderInstanceIDNEQ applies the NEQ predicate on the "provider_instance_id" field.
+func ProviderInstanceIDNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldProviderInstanceID, v))
+}
+
+// ProviderInstanceIDIn applies the In predicate on the "provider_instance_id" field.
+func ProviderInstanceIDIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldProviderInstanceID, vs...))
+}
+
+// ProviderInstanceIDNotIn applies the NotIn predicate on the "provider_instance_id" field.
+func ProviderInstanceIDNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldProviderInstanceID, vs...))
+}
+
+// ProviderInstanceIDGT applies the GT predicate on the "provider_instance_id" field.
+func ProviderInstanceIDGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldProviderInstanceID, v))
+}
+
+// ProviderInstanceIDGTE applies the GTE predicate on the "provider_instance_id" field.
+func ProviderInstanceIDGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldProviderInstanceID, v))
+}
+
+// ProviderInstanceIDLT applies the LT predicate on the "provider_instance_id" field.
+func ProviderInstanceIDLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldProviderInstanceID, v))
+}
+
+// ProviderInstanceIDLTE applies the LTE predicate on the "provider_instance_id" field.
+func ProviderInstanceIDLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldProviderInstanceID, v))
+}
+
+// ProviderInstanceIDContains applies the Contains predicate on the "provider_instance_id" field.
+func ProviderInstanceIDContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldProviderInstanceID, v))
+}
+
+// ProviderInstanceIDHasPrefix applies the HasPrefix predicate on the "provider_instance_id" field.
+func ProviderInstanceIDHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldProviderInstanceID, v))
+}
+
+// ProviderInstanceIDHasSuffix applies the HasSuffix predicate on the "provider_instance_id" field.
+func ProviderInstanceIDHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldProviderInstanceID, v))
+}
+
+// ProviderInstanceIDIsNil applies the IsNil predicate on the "provider_instance_id" field.
+func ProviderInstanceIDIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderInstanceID))
+}
+
+// ProviderInstanceIDNotNil applies the NotNil predicate on the "provider_instance_id" field.
+func ProviderInstanceIDNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderInstanceID))
+}
+
+// ProviderInstanceIDEqualFold applies the EqualFold predicate on the "provider_instance_id" field.
+func ProviderInstanceIDEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldProviderInstanceID, v))
+}
+
+// ProviderInstanceIDContainsFold applies the ContainsFold predicate on the "provider_instance_id" field.
+func ProviderInstanceIDContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderInstanceID, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyIsNil applies the IsNil predicate on the "provider_key" field.
+func ProviderKeyIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderKey))
+}
+
+// ProviderKeyNotNil applies the NotNil predicate on the "provider_key" field.
+func ProviderKeyNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderKey))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ProviderSnapshotIsNil applies the IsNil predicate on the "provider_snapshot" field.
+func ProviderSnapshotIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderSnapshot))
+}
+
+// ProviderSnapshotNotNil applies the NotNil predicate on the "provider_snapshot" field.
+func ProviderSnapshotNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderSnapshot))
+}
+
+// StatusEQ applies the EQ predicate on the "status" field.
+func StatusEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v))
+}
+
+// StatusNEQ applies the NEQ predicate on the "status" field.
+func StatusNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldStatus, v))
+}
+
+// StatusIn applies the In predicate on the "status" field.
+func StatusIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldStatus, vs...))
+}
+
+// StatusNotIn applies the NotIn predicate on the "status" field.
+func StatusNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldStatus, vs...))
+}
+
+// StatusGT applies the GT predicate on the "status" field.
+func StatusGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldStatus, v))
+}
+
+// StatusGTE applies the GTE predicate on the "status" field.
+func StatusGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldStatus, v))
+}
+
+// StatusLT applies the LT predicate on the "status" field.
+func StatusLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldStatus, v))
+}
+
+// StatusLTE applies the LTE predicate on the "status" field.
+func StatusLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldStatus, v))
+}
+
+// StatusContains applies the Contains predicate on the "status" field.
+func StatusContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldStatus, v))
+}
+
+// StatusHasPrefix applies the HasPrefix predicate on the "status" field.
+func StatusHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldStatus, v))
+}
+
+// StatusHasSuffix applies the HasSuffix predicate on the "status" field.
+func StatusHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldStatus, v))
+}
+
+// StatusEqualFold applies the EqualFold predicate on the "status" field.
+func StatusEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldStatus, v))
+}
+
+// StatusContainsFold applies the ContainsFold predicate on the "status" field.
+func StatusContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldStatus, v))
+}
+
+// RefundAmountEQ applies the EQ predicate on the "refund_amount" field.
+func RefundAmountEQ(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldRefundAmount, v))
+}
+
+// RefundAmountNEQ applies the NEQ predicate on the "refund_amount" field.
+func RefundAmountNEQ(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldRefundAmount, v))
+}
+
+// RefundAmountIn applies the In predicate on the "refund_amount" field.
+func RefundAmountIn(vs ...float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldRefundAmount, vs...))
+}
+
+// RefundAmountNotIn applies the NotIn predicate on the "refund_amount" field.
+func RefundAmountNotIn(vs ...float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldRefundAmount, vs...))
+}
+
+// RefundAmountGT applies the GT predicate on the "refund_amount" field.
+func RefundAmountGT(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldRefundAmount, v))
+}
+
+// RefundAmountGTE applies the GTE predicate on the "refund_amount" field.
+func RefundAmountGTE(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldRefundAmount, v))
+}
+
+// RefundAmountLT applies the LT predicate on the "refund_amount" field.
+func RefundAmountLT(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldRefundAmount, v))
+}
+
+// RefundAmountLTE applies the LTE predicate on the "refund_amount" field.
+func RefundAmountLTE(v float64) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldRefundAmount, v))
+}
+
+// RefundReasonEQ applies the EQ predicate on the "refund_reason" field.
+func RefundReasonEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldRefundReason, v))
+}
+
+// RefundReasonNEQ applies the NEQ predicate on the "refund_reason" field.
+func RefundReasonNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldRefundReason, v))
+}
+
+// RefundReasonIn applies the In predicate on the "refund_reason" field.
+func RefundReasonIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldRefundReason, vs...))
+}
+
+// RefundReasonNotIn applies the NotIn predicate on the "refund_reason" field.
+func RefundReasonNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldRefundReason, vs...))
+}
+
+// RefundReasonGT applies the GT predicate on the "refund_reason" field.
+func RefundReasonGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldRefundReason, v))
+}
+
+// RefundReasonGTE applies the GTE predicate on the "refund_reason" field.
+func RefundReasonGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldRefundReason, v))
+}
+
+// RefundReasonLT applies the LT predicate on the "refund_reason" field.
+func RefundReasonLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldRefundReason, v))
+}
+
+// RefundReasonLTE applies the LTE predicate on the "refund_reason" field.
+func RefundReasonLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldRefundReason, v))
+}
+
+// RefundReasonContains applies the Contains predicate on the "refund_reason" field.
+func RefundReasonContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldRefundReason, v))
+}
+
+// RefundReasonHasPrefix applies the HasPrefix predicate on the "refund_reason" field.
+func RefundReasonHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldRefundReason, v))
+}
+
+// RefundReasonHasSuffix applies the HasSuffix predicate on the "refund_reason" field.
+func RefundReasonHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldRefundReason, v))
+}
+
+// RefundReasonIsNil applies the IsNil predicate on the "refund_reason" field.
+func RefundReasonIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldRefundReason))
+}
+
+// RefundReasonNotNil applies the NotNil predicate on the "refund_reason" field.
+func RefundReasonNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldRefundReason))
+}
+
+// RefundReasonEqualFold applies the EqualFold predicate on the "refund_reason" field.
+func RefundReasonEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldRefundReason, v))
+}
+
+// RefundReasonContainsFold applies the ContainsFold predicate on the "refund_reason" field.
+func RefundReasonContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldRefundReason, v))
+}
+
+// RefundAtEQ applies the EQ predicate on the "refund_at" field.
+func RefundAtEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldRefundAt, v))
+}
+
+// RefundAtNEQ applies the NEQ predicate on the "refund_at" field.
+func RefundAtNEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldRefundAt, v))
+}
+
+// RefundAtIn applies the In predicate on the "refund_at" field.
+func RefundAtIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldRefundAt, vs...))
+}
+
+// RefundAtNotIn applies the NotIn predicate on the "refund_at" field.
+func RefundAtNotIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldRefundAt, vs...))
+}
+
+// RefundAtGT applies the GT predicate on the "refund_at" field.
+func RefundAtGT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldRefundAt, v))
+}
+
+// RefundAtGTE applies the GTE predicate on the "refund_at" field.
+func RefundAtGTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldRefundAt, v))
+}
+
+// RefundAtLT applies the LT predicate on the "refund_at" field.
+func RefundAtLT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldRefundAt, v))
+}
+
+// RefundAtLTE applies the LTE predicate on the "refund_at" field.
+func RefundAtLTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldRefundAt, v))
+}
+
+// RefundAtIsNil applies the IsNil predicate on the "refund_at" field.
+func RefundAtIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldRefundAt))
+}
+
+// RefundAtNotNil applies the NotNil predicate on the "refund_at" field.
+func RefundAtNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldRefundAt))
+}
+
+// ForceRefundEQ applies the EQ predicate on the "force_refund" field.
+func ForceRefundEQ(v bool) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldForceRefund, v))
+}
+
+// ForceRefundNEQ applies the NEQ predicate on the "force_refund" field.
+func ForceRefundNEQ(v bool) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldForceRefund, v))
+}
+
+// RefundRequestedAtEQ applies the EQ predicate on the "refund_requested_at" field.
+func RefundRequestedAtEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldRefundRequestedAt, v))
+}
+
+// RefundRequestedAtNEQ applies the NEQ predicate on the "refund_requested_at" field.
+func RefundRequestedAtNEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldRefundRequestedAt, v))
+}
+
+// RefundRequestedAtIn applies the In predicate on the "refund_requested_at" field.
+func RefundRequestedAtIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldRefundRequestedAt, vs...))
+}
+
+// RefundRequestedAtNotIn applies the NotIn predicate on the "refund_requested_at" field.
+func RefundRequestedAtNotIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldRefundRequestedAt, vs...))
+}
+
+// RefundRequestedAtGT applies the GT predicate on the "refund_requested_at" field.
+func RefundRequestedAtGT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldRefundRequestedAt, v))
+}
+
+// RefundRequestedAtGTE applies the GTE predicate on the "refund_requested_at" field.
+func RefundRequestedAtGTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldRefundRequestedAt, v))
+}
+
+// RefundRequestedAtLT applies the LT predicate on the "refund_requested_at" field.
+func RefundRequestedAtLT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldRefundRequestedAt, v))
+}
+
+// RefundRequestedAtLTE applies the LTE predicate on the "refund_requested_at" field.
+func RefundRequestedAtLTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldRefundRequestedAt, v))
+}
+
+// RefundRequestedAtIsNil applies the IsNil predicate on the "refund_requested_at" field.
+func RefundRequestedAtIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldRefundRequestedAt))
+}
+
+// RefundRequestedAtNotNil applies the NotNil predicate on the "refund_requested_at" field.
+func RefundRequestedAtNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldRefundRequestedAt))
+}
+
+// RefundRequestReasonEQ applies the EQ predicate on the "refund_request_reason" field.
+func RefundRequestReasonEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldRefundRequestReason, v))
+}
+
+// RefundRequestReasonNEQ applies the NEQ predicate on the "refund_request_reason" field.
+func RefundRequestReasonNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldRefundRequestReason, v))
+}
+
+// RefundRequestReasonIn applies the In predicate on the "refund_request_reason" field.
+func RefundRequestReasonIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldRefundRequestReason, vs...))
+}
+
+// RefundRequestReasonNotIn applies the NotIn predicate on the "refund_request_reason" field.
+func RefundRequestReasonNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldRefundRequestReason, vs...))
+}
+
+// RefundRequestReasonGT applies the GT predicate on the "refund_request_reason" field.
+func RefundRequestReasonGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldRefundRequestReason, v))
+}
+
+// RefundRequestReasonGTE applies the GTE predicate on the "refund_request_reason" field.
+func RefundRequestReasonGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldRefundRequestReason, v))
+}
+
+// RefundRequestReasonLT applies the LT predicate on the "refund_request_reason" field.
+func RefundRequestReasonLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldRefundRequestReason, v))
+}
+
+// RefundRequestReasonLTE applies the LTE predicate on the "refund_request_reason" field.
+func RefundRequestReasonLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldRefundRequestReason, v))
+}
+
+// RefundRequestReasonContains applies the Contains predicate on the "refund_request_reason" field.
+func RefundRequestReasonContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldRefundRequestReason, v))
+}
+
+// RefundRequestReasonHasPrefix applies the HasPrefix predicate on the "refund_request_reason" field.
+func RefundRequestReasonHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldRefundRequestReason, v))
+}
+
+// RefundRequestReasonHasSuffix applies the HasSuffix predicate on the "refund_request_reason" field.
+func RefundRequestReasonHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldRefundRequestReason, v))
+}
+
+// RefundRequestReasonIsNil applies the IsNil predicate on the "refund_request_reason" field.
+func RefundRequestReasonIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldRefundRequestReason))
+}
+
+// RefundRequestReasonNotNil applies the NotNil predicate on the "refund_request_reason" field.
+func RefundRequestReasonNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldRefundRequestReason))
+}
+
+// RefundRequestReasonEqualFold applies the EqualFold predicate on the "refund_request_reason" field.
+func RefundRequestReasonEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldRefundRequestReason, v))
+}
+
+// RefundRequestReasonContainsFold applies the ContainsFold predicate on the "refund_request_reason" field.
+func RefundRequestReasonContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldRefundRequestReason, v))
+}
+
+// RefundRequestedByEQ applies the EQ predicate on the "refund_requested_by" field.
+func RefundRequestedByEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldRefundRequestedBy, v))
+}
+
+// RefundRequestedByNEQ applies the NEQ predicate on the "refund_requested_by" field.
+func RefundRequestedByNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldRefundRequestedBy, v))
+}
+
+// RefundRequestedByIn applies the In predicate on the "refund_requested_by" field.
+func RefundRequestedByIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldRefundRequestedBy, vs...))
+}
+
+// RefundRequestedByNotIn applies the NotIn predicate on the "refund_requested_by" field.
+func RefundRequestedByNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldRefundRequestedBy, vs...))
+}
+
+// RefundRequestedByGT applies the GT predicate on the "refund_requested_by" field.
+func RefundRequestedByGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldRefundRequestedBy, v))
+}
+
+// RefundRequestedByGTE applies the GTE predicate on the "refund_requested_by" field.
+func RefundRequestedByGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldRefundRequestedBy, v))
+}
+
+// RefundRequestedByLT applies the LT predicate on the "refund_requested_by" field.
+func RefundRequestedByLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldRefundRequestedBy, v))
+}
+
+// RefundRequestedByLTE applies the LTE predicate on the "refund_requested_by" field.
+func RefundRequestedByLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldRefundRequestedBy, v))
+}
+
+// RefundRequestedByContains applies the Contains predicate on the "refund_requested_by" field.
+func RefundRequestedByContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldRefundRequestedBy, v))
+}
+
+// RefundRequestedByHasPrefix applies the HasPrefix predicate on the "refund_requested_by" field.
+func RefundRequestedByHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldRefundRequestedBy, v))
+}
+
+// RefundRequestedByHasSuffix applies the HasSuffix predicate on the "refund_requested_by" field.
+func RefundRequestedByHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldRefundRequestedBy, v))
+}
+
+// RefundRequestedByIsNil applies the IsNil predicate on the "refund_requested_by" field.
+func RefundRequestedByIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldRefundRequestedBy))
+}
+
+// RefundRequestedByNotNil applies the NotNil predicate on the "refund_requested_by" field.
+func RefundRequestedByNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldRefundRequestedBy))
+}
+
+// RefundRequestedByEqualFold applies the EqualFold predicate on the "refund_requested_by" field.
+func RefundRequestedByEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldRefundRequestedBy, v))
+}
+
+// RefundRequestedByContainsFold applies the ContainsFold predicate on the "refund_requested_by" field.
+func RefundRequestedByContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldRefundRequestedBy, v))
+}
+
+// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
+func ExpiresAtEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
+func ExpiresAtNEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtIn applies the In predicate on the "expires_at" field.
+func ExpiresAtIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
+func ExpiresAtNotIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtGT applies the GT predicate on the "expires_at" field.
+func ExpiresAtGT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldExpiresAt, v))
+}
+
+// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
+func ExpiresAtGTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldExpiresAt, v))
+}
+
+// ExpiresAtLT applies the LT predicate on the "expires_at" field.
+func ExpiresAtLT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldExpiresAt, v))
+}
+
+// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
+func ExpiresAtLTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldExpiresAt, v))
+}
+
+// PaidAtEQ applies the EQ predicate on the "paid_at" field.
+func PaidAtEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldPaidAt, v))
+}
+
+// PaidAtNEQ applies the NEQ predicate on the "paid_at" field.
+func PaidAtNEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldPaidAt, v))
+}
+
+// PaidAtIn applies the In predicate on the "paid_at" field.
+func PaidAtIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldPaidAt, vs...))
+}
+
+// PaidAtNotIn applies the NotIn predicate on the "paid_at" field.
+func PaidAtNotIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldPaidAt, vs...))
+}
+
+// PaidAtGT applies the GT predicate on the "paid_at" field.
+func PaidAtGT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldPaidAt, v))
+}
+
+// PaidAtGTE applies the GTE predicate on the "paid_at" field.
+func PaidAtGTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldPaidAt, v))
+}
+
+// PaidAtLT applies the LT predicate on the "paid_at" field.
+func PaidAtLT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldPaidAt, v))
+}
+
+// PaidAtLTE applies the LTE predicate on the "paid_at" field.
+func PaidAtLTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldPaidAt, v))
+}
+
+// PaidAtIsNil applies the IsNil predicate on the "paid_at" field.
+func PaidAtIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldPaidAt))
+}
+
+// PaidAtNotNil applies the NotNil predicate on the "paid_at" field.
+func PaidAtNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldPaidAt))
+}
+
+// CompletedAtEQ applies the EQ predicate on the "completed_at" field.
+func CompletedAtEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldCompletedAt, v))
+}
+
+// CompletedAtNEQ applies the NEQ predicate on the "completed_at" field.
+func CompletedAtNEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldCompletedAt, v))
+}
+
+// CompletedAtIn applies the In predicate on the "completed_at" field.
+func CompletedAtIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldCompletedAt, vs...))
+}
+
+// CompletedAtNotIn applies the NotIn predicate on the "completed_at" field.
+func CompletedAtNotIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldCompletedAt, vs...))
+}
+
+// CompletedAtGT applies the GT predicate on the "completed_at" field.
+func CompletedAtGT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldCompletedAt, v))
+}
+
+// CompletedAtGTE applies the GTE predicate on the "completed_at" field.
+func CompletedAtGTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldCompletedAt, v))
+}
+
+// CompletedAtLT applies the LT predicate on the "completed_at" field.
+func CompletedAtLT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldCompletedAt, v))
+}
+
+// CompletedAtLTE applies the LTE predicate on the "completed_at" field.
+func CompletedAtLTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldCompletedAt, v))
+}
+
+// CompletedAtIsNil applies the IsNil predicate on the "completed_at" field.
+func CompletedAtIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldCompletedAt))
+}
+
+// CompletedAtNotNil applies the NotNil predicate on the "completed_at" field.
+func CompletedAtNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldCompletedAt))
+}
+
+// FailedAtEQ applies the EQ predicate on the "failed_at" field.
+func FailedAtEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldFailedAt, v))
+}
+
+// FailedAtNEQ applies the NEQ predicate on the "failed_at" field.
+func FailedAtNEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldFailedAt, v))
+}
+
+// FailedAtIn applies the In predicate on the "failed_at" field.
+func FailedAtIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldFailedAt, vs...))
+}
+
+// FailedAtNotIn applies the NotIn predicate on the "failed_at" field.
+func FailedAtNotIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldFailedAt, vs...))
+}
+
+// FailedAtGT applies the GT predicate on the "failed_at" field.
+func FailedAtGT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldFailedAt, v))
+}
+
+// FailedAtGTE applies the GTE predicate on the "failed_at" field.
+func FailedAtGTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldFailedAt, v))
+}
+
+// FailedAtLT applies the LT predicate on the "failed_at" field.
+func FailedAtLT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldFailedAt, v))
+}
+
+// FailedAtLTE applies the LTE predicate on the "failed_at" field.
+func FailedAtLTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldFailedAt, v))
+}
+
+// FailedAtIsNil applies the IsNil predicate on the "failed_at" field.
+func FailedAtIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldFailedAt))
+}
+
+// FailedAtNotNil applies the NotNil predicate on the "failed_at" field.
+func FailedAtNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldFailedAt))
+}
+
+// FailedReasonEQ applies the EQ predicate on the "failed_reason" field.
+func FailedReasonEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldFailedReason, v))
+}
+
+// FailedReasonNEQ applies the NEQ predicate on the "failed_reason" field.
+func FailedReasonNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldFailedReason, v))
+}
+
+// FailedReasonIn applies the In predicate on the "failed_reason" field.
+func FailedReasonIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldFailedReason, vs...))
+}
+
+// FailedReasonNotIn applies the NotIn predicate on the "failed_reason" field.
+func FailedReasonNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldFailedReason, vs...))
+}
+
+// FailedReasonGT applies the GT predicate on the "failed_reason" field.
+func FailedReasonGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldFailedReason, v))
+}
+
+// FailedReasonGTE applies the GTE predicate on the "failed_reason" field.
+func FailedReasonGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldFailedReason, v))
+}
+
+// FailedReasonLT applies the LT predicate on the "failed_reason" field.
+func FailedReasonLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldFailedReason, v))
+}
+
+// FailedReasonLTE applies the LTE predicate on the "failed_reason" field.
+func FailedReasonLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldFailedReason, v))
+}
+
+// FailedReasonContains applies the Contains predicate on the "failed_reason" field.
+func FailedReasonContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldFailedReason, v))
+}
+
+// FailedReasonHasPrefix applies the HasPrefix predicate on the "failed_reason" field.
+func FailedReasonHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldFailedReason, v))
+}
+
+// FailedReasonHasSuffix applies the HasSuffix predicate on the "failed_reason" field.
+func FailedReasonHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldFailedReason, v))
+}
+
+// FailedReasonIsNil applies the IsNil predicate on the "failed_reason" field.
+func FailedReasonIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldFailedReason))
+}
+
+// FailedReasonNotNil applies the NotNil predicate on the "failed_reason" field.
+func FailedReasonNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldFailedReason))
+}
+
+// FailedReasonEqualFold applies the EqualFold predicate on the "failed_reason" field.
+func FailedReasonEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldFailedReason, v))
+}
+
+// FailedReasonContainsFold applies the ContainsFold predicate on the "failed_reason" field.
+func FailedReasonContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldFailedReason, v))
+}
+
+// ClientIPEQ applies the EQ predicate on the "client_ip" field.
+func ClientIPEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldClientIP, v))
+}
+
+// ClientIPNEQ applies the NEQ predicate on the "client_ip" field.
+func ClientIPNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldClientIP, v))
+}
+
+// ClientIPIn applies the In predicate on the "client_ip" field.
+func ClientIPIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldClientIP, vs...))
+}
+
+// ClientIPNotIn applies the NotIn predicate on the "client_ip" field.
+func ClientIPNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldClientIP, vs...))
+}
+
+// ClientIPGT applies the GT predicate on the "client_ip" field.
+func ClientIPGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldClientIP, v))
+}
+
+// ClientIPGTE applies the GTE predicate on the "client_ip" field.
+func ClientIPGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldClientIP, v))
+}
+
+// ClientIPLT applies the LT predicate on the "client_ip" field.
+func ClientIPLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldClientIP, v))
+}
+
+// ClientIPLTE applies the LTE predicate on the "client_ip" field.
+func ClientIPLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldClientIP, v))
+}
+
+// ClientIPContains applies the Contains predicate on the "client_ip" field.
+func ClientIPContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldClientIP, v))
+}
+
+// ClientIPHasPrefix applies the HasPrefix predicate on the "client_ip" field.
+func ClientIPHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldClientIP, v))
+}
+
+// ClientIPHasSuffix applies the HasSuffix predicate on the "client_ip" field.
+func ClientIPHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldClientIP, v))
+}
+
+// ClientIPEqualFold applies the EqualFold predicate on the "client_ip" field.
+func ClientIPEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldClientIP, v))
+}
+
+// ClientIPContainsFold applies the ContainsFold predicate on the "client_ip" field.
+func ClientIPContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldClientIP, v))
+}
+
+// SrcHostEQ applies the EQ predicate on the "src_host" field.
+func SrcHostEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldSrcHost, v))
+}
+
+// SrcHostNEQ applies the NEQ predicate on the "src_host" field.
+func SrcHostNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldSrcHost, v))
+}
+
+// SrcHostIn applies the In predicate on the "src_host" field.
+func SrcHostIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldSrcHost, vs...))
+}
+
+// SrcHostNotIn applies the NotIn predicate on the "src_host" field.
+func SrcHostNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldSrcHost, vs...))
+}
+
+// SrcHostGT applies the GT predicate on the "src_host" field.
+func SrcHostGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldSrcHost, v))
+}
+
+// SrcHostGTE applies the GTE predicate on the "src_host" field.
+func SrcHostGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldSrcHost, v))
+}
+
+// SrcHostLT applies the LT predicate on the "src_host" field.
+func SrcHostLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldSrcHost, v))
+}
+
+// SrcHostLTE applies the LTE predicate on the "src_host" field.
+func SrcHostLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldSrcHost, v))
+}
+
+// SrcHostContains applies the Contains predicate on the "src_host" field.
+func SrcHostContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldSrcHost, v))
+}
+
+// SrcHostHasPrefix applies the HasPrefix predicate on the "src_host" field.
+func SrcHostHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldSrcHost, v))
+}
+
+// SrcHostHasSuffix applies the HasSuffix predicate on the "src_host" field.
+func SrcHostHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldSrcHost, v))
+}
+
+// SrcHostEqualFold applies the EqualFold predicate on the "src_host" field.
+func SrcHostEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldSrcHost, v))
+}
+
+// SrcHostContainsFold applies the ContainsFold predicate on the "src_host" field.
+func SrcHostContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldSrcHost, v))
+}
+
+// SrcURLEQ applies the EQ predicate on the "src_url" field.
+func SrcURLEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldSrcURL, v))
+}
+
+// SrcURLNEQ applies the NEQ predicate on the "src_url" field.
+func SrcURLNEQ(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldSrcURL, v))
+}
+
+// SrcURLIn applies the In predicate on the "src_url" field.
+func SrcURLIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldSrcURL, vs...))
+}
+
+// SrcURLNotIn applies the NotIn predicate on the "src_url" field.
+func SrcURLNotIn(vs ...string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldSrcURL, vs...))
+}
+
+// SrcURLGT applies the GT predicate on the "src_url" field.
+func SrcURLGT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldSrcURL, v))
+}
+
+// SrcURLGTE applies the GTE predicate on the "src_url" field.
+func SrcURLGTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldSrcURL, v))
+}
+
+// SrcURLLT applies the LT predicate on the "src_url" field.
+func SrcURLLT(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldSrcURL, v))
+}
+
+// SrcURLLTE applies the LTE predicate on the "src_url" field.
+func SrcURLLTE(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldSrcURL, v))
+}
+
+// SrcURLContains applies the Contains predicate on the "src_url" field.
+func SrcURLContains(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContains(FieldSrcURL, v))
+}
+
+// SrcURLHasPrefix applies the HasPrefix predicate on the "src_url" field.
+func SrcURLHasPrefix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasPrefix(FieldSrcURL, v))
+}
+
+// SrcURLHasSuffix applies the HasSuffix predicate on the "src_url" field.
+func SrcURLHasSuffix(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldHasSuffix(FieldSrcURL, v))
+}
+
+// SrcURLIsNil applies the IsNil predicate on the "src_url" field.
+func SrcURLIsNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIsNull(FieldSrcURL))
+}
+
+// SrcURLNotNil applies the NotNil predicate on the "src_url" field.
+func SrcURLNotNil() predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotNull(FieldSrcURL))
+}
+
+// SrcURLEqualFold applies the EqualFold predicate on the "src_url" field.
+func SrcURLEqualFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEqualFold(FieldSrcURL, v))
+}
+
+// SrcURLContainsFold applies the ContainsFold predicate on the "src_url" field.
+func SrcURLContainsFold(v string) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldContainsFold(FieldSrcURL, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// HasUser applies the HasEdge predicate on the "user" edge.
+func HasUser() predicate.PaymentOrder {
+ return predicate.PaymentOrder(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates).
+func HasUserWith(preds ...predicate.User) predicate.PaymentOrder {
+ return predicate.PaymentOrder(func(s *sql.Selector) {
+ step := newUserStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.PaymentOrder) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.PaymentOrder) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.PaymentOrder) predicate.PaymentOrder {
+ return predicate.PaymentOrder(sql.NotPredicates(p))
+}
diff --git a/backend/ent/paymentorder_create.go b/backend/ent/paymentorder_create.go
new file mode 100644
index 00000000..3ee24f8e
--- /dev/null
+++ b/backend/ent/paymentorder_create.go
@@ -0,0 +1,3262 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PaymentOrderCreate is the builder for creating a PaymentOrder entity.
+type PaymentOrderCreate struct {
+ config
+ mutation *PaymentOrderMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetUserID sets the "user_id" field.
+func (_c *PaymentOrderCreate) SetUserID(v int64) *PaymentOrderCreate {
+ _c.mutation.SetUserID(v)
+ return _c
+}
+
+// SetUserEmail sets the "user_email" field.
+func (_c *PaymentOrderCreate) SetUserEmail(v string) *PaymentOrderCreate {
+ _c.mutation.SetUserEmail(v)
+ return _c
+}
+
+// SetUserName sets the "user_name" field.
+func (_c *PaymentOrderCreate) SetUserName(v string) *PaymentOrderCreate {
+ _c.mutation.SetUserName(v)
+ return _c
+}
+
+// SetUserNotes sets the "user_notes" field.
+func (_c *PaymentOrderCreate) SetUserNotes(v string) *PaymentOrderCreate {
+ _c.mutation.SetUserNotes(v)
+ return _c
+}
+
+// SetNillableUserNotes sets the "user_notes" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableUserNotes(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetUserNotes(*v)
+ }
+ return _c
+}
+
+// SetAmount sets the "amount" field.
+func (_c *PaymentOrderCreate) SetAmount(v float64) *PaymentOrderCreate {
+ _c.mutation.SetAmount(v)
+ return _c
+}
+
+// SetPayAmount sets the "pay_amount" field.
+func (_c *PaymentOrderCreate) SetPayAmount(v float64) *PaymentOrderCreate {
+ _c.mutation.SetPayAmount(v)
+ return _c
+}
+
+// SetFeeRate sets the "fee_rate" field.
+func (_c *PaymentOrderCreate) SetFeeRate(v float64) *PaymentOrderCreate {
+ _c.mutation.SetFeeRate(v)
+ return _c
+}
+
+// SetNillableFeeRate sets the "fee_rate" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableFeeRate(v *float64) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetFeeRate(*v)
+ }
+ return _c
+}
+
+// SetRechargeCode sets the "recharge_code" field.
+func (_c *PaymentOrderCreate) SetRechargeCode(v string) *PaymentOrderCreate {
+ _c.mutation.SetRechargeCode(v)
+ return _c
+}
+
+// SetOutTradeNo sets the "out_trade_no" field.
+func (_c *PaymentOrderCreate) SetOutTradeNo(v string) *PaymentOrderCreate {
+ _c.mutation.SetOutTradeNo(v)
+ return _c
+}
+
+// SetNillableOutTradeNo sets the "out_trade_no" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableOutTradeNo(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetOutTradeNo(*v)
+ }
+ return _c
+}
+
+// SetPaymentType sets the "payment_type" field.
+func (_c *PaymentOrderCreate) SetPaymentType(v string) *PaymentOrderCreate {
+ _c.mutation.SetPaymentType(v)
+ return _c
+}
+
+// SetPaymentTradeNo sets the "payment_trade_no" field.
+func (_c *PaymentOrderCreate) SetPaymentTradeNo(v string) *PaymentOrderCreate {
+ _c.mutation.SetPaymentTradeNo(v)
+ return _c
+}
+
+// SetPayURL sets the "pay_url" field.
+func (_c *PaymentOrderCreate) SetPayURL(v string) *PaymentOrderCreate {
+ _c.mutation.SetPayURL(v)
+ return _c
+}
+
+// SetNillablePayURL sets the "pay_url" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillablePayURL(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetPayURL(*v)
+ }
+ return _c
+}
+
+// SetQrCode sets the "qr_code" field.
+func (_c *PaymentOrderCreate) SetQrCode(v string) *PaymentOrderCreate {
+ _c.mutation.SetQrCode(v)
+ return _c
+}
+
+// SetNillableQrCode sets the "qr_code" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableQrCode(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetQrCode(*v)
+ }
+ return _c
+}
+
+// SetQrCodeImg sets the "qr_code_img" field.
+func (_c *PaymentOrderCreate) SetQrCodeImg(v string) *PaymentOrderCreate {
+ _c.mutation.SetQrCodeImg(v)
+ return _c
+}
+
+// SetNillableQrCodeImg sets the "qr_code_img" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableQrCodeImg(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetQrCodeImg(*v)
+ }
+ return _c
+}
+
+// SetOrderType sets the "order_type" field.
+func (_c *PaymentOrderCreate) SetOrderType(v string) *PaymentOrderCreate {
+ _c.mutation.SetOrderType(v)
+ return _c
+}
+
+// SetNillableOrderType sets the "order_type" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableOrderType(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetOrderType(*v)
+ }
+ return _c
+}
+
+// SetPlanID sets the "plan_id" field.
+func (_c *PaymentOrderCreate) SetPlanID(v int64) *PaymentOrderCreate {
+ _c.mutation.SetPlanID(v)
+ return _c
+}
+
+// SetNillablePlanID sets the "plan_id" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillablePlanID(v *int64) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetPlanID(*v)
+ }
+ return _c
+}
+
+// SetSubscriptionGroupID sets the "subscription_group_id" field.
+func (_c *PaymentOrderCreate) SetSubscriptionGroupID(v int64) *PaymentOrderCreate {
+ _c.mutation.SetSubscriptionGroupID(v)
+ return _c
+}
+
+// SetNillableSubscriptionGroupID sets the "subscription_group_id" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableSubscriptionGroupID(v *int64) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetSubscriptionGroupID(*v)
+ }
+ return _c
+}
+
+// SetSubscriptionDays sets the "subscription_days" field.
+func (_c *PaymentOrderCreate) SetSubscriptionDays(v int) *PaymentOrderCreate {
+ _c.mutation.SetSubscriptionDays(v)
+ return _c
+}
+
+// SetNillableSubscriptionDays sets the "subscription_days" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableSubscriptionDays(v *int) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetSubscriptionDays(*v)
+ }
+ return _c
+}
+
+// SetProviderInstanceID sets the "provider_instance_id" field.
+func (_c *PaymentOrderCreate) SetProviderInstanceID(v string) *PaymentOrderCreate {
+ _c.mutation.SetProviderInstanceID(v)
+ return _c
+}
+
+// SetNillableProviderInstanceID sets the "provider_instance_id" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableProviderInstanceID(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetProviderInstanceID(*v)
+ }
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *PaymentOrderCreate) SetProviderKey(v string) *PaymentOrderCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableProviderKey(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetProviderKey(*v)
+ }
+ return _c
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (_c *PaymentOrderCreate) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderCreate {
+ _c.mutation.SetProviderSnapshot(v)
+ return _c
+}
+
+// SetStatus sets the "status" field.
+func (_c *PaymentOrderCreate) SetStatus(v string) *PaymentOrderCreate {
+ _c.mutation.SetStatus(v)
+ return _c
+}
+
+// SetNillableStatus sets the "status" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableStatus(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetStatus(*v)
+ }
+ return _c
+}
+
+// SetRefundAmount sets the "refund_amount" field.
+func (_c *PaymentOrderCreate) SetRefundAmount(v float64) *PaymentOrderCreate {
+ _c.mutation.SetRefundAmount(v)
+ return _c
+}
+
+// SetNillableRefundAmount sets the "refund_amount" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableRefundAmount(v *float64) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetRefundAmount(*v)
+ }
+ return _c
+}
+
+// SetRefundReason sets the "refund_reason" field.
+func (_c *PaymentOrderCreate) SetRefundReason(v string) *PaymentOrderCreate {
+ _c.mutation.SetRefundReason(v)
+ return _c
+}
+
+// SetNillableRefundReason sets the "refund_reason" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableRefundReason(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetRefundReason(*v)
+ }
+ return _c
+}
+
+// SetRefundAt sets the "refund_at" field.
+func (_c *PaymentOrderCreate) SetRefundAt(v time.Time) *PaymentOrderCreate {
+ _c.mutation.SetRefundAt(v)
+ return _c
+}
+
+// SetNillableRefundAt sets the "refund_at" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableRefundAt(v *time.Time) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetRefundAt(*v)
+ }
+ return _c
+}
+
+// SetForceRefund sets the "force_refund" field.
+func (_c *PaymentOrderCreate) SetForceRefund(v bool) *PaymentOrderCreate {
+ _c.mutation.SetForceRefund(v)
+ return _c
+}
+
+// SetNillableForceRefund sets the "force_refund" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableForceRefund(v *bool) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetForceRefund(*v)
+ }
+ return _c
+}
+
+// SetRefundRequestedAt sets the "refund_requested_at" field.
+func (_c *PaymentOrderCreate) SetRefundRequestedAt(v time.Time) *PaymentOrderCreate {
+ _c.mutation.SetRefundRequestedAt(v)
+ return _c
+}
+
+// SetNillableRefundRequestedAt sets the "refund_requested_at" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableRefundRequestedAt(v *time.Time) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetRefundRequestedAt(*v)
+ }
+ return _c
+}
+
+// SetRefundRequestReason sets the "refund_request_reason" field.
+func (_c *PaymentOrderCreate) SetRefundRequestReason(v string) *PaymentOrderCreate {
+ _c.mutation.SetRefundRequestReason(v)
+ return _c
+}
+
+// SetNillableRefundRequestReason sets the "refund_request_reason" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableRefundRequestReason(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetRefundRequestReason(*v)
+ }
+ return _c
+}
+
+// SetRefundRequestedBy sets the "refund_requested_by" field.
+func (_c *PaymentOrderCreate) SetRefundRequestedBy(v string) *PaymentOrderCreate {
+ _c.mutation.SetRefundRequestedBy(v)
+ return _c
+}
+
+// SetNillableRefundRequestedBy sets the "refund_requested_by" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableRefundRequestedBy(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetRefundRequestedBy(*v)
+ }
+ return _c
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_c *PaymentOrderCreate) SetExpiresAt(v time.Time) *PaymentOrderCreate {
+ _c.mutation.SetExpiresAt(v)
+ return _c
+}
+
+// SetPaidAt sets the "paid_at" field.
+func (_c *PaymentOrderCreate) SetPaidAt(v time.Time) *PaymentOrderCreate {
+ _c.mutation.SetPaidAt(v)
+ return _c
+}
+
+// SetNillablePaidAt sets the "paid_at" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillablePaidAt(v *time.Time) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetPaidAt(*v)
+ }
+ return _c
+}
+
+// SetCompletedAt sets the "completed_at" field.
+func (_c *PaymentOrderCreate) SetCompletedAt(v time.Time) *PaymentOrderCreate {
+ _c.mutation.SetCompletedAt(v)
+ return _c
+}
+
+// SetNillableCompletedAt sets the "completed_at" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableCompletedAt(v *time.Time) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetCompletedAt(*v)
+ }
+ return _c
+}
+
+// SetFailedAt sets the "failed_at" field.
+func (_c *PaymentOrderCreate) SetFailedAt(v time.Time) *PaymentOrderCreate {
+ _c.mutation.SetFailedAt(v)
+ return _c
+}
+
+// SetNillableFailedAt sets the "failed_at" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableFailedAt(v *time.Time) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetFailedAt(*v)
+ }
+ return _c
+}
+
+// SetFailedReason sets the "failed_reason" field.
+func (_c *PaymentOrderCreate) SetFailedReason(v string) *PaymentOrderCreate {
+ _c.mutation.SetFailedReason(v)
+ return _c
+}
+
+// SetNillableFailedReason sets the "failed_reason" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableFailedReason(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetFailedReason(*v)
+ }
+ return _c
+}
+
+// SetClientIP sets the "client_ip" field.
+func (_c *PaymentOrderCreate) SetClientIP(v string) *PaymentOrderCreate {
+ _c.mutation.SetClientIP(v)
+ return _c
+}
+
+// SetSrcHost sets the "src_host" field.
+func (_c *PaymentOrderCreate) SetSrcHost(v string) *PaymentOrderCreate {
+ _c.mutation.SetSrcHost(v)
+ return _c
+}
+
+// SetSrcURL sets the "src_url" field.
+func (_c *PaymentOrderCreate) SetSrcURL(v string) *PaymentOrderCreate {
+ _c.mutation.SetSrcURL(v)
+ return _c
+}
+
+// SetNillableSrcURL sets the "src_url" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableSrcURL(v *string) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetSrcURL(*v)
+ }
+ return _c
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *PaymentOrderCreate) SetCreatedAt(v time.Time) *PaymentOrderCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableCreatedAt(v *time.Time) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *PaymentOrderCreate) SetUpdatedAt(v time.Time) *PaymentOrderCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *PaymentOrderCreate) SetNillableUpdatedAt(v *time.Time) *PaymentOrderCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_c *PaymentOrderCreate) SetUser(v *User) *PaymentOrderCreate {
+ return _c.SetUserID(v.ID)
+}
+
+// Mutation returns the PaymentOrderMutation object of the builder.
+func (_c *PaymentOrderCreate) Mutation() *PaymentOrderMutation {
+ return _c.mutation
+}
+
+// Save creates the PaymentOrder in the database.
+func (_c *PaymentOrderCreate) Save(ctx context.Context) (*PaymentOrder, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *PaymentOrderCreate) SaveX(ctx context.Context) *PaymentOrder {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PaymentOrderCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PaymentOrderCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *PaymentOrderCreate) defaults() {
+ if _, ok := _c.mutation.FeeRate(); !ok {
+ v := paymentorder.DefaultFeeRate
+ _c.mutation.SetFeeRate(v)
+ }
+ if _, ok := _c.mutation.OutTradeNo(); !ok {
+ v := paymentorder.DefaultOutTradeNo
+ _c.mutation.SetOutTradeNo(v)
+ }
+ if _, ok := _c.mutation.OrderType(); !ok {
+ v := paymentorder.DefaultOrderType
+ _c.mutation.SetOrderType(v)
+ }
+ if _, ok := _c.mutation.Status(); !ok {
+ v := paymentorder.DefaultStatus
+ _c.mutation.SetStatus(v)
+ }
+ if _, ok := _c.mutation.RefundAmount(); !ok {
+ v := paymentorder.DefaultRefundAmount
+ _c.mutation.SetRefundAmount(v)
+ }
+ if _, ok := _c.mutation.ForceRefund(); !ok {
+ v := paymentorder.DefaultForceRefund
+ _c.mutation.SetForceRefund(v)
+ }
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := paymentorder.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := paymentorder.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *PaymentOrderCreate) check() error {
+ if _, ok := _c.mutation.UserID(); !ok {
+ return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "PaymentOrder.user_id"`)}
+ }
+ if _, ok := _c.mutation.UserEmail(); !ok {
+ return &ValidationError{Name: "user_email", err: errors.New(`ent: missing required field "PaymentOrder.user_email"`)}
+ }
+ if v, ok := _c.mutation.UserEmail(); ok {
+ if err := paymentorder.UserEmailValidator(v); err != nil {
+ return &ValidationError{Name: "user_email", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.user_email": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.UserName(); !ok {
+ return &ValidationError{Name: "user_name", err: errors.New(`ent: missing required field "PaymentOrder.user_name"`)}
+ }
+ if v, ok := _c.mutation.UserName(); ok {
+ if err := paymentorder.UserNameValidator(v); err != nil {
+ return &ValidationError{Name: "user_name", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.user_name": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Amount(); !ok {
+ return &ValidationError{Name: "amount", err: errors.New(`ent: missing required field "PaymentOrder.amount"`)}
+ }
+ if _, ok := _c.mutation.PayAmount(); !ok {
+ return &ValidationError{Name: "pay_amount", err: errors.New(`ent: missing required field "PaymentOrder.pay_amount"`)}
+ }
+ if _, ok := _c.mutation.FeeRate(); !ok {
+ return &ValidationError{Name: "fee_rate", err: errors.New(`ent: missing required field "PaymentOrder.fee_rate"`)}
+ }
+ if _, ok := _c.mutation.RechargeCode(); !ok {
+ return &ValidationError{Name: "recharge_code", err: errors.New(`ent: missing required field "PaymentOrder.recharge_code"`)}
+ }
+ if v, ok := _c.mutation.RechargeCode(); ok {
+ if err := paymentorder.RechargeCodeValidator(v); err != nil {
+ return &ValidationError{Name: "recharge_code", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.recharge_code": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.OutTradeNo(); !ok {
+ return &ValidationError{Name: "out_trade_no", err: errors.New(`ent: missing required field "PaymentOrder.out_trade_no"`)}
+ }
+ if v, ok := _c.mutation.OutTradeNo(); ok {
+ if err := paymentorder.OutTradeNoValidator(v); err != nil {
+ return &ValidationError{Name: "out_trade_no", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.out_trade_no": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.PaymentType(); !ok {
+ return &ValidationError{Name: "payment_type", err: errors.New(`ent: missing required field "PaymentOrder.payment_type"`)}
+ }
+ if v, ok := _c.mutation.PaymentType(); ok {
+ if err := paymentorder.PaymentTypeValidator(v); err != nil {
+ return &ValidationError{Name: "payment_type", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.payment_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.PaymentTradeNo(); !ok {
+ return &ValidationError{Name: "payment_trade_no", err: errors.New(`ent: missing required field "PaymentOrder.payment_trade_no"`)}
+ }
+ if v, ok := _c.mutation.PaymentTradeNo(); ok {
+ if err := paymentorder.PaymentTradeNoValidator(v); err != nil {
+ return &ValidationError{Name: "payment_trade_no", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.payment_trade_no": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.OrderType(); !ok {
+ return &ValidationError{Name: "order_type", err: errors.New(`ent: missing required field "PaymentOrder.order_type"`)}
+ }
+ if v, ok := _c.mutation.OrderType(); ok {
+ if err := paymentorder.OrderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "order_type", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.order_type": %w`, err)}
+ }
+ }
+ if v, ok := _c.mutation.ProviderInstanceID(); ok {
+ if err := paymentorder.ProviderInstanceIDValidator(v); err != nil {
+ return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)}
+ }
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := paymentorder.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Status(); !ok {
+ return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "PaymentOrder.status"`)}
+ }
+ if v, ok := _c.mutation.Status(); ok {
+ if err := paymentorder.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.RefundAmount(); !ok {
+ return &ValidationError{Name: "refund_amount", err: errors.New(`ent: missing required field "PaymentOrder.refund_amount"`)}
+ }
+ if _, ok := _c.mutation.ForceRefund(); !ok {
+ return &ValidationError{Name: "force_refund", err: errors.New(`ent: missing required field "PaymentOrder.force_refund"`)}
+ }
+ if v, ok := _c.mutation.RefundRequestedBy(); ok {
+ if err := paymentorder.RefundRequestedByValidator(v); err != nil {
+ return &ValidationError{Name: "refund_requested_by", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.refund_requested_by": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ExpiresAt(); !ok {
+ return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "PaymentOrder.expires_at"`)}
+ }
+ if _, ok := _c.mutation.ClientIP(); !ok {
+ return &ValidationError{Name: "client_ip", err: errors.New(`ent: missing required field "PaymentOrder.client_ip"`)}
+ }
+ if v, ok := _c.mutation.ClientIP(); ok {
+ if err := paymentorder.ClientIPValidator(v); err != nil {
+ return &ValidationError{Name: "client_ip", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.client_ip": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.SrcHost(); !ok {
+ return &ValidationError{Name: "src_host", err: errors.New(`ent: missing required field "PaymentOrder.src_host"`)}
+ }
+ if v, ok := _c.mutation.SrcHost(); ok {
+ if err := paymentorder.SrcHostValidator(v); err != nil {
+ return &ValidationError{Name: "src_host", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.src_host": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PaymentOrder.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "PaymentOrder.updated_at"`)}
+ }
+ if len(_c.mutation.UserIDs()) == 0 {
+ return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "PaymentOrder.user"`)}
+ }
+ return nil
+}
+
+func (_c *PaymentOrderCreate) sqlSave(ctx context.Context) (*PaymentOrder, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *PaymentOrderCreate) createSpec() (*PaymentOrder, *sqlgraph.CreateSpec) {
+ var (
+ _node = &PaymentOrder{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(paymentorder.Table, sqlgraph.NewFieldSpec(paymentorder.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.UserEmail(); ok {
+ _spec.SetField(paymentorder.FieldUserEmail, field.TypeString, value)
+ _node.UserEmail = value
+ }
+ if value, ok := _c.mutation.UserName(); ok {
+ _spec.SetField(paymentorder.FieldUserName, field.TypeString, value)
+ _node.UserName = value
+ }
+ if value, ok := _c.mutation.UserNotes(); ok {
+ _spec.SetField(paymentorder.FieldUserNotes, field.TypeString, value)
+ _node.UserNotes = &value
+ }
+ if value, ok := _c.mutation.Amount(); ok {
+ _spec.SetField(paymentorder.FieldAmount, field.TypeFloat64, value)
+ _node.Amount = value
+ }
+ if value, ok := _c.mutation.PayAmount(); ok {
+ _spec.SetField(paymentorder.FieldPayAmount, field.TypeFloat64, value)
+ _node.PayAmount = value
+ }
+ if value, ok := _c.mutation.FeeRate(); ok {
+ _spec.SetField(paymentorder.FieldFeeRate, field.TypeFloat64, value)
+ _node.FeeRate = value
+ }
+ if value, ok := _c.mutation.RechargeCode(); ok {
+ _spec.SetField(paymentorder.FieldRechargeCode, field.TypeString, value)
+ _node.RechargeCode = value
+ }
+ if value, ok := _c.mutation.OutTradeNo(); ok {
+ _spec.SetField(paymentorder.FieldOutTradeNo, field.TypeString, value)
+ _node.OutTradeNo = value
+ }
+ if value, ok := _c.mutation.PaymentType(); ok {
+ _spec.SetField(paymentorder.FieldPaymentType, field.TypeString, value)
+ _node.PaymentType = value
+ }
+ if value, ok := _c.mutation.PaymentTradeNo(); ok {
+ _spec.SetField(paymentorder.FieldPaymentTradeNo, field.TypeString, value)
+ _node.PaymentTradeNo = value
+ }
+ if value, ok := _c.mutation.PayURL(); ok {
+ _spec.SetField(paymentorder.FieldPayURL, field.TypeString, value)
+ _node.PayURL = &value
+ }
+ if value, ok := _c.mutation.QrCode(); ok {
+ _spec.SetField(paymentorder.FieldQrCode, field.TypeString, value)
+ _node.QrCode = &value
+ }
+ if value, ok := _c.mutation.QrCodeImg(); ok {
+ _spec.SetField(paymentorder.FieldQrCodeImg, field.TypeString, value)
+ _node.QrCodeImg = &value
+ }
+ if value, ok := _c.mutation.OrderType(); ok {
+ _spec.SetField(paymentorder.FieldOrderType, field.TypeString, value)
+ _node.OrderType = value
+ }
+ if value, ok := _c.mutation.PlanID(); ok {
+ _spec.SetField(paymentorder.FieldPlanID, field.TypeInt64, value)
+ _node.PlanID = &value
+ }
+ if value, ok := _c.mutation.SubscriptionGroupID(); ok {
+ _spec.SetField(paymentorder.FieldSubscriptionGroupID, field.TypeInt64, value)
+ _node.SubscriptionGroupID = &value
+ }
+ if value, ok := _c.mutation.SubscriptionDays(); ok {
+ _spec.SetField(paymentorder.FieldSubscriptionDays, field.TypeInt, value)
+ _node.SubscriptionDays = &value
+ }
+ if value, ok := _c.mutation.ProviderInstanceID(); ok {
+ _spec.SetField(paymentorder.FieldProviderInstanceID, field.TypeString, value)
+ _node.ProviderInstanceID = &value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = &value
+ }
+ if value, ok := _c.mutation.ProviderSnapshot(); ok {
+ _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value)
+ _node.ProviderSnapshot = value
+ }
+ if value, ok := _c.mutation.Status(); ok {
+ _spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
+ _node.Status = value
+ }
+ if value, ok := _c.mutation.RefundAmount(); ok {
+ _spec.SetField(paymentorder.FieldRefundAmount, field.TypeFloat64, value)
+ _node.RefundAmount = value
+ }
+ if value, ok := _c.mutation.RefundReason(); ok {
+ _spec.SetField(paymentorder.FieldRefundReason, field.TypeString, value)
+ _node.RefundReason = &value
+ }
+ if value, ok := _c.mutation.RefundAt(); ok {
+ _spec.SetField(paymentorder.FieldRefundAt, field.TypeTime, value)
+ _node.RefundAt = &value
+ }
+ if value, ok := _c.mutation.ForceRefund(); ok {
+ _spec.SetField(paymentorder.FieldForceRefund, field.TypeBool, value)
+ _node.ForceRefund = value
+ }
+ if value, ok := _c.mutation.RefundRequestedAt(); ok {
+ _spec.SetField(paymentorder.FieldRefundRequestedAt, field.TypeTime, value)
+ _node.RefundRequestedAt = &value
+ }
+ if value, ok := _c.mutation.RefundRequestReason(); ok {
+ _spec.SetField(paymentorder.FieldRefundRequestReason, field.TypeString, value)
+ _node.RefundRequestReason = &value
+ }
+ if value, ok := _c.mutation.RefundRequestedBy(); ok {
+ _spec.SetField(paymentorder.FieldRefundRequestedBy, field.TypeString, value)
+ _node.RefundRequestedBy = &value
+ }
+ if value, ok := _c.mutation.ExpiresAt(); ok {
+ _spec.SetField(paymentorder.FieldExpiresAt, field.TypeTime, value)
+ _node.ExpiresAt = value
+ }
+ if value, ok := _c.mutation.PaidAt(); ok {
+ _spec.SetField(paymentorder.FieldPaidAt, field.TypeTime, value)
+ _node.PaidAt = &value
+ }
+ if value, ok := _c.mutation.CompletedAt(); ok {
+ _spec.SetField(paymentorder.FieldCompletedAt, field.TypeTime, value)
+ _node.CompletedAt = &value
+ }
+ if value, ok := _c.mutation.FailedAt(); ok {
+ _spec.SetField(paymentorder.FieldFailedAt, field.TypeTime, value)
+ _node.FailedAt = &value
+ }
+ if value, ok := _c.mutation.FailedReason(); ok {
+ _spec.SetField(paymentorder.FieldFailedReason, field.TypeString, value)
+ _node.FailedReason = &value
+ }
+ if value, ok := _c.mutation.ClientIP(); ok {
+ _spec.SetField(paymentorder.FieldClientIP, field.TypeString, value)
+ _node.ClientIP = value
+ }
+ if value, ok := _c.mutation.SrcHost(); ok {
+ _spec.SetField(paymentorder.FieldSrcHost, field.TypeString, value)
+ _node.SrcHost = value
+ }
+ if value, ok := _c.mutation.SrcURL(); ok {
+ _spec.SetField(paymentorder.FieldSrcURL, field.TypeString, value)
+ _node.SrcURL = &value
+ }
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(paymentorder.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(paymentorder.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: paymentorder.UserTable,
+ Columns: []string{paymentorder.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.UserID = nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PaymentOrder.Create().
+// SetUserID(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PaymentOrderUpsert) {
+// SetUserID(v+v).
+// }).
+// Exec(ctx)
+func (_c *PaymentOrderCreate) OnConflict(opts ...sql.ConflictOption) *PaymentOrderUpsertOne {
+ _c.conflict = opts
+ return &PaymentOrderUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PaymentOrder.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PaymentOrderCreate) OnConflictColumns(columns ...string) *PaymentOrderUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PaymentOrderUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // PaymentOrderUpsertOne is the builder for "upsert"-ing
+ // one PaymentOrder node.
+ PaymentOrderUpsertOne struct {
+ create *PaymentOrderCreate
+ }
+
+ // PaymentOrderUpsert is the "OnConflict" setter.
+ PaymentOrderUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUserID sets the "user_id" field.
+func (u *PaymentOrderUpsert) SetUserID(v int64) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldUserID, v)
+ return u
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateUserID() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldUserID)
+ return u
+}
+
+// SetUserEmail sets the "user_email" field.
+func (u *PaymentOrderUpsert) SetUserEmail(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldUserEmail, v)
+ return u
+}
+
+// UpdateUserEmail sets the "user_email" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateUserEmail() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldUserEmail)
+ return u
+}
+
+// SetUserName sets the "user_name" field.
+func (u *PaymentOrderUpsert) SetUserName(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldUserName, v)
+ return u
+}
+
+// UpdateUserName sets the "user_name" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateUserName() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldUserName)
+ return u
+}
+
+// SetUserNotes sets the "user_notes" field.
+func (u *PaymentOrderUpsert) SetUserNotes(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldUserNotes, v)
+ return u
+}
+
+// UpdateUserNotes sets the "user_notes" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateUserNotes() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldUserNotes)
+ return u
+}
+
+// ClearUserNotes clears the value of the "user_notes" field.
+func (u *PaymentOrderUpsert) ClearUserNotes() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldUserNotes)
+ return u
+}
+
+// SetAmount sets the "amount" field.
+func (u *PaymentOrderUpsert) SetAmount(v float64) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldAmount, v)
+ return u
+}
+
+// UpdateAmount sets the "amount" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateAmount() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldAmount)
+ return u
+}
+
+// AddAmount adds v to the "amount" field.
+func (u *PaymentOrderUpsert) AddAmount(v float64) *PaymentOrderUpsert {
+ u.Add(paymentorder.FieldAmount, v)
+ return u
+}
+
+// SetPayAmount sets the "pay_amount" field.
+func (u *PaymentOrderUpsert) SetPayAmount(v float64) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldPayAmount, v)
+ return u
+}
+
+// UpdatePayAmount sets the "pay_amount" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdatePayAmount() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldPayAmount)
+ return u
+}
+
+// AddPayAmount adds v to the "pay_amount" field.
+func (u *PaymentOrderUpsert) AddPayAmount(v float64) *PaymentOrderUpsert {
+ u.Add(paymentorder.FieldPayAmount, v)
+ return u
+}
+
+// SetFeeRate sets the "fee_rate" field.
+func (u *PaymentOrderUpsert) SetFeeRate(v float64) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldFeeRate, v)
+ return u
+}
+
+// UpdateFeeRate sets the "fee_rate" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateFeeRate() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldFeeRate)
+ return u
+}
+
+// AddFeeRate adds v to the "fee_rate" field.
+func (u *PaymentOrderUpsert) AddFeeRate(v float64) *PaymentOrderUpsert {
+ u.Add(paymentorder.FieldFeeRate, v)
+ return u
+}
+
+// SetRechargeCode sets the "recharge_code" field.
+func (u *PaymentOrderUpsert) SetRechargeCode(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldRechargeCode, v)
+ return u
+}
+
+// UpdateRechargeCode sets the "recharge_code" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateRechargeCode() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldRechargeCode)
+ return u
+}
+
+// SetOutTradeNo sets the "out_trade_no" field.
+func (u *PaymentOrderUpsert) SetOutTradeNo(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldOutTradeNo, v)
+ return u
+}
+
+// UpdateOutTradeNo sets the "out_trade_no" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateOutTradeNo() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldOutTradeNo)
+ return u
+}
+
+// SetPaymentType sets the "payment_type" field.
+func (u *PaymentOrderUpsert) SetPaymentType(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldPaymentType, v)
+ return u
+}
+
+// UpdatePaymentType sets the "payment_type" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdatePaymentType() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldPaymentType)
+ return u
+}
+
+// SetPaymentTradeNo sets the "payment_trade_no" field.
+func (u *PaymentOrderUpsert) SetPaymentTradeNo(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldPaymentTradeNo, v)
+ return u
+}
+
+// UpdatePaymentTradeNo sets the "payment_trade_no" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdatePaymentTradeNo() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldPaymentTradeNo)
+ return u
+}
+
+// SetPayURL sets the "pay_url" field.
+func (u *PaymentOrderUpsert) SetPayURL(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldPayURL, v)
+ return u
+}
+
+// UpdatePayURL sets the "pay_url" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdatePayURL() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldPayURL)
+ return u
+}
+
+// ClearPayURL clears the value of the "pay_url" field.
+func (u *PaymentOrderUpsert) ClearPayURL() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldPayURL)
+ return u
+}
+
+// SetQrCode sets the "qr_code" field.
+func (u *PaymentOrderUpsert) SetQrCode(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldQrCode, v)
+ return u
+}
+
+// UpdateQrCode sets the "qr_code" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateQrCode() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldQrCode)
+ return u
+}
+
+// ClearQrCode clears the value of the "qr_code" field.
+func (u *PaymentOrderUpsert) ClearQrCode() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldQrCode)
+ return u
+}
+
+// SetQrCodeImg sets the "qr_code_img" field.
+func (u *PaymentOrderUpsert) SetQrCodeImg(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldQrCodeImg, v)
+ return u
+}
+
+// UpdateQrCodeImg sets the "qr_code_img" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateQrCodeImg() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldQrCodeImg)
+ return u
+}
+
+// ClearQrCodeImg clears the value of the "qr_code_img" field.
+func (u *PaymentOrderUpsert) ClearQrCodeImg() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldQrCodeImg)
+ return u
+}
+
+// SetOrderType sets the "order_type" field.
+func (u *PaymentOrderUpsert) SetOrderType(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldOrderType, v)
+ return u
+}
+
+// UpdateOrderType sets the "order_type" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateOrderType() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldOrderType)
+ return u
+}
+
+// SetPlanID sets the "plan_id" field.
+func (u *PaymentOrderUpsert) SetPlanID(v int64) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldPlanID, v)
+ return u
+}
+
+// UpdatePlanID sets the "plan_id" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdatePlanID() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldPlanID)
+ return u
+}
+
+// AddPlanID adds v to the "plan_id" field.
+func (u *PaymentOrderUpsert) AddPlanID(v int64) *PaymentOrderUpsert {
+ u.Add(paymentorder.FieldPlanID, v)
+ return u
+}
+
+// ClearPlanID clears the value of the "plan_id" field.
+func (u *PaymentOrderUpsert) ClearPlanID() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldPlanID)
+ return u
+}
+
+// SetSubscriptionGroupID sets the "subscription_group_id" field.
+func (u *PaymentOrderUpsert) SetSubscriptionGroupID(v int64) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldSubscriptionGroupID, v)
+ return u
+}
+
+// UpdateSubscriptionGroupID sets the "subscription_group_id" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateSubscriptionGroupID() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldSubscriptionGroupID)
+ return u
+}
+
+// AddSubscriptionGroupID adds v to the "subscription_group_id" field.
+func (u *PaymentOrderUpsert) AddSubscriptionGroupID(v int64) *PaymentOrderUpsert {
+ u.Add(paymentorder.FieldSubscriptionGroupID, v)
+ return u
+}
+
+// ClearSubscriptionGroupID clears the value of the "subscription_group_id" field.
+func (u *PaymentOrderUpsert) ClearSubscriptionGroupID() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldSubscriptionGroupID)
+ return u
+}
+
+// SetSubscriptionDays sets the "subscription_days" field.
+func (u *PaymentOrderUpsert) SetSubscriptionDays(v int) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldSubscriptionDays, v)
+ return u
+}
+
+// UpdateSubscriptionDays sets the "subscription_days" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateSubscriptionDays() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldSubscriptionDays)
+ return u
+}
+
+// AddSubscriptionDays adds v to the "subscription_days" field.
+func (u *PaymentOrderUpsert) AddSubscriptionDays(v int) *PaymentOrderUpsert {
+ u.Add(paymentorder.FieldSubscriptionDays, v)
+ return u
+}
+
+// ClearSubscriptionDays clears the value of the "subscription_days" field.
+func (u *PaymentOrderUpsert) ClearSubscriptionDays() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldSubscriptionDays)
+ return u
+}
+
+// SetProviderInstanceID sets the "provider_instance_id" field.
+func (u *PaymentOrderUpsert) SetProviderInstanceID(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldProviderInstanceID, v)
+ return u
+}
+
+// UpdateProviderInstanceID sets the "provider_instance_id" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateProviderInstanceID() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldProviderInstanceID)
+ return u
+}
+
+// ClearProviderInstanceID clears the value of the "provider_instance_id" field.
+func (u *PaymentOrderUpsert) ClearProviderInstanceID() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldProviderInstanceID)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PaymentOrderUpsert) SetProviderKey(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateProviderKey() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldProviderKey)
+ return u
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (u *PaymentOrderUpsert) ClearProviderKey() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldProviderKey)
+ return u
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (u *PaymentOrderUpsert) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldProviderSnapshot, v)
+ return u
+}
+
+// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateProviderSnapshot() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldProviderSnapshot)
+ return u
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (u *PaymentOrderUpsert) ClearProviderSnapshot() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldProviderSnapshot)
+ return u
+}
+
+// SetStatus sets the "status" field.
+func (u *PaymentOrderUpsert) SetStatus(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldStatus, v)
+ return u
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateStatus() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldStatus)
+ return u
+}
+
+// SetRefundAmount sets the "refund_amount" field.
+func (u *PaymentOrderUpsert) SetRefundAmount(v float64) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldRefundAmount, v)
+ return u
+}
+
+// UpdateRefundAmount sets the "refund_amount" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateRefundAmount() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldRefundAmount)
+ return u
+}
+
+// AddRefundAmount adds v to the "refund_amount" field.
+func (u *PaymentOrderUpsert) AddRefundAmount(v float64) *PaymentOrderUpsert {
+ u.Add(paymentorder.FieldRefundAmount, v)
+ return u
+}
+
+// SetRefundReason sets the "refund_reason" field.
+func (u *PaymentOrderUpsert) SetRefundReason(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldRefundReason, v)
+ return u
+}
+
+// UpdateRefundReason sets the "refund_reason" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateRefundReason() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldRefundReason)
+ return u
+}
+
+// ClearRefundReason clears the value of the "refund_reason" field.
+func (u *PaymentOrderUpsert) ClearRefundReason() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldRefundReason)
+ return u
+}
+
+// SetRefundAt sets the "refund_at" field.
+func (u *PaymentOrderUpsert) SetRefundAt(v time.Time) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldRefundAt, v)
+ return u
+}
+
+// UpdateRefundAt sets the "refund_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateRefundAt() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldRefundAt)
+ return u
+}
+
+// ClearRefundAt clears the value of the "refund_at" field.
+func (u *PaymentOrderUpsert) ClearRefundAt() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldRefundAt)
+ return u
+}
+
+// SetForceRefund sets the "force_refund" field.
+func (u *PaymentOrderUpsert) SetForceRefund(v bool) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldForceRefund, v)
+ return u
+}
+
+// UpdateForceRefund sets the "force_refund" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateForceRefund() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldForceRefund)
+ return u
+}
+
+// SetRefundRequestedAt sets the "refund_requested_at" field.
+func (u *PaymentOrderUpsert) SetRefundRequestedAt(v time.Time) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldRefundRequestedAt, v)
+ return u
+}
+
+// UpdateRefundRequestedAt sets the "refund_requested_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateRefundRequestedAt() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldRefundRequestedAt)
+ return u
+}
+
+// ClearRefundRequestedAt clears the value of the "refund_requested_at" field.
+func (u *PaymentOrderUpsert) ClearRefundRequestedAt() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldRefundRequestedAt)
+ return u
+}
+
+// SetRefundRequestReason sets the "refund_request_reason" field.
+func (u *PaymentOrderUpsert) SetRefundRequestReason(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldRefundRequestReason, v)
+ return u
+}
+
+// UpdateRefundRequestReason sets the "refund_request_reason" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateRefundRequestReason() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldRefundRequestReason)
+ return u
+}
+
+// ClearRefundRequestReason clears the value of the "refund_request_reason" field.
+func (u *PaymentOrderUpsert) ClearRefundRequestReason() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldRefundRequestReason)
+ return u
+}
+
+// SetRefundRequestedBy sets the "refund_requested_by" field.
+func (u *PaymentOrderUpsert) SetRefundRequestedBy(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldRefundRequestedBy, v)
+ return u
+}
+
+// UpdateRefundRequestedBy sets the "refund_requested_by" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateRefundRequestedBy() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldRefundRequestedBy)
+ return u
+}
+
+// ClearRefundRequestedBy clears the value of the "refund_requested_by" field.
+func (u *PaymentOrderUpsert) ClearRefundRequestedBy() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldRefundRequestedBy)
+ return u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PaymentOrderUpsert) SetExpiresAt(v time.Time) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldExpiresAt, v)
+ return u
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateExpiresAt() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldExpiresAt)
+ return u
+}
+
+// SetPaidAt sets the "paid_at" field.
+func (u *PaymentOrderUpsert) SetPaidAt(v time.Time) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldPaidAt, v)
+ return u
+}
+
+// UpdatePaidAt sets the "paid_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdatePaidAt() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldPaidAt)
+ return u
+}
+
+// ClearPaidAt clears the value of the "paid_at" field.
+func (u *PaymentOrderUpsert) ClearPaidAt() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldPaidAt)
+ return u
+}
+
+// SetCompletedAt sets the "completed_at" field.
+func (u *PaymentOrderUpsert) SetCompletedAt(v time.Time) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldCompletedAt, v)
+ return u
+}
+
+// UpdateCompletedAt sets the "completed_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateCompletedAt() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldCompletedAt)
+ return u
+}
+
+// ClearCompletedAt clears the value of the "completed_at" field.
+func (u *PaymentOrderUpsert) ClearCompletedAt() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldCompletedAt)
+ return u
+}
+
+// SetFailedAt sets the "failed_at" field.
+func (u *PaymentOrderUpsert) SetFailedAt(v time.Time) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldFailedAt, v)
+ return u
+}
+
+// UpdateFailedAt sets the "failed_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateFailedAt() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldFailedAt)
+ return u
+}
+
+// ClearFailedAt clears the value of the "failed_at" field.
+func (u *PaymentOrderUpsert) ClearFailedAt() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldFailedAt)
+ return u
+}
+
+// SetFailedReason sets the "failed_reason" field.
+func (u *PaymentOrderUpsert) SetFailedReason(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldFailedReason, v)
+ return u
+}
+
+// UpdateFailedReason sets the "failed_reason" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateFailedReason() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldFailedReason)
+ return u
+}
+
+// ClearFailedReason clears the value of the "failed_reason" field.
+func (u *PaymentOrderUpsert) ClearFailedReason() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldFailedReason)
+ return u
+}
+
+// SetClientIP sets the "client_ip" field.
+func (u *PaymentOrderUpsert) SetClientIP(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldClientIP, v)
+ return u
+}
+
+// UpdateClientIP sets the "client_ip" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateClientIP() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldClientIP)
+ return u
+}
+
+// SetSrcHost sets the "src_host" field.
+func (u *PaymentOrderUpsert) SetSrcHost(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldSrcHost, v)
+ return u
+}
+
+// UpdateSrcHost sets the "src_host" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateSrcHost() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldSrcHost)
+ return u
+}
+
+// SetSrcURL sets the "src_url" field.
+func (u *PaymentOrderUpsert) SetSrcURL(v string) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldSrcURL, v)
+ return u
+}
+
+// UpdateSrcURL sets the "src_url" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateSrcURL() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldSrcURL)
+ return u
+}
+
+// ClearSrcURL clears the value of the "src_url" field.
+func (u *PaymentOrderUpsert) ClearSrcURL() *PaymentOrderUpsert {
+ u.SetNull(paymentorder.FieldSrcURL)
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PaymentOrderUpsert) SetUpdatedAt(v time.Time) *PaymentOrderUpsert {
+ u.Set(paymentorder.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsert) UpdateUpdatedAt() *PaymentOrderUpsert {
+ u.SetExcluded(paymentorder.FieldUpdatedAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.PaymentOrder.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PaymentOrderUpsertOne) UpdateNewValues() *PaymentOrderUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(paymentorder.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PaymentOrder.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PaymentOrderUpsertOne) Ignore() *PaymentOrderUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PaymentOrderUpsertOne) DoNothing() *PaymentOrderUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PaymentOrderCreate.OnConflict
+// documentation for more info.
+func (u *PaymentOrderUpsertOne) Update(set func(*PaymentOrderUpsert)) *PaymentOrderUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PaymentOrderUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUserID sets the "user_id" field.
+func (u *PaymentOrderUpsertOne) SetUserID(v int64) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetUserID(v)
+ })
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateUserID() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateUserID()
+ })
+}
+
+// SetUserEmail sets the "user_email" field.
+func (u *PaymentOrderUpsertOne) SetUserEmail(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetUserEmail(v)
+ })
+}
+
+// UpdateUserEmail sets the "user_email" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateUserEmail() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateUserEmail()
+ })
+}
+
+// SetUserName sets the "user_name" field.
+func (u *PaymentOrderUpsertOne) SetUserName(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetUserName(v)
+ })
+}
+
+// UpdateUserName sets the "user_name" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateUserName() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateUserName()
+ })
+}
+
+// SetUserNotes sets the "user_notes" field.
+func (u *PaymentOrderUpsertOne) SetUserNotes(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetUserNotes(v)
+ })
+}
+
+// UpdateUserNotes sets the "user_notes" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateUserNotes() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateUserNotes()
+ })
+}
+
+// ClearUserNotes clears the value of the "user_notes" field.
+func (u *PaymentOrderUpsertOne) ClearUserNotes() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearUserNotes()
+ })
+}
+
+// SetAmount sets the "amount" field.
+func (u *PaymentOrderUpsertOne) SetAmount(v float64) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetAmount(v)
+ })
+}
+
+// AddAmount adds v to the "amount" field.
+func (u *PaymentOrderUpsertOne) AddAmount(v float64) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.AddAmount(v)
+ })
+}
+
+// UpdateAmount sets the "amount" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateAmount() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateAmount()
+ })
+}
+
+// SetPayAmount sets the "pay_amount" field.
+func (u *PaymentOrderUpsertOne) SetPayAmount(v float64) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetPayAmount(v)
+ })
+}
+
+// AddPayAmount adds v to the "pay_amount" field.
+func (u *PaymentOrderUpsertOne) AddPayAmount(v float64) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.AddPayAmount(v)
+ })
+}
+
+// UpdatePayAmount sets the "pay_amount" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdatePayAmount() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdatePayAmount()
+ })
+}
+
+// SetFeeRate sets the "fee_rate" field.
+func (u *PaymentOrderUpsertOne) SetFeeRate(v float64) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetFeeRate(v)
+ })
+}
+
+// AddFeeRate adds v to the "fee_rate" field.
+func (u *PaymentOrderUpsertOne) AddFeeRate(v float64) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.AddFeeRate(v)
+ })
+}
+
+// UpdateFeeRate sets the "fee_rate" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateFeeRate() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateFeeRate()
+ })
+}
+
+// SetRechargeCode sets the "recharge_code" field.
+func (u *PaymentOrderUpsertOne) SetRechargeCode(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetRechargeCode(v)
+ })
+}
+
+// UpdateRechargeCode sets the "recharge_code" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateRechargeCode() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateRechargeCode()
+ })
+}
+
+// SetOutTradeNo sets the "out_trade_no" field.
+func (u *PaymentOrderUpsertOne) SetOutTradeNo(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetOutTradeNo(v)
+ })
+}
+
+// UpdateOutTradeNo sets the "out_trade_no" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateOutTradeNo() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateOutTradeNo()
+ })
+}
+
+// SetPaymentType sets the "payment_type" field.
+func (u *PaymentOrderUpsertOne) SetPaymentType(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetPaymentType(v)
+ })
+}
+
+// UpdatePaymentType sets the "payment_type" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdatePaymentType() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdatePaymentType()
+ })
+}
+
+// SetPaymentTradeNo sets the "payment_trade_no" field.
+func (u *PaymentOrderUpsertOne) SetPaymentTradeNo(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetPaymentTradeNo(v)
+ })
+}
+
+// UpdatePaymentTradeNo sets the "payment_trade_no" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdatePaymentTradeNo() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdatePaymentTradeNo()
+ })
+}
+
+// SetPayURL sets the "pay_url" field.
+func (u *PaymentOrderUpsertOne) SetPayURL(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetPayURL(v)
+ })
+}
+
+// UpdatePayURL sets the "pay_url" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdatePayURL() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdatePayURL()
+ })
+}
+
+// ClearPayURL clears the value of the "pay_url" field.
+func (u *PaymentOrderUpsertOne) ClearPayURL() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearPayURL()
+ })
+}
+
+// SetQrCode sets the "qr_code" field.
+func (u *PaymentOrderUpsertOne) SetQrCode(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetQrCode(v)
+ })
+}
+
+// UpdateQrCode sets the "qr_code" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateQrCode() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateQrCode()
+ })
+}
+
+// ClearQrCode clears the value of the "qr_code" field.
+func (u *PaymentOrderUpsertOne) ClearQrCode() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearQrCode()
+ })
+}
+
+// SetQrCodeImg sets the "qr_code_img" field.
+func (u *PaymentOrderUpsertOne) SetQrCodeImg(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetQrCodeImg(v)
+ })
+}
+
+// UpdateQrCodeImg sets the "qr_code_img" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateQrCodeImg() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateQrCodeImg()
+ })
+}
+
+// ClearQrCodeImg clears the value of the "qr_code_img" field.
+func (u *PaymentOrderUpsertOne) ClearQrCodeImg() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearQrCodeImg()
+ })
+}
+
+// SetOrderType sets the "order_type" field.
+func (u *PaymentOrderUpsertOne) SetOrderType(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetOrderType(v)
+ })
+}
+
+// UpdateOrderType sets the "order_type" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateOrderType() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateOrderType()
+ })
+}
+
+// SetPlanID sets the "plan_id" field.
+func (u *PaymentOrderUpsertOne) SetPlanID(v int64) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetPlanID(v)
+ })
+}
+
+// AddPlanID adds v to the "plan_id" field.
+func (u *PaymentOrderUpsertOne) AddPlanID(v int64) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.AddPlanID(v)
+ })
+}
+
+// UpdatePlanID sets the "plan_id" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdatePlanID() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdatePlanID()
+ })
+}
+
+// ClearPlanID clears the value of the "plan_id" field.
+func (u *PaymentOrderUpsertOne) ClearPlanID() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearPlanID()
+ })
+}
+
+// SetSubscriptionGroupID sets the "subscription_group_id" field.
+func (u *PaymentOrderUpsertOne) SetSubscriptionGroupID(v int64) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetSubscriptionGroupID(v)
+ })
+}
+
+// AddSubscriptionGroupID adds v to the "subscription_group_id" field.
+func (u *PaymentOrderUpsertOne) AddSubscriptionGroupID(v int64) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.AddSubscriptionGroupID(v)
+ })
+}
+
+// UpdateSubscriptionGroupID sets the "subscription_group_id" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateSubscriptionGroupID() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateSubscriptionGroupID()
+ })
+}
+
+// ClearSubscriptionGroupID clears the value of the "subscription_group_id" field.
+func (u *PaymentOrderUpsertOne) ClearSubscriptionGroupID() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearSubscriptionGroupID()
+ })
+}
+
+// SetSubscriptionDays sets the "subscription_days" field.
+func (u *PaymentOrderUpsertOne) SetSubscriptionDays(v int) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetSubscriptionDays(v)
+ })
+}
+
+// AddSubscriptionDays adds v to the "subscription_days" field.
+func (u *PaymentOrderUpsertOne) AddSubscriptionDays(v int) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.AddSubscriptionDays(v)
+ })
+}
+
+// UpdateSubscriptionDays sets the "subscription_days" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateSubscriptionDays() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateSubscriptionDays()
+ })
+}
+
+// ClearSubscriptionDays clears the value of the "subscription_days" field.
+func (u *PaymentOrderUpsertOne) ClearSubscriptionDays() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearSubscriptionDays()
+ })
+}
+
+// SetProviderInstanceID sets the "provider_instance_id" field.
+func (u *PaymentOrderUpsertOne) SetProviderInstanceID(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderInstanceID(v)
+ })
+}
+
+// UpdateProviderInstanceID sets the "provider_instance_id" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateProviderInstanceID() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderInstanceID()
+ })
+}
+
+// ClearProviderInstanceID clears the value of the "provider_instance_id" field.
+func (u *PaymentOrderUpsertOne) ClearProviderInstanceID() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderInstanceID()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PaymentOrderUpsertOne) SetProviderKey(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateProviderKey() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (u *PaymentOrderUpsertOne) ClearProviderKey() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderKey()
+ })
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (u *PaymentOrderUpsertOne) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderSnapshot(v)
+ })
+}
+
+// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateProviderSnapshot() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderSnapshot()
+ })
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (u *PaymentOrderUpsertOne) ClearProviderSnapshot() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderSnapshot()
+ })
+}
+
+// SetStatus sets the "status" field.
+func (u *PaymentOrderUpsertOne) SetStatus(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetStatus(v)
+ })
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateStatus() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateStatus()
+ })
+}
+
+// SetRefundAmount sets the "refund_amount" field.
+func (u *PaymentOrderUpsertOne) SetRefundAmount(v float64) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetRefundAmount(v)
+ })
+}
+
+// AddRefundAmount adds v to the "refund_amount" field.
+func (u *PaymentOrderUpsertOne) AddRefundAmount(v float64) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.AddRefundAmount(v)
+ })
+}
+
+// UpdateRefundAmount sets the "refund_amount" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateRefundAmount() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateRefundAmount()
+ })
+}
+
+// SetRefundReason sets the "refund_reason" field.
+func (u *PaymentOrderUpsertOne) SetRefundReason(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetRefundReason(v)
+ })
+}
+
+// UpdateRefundReason sets the "refund_reason" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateRefundReason() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateRefundReason()
+ })
+}
+
+// ClearRefundReason clears the value of the "refund_reason" field.
+func (u *PaymentOrderUpsertOne) ClearRefundReason() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearRefundReason()
+ })
+}
+
+// SetRefundAt sets the "refund_at" field.
+func (u *PaymentOrderUpsertOne) SetRefundAt(v time.Time) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetRefundAt(v)
+ })
+}
+
+// UpdateRefundAt sets the "refund_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateRefundAt() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateRefundAt()
+ })
+}
+
+// ClearRefundAt clears the value of the "refund_at" field.
+func (u *PaymentOrderUpsertOne) ClearRefundAt() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearRefundAt()
+ })
+}
+
+// SetForceRefund sets the "force_refund" field.
+func (u *PaymentOrderUpsertOne) SetForceRefund(v bool) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetForceRefund(v)
+ })
+}
+
+// UpdateForceRefund sets the "force_refund" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateForceRefund() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateForceRefund()
+ })
+}
+
+// SetRefundRequestedAt sets the "refund_requested_at" field.
+func (u *PaymentOrderUpsertOne) SetRefundRequestedAt(v time.Time) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetRefundRequestedAt(v)
+ })
+}
+
+// UpdateRefundRequestedAt sets the "refund_requested_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateRefundRequestedAt() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateRefundRequestedAt()
+ })
+}
+
+// ClearRefundRequestedAt clears the value of the "refund_requested_at" field.
+func (u *PaymentOrderUpsertOne) ClearRefundRequestedAt() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearRefundRequestedAt()
+ })
+}
+
+// SetRefundRequestReason sets the "refund_request_reason" field.
+func (u *PaymentOrderUpsertOne) SetRefundRequestReason(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetRefundRequestReason(v)
+ })
+}
+
+// UpdateRefundRequestReason sets the "refund_request_reason" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateRefundRequestReason() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateRefundRequestReason()
+ })
+}
+
+// ClearRefundRequestReason clears the value of the "refund_request_reason" field.
+func (u *PaymentOrderUpsertOne) ClearRefundRequestReason() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearRefundRequestReason()
+ })
+}
+
+// SetRefundRequestedBy sets the "refund_requested_by" field.
+func (u *PaymentOrderUpsertOne) SetRefundRequestedBy(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetRefundRequestedBy(v)
+ })
+}
+
+// UpdateRefundRequestedBy sets the "refund_requested_by" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateRefundRequestedBy() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateRefundRequestedBy()
+ })
+}
+
+// ClearRefundRequestedBy clears the value of the "refund_requested_by" field.
+func (u *PaymentOrderUpsertOne) ClearRefundRequestedBy() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearRefundRequestedBy()
+ })
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PaymentOrderUpsertOne) SetExpiresAt(v time.Time) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateExpiresAt() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// SetPaidAt sets the "paid_at" field.
+func (u *PaymentOrderUpsertOne) SetPaidAt(v time.Time) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetPaidAt(v)
+ })
+}
+
+// UpdatePaidAt sets the "paid_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdatePaidAt() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdatePaidAt()
+ })
+}
+
+// ClearPaidAt clears the value of the "paid_at" field.
+func (u *PaymentOrderUpsertOne) ClearPaidAt() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearPaidAt()
+ })
+}
+
+// SetCompletedAt sets the "completed_at" field.
+func (u *PaymentOrderUpsertOne) SetCompletedAt(v time.Time) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetCompletedAt(v)
+ })
+}
+
+// UpdateCompletedAt sets the "completed_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateCompletedAt() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateCompletedAt()
+ })
+}
+
+// ClearCompletedAt clears the value of the "completed_at" field.
+func (u *PaymentOrderUpsertOne) ClearCompletedAt() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearCompletedAt()
+ })
+}
+
+// SetFailedAt sets the "failed_at" field.
+func (u *PaymentOrderUpsertOne) SetFailedAt(v time.Time) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetFailedAt(v)
+ })
+}
+
+// UpdateFailedAt sets the "failed_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateFailedAt() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateFailedAt()
+ })
+}
+
+// ClearFailedAt clears the value of the "failed_at" field.
+func (u *PaymentOrderUpsertOne) ClearFailedAt() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearFailedAt()
+ })
+}
+
+// SetFailedReason sets the "failed_reason" field.
+func (u *PaymentOrderUpsertOne) SetFailedReason(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetFailedReason(v)
+ })
+}
+
+// UpdateFailedReason sets the "failed_reason" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateFailedReason() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateFailedReason()
+ })
+}
+
+// ClearFailedReason clears the value of the "failed_reason" field.
+func (u *PaymentOrderUpsertOne) ClearFailedReason() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearFailedReason()
+ })
+}
+
+// SetClientIP sets the "client_ip" field.
+func (u *PaymentOrderUpsertOne) SetClientIP(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetClientIP(v)
+ })
+}
+
+// UpdateClientIP sets the "client_ip" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateClientIP() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateClientIP()
+ })
+}
+
+// SetSrcHost sets the "src_host" field.
+func (u *PaymentOrderUpsertOne) SetSrcHost(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetSrcHost(v)
+ })
+}
+
+// UpdateSrcHost sets the "src_host" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateSrcHost() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateSrcHost()
+ })
+}
+
+// SetSrcURL sets the "src_url" field.
+func (u *PaymentOrderUpsertOne) SetSrcURL(v string) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetSrcURL(v)
+ })
+}
+
+// UpdateSrcURL sets the "src_url" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateSrcURL() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateSrcURL()
+ })
+}
+
+// ClearSrcURL clears the value of the "src_url" field.
+func (u *PaymentOrderUpsertOne) ClearSrcURL() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearSrcURL()
+ })
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PaymentOrderUpsertOne) SetUpdatedAt(v time.Time) *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsertOne) UpdateUpdatedAt() *PaymentOrderUpsertOne {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *PaymentOrderUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PaymentOrderCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PaymentOrderUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *PaymentOrderUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *PaymentOrderUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// PaymentOrderCreateBulk is the builder for creating many PaymentOrder entities in bulk.
+type PaymentOrderCreateBulk struct {
+ config
+ err error
+ builders []*PaymentOrderCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the PaymentOrder entities in the database.
+func (_c *PaymentOrderCreateBulk) Save(ctx context.Context) ([]*PaymentOrder, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*PaymentOrder, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*PaymentOrderMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *PaymentOrderCreateBulk) SaveX(ctx context.Context) []*PaymentOrder {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PaymentOrderCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PaymentOrderCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PaymentOrder.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PaymentOrderUpsert) {
+// SetUserID(v+v).
+// }).
+// Exec(ctx)
+func (_c *PaymentOrderCreateBulk) OnConflict(opts ...sql.ConflictOption) *PaymentOrderUpsertBulk {
+ _c.conflict = opts
+ return &PaymentOrderUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PaymentOrder.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PaymentOrderCreateBulk) OnConflictColumns(columns ...string) *PaymentOrderUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PaymentOrderUpsertBulk{
+ create: _c,
+ }
+}
+
+// PaymentOrderUpsertBulk is the builder for "upsert"-ing
+// a bulk of PaymentOrder nodes.
+type PaymentOrderUpsertBulk struct {
+ create *PaymentOrderCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.PaymentOrder.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PaymentOrderUpsertBulk) UpdateNewValues() *PaymentOrderUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(paymentorder.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PaymentOrder.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PaymentOrderUpsertBulk) Ignore() *PaymentOrderUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PaymentOrderUpsertBulk) DoNothing() *PaymentOrderUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PaymentOrderCreateBulk.OnConflict
+// documentation for more info.
+func (u *PaymentOrderUpsertBulk) Update(set func(*PaymentOrderUpsert)) *PaymentOrderUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PaymentOrderUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUserID sets the "user_id" field.
+func (u *PaymentOrderUpsertBulk) SetUserID(v int64) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetUserID(v)
+ })
+}
+
+// UpdateUserID sets the "user_id" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateUserID() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateUserID()
+ })
+}
+
+// SetUserEmail sets the "user_email" field.
+func (u *PaymentOrderUpsertBulk) SetUserEmail(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetUserEmail(v)
+ })
+}
+
+// UpdateUserEmail sets the "user_email" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateUserEmail() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateUserEmail()
+ })
+}
+
+// SetUserName sets the "user_name" field.
+func (u *PaymentOrderUpsertBulk) SetUserName(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetUserName(v)
+ })
+}
+
+// UpdateUserName sets the "user_name" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateUserName() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateUserName()
+ })
+}
+
+// SetUserNotes sets the "user_notes" field.
+func (u *PaymentOrderUpsertBulk) SetUserNotes(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetUserNotes(v)
+ })
+}
+
+// UpdateUserNotes sets the "user_notes" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateUserNotes() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateUserNotes()
+ })
+}
+
+// ClearUserNotes clears the value of the "user_notes" field.
+func (u *PaymentOrderUpsertBulk) ClearUserNotes() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearUserNotes()
+ })
+}
+
+// SetAmount sets the "amount" field.
+func (u *PaymentOrderUpsertBulk) SetAmount(v float64) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetAmount(v)
+ })
+}
+
+// AddAmount adds v to the "amount" field.
+func (u *PaymentOrderUpsertBulk) AddAmount(v float64) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.AddAmount(v)
+ })
+}
+
+// UpdateAmount sets the "amount" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateAmount() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateAmount()
+ })
+}
+
+// SetPayAmount sets the "pay_amount" field.
+func (u *PaymentOrderUpsertBulk) SetPayAmount(v float64) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetPayAmount(v)
+ })
+}
+
+// AddPayAmount adds v to the "pay_amount" field.
+func (u *PaymentOrderUpsertBulk) AddPayAmount(v float64) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.AddPayAmount(v)
+ })
+}
+
+// UpdatePayAmount sets the "pay_amount" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdatePayAmount() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdatePayAmount()
+ })
+}
+
+// SetFeeRate sets the "fee_rate" field.
+func (u *PaymentOrderUpsertBulk) SetFeeRate(v float64) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetFeeRate(v)
+ })
+}
+
+// AddFeeRate adds v to the "fee_rate" field.
+func (u *PaymentOrderUpsertBulk) AddFeeRate(v float64) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.AddFeeRate(v)
+ })
+}
+
+// UpdateFeeRate sets the "fee_rate" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateFeeRate() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateFeeRate()
+ })
+}
+
+// SetRechargeCode sets the "recharge_code" field.
+func (u *PaymentOrderUpsertBulk) SetRechargeCode(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetRechargeCode(v)
+ })
+}
+
+// UpdateRechargeCode sets the "recharge_code" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateRechargeCode() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateRechargeCode()
+ })
+}
+
+// SetOutTradeNo sets the "out_trade_no" field.
+func (u *PaymentOrderUpsertBulk) SetOutTradeNo(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetOutTradeNo(v)
+ })
+}
+
+// UpdateOutTradeNo sets the "out_trade_no" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateOutTradeNo() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateOutTradeNo()
+ })
+}
+
+// SetPaymentType sets the "payment_type" field.
+func (u *PaymentOrderUpsertBulk) SetPaymentType(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetPaymentType(v)
+ })
+}
+
+// UpdatePaymentType sets the "payment_type" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdatePaymentType() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdatePaymentType()
+ })
+}
+
+// SetPaymentTradeNo sets the "payment_trade_no" field.
+func (u *PaymentOrderUpsertBulk) SetPaymentTradeNo(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetPaymentTradeNo(v)
+ })
+}
+
+// UpdatePaymentTradeNo sets the "payment_trade_no" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdatePaymentTradeNo() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdatePaymentTradeNo()
+ })
+}
+
+// SetPayURL sets the "pay_url" field.
+func (u *PaymentOrderUpsertBulk) SetPayURL(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetPayURL(v)
+ })
+}
+
+// UpdatePayURL sets the "pay_url" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdatePayURL() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdatePayURL()
+ })
+}
+
+// ClearPayURL clears the value of the "pay_url" field.
+func (u *PaymentOrderUpsertBulk) ClearPayURL() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearPayURL()
+ })
+}
+
+// SetQrCode sets the "qr_code" field.
+func (u *PaymentOrderUpsertBulk) SetQrCode(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetQrCode(v)
+ })
+}
+
+// UpdateQrCode sets the "qr_code" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateQrCode() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateQrCode()
+ })
+}
+
+// ClearQrCode clears the value of the "qr_code" field.
+func (u *PaymentOrderUpsertBulk) ClearQrCode() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearQrCode()
+ })
+}
+
+// SetQrCodeImg sets the "qr_code_img" field.
+func (u *PaymentOrderUpsertBulk) SetQrCodeImg(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetQrCodeImg(v)
+ })
+}
+
+// UpdateQrCodeImg sets the "qr_code_img" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateQrCodeImg() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateQrCodeImg()
+ })
+}
+
+// ClearQrCodeImg clears the value of the "qr_code_img" field.
+func (u *PaymentOrderUpsertBulk) ClearQrCodeImg() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearQrCodeImg()
+ })
+}
+
+// SetOrderType sets the "order_type" field.
+func (u *PaymentOrderUpsertBulk) SetOrderType(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetOrderType(v)
+ })
+}
+
+// UpdateOrderType sets the "order_type" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateOrderType() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateOrderType()
+ })
+}
+
+// SetPlanID sets the "plan_id" field.
+func (u *PaymentOrderUpsertBulk) SetPlanID(v int64) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetPlanID(v)
+ })
+}
+
+// AddPlanID adds v to the "plan_id" field.
+func (u *PaymentOrderUpsertBulk) AddPlanID(v int64) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.AddPlanID(v)
+ })
+}
+
+// UpdatePlanID sets the "plan_id" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdatePlanID() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdatePlanID()
+ })
+}
+
+// ClearPlanID clears the value of the "plan_id" field.
+func (u *PaymentOrderUpsertBulk) ClearPlanID() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearPlanID()
+ })
+}
+
+// SetSubscriptionGroupID sets the "subscription_group_id" field.
+func (u *PaymentOrderUpsertBulk) SetSubscriptionGroupID(v int64) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetSubscriptionGroupID(v)
+ })
+}
+
+// AddSubscriptionGroupID adds v to the "subscription_group_id" field.
+func (u *PaymentOrderUpsertBulk) AddSubscriptionGroupID(v int64) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.AddSubscriptionGroupID(v)
+ })
+}
+
+// UpdateSubscriptionGroupID sets the "subscription_group_id" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateSubscriptionGroupID() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateSubscriptionGroupID()
+ })
+}
+
+// ClearSubscriptionGroupID clears the value of the "subscription_group_id" field.
+func (u *PaymentOrderUpsertBulk) ClearSubscriptionGroupID() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearSubscriptionGroupID()
+ })
+}
+
+// SetSubscriptionDays sets the "subscription_days" field.
+func (u *PaymentOrderUpsertBulk) SetSubscriptionDays(v int) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetSubscriptionDays(v)
+ })
+}
+
+// AddSubscriptionDays adds v to the "subscription_days" field.
+func (u *PaymentOrderUpsertBulk) AddSubscriptionDays(v int) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.AddSubscriptionDays(v)
+ })
+}
+
+// UpdateSubscriptionDays sets the "subscription_days" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateSubscriptionDays() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateSubscriptionDays()
+ })
+}
+
+// ClearSubscriptionDays clears the value of the "subscription_days" field.
+func (u *PaymentOrderUpsertBulk) ClearSubscriptionDays() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearSubscriptionDays()
+ })
+}
+
+// SetProviderInstanceID sets the "provider_instance_id" field.
+func (u *PaymentOrderUpsertBulk) SetProviderInstanceID(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderInstanceID(v)
+ })
+}
+
+// UpdateProviderInstanceID sets the "provider_instance_id" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateProviderInstanceID() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderInstanceID()
+ })
+}
+
+// ClearProviderInstanceID clears the value of the "provider_instance_id" field.
+func (u *PaymentOrderUpsertBulk) ClearProviderInstanceID() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderInstanceID()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PaymentOrderUpsertBulk) SetProviderKey(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateProviderKey() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (u *PaymentOrderUpsertBulk) ClearProviderKey() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderKey()
+ })
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (u *PaymentOrderUpsertBulk) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetProviderSnapshot(v)
+ })
+}
+
+// UpdateProviderSnapshot sets the "provider_snapshot" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateProviderSnapshot() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateProviderSnapshot()
+ })
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (u *PaymentOrderUpsertBulk) ClearProviderSnapshot() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearProviderSnapshot()
+ })
+}
+
+// SetStatus sets the "status" field.
+func (u *PaymentOrderUpsertBulk) SetStatus(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetStatus(v)
+ })
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateStatus() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateStatus()
+ })
+}
+
+// SetRefundAmount sets the "refund_amount" field.
+func (u *PaymentOrderUpsertBulk) SetRefundAmount(v float64) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetRefundAmount(v)
+ })
+}
+
+// AddRefundAmount adds v to the "refund_amount" field.
+func (u *PaymentOrderUpsertBulk) AddRefundAmount(v float64) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.AddRefundAmount(v)
+ })
+}
+
+// UpdateRefundAmount sets the "refund_amount" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateRefundAmount() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateRefundAmount()
+ })
+}
+
+// SetRefundReason sets the "refund_reason" field.
+func (u *PaymentOrderUpsertBulk) SetRefundReason(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetRefundReason(v)
+ })
+}
+
+// UpdateRefundReason sets the "refund_reason" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateRefundReason() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateRefundReason()
+ })
+}
+
+// ClearRefundReason clears the value of the "refund_reason" field.
+func (u *PaymentOrderUpsertBulk) ClearRefundReason() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearRefundReason()
+ })
+}
+
+// SetRefundAt sets the "refund_at" field.
+func (u *PaymentOrderUpsertBulk) SetRefundAt(v time.Time) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetRefundAt(v)
+ })
+}
+
+// UpdateRefundAt sets the "refund_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateRefundAt() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateRefundAt()
+ })
+}
+
+// ClearRefundAt clears the value of the "refund_at" field.
+func (u *PaymentOrderUpsertBulk) ClearRefundAt() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearRefundAt()
+ })
+}
+
+// SetForceRefund sets the "force_refund" field.
+func (u *PaymentOrderUpsertBulk) SetForceRefund(v bool) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetForceRefund(v)
+ })
+}
+
+// UpdateForceRefund sets the "force_refund" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateForceRefund() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateForceRefund()
+ })
+}
+
+// SetRefundRequestedAt sets the "refund_requested_at" field.
+func (u *PaymentOrderUpsertBulk) SetRefundRequestedAt(v time.Time) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetRefundRequestedAt(v)
+ })
+}
+
+// UpdateRefundRequestedAt sets the "refund_requested_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateRefundRequestedAt() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateRefundRequestedAt()
+ })
+}
+
+// ClearRefundRequestedAt clears the value of the "refund_requested_at" field.
+func (u *PaymentOrderUpsertBulk) ClearRefundRequestedAt() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearRefundRequestedAt()
+ })
+}
+
+// SetRefundRequestReason sets the "refund_request_reason" field.
+func (u *PaymentOrderUpsertBulk) SetRefundRequestReason(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetRefundRequestReason(v)
+ })
+}
+
+// UpdateRefundRequestReason sets the "refund_request_reason" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateRefundRequestReason() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateRefundRequestReason()
+ })
+}
+
+// ClearRefundRequestReason clears the value of the "refund_request_reason" field.
+func (u *PaymentOrderUpsertBulk) ClearRefundRequestReason() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearRefundRequestReason()
+ })
+}
+
+// SetRefundRequestedBy sets the "refund_requested_by" field.
+func (u *PaymentOrderUpsertBulk) SetRefundRequestedBy(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetRefundRequestedBy(v)
+ })
+}
+
+// UpdateRefundRequestedBy sets the "refund_requested_by" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateRefundRequestedBy() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateRefundRequestedBy()
+ })
+}
+
+// ClearRefundRequestedBy clears the value of the "refund_requested_by" field.
+func (u *PaymentOrderUpsertBulk) ClearRefundRequestedBy() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearRefundRequestedBy()
+ })
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PaymentOrderUpsertBulk) SetExpiresAt(v time.Time) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateExpiresAt() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// SetPaidAt sets the "paid_at" field.
+func (u *PaymentOrderUpsertBulk) SetPaidAt(v time.Time) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetPaidAt(v)
+ })
+}
+
+// UpdatePaidAt sets the "paid_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdatePaidAt() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdatePaidAt()
+ })
+}
+
+// ClearPaidAt clears the value of the "paid_at" field.
+func (u *PaymentOrderUpsertBulk) ClearPaidAt() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearPaidAt()
+ })
+}
+
+// SetCompletedAt sets the "completed_at" field.
+func (u *PaymentOrderUpsertBulk) SetCompletedAt(v time.Time) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetCompletedAt(v)
+ })
+}
+
+// UpdateCompletedAt sets the "completed_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateCompletedAt() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateCompletedAt()
+ })
+}
+
+// ClearCompletedAt clears the value of the "completed_at" field.
+func (u *PaymentOrderUpsertBulk) ClearCompletedAt() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearCompletedAt()
+ })
+}
+
+// SetFailedAt sets the "failed_at" field.
+func (u *PaymentOrderUpsertBulk) SetFailedAt(v time.Time) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetFailedAt(v)
+ })
+}
+
+// UpdateFailedAt sets the "failed_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateFailedAt() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateFailedAt()
+ })
+}
+
+// ClearFailedAt clears the value of the "failed_at" field.
+func (u *PaymentOrderUpsertBulk) ClearFailedAt() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearFailedAt()
+ })
+}
+
+// SetFailedReason sets the "failed_reason" field.
+func (u *PaymentOrderUpsertBulk) SetFailedReason(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetFailedReason(v)
+ })
+}
+
+// UpdateFailedReason sets the "failed_reason" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateFailedReason() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateFailedReason()
+ })
+}
+
+// ClearFailedReason clears the value of the "failed_reason" field.
+func (u *PaymentOrderUpsertBulk) ClearFailedReason() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearFailedReason()
+ })
+}
+
+// SetClientIP sets the "client_ip" field.
+func (u *PaymentOrderUpsertBulk) SetClientIP(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetClientIP(v)
+ })
+}
+
+// UpdateClientIP sets the "client_ip" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateClientIP() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateClientIP()
+ })
+}
+
+// SetSrcHost sets the "src_host" field.
+func (u *PaymentOrderUpsertBulk) SetSrcHost(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetSrcHost(v)
+ })
+}
+
+// UpdateSrcHost sets the "src_host" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateSrcHost() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateSrcHost()
+ })
+}
+
+// SetSrcURL sets the "src_url" field.
+func (u *PaymentOrderUpsertBulk) SetSrcURL(v string) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetSrcURL(v)
+ })
+}
+
+// UpdateSrcURL sets the "src_url" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateSrcURL() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateSrcURL()
+ })
+}
+
+// ClearSrcURL clears the value of the "src_url" field.
+func (u *PaymentOrderUpsertBulk) ClearSrcURL() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.ClearSrcURL()
+ })
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PaymentOrderUpsertBulk) SetUpdatedAt(v time.Time) *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PaymentOrderUpsertBulk) UpdateUpdatedAt() *PaymentOrderUpsertBulk {
+ return u.Update(func(s *PaymentOrderUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *PaymentOrderUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the PaymentOrderCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PaymentOrderCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PaymentOrderUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/paymentorder_delete.go b/backend/ent/paymentorder_delete.go
new file mode 100644
index 00000000..a4bc1bdf
--- /dev/null
+++ b/backend/ent/paymentorder_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// PaymentOrderDelete is the builder for deleting a PaymentOrder entity.
+type PaymentOrderDelete struct {
+ config
+ hooks []Hook
+ mutation *PaymentOrderMutation
+}
+
+// Where appends a list predicates to the PaymentOrderDelete builder.
+func (_d *PaymentOrderDelete) Where(ps ...predicate.PaymentOrder) *PaymentOrderDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *PaymentOrderDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PaymentOrderDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *PaymentOrderDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(paymentorder.Table, sqlgraph.NewFieldSpec(paymentorder.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// PaymentOrderDeleteOne is the builder for deleting a single PaymentOrder entity.
+type PaymentOrderDeleteOne struct {
+ _d *PaymentOrderDelete
+}
+
+// Where appends a list predicates to the PaymentOrderDelete builder.
+func (_d *PaymentOrderDeleteOne) Where(ps ...predicate.PaymentOrder) *PaymentOrderDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *PaymentOrderDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{paymentorder.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PaymentOrderDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/paymentorder_query.go b/backend/ent/paymentorder_query.go
new file mode 100644
index 00000000..92fd74a7
--- /dev/null
+++ b/backend/ent/paymentorder_query.go
@@ -0,0 +1,643 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PaymentOrderQuery is the builder for querying PaymentOrder entities.
+type PaymentOrderQuery struct {
+ config
+ ctx *QueryContext
+ order []paymentorder.OrderOption
+ inters []Interceptor
+ predicates []predicate.PaymentOrder
+ withUser *UserQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the PaymentOrderQuery builder.
+func (_q *PaymentOrderQuery) Where(ps ...predicate.PaymentOrder) *PaymentOrderQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *PaymentOrderQuery) Limit(limit int) *PaymentOrderQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *PaymentOrderQuery) Offset(offset int) *PaymentOrderQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *PaymentOrderQuery) Unique(unique bool) *PaymentOrderQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *PaymentOrderQuery) Order(o ...paymentorder.OrderOption) *PaymentOrderQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryUser chains the current query on the "user" edge.
+func (_q *PaymentOrderQuery) QueryUser() *UserQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(paymentorder.Table, paymentorder.FieldID, selector),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, paymentorder.UserTable, paymentorder.UserColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first PaymentOrder entity from the query.
+// Returns a *NotFoundError when no PaymentOrder was found.
+func (_q *PaymentOrderQuery) First(ctx context.Context) (*PaymentOrder, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{paymentorder.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *PaymentOrderQuery) FirstX(ctx context.Context) *PaymentOrder {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first PaymentOrder ID from the query.
+// Returns a *NotFoundError when no PaymentOrder ID was found.
+func (_q *PaymentOrderQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{paymentorder.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *PaymentOrderQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single PaymentOrder entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one PaymentOrder entity is found.
+// Returns a *NotFoundError when no PaymentOrder entities are found.
+func (_q *PaymentOrderQuery) Only(ctx context.Context) (*PaymentOrder, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{paymentorder.Label}
+ default:
+ return nil, &NotSingularError{paymentorder.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *PaymentOrderQuery) OnlyX(ctx context.Context) *PaymentOrder {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only PaymentOrder ID in the query.
+// Returns a *NotSingularError when more than one PaymentOrder ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *PaymentOrderQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{paymentorder.Label}
+ default:
+ err = &NotSingularError{paymentorder.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *PaymentOrderQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of PaymentOrders.
+func (_q *PaymentOrderQuery) All(ctx context.Context) ([]*PaymentOrder, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*PaymentOrder, *PaymentOrderQuery]()
+ return withInterceptors[[]*PaymentOrder](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *PaymentOrderQuery) AllX(ctx context.Context) []*PaymentOrder {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of PaymentOrder IDs.
+func (_q *PaymentOrderQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(paymentorder.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *PaymentOrderQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *PaymentOrderQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*PaymentOrderQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *PaymentOrderQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *PaymentOrderQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *PaymentOrderQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the PaymentOrderQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *PaymentOrderQuery) Clone() *PaymentOrderQuery {
+ if _q == nil {
+ return nil
+ }
+ return &PaymentOrderQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]paymentorder.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.PaymentOrder{}, _q.predicates...),
+ withUser: _q.withUser.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithUser tells the query-builder to eager-load the nodes that are connected to
+// the "user" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *PaymentOrderQuery) WithUser(opts ...func(*UserQuery)) *PaymentOrderQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withUser = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// UserID int64 `json:"user_id,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.PaymentOrder.Query().
+// GroupBy(paymentorder.FieldUserID).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *PaymentOrderQuery) GroupBy(field string, fields ...string) *PaymentOrderGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &PaymentOrderGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = paymentorder.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// UserID int64 `json:"user_id,omitempty"`
+// }
+//
+// client.PaymentOrder.Query().
+// Select(paymentorder.FieldUserID).
+// Scan(ctx, &v)
+func (_q *PaymentOrderQuery) Select(fields ...string) *PaymentOrderSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &PaymentOrderSelect{PaymentOrderQuery: _q}
+ sbuild.label = paymentorder.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a PaymentOrderSelect configured with the given aggregations.
+func (_q *PaymentOrderQuery) Aggregate(fns ...AggregateFunc) *PaymentOrderSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *PaymentOrderQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !paymentorder.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *PaymentOrderQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PaymentOrder, error) {
+ var (
+ nodes = []*PaymentOrder{}
+ _spec = _q.querySpec()
+ loadedTypes = [1]bool{
+ _q.withUser != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*PaymentOrder).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &PaymentOrder{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withUser; query != nil {
+ if err := _q.loadUser(ctx, query, nodes, nil,
+ func(n *PaymentOrder, e *User) { n.Edges.User = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *PaymentOrderQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*PaymentOrder, init func(*PaymentOrder), assign func(*PaymentOrder, *User)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*PaymentOrder)
+ for i := range nodes {
+ fk := nodes[i].UserID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(user.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+
+func (_q *PaymentOrderQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *PaymentOrderQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(paymentorder.Table, paymentorder.Columns, sqlgraph.NewFieldSpec(paymentorder.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, paymentorder.FieldID)
+ for i := range fields {
+ if fields[i] != paymentorder.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withUser != nil {
+ _spec.Node.AddColumnOnce(paymentorder.FieldUserID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *PaymentOrderQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(paymentorder.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = paymentorder.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *PaymentOrderQuery) ForUpdate(opts ...sql.LockOption) *PaymentOrderQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *PaymentOrderQuery) ForShare(opts ...sql.LockOption) *PaymentOrderQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// PaymentOrderGroupBy is the group-by builder for PaymentOrder entities.
+type PaymentOrderGroupBy struct {
+ selector
+ build *PaymentOrderQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *PaymentOrderGroupBy) Aggregate(fns ...AggregateFunc) *PaymentOrderGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *PaymentOrderGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PaymentOrderQuery, *PaymentOrderGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *PaymentOrderGroupBy) sqlScan(ctx context.Context, root *PaymentOrderQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// PaymentOrderSelect is the builder for selecting fields of PaymentOrder entities.
+type PaymentOrderSelect struct {
+ *PaymentOrderQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *PaymentOrderSelect) Aggregate(fns ...AggregateFunc) *PaymentOrderSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *PaymentOrderSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PaymentOrderQuery, *PaymentOrderSelect](ctx, _s.PaymentOrderQuery, _s, _s.inters, v)
+}
+
+func (_s *PaymentOrderSelect) sqlScan(ctx context.Context, root *PaymentOrderQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/paymentorder_update.go b/backend/ent/paymentorder_update.go
new file mode 100644
index 00000000..378e0dad
--- /dev/null
+++ b/backend/ent/paymentorder_update.go
@@ -0,0 +1,2181 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PaymentOrderUpdate is the builder for updating PaymentOrder entities.
+type PaymentOrderUpdate struct {
+ config
+ hooks []Hook
+ mutation *PaymentOrderMutation
+}
+
+// Where appends a list predicates to the PaymentOrderUpdate builder.
+func (_u *PaymentOrderUpdate) Where(ps ...predicate.PaymentOrder) *PaymentOrderUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUserID sets the "user_id" field.
+func (_u *PaymentOrderUpdate) SetUserID(v int64) *PaymentOrderUpdate {
+ _u.mutation.SetUserID(v)
+ return _u
+}
+
+// SetNillableUserID sets the "user_id" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableUserID(v *int64) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetUserID(*v)
+ }
+ return _u
+}
+
+// SetUserEmail sets the "user_email" field.
+func (_u *PaymentOrderUpdate) SetUserEmail(v string) *PaymentOrderUpdate {
+ _u.mutation.SetUserEmail(v)
+ return _u
+}
+
+// SetNillableUserEmail sets the "user_email" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableUserEmail(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetUserEmail(*v)
+ }
+ return _u
+}
+
+// SetUserName sets the "user_name" field.
+func (_u *PaymentOrderUpdate) SetUserName(v string) *PaymentOrderUpdate {
+ _u.mutation.SetUserName(v)
+ return _u
+}
+
+// SetNillableUserName sets the "user_name" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableUserName(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetUserName(*v)
+ }
+ return _u
+}
+
+// SetUserNotes sets the "user_notes" field.
+func (_u *PaymentOrderUpdate) SetUserNotes(v string) *PaymentOrderUpdate {
+ _u.mutation.SetUserNotes(v)
+ return _u
+}
+
+// SetNillableUserNotes sets the "user_notes" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableUserNotes(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetUserNotes(*v)
+ }
+ return _u
+}
+
+// ClearUserNotes clears the value of the "user_notes" field.
+func (_u *PaymentOrderUpdate) ClearUserNotes() *PaymentOrderUpdate {
+ _u.mutation.ClearUserNotes()
+ return _u
+}
+
+// SetAmount sets the "amount" field.
+func (_u *PaymentOrderUpdate) SetAmount(v float64) *PaymentOrderUpdate {
+ _u.mutation.ResetAmount()
+ _u.mutation.SetAmount(v)
+ return _u
+}
+
+// SetNillableAmount sets the "amount" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableAmount(v *float64) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetAmount(*v)
+ }
+ return _u
+}
+
+// AddAmount adds value to the "amount" field.
+func (_u *PaymentOrderUpdate) AddAmount(v float64) *PaymentOrderUpdate {
+ _u.mutation.AddAmount(v)
+ return _u
+}
+
+// SetPayAmount sets the "pay_amount" field.
+func (_u *PaymentOrderUpdate) SetPayAmount(v float64) *PaymentOrderUpdate {
+ _u.mutation.ResetPayAmount()
+ _u.mutation.SetPayAmount(v)
+ return _u
+}
+
+// SetNillablePayAmount sets the "pay_amount" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillablePayAmount(v *float64) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetPayAmount(*v)
+ }
+ return _u
+}
+
+// AddPayAmount adds value to the "pay_amount" field.
+func (_u *PaymentOrderUpdate) AddPayAmount(v float64) *PaymentOrderUpdate {
+ _u.mutation.AddPayAmount(v)
+ return _u
+}
+
+// SetFeeRate sets the "fee_rate" field.
+func (_u *PaymentOrderUpdate) SetFeeRate(v float64) *PaymentOrderUpdate {
+ _u.mutation.ResetFeeRate()
+ _u.mutation.SetFeeRate(v)
+ return _u
+}
+
+// SetNillableFeeRate sets the "fee_rate" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableFeeRate(v *float64) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetFeeRate(*v)
+ }
+ return _u
+}
+
+// AddFeeRate adds value to the "fee_rate" field.
+func (_u *PaymentOrderUpdate) AddFeeRate(v float64) *PaymentOrderUpdate {
+ _u.mutation.AddFeeRate(v)
+ return _u
+}
+
+// SetRechargeCode sets the "recharge_code" field.
+func (_u *PaymentOrderUpdate) SetRechargeCode(v string) *PaymentOrderUpdate {
+ _u.mutation.SetRechargeCode(v)
+ return _u
+}
+
+// SetNillableRechargeCode sets the "recharge_code" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableRechargeCode(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetRechargeCode(*v)
+ }
+ return _u
+}
+
+// SetOutTradeNo sets the "out_trade_no" field.
+func (_u *PaymentOrderUpdate) SetOutTradeNo(v string) *PaymentOrderUpdate {
+ _u.mutation.SetOutTradeNo(v)
+ return _u
+}
+
+// SetNillableOutTradeNo sets the "out_trade_no" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableOutTradeNo(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetOutTradeNo(*v)
+ }
+ return _u
+}
+
+// SetPaymentType sets the "payment_type" field.
+func (_u *PaymentOrderUpdate) SetPaymentType(v string) *PaymentOrderUpdate {
+ _u.mutation.SetPaymentType(v)
+ return _u
+}
+
+// SetNillablePaymentType sets the "payment_type" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillablePaymentType(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetPaymentType(*v)
+ }
+ return _u
+}
+
+// SetPaymentTradeNo sets the "payment_trade_no" field.
+func (_u *PaymentOrderUpdate) SetPaymentTradeNo(v string) *PaymentOrderUpdate {
+ _u.mutation.SetPaymentTradeNo(v)
+ return _u
+}
+
+// SetNillablePaymentTradeNo sets the "payment_trade_no" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillablePaymentTradeNo(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetPaymentTradeNo(*v)
+ }
+ return _u
+}
+
+// SetPayURL sets the "pay_url" field.
+func (_u *PaymentOrderUpdate) SetPayURL(v string) *PaymentOrderUpdate {
+ _u.mutation.SetPayURL(v)
+ return _u
+}
+
+// SetNillablePayURL sets the "pay_url" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillablePayURL(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetPayURL(*v)
+ }
+ return _u
+}
+
+// ClearPayURL clears the value of the "pay_url" field.
+func (_u *PaymentOrderUpdate) ClearPayURL() *PaymentOrderUpdate {
+ _u.mutation.ClearPayURL()
+ return _u
+}
+
+// SetQrCode sets the "qr_code" field.
+func (_u *PaymentOrderUpdate) SetQrCode(v string) *PaymentOrderUpdate {
+ _u.mutation.SetQrCode(v)
+ return _u
+}
+
+// SetNillableQrCode sets the "qr_code" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableQrCode(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetQrCode(*v)
+ }
+ return _u
+}
+
+// ClearQrCode clears the value of the "qr_code" field.
+func (_u *PaymentOrderUpdate) ClearQrCode() *PaymentOrderUpdate {
+ _u.mutation.ClearQrCode()
+ return _u
+}
+
+// SetQrCodeImg sets the "qr_code_img" field.
+func (_u *PaymentOrderUpdate) SetQrCodeImg(v string) *PaymentOrderUpdate {
+ _u.mutation.SetQrCodeImg(v)
+ return _u
+}
+
+// SetNillableQrCodeImg sets the "qr_code_img" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableQrCodeImg(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetQrCodeImg(*v)
+ }
+ return _u
+}
+
+// ClearQrCodeImg clears the value of the "qr_code_img" field.
+func (_u *PaymentOrderUpdate) ClearQrCodeImg() *PaymentOrderUpdate {
+ _u.mutation.ClearQrCodeImg()
+ return _u
+}
+
+// SetOrderType sets the "order_type" field.
+func (_u *PaymentOrderUpdate) SetOrderType(v string) *PaymentOrderUpdate {
+ _u.mutation.SetOrderType(v)
+ return _u
+}
+
+// SetNillableOrderType sets the "order_type" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableOrderType(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetOrderType(*v)
+ }
+ return _u
+}
+
+// SetPlanID sets the "plan_id" field.
+func (_u *PaymentOrderUpdate) SetPlanID(v int64) *PaymentOrderUpdate {
+ _u.mutation.ResetPlanID()
+ _u.mutation.SetPlanID(v)
+ return _u
+}
+
+// SetNillablePlanID sets the "plan_id" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillablePlanID(v *int64) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetPlanID(*v)
+ }
+ return _u
+}
+
+// AddPlanID adds value to the "plan_id" field.
+func (_u *PaymentOrderUpdate) AddPlanID(v int64) *PaymentOrderUpdate {
+ _u.mutation.AddPlanID(v)
+ return _u
+}
+
+// ClearPlanID clears the value of the "plan_id" field.
+func (_u *PaymentOrderUpdate) ClearPlanID() *PaymentOrderUpdate {
+ _u.mutation.ClearPlanID()
+ return _u
+}
+
+// SetSubscriptionGroupID sets the "subscription_group_id" field.
+func (_u *PaymentOrderUpdate) SetSubscriptionGroupID(v int64) *PaymentOrderUpdate {
+ _u.mutation.ResetSubscriptionGroupID()
+ _u.mutation.SetSubscriptionGroupID(v)
+ return _u
+}
+
+// SetNillableSubscriptionGroupID sets the "subscription_group_id" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableSubscriptionGroupID(v *int64) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetSubscriptionGroupID(*v)
+ }
+ return _u
+}
+
+// AddSubscriptionGroupID adds value to the "subscription_group_id" field.
+func (_u *PaymentOrderUpdate) AddSubscriptionGroupID(v int64) *PaymentOrderUpdate {
+ _u.mutation.AddSubscriptionGroupID(v)
+ return _u
+}
+
+// ClearSubscriptionGroupID clears the value of the "subscription_group_id" field.
+func (_u *PaymentOrderUpdate) ClearSubscriptionGroupID() *PaymentOrderUpdate {
+ _u.mutation.ClearSubscriptionGroupID()
+ return _u
+}
+
+// SetSubscriptionDays sets the "subscription_days" field.
+func (_u *PaymentOrderUpdate) SetSubscriptionDays(v int) *PaymentOrderUpdate {
+ _u.mutation.ResetSubscriptionDays()
+ _u.mutation.SetSubscriptionDays(v)
+ return _u
+}
+
+// SetNillableSubscriptionDays sets the "subscription_days" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableSubscriptionDays(v *int) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetSubscriptionDays(*v)
+ }
+ return _u
+}
+
+// AddSubscriptionDays adds value to the "subscription_days" field.
+func (_u *PaymentOrderUpdate) AddSubscriptionDays(v int) *PaymentOrderUpdate {
+ _u.mutation.AddSubscriptionDays(v)
+ return _u
+}
+
+// ClearSubscriptionDays clears the value of the "subscription_days" field.
+func (_u *PaymentOrderUpdate) ClearSubscriptionDays() *PaymentOrderUpdate {
+ _u.mutation.ClearSubscriptionDays()
+ return _u
+}
+
+// SetProviderInstanceID sets the "provider_instance_id" field.
+func (_u *PaymentOrderUpdate) SetProviderInstanceID(v string) *PaymentOrderUpdate {
+ _u.mutation.SetProviderInstanceID(v)
+ return _u
+}
+
+// SetNillableProviderInstanceID sets the "provider_instance_id" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableProviderInstanceID(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetProviderInstanceID(*v)
+ }
+ return _u
+}
+
+// ClearProviderInstanceID clears the value of the "provider_instance_id" field.
+func (_u *PaymentOrderUpdate) ClearProviderInstanceID() *PaymentOrderUpdate {
+ _u.mutation.ClearProviderInstanceID()
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *PaymentOrderUpdate) SetProviderKey(v string) *PaymentOrderUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableProviderKey(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (_u *PaymentOrderUpdate) ClearProviderKey() *PaymentOrderUpdate {
+ _u.mutation.ClearProviderKey()
+ return _u
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (_u *PaymentOrderUpdate) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpdate {
+ _u.mutation.SetProviderSnapshot(v)
+ return _u
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (_u *PaymentOrderUpdate) ClearProviderSnapshot() *PaymentOrderUpdate {
+ _u.mutation.ClearProviderSnapshot()
+ return _u
+}
+
+// SetStatus sets the "status" field.
+func (_u *PaymentOrderUpdate) SetStatus(v string) *PaymentOrderUpdate {
+ _u.mutation.SetStatus(v)
+ return _u
+}
+
+// SetNillableStatus sets the "status" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableStatus(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetStatus(*v)
+ }
+ return _u
+}
+
+// SetRefundAmount sets the "refund_amount" field.
+func (_u *PaymentOrderUpdate) SetRefundAmount(v float64) *PaymentOrderUpdate {
+ _u.mutation.ResetRefundAmount()
+ _u.mutation.SetRefundAmount(v)
+ return _u
+}
+
+// SetNillableRefundAmount sets the "refund_amount" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableRefundAmount(v *float64) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetRefundAmount(*v)
+ }
+ return _u
+}
+
+// AddRefundAmount adds value to the "refund_amount" field.
+func (_u *PaymentOrderUpdate) AddRefundAmount(v float64) *PaymentOrderUpdate {
+ _u.mutation.AddRefundAmount(v)
+ return _u
+}
+
+// SetRefundReason sets the "refund_reason" field.
+func (_u *PaymentOrderUpdate) SetRefundReason(v string) *PaymentOrderUpdate {
+ _u.mutation.SetRefundReason(v)
+ return _u
+}
+
+// SetNillableRefundReason sets the "refund_reason" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableRefundReason(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetRefundReason(*v)
+ }
+ return _u
+}
+
+// ClearRefundReason clears the value of the "refund_reason" field.
+func (_u *PaymentOrderUpdate) ClearRefundReason() *PaymentOrderUpdate {
+ _u.mutation.ClearRefundReason()
+ return _u
+}
+
+// SetRefundAt sets the "refund_at" field.
+func (_u *PaymentOrderUpdate) SetRefundAt(v time.Time) *PaymentOrderUpdate {
+ _u.mutation.SetRefundAt(v)
+ return _u
+}
+
+// SetNillableRefundAt sets the "refund_at" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableRefundAt(v *time.Time) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetRefundAt(*v)
+ }
+ return _u
+}
+
+// ClearRefundAt clears the value of the "refund_at" field.
+func (_u *PaymentOrderUpdate) ClearRefundAt() *PaymentOrderUpdate {
+ _u.mutation.ClearRefundAt()
+ return _u
+}
+
+// SetForceRefund sets the "force_refund" field.
+func (_u *PaymentOrderUpdate) SetForceRefund(v bool) *PaymentOrderUpdate {
+ _u.mutation.SetForceRefund(v)
+ return _u
+}
+
+// SetNillableForceRefund sets the "force_refund" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableForceRefund(v *bool) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetForceRefund(*v)
+ }
+ return _u
+}
+
+// SetRefundRequestedAt sets the "refund_requested_at" field.
+func (_u *PaymentOrderUpdate) SetRefundRequestedAt(v time.Time) *PaymentOrderUpdate {
+ _u.mutation.SetRefundRequestedAt(v)
+ return _u
+}
+
+// SetNillableRefundRequestedAt sets the "refund_requested_at" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableRefundRequestedAt(v *time.Time) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetRefundRequestedAt(*v)
+ }
+ return _u
+}
+
+// ClearRefundRequestedAt clears the value of the "refund_requested_at" field.
+func (_u *PaymentOrderUpdate) ClearRefundRequestedAt() *PaymentOrderUpdate {
+ _u.mutation.ClearRefundRequestedAt()
+ return _u
+}
+
+// SetRefundRequestReason sets the "refund_request_reason" field.
+func (_u *PaymentOrderUpdate) SetRefundRequestReason(v string) *PaymentOrderUpdate {
+ _u.mutation.SetRefundRequestReason(v)
+ return _u
+}
+
+// SetNillableRefundRequestReason sets the "refund_request_reason" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableRefundRequestReason(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetRefundRequestReason(*v)
+ }
+ return _u
+}
+
+// ClearRefundRequestReason clears the value of the "refund_request_reason" field.
+func (_u *PaymentOrderUpdate) ClearRefundRequestReason() *PaymentOrderUpdate {
+ _u.mutation.ClearRefundRequestReason()
+ return _u
+}
+
+// SetRefundRequestedBy sets the "refund_requested_by" field.
+func (_u *PaymentOrderUpdate) SetRefundRequestedBy(v string) *PaymentOrderUpdate {
+ _u.mutation.SetRefundRequestedBy(v)
+ return _u
+}
+
+// SetNillableRefundRequestedBy sets the "refund_requested_by" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableRefundRequestedBy(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetRefundRequestedBy(*v)
+ }
+ return _u
+}
+
+// ClearRefundRequestedBy clears the value of the "refund_requested_by" field.
+func (_u *PaymentOrderUpdate) ClearRefundRequestedBy() *PaymentOrderUpdate {
+ _u.mutation.ClearRefundRequestedBy()
+ return _u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_u *PaymentOrderUpdate) SetExpiresAt(v time.Time) *PaymentOrderUpdate {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableExpiresAt(v *time.Time) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// SetPaidAt sets the "paid_at" field.
+func (_u *PaymentOrderUpdate) SetPaidAt(v time.Time) *PaymentOrderUpdate {
+ _u.mutation.SetPaidAt(v)
+ return _u
+}
+
+// SetNillablePaidAt sets the "paid_at" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillablePaidAt(v *time.Time) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetPaidAt(*v)
+ }
+ return _u
+}
+
+// ClearPaidAt clears the value of the "paid_at" field.
+func (_u *PaymentOrderUpdate) ClearPaidAt() *PaymentOrderUpdate {
+ _u.mutation.ClearPaidAt()
+ return _u
+}
+
+// SetCompletedAt sets the "completed_at" field.
+func (_u *PaymentOrderUpdate) SetCompletedAt(v time.Time) *PaymentOrderUpdate {
+ _u.mutation.SetCompletedAt(v)
+ return _u
+}
+
+// SetNillableCompletedAt sets the "completed_at" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableCompletedAt(v *time.Time) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetCompletedAt(*v)
+ }
+ return _u
+}
+
+// ClearCompletedAt clears the value of the "completed_at" field.
+func (_u *PaymentOrderUpdate) ClearCompletedAt() *PaymentOrderUpdate {
+ _u.mutation.ClearCompletedAt()
+ return _u
+}
+
+// SetFailedAt sets the "failed_at" field.
+func (_u *PaymentOrderUpdate) SetFailedAt(v time.Time) *PaymentOrderUpdate {
+ _u.mutation.SetFailedAt(v)
+ return _u
+}
+
+// SetNillableFailedAt sets the "failed_at" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableFailedAt(v *time.Time) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetFailedAt(*v)
+ }
+ return _u
+}
+
+// ClearFailedAt clears the value of the "failed_at" field.
+func (_u *PaymentOrderUpdate) ClearFailedAt() *PaymentOrderUpdate {
+ _u.mutation.ClearFailedAt()
+ return _u
+}
+
+// SetFailedReason sets the "failed_reason" field.
+func (_u *PaymentOrderUpdate) SetFailedReason(v string) *PaymentOrderUpdate {
+ _u.mutation.SetFailedReason(v)
+ return _u
+}
+
+// SetNillableFailedReason sets the "failed_reason" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableFailedReason(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetFailedReason(*v)
+ }
+ return _u
+}
+
+// ClearFailedReason clears the value of the "failed_reason" field.
+func (_u *PaymentOrderUpdate) ClearFailedReason() *PaymentOrderUpdate {
+ _u.mutation.ClearFailedReason()
+ return _u
+}
+
+// SetClientIP sets the "client_ip" field.
+func (_u *PaymentOrderUpdate) SetClientIP(v string) *PaymentOrderUpdate {
+ _u.mutation.SetClientIP(v)
+ return _u
+}
+
+// SetNillableClientIP sets the "client_ip" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableClientIP(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetClientIP(*v)
+ }
+ return _u
+}
+
+// SetSrcHost sets the "src_host" field.
+func (_u *PaymentOrderUpdate) SetSrcHost(v string) *PaymentOrderUpdate {
+ _u.mutation.SetSrcHost(v)
+ return _u
+}
+
+// SetNillableSrcHost sets the "src_host" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableSrcHost(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetSrcHost(*v)
+ }
+ return _u
+}
+
+// SetSrcURL sets the "src_url" field.
+func (_u *PaymentOrderUpdate) SetSrcURL(v string) *PaymentOrderUpdate {
+ _u.mutation.SetSrcURL(v)
+ return _u
+}
+
+// SetNillableSrcURL sets the "src_url" field if the given value is not nil.
+func (_u *PaymentOrderUpdate) SetNillableSrcURL(v *string) *PaymentOrderUpdate {
+ if v != nil {
+ _u.SetSrcURL(*v)
+ }
+ return _u
+}
+
+// ClearSrcURL clears the value of the "src_url" field.
+func (_u *PaymentOrderUpdate) ClearSrcURL() *PaymentOrderUpdate {
+ _u.mutation.ClearSrcURL()
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *PaymentOrderUpdate) SetUpdatedAt(v time.Time) *PaymentOrderUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_u *PaymentOrderUpdate) SetUser(v *User) *PaymentOrderUpdate {
+ return _u.SetUserID(v.ID)
+}
+
+// Mutation returns the PaymentOrderMutation object of the builder.
+func (_u *PaymentOrderUpdate) Mutation() *PaymentOrderMutation {
+ return _u.mutation
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (_u *PaymentOrderUpdate) ClearUser() *PaymentOrderUpdate {
+ _u.mutation.ClearUser()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *PaymentOrderUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PaymentOrderUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *PaymentOrderUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PaymentOrderUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *PaymentOrderUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := paymentorder.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PaymentOrderUpdate) check() error {
+ if v, ok := _u.mutation.UserEmail(); ok {
+ if err := paymentorder.UserEmailValidator(v); err != nil {
+ return &ValidationError{Name: "user_email", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.user_email": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.UserName(); ok {
+ if err := paymentorder.UserNameValidator(v); err != nil {
+ return &ValidationError{Name: "user_name", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.user_name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.RechargeCode(); ok {
+ if err := paymentorder.RechargeCodeValidator(v); err != nil {
+ return &ValidationError{Name: "recharge_code", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.recharge_code": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.OutTradeNo(); ok {
+ if err := paymentorder.OutTradeNoValidator(v); err != nil {
+ return &ValidationError{Name: "out_trade_no", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.out_trade_no": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.PaymentType(); ok {
+ if err := paymentorder.PaymentTypeValidator(v); err != nil {
+ return &ValidationError{Name: "payment_type", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.payment_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.PaymentTradeNo(); ok {
+ if err := paymentorder.PaymentTradeNoValidator(v); err != nil {
+ return &ValidationError{Name: "payment_trade_no", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.payment_trade_no": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.OrderType(); ok {
+ if err := paymentorder.OrderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "order_type", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.order_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderInstanceID(); ok {
+ if err := paymentorder.ProviderInstanceIDValidator(v); err != nil {
+ return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := paymentorder.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Status(); ok {
+ if err := paymentorder.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.RefundRequestedBy(); ok {
+ if err := paymentorder.RefundRequestedByValidator(v); err != nil {
+ return &ValidationError{Name: "refund_requested_by", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.refund_requested_by": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ClientIP(); ok {
+ if err := paymentorder.ClientIPValidator(v); err != nil {
+ return &ValidationError{Name: "client_ip", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.client_ip": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.SrcHost(); ok {
+ if err := paymentorder.SrcHostValidator(v); err != nil {
+ return &ValidationError{Name: "src_host", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.src_host": %w`, err)}
+ }
+ }
+ if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "PaymentOrder.user"`)
+ }
+ return nil
+}
+
+func (_u *PaymentOrderUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(paymentorder.Table, paymentorder.Columns, sqlgraph.NewFieldSpec(paymentorder.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UserEmail(); ok {
+ _spec.SetField(paymentorder.FieldUserEmail, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.UserName(); ok {
+ _spec.SetField(paymentorder.FieldUserName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.UserNotes(); ok {
+ _spec.SetField(paymentorder.FieldUserNotes, field.TypeString, value)
+ }
+ if _u.mutation.UserNotesCleared() {
+ _spec.ClearField(paymentorder.FieldUserNotes, field.TypeString)
+ }
+ if value, ok := _u.mutation.Amount(); ok {
+ _spec.SetField(paymentorder.FieldAmount, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedAmount(); ok {
+ _spec.AddField(paymentorder.FieldAmount, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.PayAmount(); ok {
+ _spec.SetField(paymentorder.FieldPayAmount, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedPayAmount(); ok {
+ _spec.AddField(paymentorder.FieldPayAmount, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.FeeRate(); ok {
+ _spec.SetField(paymentorder.FieldFeeRate, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedFeeRate(); ok {
+ _spec.AddField(paymentorder.FieldFeeRate, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.RechargeCode(); ok {
+ _spec.SetField(paymentorder.FieldRechargeCode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.OutTradeNo(); ok {
+ _spec.SetField(paymentorder.FieldOutTradeNo, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.PaymentType(); ok {
+ _spec.SetField(paymentorder.FieldPaymentType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.PaymentTradeNo(); ok {
+ _spec.SetField(paymentorder.FieldPaymentTradeNo, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.PayURL(); ok {
+ _spec.SetField(paymentorder.FieldPayURL, field.TypeString, value)
+ }
+ if _u.mutation.PayURLCleared() {
+ _spec.ClearField(paymentorder.FieldPayURL, field.TypeString)
+ }
+ if value, ok := _u.mutation.QrCode(); ok {
+ _spec.SetField(paymentorder.FieldQrCode, field.TypeString, value)
+ }
+ if _u.mutation.QrCodeCleared() {
+ _spec.ClearField(paymentorder.FieldQrCode, field.TypeString)
+ }
+ if value, ok := _u.mutation.QrCodeImg(); ok {
+ _spec.SetField(paymentorder.FieldQrCodeImg, field.TypeString, value)
+ }
+ if _u.mutation.QrCodeImgCleared() {
+ _spec.ClearField(paymentorder.FieldQrCodeImg, field.TypeString)
+ }
+ if value, ok := _u.mutation.OrderType(); ok {
+ _spec.SetField(paymentorder.FieldOrderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.PlanID(); ok {
+ _spec.SetField(paymentorder.FieldPlanID, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedPlanID(); ok {
+ _spec.AddField(paymentorder.FieldPlanID, field.TypeInt64, value)
+ }
+ if _u.mutation.PlanIDCleared() {
+ _spec.ClearField(paymentorder.FieldPlanID, field.TypeInt64)
+ }
+ if value, ok := _u.mutation.SubscriptionGroupID(); ok {
+ _spec.SetField(paymentorder.FieldSubscriptionGroupID, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSubscriptionGroupID(); ok {
+ _spec.AddField(paymentorder.FieldSubscriptionGroupID, field.TypeInt64, value)
+ }
+ if _u.mutation.SubscriptionGroupIDCleared() {
+ _spec.ClearField(paymentorder.FieldSubscriptionGroupID, field.TypeInt64)
+ }
+ if value, ok := _u.mutation.SubscriptionDays(); ok {
+ _spec.SetField(paymentorder.FieldSubscriptionDays, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedSubscriptionDays(); ok {
+ _spec.AddField(paymentorder.FieldSubscriptionDays, field.TypeInt, value)
+ }
+ if _u.mutation.SubscriptionDaysCleared() {
+ _spec.ClearField(paymentorder.FieldSubscriptionDays, field.TypeInt)
+ }
+ if value, ok := _u.mutation.ProviderInstanceID(); ok {
+ _spec.SetField(paymentorder.FieldProviderInstanceID, field.TypeString, value)
+ }
+ if _u.mutation.ProviderInstanceIDCleared() {
+ _spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value)
+ }
+ if _u.mutation.ProviderKeyCleared() {
+ _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString)
+ }
+ if value, ok := _u.mutation.ProviderSnapshot(); ok {
+ _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value)
+ }
+ if _u.mutation.ProviderSnapshotCleared() {
+ _spec.ClearField(paymentorder.FieldProviderSnapshot, field.TypeJSON)
+ }
+ if value, ok := _u.mutation.Status(); ok {
+ _spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RefundAmount(); ok {
+ _spec.SetField(paymentorder.FieldRefundAmount, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedRefundAmount(); ok {
+ _spec.AddField(paymentorder.FieldRefundAmount, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.RefundReason(); ok {
+ _spec.SetField(paymentorder.FieldRefundReason, field.TypeString, value)
+ }
+ if _u.mutation.RefundReasonCleared() {
+ _spec.ClearField(paymentorder.FieldRefundReason, field.TypeString)
+ }
+ if value, ok := _u.mutation.RefundAt(); ok {
+ _spec.SetField(paymentorder.FieldRefundAt, field.TypeTime, value)
+ }
+ if _u.mutation.RefundAtCleared() {
+ _spec.ClearField(paymentorder.FieldRefundAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.ForceRefund(); ok {
+ _spec.SetField(paymentorder.FieldForceRefund, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.RefundRequestedAt(); ok {
+ _spec.SetField(paymentorder.FieldRefundRequestedAt, field.TypeTime, value)
+ }
+ if _u.mutation.RefundRequestedAtCleared() {
+ _spec.ClearField(paymentorder.FieldRefundRequestedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.RefundRequestReason(); ok {
+ _spec.SetField(paymentorder.FieldRefundRequestReason, field.TypeString, value)
+ }
+ if _u.mutation.RefundRequestReasonCleared() {
+ _spec.ClearField(paymentorder.FieldRefundRequestReason, field.TypeString)
+ }
+ if value, ok := _u.mutation.RefundRequestedBy(); ok {
+ _spec.SetField(paymentorder.FieldRefundRequestedBy, field.TypeString, value)
+ }
+ if _u.mutation.RefundRequestedByCleared() {
+ _spec.ClearField(paymentorder.FieldRefundRequestedBy, field.TypeString)
+ }
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(paymentorder.FieldExpiresAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.PaidAt(); ok {
+ _spec.SetField(paymentorder.FieldPaidAt, field.TypeTime, value)
+ }
+ if _u.mutation.PaidAtCleared() {
+ _spec.ClearField(paymentorder.FieldPaidAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.CompletedAt(); ok {
+ _spec.SetField(paymentorder.FieldCompletedAt, field.TypeTime, value)
+ }
+ if _u.mutation.CompletedAtCleared() {
+ _spec.ClearField(paymentorder.FieldCompletedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.FailedAt(); ok {
+ _spec.SetField(paymentorder.FieldFailedAt, field.TypeTime, value)
+ }
+ if _u.mutation.FailedAtCleared() {
+ _spec.ClearField(paymentorder.FieldFailedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.FailedReason(); ok {
+ _spec.SetField(paymentorder.FieldFailedReason, field.TypeString, value)
+ }
+ if _u.mutation.FailedReasonCleared() {
+ _spec.ClearField(paymentorder.FieldFailedReason, field.TypeString)
+ }
+ if value, ok := _u.mutation.ClientIP(); ok {
+ _spec.SetField(paymentorder.FieldClientIP, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.SrcHost(); ok {
+ _spec.SetField(paymentorder.FieldSrcHost, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.SrcURL(); ok {
+ _spec.SetField(paymentorder.FieldSrcURL, field.TypeString, value)
+ }
+ if _u.mutation.SrcURLCleared() {
+ _spec.ClearField(paymentorder.FieldSrcURL, field.TypeString)
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(paymentorder.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if _u.mutation.UserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: paymentorder.UserTable,
+ Columns: []string{paymentorder.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: paymentorder.UserTable,
+ Columns: []string{paymentorder.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{paymentorder.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// PaymentOrderUpdateOne is the builder for updating a single PaymentOrder entity.
+type PaymentOrderUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *PaymentOrderMutation
+}
+
+// SetUserID sets the "user_id" field.
+func (_u *PaymentOrderUpdateOne) SetUserID(v int64) *PaymentOrderUpdateOne {
+ _u.mutation.SetUserID(v)
+ return _u
+}
+
+// SetNillableUserID sets the "user_id" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableUserID(v *int64) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetUserID(*v)
+ }
+ return _u
+}
+
+// SetUserEmail sets the "user_email" field.
+func (_u *PaymentOrderUpdateOne) SetUserEmail(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetUserEmail(v)
+ return _u
+}
+
+// SetNillableUserEmail sets the "user_email" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableUserEmail(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetUserEmail(*v)
+ }
+ return _u
+}
+
+// SetUserName sets the "user_name" field.
+func (_u *PaymentOrderUpdateOne) SetUserName(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetUserName(v)
+ return _u
+}
+
+// SetNillableUserName sets the "user_name" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableUserName(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetUserName(*v)
+ }
+ return _u
+}
+
+// SetUserNotes sets the "user_notes" field.
+func (_u *PaymentOrderUpdateOne) SetUserNotes(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetUserNotes(v)
+ return _u
+}
+
+// SetNillableUserNotes sets the "user_notes" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableUserNotes(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetUserNotes(*v)
+ }
+ return _u
+}
+
+// ClearUserNotes clears the value of the "user_notes" field.
+func (_u *PaymentOrderUpdateOne) ClearUserNotes() *PaymentOrderUpdateOne {
+ _u.mutation.ClearUserNotes()
+ return _u
+}
+
+// SetAmount sets the "amount" field.
+func (_u *PaymentOrderUpdateOne) SetAmount(v float64) *PaymentOrderUpdateOne {
+ _u.mutation.ResetAmount()
+ _u.mutation.SetAmount(v)
+ return _u
+}
+
+// SetNillableAmount sets the "amount" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableAmount(v *float64) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetAmount(*v)
+ }
+ return _u
+}
+
+// AddAmount adds value to the "amount" field.
+func (_u *PaymentOrderUpdateOne) AddAmount(v float64) *PaymentOrderUpdateOne {
+ _u.mutation.AddAmount(v)
+ return _u
+}
+
+// SetPayAmount sets the "pay_amount" field.
+func (_u *PaymentOrderUpdateOne) SetPayAmount(v float64) *PaymentOrderUpdateOne {
+ _u.mutation.ResetPayAmount()
+ _u.mutation.SetPayAmount(v)
+ return _u
+}
+
+// SetNillablePayAmount sets the "pay_amount" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillablePayAmount(v *float64) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetPayAmount(*v)
+ }
+ return _u
+}
+
+// AddPayAmount adds value to the "pay_amount" field.
+func (_u *PaymentOrderUpdateOne) AddPayAmount(v float64) *PaymentOrderUpdateOne {
+ _u.mutation.AddPayAmount(v)
+ return _u
+}
+
+// SetFeeRate sets the "fee_rate" field.
+func (_u *PaymentOrderUpdateOne) SetFeeRate(v float64) *PaymentOrderUpdateOne {
+ _u.mutation.ResetFeeRate()
+ _u.mutation.SetFeeRate(v)
+ return _u
+}
+
+// SetNillableFeeRate sets the "fee_rate" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableFeeRate(v *float64) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetFeeRate(*v)
+ }
+ return _u
+}
+
+// AddFeeRate adds value to the "fee_rate" field.
+func (_u *PaymentOrderUpdateOne) AddFeeRate(v float64) *PaymentOrderUpdateOne {
+ _u.mutation.AddFeeRate(v)
+ return _u
+}
+
+// SetRechargeCode sets the "recharge_code" field.
+func (_u *PaymentOrderUpdateOne) SetRechargeCode(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetRechargeCode(v)
+ return _u
+}
+
+// SetNillableRechargeCode sets the "recharge_code" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableRechargeCode(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetRechargeCode(*v)
+ }
+ return _u
+}
+
+// SetOutTradeNo sets the "out_trade_no" field.
+func (_u *PaymentOrderUpdateOne) SetOutTradeNo(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetOutTradeNo(v)
+ return _u
+}
+
+// SetNillableOutTradeNo sets the "out_trade_no" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableOutTradeNo(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetOutTradeNo(*v)
+ }
+ return _u
+}
+
+// SetPaymentType sets the "payment_type" field.
+func (_u *PaymentOrderUpdateOne) SetPaymentType(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetPaymentType(v)
+ return _u
+}
+
+// SetNillablePaymentType sets the "payment_type" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillablePaymentType(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetPaymentType(*v)
+ }
+ return _u
+}
+
+// SetPaymentTradeNo sets the "payment_trade_no" field.
+func (_u *PaymentOrderUpdateOne) SetPaymentTradeNo(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetPaymentTradeNo(v)
+ return _u
+}
+
+// SetNillablePaymentTradeNo sets the "payment_trade_no" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillablePaymentTradeNo(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetPaymentTradeNo(*v)
+ }
+ return _u
+}
+
+// SetPayURL sets the "pay_url" field.
+func (_u *PaymentOrderUpdateOne) SetPayURL(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetPayURL(v)
+ return _u
+}
+
+// SetNillablePayURL sets the "pay_url" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillablePayURL(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetPayURL(*v)
+ }
+ return _u
+}
+
+// ClearPayURL clears the value of the "pay_url" field.
+func (_u *PaymentOrderUpdateOne) ClearPayURL() *PaymentOrderUpdateOne {
+ _u.mutation.ClearPayURL()
+ return _u
+}
+
+// SetQrCode sets the "qr_code" field.
+func (_u *PaymentOrderUpdateOne) SetQrCode(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetQrCode(v)
+ return _u
+}
+
+// SetNillableQrCode sets the "qr_code" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableQrCode(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetQrCode(*v)
+ }
+ return _u
+}
+
+// ClearQrCode clears the value of the "qr_code" field.
+func (_u *PaymentOrderUpdateOne) ClearQrCode() *PaymentOrderUpdateOne {
+ _u.mutation.ClearQrCode()
+ return _u
+}
+
+// SetQrCodeImg sets the "qr_code_img" field.
+func (_u *PaymentOrderUpdateOne) SetQrCodeImg(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetQrCodeImg(v)
+ return _u
+}
+
+// SetNillableQrCodeImg sets the "qr_code_img" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableQrCodeImg(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetQrCodeImg(*v)
+ }
+ return _u
+}
+
+// ClearQrCodeImg clears the value of the "qr_code_img" field.
+func (_u *PaymentOrderUpdateOne) ClearQrCodeImg() *PaymentOrderUpdateOne {
+ _u.mutation.ClearQrCodeImg()
+ return _u
+}
+
+// SetOrderType sets the "order_type" field.
+func (_u *PaymentOrderUpdateOne) SetOrderType(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetOrderType(v)
+ return _u
+}
+
+// SetNillableOrderType sets the "order_type" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableOrderType(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetOrderType(*v)
+ }
+ return _u
+}
+
+// SetPlanID sets the "plan_id" field.
+func (_u *PaymentOrderUpdateOne) SetPlanID(v int64) *PaymentOrderUpdateOne {
+ _u.mutation.ResetPlanID()
+ _u.mutation.SetPlanID(v)
+ return _u
+}
+
+// SetNillablePlanID sets the "plan_id" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillablePlanID(v *int64) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetPlanID(*v)
+ }
+ return _u
+}
+
+// AddPlanID adds value to the "plan_id" field.
+func (_u *PaymentOrderUpdateOne) AddPlanID(v int64) *PaymentOrderUpdateOne {
+ _u.mutation.AddPlanID(v)
+ return _u
+}
+
+// ClearPlanID clears the value of the "plan_id" field.
+func (_u *PaymentOrderUpdateOne) ClearPlanID() *PaymentOrderUpdateOne {
+ _u.mutation.ClearPlanID()
+ return _u
+}
+
+// SetSubscriptionGroupID sets the "subscription_group_id" field.
+func (_u *PaymentOrderUpdateOne) SetSubscriptionGroupID(v int64) *PaymentOrderUpdateOne {
+ _u.mutation.ResetSubscriptionGroupID()
+ _u.mutation.SetSubscriptionGroupID(v)
+ return _u
+}
+
+// SetNillableSubscriptionGroupID sets the "subscription_group_id" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableSubscriptionGroupID(v *int64) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetSubscriptionGroupID(*v)
+ }
+ return _u
+}
+
+// AddSubscriptionGroupID adds value to the "subscription_group_id" field.
+func (_u *PaymentOrderUpdateOne) AddSubscriptionGroupID(v int64) *PaymentOrderUpdateOne {
+ _u.mutation.AddSubscriptionGroupID(v)
+ return _u
+}
+
+// ClearSubscriptionGroupID clears the value of the "subscription_group_id" field.
+func (_u *PaymentOrderUpdateOne) ClearSubscriptionGroupID() *PaymentOrderUpdateOne {
+ _u.mutation.ClearSubscriptionGroupID()
+ return _u
+}
+
+// SetSubscriptionDays sets the "subscription_days" field.
+func (_u *PaymentOrderUpdateOne) SetSubscriptionDays(v int) *PaymentOrderUpdateOne {
+ _u.mutation.ResetSubscriptionDays()
+ _u.mutation.SetSubscriptionDays(v)
+ return _u
+}
+
+// SetNillableSubscriptionDays sets the "subscription_days" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableSubscriptionDays(v *int) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetSubscriptionDays(*v)
+ }
+ return _u
+}
+
+// AddSubscriptionDays adds value to the "subscription_days" field.
+func (_u *PaymentOrderUpdateOne) AddSubscriptionDays(v int) *PaymentOrderUpdateOne {
+ _u.mutation.AddSubscriptionDays(v)
+ return _u
+}
+
+// ClearSubscriptionDays clears the value of the "subscription_days" field.
+func (_u *PaymentOrderUpdateOne) ClearSubscriptionDays() *PaymentOrderUpdateOne {
+ _u.mutation.ClearSubscriptionDays()
+ return _u
+}
+
+// SetProviderInstanceID sets the "provider_instance_id" field.
+func (_u *PaymentOrderUpdateOne) SetProviderInstanceID(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetProviderInstanceID(v)
+ return _u
+}
+
+// SetNillableProviderInstanceID sets the "provider_instance_id" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableProviderInstanceID(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetProviderInstanceID(*v)
+ }
+ return _u
+}
+
+// ClearProviderInstanceID clears the value of the "provider_instance_id" field.
+func (_u *PaymentOrderUpdateOne) ClearProviderInstanceID() *PaymentOrderUpdateOne {
+ _u.mutation.ClearProviderInstanceID()
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *PaymentOrderUpdateOne) SetProviderKey(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableProviderKey(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// ClearProviderKey clears the value of the "provider_key" field.
+func (_u *PaymentOrderUpdateOne) ClearProviderKey() *PaymentOrderUpdateOne {
+ _u.mutation.ClearProviderKey()
+ return _u
+}
+
+// SetProviderSnapshot sets the "provider_snapshot" field.
+func (_u *PaymentOrderUpdateOne) SetProviderSnapshot(v map[string]interface{}) *PaymentOrderUpdateOne {
+ _u.mutation.SetProviderSnapshot(v)
+ return _u
+}
+
+// ClearProviderSnapshot clears the value of the "provider_snapshot" field.
+func (_u *PaymentOrderUpdateOne) ClearProviderSnapshot() *PaymentOrderUpdateOne {
+ _u.mutation.ClearProviderSnapshot()
+ return _u
+}
+
+// SetStatus sets the "status" field.
+func (_u *PaymentOrderUpdateOne) SetStatus(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetStatus(v)
+ return _u
+}
+
+// SetNillableStatus sets the "status" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableStatus(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetStatus(*v)
+ }
+ return _u
+}
+
+// SetRefundAmount sets the "refund_amount" field.
+func (_u *PaymentOrderUpdateOne) SetRefundAmount(v float64) *PaymentOrderUpdateOne {
+ _u.mutation.ResetRefundAmount()
+ _u.mutation.SetRefundAmount(v)
+ return _u
+}
+
+// SetNillableRefundAmount sets the "refund_amount" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableRefundAmount(v *float64) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetRefundAmount(*v)
+ }
+ return _u
+}
+
+// AddRefundAmount adds value to the "refund_amount" field.
+func (_u *PaymentOrderUpdateOne) AddRefundAmount(v float64) *PaymentOrderUpdateOne {
+ _u.mutation.AddRefundAmount(v)
+ return _u
+}
+
+// SetRefundReason sets the "refund_reason" field.
+func (_u *PaymentOrderUpdateOne) SetRefundReason(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetRefundReason(v)
+ return _u
+}
+
+// SetNillableRefundReason sets the "refund_reason" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableRefundReason(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetRefundReason(*v)
+ }
+ return _u
+}
+
+// ClearRefundReason clears the value of the "refund_reason" field.
+func (_u *PaymentOrderUpdateOne) ClearRefundReason() *PaymentOrderUpdateOne {
+ _u.mutation.ClearRefundReason()
+ return _u
+}
+
+// SetRefundAt sets the "refund_at" field.
+func (_u *PaymentOrderUpdateOne) SetRefundAt(v time.Time) *PaymentOrderUpdateOne {
+ _u.mutation.SetRefundAt(v)
+ return _u
+}
+
+// SetNillableRefundAt sets the "refund_at" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableRefundAt(v *time.Time) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetRefundAt(*v)
+ }
+ return _u
+}
+
+// ClearRefundAt clears the value of the "refund_at" field.
+func (_u *PaymentOrderUpdateOne) ClearRefundAt() *PaymentOrderUpdateOne {
+ _u.mutation.ClearRefundAt()
+ return _u
+}
+
+// SetForceRefund sets the "force_refund" field.
+func (_u *PaymentOrderUpdateOne) SetForceRefund(v bool) *PaymentOrderUpdateOne {
+ _u.mutation.SetForceRefund(v)
+ return _u
+}
+
+// SetNillableForceRefund sets the "force_refund" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableForceRefund(v *bool) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetForceRefund(*v)
+ }
+ return _u
+}
+
+// SetRefundRequestedAt sets the "refund_requested_at" field.
+func (_u *PaymentOrderUpdateOne) SetRefundRequestedAt(v time.Time) *PaymentOrderUpdateOne {
+ _u.mutation.SetRefundRequestedAt(v)
+ return _u
+}
+
+// SetNillableRefundRequestedAt sets the "refund_requested_at" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableRefundRequestedAt(v *time.Time) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetRefundRequestedAt(*v)
+ }
+ return _u
+}
+
+// ClearRefundRequestedAt clears the value of the "refund_requested_at" field.
+func (_u *PaymentOrderUpdateOne) ClearRefundRequestedAt() *PaymentOrderUpdateOne {
+ _u.mutation.ClearRefundRequestedAt()
+ return _u
+}
+
+// SetRefundRequestReason sets the "refund_request_reason" field.
+func (_u *PaymentOrderUpdateOne) SetRefundRequestReason(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetRefundRequestReason(v)
+ return _u
+}
+
+// SetNillableRefundRequestReason sets the "refund_request_reason" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableRefundRequestReason(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetRefundRequestReason(*v)
+ }
+ return _u
+}
+
+// ClearRefundRequestReason clears the value of the "refund_request_reason" field.
+func (_u *PaymentOrderUpdateOne) ClearRefundRequestReason() *PaymentOrderUpdateOne {
+ _u.mutation.ClearRefundRequestReason()
+ return _u
+}
+
+// SetRefundRequestedBy sets the "refund_requested_by" field.
+func (_u *PaymentOrderUpdateOne) SetRefundRequestedBy(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetRefundRequestedBy(v)
+ return _u
+}
+
+// SetNillableRefundRequestedBy sets the "refund_requested_by" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableRefundRequestedBy(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetRefundRequestedBy(*v)
+ }
+ return _u
+}
+
+// ClearRefundRequestedBy clears the value of the "refund_requested_by" field.
+func (_u *PaymentOrderUpdateOne) ClearRefundRequestedBy() *PaymentOrderUpdateOne {
+ _u.mutation.ClearRefundRequestedBy()
+ return _u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_u *PaymentOrderUpdateOne) SetExpiresAt(v time.Time) *PaymentOrderUpdateOne {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableExpiresAt(v *time.Time) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// SetPaidAt sets the "paid_at" field.
+func (_u *PaymentOrderUpdateOne) SetPaidAt(v time.Time) *PaymentOrderUpdateOne {
+ _u.mutation.SetPaidAt(v)
+ return _u
+}
+
+// SetNillablePaidAt sets the "paid_at" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillablePaidAt(v *time.Time) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetPaidAt(*v)
+ }
+ return _u
+}
+
+// ClearPaidAt clears the value of the "paid_at" field.
+func (_u *PaymentOrderUpdateOne) ClearPaidAt() *PaymentOrderUpdateOne {
+ _u.mutation.ClearPaidAt()
+ return _u
+}
+
+// SetCompletedAt sets the "completed_at" field.
+func (_u *PaymentOrderUpdateOne) SetCompletedAt(v time.Time) *PaymentOrderUpdateOne {
+ _u.mutation.SetCompletedAt(v)
+ return _u
+}
+
+// SetNillableCompletedAt sets the "completed_at" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableCompletedAt(v *time.Time) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetCompletedAt(*v)
+ }
+ return _u
+}
+
+// ClearCompletedAt clears the value of the "completed_at" field.
+func (_u *PaymentOrderUpdateOne) ClearCompletedAt() *PaymentOrderUpdateOne {
+ _u.mutation.ClearCompletedAt()
+ return _u
+}
+
+// SetFailedAt sets the "failed_at" field.
+func (_u *PaymentOrderUpdateOne) SetFailedAt(v time.Time) *PaymentOrderUpdateOne {
+ _u.mutation.SetFailedAt(v)
+ return _u
+}
+
+// SetNillableFailedAt sets the "failed_at" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableFailedAt(v *time.Time) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetFailedAt(*v)
+ }
+ return _u
+}
+
+// ClearFailedAt clears the value of the "failed_at" field.
+func (_u *PaymentOrderUpdateOne) ClearFailedAt() *PaymentOrderUpdateOne {
+ _u.mutation.ClearFailedAt()
+ return _u
+}
+
+// SetFailedReason sets the "failed_reason" field.
+func (_u *PaymentOrderUpdateOne) SetFailedReason(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetFailedReason(v)
+ return _u
+}
+
+// SetNillableFailedReason sets the "failed_reason" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableFailedReason(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetFailedReason(*v)
+ }
+ return _u
+}
+
+// ClearFailedReason clears the value of the "failed_reason" field.
+func (_u *PaymentOrderUpdateOne) ClearFailedReason() *PaymentOrderUpdateOne {
+ _u.mutation.ClearFailedReason()
+ return _u
+}
+
+// SetClientIP sets the "client_ip" field.
+func (_u *PaymentOrderUpdateOne) SetClientIP(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetClientIP(v)
+ return _u
+}
+
+// SetNillableClientIP sets the "client_ip" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableClientIP(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetClientIP(*v)
+ }
+ return _u
+}
+
+// SetSrcHost sets the "src_host" field.
+func (_u *PaymentOrderUpdateOne) SetSrcHost(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetSrcHost(v)
+ return _u
+}
+
+// SetNillableSrcHost sets the "src_host" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableSrcHost(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetSrcHost(*v)
+ }
+ return _u
+}
+
+// SetSrcURL sets the "src_url" field.
+func (_u *PaymentOrderUpdateOne) SetSrcURL(v string) *PaymentOrderUpdateOne {
+ _u.mutation.SetSrcURL(v)
+ return _u
+}
+
+// SetNillableSrcURL sets the "src_url" field if the given value is not nil.
+func (_u *PaymentOrderUpdateOne) SetNillableSrcURL(v *string) *PaymentOrderUpdateOne {
+ if v != nil {
+ _u.SetSrcURL(*v)
+ }
+ return _u
+}
+
+// ClearSrcURL clears the value of the "src_url" field.
+func (_u *PaymentOrderUpdateOne) ClearSrcURL() *PaymentOrderUpdateOne {
+ _u.mutation.ClearSrcURL()
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *PaymentOrderUpdateOne) SetUpdatedAt(v time.Time) *PaymentOrderUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetUser sets the "user" edge to the User entity.
+func (_u *PaymentOrderUpdateOne) SetUser(v *User) *PaymentOrderUpdateOne {
+ return _u.SetUserID(v.ID)
+}
+
+// Mutation returns the PaymentOrderMutation object of the builder.
+func (_u *PaymentOrderUpdateOne) Mutation() *PaymentOrderMutation {
+ return _u.mutation
+}
+
+// ClearUser clears the "user" edge to the User entity.
+func (_u *PaymentOrderUpdateOne) ClearUser() *PaymentOrderUpdateOne {
+ _u.mutation.ClearUser()
+ return _u
+}
+
+// Where appends a list predicates to the PaymentOrderUpdate builder.
+func (_u *PaymentOrderUpdateOne) Where(ps ...predicate.PaymentOrder) *PaymentOrderUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *PaymentOrderUpdateOne) Select(field string, fields ...string) *PaymentOrderUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated PaymentOrder entity.
+func (_u *PaymentOrderUpdateOne) Save(ctx context.Context) (*PaymentOrder, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PaymentOrderUpdateOne) SaveX(ctx context.Context) *PaymentOrder {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *PaymentOrderUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PaymentOrderUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *PaymentOrderUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := paymentorder.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PaymentOrderUpdateOne) check() error {
+ if v, ok := _u.mutation.UserEmail(); ok {
+ if err := paymentorder.UserEmailValidator(v); err != nil {
+ return &ValidationError{Name: "user_email", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.user_email": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.UserName(); ok {
+ if err := paymentorder.UserNameValidator(v); err != nil {
+ return &ValidationError{Name: "user_name", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.user_name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.RechargeCode(); ok {
+ if err := paymentorder.RechargeCodeValidator(v); err != nil {
+ return &ValidationError{Name: "recharge_code", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.recharge_code": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.OutTradeNo(); ok {
+ if err := paymentorder.OutTradeNoValidator(v); err != nil {
+ return &ValidationError{Name: "out_trade_no", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.out_trade_no": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.PaymentType(); ok {
+ if err := paymentorder.PaymentTypeValidator(v); err != nil {
+ return &ValidationError{Name: "payment_type", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.payment_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.PaymentTradeNo(); ok {
+ if err := paymentorder.PaymentTradeNoValidator(v); err != nil {
+ return &ValidationError{Name: "payment_trade_no", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.payment_trade_no": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.OrderType(); ok {
+ if err := paymentorder.OrderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "order_type", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.order_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderInstanceID(); ok {
+ if err := paymentorder.ProviderInstanceIDValidator(v); err != nil {
+ return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := paymentorder.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Status(); ok {
+ if err := paymentorder.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.RefundRequestedBy(); ok {
+ if err := paymentorder.RefundRequestedByValidator(v); err != nil {
+ return &ValidationError{Name: "refund_requested_by", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.refund_requested_by": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ClientIP(); ok {
+ if err := paymentorder.ClientIPValidator(v); err != nil {
+ return &ValidationError{Name: "client_ip", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.client_ip": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.SrcHost(); ok {
+ if err := paymentorder.SrcHostValidator(v); err != nil {
+ return &ValidationError{Name: "src_host", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.src_host": %w`, err)}
+ }
+ }
+ if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
+ return errors.New(`ent: clearing a required unique edge "PaymentOrder.user"`)
+ }
+ return nil
+}
+
+func (_u *PaymentOrderUpdateOne) sqlSave(ctx context.Context) (_node *PaymentOrder, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(paymentorder.Table, paymentorder.Columns, sqlgraph.NewFieldSpec(paymentorder.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PaymentOrder.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, paymentorder.FieldID)
+ for _, f := range fields {
+ if !paymentorder.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != paymentorder.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UserEmail(); ok {
+ _spec.SetField(paymentorder.FieldUserEmail, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.UserName(); ok {
+ _spec.SetField(paymentorder.FieldUserName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.UserNotes(); ok {
+ _spec.SetField(paymentorder.FieldUserNotes, field.TypeString, value)
+ }
+ if _u.mutation.UserNotesCleared() {
+ _spec.ClearField(paymentorder.FieldUserNotes, field.TypeString)
+ }
+ if value, ok := _u.mutation.Amount(); ok {
+ _spec.SetField(paymentorder.FieldAmount, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedAmount(); ok {
+ _spec.AddField(paymentorder.FieldAmount, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.PayAmount(); ok {
+ _spec.SetField(paymentorder.FieldPayAmount, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedPayAmount(); ok {
+ _spec.AddField(paymentorder.FieldPayAmount, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.FeeRate(); ok {
+ _spec.SetField(paymentorder.FieldFeeRate, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedFeeRate(); ok {
+ _spec.AddField(paymentorder.FieldFeeRate, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.RechargeCode(); ok {
+ _spec.SetField(paymentorder.FieldRechargeCode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.OutTradeNo(); ok {
+ _spec.SetField(paymentorder.FieldOutTradeNo, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.PaymentType(); ok {
+ _spec.SetField(paymentorder.FieldPaymentType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.PaymentTradeNo(); ok {
+ _spec.SetField(paymentorder.FieldPaymentTradeNo, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.PayURL(); ok {
+ _spec.SetField(paymentorder.FieldPayURL, field.TypeString, value)
+ }
+ if _u.mutation.PayURLCleared() {
+ _spec.ClearField(paymentorder.FieldPayURL, field.TypeString)
+ }
+ if value, ok := _u.mutation.QrCode(); ok {
+ _spec.SetField(paymentorder.FieldQrCode, field.TypeString, value)
+ }
+ if _u.mutation.QrCodeCleared() {
+ _spec.ClearField(paymentorder.FieldQrCode, field.TypeString)
+ }
+ if value, ok := _u.mutation.QrCodeImg(); ok {
+ _spec.SetField(paymentorder.FieldQrCodeImg, field.TypeString, value)
+ }
+ if _u.mutation.QrCodeImgCleared() {
+ _spec.ClearField(paymentorder.FieldQrCodeImg, field.TypeString)
+ }
+ if value, ok := _u.mutation.OrderType(); ok {
+ _spec.SetField(paymentorder.FieldOrderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.PlanID(); ok {
+ _spec.SetField(paymentorder.FieldPlanID, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedPlanID(); ok {
+ _spec.AddField(paymentorder.FieldPlanID, field.TypeInt64, value)
+ }
+ if _u.mutation.PlanIDCleared() {
+ _spec.ClearField(paymentorder.FieldPlanID, field.TypeInt64)
+ }
+ if value, ok := _u.mutation.SubscriptionGroupID(); ok {
+ _spec.SetField(paymentorder.FieldSubscriptionGroupID, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedSubscriptionGroupID(); ok {
+ _spec.AddField(paymentorder.FieldSubscriptionGroupID, field.TypeInt64, value)
+ }
+ if _u.mutation.SubscriptionGroupIDCleared() {
+ _spec.ClearField(paymentorder.FieldSubscriptionGroupID, field.TypeInt64)
+ }
+ if value, ok := _u.mutation.SubscriptionDays(); ok {
+ _spec.SetField(paymentorder.FieldSubscriptionDays, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedSubscriptionDays(); ok {
+ _spec.AddField(paymentorder.FieldSubscriptionDays, field.TypeInt, value)
+ }
+ if _u.mutation.SubscriptionDaysCleared() {
+ _spec.ClearField(paymentorder.FieldSubscriptionDays, field.TypeInt)
+ }
+ if value, ok := _u.mutation.ProviderInstanceID(); ok {
+ _spec.SetField(paymentorder.FieldProviderInstanceID, field.TypeString, value)
+ }
+ if _u.mutation.ProviderInstanceIDCleared() {
+ _spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value)
+ }
+ if _u.mutation.ProviderKeyCleared() {
+ _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString)
+ }
+ if value, ok := _u.mutation.ProviderSnapshot(); ok {
+ _spec.SetField(paymentorder.FieldProviderSnapshot, field.TypeJSON, value)
+ }
+ if _u.mutation.ProviderSnapshotCleared() {
+ _spec.ClearField(paymentorder.FieldProviderSnapshot, field.TypeJSON)
+ }
+ if value, ok := _u.mutation.Status(); ok {
+ _spec.SetField(paymentorder.FieldStatus, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RefundAmount(); ok {
+ _spec.SetField(paymentorder.FieldRefundAmount, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedRefundAmount(); ok {
+ _spec.AddField(paymentorder.FieldRefundAmount, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.RefundReason(); ok {
+ _spec.SetField(paymentorder.FieldRefundReason, field.TypeString, value)
+ }
+ if _u.mutation.RefundReasonCleared() {
+ _spec.ClearField(paymentorder.FieldRefundReason, field.TypeString)
+ }
+ if value, ok := _u.mutation.RefundAt(); ok {
+ _spec.SetField(paymentorder.FieldRefundAt, field.TypeTime, value)
+ }
+ if _u.mutation.RefundAtCleared() {
+ _spec.ClearField(paymentorder.FieldRefundAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.ForceRefund(); ok {
+ _spec.SetField(paymentorder.FieldForceRefund, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.RefundRequestedAt(); ok {
+ _spec.SetField(paymentorder.FieldRefundRequestedAt, field.TypeTime, value)
+ }
+ if _u.mutation.RefundRequestedAtCleared() {
+ _spec.ClearField(paymentorder.FieldRefundRequestedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.RefundRequestReason(); ok {
+ _spec.SetField(paymentorder.FieldRefundRequestReason, field.TypeString, value)
+ }
+ if _u.mutation.RefundRequestReasonCleared() {
+ _spec.ClearField(paymentorder.FieldRefundRequestReason, field.TypeString)
+ }
+ if value, ok := _u.mutation.RefundRequestedBy(); ok {
+ _spec.SetField(paymentorder.FieldRefundRequestedBy, field.TypeString, value)
+ }
+ if _u.mutation.RefundRequestedByCleared() {
+ _spec.ClearField(paymentorder.FieldRefundRequestedBy, field.TypeString)
+ }
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(paymentorder.FieldExpiresAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.PaidAt(); ok {
+ _spec.SetField(paymentorder.FieldPaidAt, field.TypeTime, value)
+ }
+ if _u.mutation.PaidAtCleared() {
+ _spec.ClearField(paymentorder.FieldPaidAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.CompletedAt(); ok {
+ _spec.SetField(paymentorder.FieldCompletedAt, field.TypeTime, value)
+ }
+ if _u.mutation.CompletedAtCleared() {
+ _spec.ClearField(paymentorder.FieldCompletedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.FailedAt(); ok {
+ _spec.SetField(paymentorder.FieldFailedAt, field.TypeTime, value)
+ }
+ if _u.mutation.FailedAtCleared() {
+ _spec.ClearField(paymentorder.FieldFailedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.FailedReason(); ok {
+ _spec.SetField(paymentorder.FieldFailedReason, field.TypeString, value)
+ }
+ if _u.mutation.FailedReasonCleared() {
+ _spec.ClearField(paymentorder.FieldFailedReason, field.TypeString)
+ }
+ if value, ok := _u.mutation.ClientIP(); ok {
+ _spec.SetField(paymentorder.FieldClientIP, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.SrcHost(); ok {
+ _spec.SetField(paymentorder.FieldSrcHost, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.SrcURL(); ok {
+ _spec.SetField(paymentorder.FieldSrcURL, field.TypeString, value)
+ }
+ if _u.mutation.SrcURLCleared() {
+ _spec.ClearField(paymentorder.FieldSrcURL, field.TypeString)
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(paymentorder.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if _u.mutation.UserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: paymentorder.UserTable,
+ Columns: []string{paymentorder.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.UserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: paymentorder.UserTable,
+ Columns: []string{paymentorder.UserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &PaymentOrder{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{paymentorder.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/paymentproviderinstance.go b/backend/ent/paymentproviderinstance.go
new file mode 100644
index 00000000..4279b86e
--- /dev/null
+++ b/backend/ent/paymentproviderinstance.go
@@ -0,0 +1,229 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+)
+
+// PaymentProviderInstance is the model entity for the PaymentProviderInstance schema.
+type PaymentProviderInstance struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // Name holds the value of the "name" field.
+ Name string `json:"name,omitempty"`
+ // Config holds the value of the "config" field.
+ Config string `json:"config,omitempty"`
+ // SupportedTypes holds the value of the "supported_types" field.
+ SupportedTypes string `json:"supported_types,omitempty"`
+ // Enabled holds the value of the "enabled" field.
+ Enabled bool `json:"enabled,omitempty"`
+ // PaymentMode holds the value of the "payment_mode" field.
+ PaymentMode string `json:"payment_mode,omitempty"`
+ // SortOrder holds the value of the "sort_order" field.
+ SortOrder int `json:"sort_order,omitempty"`
+ // Limits holds the value of the "limits" field.
+ Limits string `json:"limits,omitempty"`
+ // RefundEnabled holds the value of the "refund_enabled" field.
+ RefundEnabled bool `json:"refund_enabled,omitempty"`
+ // AllowUserRefund holds the value of the "allow_user_refund" field.
+ AllowUserRefund bool `json:"allow_user_refund,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ selectValues sql.SelectValues
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*PaymentProviderInstance) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled, paymentproviderinstance.FieldAllowUserRefund:
+ values[i] = new(sql.NullBool)
+ case paymentproviderinstance.FieldID, paymentproviderinstance.FieldSortOrder:
+ values[i] = new(sql.NullInt64)
+ case paymentproviderinstance.FieldProviderKey, paymentproviderinstance.FieldName, paymentproviderinstance.FieldConfig, paymentproviderinstance.FieldSupportedTypes, paymentproviderinstance.FieldPaymentMode, paymentproviderinstance.FieldLimits:
+ values[i] = new(sql.NullString)
+ case paymentproviderinstance.FieldCreatedAt, paymentproviderinstance.FieldUpdatedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the PaymentProviderInstance fields.
+func (_m *PaymentProviderInstance) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case paymentproviderinstance.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case paymentproviderinstance.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case paymentproviderinstance.FieldName:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field name", values[i])
+ } else if value.Valid {
+ _m.Name = value.String
+ }
+ case paymentproviderinstance.FieldConfig:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field config", values[i])
+ } else if value.Valid {
+ _m.Config = value.String
+ }
+ case paymentproviderinstance.FieldSupportedTypes:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field supported_types", values[i])
+ } else if value.Valid {
+ _m.SupportedTypes = value.String
+ }
+ case paymentproviderinstance.FieldEnabled:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field enabled", values[i])
+ } else if value.Valid {
+ _m.Enabled = value.Bool
+ }
+ case paymentproviderinstance.FieldPaymentMode:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field payment_mode", values[i])
+ } else if value.Valid {
+ _m.PaymentMode = value.String
+ }
+ case paymentproviderinstance.FieldSortOrder:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field sort_order", values[i])
+ } else if value.Valid {
+ _m.SortOrder = int(value.Int64)
+ }
+ case paymentproviderinstance.FieldLimits:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field limits", values[i])
+ } else if value.Valid {
+ _m.Limits = value.String
+ }
+ case paymentproviderinstance.FieldRefundEnabled:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field refund_enabled", values[i])
+ } else if value.Valid {
+ _m.RefundEnabled = value.Bool
+ }
+ case paymentproviderinstance.FieldAllowUserRefund:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field allow_user_refund", values[i])
+ } else if value.Valid {
+ _m.AllowUserRefund = value.Bool
+ }
+ case paymentproviderinstance.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case paymentproviderinstance.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the PaymentProviderInstance.
+// This includes values selected through modifiers, order, etc.
+func (_m *PaymentProviderInstance) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// Update returns a builder for updating this PaymentProviderInstance.
+// Note that you need to call PaymentProviderInstance.Unwrap() before calling this method if this PaymentProviderInstance
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *PaymentProviderInstance) Update() *PaymentProviderInstanceUpdateOne {
+ return NewPaymentProviderInstanceClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the PaymentProviderInstance entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *PaymentProviderInstance) Unwrap() *PaymentProviderInstance {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: PaymentProviderInstance is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *PaymentProviderInstance) String() string {
+ var builder strings.Builder
+ builder.WriteString("PaymentProviderInstance(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("name=")
+ builder.WriteString(_m.Name)
+ builder.WriteString(", ")
+ builder.WriteString("config=")
+ builder.WriteString(_m.Config)
+ builder.WriteString(", ")
+ builder.WriteString("supported_types=")
+ builder.WriteString(_m.SupportedTypes)
+ builder.WriteString(", ")
+ builder.WriteString("enabled=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Enabled))
+ builder.WriteString(", ")
+ builder.WriteString("payment_mode=")
+ builder.WriteString(_m.PaymentMode)
+ builder.WriteString(", ")
+ builder.WriteString("sort_order=")
+ builder.WriteString(fmt.Sprintf("%v", _m.SortOrder))
+ builder.WriteString(", ")
+ builder.WriteString("limits=")
+ builder.WriteString(_m.Limits)
+ builder.WriteString(", ")
+ builder.WriteString("refund_enabled=")
+ builder.WriteString(fmt.Sprintf("%v", _m.RefundEnabled))
+ builder.WriteString(", ")
+ builder.WriteString("allow_user_refund=")
+ builder.WriteString(fmt.Sprintf("%v", _m.AllowUserRefund))
+ builder.WriteString(", ")
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// PaymentProviderInstances is a parsable slice of PaymentProviderInstance.
+type PaymentProviderInstances []*PaymentProviderInstance
diff --git a/backend/ent/paymentproviderinstance/paymentproviderinstance.go b/backend/ent/paymentproviderinstance/paymentproviderinstance.go
new file mode 100644
index 00000000..eb1b0c52
--- /dev/null
+++ b/backend/ent/paymentproviderinstance/paymentproviderinstance.go
@@ -0,0 +1,170 @@
+// Code generated by ent, DO NOT EDIT.
+
+package paymentproviderinstance
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+)
+
+const (
+ // Label holds the string label denoting the paymentproviderinstance type in the database.
+ Label = "payment_provider_instance"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldName holds the string denoting the name field in the database.
+ FieldName = "name"
+ // FieldConfig holds the string denoting the config field in the database.
+ FieldConfig = "config"
+ // FieldSupportedTypes holds the string denoting the supported_types field in the database.
+ FieldSupportedTypes = "supported_types"
+ // FieldEnabled holds the string denoting the enabled field in the database.
+ FieldEnabled = "enabled"
+ // FieldPaymentMode holds the string denoting the payment_mode field in the database.
+ FieldPaymentMode = "payment_mode"
+ // FieldSortOrder holds the string denoting the sort_order field in the database.
+ FieldSortOrder = "sort_order"
+ // FieldLimits holds the string denoting the limits field in the database.
+ FieldLimits = "limits"
+ // FieldRefundEnabled holds the string denoting the refund_enabled field in the database.
+ FieldRefundEnabled = "refund_enabled"
+ // FieldAllowUserRefund holds the string denoting the allow_user_refund field in the database.
+ FieldAllowUserRefund = "allow_user_refund"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // Table holds the table name of the paymentproviderinstance in the database.
+ Table = "payment_provider_instances"
+)
+
+// Columns holds all SQL columns for paymentproviderinstance fields.
+var Columns = []string{
+ FieldID,
+ FieldProviderKey,
+ FieldName,
+ FieldConfig,
+ FieldSupportedTypes,
+ FieldEnabled,
+ FieldPaymentMode,
+ FieldSortOrder,
+ FieldLimits,
+ FieldRefundEnabled,
+ FieldAllowUserRefund,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // DefaultName holds the default value on creation for the "name" field.
+ DefaultName string
+ // NameValidator is a validator for the "name" field. It is called by the builders before save.
+ NameValidator func(string) error
+ // DefaultSupportedTypes holds the default value on creation for the "supported_types" field.
+ DefaultSupportedTypes string
+ // SupportedTypesValidator is a validator for the "supported_types" field. It is called by the builders before save.
+ SupportedTypesValidator func(string) error
+ // DefaultEnabled holds the default value on creation for the "enabled" field.
+ DefaultEnabled bool
+ // DefaultPaymentMode holds the default value on creation for the "payment_mode" field.
+ DefaultPaymentMode string
+ // PaymentModeValidator is a validator for the "payment_mode" field. It is called by the builders before save.
+ PaymentModeValidator func(string) error
+ // DefaultSortOrder holds the default value on creation for the "sort_order" field.
+ DefaultSortOrder int
+ // DefaultLimits holds the default value on creation for the "limits" field.
+ DefaultLimits string
+ // DefaultRefundEnabled holds the default value on creation for the "refund_enabled" field.
+ DefaultRefundEnabled bool
+ // DefaultAllowUserRefund holds the default value on creation for the "allow_user_refund" field.
+ DefaultAllowUserRefund bool
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the PaymentProviderInstance queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByName orders the results by the name field.
+func ByName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldName, opts...).ToFunc()
+}
+
+// ByConfig orders the results by the config field.
+func ByConfig(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldConfig, opts...).ToFunc()
+}
+
+// BySupportedTypes orders the results by the supported_types field.
+func BySupportedTypes(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSupportedTypes, opts...).ToFunc()
+}
+
+// ByEnabled orders the results by the enabled field.
+func ByEnabled(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldEnabled, opts...).ToFunc()
+}
+
+// ByPaymentMode orders the results by the payment_mode field.
+func ByPaymentMode(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPaymentMode, opts...).ToFunc()
+}
+
+// BySortOrder orders the results by the sort_order field.
+func BySortOrder(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSortOrder, opts...).ToFunc()
+}
+
+// ByLimits orders the results by the limits field.
+func ByLimits(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLimits, opts...).ToFunc()
+}
+
+// ByRefundEnabled orders the results by the refund_enabled field.
+func ByRefundEnabled(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRefundEnabled, opts...).ToFunc()
+}
+
+// ByAllowUserRefund orders the results by the allow_user_refund field.
+func ByAllowUserRefund(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAllowUserRefund, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
diff --git a/backend/ent/paymentproviderinstance/where.go b/backend/ent/paymentproviderinstance/where.go
new file mode 100644
index 00000000..40e5a1f6
--- /dev/null
+++ b/backend/ent/paymentproviderinstance/where.go
@@ -0,0 +1,670 @@
+// Code generated by ent, DO NOT EDIT.
+
+package paymentproviderinstance
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLTE(FieldID, id))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
+func Name(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldName, v))
+}
+
+// Config applies equality check predicate on the "config" field. It's identical to ConfigEQ.
+func Config(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldConfig, v))
+}
+
+// SupportedTypes applies equality check predicate on the "supported_types" field. It's identical to SupportedTypesEQ.
+func SupportedTypes(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldSupportedTypes, v))
+}
+
+// Enabled applies equality check predicate on the "enabled" field. It's identical to EnabledEQ.
+func Enabled(v bool) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldEnabled, v))
+}
+
+// PaymentMode applies equality check predicate on the "payment_mode" field. It's identical to PaymentModeEQ.
+func PaymentMode(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldPaymentMode, v))
+}
+
+// SortOrder applies equality check predicate on the "sort_order" field. It's identical to SortOrderEQ.
+func SortOrder(v int) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldSortOrder, v))
+}
+
+// Limits applies equality check predicate on the "limits" field. It's identical to LimitsEQ.
+func Limits(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldLimits, v))
+}
+
+// RefundEnabled applies equality check predicate on the "refund_enabled" field. It's identical to RefundEnabledEQ.
+func RefundEnabled(v bool) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldRefundEnabled, v))
+}
+
+// AllowUserRefund applies equality check predicate on the "allow_user_refund" field. It's identical to AllowUserRefundEQ.
+func AllowUserRefund(v bool) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// NameEQ applies the EQ predicate on the "name" field.
+func NameEQ(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldName, v))
+}
+
+// NameNEQ applies the NEQ predicate on the "name" field.
+func NameNEQ(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldName, v))
+}
+
+// NameIn applies the In predicate on the "name" field.
+func NameIn(vs ...string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldIn(FieldName, vs...))
+}
+
+// NameNotIn applies the NotIn predicate on the "name" field.
+func NameNotIn(vs ...string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNotIn(FieldName, vs...))
+}
+
+// NameGT applies the GT predicate on the "name" field.
+func NameGT(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGT(FieldName, v))
+}
+
+// NameGTE applies the GTE predicate on the "name" field.
+func NameGTE(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGTE(FieldName, v))
+}
+
+// NameLT applies the LT predicate on the "name" field.
+func NameLT(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLT(FieldName, v))
+}
+
+// NameLTE applies the LTE predicate on the "name" field.
+func NameLTE(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLTE(FieldName, v))
+}
+
+// NameContains applies the Contains predicate on the "name" field.
+func NameContains(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldContains(FieldName, v))
+}
+
+// NameHasPrefix applies the HasPrefix predicate on the "name" field.
+func NameHasPrefix(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldHasPrefix(FieldName, v))
+}
+
+// NameHasSuffix applies the HasSuffix predicate on the "name" field.
+func NameHasSuffix(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldHasSuffix(FieldName, v))
+}
+
+// NameEqualFold applies the EqualFold predicate on the "name" field.
+func NameEqualFold(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEqualFold(FieldName, v))
+}
+
+// NameContainsFold applies the ContainsFold predicate on the "name" field.
+func NameContainsFold(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldContainsFold(FieldName, v))
+}
+
+// ConfigEQ applies the EQ predicate on the "config" field.
+func ConfigEQ(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldConfig, v))
+}
+
+// ConfigNEQ applies the NEQ predicate on the "config" field.
+func ConfigNEQ(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldConfig, v))
+}
+
+// ConfigIn applies the In predicate on the "config" field.
+func ConfigIn(vs ...string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldIn(FieldConfig, vs...))
+}
+
+// ConfigNotIn applies the NotIn predicate on the "config" field.
+func ConfigNotIn(vs ...string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNotIn(FieldConfig, vs...))
+}
+
+// ConfigGT applies the GT predicate on the "config" field.
+func ConfigGT(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGT(FieldConfig, v))
+}
+
+// ConfigGTE applies the GTE predicate on the "config" field.
+func ConfigGTE(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGTE(FieldConfig, v))
+}
+
+// ConfigLT applies the LT predicate on the "config" field.
+func ConfigLT(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLT(FieldConfig, v))
+}
+
+// ConfigLTE applies the LTE predicate on the "config" field.
+func ConfigLTE(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLTE(FieldConfig, v))
+}
+
+// ConfigContains applies the Contains predicate on the "config" field.
+func ConfigContains(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldContains(FieldConfig, v))
+}
+
+// ConfigHasPrefix applies the HasPrefix predicate on the "config" field.
+func ConfigHasPrefix(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldHasPrefix(FieldConfig, v))
+}
+
+// ConfigHasSuffix applies the HasSuffix predicate on the "config" field.
+func ConfigHasSuffix(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldHasSuffix(FieldConfig, v))
+}
+
+// ConfigEqualFold applies the EqualFold predicate on the "config" field.
+func ConfigEqualFold(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEqualFold(FieldConfig, v))
+}
+
+// ConfigContainsFold applies the ContainsFold predicate on the "config" field.
+func ConfigContainsFold(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldContainsFold(FieldConfig, v))
+}
+
+// SupportedTypesEQ applies the EQ predicate on the "supported_types" field.
+func SupportedTypesEQ(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldSupportedTypes, v))
+}
+
+// SupportedTypesNEQ applies the NEQ predicate on the "supported_types" field.
+func SupportedTypesNEQ(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldSupportedTypes, v))
+}
+
+// SupportedTypesIn applies the In predicate on the "supported_types" field.
+func SupportedTypesIn(vs ...string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldIn(FieldSupportedTypes, vs...))
+}
+
+// SupportedTypesNotIn applies the NotIn predicate on the "supported_types" field.
+func SupportedTypesNotIn(vs ...string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNotIn(FieldSupportedTypes, vs...))
+}
+
+// SupportedTypesGT applies the GT predicate on the "supported_types" field.
+func SupportedTypesGT(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGT(FieldSupportedTypes, v))
+}
+
+// SupportedTypesGTE applies the GTE predicate on the "supported_types" field.
+func SupportedTypesGTE(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGTE(FieldSupportedTypes, v))
+}
+
+// SupportedTypesLT applies the LT predicate on the "supported_types" field.
+func SupportedTypesLT(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLT(FieldSupportedTypes, v))
+}
+
+// SupportedTypesLTE applies the LTE predicate on the "supported_types" field.
+func SupportedTypesLTE(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLTE(FieldSupportedTypes, v))
+}
+
+// SupportedTypesContains applies the Contains predicate on the "supported_types" field.
+func SupportedTypesContains(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldContains(FieldSupportedTypes, v))
+}
+
+// SupportedTypesHasPrefix applies the HasPrefix predicate on the "supported_types" field.
+func SupportedTypesHasPrefix(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldHasPrefix(FieldSupportedTypes, v))
+}
+
+// SupportedTypesHasSuffix applies the HasSuffix predicate on the "supported_types" field.
+func SupportedTypesHasSuffix(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldHasSuffix(FieldSupportedTypes, v))
+}
+
+// SupportedTypesEqualFold applies the EqualFold predicate on the "supported_types" field.
+func SupportedTypesEqualFold(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEqualFold(FieldSupportedTypes, v))
+}
+
+// SupportedTypesContainsFold applies the ContainsFold predicate on the "supported_types" field.
+func SupportedTypesContainsFold(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldContainsFold(FieldSupportedTypes, v))
+}
+
+// EnabledEQ applies the EQ predicate on the "enabled" field.
+func EnabledEQ(v bool) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldEnabled, v))
+}
+
+// EnabledNEQ applies the NEQ predicate on the "enabled" field.
+func EnabledNEQ(v bool) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldEnabled, v))
+}
+
+// PaymentModeEQ applies the EQ predicate on the "payment_mode" field.
+func PaymentModeEQ(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldPaymentMode, v))
+}
+
+// PaymentModeNEQ applies the NEQ predicate on the "payment_mode" field.
+func PaymentModeNEQ(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldPaymentMode, v))
+}
+
+// PaymentModeIn applies the In predicate on the "payment_mode" field.
+func PaymentModeIn(vs ...string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldIn(FieldPaymentMode, vs...))
+}
+
+// PaymentModeNotIn applies the NotIn predicate on the "payment_mode" field.
+func PaymentModeNotIn(vs ...string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNotIn(FieldPaymentMode, vs...))
+}
+
+// PaymentModeGT applies the GT predicate on the "payment_mode" field.
+func PaymentModeGT(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGT(FieldPaymentMode, v))
+}
+
+// PaymentModeGTE applies the GTE predicate on the "payment_mode" field.
+func PaymentModeGTE(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGTE(FieldPaymentMode, v))
+}
+
+// PaymentModeLT applies the LT predicate on the "payment_mode" field.
+func PaymentModeLT(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLT(FieldPaymentMode, v))
+}
+
+// PaymentModeLTE applies the LTE predicate on the "payment_mode" field.
+func PaymentModeLTE(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLTE(FieldPaymentMode, v))
+}
+
+// PaymentModeContains applies the Contains predicate on the "payment_mode" field.
+func PaymentModeContains(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldContains(FieldPaymentMode, v))
+}
+
+// PaymentModeHasPrefix applies the HasPrefix predicate on the "payment_mode" field.
+func PaymentModeHasPrefix(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldHasPrefix(FieldPaymentMode, v))
+}
+
+// PaymentModeHasSuffix applies the HasSuffix predicate on the "payment_mode" field.
+func PaymentModeHasSuffix(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldHasSuffix(FieldPaymentMode, v))
+}
+
+// PaymentModeEqualFold applies the EqualFold predicate on the "payment_mode" field.
+func PaymentModeEqualFold(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEqualFold(FieldPaymentMode, v))
+}
+
+// PaymentModeContainsFold applies the ContainsFold predicate on the "payment_mode" field.
+func PaymentModeContainsFold(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldContainsFold(FieldPaymentMode, v))
+}
+
+// SortOrderEQ applies the EQ predicate on the "sort_order" field.
+func SortOrderEQ(v int) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldSortOrder, v))
+}
+
+// SortOrderNEQ applies the NEQ predicate on the "sort_order" field.
+func SortOrderNEQ(v int) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldSortOrder, v))
+}
+
+// SortOrderIn applies the In predicate on the "sort_order" field.
+func SortOrderIn(vs ...int) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldIn(FieldSortOrder, vs...))
+}
+
+// SortOrderNotIn applies the NotIn predicate on the "sort_order" field.
+func SortOrderNotIn(vs ...int) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNotIn(FieldSortOrder, vs...))
+}
+
+// SortOrderGT applies the GT predicate on the "sort_order" field.
+func SortOrderGT(v int) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGT(FieldSortOrder, v))
+}
+
+// SortOrderGTE applies the GTE predicate on the "sort_order" field.
+func SortOrderGTE(v int) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGTE(FieldSortOrder, v))
+}
+
+// SortOrderLT applies the LT predicate on the "sort_order" field.
+func SortOrderLT(v int) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLT(FieldSortOrder, v))
+}
+
+// SortOrderLTE applies the LTE predicate on the "sort_order" field.
+func SortOrderLTE(v int) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLTE(FieldSortOrder, v))
+}
+
+// LimitsEQ applies the EQ predicate on the "limits" field.
+func LimitsEQ(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldLimits, v))
+}
+
+// LimitsNEQ applies the NEQ predicate on the "limits" field.
+func LimitsNEQ(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldLimits, v))
+}
+
+// LimitsIn applies the In predicate on the "limits" field.
+func LimitsIn(vs ...string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldIn(FieldLimits, vs...))
+}
+
+// LimitsNotIn applies the NotIn predicate on the "limits" field.
+func LimitsNotIn(vs ...string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNotIn(FieldLimits, vs...))
+}
+
+// LimitsGT applies the GT predicate on the "limits" field.
+func LimitsGT(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGT(FieldLimits, v))
+}
+
+// LimitsGTE applies the GTE predicate on the "limits" field.
+func LimitsGTE(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGTE(FieldLimits, v))
+}
+
+// LimitsLT applies the LT predicate on the "limits" field.
+func LimitsLT(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLT(FieldLimits, v))
+}
+
+// LimitsLTE applies the LTE predicate on the "limits" field.
+func LimitsLTE(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLTE(FieldLimits, v))
+}
+
+// LimitsContains applies the Contains predicate on the "limits" field.
+func LimitsContains(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldContains(FieldLimits, v))
+}
+
+// LimitsHasPrefix applies the HasPrefix predicate on the "limits" field.
+func LimitsHasPrefix(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldHasPrefix(FieldLimits, v))
+}
+
+// LimitsHasSuffix applies the HasSuffix predicate on the "limits" field.
+func LimitsHasSuffix(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldHasSuffix(FieldLimits, v))
+}
+
+// LimitsEqualFold applies the EqualFold predicate on the "limits" field.
+func LimitsEqualFold(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEqualFold(FieldLimits, v))
+}
+
+// LimitsContainsFold applies the ContainsFold predicate on the "limits" field.
+func LimitsContainsFold(v string) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldContainsFold(FieldLimits, v))
+}
+
+// RefundEnabledEQ applies the EQ predicate on the "refund_enabled" field.
+func RefundEnabledEQ(v bool) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldRefundEnabled, v))
+}
+
+// RefundEnabledNEQ applies the NEQ predicate on the "refund_enabled" field.
+func RefundEnabledNEQ(v bool) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldRefundEnabled, v))
+}
+
+// AllowUserRefundEQ applies the EQ predicate on the "allow_user_refund" field.
+func AllowUserRefundEQ(v bool) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v))
+}
+
+// AllowUserRefundNEQ applies the NEQ predicate on the "allow_user_refund" field.
+func AllowUserRefundNEQ(v bool) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldAllowUserRefund, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.PaymentProviderInstance) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.PaymentProviderInstance) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.PaymentProviderInstance) predicate.PaymentProviderInstance {
+ return predicate.PaymentProviderInstance(sql.NotPredicates(p))
+}
diff --git a/backend/ent/paymentproviderinstance_create.go b/backend/ent/paymentproviderinstance_create.go
new file mode 100644
index 00000000..d1b14617
--- /dev/null
+++ b/backend/ent/paymentproviderinstance_create.go
@@ -0,0 +1,1176 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+)
+
+// PaymentProviderInstanceCreate is the builder for creating a PaymentProviderInstance entity.
+type PaymentProviderInstanceCreate struct {
+ config
+ mutation *PaymentProviderInstanceMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *PaymentProviderInstanceCreate) SetProviderKey(v string) *PaymentProviderInstanceCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetName sets the "name" field.
+func (_c *PaymentProviderInstanceCreate) SetName(v string) *PaymentProviderInstanceCreate {
+ _c.mutation.SetName(v)
+ return _c
+}
+
+// SetNillableName sets the "name" field if the given value is not nil.
+func (_c *PaymentProviderInstanceCreate) SetNillableName(v *string) *PaymentProviderInstanceCreate {
+ if v != nil {
+ _c.SetName(*v)
+ }
+ return _c
+}
+
+// SetConfig sets the "config" field.
+func (_c *PaymentProviderInstanceCreate) SetConfig(v string) *PaymentProviderInstanceCreate {
+ _c.mutation.SetConfig(v)
+ return _c
+}
+
+// SetSupportedTypes sets the "supported_types" field.
+func (_c *PaymentProviderInstanceCreate) SetSupportedTypes(v string) *PaymentProviderInstanceCreate {
+ _c.mutation.SetSupportedTypes(v)
+ return _c
+}
+
+// SetNillableSupportedTypes sets the "supported_types" field if the given value is not nil.
+func (_c *PaymentProviderInstanceCreate) SetNillableSupportedTypes(v *string) *PaymentProviderInstanceCreate {
+ if v != nil {
+ _c.SetSupportedTypes(*v)
+ }
+ return _c
+}
+
+// SetEnabled sets the "enabled" field.
+func (_c *PaymentProviderInstanceCreate) SetEnabled(v bool) *PaymentProviderInstanceCreate {
+ _c.mutation.SetEnabled(v)
+ return _c
+}
+
+// SetNillableEnabled sets the "enabled" field if the given value is not nil.
+func (_c *PaymentProviderInstanceCreate) SetNillableEnabled(v *bool) *PaymentProviderInstanceCreate {
+ if v != nil {
+ _c.SetEnabled(*v)
+ }
+ return _c
+}
+
+// SetPaymentMode sets the "payment_mode" field.
+func (_c *PaymentProviderInstanceCreate) SetPaymentMode(v string) *PaymentProviderInstanceCreate {
+ _c.mutation.SetPaymentMode(v)
+ return _c
+}
+
+// SetNillablePaymentMode sets the "payment_mode" field if the given value is not nil.
+func (_c *PaymentProviderInstanceCreate) SetNillablePaymentMode(v *string) *PaymentProviderInstanceCreate {
+ if v != nil {
+ _c.SetPaymentMode(*v)
+ }
+ return _c
+}
+
+// SetSortOrder sets the "sort_order" field.
+func (_c *PaymentProviderInstanceCreate) SetSortOrder(v int) *PaymentProviderInstanceCreate {
+ _c.mutation.SetSortOrder(v)
+ return _c
+}
+
+// SetNillableSortOrder sets the "sort_order" field if the given value is not nil.
+func (_c *PaymentProviderInstanceCreate) SetNillableSortOrder(v *int) *PaymentProviderInstanceCreate {
+ if v != nil {
+ _c.SetSortOrder(*v)
+ }
+ return _c
+}
+
+// SetLimits sets the "limits" field.
+func (_c *PaymentProviderInstanceCreate) SetLimits(v string) *PaymentProviderInstanceCreate {
+ _c.mutation.SetLimits(v)
+ return _c
+}
+
+// SetNillableLimits sets the "limits" field if the given value is not nil.
+func (_c *PaymentProviderInstanceCreate) SetNillableLimits(v *string) *PaymentProviderInstanceCreate {
+ if v != nil {
+ _c.SetLimits(*v)
+ }
+ return _c
+}
+
+// SetRefundEnabled sets the "refund_enabled" field.
+func (_c *PaymentProviderInstanceCreate) SetRefundEnabled(v bool) *PaymentProviderInstanceCreate {
+ _c.mutation.SetRefundEnabled(v)
+ return _c
+}
+
+// SetNillableRefundEnabled sets the "refund_enabled" field if the given value is not nil.
+func (_c *PaymentProviderInstanceCreate) SetNillableRefundEnabled(v *bool) *PaymentProviderInstanceCreate {
+ if v != nil {
+ _c.SetRefundEnabled(*v)
+ }
+ return _c
+}
+
+// SetAllowUserRefund sets the "allow_user_refund" field.
+func (_c *PaymentProviderInstanceCreate) SetAllowUserRefund(v bool) *PaymentProviderInstanceCreate {
+ _c.mutation.SetAllowUserRefund(v)
+ return _c
+}
+
+// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
+func (_c *PaymentProviderInstanceCreate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceCreate {
+ if v != nil {
+ _c.SetAllowUserRefund(*v)
+ }
+ return _c
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *PaymentProviderInstanceCreate) SetCreatedAt(v time.Time) *PaymentProviderInstanceCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *PaymentProviderInstanceCreate) SetNillableCreatedAt(v *time.Time) *PaymentProviderInstanceCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *PaymentProviderInstanceCreate) SetUpdatedAt(v time.Time) *PaymentProviderInstanceCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *PaymentProviderInstanceCreate) SetNillableUpdatedAt(v *time.Time) *PaymentProviderInstanceCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// Mutation returns the PaymentProviderInstanceMutation object of the builder.
+func (_c *PaymentProviderInstanceCreate) Mutation() *PaymentProviderInstanceMutation {
+ return _c.mutation
+}
+
+// Save creates the PaymentProviderInstance in the database.
+func (_c *PaymentProviderInstanceCreate) Save(ctx context.Context) (*PaymentProviderInstance, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *PaymentProviderInstanceCreate) SaveX(ctx context.Context) *PaymentProviderInstance {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PaymentProviderInstanceCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PaymentProviderInstanceCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *PaymentProviderInstanceCreate) defaults() {
+ if _, ok := _c.mutation.Name(); !ok {
+ v := paymentproviderinstance.DefaultName
+ _c.mutation.SetName(v)
+ }
+ if _, ok := _c.mutation.SupportedTypes(); !ok {
+ v := paymentproviderinstance.DefaultSupportedTypes
+ _c.mutation.SetSupportedTypes(v)
+ }
+ if _, ok := _c.mutation.Enabled(); !ok {
+ v := paymentproviderinstance.DefaultEnabled
+ _c.mutation.SetEnabled(v)
+ }
+ if _, ok := _c.mutation.PaymentMode(); !ok {
+ v := paymentproviderinstance.DefaultPaymentMode
+ _c.mutation.SetPaymentMode(v)
+ }
+ if _, ok := _c.mutation.SortOrder(); !ok {
+ v := paymentproviderinstance.DefaultSortOrder
+ _c.mutation.SetSortOrder(v)
+ }
+ if _, ok := _c.mutation.Limits(); !ok {
+ v := paymentproviderinstance.DefaultLimits
+ _c.mutation.SetLimits(v)
+ }
+ if _, ok := _c.mutation.RefundEnabled(); !ok {
+ v := paymentproviderinstance.DefaultRefundEnabled
+ _c.mutation.SetRefundEnabled(v)
+ }
+ if _, ok := _c.mutation.AllowUserRefund(); !ok {
+ v := paymentproviderinstance.DefaultAllowUserRefund
+ _c.mutation.SetAllowUserRefund(v)
+ }
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := paymentproviderinstance.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := paymentproviderinstance.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *PaymentProviderInstanceCreate) check() error {
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "PaymentProviderInstance.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := paymentproviderinstance.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentProviderInstance.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Name(); !ok {
+ return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "PaymentProviderInstance.name"`)}
+ }
+ if v, ok := _c.mutation.Name(); ok {
+ if err := paymentproviderinstance.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "PaymentProviderInstance.name": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Config(); !ok {
+ return &ValidationError{Name: "config", err: errors.New(`ent: missing required field "PaymentProviderInstance.config"`)}
+ }
+ if _, ok := _c.mutation.SupportedTypes(); !ok {
+ return &ValidationError{Name: "supported_types", err: errors.New(`ent: missing required field "PaymentProviderInstance.supported_types"`)}
+ }
+ if v, ok := _c.mutation.SupportedTypes(); ok {
+ if err := paymentproviderinstance.SupportedTypesValidator(v); err != nil {
+ return &ValidationError{Name: "supported_types", err: fmt.Errorf(`ent: validator failed for field "PaymentProviderInstance.supported_types": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Enabled(); !ok {
+ return &ValidationError{Name: "enabled", err: errors.New(`ent: missing required field "PaymentProviderInstance.enabled"`)}
+ }
+ if _, ok := _c.mutation.PaymentMode(); !ok {
+ return &ValidationError{Name: "payment_mode", err: errors.New(`ent: missing required field "PaymentProviderInstance.payment_mode"`)}
+ }
+ if v, ok := _c.mutation.PaymentMode(); ok {
+ if err := paymentproviderinstance.PaymentModeValidator(v); err != nil {
+ return &ValidationError{Name: "payment_mode", err: fmt.Errorf(`ent: validator failed for field "PaymentProviderInstance.payment_mode": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.SortOrder(); !ok {
+ return &ValidationError{Name: "sort_order", err: errors.New(`ent: missing required field "PaymentProviderInstance.sort_order"`)}
+ }
+ if _, ok := _c.mutation.Limits(); !ok {
+ return &ValidationError{Name: "limits", err: errors.New(`ent: missing required field "PaymentProviderInstance.limits"`)}
+ }
+ if _, ok := _c.mutation.RefundEnabled(); !ok {
+ return &ValidationError{Name: "refund_enabled", err: errors.New(`ent: missing required field "PaymentProviderInstance.refund_enabled"`)}
+ }
+ if _, ok := _c.mutation.AllowUserRefund(); !ok {
+ return &ValidationError{Name: "allow_user_refund", err: errors.New(`ent: missing required field "PaymentProviderInstance.allow_user_refund"`)}
+ }
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PaymentProviderInstance.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "PaymentProviderInstance.updated_at"`)}
+ }
+ return nil
+}
+
+func (_c *PaymentProviderInstanceCreate) sqlSave(ctx context.Context) (*PaymentProviderInstance, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *PaymentProviderInstanceCreate) createSpec() (*PaymentProviderInstance, *sqlgraph.CreateSpec) {
+ var (
+ _node = &PaymentProviderInstance{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(paymentproviderinstance.Table, sqlgraph.NewFieldSpec(paymentproviderinstance.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(paymentproviderinstance.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.Name(); ok {
+ _spec.SetField(paymentproviderinstance.FieldName, field.TypeString, value)
+ _node.Name = value
+ }
+ if value, ok := _c.mutation.Config(); ok {
+ _spec.SetField(paymentproviderinstance.FieldConfig, field.TypeString, value)
+ _node.Config = value
+ }
+ if value, ok := _c.mutation.SupportedTypes(); ok {
+ _spec.SetField(paymentproviderinstance.FieldSupportedTypes, field.TypeString, value)
+ _node.SupportedTypes = value
+ }
+ if value, ok := _c.mutation.Enabled(); ok {
+ _spec.SetField(paymentproviderinstance.FieldEnabled, field.TypeBool, value)
+ _node.Enabled = value
+ }
+ if value, ok := _c.mutation.PaymentMode(); ok {
+ _spec.SetField(paymentproviderinstance.FieldPaymentMode, field.TypeString, value)
+ _node.PaymentMode = value
+ }
+ if value, ok := _c.mutation.SortOrder(); ok {
+ _spec.SetField(paymentproviderinstance.FieldSortOrder, field.TypeInt, value)
+ _node.SortOrder = value
+ }
+ if value, ok := _c.mutation.Limits(); ok {
+ _spec.SetField(paymentproviderinstance.FieldLimits, field.TypeString, value)
+ _node.Limits = value
+ }
+ if value, ok := _c.mutation.RefundEnabled(); ok {
+ _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
+ _node.RefundEnabled = value
+ }
+ if value, ok := _c.mutation.AllowUserRefund(); ok {
+ _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
+ _node.AllowUserRefund = value
+ }
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(paymentproviderinstance.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PaymentProviderInstance.Create().
+// SetProviderKey(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PaymentProviderInstanceUpsert) {
+// SetProviderKey(v+v).
+// }).
+// Exec(ctx)
+func (_c *PaymentProviderInstanceCreate) OnConflict(opts ...sql.ConflictOption) *PaymentProviderInstanceUpsertOne {
+ _c.conflict = opts
+ return &PaymentProviderInstanceUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PaymentProviderInstance.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PaymentProviderInstanceCreate) OnConflictColumns(columns ...string) *PaymentProviderInstanceUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PaymentProviderInstanceUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // PaymentProviderInstanceUpsertOne is the builder for "upsert"-ing
+ // one PaymentProviderInstance node.
+ PaymentProviderInstanceUpsertOne struct {
+ create *PaymentProviderInstanceCreate
+ }
+
+ // PaymentProviderInstanceUpsert is the "OnConflict" setter.
+ PaymentProviderInstanceUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PaymentProviderInstanceUpsert) SetProviderKey(v string) *PaymentProviderInstanceUpsert {
+ u.Set(paymentproviderinstance.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsert) UpdateProviderKey() *PaymentProviderInstanceUpsert {
+ u.SetExcluded(paymentproviderinstance.FieldProviderKey)
+ return u
+}
+
+// SetName sets the "name" field.
+func (u *PaymentProviderInstanceUpsert) SetName(v string) *PaymentProviderInstanceUpsert {
+ u.Set(paymentproviderinstance.FieldName, v)
+ return u
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsert) UpdateName() *PaymentProviderInstanceUpsert {
+ u.SetExcluded(paymentproviderinstance.FieldName)
+ return u
+}
+
+// SetConfig sets the "config" field.
+func (u *PaymentProviderInstanceUpsert) SetConfig(v string) *PaymentProviderInstanceUpsert {
+ u.Set(paymentproviderinstance.FieldConfig, v)
+ return u
+}
+
+// UpdateConfig sets the "config" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsert) UpdateConfig() *PaymentProviderInstanceUpsert {
+ u.SetExcluded(paymentproviderinstance.FieldConfig)
+ return u
+}
+
+// SetSupportedTypes sets the "supported_types" field.
+func (u *PaymentProviderInstanceUpsert) SetSupportedTypes(v string) *PaymentProviderInstanceUpsert {
+ u.Set(paymentproviderinstance.FieldSupportedTypes, v)
+ return u
+}
+
+// UpdateSupportedTypes sets the "supported_types" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsert) UpdateSupportedTypes() *PaymentProviderInstanceUpsert {
+ u.SetExcluded(paymentproviderinstance.FieldSupportedTypes)
+ return u
+}
+
+// SetEnabled sets the "enabled" field.
+func (u *PaymentProviderInstanceUpsert) SetEnabled(v bool) *PaymentProviderInstanceUpsert {
+ u.Set(paymentproviderinstance.FieldEnabled, v)
+ return u
+}
+
+// UpdateEnabled sets the "enabled" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsert) UpdateEnabled() *PaymentProviderInstanceUpsert {
+ u.SetExcluded(paymentproviderinstance.FieldEnabled)
+ return u
+}
+
+// SetPaymentMode sets the "payment_mode" field.
+func (u *PaymentProviderInstanceUpsert) SetPaymentMode(v string) *PaymentProviderInstanceUpsert {
+ u.Set(paymentproviderinstance.FieldPaymentMode, v)
+ return u
+}
+
+// UpdatePaymentMode sets the "payment_mode" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsert) UpdatePaymentMode() *PaymentProviderInstanceUpsert {
+ u.SetExcluded(paymentproviderinstance.FieldPaymentMode)
+ return u
+}
+
+// SetSortOrder sets the "sort_order" field.
+func (u *PaymentProviderInstanceUpsert) SetSortOrder(v int) *PaymentProviderInstanceUpsert {
+ u.Set(paymentproviderinstance.FieldSortOrder, v)
+ return u
+}
+
+// UpdateSortOrder sets the "sort_order" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsert) UpdateSortOrder() *PaymentProviderInstanceUpsert {
+ u.SetExcluded(paymentproviderinstance.FieldSortOrder)
+ return u
+}
+
+// AddSortOrder adds v to the "sort_order" field.
+func (u *PaymentProviderInstanceUpsert) AddSortOrder(v int) *PaymentProviderInstanceUpsert {
+ u.Add(paymentproviderinstance.FieldSortOrder, v)
+ return u
+}
+
+// SetLimits sets the "limits" field.
+func (u *PaymentProviderInstanceUpsert) SetLimits(v string) *PaymentProviderInstanceUpsert {
+ u.Set(paymentproviderinstance.FieldLimits, v)
+ return u
+}
+
+// UpdateLimits sets the "limits" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsert) UpdateLimits() *PaymentProviderInstanceUpsert {
+ u.SetExcluded(paymentproviderinstance.FieldLimits)
+ return u
+}
+
+// SetRefundEnabled sets the "refund_enabled" field.
+func (u *PaymentProviderInstanceUpsert) SetRefundEnabled(v bool) *PaymentProviderInstanceUpsert {
+ u.Set(paymentproviderinstance.FieldRefundEnabled, v)
+ return u
+}
+
+// UpdateRefundEnabled sets the "refund_enabled" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsert) UpdateRefundEnabled() *PaymentProviderInstanceUpsert {
+ u.SetExcluded(paymentproviderinstance.FieldRefundEnabled)
+ return u
+}
+
+// SetAllowUserRefund sets the "allow_user_refund" field.
+func (u *PaymentProviderInstanceUpsert) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsert {
+ u.Set(paymentproviderinstance.FieldAllowUserRefund, v)
+ return u
+}
+
+// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsert) UpdateAllowUserRefund() *PaymentProviderInstanceUpsert {
+ u.SetExcluded(paymentproviderinstance.FieldAllowUserRefund)
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PaymentProviderInstanceUpsert) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsert {
+ u.Set(paymentproviderinstance.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsert) UpdateUpdatedAt() *PaymentProviderInstanceUpsert {
+ u.SetExcluded(paymentproviderinstance.FieldUpdatedAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.PaymentProviderInstance.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PaymentProviderInstanceUpsertOne) UpdateNewValues() *PaymentProviderInstanceUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(paymentproviderinstance.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PaymentProviderInstance.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PaymentProviderInstanceUpsertOne) Ignore() *PaymentProviderInstanceUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PaymentProviderInstanceUpsertOne) DoNothing() *PaymentProviderInstanceUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PaymentProviderInstanceCreate.OnConflict
+// documentation for more info.
+func (u *PaymentProviderInstanceUpsertOne) Update(set func(*PaymentProviderInstanceUpsert)) *PaymentProviderInstanceUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PaymentProviderInstanceUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PaymentProviderInstanceUpsertOne) SetProviderKey(v string) *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertOne) UpdateProviderKey() *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetName sets the "name" field.
+func (u *PaymentProviderInstanceUpsertOne) SetName(v string) *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetName(v)
+ })
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertOne) UpdateName() *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateName()
+ })
+}
+
+// SetConfig sets the "config" field.
+func (u *PaymentProviderInstanceUpsertOne) SetConfig(v string) *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetConfig(v)
+ })
+}
+
+// UpdateConfig sets the "config" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertOne) UpdateConfig() *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateConfig()
+ })
+}
+
+// SetSupportedTypes sets the "supported_types" field.
+func (u *PaymentProviderInstanceUpsertOne) SetSupportedTypes(v string) *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetSupportedTypes(v)
+ })
+}
+
+// UpdateSupportedTypes sets the "supported_types" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertOne) UpdateSupportedTypes() *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateSupportedTypes()
+ })
+}
+
+// SetEnabled sets the "enabled" field.
+func (u *PaymentProviderInstanceUpsertOne) SetEnabled(v bool) *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetEnabled(v)
+ })
+}
+
+// UpdateEnabled sets the "enabled" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertOne) UpdateEnabled() *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateEnabled()
+ })
+}
+
+// SetPaymentMode sets the "payment_mode" field.
+func (u *PaymentProviderInstanceUpsertOne) SetPaymentMode(v string) *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetPaymentMode(v)
+ })
+}
+
+// UpdatePaymentMode sets the "payment_mode" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertOne) UpdatePaymentMode() *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdatePaymentMode()
+ })
+}
+
+// SetSortOrder sets the "sort_order" field.
+func (u *PaymentProviderInstanceUpsertOne) SetSortOrder(v int) *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetSortOrder(v)
+ })
+}
+
+// AddSortOrder adds v to the "sort_order" field.
+func (u *PaymentProviderInstanceUpsertOne) AddSortOrder(v int) *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.AddSortOrder(v)
+ })
+}
+
+// UpdateSortOrder sets the "sort_order" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertOne) UpdateSortOrder() *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateSortOrder()
+ })
+}
+
+// SetLimits sets the "limits" field.
+func (u *PaymentProviderInstanceUpsertOne) SetLimits(v string) *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetLimits(v)
+ })
+}
+
+// UpdateLimits sets the "limits" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertOne) UpdateLimits() *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateLimits()
+ })
+}
+
+// SetRefundEnabled sets the "refund_enabled" field.
+func (u *PaymentProviderInstanceUpsertOne) SetRefundEnabled(v bool) *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetRefundEnabled(v)
+ })
+}
+
+// UpdateRefundEnabled sets the "refund_enabled" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertOne) UpdateRefundEnabled() *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateRefundEnabled()
+ })
+}
+
+// SetAllowUserRefund sets the "allow_user_refund" field.
+func (u *PaymentProviderInstanceUpsertOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetAllowUserRefund(v)
+ })
+}
+
+// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertOne) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateAllowUserRefund()
+ })
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PaymentProviderInstanceUpsertOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertOne) UpdateUpdatedAt() *PaymentProviderInstanceUpsertOne {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *PaymentProviderInstanceUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PaymentProviderInstanceCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PaymentProviderInstanceUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *PaymentProviderInstanceUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *PaymentProviderInstanceUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// PaymentProviderInstanceCreateBulk is the builder for creating many PaymentProviderInstance entities in bulk.
+type PaymentProviderInstanceCreateBulk struct {
+ config
+ err error
+ builders []*PaymentProviderInstanceCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the PaymentProviderInstance entities in the database.
+func (_c *PaymentProviderInstanceCreateBulk) Save(ctx context.Context) ([]*PaymentProviderInstance, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*PaymentProviderInstance, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*PaymentProviderInstanceMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *PaymentProviderInstanceCreateBulk) SaveX(ctx context.Context) []*PaymentProviderInstance {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PaymentProviderInstanceCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PaymentProviderInstanceCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PaymentProviderInstance.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PaymentProviderInstanceUpsert) {
+// SetProviderKey(v+v).
+// }).
+// Exec(ctx)
+func (_c *PaymentProviderInstanceCreateBulk) OnConflict(opts ...sql.ConflictOption) *PaymentProviderInstanceUpsertBulk {
+ _c.conflict = opts
+ return &PaymentProviderInstanceUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PaymentProviderInstance.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PaymentProviderInstanceCreateBulk) OnConflictColumns(columns ...string) *PaymentProviderInstanceUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PaymentProviderInstanceUpsertBulk{
+ create: _c,
+ }
+}
+
+// PaymentProviderInstanceUpsertBulk is the builder for "upsert"-ing
+// a bulk of PaymentProviderInstance nodes.
+type PaymentProviderInstanceUpsertBulk struct {
+ create *PaymentProviderInstanceCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.PaymentProviderInstance.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PaymentProviderInstanceUpsertBulk) UpdateNewValues() *PaymentProviderInstanceUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(paymentproviderinstance.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PaymentProviderInstance.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PaymentProviderInstanceUpsertBulk) Ignore() *PaymentProviderInstanceUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PaymentProviderInstanceUpsertBulk) DoNothing() *PaymentProviderInstanceUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PaymentProviderInstanceCreateBulk.OnConflict
+// documentation for more info.
+func (u *PaymentProviderInstanceUpsertBulk) Update(set func(*PaymentProviderInstanceUpsert)) *PaymentProviderInstanceUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PaymentProviderInstanceUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PaymentProviderInstanceUpsertBulk) SetProviderKey(v string) *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertBulk) UpdateProviderKey() *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetName sets the "name" field.
+func (u *PaymentProviderInstanceUpsertBulk) SetName(v string) *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetName(v)
+ })
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertBulk) UpdateName() *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateName()
+ })
+}
+
+// SetConfig sets the "config" field.
+func (u *PaymentProviderInstanceUpsertBulk) SetConfig(v string) *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetConfig(v)
+ })
+}
+
+// UpdateConfig sets the "config" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertBulk) UpdateConfig() *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateConfig()
+ })
+}
+
+// SetSupportedTypes sets the "supported_types" field.
+func (u *PaymentProviderInstanceUpsertBulk) SetSupportedTypes(v string) *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetSupportedTypes(v)
+ })
+}
+
+// UpdateSupportedTypes sets the "supported_types" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertBulk) UpdateSupportedTypes() *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateSupportedTypes()
+ })
+}
+
+// SetEnabled sets the "enabled" field.
+func (u *PaymentProviderInstanceUpsertBulk) SetEnabled(v bool) *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetEnabled(v)
+ })
+}
+
+// UpdateEnabled sets the "enabled" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertBulk) UpdateEnabled() *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateEnabled()
+ })
+}
+
+// SetPaymentMode sets the "payment_mode" field.
+func (u *PaymentProviderInstanceUpsertBulk) SetPaymentMode(v string) *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetPaymentMode(v)
+ })
+}
+
+// UpdatePaymentMode sets the "payment_mode" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertBulk) UpdatePaymentMode() *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdatePaymentMode()
+ })
+}
+
+// SetSortOrder sets the "sort_order" field.
+func (u *PaymentProviderInstanceUpsertBulk) SetSortOrder(v int) *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetSortOrder(v)
+ })
+}
+
+// AddSortOrder adds v to the "sort_order" field.
+func (u *PaymentProviderInstanceUpsertBulk) AddSortOrder(v int) *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.AddSortOrder(v)
+ })
+}
+
+// UpdateSortOrder sets the "sort_order" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertBulk) UpdateSortOrder() *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateSortOrder()
+ })
+}
+
+// SetLimits sets the "limits" field.
+func (u *PaymentProviderInstanceUpsertBulk) SetLimits(v string) *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetLimits(v)
+ })
+}
+
+// UpdateLimits sets the "limits" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertBulk) UpdateLimits() *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateLimits()
+ })
+}
+
+// SetRefundEnabled sets the "refund_enabled" field.
+func (u *PaymentProviderInstanceUpsertBulk) SetRefundEnabled(v bool) *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetRefundEnabled(v)
+ })
+}
+
+// UpdateRefundEnabled sets the "refund_enabled" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertBulk) UpdateRefundEnabled() *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateRefundEnabled()
+ })
+}
+
+// SetAllowUserRefund sets the "allow_user_refund" field.
+func (u *PaymentProviderInstanceUpsertBulk) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetAllowUserRefund(v)
+ })
+}
+
+// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertBulk) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateAllowUserRefund()
+ })
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PaymentProviderInstanceUpsertBulk) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PaymentProviderInstanceUpsertBulk) UpdateUpdatedAt() *PaymentProviderInstanceUpsertBulk {
+ return u.Update(func(s *PaymentProviderInstanceUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *PaymentProviderInstanceUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the PaymentProviderInstanceCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PaymentProviderInstanceCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PaymentProviderInstanceUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/paymentproviderinstance_delete.go b/backend/ent/paymentproviderinstance_delete.go
new file mode 100644
index 00000000..0cffe731
--- /dev/null
+++ b/backend/ent/paymentproviderinstance_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// PaymentProviderInstanceDelete is the builder for deleting a PaymentProviderInstance entity.
+type PaymentProviderInstanceDelete struct {
+ config
+ hooks []Hook
+ mutation *PaymentProviderInstanceMutation
+}
+
+// Where appends a list predicates to the PaymentProviderInstanceDelete builder.
+func (_d *PaymentProviderInstanceDelete) Where(ps ...predicate.PaymentProviderInstance) *PaymentProviderInstanceDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *PaymentProviderInstanceDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PaymentProviderInstanceDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *PaymentProviderInstanceDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(paymentproviderinstance.Table, sqlgraph.NewFieldSpec(paymentproviderinstance.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// PaymentProviderInstanceDeleteOne is the builder for deleting a single PaymentProviderInstance entity.
+type PaymentProviderInstanceDeleteOne struct {
+ _d *PaymentProviderInstanceDelete
+}
+
+// Where appends a list predicates to the PaymentProviderInstanceDelete builder.
+func (_d *PaymentProviderInstanceDeleteOne) Where(ps ...predicate.PaymentProviderInstance) *PaymentProviderInstanceDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *PaymentProviderInstanceDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{paymentproviderinstance.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PaymentProviderInstanceDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/paymentproviderinstance_query.go b/backend/ent/paymentproviderinstance_query.go
new file mode 100644
index 00000000..c0212088
--- /dev/null
+++ b/backend/ent/paymentproviderinstance_query.go
@@ -0,0 +1,564 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// PaymentProviderInstanceQuery is the builder for querying PaymentProviderInstance entities.
+type PaymentProviderInstanceQuery struct {
+ config
+ ctx *QueryContext
+ order []paymentproviderinstance.OrderOption
+ inters []Interceptor
+ predicates []predicate.PaymentProviderInstance
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the PaymentProviderInstanceQuery builder.
+func (_q *PaymentProviderInstanceQuery) Where(ps ...predicate.PaymentProviderInstance) *PaymentProviderInstanceQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *PaymentProviderInstanceQuery) Limit(limit int) *PaymentProviderInstanceQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *PaymentProviderInstanceQuery) Offset(offset int) *PaymentProviderInstanceQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *PaymentProviderInstanceQuery) Unique(unique bool) *PaymentProviderInstanceQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *PaymentProviderInstanceQuery) Order(o ...paymentproviderinstance.OrderOption) *PaymentProviderInstanceQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// First returns the first PaymentProviderInstance entity from the query.
+// Returns a *NotFoundError when no PaymentProviderInstance was found.
+func (_q *PaymentProviderInstanceQuery) First(ctx context.Context) (*PaymentProviderInstance, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{paymentproviderinstance.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *PaymentProviderInstanceQuery) FirstX(ctx context.Context) *PaymentProviderInstance {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first PaymentProviderInstance ID from the query.
+// Returns a *NotFoundError when no PaymentProviderInstance ID was found.
+func (_q *PaymentProviderInstanceQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{paymentproviderinstance.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *PaymentProviderInstanceQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single PaymentProviderInstance entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one PaymentProviderInstance entity is found.
+// Returns a *NotFoundError when no PaymentProviderInstance entities are found.
+func (_q *PaymentProviderInstanceQuery) Only(ctx context.Context) (*PaymentProviderInstance, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{paymentproviderinstance.Label}
+ default:
+ return nil, &NotSingularError{paymentproviderinstance.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *PaymentProviderInstanceQuery) OnlyX(ctx context.Context) *PaymentProviderInstance {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only PaymentProviderInstance ID in the query.
+// Returns a *NotSingularError when more than one PaymentProviderInstance ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *PaymentProviderInstanceQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{paymentproviderinstance.Label}
+ default:
+ err = &NotSingularError{paymentproviderinstance.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *PaymentProviderInstanceQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of PaymentProviderInstances.
+func (_q *PaymentProviderInstanceQuery) All(ctx context.Context) ([]*PaymentProviderInstance, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*PaymentProviderInstance, *PaymentProviderInstanceQuery]()
+ return withInterceptors[[]*PaymentProviderInstance](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *PaymentProviderInstanceQuery) AllX(ctx context.Context) []*PaymentProviderInstance {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of PaymentProviderInstance IDs.
+func (_q *PaymentProviderInstanceQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(paymentproviderinstance.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *PaymentProviderInstanceQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *PaymentProviderInstanceQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*PaymentProviderInstanceQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *PaymentProviderInstanceQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *PaymentProviderInstanceQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *PaymentProviderInstanceQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the PaymentProviderInstanceQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *PaymentProviderInstanceQuery) Clone() *PaymentProviderInstanceQuery {
+ if _q == nil {
+ return nil
+ }
+ return &PaymentProviderInstanceQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]paymentproviderinstance.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.PaymentProviderInstance{}, _q.predicates...),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// ProviderKey string `json:"provider_key,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.PaymentProviderInstance.Query().
+// GroupBy(paymentproviderinstance.FieldProviderKey).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *PaymentProviderInstanceQuery) GroupBy(field string, fields ...string) *PaymentProviderInstanceGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &PaymentProviderInstanceGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = paymentproviderinstance.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// ProviderKey string `json:"provider_key,omitempty"`
+// }
+//
+// client.PaymentProviderInstance.Query().
+// Select(paymentproviderinstance.FieldProviderKey).
+// Scan(ctx, &v)
+func (_q *PaymentProviderInstanceQuery) Select(fields ...string) *PaymentProviderInstanceSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &PaymentProviderInstanceSelect{PaymentProviderInstanceQuery: _q}
+ sbuild.label = paymentproviderinstance.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a PaymentProviderInstanceSelect configured with the given aggregations.
+func (_q *PaymentProviderInstanceQuery) Aggregate(fns ...AggregateFunc) *PaymentProviderInstanceSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *PaymentProviderInstanceQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !paymentproviderinstance.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *PaymentProviderInstanceQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PaymentProviderInstance, error) {
+ var (
+ nodes = []*PaymentProviderInstance{}
+ _spec = _q.querySpec()
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*PaymentProviderInstance).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &PaymentProviderInstance{config: _q.config}
+ nodes = append(nodes, node)
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ return nodes, nil
+}
+
+func (_q *PaymentProviderInstanceQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *PaymentProviderInstanceQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(paymentproviderinstance.Table, paymentproviderinstance.Columns, sqlgraph.NewFieldSpec(paymentproviderinstance.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, paymentproviderinstance.FieldID)
+ for i := range fields {
+ if fields[i] != paymentproviderinstance.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *PaymentProviderInstanceQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(paymentproviderinstance.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = paymentproviderinstance.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *PaymentProviderInstanceQuery) ForUpdate(opts ...sql.LockOption) *PaymentProviderInstanceQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *PaymentProviderInstanceQuery) ForShare(opts ...sql.LockOption) *PaymentProviderInstanceQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// PaymentProviderInstanceGroupBy is the group-by builder for PaymentProviderInstance entities.
+type PaymentProviderInstanceGroupBy struct {
+ selector
+ build *PaymentProviderInstanceQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *PaymentProviderInstanceGroupBy) Aggregate(fns ...AggregateFunc) *PaymentProviderInstanceGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *PaymentProviderInstanceGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PaymentProviderInstanceQuery, *PaymentProviderInstanceGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *PaymentProviderInstanceGroupBy) sqlScan(ctx context.Context, root *PaymentProviderInstanceQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// PaymentProviderInstanceSelect is the builder for selecting fields of PaymentProviderInstance entities.
+type PaymentProviderInstanceSelect struct {
+ *PaymentProviderInstanceQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *PaymentProviderInstanceSelect) Aggregate(fns ...AggregateFunc) *PaymentProviderInstanceSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *PaymentProviderInstanceSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PaymentProviderInstanceQuery, *PaymentProviderInstanceSelect](ctx, _s.PaymentProviderInstanceQuery, _s, _s.inters, v)
+}
+
+func (_s *PaymentProviderInstanceSelect) sqlScan(ctx context.Context, root *PaymentProviderInstanceQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/paymentproviderinstance_update.go b/backend/ent/paymentproviderinstance_update.go
new file mode 100644
index 00000000..6bb3a82d
--- /dev/null
+++ b/backend/ent/paymentproviderinstance_update.go
@@ -0,0 +1,628 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// PaymentProviderInstanceUpdate is the builder for updating PaymentProviderInstance entities.
+type PaymentProviderInstanceUpdate struct {
+ config
+ hooks []Hook
+ mutation *PaymentProviderInstanceMutation
+}
+
+// Where appends a list predicates to the PaymentProviderInstanceUpdate builder.
+func (_u *PaymentProviderInstanceUpdate) Where(ps ...predicate.PaymentProviderInstance) *PaymentProviderInstanceUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *PaymentProviderInstanceUpdate) SetProviderKey(v string) *PaymentProviderInstanceUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdate) SetNillableProviderKey(v *string) *PaymentProviderInstanceUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetName sets the "name" field.
+func (_u *PaymentProviderInstanceUpdate) SetName(v string) *PaymentProviderInstanceUpdate {
+ _u.mutation.SetName(v)
+ return _u
+}
+
+// SetNillableName sets the "name" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdate) SetNillableName(v *string) *PaymentProviderInstanceUpdate {
+ if v != nil {
+ _u.SetName(*v)
+ }
+ return _u
+}
+
+// SetConfig sets the "config" field.
+func (_u *PaymentProviderInstanceUpdate) SetConfig(v string) *PaymentProviderInstanceUpdate {
+ _u.mutation.SetConfig(v)
+ return _u
+}
+
+// SetNillableConfig sets the "config" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdate) SetNillableConfig(v *string) *PaymentProviderInstanceUpdate {
+ if v != nil {
+ _u.SetConfig(*v)
+ }
+ return _u
+}
+
+// SetSupportedTypes sets the "supported_types" field.
+func (_u *PaymentProviderInstanceUpdate) SetSupportedTypes(v string) *PaymentProviderInstanceUpdate {
+ _u.mutation.SetSupportedTypes(v)
+ return _u
+}
+
+// SetNillableSupportedTypes sets the "supported_types" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdate) SetNillableSupportedTypes(v *string) *PaymentProviderInstanceUpdate {
+ if v != nil {
+ _u.SetSupportedTypes(*v)
+ }
+ return _u
+}
+
+// SetEnabled sets the "enabled" field.
+func (_u *PaymentProviderInstanceUpdate) SetEnabled(v bool) *PaymentProviderInstanceUpdate {
+ _u.mutation.SetEnabled(v)
+ return _u
+}
+
+// SetNillableEnabled sets the "enabled" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdate) SetNillableEnabled(v *bool) *PaymentProviderInstanceUpdate {
+ if v != nil {
+ _u.SetEnabled(*v)
+ }
+ return _u
+}
+
+// SetPaymentMode sets the "payment_mode" field.
+func (_u *PaymentProviderInstanceUpdate) SetPaymentMode(v string) *PaymentProviderInstanceUpdate {
+ _u.mutation.SetPaymentMode(v)
+ return _u
+}
+
+// SetNillablePaymentMode sets the "payment_mode" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdate) SetNillablePaymentMode(v *string) *PaymentProviderInstanceUpdate {
+ if v != nil {
+ _u.SetPaymentMode(*v)
+ }
+ return _u
+}
+
+// SetSortOrder sets the "sort_order" field.
+func (_u *PaymentProviderInstanceUpdate) SetSortOrder(v int) *PaymentProviderInstanceUpdate {
+ _u.mutation.ResetSortOrder()
+ _u.mutation.SetSortOrder(v)
+ return _u
+}
+
+// SetNillableSortOrder sets the "sort_order" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdate) SetNillableSortOrder(v *int) *PaymentProviderInstanceUpdate {
+ if v != nil {
+ _u.SetSortOrder(*v)
+ }
+ return _u
+}
+
+// AddSortOrder adds value to the "sort_order" field.
+func (_u *PaymentProviderInstanceUpdate) AddSortOrder(v int) *PaymentProviderInstanceUpdate {
+ _u.mutation.AddSortOrder(v)
+ return _u
+}
+
+// SetLimits sets the "limits" field.
+func (_u *PaymentProviderInstanceUpdate) SetLimits(v string) *PaymentProviderInstanceUpdate {
+ _u.mutation.SetLimits(v)
+ return _u
+}
+
+// SetNillableLimits sets the "limits" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdate) SetNillableLimits(v *string) *PaymentProviderInstanceUpdate {
+ if v != nil {
+ _u.SetLimits(*v)
+ }
+ return _u
+}
+
+// SetRefundEnabled sets the "refund_enabled" field.
+func (_u *PaymentProviderInstanceUpdate) SetRefundEnabled(v bool) *PaymentProviderInstanceUpdate {
+ _u.mutation.SetRefundEnabled(v)
+ return _u
+}
+
+// SetNillableRefundEnabled sets the "refund_enabled" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdate) SetNillableRefundEnabled(v *bool) *PaymentProviderInstanceUpdate {
+ if v != nil {
+ _u.SetRefundEnabled(*v)
+ }
+ return _u
+}
+
+// SetAllowUserRefund sets the "allow_user_refund" field.
+func (_u *PaymentProviderInstanceUpdate) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdate {
+ _u.mutation.SetAllowUserRefund(v)
+ return _u
+}
+
+// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdate {
+ if v != nil {
+ _u.SetAllowUserRefund(*v)
+ }
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *PaymentProviderInstanceUpdate) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// Mutation returns the PaymentProviderInstanceMutation object of the builder.
+func (_u *PaymentProviderInstanceUpdate) Mutation() *PaymentProviderInstanceMutation {
+ return _u.mutation
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *PaymentProviderInstanceUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PaymentProviderInstanceUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *PaymentProviderInstanceUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PaymentProviderInstanceUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *PaymentProviderInstanceUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := paymentproviderinstance.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PaymentProviderInstanceUpdate) check() error {
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := paymentproviderinstance.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentProviderInstance.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Name(); ok {
+ if err := paymentproviderinstance.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "PaymentProviderInstance.name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.SupportedTypes(); ok {
+ if err := paymentproviderinstance.SupportedTypesValidator(v); err != nil {
+ return &ValidationError{Name: "supported_types", err: fmt.Errorf(`ent: validator failed for field "PaymentProviderInstance.supported_types": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.PaymentMode(); ok {
+ if err := paymentproviderinstance.PaymentModeValidator(v); err != nil {
+ return &ValidationError{Name: "payment_mode", err: fmt.Errorf(`ent: validator failed for field "PaymentProviderInstance.payment_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *PaymentProviderInstanceUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(paymentproviderinstance.Table, paymentproviderinstance.Columns, sqlgraph.NewFieldSpec(paymentproviderinstance.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(paymentproviderinstance.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Name(); ok {
+ _spec.SetField(paymentproviderinstance.FieldName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Config(); ok {
+ _spec.SetField(paymentproviderinstance.FieldConfig, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.SupportedTypes(); ok {
+ _spec.SetField(paymentproviderinstance.FieldSupportedTypes, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Enabled(); ok {
+ _spec.SetField(paymentproviderinstance.FieldEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.PaymentMode(); ok {
+ _spec.SetField(paymentproviderinstance.FieldPaymentMode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.SortOrder(); ok {
+ _spec.SetField(paymentproviderinstance.FieldSortOrder, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedSortOrder(); ok {
+ _spec.AddField(paymentproviderinstance.FieldSortOrder, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.Limits(); ok {
+ _spec.SetField(paymentproviderinstance.FieldLimits, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RefundEnabled(); ok {
+ _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.AllowUserRefund(); ok {
+ _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{paymentproviderinstance.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// PaymentProviderInstanceUpdateOne is the builder for updating a single PaymentProviderInstance entity.
+type PaymentProviderInstanceUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *PaymentProviderInstanceMutation
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *PaymentProviderInstanceUpdateOne) SetProviderKey(v string) *PaymentProviderInstanceUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdateOne) SetNillableProviderKey(v *string) *PaymentProviderInstanceUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetName sets the "name" field.
+func (_u *PaymentProviderInstanceUpdateOne) SetName(v string) *PaymentProviderInstanceUpdateOne {
+ _u.mutation.SetName(v)
+ return _u
+}
+
+// SetNillableName sets the "name" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdateOne) SetNillableName(v *string) *PaymentProviderInstanceUpdateOne {
+ if v != nil {
+ _u.SetName(*v)
+ }
+ return _u
+}
+
+// SetConfig sets the "config" field.
+func (_u *PaymentProviderInstanceUpdateOne) SetConfig(v string) *PaymentProviderInstanceUpdateOne {
+ _u.mutation.SetConfig(v)
+ return _u
+}
+
+// SetNillableConfig sets the "config" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdateOne) SetNillableConfig(v *string) *PaymentProviderInstanceUpdateOne {
+ if v != nil {
+ _u.SetConfig(*v)
+ }
+ return _u
+}
+
+// SetSupportedTypes sets the "supported_types" field.
+func (_u *PaymentProviderInstanceUpdateOne) SetSupportedTypes(v string) *PaymentProviderInstanceUpdateOne {
+ _u.mutation.SetSupportedTypes(v)
+ return _u
+}
+
+// SetNillableSupportedTypes sets the "supported_types" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdateOne) SetNillableSupportedTypes(v *string) *PaymentProviderInstanceUpdateOne {
+ if v != nil {
+ _u.SetSupportedTypes(*v)
+ }
+ return _u
+}
+
+// SetEnabled sets the "enabled" field.
+func (_u *PaymentProviderInstanceUpdateOne) SetEnabled(v bool) *PaymentProviderInstanceUpdateOne {
+ _u.mutation.SetEnabled(v)
+ return _u
+}
+
+// SetNillableEnabled sets the "enabled" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdateOne) SetNillableEnabled(v *bool) *PaymentProviderInstanceUpdateOne {
+ if v != nil {
+ _u.SetEnabled(*v)
+ }
+ return _u
+}
+
+// SetPaymentMode sets the "payment_mode" field.
+func (_u *PaymentProviderInstanceUpdateOne) SetPaymentMode(v string) *PaymentProviderInstanceUpdateOne {
+ _u.mutation.SetPaymentMode(v)
+ return _u
+}
+
+// SetNillablePaymentMode sets the "payment_mode" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdateOne) SetNillablePaymentMode(v *string) *PaymentProviderInstanceUpdateOne {
+ if v != nil {
+ _u.SetPaymentMode(*v)
+ }
+ return _u
+}
+
+// SetSortOrder sets the "sort_order" field.
+func (_u *PaymentProviderInstanceUpdateOne) SetSortOrder(v int) *PaymentProviderInstanceUpdateOne {
+ _u.mutation.ResetSortOrder()
+ _u.mutation.SetSortOrder(v)
+ return _u
+}
+
+// SetNillableSortOrder sets the "sort_order" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdateOne) SetNillableSortOrder(v *int) *PaymentProviderInstanceUpdateOne {
+ if v != nil {
+ _u.SetSortOrder(*v)
+ }
+ return _u
+}
+
+// AddSortOrder adds value to the "sort_order" field.
+func (_u *PaymentProviderInstanceUpdateOne) AddSortOrder(v int) *PaymentProviderInstanceUpdateOne {
+ _u.mutation.AddSortOrder(v)
+ return _u
+}
+
+// SetLimits sets the "limits" field.
+func (_u *PaymentProviderInstanceUpdateOne) SetLimits(v string) *PaymentProviderInstanceUpdateOne {
+ _u.mutation.SetLimits(v)
+ return _u
+}
+
+// SetNillableLimits sets the "limits" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdateOne) SetNillableLimits(v *string) *PaymentProviderInstanceUpdateOne {
+ if v != nil {
+ _u.SetLimits(*v)
+ }
+ return _u
+}
+
+// SetRefundEnabled sets the "refund_enabled" field.
+func (_u *PaymentProviderInstanceUpdateOne) SetRefundEnabled(v bool) *PaymentProviderInstanceUpdateOne {
+ _u.mutation.SetRefundEnabled(v)
+ return _u
+}
+
+// SetNillableRefundEnabled sets the "refund_enabled" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdateOne) SetNillableRefundEnabled(v *bool) *PaymentProviderInstanceUpdateOne {
+ if v != nil {
+ _u.SetRefundEnabled(*v)
+ }
+ return _u
+}
+
+// SetAllowUserRefund sets the "allow_user_refund" field.
+func (_u *PaymentProviderInstanceUpdateOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdateOne {
+ _u.mutation.SetAllowUserRefund(v)
+ return _u
+}
+
+// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
+func (_u *PaymentProviderInstanceUpdateOne) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdateOne {
+ if v != nil {
+ _u.SetAllowUserRefund(*v)
+ }
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *PaymentProviderInstanceUpdateOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// Mutation returns the PaymentProviderInstanceMutation object of the builder.
+func (_u *PaymentProviderInstanceUpdateOne) Mutation() *PaymentProviderInstanceMutation {
+ return _u.mutation
+}
+
+// Where appends a list predicates to the PaymentProviderInstanceUpdate builder.
+func (_u *PaymentProviderInstanceUpdateOne) Where(ps ...predicate.PaymentProviderInstance) *PaymentProviderInstanceUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *PaymentProviderInstanceUpdateOne) Select(field string, fields ...string) *PaymentProviderInstanceUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated PaymentProviderInstance entity.
+func (_u *PaymentProviderInstanceUpdateOne) Save(ctx context.Context) (*PaymentProviderInstance, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PaymentProviderInstanceUpdateOne) SaveX(ctx context.Context) *PaymentProviderInstance {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *PaymentProviderInstanceUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PaymentProviderInstanceUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *PaymentProviderInstanceUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := paymentproviderinstance.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PaymentProviderInstanceUpdateOne) check() error {
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := paymentproviderinstance.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentProviderInstance.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Name(); ok {
+ if err := paymentproviderinstance.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "PaymentProviderInstance.name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.SupportedTypes(); ok {
+ if err := paymentproviderinstance.SupportedTypesValidator(v); err != nil {
+ return &ValidationError{Name: "supported_types", err: fmt.Errorf(`ent: validator failed for field "PaymentProviderInstance.supported_types": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.PaymentMode(); ok {
+ if err := paymentproviderinstance.PaymentModeValidator(v); err != nil {
+ return &ValidationError{Name: "payment_mode", err: fmt.Errorf(`ent: validator failed for field "PaymentProviderInstance.payment_mode": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *PaymentProviderInstanceUpdateOne) sqlSave(ctx context.Context) (_node *PaymentProviderInstance, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(paymentproviderinstance.Table, paymentproviderinstance.Columns, sqlgraph.NewFieldSpec(paymentproviderinstance.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PaymentProviderInstance.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, paymentproviderinstance.FieldID)
+ for _, f := range fields {
+ if !paymentproviderinstance.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != paymentproviderinstance.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(paymentproviderinstance.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Name(); ok {
+ _spec.SetField(paymentproviderinstance.FieldName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Config(); ok {
+ _spec.SetField(paymentproviderinstance.FieldConfig, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.SupportedTypes(); ok {
+ _spec.SetField(paymentproviderinstance.FieldSupportedTypes, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Enabled(); ok {
+ _spec.SetField(paymentproviderinstance.FieldEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.PaymentMode(); ok {
+ _spec.SetField(paymentproviderinstance.FieldPaymentMode, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.SortOrder(); ok {
+ _spec.SetField(paymentproviderinstance.FieldSortOrder, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedSortOrder(); ok {
+ _spec.AddField(paymentproviderinstance.FieldSortOrder, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.Limits(); ok {
+ _spec.SetField(paymentproviderinstance.FieldLimits, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RefundEnabled(); ok {
+ _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.AllowUserRefund(); ok {
+ _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value)
+ }
+ _node = &PaymentProviderInstance{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{paymentproviderinstance.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/pendingauthsession.go b/backend/ent/pendingauthsession.go
new file mode 100644
index 00000000..e77c065f
--- /dev/null
+++ b/backend/ent/pendingauthsession.go
@@ -0,0 +1,399 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSession is the model entity for the PendingAuthSession schema.
+type PendingAuthSession struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // SessionToken holds the value of the "session_token" field.
+ SessionToken string `json:"session_token,omitempty"`
+ // Intent holds the value of the "intent" field.
+ Intent string `json:"intent,omitempty"`
+ // ProviderType holds the value of the "provider_type" field.
+ ProviderType string `json:"provider_type,omitempty"`
+ // ProviderKey holds the value of the "provider_key" field.
+ ProviderKey string `json:"provider_key,omitempty"`
+ // ProviderSubject holds the value of the "provider_subject" field.
+ ProviderSubject string `json:"provider_subject,omitempty"`
+ // TargetUserID holds the value of the "target_user_id" field.
+ TargetUserID *int64 `json:"target_user_id,omitempty"`
+ // RedirectTo holds the value of the "redirect_to" field.
+ RedirectTo string `json:"redirect_to,omitempty"`
+ // ResolvedEmail holds the value of the "resolved_email" field.
+ ResolvedEmail string `json:"resolved_email,omitempty"`
+ // RegistrationPasswordHash holds the value of the "registration_password_hash" field.
+ RegistrationPasswordHash string `json:"registration_password_hash,omitempty"`
+ // UpstreamIdentityClaims holds the value of the "upstream_identity_claims" field.
+ UpstreamIdentityClaims map[string]interface{} `json:"upstream_identity_claims,omitempty"`
+ // LocalFlowState holds the value of the "local_flow_state" field.
+ LocalFlowState map[string]interface{} `json:"local_flow_state,omitempty"`
+ // BrowserSessionKey holds the value of the "browser_session_key" field.
+ BrowserSessionKey string `json:"browser_session_key,omitempty"`
+ // CompletionCodeHash holds the value of the "completion_code_hash" field.
+ CompletionCodeHash string `json:"completion_code_hash,omitempty"`
+ // CompletionCodeExpiresAt holds the value of the "completion_code_expires_at" field.
+ CompletionCodeExpiresAt *time.Time `json:"completion_code_expires_at,omitempty"`
+ // EmailVerifiedAt holds the value of the "email_verified_at" field.
+ EmailVerifiedAt *time.Time `json:"email_verified_at,omitempty"`
+ // PasswordVerifiedAt holds the value of the "password_verified_at" field.
+ PasswordVerifiedAt *time.Time `json:"password_verified_at,omitempty"`
+ // TotpVerifiedAt holds the value of the "totp_verified_at" field.
+ TotpVerifiedAt *time.Time `json:"totp_verified_at,omitempty"`
+ // ExpiresAt holds the value of the "expires_at" field.
+ ExpiresAt time.Time `json:"expires_at,omitempty"`
+ // ConsumedAt holds the value of the "consumed_at" field.
+ ConsumedAt *time.Time `json:"consumed_at,omitempty"`
+ // Edges holds the relations/edges for other nodes in the graph.
+ // The values are being populated by the PendingAuthSessionQuery when eager-loading is set.
+ Edges PendingAuthSessionEdges `json:"edges"`
+ selectValues sql.SelectValues
+}
+
+// PendingAuthSessionEdges holds the relations/edges for other nodes in the graph.
+type PendingAuthSessionEdges struct {
+ // TargetUser holds the value of the target_user edge.
+ TargetUser *User `json:"target_user,omitempty"`
+ // AdoptionDecision holds the value of the adoption_decision edge.
+ AdoptionDecision *IdentityAdoptionDecision `json:"adoption_decision,omitempty"`
+ // loadedTypes holds the information for reporting if a
+ // type was loaded (or requested) in eager-loading or not.
+ loadedTypes [2]bool
+}
+
+// TargetUserOrErr returns the TargetUser value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e PendingAuthSessionEdges) TargetUserOrErr() (*User, error) {
+ if e.TargetUser != nil {
+ return e.TargetUser, nil
+ } else if e.loadedTypes[0] {
+ return nil, &NotFoundError{label: user.Label}
+ }
+ return nil, &NotLoadedError{edge: "target_user"}
+}
+
+// AdoptionDecisionOrErr returns the AdoptionDecision value or an error if the edge
+// was not loaded in eager-loading, or loaded but was not found.
+func (e PendingAuthSessionEdges) AdoptionDecisionOrErr() (*IdentityAdoptionDecision, error) {
+ if e.AdoptionDecision != nil {
+ return e.AdoptionDecision, nil
+ } else if e.loadedTypes[1] {
+ return nil, &NotFoundError{label: identityadoptiondecision.Label}
+ }
+ return nil, &NotLoadedError{edge: "adoption_decision"}
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*PendingAuthSession) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case pendingauthsession.FieldUpstreamIdentityClaims, pendingauthsession.FieldLocalFlowState:
+ values[i] = new([]byte)
+ case pendingauthsession.FieldID, pendingauthsession.FieldTargetUserID:
+ values[i] = new(sql.NullInt64)
+ case pendingauthsession.FieldSessionToken, pendingauthsession.FieldIntent, pendingauthsession.FieldProviderType, pendingauthsession.FieldProviderKey, pendingauthsession.FieldProviderSubject, pendingauthsession.FieldRedirectTo, pendingauthsession.FieldResolvedEmail, pendingauthsession.FieldRegistrationPasswordHash, pendingauthsession.FieldBrowserSessionKey, pendingauthsession.FieldCompletionCodeHash:
+ values[i] = new(sql.NullString)
+ case pendingauthsession.FieldCreatedAt, pendingauthsession.FieldUpdatedAt, pendingauthsession.FieldCompletionCodeExpiresAt, pendingauthsession.FieldEmailVerifiedAt, pendingauthsession.FieldPasswordVerifiedAt, pendingauthsession.FieldTotpVerifiedAt, pendingauthsession.FieldExpiresAt, pendingauthsession.FieldConsumedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the PendingAuthSession fields.
+func (_m *PendingAuthSession) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case pendingauthsession.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case pendingauthsession.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case pendingauthsession.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case pendingauthsession.FieldSessionToken:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field session_token", values[i])
+ } else if value.Valid {
+ _m.SessionToken = value.String
+ }
+ case pendingauthsession.FieldIntent:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field intent", values[i])
+ } else if value.Valid {
+ _m.Intent = value.String
+ }
+ case pendingauthsession.FieldProviderType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_type", values[i])
+ } else if value.Valid {
+ _m.ProviderType = value.String
+ }
+ case pendingauthsession.FieldProviderKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_key", values[i])
+ } else if value.Valid {
+ _m.ProviderKey = value.String
+ }
+ case pendingauthsession.FieldProviderSubject:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field provider_subject", values[i])
+ } else if value.Valid {
+ _m.ProviderSubject = value.String
+ }
+ case pendingauthsession.FieldTargetUserID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field target_user_id", values[i])
+ } else if value.Valid {
+ _m.TargetUserID = new(int64)
+ *_m.TargetUserID = value.Int64
+ }
+ case pendingauthsession.FieldRedirectTo:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field redirect_to", values[i])
+ } else if value.Valid {
+ _m.RedirectTo = value.String
+ }
+ case pendingauthsession.FieldResolvedEmail:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field resolved_email", values[i])
+ } else if value.Valid {
+ _m.ResolvedEmail = value.String
+ }
+ case pendingauthsession.FieldRegistrationPasswordHash:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field registration_password_hash", values[i])
+ } else if value.Valid {
+ _m.RegistrationPasswordHash = value.String
+ }
+ case pendingauthsession.FieldUpstreamIdentityClaims:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field upstream_identity_claims", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.UpstreamIdentityClaims); err != nil {
+ return fmt.Errorf("unmarshal field upstream_identity_claims: %w", err)
+ }
+ }
+ case pendingauthsession.FieldLocalFlowState:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field local_flow_state", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.LocalFlowState); err != nil {
+ return fmt.Errorf("unmarshal field local_flow_state: %w", err)
+ }
+ }
+ case pendingauthsession.FieldBrowserSessionKey:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field browser_session_key", values[i])
+ } else if value.Valid {
+ _m.BrowserSessionKey = value.String
+ }
+ case pendingauthsession.FieldCompletionCodeHash:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field completion_code_hash", values[i])
+ } else if value.Valid {
+ _m.CompletionCodeHash = value.String
+ }
+ case pendingauthsession.FieldCompletionCodeExpiresAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field completion_code_expires_at", values[i])
+ } else if value.Valid {
+ _m.CompletionCodeExpiresAt = new(time.Time)
+ *_m.CompletionCodeExpiresAt = value.Time
+ }
+ case pendingauthsession.FieldEmailVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field email_verified_at", values[i])
+ } else if value.Valid {
+ _m.EmailVerifiedAt = new(time.Time)
+ *_m.EmailVerifiedAt = value.Time
+ }
+ case pendingauthsession.FieldPasswordVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field password_verified_at", values[i])
+ } else if value.Valid {
+ _m.PasswordVerifiedAt = new(time.Time)
+ *_m.PasswordVerifiedAt = value.Time
+ }
+ case pendingauthsession.FieldTotpVerifiedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field totp_verified_at", values[i])
+ } else if value.Valid {
+ _m.TotpVerifiedAt = new(time.Time)
+ *_m.TotpVerifiedAt = value.Time
+ }
+ case pendingauthsession.FieldExpiresAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field expires_at", values[i])
+ } else if value.Valid {
+ _m.ExpiresAt = value.Time
+ }
+ case pendingauthsession.FieldConsumedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field consumed_at", values[i])
+ } else if value.Valid {
+ _m.ConsumedAt = new(time.Time)
+ *_m.ConsumedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the PendingAuthSession.
+// This includes values selected through modifiers, order, etc.
+func (_m *PendingAuthSession) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// QueryTargetUser queries the "target_user" edge of the PendingAuthSession entity.
+func (_m *PendingAuthSession) QueryTargetUser() *UserQuery {
+ return NewPendingAuthSessionClient(_m.config).QueryTargetUser(_m)
+}
+
+// QueryAdoptionDecision queries the "adoption_decision" edge of the PendingAuthSession entity.
+func (_m *PendingAuthSession) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery {
+ return NewPendingAuthSessionClient(_m.config).QueryAdoptionDecision(_m)
+}
+
+// Update returns a builder for updating this PendingAuthSession.
+// Note that you need to call PendingAuthSession.Unwrap() before calling this method if this PendingAuthSession
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *PendingAuthSession) Update() *PendingAuthSessionUpdateOne {
+ return NewPendingAuthSessionClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the PendingAuthSession entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *PendingAuthSession) Unwrap() *PendingAuthSession {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: PendingAuthSession is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *PendingAuthSession) String() string {
+ var builder strings.Builder
+ builder.WriteString("PendingAuthSession(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("session_token=")
+ builder.WriteString(_m.SessionToken)
+ builder.WriteString(", ")
+ builder.WriteString("intent=")
+ builder.WriteString(_m.Intent)
+ builder.WriteString(", ")
+ builder.WriteString("provider_type=")
+ builder.WriteString(_m.ProviderType)
+ builder.WriteString(", ")
+ builder.WriteString("provider_key=")
+ builder.WriteString(_m.ProviderKey)
+ builder.WriteString(", ")
+ builder.WriteString("provider_subject=")
+ builder.WriteString(_m.ProviderSubject)
+ builder.WriteString(", ")
+ if v := _m.TargetUserID; v != nil {
+ builder.WriteString("target_user_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("redirect_to=")
+ builder.WriteString(_m.RedirectTo)
+ builder.WriteString(", ")
+ builder.WriteString("resolved_email=")
+ builder.WriteString(_m.ResolvedEmail)
+ builder.WriteString(", ")
+ builder.WriteString("registration_password_hash=")
+ builder.WriteString(_m.RegistrationPasswordHash)
+ builder.WriteString(", ")
+ builder.WriteString("upstream_identity_claims=")
+ builder.WriteString(fmt.Sprintf("%v", _m.UpstreamIdentityClaims))
+ builder.WriteString(", ")
+ builder.WriteString("local_flow_state=")
+ builder.WriteString(fmt.Sprintf("%v", _m.LocalFlowState))
+ builder.WriteString(", ")
+ builder.WriteString("browser_session_key=")
+ builder.WriteString(_m.BrowserSessionKey)
+ builder.WriteString(", ")
+ builder.WriteString("completion_code_hash=")
+ builder.WriteString(_m.CompletionCodeHash)
+ builder.WriteString(", ")
+ if v := _m.CompletionCodeExpiresAt; v != nil {
+ builder.WriteString("completion_code_expires_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.EmailVerifiedAt; v != nil {
+ builder.WriteString("email_verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.PasswordVerifiedAt; v != nil {
+ builder.WriteString("password_verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.TotpVerifiedAt; v != nil {
+ builder.WriteString("totp_verified_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("expires_at=")
+ builder.WriteString(_m.ExpiresAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ if v := _m.ConsumedAt; v != nil {
+ builder.WriteString("consumed_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// PendingAuthSessions is a parsable slice of PendingAuthSession.
+type PendingAuthSessions []*PendingAuthSession
diff --git a/backend/ent/pendingauthsession/pendingauthsession.go b/backend/ent/pendingauthsession/pendingauthsession.go
new file mode 100644
index 00000000..8a3ac9bf
--- /dev/null
+++ b/backend/ent/pendingauthsession/pendingauthsession.go
@@ -0,0 +1,279 @@
+// Code generated by ent, DO NOT EDIT.
+
+package pendingauthsession
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+)
+
+const (
+ // Label holds the string label denoting the pendingauthsession type in the database.
+ Label = "pending_auth_session"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldSessionToken holds the string denoting the session_token field in the database.
+ FieldSessionToken = "session_token"
+ // FieldIntent holds the string denoting the intent field in the database.
+ FieldIntent = "intent"
+ // FieldProviderType holds the string denoting the provider_type field in the database.
+ FieldProviderType = "provider_type"
+ // FieldProviderKey holds the string denoting the provider_key field in the database.
+ FieldProviderKey = "provider_key"
+ // FieldProviderSubject holds the string denoting the provider_subject field in the database.
+ FieldProviderSubject = "provider_subject"
+ // FieldTargetUserID holds the string denoting the target_user_id field in the database.
+ FieldTargetUserID = "target_user_id"
+ // FieldRedirectTo holds the string denoting the redirect_to field in the database.
+ FieldRedirectTo = "redirect_to"
+ // FieldResolvedEmail holds the string denoting the resolved_email field in the database.
+ FieldResolvedEmail = "resolved_email"
+ // FieldRegistrationPasswordHash holds the string denoting the registration_password_hash field in the database.
+ FieldRegistrationPasswordHash = "registration_password_hash"
+ // FieldUpstreamIdentityClaims holds the string denoting the upstream_identity_claims field in the database.
+ FieldUpstreamIdentityClaims = "upstream_identity_claims"
+ // FieldLocalFlowState holds the string denoting the local_flow_state field in the database.
+ FieldLocalFlowState = "local_flow_state"
+ // FieldBrowserSessionKey holds the string denoting the browser_session_key field in the database.
+ FieldBrowserSessionKey = "browser_session_key"
+ // FieldCompletionCodeHash holds the string denoting the completion_code_hash field in the database.
+ FieldCompletionCodeHash = "completion_code_hash"
+ // FieldCompletionCodeExpiresAt holds the string denoting the completion_code_expires_at field in the database.
+ FieldCompletionCodeExpiresAt = "completion_code_expires_at"
+ // FieldEmailVerifiedAt holds the string denoting the email_verified_at field in the database.
+ FieldEmailVerifiedAt = "email_verified_at"
+ // FieldPasswordVerifiedAt holds the string denoting the password_verified_at field in the database.
+ FieldPasswordVerifiedAt = "password_verified_at"
+ // FieldTotpVerifiedAt holds the string denoting the totp_verified_at field in the database.
+ FieldTotpVerifiedAt = "totp_verified_at"
+ // FieldExpiresAt holds the string denoting the expires_at field in the database.
+ FieldExpiresAt = "expires_at"
+ // FieldConsumedAt holds the string denoting the consumed_at field in the database.
+ FieldConsumedAt = "consumed_at"
+ // EdgeTargetUser holds the string denoting the target_user edge name in mutations.
+ EdgeTargetUser = "target_user"
+ // EdgeAdoptionDecision holds the string denoting the adoption_decision edge name in mutations.
+ EdgeAdoptionDecision = "adoption_decision"
+ // Table holds the table name of the pendingauthsession in the database.
+ Table = "pending_auth_sessions"
+ // TargetUserTable is the table that holds the target_user relation/edge.
+ TargetUserTable = "pending_auth_sessions"
+ // TargetUserInverseTable is the table name for the User entity.
+ // It exists in this package in order to avoid circular dependency with the "user" package.
+ TargetUserInverseTable = "users"
+ // TargetUserColumn is the table column denoting the target_user relation/edge.
+ TargetUserColumn = "target_user_id"
+ // AdoptionDecisionTable is the table that holds the adoption_decision relation/edge.
+ AdoptionDecisionTable = "identity_adoption_decisions"
+ // AdoptionDecisionInverseTable is the table name for the IdentityAdoptionDecision entity.
+ // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package.
+ AdoptionDecisionInverseTable = "identity_adoption_decisions"
+ // AdoptionDecisionColumn is the table column denoting the adoption_decision relation/edge.
+ AdoptionDecisionColumn = "pending_auth_session_id"
+)
+
+// Columns holds all SQL columns for pendingauthsession fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldSessionToken,
+ FieldIntent,
+ FieldProviderType,
+ FieldProviderKey,
+ FieldProviderSubject,
+ FieldTargetUserID,
+ FieldRedirectTo,
+ FieldResolvedEmail,
+ FieldRegistrationPasswordHash,
+ FieldUpstreamIdentityClaims,
+ FieldLocalFlowState,
+ FieldBrowserSessionKey,
+ FieldCompletionCodeHash,
+ FieldCompletionCodeExpiresAt,
+ FieldEmailVerifiedAt,
+ FieldPasswordVerifiedAt,
+ FieldTotpVerifiedAt,
+ FieldExpiresAt,
+ FieldConsumedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save.
+ SessionTokenValidator func(string) error
+ // IntentValidator is a validator for the "intent" field. It is called by the builders before save.
+ IntentValidator func(string) error
+ // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ ProviderTypeValidator func(string) error
+ // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ ProviderKeyValidator func(string) error
+ // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ ProviderSubjectValidator func(string) error
+ // DefaultRedirectTo holds the default value on creation for the "redirect_to" field.
+ DefaultRedirectTo string
+ // DefaultResolvedEmail holds the default value on creation for the "resolved_email" field.
+ DefaultResolvedEmail string
+ // DefaultRegistrationPasswordHash holds the default value on creation for the "registration_password_hash" field.
+ DefaultRegistrationPasswordHash string
+ // DefaultUpstreamIdentityClaims holds the default value on creation for the "upstream_identity_claims" field.
+ DefaultUpstreamIdentityClaims func() map[string]interface{}
+ // DefaultLocalFlowState holds the default value on creation for the "local_flow_state" field.
+ DefaultLocalFlowState func() map[string]interface{}
+ // DefaultBrowserSessionKey holds the default value on creation for the "browser_session_key" field.
+ DefaultBrowserSessionKey string
+ // DefaultCompletionCodeHash holds the default value on creation for the "completion_code_hash" field.
+ DefaultCompletionCodeHash string
+)
+
+// OrderOption defines the ordering options for the PendingAuthSession queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// BySessionToken orders the results by the session_token field.
+func BySessionToken(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSessionToken, opts...).ToFunc()
+}
+
+// ByIntent orders the results by the intent field.
+func ByIntent(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldIntent, opts...).ToFunc()
+}
+
+// ByProviderType orders the results by the provider_type field.
+func ByProviderType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderType, opts...).ToFunc()
+}
+
+// ByProviderKey orders the results by the provider_key field.
+func ByProviderKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderKey, opts...).ToFunc()
+}
+
+// ByProviderSubject orders the results by the provider_subject field.
+func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProviderSubject, opts...).ToFunc()
+}
+
+// ByTargetUserID orders the results by the target_user_id field.
+func ByTargetUserID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTargetUserID, opts...).ToFunc()
+}
+
+// ByRedirectTo orders the results by the redirect_to field.
+func ByRedirectTo(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRedirectTo, opts...).ToFunc()
+}
+
+// ByResolvedEmail orders the results by the resolved_email field.
+func ByResolvedEmail(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldResolvedEmail, opts...).ToFunc()
+}
+
+// ByRegistrationPasswordHash orders the results by the registration_password_hash field.
+func ByRegistrationPasswordHash(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRegistrationPasswordHash, opts...).ToFunc()
+}
+
+// ByBrowserSessionKey orders the results by the browser_session_key field.
+func ByBrowserSessionKey(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBrowserSessionKey, opts...).ToFunc()
+}
+
+// ByCompletionCodeHash orders the results by the completion_code_hash field.
+func ByCompletionCodeHash(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCompletionCodeHash, opts...).ToFunc()
+}
+
+// ByCompletionCodeExpiresAt orders the results by the completion_code_expires_at field.
+func ByCompletionCodeExpiresAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCompletionCodeExpiresAt, opts...).ToFunc()
+}
+
+// ByEmailVerifiedAt orders the results by the email_verified_at field.
+func ByEmailVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldEmailVerifiedAt, opts...).ToFunc()
+}
+
+// ByPasswordVerifiedAt orders the results by the password_verified_at field.
+func ByPasswordVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPasswordVerifiedAt, opts...).ToFunc()
+}
+
+// ByTotpVerifiedAt orders the results by the totp_verified_at field.
+func ByTotpVerifiedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotpVerifiedAt, opts...).ToFunc()
+}
+
+// ByExpiresAt orders the results by the expires_at field.
+func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
+}
+
+// ByConsumedAt orders the results by the consumed_at field.
+func ByConsumedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldConsumedAt, opts...).ToFunc()
+}
+
+// ByTargetUserField orders the results by target_user field.
+func ByTargetUserField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newTargetUserStep(), sql.OrderByField(field, opts...))
+ }
+}
+
+// ByAdoptionDecisionField orders the results by adoption_decision field.
+func ByAdoptionDecisionField(field string, opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionStep(), sql.OrderByField(field, opts...))
+ }
+}
+func newTargetUserStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(TargetUserInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn),
+ )
+}
+func newAdoptionDecisionStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AdoptionDecisionInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn),
+ )
+}
diff --git a/backend/ent/pendingauthsession/where.go b/backend/ent/pendingauthsession/where.go
new file mode 100644
index 00000000..cb316f47
--- /dev/null
+++ b/backend/ent/pendingauthsession/where.go
@@ -0,0 +1,1262 @@
+// Code generated by ent, DO NOT EDIT.
+
+package pendingauthsession
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// SessionToken applies equality check predicate on the "session_token" field. It's identical to SessionTokenEQ.
+func SessionToken(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v))
+}
+
+// Intent applies equality check predicate on the "intent" field. It's identical to IntentEQ.
+func Intent(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v))
+}
+
+// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ.
+func ProviderType(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ.
+func ProviderKey(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ.
+func ProviderSubject(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// TargetUserID applies equality check predicate on the "target_user_id" field. It's identical to TargetUserIDEQ.
+func TargetUserID(v int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v))
+}
+
+// RedirectTo applies equality check predicate on the "redirect_to" field. It's identical to RedirectToEQ.
+func RedirectTo(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v))
+}
+
+// ResolvedEmail applies equality check predicate on the "resolved_email" field. It's identical to ResolvedEmailEQ.
+func ResolvedEmail(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v))
+}
+
+// RegistrationPasswordHash applies equality check predicate on the "registration_password_hash" field. It's identical to RegistrationPasswordHashEQ.
+func RegistrationPasswordHash(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v))
+}
+
+// BrowserSessionKey applies equality check predicate on the "browser_session_key" field. It's identical to BrowserSessionKeyEQ.
+func BrowserSessionKey(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v))
+}
+
+// CompletionCodeHash applies equality check predicate on the "completion_code_hash" field. It's identical to CompletionCodeHashEQ.
+func CompletionCodeHash(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeExpiresAt applies equality check predicate on the "completion_code_expires_at" field. It's identical to CompletionCodeExpiresAtEQ.
+func CompletionCodeExpiresAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v))
+}
+
+// EmailVerifiedAt applies equality check predicate on the "email_verified_at" field. It's identical to EmailVerifiedAtEQ.
+func EmailVerifiedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v))
+}
+
+// PasswordVerifiedAt applies equality check predicate on the "password_verified_at" field. It's identical to PasswordVerifiedAtEQ.
+func PasswordVerifiedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v))
+}
+
+// TotpVerifiedAt applies equality check predicate on the "totp_verified_at" field. It's identical to TotpVerifiedAtEQ.
+func TotpVerifiedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v))
+}
+
+// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
+func ExpiresAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// ConsumedAt applies equality check predicate on the "consumed_at" field. It's identical to ConsumedAtEQ.
+func ConsumedAt(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// SessionTokenEQ applies the EQ predicate on the "session_token" field.
+func SessionTokenEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v))
+}
+
+// SessionTokenNEQ applies the NEQ predicate on the "session_token" field.
+func SessionTokenNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldSessionToken, v))
+}
+
+// SessionTokenIn applies the In predicate on the "session_token" field.
+func SessionTokenIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldSessionToken, vs...))
+}
+
+// SessionTokenNotIn applies the NotIn predicate on the "session_token" field.
+func SessionTokenNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldSessionToken, vs...))
+}
+
+// SessionTokenGT applies the GT predicate on the "session_token" field.
+func SessionTokenGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldSessionToken, v))
+}
+
+// SessionTokenGTE applies the GTE predicate on the "session_token" field.
+func SessionTokenGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldSessionToken, v))
+}
+
+// SessionTokenLT applies the LT predicate on the "session_token" field.
+func SessionTokenLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldSessionToken, v))
+}
+
+// SessionTokenLTE applies the LTE predicate on the "session_token" field.
+func SessionTokenLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldSessionToken, v))
+}
+
+// SessionTokenContains applies the Contains predicate on the "session_token" field.
+func SessionTokenContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldSessionToken, v))
+}
+
+// SessionTokenHasPrefix applies the HasPrefix predicate on the "session_token" field.
+func SessionTokenHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldSessionToken, v))
+}
+
+// SessionTokenHasSuffix applies the HasSuffix predicate on the "session_token" field.
+func SessionTokenHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldSessionToken, v))
+}
+
+// SessionTokenEqualFold applies the EqualFold predicate on the "session_token" field.
+func SessionTokenEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldSessionToken, v))
+}
+
+// SessionTokenContainsFold applies the ContainsFold predicate on the "session_token" field.
+func SessionTokenContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldSessionToken, v))
+}
+
+// IntentEQ applies the EQ predicate on the "intent" field.
+func IntentEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v))
+}
+
+// IntentNEQ applies the NEQ predicate on the "intent" field.
+func IntentNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldIntent, v))
+}
+
+// IntentIn applies the In predicate on the "intent" field.
+func IntentIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldIntent, vs...))
+}
+
+// IntentNotIn applies the NotIn predicate on the "intent" field.
+func IntentNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldIntent, vs...))
+}
+
+// IntentGT applies the GT predicate on the "intent" field.
+func IntentGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldIntent, v))
+}
+
+// IntentGTE applies the GTE predicate on the "intent" field.
+func IntentGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldIntent, v))
+}
+
+// IntentLT applies the LT predicate on the "intent" field.
+func IntentLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldIntent, v))
+}
+
+// IntentLTE applies the LTE predicate on the "intent" field.
+func IntentLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldIntent, v))
+}
+
+// IntentContains applies the Contains predicate on the "intent" field.
+func IntentContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldIntent, v))
+}
+
+// IntentHasPrefix applies the HasPrefix predicate on the "intent" field.
+func IntentHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldIntent, v))
+}
+
+// IntentHasSuffix applies the HasSuffix predicate on the "intent" field.
+func IntentHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldIntent, v))
+}
+
+// IntentEqualFold applies the EqualFold predicate on the "intent" field.
+func IntentEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldIntent, v))
+}
+
+// IntentContainsFold applies the ContainsFold predicate on the "intent" field.
+func IntentContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldIntent, v))
+}
+
+// ProviderTypeEQ applies the EQ predicate on the "provider_type" field.
+func ProviderTypeEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v))
+}
+
+// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field.
+func ProviderTypeNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderType, v))
+}
+
+// ProviderTypeIn applies the In predicate on the "provider_type" field.
+func ProviderTypeIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field.
+func ProviderTypeNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderType, vs...))
+}
+
+// ProviderTypeGT applies the GT predicate on the "provider_type" field.
+func ProviderTypeGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldProviderType, v))
+}
+
+// ProviderTypeGTE applies the GTE predicate on the "provider_type" field.
+func ProviderTypeGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderType, v))
+}
+
+// ProviderTypeLT applies the LT predicate on the "provider_type" field.
+func ProviderTypeLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldProviderType, v))
+}
+
+// ProviderTypeLTE applies the LTE predicate on the "provider_type" field.
+func ProviderTypeLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderType, v))
+}
+
+// ProviderTypeContains applies the Contains predicate on the "provider_type" field.
+func ProviderTypeContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldProviderType, v))
+}
+
+// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field.
+func ProviderTypeHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderType, v))
+}
+
+// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field.
+func ProviderTypeHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderType, v))
+}
+
+// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field.
+func ProviderTypeEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderType, v))
+}
+
+// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field.
+func ProviderTypeContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderType, v))
+}
+
+// ProviderKeyEQ applies the EQ predicate on the "provider_key" field.
+func ProviderKeyEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field.
+func ProviderKeyNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderKey, v))
+}
+
+// ProviderKeyIn applies the In predicate on the "provider_key" field.
+func ProviderKeyIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field.
+func ProviderKeyNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderKey, vs...))
+}
+
+// ProviderKeyGT applies the GT predicate on the "provider_key" field.
+func ProviderKeyGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldProviderKey, v))
+}
+
+// ProviderKeyGTE applies the GTE predicate on the "provider_key" field.
+func ProviderKeyGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderKey, v))
+}
+
+// ProviderKeyLT applies the LT predicate on the "provider_key" field.
+func ProviderKeyLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldProviderKey, v))
+}
+
+// ProviderKeyLTE applies the LTE predicate on the "provider_key" field.
+func ProviderKeyLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderKey, v))
+}
+
+// ProviderKeyContains applies the Contains predicate on the "provider_key" field.
+func ProviderKeyContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldProviderKey, v))
+}
+
+// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field.
+func ProviderKeyHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderKey, v))
+}
+
+// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field.
+func ProviderKeyHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderKey, v))
+}
+
+// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field.
+func ProviderKeyEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderKey, v))
+}
+
+// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field.
+func ProviderKeyContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderKey, v))
+}
+
+// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field.
+func ProviderSubjectEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field.
+func ProviderSubjectNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderSubject, v))
+}
+
+// ProviderSubjectIn applies the In predicate on the "provider_subject" field.
+func ProviderSubjectIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field.
+func ProviderSubjectNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderSubject, vs...))
+}
+
+// ProviderSubjectGT applies the GT predicate on the "provider_subject" field.
+func ProviderSubjectGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field.
+func ProviderSubjectGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLT applies the LT predicate on the "provider_subject" field.
+func ProviderSubjectLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldProviderSubject, v))
+}
+
+// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field.
+func ProviderSubjectLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field.
+func ProviderSubjectContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field.
+func ProviderSubjectHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field.
+func ProviderSubjectHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderSubject, v))
+}
+
+// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field.
+func ProviderSubjectEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderSubject, v))
+}
+
+// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field.
+func ProviderSubjectContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderSubject, v))
+}
+
+// TargetUserIDEQ applies the EQ predicate on the "target_user_id" field.
+func TargetUserIDEQ(v int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v))
+}
+
+// TargetUserIDNEQ applies the NEQ predicate on the "target_user_id" field.
+func TargetUserIDNEQ(v int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldTargetUserID, v))
+}
+
+// TargetUserIDIn applies the In predicate on the "target_user_id" field.
+func TargetUserIDIn(vs ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldTargetUserID, vs...))
+}
+
+// TargetUserIDNotIn applies the NotIn predicate on the "target_user_id" field.
+func TargetUserIDNotIn(vs ...int64) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldTargetUserID, vs...))
+}
+
+// TargetUserIDIsNil applies the IsNil predicate on the "target_user_id" field.
+func TargetUserIDIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldTargetUserID))
+}
+
+// TargetUserIDNotNil applies the NotNil predicate on the "target_user_id" field.
+func TargetUserIDNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldTargetUserID))
+}
+
+// RedirectToEQ applies the EQ predicate on the "redirect_to" field.
+func RedirectToEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v))
+}
+
+// RedirectToNEQ applies the NEQ predicate on the "redirect_to" field.
+func RedirectToNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldRedirectTo, v))
+}
+
+// RedirectToIn applies the In predicate on the "redirect_to" field.
+func RedirectToIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldRedirectTo, vs...))
+}
+
+// RedirectToNotIn applies the NotIn predicate on the "redirect_to" field.
+func RedirectToNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldRedirectTo, vs...))
+}
+
+// RedirectToGT applies the GT predicate on the "redirect_to" field.
+func RedirectToGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldRedirectTo, v))
+}
+
+// RedirectToGTE applies the GTE predicate on the "redirect_to" field.
+func RedirectToGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldRedirectTo, v))
+}
+
+// RedirectToLT applies the LT predicate on the "redirect_to" field.
+func RedirectToLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldRedirectTo, v))
+}
+
+// RedirectToLTE applies the LTE predicate on the "redirect_to" field.
+func RedirectToLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldRedirectTo, v))
+}
+
+// RedirectToContains applies the Contains predicate on the "redirect_to" field.
+func RedirectToContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldRedirectTo, v))
+}
+
+// RedirectToHasPrefix applies the HasPrefix predicate on the "redirect_to" field.
+func RedirectToHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRedirectTo, v))
+}
+
+// RedirectToHasSuffix applies the HasSuffix predicate on the "redirect_to" field.
+func RedirectToHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRedirectTo, v))
+}
+
+// RedirectToEqualFold applies the EqualFold predicate on the "redirect_to" field.
+func RedirectToEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRedirectTo, v))
+}
+
+// RedirectToContainsFold applies the ContainsFold predicate on the "redirect_to" field.
+func RedirectToContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRedirectTo, v))
+}
+
+// ResolvedEmailEQ applies the EQ predicate on the "resolved_email" field.
+func ResolvedEmailEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailNEQ applies the NEQ predicate on the "resolved_email" field.
+func ResolvedEmailNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailIn applies the In predicate on the "resolved_email" field.
+func ResolvedEmailIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldResolvedEmail, vs...))
+}
+
+// ResolvedEmailNotIn applies the NotIn predicate on the "resolved_email" field.
+func ResolvedEmailNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldResolvedEmail, vs...))
+}
+
+// ResolvedEmailGT applies the GT predicate on the "resolved_email" field.
+func ResolvedEmailGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailGTE applies the GTE predicate on the "resolved_email" field.
+func ResolvedEmailGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailLT applies the LT predicate on the "resolved_email" field.
+func ResolvedEmailLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailLTE applies the LTE predicate on the "resolved_email" field.
+func ResolvedEmailLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailContains applies the Contains predicate on the "resolved_email" field.
+func ResolvedEmailContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailHasPrefix applies the HasPrefix predicate on the "resolved_email" field.
+func ResolvedEmailHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailHasSuffix applies the HasSuffix predicate on the "resolved_email" field.
+func ResolvedEmailHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailEqualFold applies the EqualFold predicate on the "resolved_email" field.
+func ResolvedEmailEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldResolvedEmail, v))
+}
+
+// ResolvedEmailContainsFold applies the ContainsFold predicate on the "resolved_email" field.
+func ResolvedEmailContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldResolvedEmail, v))
+}
+
+// RegistrationPasswordHashEQ applies the EQ predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashNEQ applies the NEQ predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashIn applies the In predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldRegistrationPasswordHash, vs...))
+}
+
+// RegistrationPasswordHashNotIn applies the NotIn predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldRegistrationPasswordHash, vs...))
+}
+
+// RegistrationPasswordHashGT applies the GT predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashGTE applies the GTE predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashLT applies the LT predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashLTE applies the LTE predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashContains applies the Contains predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashHasPrefix applies the HasPrefix predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashHasSuffix applies the HasSuffix predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashEqualFold applies the EqualFold predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRegistrationPasswordHash, v))
+}
+
+// RegistrationPasswordHashContainsFold applies the ContainsFold predicate on the "registration_password_hash" field.
+func RegistrationPasswordHashContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRegistrationPasswordHash, v))
+}
+
+// BrowserSessionKeyEQ applies the EQ predicate on the "browser_session_key" field.
+func BrowserSessionKeyEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyNEQ applies the NEQ predicate on the "browser_session_key" field.
+func BrowserSessionKeyNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyIn applies the In predicate on the "browser_session_key" field.
+func BrowserSessionKeyIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldBrowserSessionKey, vs...))
+}
+
+// BrowserSessionKeyNotIn applies the NotIn predicate on the "browser_session_key" field.
+func BrowserSessionKeyNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldBrowserSessionKey, vs...))
+}
+
+// BrowserSessionKeyGT applies the GT predicate on the "browser_session_key" field.
+func BrowserSessionKeyGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyGTE applies the GTE predicate on the "browser_session_key" field.
+func BrowserSessionKeyGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyLT applies the LT predicate on the "browser_session_key" field.
+func BrowserSessionKeyLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyLTE applies the LTE predicate on the "browser_session_key" field.
+func BrowserSessionKeyLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyContains applies the Contains predicate on the "browser_session_key" field.
+func BrowserSessionKeyContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyHasPrefix applies the HasPrefix predicate on the "browser_session_key" field.
+func BrowserSessionKeyHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyHasSuffix applies the HasSuffix predicate on the "browser_session_key" field.
+func BrowserSessionKeyHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyEqualFold applies the EqualFold predicate on the "browser_session_key" field.
+func BrowserSessionKeyEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldBrowserSessionKey, v))
+}
+
+// BrowserSessionKeyContainsFold applies the ContainsFold predicate on the "browser_session_key" field.
+func BrowserSessionKeyContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldBrowserSessionKey, v))
+}
+
+// CompletionCodeHashEQ applies the EQ predicate on the "completion_code_hash" field.
+func CompletionCodeHashEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashNEQ applies the NEQ predicate on the "completion_code_hash" field.
+func CompletionCodeHashNEQ(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashIn applies the In predicate on the "completion_code_hash" field.
+func CompletionCodeHashIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeHash, vs...))
+}
+
+// CompletionCodeHashNotIn applies the NotIn predicate on the "completion_code_hash" field.
+func CompletionCodeHashNotIn(vs ...string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeHash, vs...))
+}
+
+// CompletionCodeHashGT applies the GT predicate on the "completion_code_hash" field.
+func CompletionCodeHashGT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashGTE applies the GTE predicate on the "completion_code_hash" field.
+func CompletionCodeHashGTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashLT applies the LT predicate on the "completion_code_hash" field.
+func CompletionCodeHashLT(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashLTE applies the LTE predicate on the "completion_code_hash" field.
+func CompletionCodeHashLTE(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashContains applies the Contains predicate on the "completion_code_hash" field.
+func CompletionCodeHashContains(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContains(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashHasPrefix applies the HasPrefix predicate on the "completion_code_hash" field.
+func CompletionCodeHashHasPrefix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashHasSuffix applies the HasSuffix predicate on the "completion_code_hash" field.
+func CompletionCodeHashHasSuffix(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashEqualFold applies the EqualFold predicate on the "completion_code_hash" field.
+func CompletionCodeHashEqualFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEqualFold(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeHashContainsFold applies the ContainsFold predicate on the "completion_code_hash" field.
+func CompletionCodeHashContainsFold(v string) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldContainsFold(FieldCompletionCodeHash, v))
+}
+
+// CompletionCodeExpiresAtEQ applies the EQ predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtNEQ applies the NEQ predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtIn applies the In predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeExpiresAt, vs...))
+}
+
+// CompletionCodeExpiresAtNotIn applies the NotIn predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeExpiresAt, vs...))
+}
+
+// CompletionCodeExpiresAtGT applies the GT predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtGTE applies the GTE predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtLT applies the LT predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtLTE applies the LTE predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeExpiresAt, v))
+}
+
+// CompletionCodeExpiresAtIsNil applies the IsNil predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldCompletionCodeExpiresAt))
+}
+
+// CompletionCodeExpiresAtNotNil applies the NotNil predicate on the "completion_code_expires_at" field.
+func CompletionCodeExpiresAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldCompletionCodeExpiresAt))
+}
+
+// EmailVerifiedAtEQ applies the EQ predicate on the "email_verified_at" field.
+func EmailVerifiedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtNEQ applies the NEQ predicate on the "email_verified_at" field.
+func EmailVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtIn applies the In predicate on the "email_verified_at" field.
+func EmailVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldEmailVerifiedAt, vs...))
+}
+
+// EmailVerifiedAtNotIn applies the NotIn predicate on the "email_verified_at" field.
+func EmailVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldEmailVerifiedAt, vs...))
+}
+
+// EmailVerifiedAtGT applies the GT predicate on the "email_verified_at" field.
+func EmailVerifiedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtGTE applies the GTE predicate on the "email_verified_at" field.
+func EmailVerifiedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtLT applies the LT predicate on the "email_verified_at" field.
+func EmailVerifiedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtLTE applies the LTE predicate on the "email_verified_at" field.
+func EmailVerifiedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldEmailVerifiedAt, v))
+}
+
+// EmailVerifiedAtIsNil applies the IsNil predicate on the "email_verified_at" field.
+func EmailVerifiedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldEmailVerifiedAt))
+}
+
+// EmailVerifiedAtNotNil applies the NotNil predicate on the "email_verified_at" field.
+func EmailVerifiedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldEmailVerifiedAt))
+}
+
+// PasswordVerifiedAtEQ applies the EQ predicate on the "password_verified_at" field.
+func PasswordVerifiedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtNEQ applies the NEQ predicate on the "password_verified_at" field.
+func PasswordVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtIn applies the In predicate on the "password_verified_at" field.
+func PasswordVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldPasswordVerifiedAt, vs...))
+}
+
+// PasswordVerifiedAtNotIn applies the NotIn predicate on the "password_verified_at" field.
+func PasswordVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldPasswordVerifiedAt, vs...))
+}
+
+// PasswordVerifiedAtGT applies the GT predicate on the "password_verified_at" field.
+func PasswordVerifiedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtGTE applies the GTE predicate on the "password_verified_at" field.
+func PasswordVerifiedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtLT applies the LT predicate on the "password_verified_at" field.
+func PasswordVerifiedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtLTE applies the LTE predicate on the "password_verified_at" field.
+func PasswordVerifiedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldPasswordVerifiedAt, v))
+}
+
+// PasswordVerifiedAtIsNil applies the IsNil predicate on the "password_verified_at" field.
+func PasswordVerifiedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldPasswordVerifiedAt))
+}
+
+// PasswordVerifiedAtNotNil applies the NotNil predicate on the "password_verified_at" field.
+func PasswordVerifiedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldPasswordVerifiedAt))
+}
+
+// TotpVerifiedAtEQ applies the EQ predicate on the "totp_verified_at" field.
+func TotpVerifiedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtNEQ applies the NEQ predicate on the "totp_verified_at" field.
+func TotpVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtIn applies the In predicate on the "totp_verified_at" field.
+func TotpVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldTotpVerifiedAt, vs...))
+}
+
+// TotpVerifiedAtNotIn applies the NotIn predicate on the "totp_verified_at" field.
+func TotpVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldTotpVerifiedAt, vs...))
+}
+
+// TotpVerifiedAtGT applies the GT predicate on the "totp_verified_at" field.
+func TotpVerifiedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtGTE applies the GTE predicate on the "totp_verified_at" field.
+func TotpVerifiedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtLT applies the LT predicate on the "totp_verified_at" field.
+func TotpVerifiedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtLTE applies the LTE predicate on the "totp_verified_at" field.
+func TotpVerifiedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldTotpVerifiedAt, v))
+}
+
+// TotpVerifiedAtIsNil applies the IsNil predicate on the "totp_verified_at" field.
+func TotpVerifiedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldTotpVerifiedAt))
+}
+
+// TotpVerifiedAtNotNil applies the NotNil predicate on the "totp_verified_at" field.
+func TotpVerifiedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldTotpVerifiedAt))
+}
+
+// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
+func ExpiresAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
+func ExpiresAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtIn applies the In predicate on the "expires_at" field.
+func ExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
+func ExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtGT applies the GT predicate on the "expires_at" field.
+func ExpiresAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldExpiresAt, v))
+}
+
+// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
+func ExpiresAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldExpiresAt, v))
+}
+
+// ExpiresAtLT applies the LT predicate on the "expires_at" field.
+func ExpiresAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldExpiresAt, v))
+}
+
+// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
+func ExpiresAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldExpiresAt, v))
+}
+
+// ConsumedAtEQ applies the EQ predicate on the "consumed_at" field.
+func ConsumedAtEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v))
+}
+
+// ConsumedAtNEQ applies the NEQ predicate on the "consumed_at" field.
+func ConsumedAtNEQ(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNEQ(FieldConsumedAt, v))
+}
+
+// ConsumedAtIn applies the In predicate on the "consumed_at" field.
+func ConsumedAtIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIn(FieldConsumedAt, vs...))
+}
+
+// ConsumedAtNotIn applies the NotIn predicate on the "consumed_at" field.
+func ConsumedAtNotIn(vs ...time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotIn(FieldConsumedAt, vs...))
+}
+
+// ConsumedAtGT applies the GT predicate on the "consumed_at" field.
+func ConsumedAtGT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGT(FieldConsumedAt, v))
+}
+
+// ConsumedAtGTE applies the GTE predicate on the "consumed_at" field.
+func ConsumedAtGTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldGTE(FieldConsumedAt, v))
+}
+
+// ConsumedAtLT applies the LT predicate on the "consumed_at" field.
+func ConsumedAtLT(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLT(FieldConsumedAt, v))
+}
+
+// ConsumedAtLTE applies the LTE predicate on the "consumed_at" field.
+func ConsumedAtLTE(v time.Time) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldLTE(FieldConsumedAt, v))
+}
+
+// ConsumedAtIsNil applies the IsNil predicate on the "consumed_at" field.
+func ConsumedAtIsNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldIsNull(FieldConsumedAt))
+}
+
+// ConsumedAtNotNil applies the NotNil predicate on the "consumed_at" field.
+func ConsumedAtNotNil() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.FieldNotNull(FieldConsumedAt))
+}
+
+// HasTargetUser applies the HasEdge predicate on the "target_user" edge.
+func HasTargetUser() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasTargetUserWith applies the HasEdge predicate on the "target_user" edge with a given conditions (other predicates).
+func HasTargetUserWith(preds ...predicate.User) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := newTargetUserStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasAdoptionDecision applies the HasEdge predicate on the "adoption_decision" edge.
+func HasAdoptionDecision() predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAdoptionDecisionWith applies the HasEdge predicate on the "adoption_decision" edge with a given conditions (other predicates).
+func HasAdoptionDecisionWith(preds ...predicate.IdentityAdoptionDecision) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(func(s *sql.Selector) {
+ step := newAdoptionDecisionStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.PendingAuthSession) predicate.PendingAuthSession {
+ return predicate.PendingAuthSession(sql.NotPredicates(p))
+}
diff --git a/backend/ent/pendingauthsession_create.go b/backend/ent/pendingauthsession_create.go
new file mode 100644
index 00000000..60276daa
--- /dev/null
+++ b/backend/ent/pendingauthsession_create.go
@@ -0,0 +1,1815 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSessionCreate is the builder for creating a PendingAuthSession entity.
+type PendingAuthSessionCreate struct {
+ config
+ mutation *PendingAuthSessionMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *PendingAuthSessionCreate) SetCreatedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableCreatedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *PendingAuthSessionCreate) SetUpdatedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableUpdatedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetSessionToken sets the "session_token" field.
+func (_c *PendingAuthSessionCreate) SetSessionToken(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetSessionToken(v)
+ return _c
+}
+
+// SetIntent sets the "intent" field.
+func (_c *PendingAuthSessionCreate) SetIntent(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetIntent(v)
+ return _c
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_c *PendingAuthSessionCreate) SetProviderType(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetProviderType(v)
+ return _c
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_c *PendingAuthSessionCreate) SetProviderKey(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetProviderKey(v)
+ return _c
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_c *PendingAuthSessionCreate) SetProviderSubject(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetProviderSubject(v)
+ return _c
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (_c *PendingAuthSessionCreate) SetTargetUserID(v int64) *PendingAuthSessionCreate {
+ _c.mutation.SetTargetUserID(v)
+ return _c
+}
+
+// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableTargetUserID(v *int64) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetTargetUserID(*v)
+ }
+ return _c
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (_c *PendingAuthSessionCreate) SetRedirectTo(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetRedirectTo(v)
+ return _c
+}
+
+// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableRedirectTo(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetRedirectTo(*v)
+ }
+ return _c
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (_c *PendingAuthSessionCreate) SetResolvedEmail(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetResolvedEmail(v)
+ return _c
+}
+
+// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableResolvedEmail(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetResolvedEmail(*v)
+ }
+ return _c
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (_c *PendingAuthSessionCreate) SetRegistrationPasswordHash(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetRegistrationPasswordHash(v)
+ return _c
+}
+
+// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetRegistrationPasswordHash(*v)
+ }
+ return _c
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (_c *PendingAuthSessionCreate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionCreate {
+ _c.mutation.SetUpstreamIdentityClaims(v)
+ return _c
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (_c *PendingAuthSessionCreate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionCreate {
+ _c.mutation.SetLocalFlowState(v)
+ return _c
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (_c *PendingAuthSessionCreate) SetBrowserSessionKey(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetBrowserSessionKey(v)
+ return _c
+}
+
+// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetBrowserSessionKey(*v)
+ }
+ return _c
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (_c *PendingAuthSessionCreate) SetCompletionCodeHash(v string) *PendingAuthSessionCreate {
+ _c.mutation.SetCompletionCodeHash(v)
+ return _c
+}
+
+// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetCompletionCodeHash(*v)
+ }
+ return _c
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (_c *PendingAuthSessionCreate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetCompletionCodeExpiresAt(v)
+ return _c
+}
+
+// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetCompletionCodeExpiresAt(*v)
+ }
+ return _c
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (_c *PendingAuthSessionCreate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetEmailVerifiedAt(v)
+ return _c
+}
+
+// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetEmailVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (_c *PendingAuthSessionCreate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetPasswordVerifiedAt(v)
+ return _c
+}
+
+// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetPasswordVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (_c *PendingAuthSessionCreate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetTotpVerifiedAt(v)
+ return _c
+}
+
+// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetTotpVerifiedAt(*v)
+ }
+ return _c
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_c *PendingAuthSessionCreate) SetExpiresAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetExpiresAt(v)
+ return _c
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (_c *PendingAuthSessionCreate) SetConsumedAt(v time.Time) *PendingAuthSessionCreate {
+ _c.mutation.SetConsumedAt(v)
+ return _c
+}
+
+// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionCreate {
+ if v != nil {
+ _c.SetConsumedAt(*v)
+ }
+ return _c
+}
+
+// SetTargetUser sets the "target_user" edge to the User entity.
+func (_c *PendingAuthSessionCreate) SetTargetUser(v *User) *PendingAuthSessionCreate {
+ return _c.SetTargetUserID(v.ID)
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID.
+func (_c *PendingAuthSessionCreate) SetAdoptionDecisionID(id int64) *PendingAuthSessionCreate {
+ _c.mutation.SetAdoptionDecisionID(id)
+ return _c
+}
+
+// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil.
+func (_c *PendingAuthSessionCreate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionCreate {
+ if id != nil {
+ _c = _c.SetAdoptionDecisionID(*id)
+ }
+ return _c
+}
+
+// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_c *PendingAuthSessionCreate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionCreate {
+ return _c.SetAdoptionDecisionID(v.ID)
+}
+
+// Mutation returns the PendingAuthSessionMutation object of the builder.
+func (_c *PendingAuthSessionCreate) Mutation() *PendingAuthSessionMutation {
+ return _c.mutation
+}
+
+// Save creates the PendingAuthSession in the database.
+func (_c *PendingAuthSessionCreate) Save(ctx context.Context) (*PendingAuthSession, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *PendingAuthSessionCreate) SaveX(ctx context.Context) *PendingAuthSession {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PendingAuthSessionCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PendingAuthSessionCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *PendingAuthSessionCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := pendingauthsession.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := pendingauthsession.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.RedirectTo(); !ok {
+ v := pendingauthsession.DefaultRedirectTo
+ _c.mutation.SetRedirectTo(v)
+ }
+ if _, ok := _c.mutation.ResolvedEmail(); !ok {
+ v := pendingauthsession.DefaultResolvedEmail
+ _c.mutation.SetResolvedEmail(v)
+ }
+ if _, ok := _c.mutation.RegistrationPasswordHash(); !ok {
+ v := pendingauthsession.DefaultRegistrationPasswordHash
+ _c.mutation.SetRegistrationPasswordHash(v)
+ }
+ if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok {
+ v := pendingauthsession.DefaultUpstreamIdentityClaims()
+ _c.mutation.SetUpstreamIdentityClaims(v)
+ }
+ if _, ok := _c.mutation.LocalFlowState(); !ok {
+ v := pendingauthsession.DefaultLocalFlowState()
+ _c.mutation.SetLocalFlowState(v)
+ }
+ if _, ok := _c.mutation.BrowserSessionKey(); !ok {
+ v := pendingauthsession.DefaultBrowserSessionKey
+ _c.mutation.SetBrowserSessionKey(v)
+ }
+ if _, ok := _c.mutation.CompletionCodeHash(); !ok {
+ v := pendingauthsession.DefaultCompletionCodeHash
+ _c.mutation.SetCompletionCodeHash(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *PendingAuthSessionCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PendingAuthSession.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "PendingAuthSession.updated_at"`)}
+ }
+ if _, ok := _c.mutation.SessionToken(); !ok {
+ return &ValidationError{Name: "session_token", err: errors.New(`ent: missing required field "PendingAuthSession.session_token"`)}
+ }
+ if v, ok := _c.mutation.SessionToken(); ok {
+ if err := pendingauthsession.SessionTokenValidator(v); err != nil {
+ return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Intent(); !ok {
+ return &ValidationError{Name: "intent", err: errors.New(`ent: missing required field "PendingAuthSession.intent"`)}
+ }
+ if v, ok := _c.mutation.Intent(); ok {
+ if err := pendingauthsession.IntentValidator(v); err != nil {
+ return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderType(); !ok {
+ return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "PendingAuthSession.provider_type"`)}
+ }
+ if v, ok := _c.mutation.ProviderType(); ok {
+ if err := pendingauthsession.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderKey(); !ok {
+ return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "PendingAuthSession.provider_key"`)}
+ }
+ if v, ok := _c.mutation.ProviderKey(); ok {
+ if err := pendingauthsession.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ProviderSubject(); !ok {
+ return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "PendingAuthSession.provider_subject"`)}
+ }
+ if v, ok := _c.mutation.ProviderSubject(); ok {
+ if err := pendingauthsession.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.RedirectTo(); !ok {
+ return &ValidationError{Name: "redirect_to", err: errors.New(`ent: missing required field "PendingAuthSession.redirect_to"`)}
+ }
+ if _, ok := _c.mutation.ResolvedEmail(); !ok {
+ return &ValidationError{Name: "resolved_email", err: errors.New(`ent: missing required field "PendingAuthSession.resolved_email"`)}
+ }
+ if _, ok := _c.mutation.RegistrationPasswordHash(); !ok {
+ return &ValidationError{Name: "registration_password_hash", err: errors.New(`ent: missing required field "PendingAuthSession.registration_password_hash"`)}
+ }
+ if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok {
+ return &ValidationError{Name: "upstream_identity_claims", err: errors.New(`ent: missing required field "PendingAuthSession.upstream_identity_claims"`)}
+ }
+ if _, ok := _c.mutation.LocalFlowState(); !ok {
+ return &ValidationError{Name: "local_flow_state", err: errors.New(`ent: missing required field "PendingAuthSession.local_flow_state"`)}
+ }
+ if _, ok := _c.mutation.BrowserSessionKey(); !ok {
+ return &ValidationError{Name: "browser_session_key", err: errors.New(`ent: missing required field "PendingAuthSession.browser_session_key"`)}
+ }
+ if _, ok := _c.mutation.CompletionCodeHash(); !ok {
+ return &ValidationError{Name: "completion_code_hash", err: errors.New(`ent: missing required field "PendingAuthSession.completion_code_hash"`)}
+ }
+ if _, ok := _c.mutation.ExpiresAt(); !ok {
+ return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "PendingAuthSession.expires_at"`)}
+ }
+ return nil
+}
+
+func (_c *PendingAuthSessionCreate) sqlSave(ctx context.Context) (*PendingAuthSession, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *PendingAuthSessionCreate) createSpec() (*PendingAuthSession, *sqlgraph.CreateSpec) {
+ var (
+ _node = &PendingAuthSession{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.SessionToken(); ok {
+ _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value)
+ _node.SessionToken = value
+ }
+ if value, ok := _c.mutation.Intent(); ok {
+ _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value)
+ _node.Intent = value
+ }
+ if value, ok := _c.mutation.ProviderType(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value)
+ _node.ProviderType = value
+ }
+ if value, ok := _c.mutation.ProviderKey(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value)
+ _node.ProviderKey = value
+ }
+ if value, ok := _c.mutation.ProviderSubject(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value)
+ _node.ProviderSubject = value
+ }
+ if value, ok := _c.mutation.RedirectTo(); ok {
+ _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value)
+ _node.RedirectTo = value
+ }
+ if value, ok := _c.mutation.ResolvedEmail(); ok {
+ _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value)
+ _node.ResolvedEmail = value
+ }
+ if value, ok := _c.mutation.RegistrationPasswordHash(); ok {
+ _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value)
+ _node.RegistrationPasswordHash = value
+ }
+ if value, ok := _c.mutation.UpstreamIdentityClaims(); ok {
+ _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value)
+ _node.UpstreamIdentityClaims = value
+ }
+ if value, ok := _c.mutation.LocalFlowState(); ok {
+ _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value)
+ _node.LocalFlowState = value
+ }
+ if value, ok := _c.mutation.BrowserSessionKey(); ok {
+ _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value)
+ _node.BrowserSessionKey = value
+ }
+ if value, ok := _c.mutation.CompletionCodeHash(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value)
+ _node.CompletionCodeHash = value
+ }
+ if value, ok := _c.mutation.CompletionCodeExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value)
+ _node.CompletionCodeExpiresAt = &value
+ }
+ if value, ok := _c.mutation.EmailVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value)
+ _node.EmailVerifiedAt = &value
+ }
+ if value, ok := _c.mutation.PasswordVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value)
+ _node.PasswordVerifiedAt = &value
+ }
+ if value, ok := _c.mutation.TotpVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value)
+ _node.TotpVerifiedAt = &value
+ }
+ if value, ok := _c.mutation.ExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value)
+ _node.ExpiresAt = value
+ }
+ if value, ok := _c.mutation.ConsumedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value)
+ _node.ConsumedAt = &value
+ }
+ if nodes := _c.mutation.TargetUserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _node.TargetUserID = &nodes[0]
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.AdoptionDecisionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PendingAuthSession.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PendingAuthSessionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreate) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertOne {
+ _c.conflict = opts
+ return &PendingAuthSessionUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreate) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PendingAuthSessionUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // PendingAuthSessionUpsertOne is the builder for "upsert"-ing
+ // one PendingAuthSession node.
+ PendingAuthSessionUpsertOne struct {
+ create *PendingAuthSessionCreate
+ }
+
+ // PendingAuthSessionUpsert is the "OnConflict" setter.
+ PendingAuthSessionUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PendingAuthSessionUpsert) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateUpdatedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldUpdatedAt)
+ return u
+}
+
+// SetSessionToken sets the "session_token" field.
+func (u *PendingAuthSessionUpsert) SetSessionToken(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldSessionToken, v)
+ return u
+}
+
+// UpdateSessionToken sets the "session_token" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateSessionToken() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldSessionToken)
+ return u
+}
+
+// SetIntent sets the "intent" field.
+func (u *PendingAuthSessionUpsert) SetIntent(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldIntent, v)
+ return u
+}
+
+// UpdateIntent sets the "intent" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateIntent() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldIntent)
+ return u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *PendingAuthSessionUpsert) SetProviderType(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldProviderType, v)
+ return u
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateProviderType() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldProviderType)
+ return u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PendingAuthSessionUpsert) SetProviderKey(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldProviderKey, v)
+ return u
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateProviderKey() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldProviderKey)
+ return u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *PendingAuthSessionUpsert) SetProviderSubject(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldProviderSubject, v)
+ return u
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateProviderSubject() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldProviderSubject)
+ return u
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (u *PendingAuthSessionUpsert) SetTargetUserID(v int64) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldTargetUserID, v)
+ return u
+}
+
+// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateTargetUserID() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldTargetUserID)
+ return u
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (u *PendingAuthSessionUpsert) ClearTargetUserID() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldTargetUserID)
+ return u
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (u *PendingAuthSessionUpsert) SetRedirectTo(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldRedirectTo, v)
+ return u
+}
+
+// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateRedirectTo() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldRedirectTo)
+ return u
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (u *PendingAuthSessionUpsert) SetResolvedEmail(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldResolvedEmail, v)
+ return u
+}
+
+// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateResolvedEmail() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldResolvedEmail)
+ return u
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (u *PendingAuthSessionUpsert) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldRegistrationPasswordHash, v)
+ return u
+}
+
+// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldRegistrationPasswordHash)
+ return u
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (u *PendingAuthSessionUpsert) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldUpstreamIdentityClaims, v)
+ return u
+}
+
+// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldUpstreamIdentityClaims)
+ return u
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (u *PendingAuthSessionUpsert) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldLocalFlowState, v)
+ return u
+}
+
+// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateLocalFlowState() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldLocalFlowState)
+ return u
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (u *PendingAuthSessionUpsert) SetBrowserSessionKey(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldBrowserSessionKey, v)
+ return u
+}
+
+// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateBrowserSessionKey() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldBrowserSessionKey)
+ return u
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (u *PendingAuthSessionUpsert) SetCompletionCodeHash(v string) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldCompletionCodeHash, v)
+ return u
+}
+
+// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateCompletionCodeHash() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldCompletionCodeHash)
+ return u
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsert) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldCompletionCodeExpiresAt, v)
+ return u
+}
+
+// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldCompletionCodeExpiresAt)
+ return u
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsert) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldCompletionCodeExpiresAt)
+ return u
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (u *PendingAuthSessionUpsert) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldEmailVerifiedAt, v)
+ return u
+}
+
+// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateEmailVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldEmailVerifiedAt)
+ return u
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (u *PendingAuthSessionUpsert) ClearEmailVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldEmailVerifiedAt)
+ return u
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (u *PendingAuthSessionUpsert) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldPasswordVerifiedAt, v)
+ return u
+}
+
+// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldPasswordVerifiedAt)
+ return u
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (u *PendingAuthSessionUpsert) ClearPasswordVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldPasswordVerifiedAt)
+ return u
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsert) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldTotpVerifiedAt, v)
+ return u
+}
+
+// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateTotpVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldTotpVerifiedAt)
+ return u
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsert) ClearTotpVerifiedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldTotpVerifiedAt)
+ return u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PendingAuthSessionUpsert) SetExpiresAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldExpiresAt, v)
+ return u
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateExpiresAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldExpiresAt)
+ return u
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (u *PendingAuthSessionUpsert) SetConsumedAt(v time.Time) *PendingAuthSessionUpsert {
+ u.Set(pendingauthsession.FieldConsumedAt, v)
+ return u
+}
+
+// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsert) UpdateConsumedAt() *PendingAuthSessionUpsert {
+ u.SetExcluded(pendingauthsession.FieldConsumedAt)
+ return u
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (u *PendingAuthSessionUpsert) ClearConsumedAt() *PendingAuthSessionUpsert {
+ u.SetNull(pendingauthsession.FieldConsumedAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertOne) UpdateNewValues() *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(pendingauthsession.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertOne) Ignore() *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PendingAuthSessionUpsertOne) DoNothing() *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreate.OnConflict
+// documentation for more info.
+func (u *PendingAuthSessionUpsertOne) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PendingAuthSessionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PendingAuthSessionUpsertOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateUpdatedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetSessionToken sets the "session_token" field.
+func (u *PendingAuthSessionUpsertOne) SetSessionToken(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetSessionToken(v)
+ })
+}
+
+// UpdateSessionToken sets the "session_token" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateSessionToken() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateSessionToken()
+ })
+}
+
+// SetIntent sets the "intent" field.
+func (u *PendingAuthSessionUpsertOne) SetIntent(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetIntent(v)
+ })
+}
+
+// UpdateIntent sets the "intent" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateIntent() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateIntent()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *PendingAuthSessionUpsertOne) SetProviderType(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateProviderType() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PendingAuthSessionUpsertOne) SetProviderKey(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateProviderKey() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *PendingAuthSessionUpsertOne) SetProviderSubject(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateProviderSubject() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (u *PendingAuthSessionUpsertOne) SetTargetUserID(v int64) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTargetUserID(v)
+ })
+}
+
+// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateTargetUserID() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTargetUserID()
+ })
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (u *PendingAuthSessionUpsertOne) ClearTargetUserID() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTargetUserID()
+ })
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (u *PendingAuthSessionUpsertOne) SetRedirectTo(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRedirectTo(v)
+ })
+}
+
+// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateRedirectTo() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRedirectTo()
+ })
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (u *PendingAuthSessionUpsertOne) SetResolvedEmail(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetResolvedEmail(v)
+ })
+}
+
+// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateResolvedEmail() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateResolvedEmail()
+ })
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (u *PendingAuthSessionUpsertOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRegistrationPasswordHash(v)
+ })
+}
+
+// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRegistrationPasswordHash()
+ })
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (u *PendingAuthSessionUpsertOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpstreamIdentityClaims(v)
+ })
+}
+
+// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpstreamIdentityClaims()
+ })
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (u *PendingAuthSessionUpsertOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetLocalFlowState(v)
+ })
+}
+
+// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateLocalFlowState() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateLocalFlowState()
+ })
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (u *PendingAuthSessionUpsertOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetBrowserSessionKey(v)
+ })
+}
+
+// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateBrowserSessionKey() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateBrowserSessionKey()
+ })
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (u *PendingAuthSessionUpsertOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeHash(v)
+ })
+}
+
+// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeHash() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeHash()
+ })
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeExpiresAt(v)
+ })
+}
+
+// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeExpiresAt()
+ })
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearCompletionCodeExpiresAt()
+ })
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetEmailVerifiedAt(v)
+ })
+}
+
+// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateEmailVerifiedAt()
+ })
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearEmailVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearEmailVerifiedAt()
+ })
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetPasswordVerifiedAt(v)
+ })
+}
+
+// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdatePasswordVerifiedAt()
+ })
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearPasswordVerifiedAt()
+ })
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTotpVerifiedAt(v)
+ })
+}
+
+// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTotpVerifiedAt()
+ })
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearTotpVerifiedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTotpVerifiedAt()
+ })
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PendingAuthSessionUpsertOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateExpiresAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (u *PendingAuthSessionUpsertOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetConsumedAt(v)
+ })
+}
+
+// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertOne) UpdateConsumedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateConsumedAt()
+ })
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (u *PendingAuthSessionUpsertOne) ClearConsumedAt() *PendingAuthSessionUpsertOne {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearConsumedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *PendingAuthSessionUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PendingAuthSessionCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PendingAuthSessionUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *PendingAuthSessionUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *PendingAuthSessionUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// PendingAuthSessionCreateBulk is the builder for creating many PendingAuthSession entities in bulk.
+type PendingAuthSessionCreateBulk struct {
+ config
+ err error
+ builders []*PendingAuthSessionCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the PendingAuthSession entities in the database.
+func (_c *PendingAuthSessionCreateBulk) Save(ctx context.Context) ([]*PendingAuthSession, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*PendingAuthSession, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*PendingAuthSessionMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *PendingAuthSessionCreateBulk) SaveX(ctx context.Context) []*PendingAuthSession {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *PendingAuthSessionCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *PendingAuthSessionCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.PendingAuthSession.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.PendingAuthSessionUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreateBulk) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertBulk {
+ _c.conflict = opts
+ return &PendingAuthSessionUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *PendingAuthSessionCreateBulk) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &PendingAuthSessionUpsertBulk{
+ create: _c,
+ }
+}
+
+// PendingAuthSessionUpsertBulk is the builder for "upsert"-ing
+// a bulk of PendingAuthSession nodes.
+type PendingAuthSessionUpsertBulk struct {
+ create *PendingAuthSessionCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertBulk) UpdateNewValues() *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(pendingauthsession.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.PendingAuthSession.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *PendingAuthSessionUpsertBulk) Ignore() *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *PendingAuthSessionUpsertBulk) DoNothing() *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreateBulk.OnConflict
+// documentation for more info.
+func (u *PendingAuthSessionUpsertBulk) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&PendingAuthSessionUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateUpdatedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetSessionToken sets the "session_token" field.
+func (u *PendingAuthSessionUpsertBulk) SetSessionToken(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetSessionToken(v)
+ })
+}
+
+// UpdateSessionToken sets the "session_token" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateSessionToken() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateSessionToken()
+ })
+}
+
+// SetIntent sets the "intent" field.
+func (u *PendingAuthSessionUpsertBulk) SetIntent(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetIntent(v)
+ })
+}
+
+// UpdateIntent sets the "intent" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateIntent() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateIntent()
+ })
+}
+
+// SetProviderType sets the "provider_type" field.
+func (u *PendingAuthSessionUpsertBulk) SetProviderType(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderType(v)
+ })
+}
+
+// UpdateProviderType sets the "provider_type" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateProviderType() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderType()
+ })
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (u *PendingAuthSessionUpsertBulk) SetProviderKey(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderKey(v)
+ })
+}
+
+// UpdateProviderKey sets the "provider_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateProviderKey() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderKey()
+ })
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (u *PendingAuthSessionUpsertBulk) SetProviderSubject(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetProviderSubject(v)
+ })
+}
+
+// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateProviderSubject() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateProviderSubject()
+ })
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (u *PendingAuthSessionUpsertBulk) SetTargetUserID(v int64) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTargetUserID(v)
+ })
+}
+
+// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateTargetUserID() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTargetUserID()
+ })
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (u *PendingAuthSessionUpsertBulk) ClearTargetUserID() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTargetUserID()
+ })
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (u *PendingAuthSessionUpsertBulk) SetRedirectTo(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRedirectTo(v)
+ })
+}
+
+// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateRedirectTo() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRedirectTo()
+ })
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (u *PendingAuthSessionUpsertBulk) SetResolvedEmail(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetResolvedEmail(v)
+ })
+}
+
+// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateResolvedEmail() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateResolvedEmail()
+ })
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (u *PendingAuthSessionUpsertBulk) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetRegistrationPasswordHash(v)
+ })
+}
+
+// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateRegistrationPasswordHash()
+ })
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (u *PendingAuthSessionUpsertBulk) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetUpstreamIdentityClaims(v)
+ })
+}
+
+// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateUpstreamIdentityClaims()
+ })
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (u *PendingAuthSessionUpsertBulk) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetLocalFlowState(v)
+ })
+}
+
+// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateLocalFlowState() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateLocalFlowState()
+ })
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (u *PendingAuthSessionUpsertBulk) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetBrowserSessionKey(v)
+ })
+}
+
+// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateBrowserSessionKey() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateBrowserSessionKey()
+ })
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeHash(v)
+ })
+}
+
+// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeHash() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeHash()
+ })
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetCompletionCodeExpiresAt(v)
+ })
+}
+
+// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateCompletionCodeExpiresAt()
+ })
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearCompletionCodeExpiresAt()
+ })
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetEmailVerifiedAt(v)
+ })
+}
+
+// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateEmailVerifiedAt()
+ })
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearEmailVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearEmailVerifiedAt()
+ })
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetPasswordVerifiedAt(v)
+ })
+}
+
+// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdatePasswordVerifiedAt()
+ })
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearPasswordVerifiedAt()
+ })
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetTotpVerifiedAt(v)
+ })
+}
+
+// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateTotpVerifiedAt()
+ })
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearTotpVerifiedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearTotpVerifiedAt()
+ })
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateExpiresAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (u *PendingAuthSessionUpsertBulk) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.SetConsumedAt(v)
+ })
+}
+
+// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create.
+func (u *PendingAuthSessionUpsertBulk) UpdateConsumedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.UpdateConsumedAt()
+ })
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (u *PendingAuthSessionUpsertBulk) ClearConsumedAt() *PendingAuthSessionUpsertBulk {
+ return u.Update(func(s *PendingAuthSessionUpsert) {
+ s.ClearConsumedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *PendingAuthSessionUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the PendingAuthSessionCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for PendingAuthSessionCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *PendingAuthSessionUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/pendingauthsession_delete.go b/backend/ent/pendingauthsession_delete.go
new file mode 100644
index 00000000..ee4fe605
--- /dev/null
+++ b/backend/ent/pendingauthsession_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// PendingAuthSessionDelete is the builder for deleting a PendingAuthSession entity.
+type PendingAuthSessionDelete struct {
+ config
+ hooks []Hook
+ mutation *PendingAuthSessionMutation
+}
+
+// Where appends a list predicates to the PendingAuthSessionDelete builder.
+func (_d *PendingAuthSessionDelete) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *PendingAuthSessionDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PendingAuthSessionDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *PendingAuthSessionDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// PendingAuthSessionDeleteOne is the builder for deleting a single PendingAuthSession entity.
+type PendingAuthSessionDeleteOne struct {
+ _d *PendingAuthSessionDelete
+}
+
+// Where appends a list predicates to the PendingAuthSessionDelete builder.
+func (_d *PendingAuthSessionDeleteOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *PendingAuthSessionDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{pendingauthsession.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *PendingAuthSessionDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/pendingauthsession_query.go b/backend/ent/pendingauthsession_query.go
new file mode 100644
index 00000000..78e29cd2
--- /dev/null
+++ b/backend/ent/pendingauthsession_query.go
@@ -0,0 +1,717 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "database/sql/driver"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSessionQuery is the builder for querying PendingAuthSession entities.
+type PendingAuthSessionQuery struct {
+ config
+ ctx *QueryContext
+ order []pendingauthsession.OrderOption
+ inters []Interceptor
+ predicates []predicate.PendingAuthSession
+ withTargetUser *UserQuery
+ withAdoptionDecision *IdentityAdoptionDecisionQuery
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the PendingAuthSessionQuery builder.
+func (_q *PendingAuthSessionQuery) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *PendingAuthSessionQuery) Limit(limit int) *PendingAuthSessionQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *PendingAuthSessionQuery) Offset(offset int) *PendingAuthSessionQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *PendingAuthSessionQuery) Unique(unique bool) *PendingAuthSessionQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *PendingAuthSessionQuery) Order(o ...pendingauthsession.OrderOption) *PendingAuthSessionQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// QueryTargetUser chains the current query on the "target_user" edge.
+func (_q *PendingAuthSessionQuery) QueryTargetUser() *UserQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector),
+ sqlgraph.To(user.Table, user.FieldID),
+ sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryAdoptionDecision chains the current query on the "adoption_decision" edge.
+func (_q *PendingAuthSessionQuery) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector),
+ sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID),
+ sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// First returns the first PendingAuthSession entity from the query.
+// Returns a *NotFoundError when no PendingAuthSession was found.
+func (_q *PendingAuthSessionQuery) First(ctx context.Context) (*PendingAuthSession, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{pendingauthsession.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) FirstX(ctx context.Context) *PendingAuthSession {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first PendingAuthSession ID from the query.
+// Returns a *NotFoundError when no PendingAuthSession ID was found.
+func (_q *PendingAuthSessionQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{pendingauthsession.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single PendingAuthSession entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one PendingAuthSession entity is found.
+// Returns a *NotFoundError when no PendingAuthSession entities are found.
+func (_q *PendingAuthSessionQuery) Only(ctx context.Context) (*PendingAuthSession, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{pendingauthsession.Label}
+ default:
+ return nil, &NotSingularError{pendingauthsession.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) OnlyX(ctx context.Context) *PendingAuthSession {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only PendingAuthSession ID in the query.
+// Returns a *NotSingularError when more than one PendingAuthSession ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *PendingAuthSessionQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{pendingauthsession.Label}
+ default:
+ err = &NotSingularError{pendingauthsession.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of PendingAuthSessions.
+func (_q *PendingAuthSessionQuery) All(ctx context.Context) ([]*PendingAuthSession, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*PendingAuthSession, *PendingAuthSessionQuery]()
+ return withInterceptors[[]*PendingAuthSession](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) AllX(ctx context.Context) []*PendingAuthSession {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of PendingAuthSession IDs.
+func (_q *PendingAuthSessionQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(pendingauthsession.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *PendingAuthSessionQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*PendingAuthSessionQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *PendingAuthSessionQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *PendingAuthSessionQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the PendingAuthSessionQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *PendingAuthSessionQuery) Clone() *PendingAuthSessionQuery {
+ if _q == nil {
+ return nil
+ }
+ return &PendingAuthSessionQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]pendingauthsession.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.PendingAuthSession{}, _q.predicates...),
+ withTargetUser: _q.withTargetUser.Clone(),
+ withAdoptionDecision: _q.withAdoptionDecision.Clone(),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// WithTargetUser tells the query-builder to eager-load the nodes that are connected to
+// the "target_user" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *PendingAuthSessionQuery) WithTargetUser(opts ...func(*UserQuery)) *PendingAuthSessionQuery {
+ query := (&UserClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withTargetUser = query
+ return _q
+}
+
+// WithAdoptionDecision tells the query-builder to eager-load the nodes that are connected to
+// the "adoption_decision" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *PendingAuthSessionQuery) WithAdoptionDecision(opts ...func(*IdentityAdoptionDecisionQuery)) *PendingAuthSessionQuery {
+ query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAdoptionDecision = query
+ return _q
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.PendingAuthSession.Query().
+// GroupBy(pendingauthsession.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *PendingAuthSessionQuery) GroupBy(field string, fields ...string) *PendingAuthSessionGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &PendingAuthSessionGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = pendingauthsession.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.PendingAuthSession.Query().
+// Select(pendingauthsession.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *PendingAuthSessionQuery) Select(fields ...string) *PendingAuthSessionSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &PendingAuthSessionSelect{PendingAuthSessionQuery: _q}
+ sbuild.label = pendingauthsession.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a PendingAuthSessionSelect configured with the given aggregations.
+func (_q *PendingAuthSessionQuery) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *PendingAuthSessionQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !pendingauthsession.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *PendingAuthSessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PendingAuthSession, error) {
+ var (
+ nodes = []*PendingAuthSession{}
+ _spec = _q.querySpec()
+ loadedTypes = [2]bool{
+ _q.withTargetUser != nil,
+ _q.withAdoptionDecision != nil,
+ }
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*PendingAuthSession).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &PendingAuthSession{config: _q.config}
+ nodes = append(nodes, node)
+ node.Edges.loadedTypes = loadedTypes
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ if query := _q.withTargetUser; query != nil {
+ if err := _q.loadTargetUser(ctx, query, nodes, nil,
+ func(n *PendingAuthSession, e *User) { n.Edges.TargetUser = e }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withAdoptionDecision; query != nil {
+ if err := _q.loadAdoptionDecision(ctx, query, nodes, nil,
+ func(n *PendingAuthSession, e *IdentityAdoptionDecision) { n.Edges.AdoptionDecision = e }); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+func (_q *PendingAuthSessionQuery) loadTargetUser(ctx context.Context, query *UserQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *User)) error {
+ ids := make([]int64, 0, len(nodes))
+ nodeids := make(map[int64][]*PendingAuthSession)
+ for i := range nodes {
+ if nodes[i].TargetUserID == nil {
+ continue
+ }
+ fk := *nodes[i].TargetUserID
+ if _, ok := nodeids[fk]; !ok {
+ ids = append(ids, fk)
+ }
+ nodeids[fk] = append(nodeids[fk], nodes[i])
+ }
+ if len(ids) == 0 {
+ return nil
+ }
+ query.Where(user.IDIn(ids...))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ nodes, ok := nodeids[n.ID]
+ if !ok {
+ return fmt.Errorf(`unexpected foreign-key "target_user_id" returned %v`, n.ID)
+ }
+ for i := range nodes {
+ assign(nodes[i], n)
+ }
+ }
+ return nil
+}
+func (_q *PendingAuthSessionQuery) loadAdoptionDecision(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *IdentityAdoptionDecision)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*PendingAuthSession)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(identityadoptiondecision.FieldPendingAuthSessionID)
+ }
+ query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(pendingauthsession.AdoptionDecisionColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.PendingAuthSessionID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "pending_auth_session_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+
+func (_q *PendingAuthSessionQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *PendingAuthSessionQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID)
+ for i := range fields {
+ if fields[i] != pendingauthsession.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ if _q.withTargetUser != nil {
+ _spec.Node.AddColumnOnce(pendingauthsession.FieldTargetUserID)
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *PendingAuthSessionQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(pendingauthsession.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = pendingauthsession.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *PendingAuthSessionQuery) ForUpdate(opts ...sql.LockOption) *PendingAuthSessionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *PendingAuthSessionQuery) ForShare(opts ...sql.LockOption) *PendingAuthSessionQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// PendingAuthSessionGroupBy is the group-by builder for PendingAuthSession entities.
+type PendingAuthSessionGroupBy struct {
+ selector
+ build *PendingAuthSessionQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *PendingAuthSessionGroupBy) Aggregate(fns ...AggregateFunc) *PendingAuthSessionGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *PendingAuthSessionGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *PendingAuthSessionGroupBy) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// PendingAuthSessionSelect is the builder for selecting fields of PendingAuthSession entities.
+type PendingAuthSessionSelect struct {
+ *PendingAuthSessionQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *PendingAuthSessionSelect) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *PendingAuthSessionSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionSelect](ctx, _s.PendingAuthSessionQuery, _s, _s.inters, v)
+}
+
+func (_s *PendingAuthSessionSelect) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/pendingauthsession_update.go b/backend/ent/pendingauthsession_update.go
new file mode 100644
index 00000000..00066f69
--- /dev/null
+++ b/backend/ent/pendingauthsession_update.go
@@ -0,0 +1,1178 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+)
+
+// PendingAuthSessionUpdate is the builder for updating PendingAuthSession entities.
+type PendingAuthSessionUpdate struct {
+ config
+ hooks []Hook
+ mutation *PendingAuthSessionMutation
+}
+
+// Where appends a list predicates to the PendingAuthSessionUpdate builder.
+func (_u *PendingAuthSessionUpdate) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *PendingAuthSessionUpdate) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetSessionToken sets the "session_token" field.
+func (_u *PendingAuthSessionUpdate) SetSessionToken(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetSessionToken(v)
+ return _u
+}
+
+// SetNillableSessionToken sets the "session_token" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableSessionToken(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetSessionToken(*v)
+ }
+ return _u
+}
+
+// SetIntent sets the "intent" field.
+func (_u *PendingAuthSessionUpdate) SetIntent(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetIntent(v)
+ return _u
+}
+
+// SetNillableIntent sets the "intent" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableIntent(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetIntent(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *PendingAuthSessionUpdate) SetProviderType(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableProviderType(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *PendingAuthSessionUpdate) SetProviderKey(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableProviderKey(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *PendingAuthSessionUpdate) SetProviderSubject(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (_u *PendingAuthSessionUpdate) SetTargetUserID(v int64) *PendingAuthSessionUpdate {
+ _u.mutation.SetTargetUserID(v)
+ return _u
+}
+
+// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetTargetUserID(*v)
+ }
+ return _u
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (_u *PendingAuthSessionUpdate) ClearTargetUserID() *PendingAuthSessionUpdate {
+ _u.mutation.ClearTargetUserID()
+ return _u
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (_u *PendingAuthSessionUpdate) SetRedirectTo(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetRedirectTo(v)
+ return _u
+}
+
+// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetRedirectTo(*v)
+ }
+ return _u
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (_u *PendingAuthSessionUpdate) SetResolvedEmail(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetResolvedEmail(v)
+ return _u
+}
+
+// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetResolvedEmail(*v)
+ }
+ return _u
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (_u *PendingAuthSessionUpdate) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetRegistrationPasswordHash(v)
+ return _u
+}
+
+// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetRegistrationPasswordHash(*v)
+ }
+ return _u
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (_u *PendingAuthSessionUpdate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdate {
+ _u.mutation.SetUpstreamIdentityClaims(v)
+ return _u
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (_u *PendingAuthSessionUpdate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdate {
+ _u.mutation.SetLocalFlowState(v)
+ return _u
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (_u *PendingAuthSessionUpdate) SetBrowserSessionKey(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetBrowserSessionKey(v)
+ return _u
+}
+
+// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetBrowserSessionKey(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (_u *PendingAuthSessionUpdate) SetCompletionCodeHash(v string) *PendingAuthSessionUpdate {
+ _u.mutation.SetCompletionCodeHash(v)
+ return _u
+}
+
+// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetCompletionCodeHash(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetCompletionCodeExpiresAt(v)
+ return _u
+}
+
+// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetCompletionCodeExpiresAt(*v)
+ }
+ return _u
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdate) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearCompletionCodeExpiresAt()
+ return _u
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetEmailVerifiedAt(v)
+ return _u
+}
+
+// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetEmailVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdate) ClearEmailVerifiedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearEmailVerifiedAt()
+ return _u
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetPasswordVerifiedAt(v)
+ return _u
+}
+
+// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetPasswordVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdate) ClearPasswordVerifiedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearPasswordVerifiedAt()
+ return _u
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetTotpVerifiedAt(v)
+ return _u
+}
+
+// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetTotpVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdate) ClearTotpVerifiedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearTotpVerifiedAt()
+ return _u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_u *PendingAuthSessionUpdate) SetExpiresAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (_u *PendingAuthSessionUpdate) SetConsumedAt(v time.Time) *PendingAuthSessionUpdate {
+ _u.mutation.SetConsumedAt(v)
+ return _u
+}
+
+// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdate {
+ if v != nil {
+ _u.SetConsumedAt(*v)
+ }
+ return _u
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (_u *PendingAuthSessionUpdate) ClearConsumedAt() *PendingAuthSessionUpdate {
+ _u.mutation.ClearConsumedAt()
+ return _u
+}
+
+// SetTargetUser sets the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdate) SetTargetUser(v *User) *PendingAuthSessionUpdate {
+ return _u.SetTargetUserID(v.ID)
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID.
+func (_u *PendingAuthSessionUpdate) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdate {
+ _u.mutation.SetAdoptionDecisionID(id)
+ return _u
+}
+
+// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil.
+func (_u *PendingAuthSessionUpdate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdate {
+ if id != nil {
+ _u = _u.SetAdoptionDecisionID(*id)
+ }
+ return _u
+}
+
+// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdate {
+ return _u.SetAdoptionDecisionID(v.ID)
+}
+
+// Mutation returns the PendingAuthSessionMutation object of the builder.
+func (_u *PendingAuthSessionUpdate) Mutation() *PendingAuthSessionMutation {
+ return _u.mutation
+}
+
+// ClearTargetUser clears the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdate) ClearTargetUser() *PendingAuthSessionUpdate {
+ _u.mutation.ClearTargetUser()
+ return _u
+}
+
+// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdate) ClearAdoptionDecision() *PendingAuthSessionUpdate {
+ _u.mutation.ClearAdoptionDecision()
+ return _u
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *PendingAuthSessionUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *PendingAuthSessionUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *PendingAuthSessionUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := pendingauthsession.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PendingAuthSessionUpdate) check() error {
+ if v, ok := _u.mutation.SessionToken(); ok {
+ if err := pendingauthsession.SessionTokenValidator(v); err != nil {
+ return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Intent(); ok {
+ if err := pendingauthsession.IntentValidator(v); err != nil {
+ return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := pendingauthsession.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := pendingauthsession.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := pendingauthsession.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *PendingAuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.SessionToken(); ok {
+ _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Intent(); ok {
+ _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RedirectTo(); ok {
+ _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ResolvedEmail(); ok {
+ _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RegistrationPasswordHash(); ok {
+ _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.UpstreamIdentityClaims(); ok {
+ _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.LocalFlowState(); ok {
+ _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BrowserSessionKey(); ok {
+ _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeHash(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value)
+ }
+ if _u.mutation.CompletionCodeExpiresAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.EmailVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.EmailVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.PasswordVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.PasswordVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.TotpVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.TotpVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ConsumedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value)
+ }
+ if _u.mutation.ConsumedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime)
+ }
+ if _u.mutation.TargetUserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{pendingauthsession.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// PendingAuthSessionUpdateOne is the builder for updating a single PendingAuthSession entity.
+type PendingAuthSessionUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *PendingAuthSessionMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetSessionToken sets the "session_token" field.
+func (_u *PendingAuthSessionUpdateOne) SetSessionToken(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetSessionToken(v)
+ return _u
+}
+
+// SetNillableSessionToken sets the "session_token" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableSessionToken(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetSessionToken(*v)
+ }
+ return _u
+}
+
+// SetIntent sets the "intent" field.
+func (_u *PendingAuthSessionUpdateOne) SetIntent(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetIntent(v)
+ return _u
+}
+
+// SetNillableIntent sets the "intent" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableIntent(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetIntent(*v)
+ }
+ return _u
+}
+
+// SetProviderType sets the "provider_type" field.
+func (_u *PendingAuthSessionUpdateOne) SetProviderType(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetProviderType(v)
+ return _u
+}
+
+// SetNillableProviderType sets the "provider_type" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableProviderType(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetProviderType(*v)
+ }
+ return _u
+}
+
+// SetProviderKey sets the "provider_key" field.
+func (_u *PendingAuthSessionUpdateOne) SetProviderKey(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetProviderKey(v)
+ return _u
+}
+
+// SetNillableProviderKey sets the "provider_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableProviderKey(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetProviderKey(*v)
+ }
+ return _u
+}
+
+// SetProviderSubject sets the "provider_subject" field.
+func (_u *PendingAuthSessionUpdateOne) SetProviderSubject(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetProviderSubject(v)
+ return _u
+}
+
+// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetProviderSubject(*v)
+ }
+ return _u
+}
+
+// SetTargetUserID sets the "target_user_id" field.
+func (_u *PendingAuthSessionUpdateOne) SetTargetUserID(v int64) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetTargetUserID(v)
+ return _u
+}
+
+// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetTargetUserID(*v)
+ }
+ return _u
+}
+
+// ClearTargetUserID clears the value of the "target_user_id" field.
+func (_u *PendingAuthSessionUpdateOne) ClearTargetUserID() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearTargetUserID()
+ return _u
+}
+
+// SetRedirectTo sets the "redirect_to" field.
+func (_u *PendingAuthSessionUpdateOne) SetRedirectTo(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetRedirectTo(v)
+ return _u
+}
+
+// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetRedirectTo(*v)
+ }
+ return _u
+}
+
+// SetResolvedEmail sets the "resolved_email" field.
+func (_u *PendingAuthSessionUpdateOne) SetResolvedEmail(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetResolvedEmail(v)
+ return _u
+}
+
+// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetResolvedEmail(*v)
+ }
+ return _u
+}
+
+// SetRegistrationPasswordHash sets the "registration_password_hash" field.
+func (_u *PendingAuthSessionUpdateOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetRegistrationPasswordHash(v)
+ return _u
+}
+
+// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetRegistrationPasswordHash(*v)
+ }
+ return _u
+}
+
+// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field.
+func (_u *PendingAuthSessionUpdateOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetUpstreamIdentityClaims(v)
+ return _u
+}
+
+// SetLocalFlowState sets the "local_flow_state" field.
+func (_u *PendingAuthSessionUpdateOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetLocalFlowState(v)
+ return _u
+}
+
+// SetBrowserSessionKey sets the "browser_session_key" field.
+func (_u *PendingAuthSessionUpdateOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetBrowserSessionKey(v)
+ return _u
+}
+
+// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetBrowserSessionKey(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeHash sets the "completion_code_hash" field.
+func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetCompletionCodeHash(v)
+ return _u
+}
+
+// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetCompletionCodeHash(*v)
+ }
+ return _u
+}
+
+// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetCompletionCodeExpiresAt(v)
+ return _u
+}
+
+// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetCompletionCodeExpiresAt(*v)
+ }
+ return _u
+}
+
+// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearCompletionCodeExpiresAt()
+ return _u
+}
+
+// SetEmailVerifiedAt sets the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetEmailVerifiedAt(v)
+ return _u
+}
+
+// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetEmailVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearEmailVerifiedAt clears the value of the "email_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearEmailVerifiedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearEmailVerifiedAt()
+ return _u
+}
+
+// SetPasswordVerifiedAt sets the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetPasswordVerifiedAt(v)
+ return _u
+}
+
+// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetPasswordVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearPasswordVerifiedAt()
+ return _u
+}
+
+// SetTotpVerifiedAt sets the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetTotpVerifiedAt(v)
+ return _u
+}
+
+// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetTotpVerifiedAt(*v)
+ }
+ return _u
+}
+
+// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearTotpVerifiedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearTotpVerifiedAt()
+ return _u
+}
+
+// SetExpiresAt sets the "expires_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// SetConsumedAt sets the "consumed_at" field.
+func (_u *PendingAuthSessionUpdateOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetConsumedAt(v)
+ return _u
+}
+
+// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdateOne {
+ if v != nil {
+ _u.SetConsumedAt(*v)
+ }
+ return _u
+}
+
+// ClearConsumedAt clears the value of the "consumed_at" field.
+func (_u *PendingAuthSessionUpdateOne) ClearConsumedAt() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearConsumedAt()
+ return _u
+}
+
+// SetTargetUser sets the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdateOne) SetTargetUser(v *User) *PendingAuthSessionUpdateOne {
+ return _u.SetTargetUserID(v.ID)
+}
+
+// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID.
+func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdateOne {
+ _u.mutation.SetAdoptionDecisionID(id)
+ return _u
+}
+
+// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil.
+func (_u *PendingAuthSessionUpdateOne) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdateOne {
+ if id != nil {
+ _u = _u.SetAdoptionDecisionID(*id)
+ }
+ return _u
+}
+
+// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdateOne {
+ return _u.SetAdoptionDecisionID(v.ID)
+}
+
+// Mutation returns the PendingAuthSessionMutation object of the builder.
+func (_u *PendingAuthSessionUpdateOne) Mutation() *PendingAuthSessionMutation {
+ return _u.mutation
+}
+
+// ClearTargetUser clears the "target_user" edge to the User entity.
+func (_u *PendingAuthSessionUpdateOne) ClearTargetUser() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearTargetUser()
+ return _u
+}
+
+// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity.
+func (_u *PendingAuthSessionUpdateOne) ClearAdoptionDecision() *PendingAuthSessionUpdateOne {
+ _u.mutation.ClearAdoptionDecision()
+ return _u
+}
+
+// Where appends a list predicates to the PendingAuthSessionUpdate builder.
+func (_u *PendingAuthSessionUpdateOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *PendingAuthSessionUpdateOne) Select(field string, fields ...string) *PendingAuthSessionUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated PendingAuthSession entity.
+func (_u *PendingAuthSessionUpdateOne) Save(ctx context.Context) (*PendingAuthSession, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdateOne) SaveX(ctx context.Context) *PendingAuthSession {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *PendingAuthSessionUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *PendingAuthSessionUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *PendingAuthSessionUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := pendingauthsession.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *PendingAuthSessionUpdateOne) check() error {
+ if v, ok := _u.mutation.SessionToken(); ok {
+ if err := pendingauthsession.SessionTokenValidator(v); err != nil {
+ return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.Intent(); ok {
+ if err := pendingauthsession.IntentValidator(v); err != nil {
+ return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderType(); ok {
+ if err := pendingauthsession.ProviderTypeValidator(v); err != nil {
+ return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderKey(); ok {
+ if err := pendingauthsession.ProviderKeyValidator(v); err != nil {
+ return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProviderSubject(); ok {
+ if err := pendingauthsession.ProviderSubjectValidator(v); err != nil {
+ return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *PendingAuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *PendingAuthSession, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PendingAuthSession.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID)
+ for _, f := range fields {
+ if !pendingauthsession.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != pendingauthsession.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.SessionToken(); ok {
+ _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Intent(); ok {
+ _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderType(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderKey(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProviderSubject(); ok {
+ _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RedirectTo(); ok {
+ _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ResolvedEmail(); ok {
+ _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.RegistrationPasswordHash(); ok {
+ _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.UpstreamIdentityClaims(); ok {
+ _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.LocalFlowState(); ok {
+ _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.BrowserSessionKey(); ok {
+ _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeHash(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value)
+ }
+ if _u.mutation.CompletionCodeExpiresAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.EmailVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.EmailVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.PasswordVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.PasswordVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.TotpVerifiedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value)
+ }
+ if _u.mutation.TotpVerifiedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.ConsumedAt(); ok {
+ _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value)
+ }
+ if _u.mutation.ConsumedAtCleared() {
+ _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime)
+ }
+ if _u.mutation.TargetUserCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.M2O,
+ Inverse: true,
+ Table: pendingauthsession.TargetUserTable,
+ Columns: []string{pendingauthsession.TargetUserColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AdoptionDecisionCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2O,
+ Inverse: false,
+ Table: pendingauthsession.AdoptionDecisionTable,
+ Columns: []string{pendingauthsession.AdoptionDecisionColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ _node = &PendingAuthSession{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{pendingauthsession.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go
index a652ab3f..dc86471e 100644
--- a/backend/ent/predicate/predicate.go
+++ b/backend/ent/predicate/predicate.go
@@ -21,6 +21,24 @@ type Announcement func(*sql.Selector)
// AnnouncementRead is the predicate function for announcementread builders.
type AnnouncementRead func(*sql.Selector)
+// AuthIdentity is the predicate function for authidentity builders.
+type AuthIdentity func(*sql.Selector)
+
+// AuthIdentityChannel is the predicate function for authidentitychannel builders.
+type AuthIdentityChannel func(*sql.Selector)
+
+// ChannelMonitor is the predicate function for channelmonitor builders.
+type ChannelMonitor func(*sql.Selector)
+
+// ChannelMonitorDailyRollup is the predicate function for channelmonitordailyrollup builders.
+type ChannelMonitorDailyRollup func(*sql.Selector)
+
+// ChannelMonitorHistory is the predicate function for channelmonitorhistory builders.
+type ChannelMonitorHistory func(*sql.Selector)
+
+// ChannelMonitorRequestTemplate is the predicate function for channelmonitorrequesttemplate builders.
+type ChannelMonitorRequestTemplate func(*sql.Selector)
+
// ErrorPassthroughRule is the predicate function for errorpassthroughrule builders.
type ErrorPassthroughRule func(*sql.Selector)
@@ -30,6 +48,21 @@ type Group func(*sql.Selector)
// IdempotencyRecord is the predicate function for idempotencyrecord builders.
type IdempotencyRecord func(*sql.Selector)
+// IdentityAdoptionDecision is the predicate function for identityadoptiondecision builders.
+type IdentityAdoptionDecision func(*sql.Selector)
+
+// PaymentAuditLog is the predicate function for paymentauditlog builders.
+type PaymentAuditLog func(*sql.Selector)
+
+// PaymentOrder is the predicate function for paymentorder builders.
+type PaymentOrder func(*sql.Selector)
+
+// PaymentProviderInstance is the predicate function for paymentproviderinstance builders.
+type PaymentProviderInstance func(*sql.Selector)
+
+// PendingAuthSession is the predicate function for pendingauthsession builders.
+type PendingAuthSession func(*sql.Selector)
+
// PromoCode is the predicate function for promocode builders.
type PromoCode func(*sql.Selector)
@@ -48,6 +81,9 @@ type SecuritySecret func(*sql.Selector)
// Setting is the predicate function for setting builders.
type Setting func(*sql.Selector)
+// SubscriptionPlan is the predicate function for subscriptionplan builders.
+type SubscriptionPlan func(*sql.Selector)
+
// TLSFingerprintProfile is the predicate function for tlsfingerprintprofile builders.
type TLSFingerprintProfile func(*sql.Selector)
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index fd6be291..6b344a55 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -10,9 +10,20 @@ import (
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitordailyrollup"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/idempotencyrecord"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/proxy"
@@ -20,6 +31,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/schema"
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
"github.com/Wei-Shaw/sub2api/ent/setting"
+ "github.com/Wei-Shaw/sub2api/ent/subscriptionplan"
"github.com/Wei-Shaw/sub2api/ent/tlsfingerprintprofile"
"github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -28,6 +40,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
+ "github.com/Wei-Shaw/sub2api/internal/domain"
)
// The init function reads all schema descriptors with runtime code
@@ -304,6 +317,366 @@ func init() {
announcementreadDescCreatedAt := announcementreadFields[3].Descriptor()
// announcementread.DefaultCreatedAt holds the default value on creation for the created_at field.
announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time)
+ authidentityMixin := schema.AuthIdentity{}.Mixin()
+ authidentityMixinFields0 := authidentityMixin[0].Fields()
+ _ = authidentityMixinFields0
+ authidentityFields := schema.AuthIdentity{}.Fields()
+ _ = authidentityFields
+ // authidentityDescCreatedAt is the schema descriptor for created_at field.
+ authidentityDescCreatedAt := authidentityMixinFields0[0].Descriptor()
+ // authidentity.DefaultCreatedAt holds the default value on creation for the created_at field.
+ authidentity.DefaultCreatedAt = authidentityDescCreatedAt.Default.(func() time.Time)
+ // authidentityDescUpdatedAt is the schema descriptor for updated_at field.
+ authidentityDescUpdatedAt := authidentityMixinFields0[1].Descriptor()
+ // authidentity.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ authidentity.DefaultUpdatedAt = authidentityDescUpdatedAt.Default.(func() time.Time)
+ // authidentity.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ authidentity.UpdateDefaultUpdatedAt = authidentityDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // authidentityDescProviderType is the schema descriptor for provider_type field.
+ authidentityDescProviderType := authidentityFields[1].Descriptor()
+ // authidentity.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ authidentity.ProviderTypeValidator = func() func(string) error {
+ validators := authidentityDescProviderType.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(provider_type string) error {
+ for _, fn := range fns {
+ if err := fn(provider_type); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // authidentityDescProviderKey is the schema descriptor for provider_key field.
+ authidentityDescProviderKey := authidentityFields[2].Descriptor()
+ // authidentity.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ authidentity.ProviderKeyValidator = authidentityDescProviderKey.Validators[0].(func(string) error)
+ // authidentityDescProviderSubject is the schema descriptor for provider_subject field.
+ authidentityDescProviderSubject := authidentityFields[3].Descriptor()
+ // authidentity.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ authidentity.ProviderSubjectValidator = authidentityDescProviderSubject.Validators[0].(func(string) error)
+ // authidentityDescMetadata is the schema descriptor for metadata field.
+ authidentityDescMetadata := authidentityFields[6].Descriptor()
+ // authidentity.DefaultMetadata holds the default value on creation for the metadata field.
+ authidentity.DefaultMetadata = authidentityDescMetadata.Default.(func() map[string]interface{})
+ authidentitychannelMixin := schema.AuthIdentityChannel{}.Mixin()
+ authidentitychannelMixinFields0 := authidentitychannelMixin[0].Fields()
+ _ = authidentitychannelMixinFields0
+ authidentitychannelFields := schema.AuthIdentityChannel{}.Fields()
+ _ = authidentitychannelFields
+ // authidentitychannelDescCreatedAt is the schema descriptor for created_at field.
+ authidentitychannelDescCreatedAt := authidentitychannelMixinFields0[0].Descriptor()
+ // authidentitychannel.DefaultCreatedAt holds the default value on creation for the created_at field.
+ authidentitychannel.DefaultCreatedAt = authidentitychannelDescCreatedAt.Default.(func() time.Time)
+ // authidentitychannelDescUpdatedAt is the schema descriptor for updated_at field.
+ authidentitychannelDescUpdatedAt := authidentitychannelMixinFields0[1].Descriptor()
+ // authidentitychannel.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ authidentitychannel.DefaultUpdatedAt = authidentitychannelDescUpdatedAt.Default.(func() time.Time)
+ // authidentitychannel.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ authidentitychannel.UpdateDefaultUpdatedAt = authidentitychannelDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // authidentitychannelDescProviderType is the schema descriptor for provider_type field.
+ authidentitychannelDescProviderType := authidentitychannelFields[1].Descriptor()
+ // authidentitychannel.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ authidentitychannel.ProviderTypeValidator = func() func(string) error {
+ validators := authidentitychannelDescProviderType.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(provider_type string) error {
+ for _, fn := range fns {
+ if err := fn(provider_type); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // authidentitychannelDescProviderKey is the schema descriptor for provider_key field.
+ authidentitychannelDescProviderKey := authidentitychannelFields[2].Descriptor()
+ // authidentitychannel.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ authidentitychannel.ProviderKeyValidator = authidentitychannelDescProviderKey.Validators[0].(func(string) error)
+ // authidentitychannelDescChannel is the schema descriptor for channel field.
+ authidentitychannelDescChannel := authidentitychannelFields[3].Descriptor()
+ // authidentitychannel.ChannelValidator is a validator for the "channel" field. It is called by the builders before save.
+ authidentitychannel.ChannelValidator = func() func(string) error {
+ validators := authidentitychannelDescChannel.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(channel string) error {
+ for _, fn := range fns {
+ if err := fn(channel); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // authidentitychannelDescChannelAppID is the schema descriptor for channel_app_id field.
+ authidentitychannelDescChannelAppID := authidentitychannelFields[4].Descriptor()
+ // authidentitychannel.ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save.
+ authidentitychannel.ChannelAppIDValidator = authidentitychannelDescChannelAppID.Validators[0].(func(string) error)
+ // authidentitychannelDescChannelSubject is the schema descriptor for channel_subject field.
+ authidentitychannelDescChannelSubject := authidentitychannelFields[5].Descriptor()
+ // authidentitychannel.ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save.
+ authidentitychannel.ChannelSubjectValidator = authidentitychannelDescChannelSubject.Validators[0].(func(string) error)
+ // authidentitychannelDescMetadata is the schema descriptor for metadata field.
+ authidentitychannelDescMetadata := authidentitychannelFields[6].Descriptor()
+ // authidentitychannel.DefaultMetadata holds the default value on creation for the metadata field.
+ authidentitychannel.DefaultMetadata = authidentitychannelDescMetadata.Default.(func() map[string]interface{})
+ channelmonitorMixin := schema.ChannelMonitor{}.Mixin()
+ channelmonitorMixinFields0 := channelmonitorMixin[0].Fields()
+ _ = channelmonitorMixinFields0
+ channelmonitorFields := schema.ChannelMonitor{}.Fields()
+ _ = channelmonitorFields
+ // channelmonitorDescCreatedAt is the schema descriptor for created_at field.
+ channelmonitorDescCreatedAt := channelmonitorMixinFields0[0].Descriptor()
+ // channelmonitor.DefaultCreatedAt holds the default value on creation for the created_at field.
+ channelmonitor.DefaultCreatedAt = channelmonitorDescCreatedAt.Default.(func() time.Time)
+ // channelmonitorDescUpdatedAt is the schema descriptor for updated_at field.
+ channelmonitorDescUpdatedAt := channelmonitorMixinFields0[1].Descriptor()
+ // channelmonitor.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ channelmonitor.DefaultUpdatedAt = channelmonitorDescUpdatedAt.Default.(func() time.Time)
+ // channelmonitor.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ channelmonitor.UpdateDefaultUpdatedAt = channelmonitorDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // channelmonitorDescName is the schema descriptor for name field.
+ channelmonitorDescName := channelmonitorFields[0].Descriptor()
+ // channelmonitor.NameValidator is a validator for the "name" field. It is called by the builders before save.
+ channelmonitor.NameValidator = func() func(string) error {
+ validators := channelmonitorDescName.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(name string) error {
+ for _, fn := range fns {
+ if err := fn(name); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // channelmonitorDescEndpoint is the schema descriptor for endpoint field.
+ channelmonitorDescEndpoint := channelmonitorFields[2].Descriptor()
+ // channelmonitor.EndpointValidator is a validator for the "endpoint" field. It is called by the builders before save.
+ channelmonitor.EndpointValidator = func() func(string) error {
+ validators := channelmonitorDescEndpoint.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(endpoint string) error {
+ for _, fn := range fns {
+ if err := fn(endpoint); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // channelmonitorDescAPIKeyEncrypted is the schema descriptor for api_key_encrypted field.
+ channelmonitorDescAPIKeyEncrypted := channelmonitorFields[3].Descriptor()
+ // channelmonitor.APIKeyEncryptedValidator is a validator for the "api_key_encrypted" field. It is called by the builders before save.
+ channelmonitor.APIKeyEncryptedValidator = channelmonitorDescAPIKeyEncrypted.Validators[0].(func(string) error)
+ // channelmonitorDescPrimaryModel is the schema descriptor for primary_model field.
+ channelmonitorDescPrimaryModel := channelmonitorFields[4].Descriptor()
+ // channelmonitor.PrimaryModelValidator is a validator for the "primary_model" field. It is called by the builders before save.
+ channelmonitor.PrimaryModelValidator = func() func(string) error {
+ validators := channelmonitorDescPrimaryModel.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(primary_model string) error {
+ for _, fn := range fns {
+ if err := fn(primary_model); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // channelmonitorDescExtraModels is the schema descriptor for extra_models field.
+ channelmonitorDescExtraModels := channelmonitorFields[5].Descriptor()
+ // channelmonitor.DefaultExtraModels holds the default value on creation for the extra_models field.
+ channelmonitor.DefaultExtraModels = channelmonitorDescExtraModels.Default.([]string)
+ // channelmonitorDescGroupName is the schema descriptor for group_name field.
+ channelmonitorDescGroupName := channelmonitorFields[6].Descriptor()
+ // channelmonitor.DefaultGroupName holds the default value on creation for the group_name field.
+ channelmonitor.DefaultGroupName = channelmonitorDescGroupName.Default.(string)
+ // channelmonitor.GroupNameValidator is a validator for the "group_name" field. It is called by the builders before save.
+ channelmonitor.GroupNameValidator = channelmonitorDescGroupName.Validators[0].(func(string) error)
+ // channelmonitorDescEnabled is the schema descriptor for enabled field.
+ channelmonitorDescEnabled := channelmonitorFields[7].Descriptor()
+ // channelmonitor.DefaultEnabled holds the default value on creation for the enabled field.
+ channelmonitor.DefaultEnabled = channelmonitorDescEnabled.Default.(bool)
+ // channelmonitorDescIntervalSeconds is the schema descriptor for interval_seconds field.
+ channelmonitorDescIntervalSeconds := channelmonitorFields[8].Descriptor()
+ // channelmonitor.IntervalSecondsValidator is a validator for the "interval_seconds" field. It is called by the builders before save.
+ channelmonitor.IntervalSecondsValidator = channelmonitorDescIntervalSeconds.Validators[0].(func(int) error)
+ // channelmonitorDescExtraHeaders is the schema descriptor for extra_headers field.
+ channelmonitorDescExtraHeaders := channelmonitorFields[12].Descriptor()
+ // channelmonitor.DefaultExtraHeaders holds the default value on creation for the extra_headers field.
+ channelmonitor.DefaultExtraHeaders = channelmonitorDescExtraHeaders.Default.(map[string]string)
+ // channelmonitorDescBodyOverrideMode is the schema descriptor for body_override_mode field.
+ channelmonitorDescBodyOverrideMode := channelmonitorFields[13].Descriptor()
+ // channelmonitor.DefaultBodyOverrideMode holds the default value on creation for the body_override_mode field.
+ channelmonitor.DefaultBodyOverrideMode = channelmonitorDescBodyOverrideMode.Default.(string)
+ // channelmonitor.BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
+ channelmonitor.BodyOverrideModeValidator = channelmonitorDescBodyOverrideMode.Validators[0].(func(string) error)
+ channelmonitordailyrollupFields := schema.ChannelMonitorDailyRollup{}.Fields()
+ _ = channelmonitordailyrollupFields
+ // channelmonitordailyrollupDescModel is the schema descriptor for model field.
+ channelmonitordailyrollupDescModel := channelmonitordailyrollupFields[1].Descriptor()
+ // channelmonitordailyrollup.ModelValidator is a validator for the "model" field. It is called by the builders before save.
+ channelmonitordailyrollup.ModelValidator = func() func(string) error {
+ validators := channelmonitordailyrollupDescModel.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(model string) error {
+ for _, fn := range fns {
+ if err := fn(model); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // channelmonitordailyrollupDescTotalChecks is the schema descriptor for total_checks field.
+ channelmonitordailyrollupDescTotalChecks := channelmonitordailyrollupFields[3].Descriptor()
+ // channelmonitordailyrollup.DefaultTotalChecks holds the default value on creation for the total_checks field.
+ channelmonitordailyrollup.DefaultTotalChecks = channelmonitordailyrollupDescTotalChecks.Default.(int)
+ // channelmonitordailyrollupDescOkCount is the schema descriptor for ok_count field.
+ channelmonitordailyrollupDescOkCount := channelmonitordailyrollupFields[4].Descriptor()
+ // channelmonitordailyrollup.DefaultOkCount holds the default value on creation for the ok_count field.
+ channelmonitordailyrollup.DefaultOkCount = channelmonitordailyrollupDescOkCount.Default.(int)
+ // channelmonitordailyrollupDescOperationalCount is the schema descriptor for operational_count field.
+ channelmonitordailyrollupDescOperationalCount := channelmonitordailyrollupFields[5].Descriptor()
+ // channelmonitordailyrollup.DefaultOperationalCount holds the default value on creation for the operational_count field.
+ channelmonitordailyrollup.DefaultOperationalCount = channelmonitordailyrollupDescOperationalCount.Default.(int)
+ // channelmonitordailyrollupDescDegradedCount is the schema descriptor for degraded_count field.
+ channelmonitordailyrollupDescDegradedCount := channelmonitordailyrollupFields[6].Descriptor()
+ // channelmonitordailyrollup.DefaultDegradedCount holds the default value on creation for the degraded_count field.
+ channelmonitordailyrollup.DefaultDegradedCount = channelmonitordailyrollupDescDegradedCount.Default.(int)
+ // channelmonitordailyrollupDescFailedCount is the schema descriptor for failed_count field.
+ channelmonitordailyrollupDescFailedCount := channelmonitordailyrollupFields[7].Descriptor()
+ // channelmonitordailyrollup.DefaultFailedCount holds the default value on creation for the failed_count field.
+ channelmonitordailyrollup.DefaultFailedCount = channelmonitordailyrollupDescFailedCount.Default.(int)
+ // channelmonitordailyrollupDescErrorCount is the schema descriptor for error_count field.
+ channelmonitordailyrollupDescErrorCount := channelmonitordailyrollupFields[8].Descriptor()
+ // channelmonitordailyrollup.DefaultErrorCount holds the default value on creation for the error_count field.
+ channelmonitordailyrollup.DefaultErrorCount = channelmonitordailyrollupDescErrorCount.Default.(int)
+ // channelmonitordailyrollupDescSumLatencyMs is the schema descriptor for sum_latency_ms field.
+ channelmonitordailyrollupDescSumLatencyMs := channelmonitordailyrollupFields[9].Descriptor()
+ // channelmonitordailyrollup.DefaultSumLatencyMs holds the default value on creation for the sum_latency_ms field.
+ channelmonitordailyrollup.DefaultSumLatencyMs = channelmonitordailyrollupDescSumLatencyMs.Default.(int64)
+ // channelmonitordailyrollupDescCountLatency is the schema descriptor for count_latency field.
+ channelmonitordailyrollupDescCountLatency := channelmonitordailyrollupFields[10].Descriptor()
+ // channelmonitordailyrollup.DefaultCountLatency holds the default value on creation for the count_latency field.
+ channelmonitordailyrollup.DefaultCountLatency = channelmonitordailyrollupDescCountLatency.Default.(int)
+ // channelmonitordailyrollupDescSumPingLatencyMs is the schema descriptor for sum_ping_latency_ms field.
+ channelmonitordailyrollupDescSumPingLatencyMs := channelmonitordailyrollupFields[11].Descriptor()
+ // channelmonitordailyrollup.DefaultSumPingLatencyMs holds the default value on creation for the sum_ping_latency_ms field.
+ channelmonitordailyrollup.DefaultSumPingLatencyMs = channelmonitordailyrollupDescSumPingLatencyMs.Default.(int64)
+ // channelmonitordailyrollupDescCountPingLatency is the schema descriptor for count_ping_latency field.
+ channelmonitordailyrollupDescCountPingLatency := channelmonitordailyrollupFields[12].Descriptor()
+ // channelmonitordailyrollup.DefaultCountPingLatency holds the default value on creation for the count_ping_latency field.
+ channelmonitordailyrollup.DefaultCountPingLatency = channelmonitordailyrollupDescCountPingLatency.Default.(int)
+ // channelmonitordailyrollupDescComputedAt is the schema descriptor for computed_at field.
+ channelmonitordailyrollupDescComputedAt := channelmonitordailyrollupFields[13].Descriptor()
+ // channelmonitordailyrollup.DefaultComputedAt holds the default value on creation for the computed_at field.
+ channelmonitordailyrollup.DefaultComputedAt = channelmonitordailyrollupDescComputedAt.Default.(func() time.Time)
+ // channelmonitordailyrollup.UpdateDefaultComputedAt holds the default value on update for the computed_at field.
+ channelmonitordailyrollup.UpdateDefaultComputedAt = channelmonitordailyrollupDescComputedAt.UpdateDefault.(func() time.Time)
+ channelmonitorhistoryFields := schema.ChannelMonitorHistory{}.Fields()
+ _ = channelmonitorhistoryFields
+ // channelmonitorhistoryDescModel is the schema descriptor for model field.
+ channelmonitorhistoryDescModel := channelmonitorhistoryFields[1].Descriptor()
+ // channelmonitorhistory.ModelValidator is a validator for the "model" field. It is called by the builders before save.
+ channelmonitorhistory.ModelValidator = func() func(string) error {
+ validators := channelmonitorhistoryDescModel.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(model string) error {
+ for _, fn := range fns {
+ if err := fn(model); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // channelmonitorhistoryDescMessage is the schema descriptor for message field.
+ channelmonitorhistoryDescMessage := channelmonitorhistoryFields[5].Descriptor()
+ // channelmonitorhistory.DefaultMessage holds the default value on creation for the message field.
+ channelmonitorhistory.DefaultMessage = channelmonitorhistoryDescMessage.Default.(string)
+ // channelmonitorhistory.MessageValidator is a validator for the "message" field. It is called by the builders before save.
+ channelmonitorhistory.MessageValidator = channelmonitorhistoryDescMessage.Validators[0].(func(string) error)
+ // channelmonitorhistoryDescCheckedAt is the schema descriptor for checked_at field.
+ channelmonitorhistoryDescCheckedAt := channelmonitorhistoryFields[6].Descriptor()
+ // channelmonitorhistory.DefaultCheckedAt holds the default value on creation for the checked_at field.
+ channelmonitorhistory.DefaultCheckedAt = channelmonitorhistoryDescCheckedAt.Default.(func() time.Time)
+ channelmonitorrequesttemplateMixin := schema.ChannelMonitorRequestTemplate{}.Mixin()
+ channelmonitorrequesttemplateMixinFields0 := channelmonitorrequesttemplateMixin[0].Fields()
+ _ = channelmonitorrequesttemplateMixinFields0
+ channelmonitorrequesttemplateFields := schema.ChannelMonitorRequestTemplate{}.Fields()
+ _ = channelmonitorrequesttemplateFields
+ // channelmonitorrequesttemplateDescCreatedAt is the schema descriptor for created_at field.
+ channelmonitorrequesttemplateDescCreatedAt := channelmonitorrequesttemplateMixinFields0[0].Descriptor()
+ // channelmonitorrequesttemplate.DefaultCreatedAt holds the default value on creation for the created_at field.
+ channelmonitorrequesttemplate.DefaultCreatedAt = channelmonitorrequesttemplateDescCreatedAt.Default.(func() time.Time)
+ // channelmonitorrequesttemplateDescUpdatedAt is the schema descriptor for updated_at field.
+ channelmonitorrequesttemplateDescUpdatedAt := channelmonitorrequesttemplateMixinFields0[1].Descriptor()
+ // channelmonitorrequesttemplate.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ channelmonitorrequesttemplate.DefaultUpdatedAt = channelmonitorrequesttemplateDescUpdatedAt.Default.(func() time.Time)
+ // channelmonitorrequesttemplate.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ channelmonitorrequesttemplate.UpdateDefaultUpdatedAt = channelmonitorrequesttemplateDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // channelmonitorrequesttemplateDescName is the schema descriptor for name field.
+ channelmonitorrequesttemplateDescName := channelmonitorrequesttemplateFields[0].Descriptor()
+ // channelmonitorrequesttemplate.NameValidator is a validator for the "name" field. It is called by the builders before save.
+ channelmonitorrequesttemplate.NameValidator = func() func(string) error {
+ validators := channelmonitorrequesttemplateDescName.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(name string) error {
+ for _, fn := range fns {
+ if err := fn(name); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // channelmonitorrequesttemplateDescDescription is the schema descriptor for description field.
+ channelmonitorrequesttemplateDescDescription := channelmonitorrequesttemplateFields[2].Descriptor()
+ // channelmonitorrequesttemplate.DefaultDescription holds the default value on creation for the description field.
+ channelmonitorrequesttemplate.DefaultDescription = channelmonitorrequesttemplateDescDescription.Default.(string)
+ // channelmonitorrequesttemplate.DescriptionValidator is a validator for the "description" field. It is called by the builders before save.
+ channelmonitorrequesttemplate.DescriptionValidator = channelmonitorrequesttemplateDescDescription.Validators[0].(func(string) error)
+ // channelmonitorrequesttemplateDescExtraHeaders is the schema descriptor for extra_headers field.
+ channelmonitorrequesttemplateDescExtraHeaders := channelmonitorrequesttemplateFields[3].Descriptor()
+ // channelmonitorrequesttemplate.DefaultExtraHeaders holds the default value on creation for the extra_headers field.
+ channelmonitorrequesttemplate.DefaultExtraHeaders = channelmonitorrequesttemplateDescExtraHeaders.Default.(map[string]string)
+ // channelmonitorrequesttemplateDescBodyOverrideMode is the schema descriptor for body_override_mode field.
+ channelmonitorrequesttemplateDescBodyOverrideMode := channelmonitorrequesttemplateFields[4].Descriptor()
+ // channelmonitorrequesttemplate.DefaultBodyOverrideMode holds the default value on creation for the body_override_mode field.
+ channelmonitorrequesttemplate.DefaultBodyOverrideMode = channelmonitorrequesttemplateDescBodyOverrideMode.Default.(string)
+ // channelmonitorrequesttemplate.BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
+ channelmonitorrequesttemplate.BodyOverrideModeValidator = channelmonitorrequesttemplateDescBodyOverrideMode.Validators[0].(func(string) error)
errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin()
errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields()
_ = errorpassthroughruleMixinFields0
@@ -430,48 +803,52 @@ func init() {
groupDescDefaultValidityDays := groupFields[10].Descriptor()
// group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field.
group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int)
- // groupDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field.
- groupDescSoraStorageQuotaBytes := groupFields[18].Descriptor()
- // group.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field.
- group.DefaultSoraStorageQuotaBytes = groupDescSoraStorageQuotaBytes.Default.(int64)
// groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field.
- groupDescClaudeCodeOnly := groupFields[19].Descriptor()
+ groupDescClaudeCodeOnly := groupFields[14].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
// groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
- groupDescModelRoutingEnabled := groupFields[23].Descriptor()
+ groupDescModelRoutingEnabled := groupFields[18].Descriptor()
// group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
// groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field.
- groupDescMcpXMLInject := groupFields[24].Descriptor()
+ groupDescMcpXMLInject := groupFields[19].Descriptor()
// group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field.
group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool)
// groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field.
- groupDescSupportedModelScopes := groupFields[25].Descriptor()
+ groupDescSupportedModelScopes := groupFields[20].Descriptor()
// group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field.
group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string)
// groupDescSortOrder is the schema descriptor for sort_order field.
- groupDescSortOrder := groupFields[26].Descriptor()
+ groupDescSortOrder := groupFields[21].Descriptor()
// group.DefaultSortOrder holds the default value on creation for the sort_order field.
group.DefaultSortOrder = groupDescSortOrder.Default.(int)
// groupDescAllowMessagesDispatch is the schema descriptor for allow_messages_dispatch field.
- groupDescAllowMessagesDispatch := groupFields[27].Descriptor()
+ groupDescAllowMessagesDispatch := groupFields[22].Descriptor()
// group.DefaultAllowMessagesDispatch holds the default value on creation for the allow_messages_dispatch field.
group.DefaultAllowMessagesDispatch = groupDescAllowMessagesDispatch.Default.(bool)
// groupDescRequireOauthOnly is the schema descriptor for require_oauth_only field.
- groupDescRequireOauthOnly := groupFields[28].Descriptor()
+ groupDescRequireOauthOnly := groupFields[23].Descriptor()
// group.DefaultRequireOauthOnly holds the default value on creation for the require_oauth_only field.
group.DefaultRequireOauthOnly = groupDescRequireOauthOnly.Default.(bool)
// groupDescRequirePrivacySet is the schema descriptor for require_privacy_set field.
- groupDescRequirePrivacySet := groupFields[29].Descriptor()
+ groupDescRequirePrivacySet := groupFields[24].Descriptor()
// group.DefaultRequirePrivacySet holds the default value on creation for the require_privacy_set field.
group.DefaultRequirePrivacySet = groupDescRequirePrivacySet.Default.(bool)
// groupDescDefaultMappedModel is the schema descriptor for default_mapped_model field.
- groupDescDefaultMappedModel := groupFields[30].Descriptor()
+ groupDescDefaultMappedModel := groupFields[25].Descriptor()
// group.DefaultDefaultMappedModel holds the default value on creation for the default_mapped_model field.
group.DefaultDefaultMappedModel = groupDescDefaultMappedModel.Default.(string)
// group.DefaultMappedModelValidator is a validator for the "default_mapped_model" field. It is called by the builders before save.
group.DefaultMappedModelValidator = groupDescDefaultMappedModel.Validators[0].(func(string) error)
+ // groupDescMessagesDispatchModelConfig is the schema descriptor for messages_dispatch_model_config field.
+ groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor()
+ // group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field.
+ group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig)
+ // groupDescRpmLimit is the schema descriptor for rpm_limit field.
+ groupDescRpmLimit := groupFields[27].Descriptor()
+ // group.DefaultRpmLimit holds the default value on creation for the rpm_limit field.
+ group.DefaultRpmLimit = groupDescRpmLimit.Default.(int)
idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin()
idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields()
_ = idempotencyrecordMixinFields0
@@ -507,6 +884,314 @@ func init() {
idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor()
// idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save.
idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error)
+ identityadoptiondecisionMixin := schema.IdentityAdoptionDecision{}.Mixin()
+ identityadoptiondecisionMixinFields0 := identityadoptiondecisionMixin[0].Fields()
+ _ = identityadoptiondecisionMixinFields0
+ identityadoptiondecisionFields := schema.IdentityAdoptionDecision{}.Fields()
+ _ = identityadoptiondecisionFields
+ // identityadoptiondecisionDescCreatedAt is the schema descriptor for created_at field.
+ identityadoptiondecisionDescCreatedAt := identityadoptiondecisionMixinFields0[0].Descriptor()
+ // identityadoptiondecision.DefaultCreatedAt holds the default value on creation for the created_at field.
+ identityadoptiondecision.DefaultCreatedAt = identityadoptiondecisionDescCreatedAt.Default.(func() time.Time)
+ // identityadoptiondecisionDescUpdatedAt is the schema descriptor for updated_at field.
+ identityadoptiondecisionDescUpdatedAt := identityadoptiondecisionMixinFields0[1].Descriptor()
+ // identityadoptiondecision.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ identityadoptiondecision.DefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.Default.(func() time.Time)
+ // identityadoptiondecision.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ identityadoptiondecision.UpdateDefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // identityadoptiondecisionDescAdoptDisplayName is the schema descriptor for adopt_display_name field.
+ identityadoptiondecisionDescAdoptDisplayName := identityadoptiondecisionFields[2].Descriptor()
+ // identityadoptiondecision.DefaultAdoptDisplayName holds the default value on creation for the adopt_display_name field.
+ identityadoptiondecision.DefaultAdoptDisplayName = identityadoptiondecisionDescAdoptDisplayName.Default.(bool)
+ // identityadoptiondecisionDescAdoptAvatar is the schema descriptor for adopt_avatar field.
+ identityadoptiondecisionDescAdoptAvatar := identityadoptiondecisionFields[3].Descriptor()
+ // identityadoptiondecision.DefaultAdoptAvatar holds the default value on creation for the adopt_avatar field.
+ identityadoptiondecision.DefaultAdoptAvatar = identityadoptiondecisionDescAdoptAvatar.Default.(bool)
+ // identityadoptiondecisionDescDecidedAt is the schema descriptor for decided_at field.
+ identityadoptiondecisionDescDecidedAt := identityadoptiondecisionFields[4].Descriptor()
+ // identityadoptiondecision.DefaultDecidedAt holds the default value on creation for the decided_at field.
+ identityadoptiondecision.DefaultDecidedAt = identityadoptiondecisionDescDecidedAt.Default.(func() time.Time)
+ paymentauditlogFields := schema.PaymentAuditLog{}.Fields()
+ _ = paymentauditlogFields
+ // paymentauditlogDescOrderID is the schema descriptor for order_id field.
+ paymentauditlogDescOrderID := paymentauditlogFields[0].Descriptor()
+ // paymentauditlog.OrderIDValidator is a validator for the "order_id" field. It is called by the builders before save.
+ paymentauditlog.OrderIDValidator = paymentauditlogDescOrderID.Validators[0].(func(string) error)
+ // paymentauditlogDescAction is the schema descriptor for action field.
+ paymentauditlogDescAction := paymentauditlogFields[1].Descriptor()
+ // paymentauditlog.ActionValidator is a validator for the "action" field. It is called by the builders before save.
+ paymentauditlog.ActionValidator = paymentauditlogDescAction.Validators[0].(func(string) error)
+ // paymentauditlogDescDetail is the schema descriptor for detail field.
+ paymentauditlogDescDetail := paymentauditlogFields[2].Descriptor()
+ // paymentauditlog.DefaultDetail holds the default value on creation for the detail field.
+ paymentauditlog.DefaultDetail = paymentauditlogDescDetail.Default.(string)
+ // paymentauditlogDescOperator is the schema descriptor for operator field.
+ paymentauditlogDescOperator := paymentauditlogFields[3].Descriptor()
+ // paymentauditlog.DefaultOperator holds the default value on creation for the operator field.
+ paymentauditlog.DefaultOperator = paymentauditlogDescOperator.Default.(string)
+ // paymentauditlog.OperatorValidator is a validator for the "operator" field. It is called by the builders before save.
+ paymentauditlog.OperatorValidator = paymentauditlogDescOperator.Validators[0].(func(string) error)
+ // paymentauditlogDescCreatedAt is the schema descriptor for created_at field.
+ paymentauditlogDescCreatedAt := paymentauditlogFields[4].Descriptor()
+ // paymentauditlog.DefaultCreatedAt holds the default value on creation for the created_at field.
+ paymentauditlog.DefaultCreatedAt = paymentauditlogDescCreatedAt.Default.(func() time.Time)
+ paymentorderFields := schema.PaymentOrder{}.Fields()
+ _ = paymentorderFields
+ // paymentorderDescUserEmail is the schema descriptor for user_email field.
+ paymentorderDescUserEmail := paymentorderFields[1].Descriptor()
+ // paymentorder.UserEmailValidator is a validator for the "user_email" field. It is called by the builders before save.
+ paymentorder.UserEmailValidator = paymentorderDescUserEmail.Validators[0].(func(string) error)
+ // paymentorderDescUserName is the schema descriptor for user_name field.
+ paymentorderDescUserName := paymentorderFields[2].Descriptor()
+ // paymentorder.UserNameValidator is a validator for the "user_name" field. It is called by the builders before save.
+ paymentorder.UserNameValidator = paymentorderDescUserName.Validators[0].(func(string) error)
+ // paymentorderDescFeeRate is the schema descriptor for fee_rate field.
+ paymentorderDescFeeRate := paymentorderFields[6].Descriptor()
+ // paymentorder.DefaultFeeRate holds the default value on creation for the fee_rate field.
+ paymentorder.DefaultFeeRate = paymentorderDescFeeRate.Default.(float64)
+ // paymentorderDescRechargeCode is the schema descriptor for recharge_code field.
+ paymentorderDescRechargeCode := paymentorderFields[7].Descriptor()
+ // paymentorder.RechargeCodeValidator is a validator for the "recharge_code" field. It is called by the builders before save.
+ paymentorder.RechargeCodeValidator = paymentorderDescRechargeCode.Validators[0].(func(string) error)
+ // paymentorderDescOutTradeNo is the schema descriptor for out_trade_no field.
+ paymentorderDescOutTradeNo := paymentorderFields[8].Descriptor()
+ // paymentorder.DefaultOutTradeNo holds the default value on creation for the out_trade_no field.
+ paymentorder.DefaultOutTradeNo = paymentorderDescOutTradeNo.Default.(string)
+ // paymentorder.OutTradeNoValidator is a validator for the "out_trade_no" field. It is called by the builders before save.
+ paymentorder.OutTradeNoValidator = paymentorderDescOutTradeNo.Validators[0].(func(string) error)
+ // paymentorderDescPaymentType is the schema descriptor for payment_type field.
+ paymentorderDescPaymentType := paymentorderFields[9].Descriptor()
+ // paymentorder.PaymentTypeValidator is a validator for the "payment_type" field. It is called by the builders before save.
+ paymentorder.PaymentTypeValidator = paymentorderDescPaymentType.Validators[0].(func(string) error)
+ // paymentorderDescPaymentTradeNo is the schema descriptor for payment_trade_no field.
+ paymentorderDescPaymentTradeNo := paymentorderFields[10].Descriptor()
+ // paymentorder.PaymentTradeNoValidator is a validator for the "payment_trade_no" field. It is called by the builders before save.
+ paymentorder.PaymentTradeNoValidator = paymentorderDescPaymentTradeNo.Validators[0].(func(string) error)
+ // paymentorderDescOrderType is the schema descriptor for order_type field.
+ paymentorderDescOrderType := paymentorderFields[14].Descriptor()
+ // paymentorder.DefaultOrderType holds the default value on creation for the order_type field.
+ paymentorder.DefaultOrderType = paymentorderDescOrderType.Default.(string)
+ // paymentorder.OrderTypeValidator is a validator for the "order_type" field. It is called by the builders before save.
+ paymentorder.OrderTypeValidator = paymentorderDescOrderType.Validators[0].(func(string) error)
+ // paymentorderDescProviderInstanceID is the schema descriptor for provider_instance_id field.
+ paymentorderDescProviderInstanceID := paymentorderFields[18].Descriptor()
+ // paymentorder.ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save.
+ paymentorder.ProviderInstanceIDValidator = paymentorderDescProviderInstanceID.Validators[0].(func(string) error)
+ // paymentorderDescProviderKey is the schema descriptor for provider_key field.
+ paymentorderDescProviderKey := paymentorderFields[19].Descriptor()
+ // paymentorder.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ paymentorder.ProviderKeyValidator = paymentorderDescProviderKey.Validators[0].(func(string) error)
+ // paymentorderDescStatus is the schema descriptor for status field.
+ paymentorderDescStatus := paymentorderFields[21].Descriptor()
+ // paymentorder.DefaultStatus holds the default value on creation for the status field.
+ paymentorder.DefaultStatus = paymentorderDescStatus.Default.(string)
+ // paymentorder.StatusValidator is a validator for the "status" field. It is called by the builders before save.
+ paymentorder.StatusValidator = paymentorderDescStatus.Validators[0].(func(string) error)
+ // paymentorderDescRefundAmount is the schema descriptor for refund_amount field.
+ paymentorderDescRefundAmount := paymentorderFields[22].Descriptor()
+ // paymentorder.DefaultRefundAmount holds the default value on creation for the refund_amount field.
+ paymentorder.DefaultRefundAmount = paymentorderDescRefundAmount.Default.(float64)
+ // paymentorderDescForceRefund is the schema descriptor for force_refund field.
+ paymentorderDescForceRefund := paymentorderFields[25].Descriptor()
+ // paymentorder.DefaultForceRefund holds the default value on creation for the force_refund field.
+ paymentorder.DefaultForceRefund = paymentorderDescForceRefund.Default.(bool)
+ // paymentorderDescRefundRequestedBy is the schema descriptor for refund_requested_by field.
+ paymentorderDescRefundRequestedBy := paymentorderFields[28].Descriptor()
+ // paymentorder.RefundRequestedByValidator is a validator for the "refund_requested_by" field. It is called by the builders before save.
+ paymentorder.RefundRequestedByValidator = paymentorderDescRefundRequestedBy.Validators[0].(func(string) error)
+ // paymentorderDescClientIP is the schema descriptor for client_ip field.
+ paymentorderDescClientIP := paymentorderFields[34].Descriptor()
+ // paymentorder.ClientIPValidator is a validator for the "client_ip" field. It is called by the builders before save.
+ paymentorder.ClientIPValidator = paymentorderDescClientIP.Validators[0].(func(string) error)
+ // paymentorderDescSrcHost is the schema descriptor for src_host field.
+ paymentorderDescSrcHost := paymentorderFields[35].Descriptor()
+ // paymentorder.SrcHostValidator is a validator for the "src_host" field. It is called by the builders before save.
+ paymentorder.SrcHostValidator = paymentorderDescSrcHost.Validators[0].(func(string) error)
+ // paymentorderDescCreatedAt is the schema descriptor for created_at field.
+ paymentorderDescCreatedAt := paymentorderFields[37].Descriptor()
+ // paymentorder.DefaultCreatedAt holds the default value on creation for the created_at field.
+ paymentorder.DefaultCreatedAt = paymentorderDescCreatedAt.Default.(func() time.Time)
+ // paymentorderDescUpdatedAt is the schema descriptor for updated_at field.
+ paymentorderDescUpdatedAt := paymentorderFields[38].Descriptor()
+ // paymentorder.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ paymentorder.DefaultUpdatedAt = paymentorderDescUpdatedAt.Default.(func() time.Time)
+ // paymentorder.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ paymentorder.UpdateDefaultUpdatedAt = paymentorderDescUpdatedAt.UpdateDefault.(func() time.Time)
+ paymentproviderinstanceFields := schema.PaymentProviderInstance{}.Fields()
+ _ = paymentproviderinstanceFields
+ // paymentproviderinstanceDescProviderKey is the schema descriptor for provider_key field.
+ paymentproviderinstanceDescProviderKey := paymentproviderinstanceFields[0].Descriptor()
+ // paymentproviderinstance.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ paymentproviderinstance.ProviderKeyValidator = func() func(string) error {
+ validators := paymentproviderinstanceDescProviderKey.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(provider_key string) error {
+ for _, fn := range fns {
+ if err := fn(provider_key); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // paymentproviderinstanceDescName is the schema descriptor for name field.
+ paymentproviderinstanceDescName := paymentproviderinstanceFields[1].Descriptor()
+ // paymentproviderinstance.DefaultName holds the default value on creation for the name field.
+ paymentproviderinstance.DefaultName = paymentproviderinstanceDescName.Default.(string)
+ // paymentproviderinstance.NameValidator is a validator for the "name" field. It is called by the builders before save.
+ paymentproviderinstance.NameValidator = paymentproviderinstanceDescName.Validators[0].(func(string) error)
+ // paymentproviderinstanceDescSupportedTypes is the schema descriptor for supported_types field.
+ paymentproviderinstanceDescSupportedTypes := paymentproviderinstanceFields[3].Descriptor()
+ // paymentproviderinstance.DefaultSupportedTypes holds the default value on creation for the supported_types field.
+ paymentproviderinstance.DefaultSupportedTypes = paymentproviderinstanceDescSupportedTypes.Default.(string)
+ // paymentproviderinstance.SupportedTypesValidator is a validator for the "supported_types" field. It is called by the builders before save.
+ paymentproviderinstance.SupportedTypesValidator = paymentproviderinstanceDescSupportedTypes.Validators[0].(func(string) error)
+ // paymentproviderinstanceDescEnabled is the schema descriptor for enabled field.
+ paymentproviderinstanceDescEnabled := paymentproviderinstanceFields[4].Descriptor()
+ // paymentproviderinstance.DefaultEnabled holds the default value on creation for the enabled field.
+ paymentproviderinstance.DefaultEnabled = paymentproviderinstanceDescEnabled.Default.(bool)
+ // paymentproviderinstanceDescPaymentMode is the schema descriptor for payment_mode field.
+ paymentproviderinstanceDescPaymentMode := paymentproviderinstanceFields[5].Descriptor()
+ // paymentproviderinstance.DefaultPaymentMode holds the default value on creation for the payment_mode field.
+ paymentproviderinstance.DefaultPaymentMode = paymentproviderinstanceDescPaymentMode.Default.(string)
+ // paymentproviderinstance.PaymentModeValidator is a validator for the "payment_mode" field. It is called by the builders before save.
+ paymentproviderinstance.PaymentModeValidator = paymentproviderinstanceDescPaymentMode.Validators[0].(func(string) error)
+ // paymentproviderinstanceDescSortOrder is the schema descriptor for sort_order field.
+ paymentproviderinstanceDescSortOrder := paymentproviderinstanceFields[6].Descriptor()
+ // paymentproviderinstance.DefaultSortOrder holds the default value on creation for the sort_order field.
+ paymentproviderinstance.DefaultSortOrder = paymentproviderinstanceDescSortOrder.Default.(int)
+ // paymentproviderinstanceDescLimits is the schema descriptor for limits field.
+ paymentproviderinstanceDescLimits := paymentproviderinstanceFields[7].Descriptor()
+ // paymentproviderinstance.DefaultLimits holds the default value on creation for the limits field.
+ paymentproviderinstance.DefaultLimits = paymentproviderinstanceDescLimits.Default.(string)
+ // paymentproviderinstanceDescRefundEnabled is the schema descriptor for refund_enabled field.
+ paymentproviderinstanceDescRefundEnabled := paymentproviderinstanceFields[8].Descriptor()
+ // paymentproviderinstance.DefaultRefundEnabled holds the default value on creation for the refund_enabled field.
+ paymentproviderinstance.DefaultRefundEnabled = paymentproviderinstanceDescRefundEnabled.Default.(bool)
+ // paymentproviderinstanceDescAllowUserRefund is the schema descriptor for allow_user_refund field.
+ paymentproviderinstanceDescAllowUserRefund := paymentproviderinstanceFields[9].Descriptor()
+ // paymentproviderinstance.DefaultAllowUserRefund holds the default value on creation for the allow_user_refund field.
+ paymentproviderinstance.DefaultAllowUserRefund = paymentproviderinstanceDescAllowUserRefund.Default.(bool)
+ // paymentproviderinstanceDescCreatedAt is the schema descriptor for created_at field.
+ paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[10].Descriptor()
+ // paymentproviderinstance.DefaultCreatedAt holds the default value on creation for the created_at field.
+ paymentproviderinstance.DefaultCreatedAt = paymentproviderinstanceDescCreatedAt.Default.(func() time.Time)
+ // paymentproviderinstanceDescUpdatedAt is the schema descriptor for updated_at field.
+ paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[11].Descriptor()
+ // paymentproviderinstance.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time)
+ // paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ paymentproviderinstance.UpdateDefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.UpdateDefault.(func() time.Time)
+ pendingauthsessionMixin := schema.PendingAuthSession{}.Mixin()
+ pendingauthsessionMixinFields0 := pendingauthsessionMixin[0].Fields()
+ _ = pendingauthsessionMixinFields0
+ pendingauthsessionFields := schema.PendingAuthSession{}.Fields()
+ _ = pendingauthsessionFields
+ // pendingauthsessionDescCreatedAt is the schema descriptor for created_at field.
+ pendingauthsessionDescCreatedAt := pendingauthsessionMixinFields0[0].Descriptor()
+ // pendingauthsession.DefaultCreatedAt holds the default value on creation for the created_at field.
+ pendingauthsession.DefaultCreatedAt = pendingauthsessionDescCreatedAt.Default.(func() time.Time)
+ // pendingauthsessionDescUpdatedAt is the schema descriptor for updated_at field.
+ pendingauthsessionDescUpdatedAt := pendingauthsessionMixinFields0[1].Descriptor()
+ // pendingauthsession.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ pendingauthsession.DefaultUpdatedAt = pendingauthsessionDescUpdatedAt.Default.(func() time.Time)
+ // pendingauthsession.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ pendingauthsession.UpdateDefaultUpdatedAt = pendingauthsessionDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // pendingauthsessionDescSessionToken is the schema descriptor for session_token field.
+ pendingauthsessionDescSessionToken := pendingauthsessionFields[0].Descriptor()
+ // pendingauthsession.SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save.
+ pendingauthsession.SessionTokenValidator = func() func(string) error {
+ validators := pendingauthsessionDescSessionToken.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(session_token string) error {
+ for _, fn := range fns {
+ if err := fn(session_token); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // pendingauthsessionDescIntent is the schema descriptor for intent field.
+ pendingauthsessionDescIntent := pendingauthsessionFields[1].Descriptor()
+ // pendingauthsession.IntentValidator is a validator for the "intent" field. It is called by the builders before save.
+ pendingauthsession.IntentValidator = func() func(string) error {
+ validators := pendingauthsessionDescIntent.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(intent string) error {
+ for _, fn := range fns {
+ if err := fn(intent); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // pendingauthsessionDescProviderType is the schema descriptor for provider_type field.
+ pendingauthsessionDescProviderType := pendingauthsessionFields[2].Descriptor()
+ // pendingauthsession.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save.
+ pendingauthsession.ProviderTypeValidator = func() func(string) error {
+ validators := pendingauthsessionDescProviderType.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ validators[2].(func(string) error),
+ }
+ return func(provider_type string) error {
+ for _, fn := range fns {
+ if err := fn(provider_type); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // pendingauthsessionDescProviderKey is the schema descriptor for provider_key field.
+ pendingauthsessionDescProviderKey := pendingauthsessionFields[3].Descriptor()
+ // pendingauthsession.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save.
+ pendingauthsession.ProviderKeyValidator = pendingauthsessionDescProviderKey.Validators[0].(func(string) error)
+ // pendingauthsessionDescProviderSubject is the schema descriptor for provider_subject field.
+ pendingauthsessionDescProviderSubject := pendingauthsessionFields[4].Descriptor()
+ // pendingauthsession.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save.
+ pendingauthsession.ProviderSubjectValidator = pendingauthsessionDescProviderSubject.Validators[0].(func(string) error)
+ // pendingauthsessionDescRedirectTo is the schema descriptor for redirect_to field.
+ pendingauthsessionDescRedirectTo := pendingauthsessionFields[6].Descriptor()
+ // pendingauthsession.DefaultRedirectTo holds the default value on creation for the redirect_to field.
+ pendingauthsession.DefaultRedirectTo = pendingauthsessionDescRedirectTo.Default.(string)
+ // pendingauthsessionDescResolvedEmail is the schema descriptor for resolved_email field.
+ pendingauthsessionDescResolvedEmail := pendingauthsessionFields[7].Descriptor()
+ // pendingauthsession.DefaultResolvedEmail holds the default value on creation for the resolved_email field.
+ pendingauthsession.DefaultResolvedEmail = pendingauthsessionDescResolvedEmail.Default.(string)
+ // pendingauthsessionDescRegistrationPasswordHash is the schema descriptor for registration_password_hash field.
+ pendingauthsessionDescRegistrationPasswordHash := pendingauthsessionFields[8].Descriptor()
+ // pendingauthsession.DefaultRegistrationPasswordHash holds the default value on creation for the registration_password_hash field.
+ pendingauthsession.DefaultRegistrationPasswordHash = pendingauthsessionDescRegistrationPasswordHash.Default.(string)
+ // pendingauthsessionDescUpstreamIdentityClaims is the schema descriptor for upstream_identity_claims field.
+ pendingauthsessionDescUpstreamIdentityClaims := pendingauthsessionFields[9].Descriptor()
+ // pendingauthsession.DefaultUpstreamIdentityClaims holds the default value on creation for the upstream_identity_claims field.
+ pendingauthsession.DefaultUpstreamIdentityClaims = pendingauthsessionDescUpstreamIdentityClaims.Default.(func() map[string]interface{})
+ // pendingauthsessionDescLocalFlowState is the schema descriptor for local_flow_state field.
+ pendingauthsessionDescLocalFlowState := pendingauthsessionFields[10].Descriptor()
+ // pendingauthsession.DefaultLocalFlowState holds the default value on creation for the local_flow_state field.
+ pendingauthsession.DefaultLocalFlowState = pendingauthsessionDescLocalFlowState.Default.(func() map[string]interface{})
+ // pendingauthsessionDescBrowserSessionKey is the schema descriptor for browser_session_key field.
+ pendingauthsessionDescBrowserSessionKey := pendingauthsessionFields[11].Descriptor()
+ // pendingauthsession.DefaultBrowserSessionKey holds the default value on creation for the browser_session_key field.
+ pendingauthsession.DefaultBrowserSessionKey = pendingauthsessionDescBrowserSessionKey.Default.(string)
+ // pendingauthsessionDescCompletionCodeHash is the schema descriptor for completion_code_hash field.
+ pendingauthsessionDescCompletionCodeHash := pendingauthsessionFields[12].Descriptor()
+ // pendingauthsession.DefaultCompletionCodeHash holds the default value on creation for the completion_code_hash field.
+ pendingauthsession.DefaultCompletionCodeHash = pendingauthsessionDescCompletionCodeHash.Default.(string)
promocodeFields := schema.PromoCode{}.Fields()
_ = promocodeFields
// promocodeDescCode is the schema descriptor for code field.
@@ -755,6 +1440,68 @@ func init() {
setting.DefaultUpdatedAt = settingDescUpdatedAt.Default.(func() time.Time)
// setting.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
setting.UpdateDefaultUpdatedAt = settingDescUpdatedAt.UpdateDefault.(func() time.Time)
+ subscriptionplanFields := schema.SubscriptionPlan{}.Fields()
+ _ = subscriptionplanFields
+ // subscriptionplanDescName is the schema descriptor for name field.
+ subscriptionplanDescName := subscriptionplanFields[1].Descriptor()
+ // subscriptionplan.NameValidator is a validator for the "name" field. It is called by the builders before save.
+ subscriptionplan.NameValidator = func() func(string) error {
+ validators := subscriptionplanDescName.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(name string) error {
+ for _, fn := range fns {
+ if err := fn(name); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // subscriptionplanDescDescription is the schema descriptor for description field.
+ subscriptionplanDescDescription := subscriptionplanFields[2].Descriptor()
+ // subscriptionplan.DefaultDescription holds the default value on creation for the description field.
+ subscriptionplan.DefaultDescription = subscriptionplanDescDescription.Default.(string)
+ // subscriptionplanDescValidityDays is the schema descriptor for validity_days field.
+ subscriptionplanDescValidityDays := subscriptionplanFields[5].Descriptor()
+ // subscriptionplan.DefaultValidityDays holds the default value on creation for the validity_days field.
+ subscriptionplan.DefaultValidityDays = subscriptionplanDescValidityDays.Default.(int)
+ // subscriptionplanDescValidityUnit is the schema descriptor for validity_unit field.
+ subscriptionplanDescValidityUnit := subscriptionplanFields[6].Descriptor()
+ // subscriptionplan.DefaultValidityUnit holds the default value on creation for the validity_unit field.
+ subscriptionplan.DefaultValidityUnit = subscriptionplanDescValidityUnit.Default.(string)
+ // subscriptionplan.ValidityUnitValidator is a validator for the "validity_unit" field. It is called by the builders before save.
+ subscriptionplan.ValidityUnitValidator = subscriptionplanDescValidityUnit.Validators[0].(func(string) error)
+ // subscriptionplanDescFeatures is the schema descriptor for features field.
+ subscriptionplanDescFeatures := subscriptionplanFields[7].Descriptor()
+ // subscriptionplan.DefaultFeatures holds the default value on creation for the features field.
+ subscriptionplan.DefaultFeatures = subscriptionplanDescFeatures.Default.(string)
+ // subscriptionplanDescProductName is the schema descriptor for product_name field.
+ subscriptionplanDescProductName := subscriptionplanFields[8].Descriptor()
+ // subscriptionplan.DefaultProductName holds the default value on creation for the product_name field.
+ subscriptionplan.DefaultProductName = subscriptionplanDescProductName.Default.(string)
+ // subscriptionplan.ProductNameValidator is a validator for the "product_name" field. It is called by the builders before save.
+ subscriptionplan.ProductNameValidator = subscriptionplanDescProductName.Validators[0].(func(string) error)
+ // subscriptionplanDescForSale is the schema descriptor for for_sale field.
+ subscriptionplanDescForSale := subscriptionplanFields[9].Descriptor()
+ // subscriptionplan.DefaultForSale holds the default value on creation for the for_sale field.
+ subscriptionplan.DefaultForSale = subscriptionplanDescForSale.Default.(bool)
+ // subscriptionplanDescSortOrder is the schema descriptor for sort_order field.
+ subscriptionplanDescSortOrder := subscriptionplanFields[10].Descriptor()
+ // subscriptionplan.DefaultSortOrder holds the default value on creation for the sort_order field.
+ subscriptionplan.DefaultSortOrder = subscriptionplanDescSortOrder.Default.(int)
+ // subscriptionplanDescCreatedAt is the schema descriptor for created_at field.
+ subscriptionplanDescCreatedAt := subscriptionplanFields[11].Descriptor()
+ // subscriptionplan.DefaultCreatedAt holds the default value on creation for the created_at field.
+ subscriptionplan.DefaultCreatedAt = subscriptionplanDescCreatedAt.Default.(func() time.Time)
+ // subscriptionplanDescUpdatedAt is the schema descriptor for updated_at field.
+ subscriptionplanDescUpdatedAt := subscriptionplanFields[12].Descriptor()
+ // subscriptionplan.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ subscriptionplan.DefaultUpdatedAt = subscriptionplanDescUpdatedAt.Default.(func() time.Time)
+ // subscriptionplan.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ subscriptionplan.UpdateDefaultUpdatedAt = subscriptionplanDescUpdatedAt.UpdateDefault.(func() time.Time)
tlsfingerprintprofileMixin := schema.TLSFingerprintProfile{}.Mixin()
tlsfingerprintprofileMixinFields0 := tlsfingerprintprofileMixin[0].Fields()
_ = tlsfingerprintprofileMixinFields0
@@ -875,92 +1622,100 @@ func init() {
usagelogDescUpstreamModel := usagelogFields[6].Descriptor()
// usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error)
+ // usagelogDescModelMappingChain is the schema descriptor for model_mapping_chain field.
+ usagelogDescModelMappingChain := usagelogFields[8].Descriptor()
+ // usagelog.ModelMappingChainValidator is a validator for the "model_mapping_chain" field. It is called by the builders before save.
+ usagelog.ModelMappingChainValidator = usagelogDescModelMappingChain.Validators[0].(func(string) error)
+ // usagelogDescBillingTier is the schema descriptor for billing_tier field.
+ usagelogDescBillingTier := usagelogFields[9].Descriptor()
+ // usagelog.BillingTierValidator is a validator for the "billing_tier" field. It is called by the builders before save.
+ usagelog.BillingTierValidator = usagelogDescBillingTier.Validators[0].(func(string) error)
+ // usagelogDescBillingMode is the schema descriptor for billing_mode field.
+ usagelogDescBillingMode := usagelogFields[10].Descriptor()
+ // usagelog.BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save.
+ usagelog.BillingModeValidator = usagelogDescBillingMode.Validators[0].(func(string) error)
// usagelogDescInputTokens is the schema descriptor for input_tokens field.
- usagelogDescInputTokens := usagelogFields[9].Descriptor()
+ usagelogDescInputTokens := usagelogFields[13].Descriptor()
// usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field.
usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int)
// usagelogDescOutputTokens is the schema descriptor for output_tokens field.
- usagelogDescOutputTokens := usagelogFields[10].Descriptor()
+ usagelogDescOutputTokens := usagelogFields[14].Descriptor()
// usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field.
usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int)
// usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field.
- usagelogDescCacheCreationTokens := usagelogFields[11].Descriptor()
+ usagelogDescCacheCreationTokens := usagelogFields[15].Descriptor()
// usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field.
usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int)
// usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field.
- usagelogDescCacheReadTokens := usagelogFields[12].Descriptor()
+ usagelogDescCacheReadTokens := usagelogFields[16].Descriptor()
// usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field.
usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int)
// usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field.
- usagelogDescCacheCreation5mTokens := usagelogFields[13].Descriptor()
+ usagelogDescCacheCreation5mTokens := usagelogFields[17].Descriptor()
// usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field.
usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int)
// usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field.
- usagelogDescCacheCreation1hTokens := usagelogFields[14].Descriptor()
+ usagelogDescCacheCreation1hTokens := usagelogFields[18].Descriptor()
// usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field.
usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int)
// usagelogDescInputCost is the schema descriptor for input_cost field.
- usagelogDescInputCost := usagelogFields[15].Descriptor()
+ usagelogDescInputCost := usagelogFields[19].Descriptor()
// usagelog.DefaultInputCost holds the default value on creation for the input_cost field.
usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64)
// usagelogDescOutputCost is the schema descriptor for output_cost field.
- usagelogDescOutputCost := usagelogFields[16].Descriptor()
+ usagelogDescOutputCost := usagelogFields[20].Descriptor()
// usagelog.DefaultOutputCost holds the default value on creation for the output_cost field.
usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64)
// usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field.
- usagelogDescCacheCreationCost := usagelogFields[17].Descriptor()
+ usagelogDescCacheCreationCost := usagelogFields[21].Descriptor()
// usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field.
usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64)
// usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field.
- usagelogDescCacheReadCost := usagelogFields[18].Descriptor()
+ usagelogDescCacheReadCost := usagelogFields[22].Descriptor()
// usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field.
usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64)
// usagelogDescTotalCost is the schema descriptor for total_cost field.
- usagelogDescTotalCost := usagelogFields[19].Descriptor()
+ usagelogDescTotalCost := usagelogFields[23].Descriptor()
// usagelog.DefaultTotalCost holds the default value on creation for the total_cost field.
usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64)
// usagelogDescActualCost is the schema descriptor for actual_cost field.
- usagelogDescActualCost := usagelogFields[20].Descriptor()
+ usagelogDescActualCost := usagelogFields[24].Descriptor()
// usagelog.DefaultActualCost holds the default value on creation for the actual_cost field.
usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64)
// usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field.
- usagelogDescRateMultiplier := usagelogFields[21].Descriptor()
+ usagelogDescRateMultiplier := usagelogFields[25].Descriptor()
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
// usagelogDescBillingType is the schema descriptor for billing_type field.
- usagelogDescBillingType := usagelogFields[23].Descriptor()
+ usagelogDescBillingType := usagelogFields[27].Descriptor()
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
// usagelogDescStream is the schema descriptor for stream field.
- usagelogDescStream := usagelogFields[24].Descriptor()
+ usagelogDescStream := usagelogFields[28].Descriptor()
// usagelog.DefaultStream holds the default value on creation for the stream field.
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
// usagelogDescUserAgent is the schema descriptor for user_agent field.
- usagelogDescUserAgent := usagelogFields[27].Descriptor()
+ usagelogDescUserAgent := usagelogFields[31].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field.
- usagelogDescIPAddress := usagelogFields[28].Descriptor()
+ usagelogDescIPAddress := usagelogFields[32].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field.
- usagelogDescImageCount := usagelogFields[29].Descriptor()
+ usagelogDescImageCount := usagelogFields[33].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field.
- usagelogDescImageSize := usagelogFields[30].Descriptor()
+ usagelogDescImageSize := usagelogFields[34].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
- // usagelogDescMediaType is the schema descriptor for media_type field.
- usagelogDescMediaType := usagelogFields[31].Descriptor()
- // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
- usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
- usagelogDescCacheTTLOverridden := usagelogFields[32].Descriptor()
+ usagelogDescCacheTTLOverridden := usagelogFields[35].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
- usagelogDescCreatedAt := usagelogFields[33].Descriptor()
+ usagelogDescCreatedAt := usagelogFields[36].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()
@@ -1052,14 +1807,32 @@ func init() {
userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
- // userDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field.
- userDescSoraStorageQuotaBytes := userFields[11].Descriptor()
- // user.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field.
- user.DefaultSoraStorageQuotaBytes = userDescSoraStorageQuotaBytes.Default.(int64)
- // userDescSoraStorageUsedBytes is the schema descriptor for sora_storage_used_bytes field.
- userDescSoraStorageUsedBytes := userFields[12].Descriptor()
- // user.DefaultSoraStorageUsedBytes holds the default value on creation for the sora_storage_used_bytes field.
- user.DefaultSoraStorageUsedBytes = userDescSoraStorageUsedBytes.Default.(int64)
+ // userDescSignupSource is the schema descriptor for signup_source field.
+ userDescSignupSource := userFields[11].Descriptor()
+ // user.DefaultSignupSource holds the default value on creation for the signup_source field.
+ user.DefaultSignupSource = userDescSignupSource.Default.(string)
+ // user.SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save.
+ user.SignupSourceValidator = userDescSignupSource.Validators[0].(func(string) error)
+ // userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field.
+ userDescBalanceNotifyEnabled := userFields[14].Descriptor()
+ // user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field.
+ user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool)
+ // userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field.
+ userDescBalanceNotifyThresholdType := userFields[15].Descriptor()
+ // user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field.
+ user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string)
+ // userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field.
+ userDescBalanceNotifyExtraEmails := userFields[17].Descriptor()
+ // user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field.
+ user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string)
+ // userDescTotalRecharged is the schema descriptor for total_recharged field.
+ userDescTotalRecharged := userFields[18].Descriptor()
+ // user.DefaultTotalRecharged holds the default value on creation for the total_recharged field.
+ user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64)
+ // userDescRpmLimit is the schema descriptor for rpm_limit field.
+ userDescRpmLimit := userFields[19].Descriptor()
+ // user.DefaultRpmLimit holds the default value on creation for the rpm_limit field.
+ user.DefaultRpmLimit = userDescRpmLimit.Default.(int)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
diff --git a/backend/ent/schema/auth_identity.go b/backend/ent/schema/auth_identity.go
new file mode 100644
index 00000000..0b1b56ab
--- /dev/null
+++ b/backend/ent/schema/auth_identity.go
@@ -0,0 +1,94 @@
+package schema
+
+import (
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+var authProviderTypes = map[string]struct{}{
+ "email": {},
+ "linuxdo": {},
+ "oidc": {},
+ "wechat": {},
+}
+
+func validateAuthProviderType(value string) error {
+ if _, ok := authProviderTypes[value]; ok {
+ return nil
+ }
+ return fmt.Errorf("invalid auth provider type %q", value)
+}
+
+// AuthIdentity stores the canonical login identity for an account.
+type AuthIdentity struct {
+ ent.Schema
+}
+
+func (AuthIdentity) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "auth_identities"},
+ }
+}
+
+func (AuthIdentity) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (AuthIdentity) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("user_id"),
+ field.String("provider_type").
+ MaxLen(20).
+ NotEmpty().
+ Validate(validateAuthProviderType),
+ field.String("provider_key").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("provider_subject").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Time("verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.String("issuer").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.JSON("metadata", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ }
+}
+
+func (AuthIdentity) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("user", User.Type).
+ Ref("auth_identities").
+ Field("user_id").
+ Required().
+ Unique(),
+ edge.To("channels", AuthIdentityChannel.Type).
+ Annotations(entsql.OnDelete(entsql.Cascade)),
+ edge.To("adoption_decisions", IdentityAdoptionDecision.Type),
+ }
+}
+
+func (AuthIdentity) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("provider_type", "provider_key", "provider_subject").Unique(),
+ index.Fields("user_id"),
+ index.Fields("user_id", "provider_type"),
+ }
+}
diff --git a/backend/ent/schema/auth_identity_channel.go b/backend/ent/schema/auth_identity_channel.go
new file mode 100644
index 00000000..69f2ad02
--- /dev/null
+++ b/backend/ent/schema/auth_identity_channel.go
@@ -0,0 +1,72 @@
+package schema
+
+import (
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// AuthIdentityChannel stores channel-scoped identifiers for a canonical identity.
+type AuthIdentityChannel struct {
+ ent.Schema
+}
+
+func (AuthIdentityChannel) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "auth_identity_channels"},
+ }
+}
+
+func (AuthIdentityChannel) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (AuthIdentityChannel) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("identity_id"),
+ field.String("provider_type").
+ MaxLen(20).
+ NotEmpty().
+ Validate(validateAuthProviderType),
+ field.String("provider_key").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("channel").
+ MaxLen(20).
+ NotEmpty(),
+ field.String("channel_app_id").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("channel_subject").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.JSON("metadata", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ }
+}
+
+func (AuthIdentityChannel) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("identity", AuthIdentity.Type).
+ Ref("channels").
+ Field("identity_id").
+ Required().
+ Unique(),
+ }
+}
+
+func (AuthIdentityChannel) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("provider_type", "provider_key", "channel", "channel_app_id", "channel_subject").Unique(),
+ index.Fields("identity_id"),
+ }
+}
diff --git a/backend/ent/schema/auth_identity_schema_test.go b/backend/ent/schema/auth_identity_schema_test.go
new file mode 100644
index 00000000..fbb93236
--- /dev/null
+++ b/backend/ent/schema/auth_identity_schema_test.go
@@ -0,0 +1,168 @@
+package schema
+
+import (
+ "testing"
+
+ "entgo.io/ent"
+ "entgo.io/ent/entc/load"
+ "entgo.io/ent/schema/field"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityFoundationSchemas(t *testing.T) {
+ spec, err := (&load.Config{Path: "."}).Load()
+ require.NoError(t, err)
+
+ schemas := map[string]*load.Schema{}
+ for _, schema := range spec.Schemas {
+ schemas[schema.Name] = schema
+ }
+
+ authIdentity := requireSchema(t, schemas, "AuthIdentity")
+ requireSchemaFields(t, authIdentity,
+ "user_id",
+ "provider_type",
+ "provider_key",
+ "provider_subject",
+ "verified_at",
+ "issuer",
+ "metadata",
+ )
+ requireHasUniqueIndex(t, authIdentity, "provider_type", "provider_key", "provider_subject")
+
+ authIdentityChannel := requireSchema(t, schemas, "AuthIdentityChannel")
+ requireSchemaFields(t, authIdentityChannel,
+ "identity_id",
+ "provider_type",
+ "provider_key",
+ "channel",
+ "channel_app_id",
+ "channel_subject",
+ "metadata",
+ )
+ requireHasUniqueIndex(t, authIdentityChannel, "provider_type", "provider_key", "channel", "channel_app_id", "channel_subject")
+
+ pendingAuthSession := requireSchema(t, schemas, "PendingAuthSession")
+ requireSchemaFields(t, pendingAuthSession,
+ "intent",
+ "provider_type",
+ "provider_key",
+ "provider_subject",
+ "target_user_id",
+ "redirect_to",
+ "resolved_email",
+ "registration_password_hash",
+ "upstream_identity_claims",
+ "local_flow_state",
+ "browser_session_key",
+ "completion_code_hash",
+ "completion_code_expires_at",
+ "email_verified_at",
+ "password_verified_at",
+ "totp_verified_at",
+ "expires_at",
+ "consumed_at",
+ )
+
+ adoptionDecision := requireSchema(t, schemas, "IdentityAdoptionDecision")
+ requireSchemaFields(t, adoptionDecision,
+ "pending_auth_session_id",
+ "identity_id",
+ "adopt_display_name",
+ "adopt_avatar",
+ "decided_at",
+ )
+ requireHasUniqueIndex(t, adoptionDecision, "pending_auth_session_id")
+
+ userSchema := requireSchema(t, schemas, "User")
+ requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at")
+ signupSource := requireSchemaField(t, userSchema, "signup_source")
+ require.Equal(t, field.TypeString, signupSource.Info.Type)
+ require.True(t, signupSource.Default)
+ require.Equal(t, "email", signupSource.DefaultValue)
+ require.Equal(t, 1, signupSource.Validators)
+
+ validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
+ for _, value := range []string{"email", "linuxdo", "wechat", "oidc"} {
+ require.NoError(t, validator(value))
+ }
+ require.Error(t, validator("github"))
+}
+
+func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema {
+ t.Helper()
+
+ schema, ok := schemas[name]
+ require.True(t, ok, "schema %s should exist", name)
+ return schema
+}
+
+func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) {
+ t.Helper()
+
+ fields := map[string]struct{}{}
+ for _, field := range schema.Fields {
+ fields[field.Name] = struct{}{}
+ }
+
+ for _, name := range names {
+ _, ok := fields[name]
+ require.True(t, ok, "schema %s should include field %s", schema.Name, name)
+ }
+}
+
+func requireSchemaField(t *testing.T, schema *load.Schema, name string) *load.Field {
+ t.Helper()
+
+ for _, schemaField := range schema.Fields {
+ if schemaField.Name == name {
+ return schemaField
+ }
+ }
+
+ require.Failf(t, "missing schema field", "schema %s should include field %s", schema.Name, name)
+ return nil
+}
+
+func requireStringFieldValidator(t *testing.T, fields []ent.Field, name string) func(string) error {
+ t.Helper()
+
+ for _, entField := range fields {
+ descriptor := entField.Descriptor()
+ if descriptor.Name != name {
+ continue
+ }
+ require.NotEmpty(t, descriptor.Validators, "field %s should include a validator", name)
+ validator, ok := descriptor.Validators[0].(func(string) error)
+ require.True(t, ok, "field %s validator should be func(string) error", name)
+ return validator
+ }
+
+ require.Failf(t, "missing field validator", "schema should include field %s", name)
+ return nil
+}
+
+func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) {
+ t.Helper()
+
+ for _, index := range schema.Indexes {
+ if !index.Unique {
+ continue
+ }
+ if len(index.Fields) != len(fields) {
+ continue
+ }
+ match := true
+ for i := range fields {
+ if index.Fields[i] != fields[i] {
+ match = false
+ break
+ }
+ }
+ if match {
+ return
+ }
+ }
+
+ require.Failf(t, "missing unique index", "schema %s should include unique index on %v", schema.Name, fields)
+}
diff --git a/backend/ent/schema/channel_monitor.go b/backend/ent/schema/channel_monitor.go
new file mode 100644
index 00000000..355ade4b
--- /dev/null
+++ b/backend/ent/schema/channel_monitor.go
@@ -0,0 +1,110 @@
+package schema
+
+import (
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// ChannelMonitor holds the schema definition for the ChannelMonitor entity.
+// 渠道监控配置:定期对指定 provider/endpoint/api_key 下的模型做心跳测试。
+type ChannelMonitor struct {
+ ent.Schema
+}
+
+func (ChannelMonitor) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "channel_monitors"},
+ }
+}
+
+func (ChannelMonitor) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (ChannelMonitor) Fields() []ent.Field {
+ return []ent.Field{
+ field.String("name").
+ NotEmpty().
+ MaxLen(100),
+ field.Enum("provider").
+ Values("openai", "anthropic", "gemini"),
+ field.String("endpoint").
+ NotEmpty().
+ MaxLen(500).
+ Comment("Provider base origin, e.g. https://api.openai.com"),
+ field.String("api_key_encrypted").
+ NotEmpty().
+ Sensitive().
+ Comment("AES-256-GCM encrypted API key"),
+ field.String("primary_model").
+ NotEmpty().
+ MaxLen(200),
+ field.JSON("extra_models", []string{}).
+ Default([]string{}).
+ Comment("Additional model names to test alongside primary_model"),
+ field.String("group_name").
+ Optional().
+ Default("").
+ MaxLen(100),
+ field.Bool("enabled").
+ Default(true),
+ field.Int("interval_seconds").
+ Range(15, 3600),
+ field.Time("last_checked_at").
+ Optional().
+ Nillable(),
+ field.Int64("created_by"),
+
+ // ---- 自定义请求快照字段(来自模板 / 手动编辑) ----
+
+ // template_id: 关联的请求模板 ID(仅用于 UI 分组 + 一键应用)。
+ // 实际运行时 checker 只读下面 3 个快照字段,**不再回查模板表**。
+ // 模板被删除时此字段会被 SET NULL(见 Edges 的 OnDelete 注解)。
+ field.Int64("template_id").
+ Optional().
+ Nillable(),
+ // extra_headers: 自定义 HTTP 头快照(来自模板 or 用户手填)。
+ // 运行时 merge 进 adapter 默认 headers。
+ field.JSON("extra_headers", map[string]string{}).
+ Default(map[string]string{}),
+ // body_override_mode: 同 ChannelMonitorRequestTemplate.body_override_mode
+ field.String("body_override_mode").
+ Default("off").
+ MaxLen(10),
+ // body_override: 同 ChannelMonitorRequestTemplate.body_override
+ field.JSON("body_override", map[string]any{}).
+ Optional(),
+ }
+}
+
+func (ChannelMonitor) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.To("history", ChannelMonitorHistory.Type).
+ Annotations(entsql.OnDelete(entsql.Cascade)),
+ edge.To("daily_rollups", ChannelMonitorDailyRollup.Type).
+ Annotations(entsql.OnDelete(entsql.Cascade)),
+ // 关联请求模板:模板被删除时 template_id 自动置空,
+ // 监控本身保留(继续用快照字段跑)。
+ edge.To("request_template", ChannelMonitorRequestTemplate.Type).
+ Field("template_id").
+ Unique().
+ Annotations(entsql.OnDelete(entsql.SetNull)),
+ }
+}
+
+func (ChannelMonitor) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("enabled", "last_checked_at"),
+ index.Fields("provider"),
+ index.Fields("group_name"),
+ index.Fields("template_id"),
+ }
+}
diff --git a/backend/ent/schema/channel_monitor_daily_rollup.go b/backend/ent/schema/channel_monitor_daily_rollup.go
new file mode 100644
index 00000000..23f032e3
--- /dev/null
+++ b/backend/ent/schema/channel_monitor_daily_rollup.go
@@ -0,0 +1,66 @@
+package schema
+
+import (
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// ChannelMonitorDailyRollup 按 (monitor_id, model, bucket_date) 维度聚合的渠道监控日统计。
+// 每天的明细被收敛为一行(保留 status 分布 + 延迟和),用于 7d/15d/30d 窗口的可用率
+// 加权计算(avg_latency = sum_latency_ms / count_latency;availability = ok_count / total_checks)。
+// 超过保留期由每日维护任务分批物理删(不用软删除,理由同 channel_monitor_history)。
+type ChannelMonitorDailyRollup struct {
+ ent.Schema
+}
+
+func (ChannelMonitorDailyRollup) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "channel_monitor_daily_rollups"},
+ }
+}
+
+func (ChannelMonitorDailyRollup) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("monitor_id"),
+ field.String("model").
+ NotEmpty().
+ MaxLen(200),
+ field.Time("bucket_date").
+ SchemaType(map[string]string{dialect.Postgres: "date"}),
+ field.Int("total_checks").Default(0),
+ field.Int("ok_count").Default(0),
+ field.Int("operational_count").Default(0),
+ field.Int("degraded_count").Default(0),
+ field.Int("failed_count").Default(0),
+ field.Int("error_count").Default(0),
+ field.Int64("sum_latency_ms").Default(0),
+ field.Int("count_latency").Default(0),
+ field.Int64("sum_ping_latency_ms").Default(0),
+ field.Int("count_ping_latency").Default(0),
+ field.Time("computed_at").Default(time.Now).UpdateDefault(time.Now),
+ }
+}
+
+func (ChannelMonitorDailyRollup) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("monitor", ChannelMonitor.Type).
+ Ref("daily_rollups").
+ Field("monitor_id").
+ Unique().
+ Required(),
+ }
+}
+
+func (ChannelMonitorDailyRollup) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("monitor_id", "model", "bucket_date").Unique(),
+ index.Fields("bucket_date"),
+ }
+}
diff --git a/backend/ent/schema/channel_monitor_history.go b/backend/ent/schema/channel_monitor_history.go
new file mode 100644
index 00000000..4366e79a
--- /dev/null
+++ b/backend/ent/schema/channel_monitor_history.go
@@ -0,0 +1,66 @@
+package schema
+
+import (
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// ChannelMonitorHistory holds the schema definition for the ChannelMonitorHistory entity.
+// 渠道监控历史:每次检测每个模型一行记录。明细只保留 1 天,超过 1 天由每日维护任务
+// 先聚合到 channel_monitor_daily_rollups,再分批物理删(不用软删除:日志类表无恢复
+// 需求,软删会让行和索引只增不减,徒增磁盘和查询开销)。
+type ChannelMonitorHistory struct {
+ ent.Schema
+}
+
+func (ChannelMonitorHistory) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "channel_monitor_histories"},
+ }
+}
+
+func (ChannelMonitorHistory) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("monitor_id"),
+ field.String("model").
+ NotEmpty().
+ MaxLen(200),
+ field.Enum("status").
+ Values("operational", "degraded", "failed", "error"),
+ field.Int("latency_ms").
+ Optional().
+ Nillable(),
+ field.Int("ping_latency_ms").
+ Optional().
+ Nillable(),
+ field.String("message").
+ Optional().
+ Default("").
+ MaxLen(500),
+ field.Time("checked_at").
+ Default(time.Now),
+ }
+}
+
+func (ChannelMonitorHistory) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("monitor", ChannelMonitor.Type).
+ Ref("history").
+ Field("monitor_id").
+ Unique().
+ Required(),
+ }
+}
+
+func (ChannelMonitorHistory) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("monitor_id", "model", "checked_at"),
+ index.Fields("checked_at"),
+ }
+}
diff --git a/backend/ent/schema/channel_monitor_request_template.go b/backend/ent/schema/channel_monitor_request_template.go
new file mode 100644
index 00000000..59df2f29
--- /dev/null
+++ b/backend/ent/schema/channel_monitor_request_template.go
@@ -0,0 +1,80 @@
+package schema
+
+import (
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// ChannelMonitorRequestTemplate 请求模板:一组可复用的 headers + 可选 body 覆盖配置。
+//
+// 语义为快照:模板被"应用"到监控时,extra_headers / body_override_mode / body_override
+// 会被**拷贝**到 channel_monitors 同名字段;后续模板变动不会自动影响已应用的监控——
+// 必须用户主动在模板编辑 Dialog 里点「应用到关联监控」才会覆盖快照。
+// 这样模板改错不会瞬间打挂所有已经跑起来的监控。
+type ChannelMonitorRequestTemplate struct {
+ ent.Schema
+}
+
+func (ChannelMonitorRequestTemplate) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "channel_monitor_request_templates"},
+ }
+}
+
+func (ChannelMonitorRequestTemplate) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (ChannelMonitorRequestTemplate) Fields() []ent.Field {
+ return []ent.Field{
+ field.String("name").
+ NotEmpty().
+ MaxLen(100),
+ field.Enum("provider").
+ Values("openai", "anthropic", "gemini"),
+ field.String("description").
+ Optional().
+ Default("").
+ MaxLen(500),
+ // extra_headers: 用户自定义 HTTP 头(如 User-Agent 伪装)。
+ // 运行时 merge 进 adapter 默认 headers,用户值优先;
+ // hop-by-hop 黑名单(Host/Content-Length/...)由 checker 过滤。
+ field.JSON("extra_headers", map[string]string{}).
+ Default(map[string]string{}),
+ // body_override_mode: 'off' | 'merge' | 'replace'
+ // off - 用 adapter 默认 body(忽略 body_override)
+ // merge - adapter 默认 body 与 body_override 浅合并(body_override 优先,
+ // model/messages/contents 等关键字段在 checker 里走黑名单跳过)
+ // replace - 直接用 body_override 作为完整 body;此时跳过 challenge 校验,
+ // 改为 HTTP 2xx + 响应文本非空即视为可用
+ field.String("body_override_mode").
+ Default("off").
+ MaxLen(10),
+ // body_override: JSON 对象,根据 body_override_mode 使用。
+ // 用 map[string]any 以便前端传任意结构(含嵌套)。
+ field.JSON("body_override", map[string]any{}).
+ Optional(),
+ }
+}
+
+func (ChannelMonitorRequestTemplate) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("monitors", ChannelMonitor.Type).
+ Ref("request_template"),
+ }
+}
+
+func (ChannelMonitorRequestTemplate) Indexes() []ent.Index {
+ return []ent.Index{
+ // 同一 provider 内 name 唯一:允许 Anthropic + OpenAI 重名 "伪装官方客户端"。
+ index.Fields("provider", "name").Unique(),
+ }
+}
diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go
index fd83bf26..11f38d66 100644
--- a/backend/ent/schema/group.go
+++ b/backend/ent/schema/group.go
@@ -87,28 +87,6 @@ func (Group) Fields() []ent.Field {
Nillable().
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
- // Sora 按次计费配置(阶段 1)
- field.Float("sora_image_price_360").
- Optional().
- Nillable().
- SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
- field.Float("sora_image_price_540").
- Optional().
- Nillable().
- SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
- field.Float("sora_video_price_per_request").
- Optional().
- Nillable().
- SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
- field.Float("sora_video_price_per_request_hd").
- Optional().
- Nillable().
- SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}),
-
- // Sora 存储配额
- field.Int64("sora_storage_quota_bytes").
- Default(0),
-
// Claude Code 客户端限制 (added by migration 029)
field.Bool("claude_code_only").
Default(false).
@@ -163,6 +141,15 @@ func (Group) Fields() []ent.Field {
MaxLen(100).
Default("").
Comment("默认映射模型 ID,当账号级映射找不到时使用此值"),
+ field.JSON("messages_dispatch_model_config", domain.OpenAIMessagesDispatchModelConfig{}).
+ Default(domain.OpenAIMessagesDispatchModelConfig{}).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
+ Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"),
+
+ // 分组级每分钟请求数上限(0 = 不限制)。设置后优先于用户级兜底生效。
+ field.Int("rpm_limit").
+ Default(0).
+ Comment("分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流"),
}
}
diff --git a/backend/ent/schema/identity_adoption_decision.go b/backend/ent/schema/identity_adoption_decision.go
new file mode 100644
index 00000000..9fdd26fb
--- /dev/null
+++ b/backend/ent/schema/identity_adoption_decision.go
@@ -0,0 +1,70 @@
+package schema
+
+import (
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// IdentityAdoptionDecision stores the one-time profile adoption choice captured during a pending auth flow.
+type IdentityAdoptionDecision struct {
+ ent.Schema
+}
+
+func (IdentityAdoptionDecision) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "identity_adoption_decisions"},
+ }
+}
+
+func (IdentityAdoptionDecision) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (IdentityAdoptionDecision) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("pending_auth_session_id"),
+ field.Int64("identity_id").
+ Optional().
+ Nillable(),
+ field.Bool("adopt_display_name").
+ Default(false),
+ field.Bool("adopt_avatar").
+ Default(false),
+ field.Time("decided_at").
+ Immutable().
+ Default(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (IdentityAdoptionDecision) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("pending_auth_session", PendingAuthSession.Type).
+ Ref("adoption_decision").
+ Field("pending_auth_session_id").
+ Required().
+ Unique(),
+ edge.From("identity", AuthIdentity.Type).
+ Ref("adoption_decisions").
+ Field("identity_id").
+ Unique(),
+ }
+}
+
+func (IdentityAdoptionDecision) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("pending_auth_session_id").Unique(),
+ index.Fields("identity_id"),
+ }
+}
diff --git a/backend/ent/schema/payment_audit_log.go b/backend/ent/schema/payment_audit_log.go
new file mode 100644
index 00000000..7f8a8c04
--- /dev/null
+++ b/backend/ent/schema/payment_audit_log.go
@@ -0,0 +1,54 @@
+package schema
+
+import (
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// PaymentAuditLog holds the schema definition for the PaymentAuditLog entity.
+//
+// 删除策略:硬删除
+// PaymentAuditLog 使用硬删除而非软删除,原因如下:
+// - 审计日志本身即为不可变记录,通常只追加不修改
+// - 如需清理历史日志,直接按时间范围批量删除即可
+// - 保持表结构简洁,提升插入和查询性能
+type PaymentAuditLog struct {
+ ent.Schema
+}
+
+func (PaymentAuditLog) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "payment_audit_logs"},
+ }
+}
+
+func (PaymentAuditLog) Fields() []ent.Field {
+ return []ent.Field{
+ field.String("order_id").
+ MaxLen(64),
+ field.String("action").
+ MaxLen(50),
+ field.String("detail").
+ SchemaType(map[string]string{dialect.Postgres: "text"}).
+ Default(""),
+ field.String("operator").
+ MaxLen(100).
+ Default("system"),
+ field.Time("created_at").
+ Immutable().
+ Default(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (PaymentAuditLog) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("order_id"),
+ }
+}
diff --git a/backend/ent/schema/payment_order.go b/backend/ent/schema/payment_order.go
new file mode 100644
index 00000000..d25d1e5e
--- /dev/null
+++ b/backend/ent/schema/payment_order.go
@@ -0,0 +1,199 @@
+package schema
+
+import (
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// PaymentOrder holds the schema definition for the PaymentOrder entity.
+//
+// 删除策略:硬删除
+// PaymentOrder 使用硬删除而非软删除,原因如下:
+// - 订单通过 status 字段追踪完整生命周期,无需依赖软删除
+// - 订单审计通过 PaymentAuditLog 表记录,删除前可归档
+// - 减少查询复杂度,避免软删除过滤开销
+type PaymentOrder struct {
+ ent.Schema
+}
+
+func (PaymentOrder) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "payment_orders"},
+ }
+}
+
+func (PaymentOrder) Fields() []ent.Field {
+ return []ent.Field{
+ // 用户信息(冗余存储,避免关联查询)
+ field.Int64("user_id"),
+ field.String("user_email").
+ MaxLen(255),
+ field.String("user_name").
+ MaxLen(100),
+ field.String("user_notes").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+
+ // 金额信息
+ field.Float("amount").
+ SchemaType(map[string]string{dialect.Postgres: "decimal(20,2)"}),
+ field.Float("pay_amount").
+ SchemaType(map[string]string{dialect.Postgres: "decimal(20,2)"}),
+ field.Float("fee_rate").
+ SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}).
+ Default(0),
+ field.String("recharge_code").
+ MaxLen(64),
+
+ // 支付信息
+ field.String("out_trade_no").
+ MaxLen(64).
+ Default(""),
+ field.String("payment_type").
+ MaxLen(30),
+ field.String("payment_trade_no").
+ MaxLen(128),
+ field.String("pay_url").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("qr_code").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("qr_code_img").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+
+ // 订单类型 & 订阅关联
+ field.String("order_type").
+ MaxLen(20).
+ Default("balance"),
+ field.Int64("plan_id").
+ Optional().
+ Nillable(),
+ field.Int64("subscription_group_id").
+ Optional().
+ Nillable(),
+ field.Int("subscription_days").
+ Optional().
+ Nillable(),
+ field.String("provider_instance_id").
+ Optional().
+ Nillable().
+ MaxLen(64),
+ field.String("provider_key").
+ Optional().
+ Nillable().
+ MaxLen(30),
+ field.JSON("provider_snapshot", map[string]any{}).
+ Optional().
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+
+ // 状态
+ field.String("status").
+ MaxLen(30).
+ Default("PENDING"),
+
+ // 退款信息
+ field.Float("refund_amount").
+ SchemaType(map[string]string{dialect.Postgres: "decimal(20,2)"}).
+ Default(0),
+ field.String("refund_reason").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Time("refund_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Bool("force_refund").
+ Default(false),
+ field.Time("refund_requested_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.String("refund_request_reason").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("refund_requested_by").
+ Optional().
+ Nillable().
+ MaxLen(20),
+
+ // 时间节点
+ field.Time("expires_at").
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("paid_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("completed_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("failed_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.String("failed_reason").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+
+ // 来源信息
+ field.String("client_ip").
+ MaxLen(50),
+ field.String("src_host").
+ MaxLen(255),
+ field.String("src_url").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+
+ // 时间戳
+ field.Time("created_at").
+ Immutable().
+ Default(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("updated_at").
+ Default(time.Now).
+ UpdateDefault(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (PaymentOrder) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("user", User.Type).
+ Ref("payment_orders").
+ Field("user_id").
+ Unique().
+ Required(),
+ }
+}
+
+func (PaymentOrder) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("out_trade_no").
+ Unique().
+ Annotations(entsql.IndexWhere("out_trade_no <> ''")),
+ index.Fields("user_id"),
+ index.Fields("status"),
+ index.Fields("expires_at"),
+ index.Fields("created_at"),
+ index.Fields("paid_at"),
+ index.Fields("payment_type", "paid_at"),
+ index.Fields("order_type"),
+ }
+}
diff --git a/backend/ent/schema/payment_provider_instance.go b/backend/ent/schema/payment_provider_instance.go
new file mode 100644
index 00000000..e4c0b72c
--- /dev/null
+++ b/backend/ent/schema/payment_provider_instance.go
@@ -0,0 +1,74 @@
+package schema
+
+import (
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// PaymentProviderInstance holds the schema definition for the PaymentProviderInstance entity.
+//
+// 删除策略:硬删除
+// PaymentProviderInstance 使用硬删除而非软删除,原因如下:
+// - 服务商实例为管理员配置的支付通道,删除即表示废弃
+// - 通过 enabled 字段控制是否启用,删除仅用于彻底移除
+// - config 字段存储加密后的密钥信息,删除时应彻底清除
+type PaymentProviderInstance struct {
+ ent.Schema
+}
+
+func (PaymentProviderInstance) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "payment_provider_instances"},
+ }
+}
+
+func (PaymentProviderInstance) Fields() []ent.Field {
+ return []ent.Field{
+ field.String("provider_key").
+ MaxLen(30).
+ NotEmpty(),
+ field.String("name").
+ MaxLen(100).
+ Default(""),
+ field.String("config").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("supported_types").
+ MaxLen(200).
+ Default(""),
+ field.Bool("enabled").
+ Default(true),
+ field.String("payment_mode").
+ MaxLen(20).
+ Default(""),
+ field.Int("sort_order").
+ Default(0),
+ field.String("limits").
+ SchemaType(map[string]string{dialect.Postgres: "text"}).
+ Default(""),
+ field.Bool("refund_enabled").
+ Default(false),
+ field.Bool("allow_user_refund").
+ Default(false),
+ field.Time("created_at").
+ Immutable().
+ Default(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("updated_at").
+ Default(time.Now).
+ UpdateDefault(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (PaymentProviderInstance) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("provider_key"),
+ index.Fields("enabled"),
+ }
+}
diff --git a/backend/ent/schema/pending_auth_session.go b/backend/ent/schema/pending_auth_session.go
new file mode 100644
index 00000000..7e95f085
--- /dev/null
+++ b/backend/ent/schema/pending_auth_session.go
@@ -0,0 +1,135 @@
+package schema
+
+import (
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/edge"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+var pendingAuthIntents = map[string]struct{}{
+ "login": {},
+ "bind_current_user": {},
+ "adopt_existing_user_by_email": {},
+}
+
+func validatePendingAuthIntent(value string) error {
+ if _, ok := pendingAuthIntents[value]; ok {
+ return nil
+ }
+ return fmt.Errorf("invalid pending auth intent %q", value)
+}
+
+// PendingAuthSession stores a short-lived post-auth decision session.
+type PendingAuthSession struct {
+ ent.Schema
+}
+
+func (PendingAuthSession) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "pending_auth_sessions"},
+ }
+}
+
+func (PendingAuthSession) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (PendingAuthSession) Fields() []ent.Field {
+ return []ent.Field{
+ field.String("session_token").
+ MaxLen(255).
+ NotEmpty(),
+ field.String("intent").
+ MaxLen(40).
+ NotEmpty().
+ Validate(validatePendingAuthIntent),
+ field.String("provider_type").
+ MaxLen(20).
+ NotEmpty().
+ Validate(validateAuthProviderType),
+ field.String("provider_key").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("provider_subject").
+ NotEmpty().
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Int64("target_user_id").
+ Optional().
+ Nillable(),
+ field.String("redirect_to").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("resolved_email").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("registration_password_hash").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.JSON("upstream_identity_claims", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ field.JSON("local_flow_state", map[string]any{}).
+ Default(func() map[string]any { return map[string]any{} }).
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
+ field.String("browser_session_key").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.String("completion_code_hash").
+ Default("").
+ SchemaType(map[string]string{dialect.Postgres: "text"}),
+ field.Time("completion_code_expires_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("email_verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("password_verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("totp_verified_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("expires_at").
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("consumed_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (PendingAuthSession) Edges() []ent.Edge {
+ return []ent.Edge{
+ edge.From("target_user", User.Type).
+ Ref("pending_auth_sessions").
+ Field("target_user_id").
+ Unique(),
+ edge.To("adoption_decision", IdentityAdoptionDecision.Type).
+ Annotations(entsql.OnDelete(entsql.Cascade)).
+ Unique(),
+ }
+}
+
+func (PendingAuthSession) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("session_token").Unique(),
+ index.Fields("target_user_id"),
+ index.Fields("expires_at"),
+ index.Fields("provider_type", "provider_key", "provider_subject"),
+ index.Fields("completion_code_hash"),
+ }
+}
diff --git a/backend/ent/schema/subscription_plan.go b/backend/ent/schema/subscription_plan.go
new file mode 100644
index 00000000..3e30490b
--- /dev/null
+++ b/backend/ent/schema/subscription_plan.go
@@ -0,0 +1,77 @@
+package schema
+
+import (
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// SubscriptionPlan holds the schema definition for the SubscriptionPlan entity.
+//
+// 删除策略:硬删除
+// SubscriptionPlan 使用硬删除而非软删除,原因如下:
+// - 套餐为管理员维护的商品配置,删除即表示下架移除
+// - 通过 for_sale 字段控制是否在售,删除仅用于彻底移除
+// - 已购买的订阅记录保存在 UserSubscription 中,不受套餐删除影响
+type SubscriptionPlan struct {
+ ent.Schema
+}
+
+func (SubscriptionPlan) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "subscription_plans"},
+ }
+}
+
+func (SubscriptionPlan) Fields() []ent.Field {
+ return []ent.Field{
+ field.Int64("group_id"),
+ field.String("name").
+ MaxLen(100).
+ NotEmpty(),
+ field.String("description").
+ SchemaType(map[string]string{dialect.Postgres: "text"}).
+ Default(""),
+ field.Float("price").
+ SchemaType(map[string]string{dialect.Postgres: "decimal(20,2)"}),
+ field.Float("original_price").
+ SchemaType(map[string]string{dialect.Postgres: "decimal(20,2)"}).
+ Optional().
+ Nillable(),
+ field.Int("validity_days").
+ Default(30),
+ field.String("validity_unit").
+ MaxLen(10).
+ Default("day"),
+ field.String("features").
+ SchemaType(map[string]string{dialect.Postgres: "text"}).
+ Default(""),
+ field.String("product_name").
+ MaxLen(100).
+ Default(""),
+ field.Bool("for_sale").
+ Default(true),
+ field.Int("sort_order").
+ Default(0),
+ field.Time("created_at").
+ Immutable().
+ Default(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("updated_at").
+ Default(time.Now).
+ UpdateDefault(time.Now).
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ }
+}
+
+func (SubscriptionPlan) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("group_id"),
+ index.Fields("for_sale"),
+ }
+}
diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go
index 32c39e25..bd3ebfcc 100644
--- a/backend/ent/schema/usage_log.go
+++ b/backend/ent/schema/usage_log.go
@@ -53,6 +53,10 @@ func (UsageLog) Fields() []ent.Field {
MaxLen(100).
Optional().
Nillable(),
+ field.Int64("channel_id").Optional().Nillable().Comment("渠道 ID"),
+ field.String("model_mapping_chain").MaxLen(500).Optional().Nillable().Comment("模型映射链"),
+ field.String("billing_tier").MaxLen(50).Optional().Nillable().Comment("计费层级标签"),
+ field.String("billing_mode").MaxLen(20).Optional().Nillable().Comment("计费模式:token/per_request/image"),
field.Int64("group_id").
Optional().
Nillable(),
@@ -130,12 +134,6 @@ func (UsageLog) Fields() []ent.Field {
MaxLen(10).
Optional().
Nillable(),
- // 媒体类型字段(sora 使用)
- field.String("media_type").
- MaxLen(16).
- Optional().
- Nillable(),
-
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
field.Bool("cache_ttl_overridden").
Default(false),
diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go
index 0a3b5d9e..83da5c32 100644
--- a/backend/ent/schema/user.go
+++ b/backend/ent/schema/user.go
@@ -1,6 +1,8 @@
package schema
import (
+ "fmt"
+
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/domain"
@@ -72,11 +74,43 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at").
Optional().
Nillable(),
+ field.String("signup_source").
+ Validate(func(value string) error {
+ switch value {
+ case "email", "linuxdo", "wechat", "oidc":
+ return nil
+ default:
+ return fmt.Errorf("must be one of email, linuxdo, wechat, oidc")
+ }
+ }).
+ Default("email"),
+ field.Time("last_login_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("last_active_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
- // Sora 存储配额
- field.Int64("sora_storage_quota_bytes").
+ // 余额不足通知
+ field.Bool("balance_notify_enabled").
+ Default(true),
+ field.String("balance_notify_threshold_type").
+ Default("fixed"), // "fixed" | "percentage"
+ field.Float("balance_notify_threshold").
+ SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
+ Optional().
+ Nillable(),
+ field.String("balance_notify_extra_emails").
+ SchemaType(map[string]string{dialect.Postgres: "text"}).
+ Default("[]"),
+ field.Float("total_recharged").
+ SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0),
- field.Int64("sora_storage_used_bytes").
+
+ // 用户级每分钟请求数上限(0 = 不限制)。仅当所在分组未设置 rpm_limit 时作为兜底生效。
+ field.Int("rpm_limit").
Default(0),
}
}
@@ -93,6 +127,10 @@ func (User) Edges() []ent.Edge {
edge.To("usage_logs", UsageLog.Type),
edge.To("attribute_values", UserAttributeValue.Type),
edge.To("promo_code_usages", PromoCodeUsage.Type),
+ edge.To("payment_orders", PaymentOrder.Type),
+ edge.To("auth_identities", AuthIdentity.Type).
+ Annotations(entsql.OnDelete(entsql.Cascade)),
+ edge.To("pending_auth_sessions", PendingAuthSession.Type),
}
}
diff --git a/backend/ent/subscriptionplan.go b/backend/ent/subscriptionplan.go
new file mode 100644
index 00000000..fa4d7ae3
--- /dev/null
+++ b/backend/ent/subscriptionplan.go
@@ -0,0 +1,245 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/subscriptionplan"
+)
+
+// SubscriptionPlan is the model entity for the SubscriptionPlan schema.
+type SubscriptionPlan struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // GroupID holds the value of the "group_id" field.
+ GroupID int64 `json:"group_id,omitempty"`
+ // Name holds the value of the "name" field.
+ Name string `json:"name,omitempty"`
+ // Description holds the value of the "description" field.
+ Description string `json:"description,omitempty"`
+ // Price holds the value of the "price" field.
+ Price float64 `json:"price,omitempty"`
+ // OriginalPrice holds the value of the "original_price" field.
+ OriginalPrice *float64 `json:"original_price,omitempty"`
+ // ValidityDays holds the value of the "validity_days" field.
+ ValidityDays int `json:"validity_days,omitempty"`
+ // ValidityUnit holds the value of the "validity_unit" field.
+ ValidityUnit string `json:"validity_unit,omitempty"`
+ // Features holds the value of the "features" field.
+ Features string `json:"features,omitempty"`
+ // ProductName holds the value of the "product_name" field.
+ ProductName string `json:"product_name,omitempty"`
+ // ForSale holds the value of the "for_sale" field.
+ ForSale bool `json:"for_sale,omitempty"`
+ // SortOrder holds the value of the "sort_order" field.
+ SortOrder int `json:"sort_order,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ selectValues sql.SelectValues
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*SubscriptionPlan) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case subscriptionplan.FieldForSale:
+ values[i] = new(sql.NullBool)
+ case subscriptionplan.FieldPrice, subscriptionplan.FieldOriginalPrice:
+ values[i] = new(sql.NullFloat64)
+ case subscriptionplan.FieldID, subscriptionplan.FieldGroupID, subscriptionplan.FieldValidityDays, subscriptionplan.FieldSortOrder:
+ values[i] = new(sql.NullInt64)
+ case subscriptionplan.FieldName, subscriptionplan.FieldDescription, subscriptionplan.FieldValidityUnit, subscriptionplan.FieldFeatures, subscriptionplan.FieldProductName:
+ values[i] = new(sql.NullString)
+ case subscriptionplan.FieldCreatedAt, subscriptionplan.FieldUpdatedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the SubscriptionPlan fields.
+func (_m *SubscriptionPlan) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case subscriptionplan.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case subscriptionplan.FieldGroupID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field group_id", values[i])
+ } else if value.Valid {
+ _m.GroupID = value.Int64
+ }
+ case subscriptionplan.FieldName:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field name", values[i])
+ } else if value.Valid {
+ _m.Name = value.String
+ }
+ case subscriptionplan.FieldDescription:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field description", values[i])
+ } else if value.Valid {
+ _m.Description = value.String
+ }
+ case subscriptionplan.FieldPrice:
+ if value, ok := values[i].(*sql.NullFloat64); !ok {
+ return fmt.Errorf("unexpected type %T for field price", values[i])
+ } else if value.Valid {
+ _m.Price = value.Float64
+ }
+ case subscriptionplan.FieldOriginalPrice:
+ if value, ok := values[i].(*sql.NullFloat64); !ok {
+ return fmt.Errorf("unexpected type %T for field original_price", values[i])
+ } else if value.Valid {
+ _m.OriginalPrice = new(float64)
+ *_m.OriginalPrice = value.Float64
+ }
+ case subscriptionplan.FieldValidityDays:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field validity_days", values[i])
+ } else if value.Valid {
+ _m.ValidityDays = int(value.Int64)
+ }
+ case subscriptionplan.FieldValidityUnit:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field validity_unit", values[i])
+ } else if value.Valid {
+ _m.ValidityUnit = value.String
+ }
+ case subscriptionplan.FieldFeatures:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field features", values[i])
+ } else if value.Valid {
+ _m.Features = value.String
+ }
+ case subscriptionplan.FieldProductName:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field product_name", values[i])
+ } else if value.Valid {
+ _m.ProductName = value.String
+ }
+ case subscriptionplan.FieldForSale:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field for_sale", values[i])
+ } else if value.Valid {
+ _m.ForSale = value.Bool
+ }
+ case subscriptionplan.FieldSortOrder:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field sort_order", values[i])
+ } else if value.Valid {
+ _m.SortOrder = int(value.Int64)
+ }
+ case subscriptionplan.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case subscriptionplan.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the SubscriptionPlan.
+// This includes values selected through modifiers, order, etc.
+func (_m *SubscriptionPlan) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// Update returns a builder for updating this SubscriptionPlan.
+// Note that you need to call SubscriptionPlan.Unwrap() before calling this method if this SubscriptionPlan
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *SubscriptionPlan) Update() *SubscriptionPlanUpdateOne {
+ return NewSubscriptionPlanClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the SubscriptionPlan entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *SubscriptionPlan) Unwrap() *SubscriptionPlan {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: SubscriptionPlan is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *SubscriptionPlan) String() string {
+ var builder strings.Builder
+ builder.WriteString("SubscriptionPlan(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("group_id=")
+ builder.WriteString(fmt.Sprintf("%v", _m.GroupID))
+ builder.WriteString(", ")
+ builder.WriteString("name=")
+ builder.WriteString(_m.Name)
+ builder.WriteString(", ")
+ builder.WriteString("description=")
+ builder.WriteString(_m.Description)
+ builder.WriteString(", ")
+ builder.WriteString("price=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Price))
+ builder.WriteString(", ")
+ if v := _m.OriginalPrice; v != nil {
+ builder.WriteString("original_price=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("validity_days=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ValidityDays))
+ builder.WriteString(", ")
+ builder.WriteString("validity_unit=")
+ builder.WriteString(_m.ValidityUnit)
+ builder.WriteString(", ")
+ builder.WriteString("features=")
+ builder.WriteString(_m.Features)
+ builder.WriteString(", ")
+ builder.WriteString("product_name=")
+ builder.WriteString(_m.ProductName)
+ builder.WriteString(", ")
+ builder.WriteString("for_sale=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ForSale))
+ builder.WriteString(", ")
+ builder.WriteString("sort_order=")
+ builder.WriteString(fmt.Sprintf("%v", _m.SortOrder))
+ builder.WriteString(", ")
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// SubscriptionPlans is a parsable slice of SubscriptionPlan.
+type SubscriptionPlans []*SubscriptionPlan
diff --git a/backend/ent/subscriptionplan/subscriptionplan.go b/backend/ent/subscriptionplan/subscriptionplan.go
new file mode 100644
index 00000000..fa125aa7
--- /dev/null
+++ b/backend/ent/subscriptionplan/subscriptionplan.go
@@ -0,0 +1,174 @@
+// Code generated by ent, DO NOT EDIT.
+
+package subscriptionplan
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+)
+
+const (
+ // Label holds the string label denoting the subscriptionplan type in the database.
+ Label = "subscription_plan"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldGroupID holds the string denoting the group_id field in the database.
+ FieldGroupID = "group_id"
+ // FieldName holds the string denoting the name field in the database.
+ FieldName = "name"
+ // FieldDescription holds the string denoting the description field in the database.
+ FieldDescription = "description"
+ // FieldPrice holds the string denoting the price field in the database.
+ FieldPrice = "price"
+ // FieldOriginalPrice holds the string denoting the original_price field in the database.
+ FieldOriginalPrice = "original_price"
+ // FieldValidityDays holds the string denoting the validity_days field in the database.
+ FieldValidityDays = "validity_days"
+ // FieldValidityUnit holds the string denoting the validity_unit field in the database.
+ FieldValidityUnit = "validity_unit"
+ // FieldFeatures holds the string denoting the features field in the database.
+ FieldFeatures = "features"
+ // FieldProductName holds the string denoting the product_name field in the database.
+ FieldProductName = "product_name"
+ // FieldForSale holds the string denoting the for_sale field in the database.
+ FieldForSale = "for_sale"
+ // FieldSortOrder holds the string denoting the sort_order field in the database.
+ FieldSortOrder = "sort_order"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // Table holds the table name of the subscriptionplan in the database.
+ Table = "subscription_plans"
+)
+
+// Columns holds all SQL columns for subscriptionplan fields.
+var Columns = []string{
+ FieldID,
+ FieldGroupID,
+ FieldName,
+ FieldDescription,
+ FieldPrice,
+ FieldOriginalPrice,
+ FieldValidityDays,
+ FieldValidityUnit,
+ FieldFeatures,
+ FieldProductName,
+ FieldForSale,
+ FieldSortOrder,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // NameValidator is a validator for the "name" field. It is called by the builders before save.
+ NameValidator func(string) error
+ // DefaultDescription holds the default value on creation for the "description" field.
+ DefaultDescription string
+ // DefaultValidityDays holds the default value on creation for the "validity_days" field.
+ DefaultValidityDays int
+ // DefaultValidityUnit holds the default value on creation for the "validity_unit" field.
+ DefaultValidityUnit string
+ // ValidityUnitValidator is a validator for the "validity_unit" field. It is called by the builders before save.
+ ValidityUnitValidator func(string) error
+ // DefaultFeatures holds the default value on creation for the "features" field.
+ DefaultFeatures string
+ // DefaultProductName holds the default value on creation for the "product_name" field.
+ DefaultProductName string
+ // ProductNameValidator is a validator for the "product_name" field. It is called by the builders before save.
+ ProductNameValidator func(string) error
+ // DefaultForSale holds the default value on creation for the "for_sale" field.
+ DefaultForSale bool
+ // DefaultSortOrder holds the default value on creation for the "sort_order" field.
+ DefaultSortOrder int
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+)
+
+// OrderOption defines the ordering options for the SubscriptionPlan queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByGroupID orders the results by the group_id field.
+func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldGroupID, opts...).ToFunc()
+}
+
+// ByName orders the results by the name field.
+func ByName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldName, opts...).ToFunc()
+}
+
+// ByDescription orders the results by the description field.
+func ByDescription(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDescription, opts...).ToFunc()
+}
+
+// ByPrice orders the results by the price field.
+func ByPrice(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldPrice, opts...).ToFunc()
+}
+
+// ByOriginalPrice orders the results by the original_price field.
+func ByOriginalPrice(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldOriginalPrice, opts...).ToFunc()
+}
+
+// ByValidityDays orders the results by the validity_days field.
+func ByValidityDays(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldValidityDays, opts...).ToFunc()
+}
+
+// ByValidityUnit orders the results by the validity_unit field.
+func ByValidityUnit(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldValidityUnit, opts...).ToFunc()
+}
+
+// ByFeatures orders the results by the features field.
+func ByFeatures(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldFeatures, opts...).ToFunc()
+}
+
+// ByProductName orders the results by the product_name field.
+func ByProductName(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldProductName, opts...).ToFunc()
+}
+
+// ByForSale orders the results by the for_sale field.
+func ByForSale(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldForSale, opts...).ToFunc()
+}
+
+// BySortOrder orders the results by the sort_order field.
+func BySortOrder(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSortOrder, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
diff --git a/backend/ent/subscriptionplan/where.go b/backend/ent/subscriptionplan/where.go
new file mode 100644
index 00000000..319cfdb5
--- /dev/null
+++ b/backend/ent/subscriptionplan/where.go
@@ -0,0 +1,760 @@
+// Code generated by ent, DO NOT EDIT.
+
+package subscriptionplan
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLTE(FieldID, id))
+}
+
+// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
+func GroupID(v int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldGroupID, v))
+}
+
+// Name applies equality check predicate on the "name" field. It's identical to NameEQ.
+func Name(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldName, v))
+}
+
+// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
+func Description(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldDescription, v))
+}
+
+// Price applies equality check predicate on the "price" field. It's identical to PriceEQ.
+func Price(v float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldPrice, v))
+}
+
+// OriginalPrice applies equality check predicate on the "original_price" field. It's identical to OriginalPriceEQ.
+func OriginalPrice(v float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldOriginalPrice, v))
+}
+
+// ValidityDays applies equality check predicate on the "validity_days" field. It's identical to ValidityDaysEQ.
+func ValidityDays(v int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldValidityDays, v))
+}
+
+// ValidityUnit applies equality check predicate on the "validity_unit" field. It's identical to ValidityUnitEQ.
+func ValidityUnit(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldValidityUnit, v))
+}
+
+// Features applies equality check predicate on the "features" field. It's identical to FeaturesEQ.
+func Features(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldFeatures, v))
+}
+
+// ProductName applies equality check predicate on the "product_name" field. It's identical to ProductNameEQ.
+func ProductName(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldProductName, v))
+}
+
+// ForSale applies equality check predicate on the "for_sale" field. It's identical to ForSaleEQ.
+func ForSale(v bool) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldForSale, v))
+}
+
+// SortOrder applies equality check predicate on the "sort_order" field. It's identical to SortOrderEQ.
+func SortOrder(v int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldSortOrder, v))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// GroupIDEQ applies the EQ predicate on the "group_id" field.
+func GroupIDEQ(v int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldGroupID, v))
+}
+
+// GroupIDNEQ applies the NEQ predicate on the "group_id" field.
+func GroupIDNEQ(v int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNEQ(FieldGroupID, v))
+}
+
+// GroupIDIn applies the In predicate on the "group_id" field.
+func GroupIDIn(vs ...int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldIn(FieldGroupID, vs...))
+}
+
+// GroupIDNotIn applies the NotIn predicate on the "group_id" field.
+func GroupIDNotIn(vs ...int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNotIn(FieldGroupID, vs...))
+}
+
+// GroupIDGT applies the GT predicate on the "group_id" field.
+func GroupIDGT(v int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGT(FieldGroupID, v))
+}
+
+// GroupIDGTE applies the GTE predicate on the "group_id" field.
+func GroupIDGTE(v int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGTE(FieldGroupID, v))
+}
+
+// GroupIDLT applies the LT predicate on the "group_id" field.
+func GroupIDLT(v int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLT(FieldGroupID, v))
+}
+
+// GroupIDLTE applies the LTE predicate on the "group_id" field.
+func GroupIDLTE(v int64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLTE(FieldGroupID, v))
+}
+
+// NameEQ applies the EQ predicate on the "name" field.
+func NameEQ(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldName, v))
+}
+
+// NameNEQ applies the NEQ predicate on the "name" field.
+func NameNEQ(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNEQ(FieldName, v))
+}
+
+// NameIn applies the In predicate on the "name" field.
+func NameIn(vs ...string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldIn(FieldName, vs...))
+}
+
+// NameNotIn applies the NotIn predicate on the "name" field.
+func NameNotIn(vs ...string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNotIn(FieldName, vs...))
+}
+
+// NameGT applies the GT predicate on the "name" field.
+func NameGT(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGT(FieldName, v))
+}
+
+// NameGTE applies the GTE predicate on the "name" field.
+func NameGTE(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGTE(FieldName, v))
+}
+
+// NameLT applies the LT predicate on the "name" field.
+func NameLT(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLT(FieldName, v))
+}
+
+// NameLTE applies the LTE predicate on the "name" field.
+func NameLTE(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLTE(FieldName, v))
+}
+
+// NameContains applies the Contains predicate on the "name" field.
+func NameContains(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldContains(FieldName, v))
+}
+
+// NameHasPrefix applies the HasPrefix predicate on the "name" field.
+func NameHasPrefix(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldHasPrefix(FieldName, v))
+}
+
+// NameHasSuffix applies the HasSuffix predicate on the "name" field.
+func NameHasSuffix(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldHasSuffix(FieldName, v))
+}
+
+// NameEqualFold applies the EqualFold predicate on the "name" field.
+func NameEqualFold(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEqualFold(FieldName, v))
+}
+
+// NameContainsFold applies the ContainsFold predicate on the "name" field.
+func NameContainsFold(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldContainsFold(FieldName, v))
+}
+
+// DescriptionEQ applies the EQ predicate on the "description" field.
+func DescriptionEQ(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldDescription, v))
+}
+
+// DescriptionNEQ applies the NEQ predicate on the "description" field.
+func DescriptionNEQ(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNEQ(FieldDescription, v))
+}
+
+// DescriptionIn applies the In predicate on the "description" field.
+func DescriptionIn(vs ...string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldIn(FieldDescription, vs...))
+}
+
+// DescriptionNotIn applies the NotIn predicate on the "description" field.
+func DescriptionNotIn(vs ...string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNotIn(FieldDescription, vs...))
+}
+
+// DescriptionGT applies the GT predicate on the "description" field.
+func DescriptionGT(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGT(FieldDescription, v))
+}
+
+// DescriptionGTE applies the GTE predicate on the "description" field.
+func DescriptionGTE(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGTE(FieldDescription, v))
+}
+
+// DescriptionLT applies the LT predicate on the "description" field.
+func DescriptionLT(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLT(FieldDescription, v))
+}
+
+// DescriptionLTE applies the LTE predicate on the "description" field.
+func DescriptionLTE(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLTE(FieldDescription, v))
+}
+
+// DescriptionContains applies the Contains predicate on the "description" field.
+func DescriptionContains(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldContains(FieldDescription, v))
+}
+
+// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field.
+func DescriptionHasPrefix(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldHasPrefix(FieldDescription, v))
+}
+
+// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field.
+func DescriptionHasSuffix(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldHasSuffix(FieldDescription, v))
+}
+
+// DescriptionEqualFold applies the EqualFold predicate on the "description" field.
+func DescriptionEqualFold(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEqualFold(FieldDescription, v))
+}
+
+// DescriptionContainsFold applies the ContainsFold predicate on the "description" field.
+func DescriptionContainsFold(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldContainsFold(FieldDescription, v))
+}
+
+// PriceEQ applies the EQ predicate on the "price" field.
+func PriceEQ(v float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldPrice, v))
+}
+
+// PriceNEQ applies the NEQ predicate on the "price" field.
+func PriceNEQ(v float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNEQ(FieldPrice, v))
+}
+
+// PriceIn applies the In predicate on the "price" field.
+func PriceIn(vs ...float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldIn(FieldPrice, vs...))
+}
+
+// PriceNotIn applies the NotIn predicate on the "price" field.
+func PriceNotIn(vs ...float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNotIn(FieldPrice, vs...))
+}
+
+// PriceGT applies the GT predicate on the "price" field.
+func PriceGT(v float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGT(FieldPrice, v))
+}
+
+// PriceGTE applies the GTE predicate on the "price" field.
+func PriceGTE(v float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGTE(FieldPrice, v))
+}
+
+// PriceLT applies the LT predicate on the "price" field.
+func PriceLT(v float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLT(FieldPrice, v))
+}
+
+// PriceLTE applies the LTE predicate on the "price" field.
+func PriceLTE(v float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLTE(FieldPrice, v))
+}
+
+// OriginalPriceEQ applies the EQ predicate on the "original_price" field.
+func OriginalPriceEQ(v float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldOriginalPrice, v))
+}
+
+// OriginalPriceNEQ applies the NEQ predicate on the "original_price" field.
+func OriginalPriceNEQ(v float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNEQ(FieldOriginalPrice, v))
+}
+
+// OriginalPriceIn applies the In predicate on the "original_price" field.
+func OriginalPriceIn(vs ...float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldIn(FieldOriginalPrice, vs...))
+}
+
+// OriginalPriceNotIn applies the NotIn predicate on the "original_price" field.
+func OriginalPriceNotIn(vs ...float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNotIn(FieldOriginalPrice, vs...))
+}
+
+// OriginalPriceGT applies the GT predicate on the "original_price" field.
+func OriginalPriceGT(v float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGT(FieldOriginalPrice, v))
+}
+
+// OriginalPriceGTE applies the GTE predicate on the "original_price" field.
+func OriginalPriceGTE(v float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGTE(FieldOriginalPrice, v))
+}
+
+// OriginalPriceLT applies the LT predicate on the "original_price" field.
+func OriginalPriceLT(v float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLT(FieldOriginalPrice, v))
+}
+
+// OriginalPriceLTE applies the LTE predicate on the "original_price" field.
+func OriginalPriceLTE(v float64) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLTE(FieldOriginalPrice, v))
+}
+
+// OriginalPriceIsNil applies the IsNil predicate on the "original_price" field.
+func OriginalPriceIsNil() predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldIsNull(FieldOriginalPrice))
+}
+
+// OriginalPriceNotNil applies the NotNil predicate on the "original_price" field.
+func OriginalPriceNotNil() predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNotNull(FieldOriginalPrice))
+}
+
+// ValidityDaysEQ applies the EQ predicate on the "validity_days" field.
+func ValidityDaysEQ(v int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldValidityDays, v))
+}
+
+// ValidityDaysNEQ applies the NEQ predicate on the "validity_days" field.
+func ValidityDaysNEQ(v int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNEQ(FieldValidityDays, v))
+}
+
+// ValidityDaysIn applies the In predicate on the "validity_days" field.
+func ValidityDaysIn(vs ...int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldIn(FieldValidityDays, vs...))
+}
+
+// ValidityDaysNotIn applies the NotIn predicate on the "validity_days" field.
+func ValidityDaysNotIn(vs ...int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNotIn(FieldValidityDays, vs...))
+}
+
+// ValidityDaysGT applies the GT predicate on the "validity_days" field.
+func ValidityDaysGT(v int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGT(FieldValidityDays, v))
+}
+
+// ValidityDaysGTE applies the GTE predicate on the "validity_days" field.
+func ValidityDaysGTE(v int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGTE(FieldValidityDays, v))
+}
+
+// ValidityDaysLT applies the LT predicate on the "validity_days" field.
+func ValidityDaysLT(v int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLT(FieldValidityDays, v))
+}
+
+// ValidityDaysLTE applies the LTE predicate on the "validity_days" field.
+func ValidityDaysLTE(v int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLTE(FieldValidityDays, v))
+}
+
+// ValidityUnitEQ applies the EQ predicate on the "validity_unit" field.
+func ValidityUnitEQ(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldValidityUnit, v))
+}
+
+// ValidityUnitNEQ applies the NEQ predicate on the "validity_unit" field.
+func ValidityUnitNEQ(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNEQ(FieldValidityUnit, v))
+}
+
+// ValidityUnitIn applies the In predicate on the "validity_unit" field.
+func ValidityUnitIn(vs ...string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldIn(FieldValidityUnit, vs...))
+}
+
+// ValidityUnitNotIn applies the NotIn predicate on the "validity_unit" field.
+func ValidityUnitNotIn(vs ...string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNotIn(FieldValidityUnit, vs...))
+}
+
+// ValidityUnitGT applies the GT predicate on the "validity_unit" field.
+func ValidityUnitGT(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGT(FieldValidityUnit, v))
+}
+
+// ValidityUnitGTE applies the GTE predicate on the "validity_unit" field.
+func ValidityUnitGTE(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGTE(FieldValidityUnit, v))
+}
+
+// ValidityUnitLT applies the LT predicate on the "validity_unit" field.
+func ValidityUnitLT(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLT(FieldValidityUnit, v))
+}
+
+// ValidityUnitLTE applies the LTE predicate on the "validity_unit" field.
+func ValidityUnitLTE(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLTE(FieldValidityUnit, v))
+}
+
+// ValidityUnitContains applies the Contains predicate on the "validity_unit" field.
+func ValidityUnitContains(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldContains(FieldValidityUnit, v))
+}
+
+// ValidityUnitHasPrefix applies the HasPrefix predicate on the "validity_unit" field.
+func ValidityUnitHasPrefix(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldHasPrefix(FieldValidityUnit, v))
+}
+
+// ValidityUnitHasSuffix applies the HasSuffix predicate on the "validity_unit" field.
+func ValidityUnitHasSuffix(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldHasSuffix(FieldValidityUnit, v))
+}
+
+// ValidityUnitEqualFold applies the EqualFold predicate on the "validity_unit" field.
+func ValidityUnitEqualFold(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEqualFold(FieldValidityUnit, v))
+}
+
+// ValidityUnitContainsFold applies the ContainsFold predicate on the "validity_unit" field.
+func ValidityUnitContainsFold(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldContainsFold(FieldValidityUnit, v))
+}
+
+// FeaturesEQ applies the EQ predicate on the "features" field.
+func FeaturesEQ(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldFeatures, v))
+}
+
+// FeaturesNEQ applies the NEQ predicate on the "features" field.
+func FeaturesNEQ(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNEQ(FieldFeatures, v))
+}
+
+// FeaturesIn applies the In predicate on the "features" field.
+func FeaturesIn(vs ...string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldIn(FieldFeatures, vs...))
+}
+
+// FeaturesNotIn applies the NotIn predicate on the "features" field.
+func FeaturesNotIn(vs ...string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNotIn(FieldFeatures, vs...))
+}
+
+// FeaturesGT applies the GT predicate on the "features" field.
+func FeaturesGT(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGT(FieldFeatures, v))
+}
+
+// FeaturesGTE applies the GTE predicate on the "features" field.
+func FeaturesGTE(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGTE(FieldFeatures, v))
+}
+
+// FeaturesLT applies the LT predicate on the "features" field.
+func FeaturesLT(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLT(FieldFeatures, v))
+}
+
+// FeaturesLTE applies the LTE predicate on the "features" field.
+func FeaturesLTE(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLTE(FieldFeatures, v))
+}
+
+// FeaturesContains applies the Contains predicate on the "features" field.
+func FeaturesContains(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldContains(FieldFeatures, v))
+}
+
+// FeaturesHasPrefix applies the HasPrefix predicate on the "features" field.
+func FeaturesHasPrefix(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldHasPrefix(FieldFeatures, v))
+}
+
+// FeaturesHasSuffix applies the HasSuffix predicate on the "features" field.
+func FeaturesHasSuffix(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldHasSuffix(FieldFeatures, v))
+}
+
+// FeaturesEqualFold applies the EqualFold predicate on the "features" field.
+func FeaturesEqualFold(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEqualFold(FieldFeatures, v))
+}
+
+// FeaturesContainsFold applies the ContainsFold predicate on the "features" field.
+func FeaturesContainsFold(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldContainsFold(FieldFeatures, v))
+}
+
+// ProductNameEQ applies the EQ predicate on the "product_name" field.
+func ProductNameEQ(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldProductName, v))
+}
+
+// ProductNameNEQ applies the NEQ predicate on the "product_name" field.
+func ProductNameNEQ(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNEQ(FieldProductName, v))
+}
+
+// ProductNameIn applies the In predicate on the "product_name" field.
+func ProductNameIn(vs ...string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldIn(FieldProductName, vs...))
+}
+
+// ProductNameNotIn applies the NotIn predicate on the "product_name" field.
+func ProductNameNotIn(vs ...string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNotIn(FieldProductName, vs...))
+}
+
+// ProductNameGT applies the GT predicate on the "product_name" field.
+func ProductNameGT(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGT(FieldProductName, v))
+}
+
+// ProductNameGTE applies the GTE predicate on the "product_name" field.
+func ProductNameGTE(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGTE(FieldProductName, v))
+}
+
+// ProductNameLT applies the LT predicate on the "product_name" field.
+func ProductNameLT(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLT(FieldProductName, v))
+}
+
+// ProductNameLTE applies the LTE predicate on the "product_name" field.
+func ProductNameLTE(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLTE(FieldProductName, v))
+}
+
+// ProductNameContains applies the Contains predicate on the "product_name" field.
+func ProductNameContains(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldContains(FieldProductName, v))
+}
+
+// ProductNameHasPrefix applies the HasPrefix predicate on the "product_name" field.
+func ProductNameHasPrefix(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldHasPrefix(FieldProductName, v))
+}
+
+// ProductNameHasSuffix applies the HasSuffix predicate on the "product_name" field.
+func ProductNameHasSuffix(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldHasSuffix(FieldProductName, v))
+}
+
+// ProductNameEqualFold applies the EqualFold predicate on the "product_name" field.
+func ProductNameEqualFold(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEqualFold(FieldProductName, v))
+}
+
+// ProductNameContainsFold applies the ContainsFold predicate on the "product_name" field.
+func ProductNameContainsFold(v string) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldContainsFold(FieldProductName, v))
+}
+
+// ForSaleEQ applies the EQ predicate on the "for_sale" field.
+func ForSaleEQ(v bool) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldForSale, v))
+}
+
+// ForSaleNEQ applies the NEQ predicate on the "for_sale" field.
+func ForSaleNEQ(v bool) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNEQ(FieldForSale, v))
+}
+
+// SortOrderEQ applies the EQ predicate on the "sort_order" field.
+func SortOrderEQ(v int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldSortOrder, v))
+}
+
+// SortOrderNEQ applies the NEQ predicate on the "sort_order" field.
+func SortOrderNEQ(v int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNEQ(FieldSortOrder, v))
+}
+
+// SortOrderIn applies the In predicate on the "sort_order" field.
+func SortOrderIn(vs ...int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldIn(FieldSortOrder, vs...))
+}
+
+// SortOrderNotIn applies the NotIn predicate on the "sort_order" field.
+func SortOrderNotIn(vs ...int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNotIn(FieldSortOrder, vs...))
+}
+
+// SortOrderGT applies the GT predicate on the "sort_order" field.
+func SortOrderGT(v int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGT(FieldSortOrder, v))
+}
+
+// SortOrderGTE applies the GTE predicate on the "sort_order" field.
+func SortOrderGTE(v int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGTE(FieldSortOrder, v))
+}
+
+// SortOrderLT applies the LT predicate on the "sort_order" field.
+func SortOrderLT(v int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLT(FieldSortOrder, v))
+}
+
+// SortOrderLTE applies the LTE predicate on the "sort_order" field.
+func SortOrderLTE(v int) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLTE(FieldSortOrder, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.SubscriptionPlan) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.SubscriptionPlan) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.SubscriptionPlan) predicate.SubscriptionPlan {
+ return predicate.SubscriptionPlan(sql.NotPredicates(p))
+}
diff --git a/backend/ent/subscriptionplan_create.go b/backend/ent/subscriptionplan_create.go
new file mode 100644
index 00000000..9109db3a
--- /dev/null
+++ b/backend/ent/subscriptionplan_create.go
@@ -0,0 +1,1317 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/subscriptionplan"
+)
+
+// SubscriptionPlanCreate is the builder for creating a SubscriptionPlan entity.
+type SubscriptionPlanCreate struct {
+ config
+ mutation *SubscriptionPlanMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetGroupID sets the "group_id" field.
+func (_c *SubscriptionPlanCreate) SetGroupID(v int64) *SubscriptionPlanCreate {
+ _c.mutation.SetGroupID(v)
+ return _c
+}
+
+// SetName sets the "name" field.
+func (_c *SubscriptionPlanCreate) SetName(v string) *SubscriptionPlanCreate {
+ _c.mutation.SetName(v)
+ return _c
+}
+
+// SetDescription sets the "description" field.
+func (_c *SubscriptionPlanCreate) SetDescription(v string) *SubscriptionPlanCreate {
+ _c.mutation.SetDescription(v)
+ return _c
+}
+
+// SetNillableDescription sets the "description" field if the given value is not nil.
+func (_c *SubscriptionPlanCreate) SetNillableDescription(v *string) *SubscriptionPlanCreate {
+ if v != nil {
+ _c.SetDescription(*v)
+ }
+ return _c
+}
+
+// SetPrice sets the "price" field.
+func (_c *SubscriptionPlanCreate) SetPrice(v float64) *SubscriptionPlanCreate {
+ _c.mutation.SetPrice(v)
+ return _c
+}
+
+// SetOriginalPrice sets the "original_price" field.
+func (_c *SubscriptionPlanCreate) SetOriginalPrice(v float64) *SubscriptionPlanCreate {
+ _c.mutation.SetOriginalPrice(v)
+ return _c
+}
+
+// SetNillableOriginalPrice sets the "original_price" field if the given value is not nil.
+func (_c *SubscriptionPlanCreate) SetNillableOriginalPrice(v *float64) *SubscriptionPlanCreate {
+ if v != nil {
+ _c.SetOriginalPrice(*v)
+ }
+ return _c
+}
+
+// SetValidityDays sets the "validity_days" field.
+func (_c *SubscriptionPlanCreate) SetValidityDays(v int) *SubscriptionPlanCreate {
+ _c.mutation.SetValidityDays(v)
+ return _c
+}
+
+// SetNillableValidityDays sets the "validity_days" field if the given value is not nil.
+func (_c *SubscriptionPlanCreate) SetNillableValidityDays(v *int) *SubscriptionPlanCreate {
+ if v != nil {
+ _c.SetValidityDays(*v)
+ }
+ return _c
+}
+
+// SetValidityUnit sets the "validity_unit" field.
+func (_c *SubscriptionPlanCreate) SetValidityUnit(v string) *SubscriptionPlanCreate {
+ _c.mutation.SetValidityUnit(v)
+ return _c
+}
+
+// SetNillableValidityUnit sets the "validity_unit" field if the given value is not nil.
+func (_c *SubscriptionPlanCreate) SetNillableValidityUnit(v *string) *SubscriptionPlanCreate {
+ if v != nil {
+ _c.SetValidityUnit(*v)
+ }
+ return _c
+}
+
+// SetFeatures sets the "features" field.
+func (_c *SubscriptionPlanCreate) SetFeatures(v string) *SubscriptionPlanCreate {
+ _c.mutation.SetFeatures(v)
+ return _c
+}
+
+// SetNillableFeatures sets the "features" field if the given value is not nil.
+func (_c *SubscriptionPlanCreate) SetNillableFeatures(v *string) *SubscriptionPlanCreate {
+ if v != nil {
+ _c.SetFeatures(*v)
+ }
+ return _c
+}
+
+// SetProductName sets the "product_name" field.
+func (_c *SubscriptionPlanCreate) SetProductName(v string) *SubscriptionPlanCreate {
+ _c.mutation.SetProductName(v)
+ return _c
+}
+
+// SetNillableProductName sets the "product_name" field if the given value is not nil.
+func (_c *SubscriptionPlanCreate) SetNillableProductName(v *string) *SubscriptionPlanCreate {
+ if v != nil {
+ _c.SetProductName(*v)
+ }
+ return _c
+}
+
+// SetForSale sets the "for_sale" field.
+func (_c *SubscriptionPlanCreate) SetForSale(v bool) *SubscriptionPlanCreate {
+ _c.mutation.SetForSale(v)
+ return _c
+}
+
+// SetNillableForSale sets the "for_sale" field if the given value is not nil.
+func (_c *SubscriptionPlanCreate) SetNillableForSale(v *bool) *SubscriptionPlanCreate {
+ if v != nil {
+ _c.SetForSale(*v)
+ }
+ return _c
+}
+
+// SetSortOrder sets the "sort_order" field.
+func (_c *SubscriptionPlanCreate) SetSortOrder(v int) *SubscriptionPlanCreate {
+ _c.mutation.SetSortOrder(v)
+ return _c
+}
+
+// SetNillableSortOrder sets the "sort_order" field if the given value is not nil.
+func (_c *SubscriptionPlanCreate) SetNillableSortOrder(v *int) *SubscriptionPlanCreate {
+ if v != nil {
+ _c.SetSortOrder(*v)
+ }
+ return _c
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *SubscriptionPlanCreate) SetCreatedAt(v time.Time) *SubscriptionPlanCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *SubscriptionPlanCreate) SetNillableCreatedAt(v *time.Time) *SubscriptionPlanCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *SubscriptionPlanCreate) SetUpdatedAt(v time.Time) *SubscriptionPlanCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *SubscriptionPlanCreate) SetNillableUpdatedAt(v *time.Time) *SubscriptionPlanCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// Mutation returns the SubscriptionPlanMutation object of the builder.
+func (_c *SubscriptionPlanCreate) Mutation() *SubscriptionPlanMutation {
+ return _c.mutation
+}
+
+// Save creates the SubscriptionPlan in the database.
+func (_c *SubscriptionPlanCreate) Save(ctx context.Context) (*SubscriptionPlan, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *SubscriptionPlanCreate) SaveX(ctx context.Context) *SubscriptionPlan {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *SubscriptionPlanCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *SubscriptionPlanCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *SubscriptionPlanCreate) defaults() {
+ if _, ok := _c.mutation.Description(); !ok {
+ v := subscriptionplan.DefaultDescription
+ _c.mutation.SetDescription(v)
+ }
+ if _, ok := _c.mutation.ValidityDays(); !ok {
+ v := subscriptionplan.DefaultValidityDays
+ _c.mutation.SetValidityDays(v)
+ }
+ if _, ok := _c.mutation.ValidityUnit(); !ok {
+ v := subscriptionplan.DefaultValidityUnit
+ _c.mutation.SetValidityUnit(v)
+ }
+ if _, ok := _c.mutation.Features(); !ok {
+ v := subscriptionplan.DefaultFeatures
+ _c.mutation.SetFeatures(v)
+ }
+ if _, ok := _c.mutation.ProductName(); !ok {
+ v := subscriptionplan.DefaultProductName
+ _c.mutation.SetProductName(v)
+ }
+ if _, ok := _c.mutation.ForSale(); !ok {
+ v := subscriptionplan.DefaultForSale
+ _c.mutation.SetForSale(v)
+ }
+ if _, ok := _c.mutation.SortOrder(); !ok {
+ v := subscriptionplan.DefaultSortOrder
+ _c.mutation.SetSortOrder(v)
+ }
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := subscriptionplan.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := subscriptionplan.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *SubscriptionPlanCreate) check() error {
+ if _, ok := _c.mutation.GroupID(); !ok {
+ return &ValidationError{Name: "group_id", err: errors.New(`ent: missing required field "SubscriptionPlan.group_id"`)}
+ }
+ if _, ok := _c.mutation.Name(); !ok {
+ return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "SubscriptionPlan.name"`)}
+ }
+ if v, ok := _c.mutation.Name(); ok {
+ if err := subscriptionplan.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "SubscriptionPlan.name": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Description(); !ok {
+ return &ValidationError{Name: "description", err: errors.New(`ent: missing required field "SubscriptionPlan.description"`)}
+ }
+ if _, ok := _c.mutation.Price(); !ok {
+ return &ValidationError{Name: "price", err: errors.New(`ent: missing required field "SubscriptionPlan.price"`)}
+ }
+ if _, ok := _c.mutation.ValidityDays(); !ok {
+ return &ValidationError{Name: "validity_days", err: errors.New(`ent: missing required field "SubscriptionPlan.validity_days"`)}
+ }
+ if _, ok := _c.mutation.ValidityUnit(); !ok {
+ return &ValidationError{Name: "validity_unit", err: errors.New(`ent: missing required field "SubscriptionPlan.validity_unit"`)}
+ }
+ if v, ok := _c.mutation.ValidityUnit(); ok {
+ if err := subscriptionplan.ValidityUnitValidator(v); err != nil {
+ return &ValidationError{Name: "validity_unit", err: fmt.Errorf(`ent: validator failed for field "SubscriptionPlan.validity_unit": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Features(); !ok {
+ return &ValidationError{Name: "features", err: errors.New(`ent: missing required field "SubscriptionPlan.features"`)}
+ }
+ if _, ok := _c.mutation.ProductName(); !ok {
+ return &ValidationError{Name: "product_name", err: errors.New(`ent: missing required field "SubscriptionPlan.product_name"`)}
+ }
+ if v, ok := _c.mutation.ProductName(); ok {
+ if err := subscriptionplan.ProductNameValidator(v); err != nil {
+ return &ValidationError{Name: "product_name", err: fmt.Errorf(`ent: validator failed for field "SubscriptionPlan.product_name": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.ForSale(); !ok {
+ return &ValidationError{Name: "for_sale", err: errors.New(`ent: missing required field "SubscriptionPlan.for_sale"`)}
+ }
+ if _, ok := _c.mutation.SortOrder(); !ok {
+ return &ValidationError{Name: "sort_order", err: errors.New(`ent: missing required field "SubscriptionPlan.sort_order"`)}
+ }
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "SubscriptionPlan.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "SubscriptionPlan.updated_at"`)}
+ }
+ return nil
+}
+
+func (_c *SubscriptionPlanCreate) sqlSave(ctx context.Context) (*SubscriptionPlan, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *SubscriptionPlanCreate) createSpec() (*SubscriptionPlan, *sqlgraph.CreateSpec) {
+ var (
+ _node = &SubscriptionPlan{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(subscriptionplan.Table, sqlgraph.NewFieldSpec(subscriptionplan.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.GroupID(); ok {
+ _spec.SetField(subscriptionplan.FieldGroupID, field.TypeInt64, value)
+ _node.GroupID = value
+ }
+ if value, ok := _c.mutation.Name(); ok {
+ _spec.SetField(subscriptionplan.FieldName, field.TypeString, value)
+ _node.Name = value
+ }
+ if value, ok := _c.mutation.Description(); ok {
+ _spec.SetField(subscriptionplan.FieldDescription, field.TypeString, value)
+ _node.Description = value
+ }
+ if value, ok := _c.mutation.Price(); ok {
+ _spec.SetField(subscriptionplan.FieldPrice, field.TypeFloat64, value)
+ _node.Price = value
+ }
+ if value, ok := _c.mutation.OriginalPrice(); ok {
+ _spec.SetField(subscriptionplan.FieldOriginalPrice, field.TypeFloat64, value)
+ _node.OriginalPrice = &value
+ }
+ if value, ok := _c.mutation.ValidityDays(); ok {
+ _spec.SetField(subscriptionplan.FieldValidityDays, field.TypeInt, value)
+ _node.ValidityDays = value
+ }
+ if value, ok := _c.mutation.ValidityUnit(); ok {
+ _spec.SetField(subscriptionplan.FieldValidityUnit, field.TypeString, value)
+ _node.ValidityUnit = value
+ }
+ if value, ok := _c.mutation.Features(); ok {
+ _spec.SetField(subscriptionplan.FieldFeatures, field.TypeString, value)
+ _node.Features = value
+ }
+ if value, ok := _c.mutation.ProductName(); ok {
+ _spec.SetField(subscriptionplan.FieldProductName, field.TypeString, value)
+ _node.ProductName = value
+ }
+ if value, ok := _c.mutation.ForSale(); ok {
+ _spec.SetField(subscriptionplan.FieldForSale, field.TypeBool, value)
+ _node.ForSale = value
+ }
+ if value, ok := _c.mutation.SortOrder(); ok {
+ _spec.SetField(subscriptionplan.FieldSortOrder, field.TypeInt, value)
+ _node.SortOrder = value
+ }
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(subscriptionplan.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(subscriptionplan.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.SubscriptionPlan.Create().
+// SetGroupID(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.SubscriptionPlanUpsert) {
+// SetGroupID(v+v).
+// }).
+// Exec(ctx)
+func (_c *SubscriptionPlanCreate) OnConflict(opts ...sql.ConflictOption) *SubscriptionPlanUpsertOne {
+ _c.conflict = opts
+ return &SubscriptionPlanUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.SubscriptionPlan.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *SubscriptionPlanCreate) OnConflictColumns(columns ...string) *SubscriptionPlanUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &SubscriptionPlanUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // SubscriptionPlanUpsertOne is the builder for "upsert"-ing
+ // one SubscriptionPlan node.
+ SubscriptionPlanUpsertOne struct {
+ create *SubscriptionPlanCreate
+ }
+
+ // SubscriptionPlanUpsert is the "OnConflict" setter.
+ SubscriptionPlanUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetGroupID sets the "group_id" field.
+func (u *SubscriptionPlanUpsert) SetGroupID(v int64) *SubscriptionPlanUpsert {
+ u.Set(subscriptionplan.FieldGroupID, v)
+ return u
+}
+
+// UpdateGroupID sets the "group_id" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsert) UpdateGroupID() *SubscriptionPlanUpsert {
+ u.SetExcluded(subscriptionplan.FieldGroupID)
+ return u
+}
+
+// AddGroupID adds v to the "group_id" field.
+func (u *SubscriptionPlanUpsert) AddGroupID(v int64) *SubscriptionPlanUpsert {
+ u.Add(subscriptionplan.FieldGroupID, v)
+ return u
+}
+
+// SetName sets the "name" field.
+func (u *SubscriptionPlanUpsert) SetName(v string) *SubscriptionPlanUpsert {
+ u.Set(subscriptionplan.FieldName, v)
+ return u
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsert) UpdateName() *SubscriptionPlanUpsert {
+ u.SetExcluded(subscriptionplan.FieldName)
+ return u
+}
+
+// SetDescription sets the "description" field.
+func (u *SubscriptionPlanUpsert) SetDescription(v string) *SubscriptionPlanUpsert {
+ u.Set(subscriptionplan.FieldDescription, v)
+ return u
+}
+
+// UpdateDescription sets the "description" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsert) UpdateDescription() *SubscriptionPlanUpsert {
+ u.SetExcluded(subscriptionplan.FieldDescription)
+ return u
+}
+
+// SetPrice sets the "price" field.
+func (u *SubscriptionPlanUpsert) SetPrice(v float64) *SubscriptionPlanUpsert {
+ u.Set(subscriptionplan.FieldPrice, v)
+ return u
+}
+
+// UpdatePrice sets the "price" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsert) UpdatePrice() *SubscriptionPlanUpsert {
+ u.SetExcluded(subscriptionplan.FieldPrice)
+ return u
+}
+
+// AddPrice adds v to the "price" field.
+func (u *SubscriptionPlanUpsert) AddPrice(v float64) *SubscriptionPlanUpsert {
+ u.Add(subscriptionplan.FieldPrice, v)
+ return u
+}
+
+// SetOriginalPrice sets the "original_price" field.
+func (u *SubscriptionPlanUpsert) SetOriginalPrice(v float64) *SubscriptionPlanUpsert {
+ u.Set(subscriptionplan.FieldOriginalPrice, v)
+ return u
+}
+
+// UpdateOriginalPrice sets the "original_price" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsert) UpdateOriginalPrice() *SubscriptionPlanUpsert {
+ u.SetExcluded(subscriptionplan.FieldOriginalPrice)
+ return u
+}
+
+// AddOriginalPrice adds v to the "original_price" field.
+func (u *SubscriptionPlanUpsert) AddOriginalPrice(v float64) *SubscriptionPlanUpsert {
+ u.Add(subscriptionplan.FieldOriginalPrice, v)
+ return u
+}
+
+// ClearOriginalPrice clears the value of the "original_price" field.
+func (u *SubscriptionPlanUpsert) ClearOriginalPrice() *SubscriptionPlanUpsert {
+ u.SetNull(subscriptionplan.FieldOriginalPrice)
+ return u
+}
+
+// SetValidityDays sets the "validity_days" field.
+func (u *SubscriptionPlanUpsert) SetValidityDays(v int) *SubscriptionPlanUpsert {
+ u.Set(subscriptionplan.FieldValidityDays, v)
+ return u
+}
+
+// UpdateValidityDays sets the "validity_days" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsert) UpdateValidityDays() *SubscriptionPlanUpsert {
+ u.SetExcluded(subscriptionplan.FieldValidityDays)
+ return u
+}
+
+// AddValidityDays adds v to the "validity_days" field.
+func (u *SubscriptionPlanUpsert) AddValidityDays(v int) *SubscriptionPlanUpsert {
+ u.Add(subscriptionplan.FieldValidityDays, v)
+ return u
+}
+
+// SetValidityUnit sets the "validity_unit" field.
+func (u *SubscriptionPlanUpsert) SetValidityUnit(v string) *SubscriptionPlanUpsert {
+ u.Set(subscriptionplan.FieldValidityUnit, v)
+ return u
+}
+
+// UpdateValidityUnit sets the "validity_unit" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsert) UpdateValidityUnit() *SubscriptionPlanUpsert {
+ u.SetExcluded(subscriptionplan.FieldValidityUnit)
+ return u
+}
+
+// SetFeatures sets the "features" field.
+func (u *SubscriptionPlanUpsert) SetFeatures(v string) *SubscriptionPlanUpsert {
+ u.Set(subscriptionplan.FieldFeatures, v)
+ return u
+}
+
+// UpdateFeatures sets the "features" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsert) UpdateFeatures() *SubscriptionPlanUpsert {
+ u.SetExcluded(subscriptionplan.FieldFeatures)
+ return u
+}
+
+// SetProductName sets the "product_name" field.
+func (u *SubscriptionPlanUpsert) SetProductName(v string) *SubscriptionPlanUpsert {
+ u.Set(subscriptionplan.FieldProductName, v)
+ return u
+}
+
+// UpdateProductName sets the "product_name" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsert) UpdateProductName() *SubscriptionPlanUpsert {
+ u.SetExcluded(subscriptionplan.FieldProductName)
+ return u
+}
+
+// SetForSale sets the "for_sale" field.
+func (u *SubscriptionPlanUpsert) SetForSale(v bool) *SubscriptionPlanUpsert {
+ u.Set(subscriptionplan.FieldForSale, v)
+ return u
+}
+
+// UpdateForSale sets the "for_sale" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsert) UpdateForSale() *SubscriptionPlanUpsert {
+ u.SetExcluded(subscriptionplan.FieldForSale)
+ return u
+}
+
+// SetSortOrder sets the "sort_order" field.
+func (u *SubscriptionPlanUpsert) SetSortOrder(v int) *SubscriptionPlanUpsert {
+ u.Set(subscriptionplan.FieldSortOrder, v)
+ return u
+}
+
+// UpdateSortOrder sets the "sort_order" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsert) UpdateSortOrder() *SubscriptionPlanUpsert {
+ u.SetExcluded(subscriptionplan.FieldSortOrder)
+ return u
+}
+
+// AddSortOrder adds v to the "sort_order" field.
+func (u *SubscriptionPlanUpsert) AddSortOrder(v int) *SubscriptionPlanUpsert {
+ u.Add(subscriptionplan.FieldSortOrder, v)
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *SubscriptionPlanUpsert) SetUpdatedAt(v time.Time) *SubscriptionPlanUpsert {
+ u.Set(subscriptionplan.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsert) UpdateUpdatedAt() *SubscriptionPlanUpsert {
+ u.SetExcluded(subscriptionplan.FieldUpdatedAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.SubscriptionPlan.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *SubscriptionPlanUpsertOne) UpdateNewValues() *SubscriptionPlanUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(subscriptionplan.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.SubscriptionPlan.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *SubscriptionPlanUpsertOne) Ignore() *SubscriptionPlanUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *SubscriptionPlanUpsertOne) DoNothing() *SubscriptionPlanUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the SubscriptionPlanCreate.OnConflict
+// documentation for more info.
+func (u *SubscriptionPlanUpsertOne) Update(set func(*SubscriptionPlanUpsert)) *SubscriptionPlanUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&SubscriptionPlanUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetGroupID sets the "group_id" field.
+func (u *SubscriptionPlanUpsertOne) SetGroupID(v int64) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetGroupID(v)
+ })
+}
+
+// AddGroupID adds v to the "group_id" field.
+func (u *SubscriptionPlanUpsertOne) AddGroupID(v int64) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.AddGroupID(v)
+ })
+}
+
+// UpdateGroupID sets the "group_id" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertOne) UpdateGroupID() *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateGroupID()
+ })
+}
+
+// SetName sets the "name" field.
+func (u *SubscriptionPlanUpsertOne) SetName(v string) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetName(v)
+ })
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertOne) UpdateName() *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateName()
+ })
+}
+
+// SetDescription sets the "description" field.
+func (u *SubscriptionPlanUpsertOne) SetDescription(v string) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetDescription(v)
+ })
+}
+
+// UpdateDescription sets the "description" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertOne) UpdateDescription() *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateDescription()
+ })
+}
+
+// SetPrice sets the "price" field.
+func (u *SubscriptionPlanUpsertOne) SetPrice(v float64) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetPrice(v)
+ })
+}
+
+// AddPrice adds v to the "price" field.
+func (u *SubscriptionPlanUpsertOne) AddPrice(v float64) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.AddPrice(v)
+ })
+}
+
+// UpdatePrice sets the "price" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertOne) UpdatePrice() *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdatePrice()
+ })
+}
+
+// SetOriginalPrice sets the "original_price" field.
+func (u *SubscriptionPlanUpsertOne) SetOriginalPrice(v float64) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetOriginalPrice(v)
+ })
+}
+
+// AddOriginalPrice adds v to the "original_price" field.
+func (u *SubscriptionPlanUpsertOne) AddOriginalPrice(v float64) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.AddOriginalPrice(v)
+ })
+}
+
+// UpdateOriginalPrice sets the "original_price" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertOne) UpdateOriginalPrice() *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateOriginalPrice()
+ })
+}
+
+// ClearOriginalPrice clears the value of the "original_price" field.
+func (u *SubscriptionPlanUpsertOne) ClearOriginalPrice() *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.ClearOriginalPrice()
+ })
+}
+
+// SetValidityDays sets the "validity_days" field.
+func (u *SubscriptionPlanUpsertOne) SetValidityDays(v int) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetValidityDays(v)
+ })
+}
+
+// AddValidityDays adds v to the "validity_days" field.
+func (u *SubscriptionPlanUpsertOne) AddValidityDays(v int) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.AddValidityDays(v)
+ })
+}
+
+// UpdateValidityDays sets the "validity_days" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertOne) UpdateValidityDays() *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateValidityDays()
+ })
+}
+
+// SetValidityUnit sets the "validity_unit" field.
+func (u *SubscriptionPlanUpsertOne) SetValidityUnit(v string) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetValidityUnit(v)
+ })
+}
+
+// UpdateValidityUnit sets the "validity_unit" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertOne) UpdateValidityUnit() *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateValidityUnit()
+ })
+}
+
+// SetFeatures sets the "features" field.
+func (u *SubscriptionPlanUpsertOne) SetFeatures(v string) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetFeatures(v)
+ })
+}
+
+// UpdateFeatures sets the "features" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertOne) UpdateFeatures() *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateFeatures()
+ })
+}
+
+// SetProductName sets the "product_name" field.
+func (u *SubscriptionPlanUpsertOne) SetProductName(v string) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetProductName(v)
+ })
+}
+
+// UpdateProductName sets the "product_name" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertOne) UpdateProductName() *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateProductName()
+ })
+}
+
+// SetForSale sets the "for_sale" field.
+func (u *SubscriptionPlanUpsertOne) SetForSale(v bool) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetForSale(v)
+ })
+}
+
+// UpdateForSale sets the "for_sale" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertOne) UpdateForSale() *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateForSale()
+ })
+}
+
+// SetSortOrder sets the "sort_order" field.
+func (u *SubscriptionPlanUpsertOne) SetSortOrder(v int) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetSortOrder(v)
+ })
+}
+
+// AddSortOrder adds v to the "sort_order" field.
+func (u *SubscriptionPlanUpsertOne) AddSortOrder(v int) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.AddSortOrder(v)
+ })
+}
+
+// UpdateSortOrder sets the "sort_order" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertOne) UpdateSortOrder() *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateSortOrder()
+ })
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *SubscriptionPlanUpsertOne) SetUpdatedAt(v time.Time) *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertOne) UpdateUpdatedAt() *SubscriptionPlanUpsertOne {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *SubscriptionPlanUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for SubscriptionPlanCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *SubscriptionPlanUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *SubscriptionPlanUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *SubscriptionPlanUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// SubscriptionPlanCreateBulk is the builder for creating many SubscriptionPlan entities in bulk.
+type SubscriptionPlanCreateBulk struct {
+ config
+ err error
+ builders []*SubscriptionPlanCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the SubscriptionPlan entities in the database.
+func (_c *SubscriptionPlanCreateBulk) Save(ctx context.Context) ([]*SubscriptionPlan, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*SubscriptionPlan, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*SubscriptionPlanMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *SubscriptionPlanCreateBulk) SaveX(ctx context.Context) []*SubscriptionPlan {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *SubscriptionPlanCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *SubscriptionPlanCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.SubscriptionPlan.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.SubscriptionPlanUpsert) {
+// SetGroupID(v+v).
+// }).
+// Exec(ctx)
+func (_c *SubscriptionPlanCreateBulk) OnConflict(opts ...sql.ConflictOption) *SubscriptionPlanUpsertBulk {
+ _c.conflict = opts
+ return &SubscriptionPlanUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.SubscriptionPlan.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *SubscriptionPlanCreateBulk) OnConflictColumns(columns ...string) *SubscriptionPlanUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &SubscriptionPlanUpsertBulk{
+ create: _c,
+ }
+}
+
+// SubscriptionPlanUpsertBulk is the builder for "upsert"-ing
+// a bulk of SubscriptionPlan nodes.
+type SubscriptionPlanUpsertBulk struct {
+ create *SubscriptionPlanCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.SubscriptionPlan.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *SubscriptionPlanUpsertBulk) UpdateNewValues() *SubscriptionPlanUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(subscriptionplan.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.SubscriptionPlan.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *SubscriptionPlanUpsertBulk) Ignore() *SubscriptionPlanUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *SubscriptionPlanUpsertBulk) DoNothing() *SubscriptionPlanUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the SubscriptionPlanCreateBulk.OnConflict
+// documentation for more info.
+func (u *SubscriptionPlanUpsertBulk) Update(set func(*SubscriptionPlanUpsert)) *SubscriptionPlanUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&SubscriptionPlanUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetGroupID sets the "group_id" field.
+func (u *SubscriptionPlanUpsertBulk) SetGroupID(v int64) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetGroupID(v)
+ })
+}
+
+// AddGroupID adds v to the "group_id" field.
+func (u *SubscriptionPlanUpsertBulk) AddGroupID(v int64) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.AddGroupID(v)
+ })
+}
+
+// UpdateGroupID sets the "group_id" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertBulk) UpdateGroupID() *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateGroupID()
+ })
+}
+
+// SetName sets the "name" field.
+func (u *SubscriptionPlanUpsertBulk) SetName(v string) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetName(v)
+ })
+}
+
+// UpdateName sets the "name" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertBulk) UpdateName() *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateName()
+ })
+}
+
+// SetDescription sets the "description" field.
+func (u *SubscriptionPlanUpsertBulk) SetDescription(v string) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetDescription(v)
+ })
+}
+
+// UpdateDescription sets the "description" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertBulk) UpdateDescription() *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateDescription()
+ })
+}
+
+// SetPrice sets the "price" field.
+func (u *SubscriptionPlanUpsertBulk) SetPrice(v float64) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetPrice(v)
+ })
+}
+
+// AddPrice adds v to the "price" field.
+func (u *SubscriptionPlanUpsertBulk) AddPrice(v float64) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.AddPrice(v)
+ })
+}
+
+// UpdatePrice sets the "price" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertBulk) UpdatePrice() *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdatePrice()
+ })
+}
+
+// SetOriginalPrice sets the "original_price" field.
+func (u *SubscriptionPlanUpsertBulk) SetOriginalPrice(v float64) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetOriginalPrice(v)
+ })
+}
+
+// AddOriginalPrice adds v to the "original_price" field.
+func (u *SubscriptionPlanUpsertBulk) AddOriginalPrice(v float64) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.AddOriginalPrice(v)
+ })
+}
+
+// UpdateOriginalPrice sets the "original_price" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertBulk) UpdateOriginalPrice() *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateOriginalPrice()
+ })
+}
+
+// ClearOriginalPrice clears the value of the "original_price" field.
+func (u *SubscriptionPlanUpsertBulk) ClearOriginalPrice() *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.ClearOriginalPrice()
+ })
+}
+
+// SetValidityDays sets the "validity_days" field.
+func (u *SubscriptionPlanUpsertBulk) SetValidityDays(v int) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetValidityDays(v)
+ })
+}
+
+// AddValidityDays adds v to the "validity_days" field.
+func (u *SubscriptionPlanUpsertBulk) AddValidityDays(v int) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.AddValidityDays(v)
+ })
+}
+
+// UpdateValidityDays sets the "validity_days" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertBulk) UpdateValidityDays() *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateValidityDays()
+ })
+}
+
+// SetValidityUnit sets the "validity_unit" field.
+func (u *SubscriptionPlanUpsertBulk) SetValidityUnit(v string) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetValidityUnit(v)
+ })
+}
+
+// UpdateValidityUnit sets the "validity_unit" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertBulk) UpdateValidityUnit() *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateValidityUnit()
+ })
+}
+
+// SetFeatures sets the "features" field.
+func (u *SubscriptionPlanUpsertBulk) SetFeatures(v string) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetFeatures(v)
+ })
+}
+
+// UpdateFeatures sets the "features" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertBulk) UpdateFeatures() *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateFeatures()
+ })
+}
+
+// SetProductName sets the "product_name" field.
+func (u *SubscriptionPlanUpsertBulk) SetProductName(v string) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetProductName(v)
+ })
+}
+
+// UpdateProductName sets the "product_name" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertBulk) UpdateProductName() *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateProductName()
+ })
+}
+
+// SetForSale sets the "for_sale" field.
+func (u *SubscriptionPlanUpsertBulk) SetForSale(v bool) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetForSale(v)
+ })
+}
+
+// UpdateForSale sets the "for_sale" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertBulk) UpdateForSale() *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateForSale()
+ })
+}
+
+// SetSortOrder sets the "sort_order" field.
+func (u *SubscriptionPlanUpsertBulk) SetSortOrder(v int) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetSortOrder(v)
+ })
+}
+
+// AddSortOrder adds v to the "sort_order" field.
+func (u *SubscriptionPlanUpsertBulk) AddSortOrder(v int) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.AddSortOrder(v)
+ })
+}
+
+// UpdateSortOrder sets the "sort_order" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertBulk) UpdateSortOrder() *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateSortOrder()
+ })
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *SubscriptionPlanUpsertBulk) SetUpdatedAt(v time.Time) *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *SubscriptionPlanUpsertBulk) UpdateUpdatedAt() *SubscriptionPlanUpsertBulk {
+ return u.Update(func(s *SubscriptionPlanUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *SubscriptionPlanUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the SubscriptionPlanCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for SubscriptionPlanCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *SubscriptionPlanUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/subscriptionplan_delete.go b/backend/ent/subscriptionplan_delete.go
new file mode 100644
index 00000000..90c71239
--- /dev/null
+++ b/backend/ent/subscriptionplan_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/subscriptionplan"
+)
+
+// SubscriptionPlanDelete is the builder for deleting a SubscriptionPlan entity.
+type SubscriptionPlanDelete struct {
+ config
+ hooks []Hook
+ mutation *SubscriptionPlanMutation
+}
+
+// Where appends a list predicates to the SubscriptionPlanDelete builder.
+func (_d *SubscriptionPlanDelete) Where(ps ...predicate.SubscriptionPlan) *SubscriptionPlanDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *SubscriptionPlanDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *SubscriptionPlanDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *SubscriptionPlanDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(subscriptionplan.Table, sqlgraph.NewFieldSpec(subscriptionplan.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// SubscriptionPlanDeleteOne is the builder for deleting a single SubscriptionPlan entity.
+type SubscriptionPlanDeleteOne struct {
+ _d *SubscriptionPlanDelete
+}
+
+// Where appends a list predicates to the SubscriptionPlanDelete builder.
+func (_d *SubscriptionPlanDeleteOne) Where(ps ...predicate.SubscriptionPlan) *SubscriptionPlanDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *SubscriptionPlanDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{subscriptionplan.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *SubscriptionPlanDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/subscriptionplan_query.go b/backend/ent/subscriptionplan_query.go
new file mode 100644
index 00000000..6c301dcd
--- /dev/null
+++ b/backend/ent/subscriptionplan_query.go
@@ -0,0 +1,564 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/subscriptionplan"
+)
+
+// SubscriptionPlanQuery is the builder for querying SubscriptionPlan entities.
+type SubscriptionPlanQuery struct {
+ config
+ ctx *QueryContext
+ order []subscriptionplan.OrderOption
+ inters []Interceptor
+ predicates []predicate.SubscriptionPlan
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the SubscriptionPlanQuery builder.
+func (_q *SubscriptionPlanQuery) Where(ps ...predicate.SubscriptionPlan) *SubscriptionPlanQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *SubscriptionPlanQuery) Limit(limit int) *SubscriptionPlanQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *SubscriptionPlanQuery) Offset(offset int) *SubscriptionPlanQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *SubscriptionPlanQuery) Unique(unique bool) *SubscriptionPlanQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *SubscriptionPlanQuery) Order(o ...subscriptionplan.OrderOption) *SubscriptionPlanQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// First returns the first SubscriptionPlan entity from the query.
+// Returns a *NotFoundError when no SubscriptionPlan was found.
+func (_q *SubscriptionPlanQuery) First(ctx context.Context) (*SubscriptionPlan, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{subscriptionplan.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *SubscriptionPlanQuery) FirstX(ctx context.Context) *SubscriptionPlan {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first SubscriptionPlan ID from the query.
+// Returns a *NotFoundError when no SubscriptionPlan ID was found.
+func (_q *SubscriptionPlanQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{subscriptionplan.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *SubscriptionPlanQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single SubscriptionPlan entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one SubscriptionPlan entity is found.
+// Returns a *NotFoundError when no SubscriptionPlan entities are found.
+func (_q *SubscriptionPlanQuery) Only(ctx context.Context) (*SubscriptionPlan, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{subscriptionplan.Label}
+ default:
+ return nil, &NotSingularError{subscriptionplan.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *SubscriptionPlanQuery) OnlyX(ctx context.Context) *SubscriptionPlan {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only SubscriptionPlan ID in the query.
+// Returns a *NotSingularError when more than one SubscriptionPlan ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *SubscriptionPlanQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{subscriptionplan.Label}
+ default:
+ err = &NotSingularError{subscriptionplan.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *SubscriptionPlanQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of SubscriptionPlans.
+func (_q *SubscriptionPlanQuery) All(ctx context.Context) ([]*SubscriptionPlan, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*SubscriptionPlan, *SubscriptionPlanQuery]()
+ return withInterceptors[[]*SubscriptionPlan](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *SubscriptionPlanQuery) AllX(ctx context.Context) []*SubscriptionPlan {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of SubscriptionPlan IDs.
+func (_q *SubscriptionPlanQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(subscriptionplan.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *SubscriptionPlanQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *SubscriptionPlanQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*SubscriptionPlanQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *SubscriptionPlanQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *SubscriptionPlanQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *SubscriptionPlanQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the SubscriptionPlanQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *SubscriptionPlanQuery) Clone() *SubscriptionPlanQuery {
+ if _q == nil {
+ return nil
+ }
+ return &SubscriptionPlanQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]subscriptionplan.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.SubscriptionPlan{}, _q.predicates...),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// GroupID int64 `json:"group_id,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.SubscriptionPlan.Query().
+// GroupBy(subscriptionplan.FieldGroupID).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *SubscriptionPlanQuery) GroupBy(field string, fields ...string) *SubscriptionPlanGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &SubscriptionPlanGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = subscriptionplan.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// GroupID int64 `json:"group_id,omitempty"`
+// }
+//
+// client.SubscriptionPlan.Query().
+// Select(subscriptionplan.FieldGroupID).
+// Scan(ctx, &v)
+func (_q *SubscriptionPlanQuery) Select(fields ...string) *SubscriptionPlanSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &SubscriptionPlanSelect{SubscriptionPlanQuery: _q}
+ sbuild.label = subscriptionplan.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a SubscriptionPlanSelect configured with the given aggregations.
+func (_q *SubscriptionPlanQuery) Aggregate(fns ...AggregateFunc) *SubscriptionPlanSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *SubscriptionPlanQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !subscriptionplan.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *SubscriptionPlanQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*SubscriptionPlan, error) {
+ var (
+ nodes = []*SubscriptionPlan{}
+ _spec = _q.querySpec()
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*SubscriptionPlan).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &SubscriptionPlan{config: _q.config}
+ nodes = append(nodes, node)
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ return nodes, nil
+}
+
+func (_q *SubscriptionPlanQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *SubscriptionPlanQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(subscriptionplan.Table, subscriptionplan.Columns, sqlgraph.NewFieldSpec(subscriptionplan.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, subscriptionplan.FieldID)
+ for i := range fields {
+ if fields[i] != subscriptionplan.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *SubscriptionPlanQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(subscriptionplan.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = subscriptionplan.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *SubscriptionPlanQuery) ForUpdate(opts ...sql.LockOption) *SubscriptionPlanQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *SubscriptionPlanQuery) ForShare(opts ...sql.LockOption) *SubscriptionPlanQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// SubscriptionPlanGroupBy is the group-by builder for SubscriptionPlan entities.
+type SubscriptionPlanGroupBy struct {
+ selector
+ build *SubscriptionPlanQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *SubscriptionPlanGroupBy) Aggregate(fns ...AggregateFunc) *SubscriptionPlanGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *SubscriptionPlanGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*SubscriptionPlanQuery, *SubscriptionPlanGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *SubscriptionPlanGroupBy) sqlScan(ctx context.Context, root *SubscriptionPlanQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// SubscriptionPlanSelect is the builder for selecting fields of SubscriptionPlan entities.
+type SubscriptionPlanSelect struct {
+ *SubscriptionPlanQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *SubscriptionPlanSelect) Aggregate(fns ...AggregateFunc) *SubscriptionPlanSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *SubscriptionPlanSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*SubscriptionPlanQuery, *SubscriptionPlanSelect](ctx, _s.SubscriptionPlanQuery, _s, _s.inters, v)
+}
+
+func (_s *SubscriptionPlanSelect) sqlScan(ctx context.Context, root *SubscriptionPlanQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/subscriptionplan_update.go b/backend/ent/subscriptionplan_update.go
new file mode 100644
index 00000000..c9225d0f
--- /dev/null
+++ b/backend/ent/subscriptionplan_update.go
@@ -0,0 +1,750 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/subscriptionplan"
+)
+
+// SubscriptionPlanUpdate is the builder for updating SubscriptionPlan entities.
+type SubscriptionPlanUpdate struct {
+ config
+ hooks []Hook
+ mutation *SubscriptionPlanMutation
+}
+
+// Where appends a list predicates to the SubscriptionPlanUpdate builder.
+func (_u *SubscriptionPlanUpdate) Where(ps ...predicate.SubscriptionPlan) *SubscriptionPlanUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetGroupID sets the "group_id" field.
+func (_u *SubscriptionPlanUpdate) SetGroupID(v int64) *SubscriptionPlanUpdate {
+ _u.mutation.ResetGroupID()
+ _u.mutation.SetGroupID(v)
+ return _u
+}
+
+// SetNillableGroupID sets the "group_id" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdate) SetNillableGroupID(v *int64) *SubscriptionPlanUpdate {
+ if v != nil {
+ _u.SetGroupID(*v)
+ }
+ return _u
+}
+
+// AddGroupID adds value to the "group_id" field.
+func (_u *SubscriptionPlanUpdate) AddGroupID(v int64) *SubscriptionPlanUpdate {
+ _u.mutation.AddGroupID(v)
+ return _u
+}
+
+// SetName sets the "name" field.
+func (_u *SubscriptionPlanUpdate) SetName(v string) *SubscriptionPlanUpdate {
+ _u.mutation.SetName(v)
+ return _u
+}
+
+// SetNillableName sets the "name" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdate) SetNillableName(v *string) *SubscriptionPlanUpdate {
+ if v != nil {
+ _u.SetName(*v)
+ }
+ return _u
+}
+
+// SetDescription sets the "description" field.
+func (_u *SubscriptionPlanUpdate) SetDescription(v string) *SubscriptionPlanUpdate {
+ _u.mutation.SetDescription(v)
+ return _u
+}
+
+// SetNillableDescription sets the "description" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdate) SetNillableDescription(v *string) *SubscriptionPlanUpdate {
+ if v != nil {
+ _u.SetDescription(*v)
+ }
+ return _u
+}
+
+// SetPrice sets the "price" field.
+func (_u *SubscriptionPlanUpdate) SetPrice(v float64) *SubscriptionPlanUpdate {
+ _u.mutation.ResetPrice()
+ _u.mutation.SetPrice(v)
+ return _u
+}
+
+// SetNillablePrice sets the "price" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdate) SetNillablePrice(v *float64) *SubscriptionPlanUpdate {
+ if v != nil {
+ _u.SetPrice(*v)
+ }
+ return _u
+}
+
+// AddPrice adds value to the "price" field.
+func (_u *SubscriptionPlanUpdate) AddPrice(v float64) *SubscriptionPlanUpdate {
+ _u.mutation.AddPrice(v)
+ return _u
+}
+
+// SetOriginalPrice sets the "original_price" field.
+func (_u *SubscriptionPlanUpdate) SetOriginalPrice(v float64) *SubscriptionPlanUpdate {
+ _u.mutation.ResetOriginalPrice()
+ _u.mutation.SetOriginalPrice(v)
+ return _u
+}
+
+// SetNillableOriginalPrice sets the "original_price" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdate) SetNillableOriginalPrice(v *float64) *SubscriptionPlanUpdate {
+ if v != nil {
+ _u.SetOriginalPrice(*v)
+ }
+ return _u
+}
+
+// AddOriginalPrice adds value to the "original_price" field.
+func (_u *SubscriptionPlanUpdate) AddOriginalPrice(v float64) *SubscriptionPlanUpdate {
+ _u.mutation.AddOriginalPrice(v)
+ return _u
+}
+
+// ClearOriginalPrice clears the value of the "original_price" field.
+func (_u *SubscriptionPlanUpdate) ClearOriginalPrice() *SubscriptionPlanUpdate {
+ _u.mutation.ClearOriginalPrice()
+ return _u
+}
+
+// SetValidityDays sets the "validity_days" field.
+func (_u *SubscriptionPlanUpdate) SetValidityDays(v int) *SubscriptionPlanUpdate {
+ _u.mutation.ResetValidityDays()
+ _u.mutation.SetValidityDays(v)
+ return _u
+}
+
+// SetNillableValidityDays sets the "validity_days" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdate) SetNillableValidityDays(v *int) *SubscriptionPlanUpdate {
+ if v != nil {
+ _u.SetValidityDays(*v)
+ }
+ return _u
+}
+
+// AddValidityDays adds value to the "validity_days" field.
+func (_u *SubscriptionPlanUpdate) AddValidityDays(v int) *SubscriptionPlanUpdate {
+ _u.mutation.AddValidityDays(v)
+ return _u
+}
+
+// SetValidityUnit sets the "validity_unit" field.
+func (_u *SubscriptionPlanUpdate) SetValidityUnit(v string) *SubscriptionPlanUpdate {
+ _u.mutation.SetValidityUnit(v)
+ return _u
+}
+
+// SetNillableValidityUnit sets the "validity_unit" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdate) SetNillableValidityUnit(v *string) *SubscriptionPlanUpdate {
+ if v != nil {
+ _u.SetValidityUnit(*v)
+ }
+ return _u
+}
+
+// SetFeatures sets the "features" field.
+func (_u *SubscriptionPlanUpdate) SetFeatures(v string) *SubscriptionPlanUpdate {
+ _u.mutation.SetFeatures(v)
+ return _u
+}
+
+// SetNillableFeatures sets the "features" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdate) SetNillableFeatures(v *string) *SubscriptionPlanUpdate {
+ if v != nil {
+ _u.SetFeatures(*v)
+ }
+ return _u
+}
+
+// SetProductName sets the "product_name" field.
+func (_u *SubscriptionPlanUpdate) SetProductName(v string) *SubscriptionPlanUpdate {
+ _u.mutation.SetProductName(v)
+ return _u
+}
+
+// SetNillableProductName sets the "product_name" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdate) SetNillableProductName(v *string) *SubscriptionPlanUpdate {
+ if v != nil {
+ _u.SetProductName(*v)
+ }
+ return _u
+}
+
+// SetForSale sets the "for_sale" field.
+func (_u *SubscriptionPlanUpdate) SetForSale(v bool) *SubscriptionPlanUpdate {
+ _u.mutation.SetForSale(v)
+ return _u
+}
+
+// SetNillableForSale sets the "for_sale" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdate) SetNillableForSale(v *bool) *SubscriptionPlanUpdate {
+ if v != nil {
+ _u.SetForSale(*v)
+ }
+ return _u
+}
+
+// SetSortOrder sets the "sort_order" field.
+func (_u *SubscriptionPlanUpdate) SetSortOrder(v int) *SubscriptionPlanUpdate {
+ _u.mutation.ResetSortOrder()
+ _u.mutation.SetSortOrder(v)
+ return _u
+}
+
+// SetNillableSortOrder sets the "sort_order" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdate) SetNillableSortOrder(v *int) *SubscriptionPlanUpdate {
+ if v != nil {
+ _u.SetSortOrder(*v)
+ }
+ return _u
+}
+
+// AddSortOrder adds value to the "sort_order" field.
+func (_u *SubscriptionPlanUpdate) AddSortOrder(v int) *SubscriptionPlanUpdate {
+ _u.mutation.AddSortOrder(v)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *SubscriptionPlanUpdate) SetUpdatedAt(v time.Time) *SubscriptionPlanUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// Mutation returns the SubscriptionPlanMutation object of the builder.
+func (_u *SubscriptionPlanUpdate) Mutation() *SubscriptionPlanMutation {
+ return _u.mutation
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *SubscriptionPlanUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *SubscriptionPlanUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *SubscriptionPlanUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *SubscriptionPlanUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *SubscriptionPlanUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := subscriptionplan.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *SubscriptionPlanUpdate) check() error {
+ if v, ok := _u.mutation.Name(); ok {
+ if err := subscriptionplan.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "SubscriptionPlan.name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ValidityUnit(); ok {
+ if err := subscriptionplan.ValidityUnitValidator(v); err != nil {
+ return &ValidationError{Name: "validity_unit", err: fmt.Errorf(`ent: validator failed for field "SubscriptionPlan.validity_unit": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProductName(); ok {
+ if err := subscriptionplan.ProductNameValidator(v); err != nil {
+ return &ValidationError{Name: "product_name", err: fmt.Errorf(`ent: validator failed for field "SubscriptionPlan.product_name": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *SubscriptionPlanUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(subscriptionplan.Table, subscriptionplan.Columns, sqlgraph.NewFieldSpec(subscriptionplan.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.GroupID(); ok {
+ _spec.SetField(subscriptionplan.FieldGroupID, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedGroupID(); ok {
+ _spec.AddField(subscriptionplan.FieldGroupID, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.Name(); ok {
+ _spec.SetField(subscriptionplan.FieldName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Description(); ok {
+ _spec.SetField(subscriptionplan.FieldDescription, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Price(); ok {
+ _spec.SetField(subscriptionplan.FieldPrice, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedPrice(); ok {
+ _spec.AddField(subscriptionplan.FieldPrice, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.OriginalPrice(); ok {
+ _spec.SetField(subscriptionplan.FieldOriginalPrice, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedOriginalPrice(); ok {
+ _spec.AddField(subscriptionplan.FieldOriginalPrice, field.TypeFloat64, value)
+ }
+ if _u.mutation.OriginalPriceCleared() {
+ _spec.ClearField(subscriptionplan.FieldOriginalPrice, field.TypeFloat64)
+ }
+ if value, ok := _u.mutation.ValidityDays(); ok {
+ _spec.SetField(subscriptionplan.FieldValidityDays, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedValidityDays(); ok {
+ _spec.AddField(subscriptionplan.FieldValidityDays, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.ValidityUnit(); ok {
+ _spec.SetField(subscriptionplan.FieldValidityUnit, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Features(); ok {
+ _spec.SetField(subscriptionplan.FieldFeatures, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProductName(); ok {
+ _spec.SetField(subscriptionplan.FieldProductName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ForSale(); ok {
+ _spec.SetField(subscriptionplan.FieldForSale, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.SortOrder(); ok {
+ _spec.SetField(subscriptionplan.FieldSortOrder, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedSortOrder(); ok {
+ _spec.AddField(subscriptionplan.FieldSortOrder, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(subscriptionplan.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{subscriptionplan.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// SubscriptionPlanUpdateOne is the builder for updating a single SubscriptionPlan entity.
+type SubscriptionPlanUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *SubscriptionPlanMutation
+}
+
+// SetGroupID sets the "group_id" field.
+func (_u *SubscriptionPlanUpdateOne) SetGroupID(v int64) *SubscriptionPlanUpdateOne {
+ _u.mutation.ResetGroupID()
+ _u.mutation.SetGroupID(v)
+ return _u
+}
+
+// SetNillableGroupID sets the "group_id" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdateOne) SetNillableGroupID(v *int64) *SubscriptionPlanUpdateOne {
+ if v != nil {
+ _u.SetGroupID(*v)
+ }
+ return _u
+}
+
+// AddGroupID adds value to the "group_id" field.
+func (_u *SubscriptionPlanUpdateOne) AddGroupID(v int64) *SubscriptionPlanUpdateOne {
+ _u.mutation.AddGroupID(v)
+ return _u
+}
+
+// SetName sets the "name" field.
+func (_u *SubscriptionPlanUpdateOne) SetName(v string) *SubscriptionPlanUpdateOne {
+ _u.mutation.SetName(v)
+ return _u
+}
+
+// SetNillableName sets the "name" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdateOne) SetNillableName(v *string) *SubscriptionPlanUpdateOne {
+ if v != nil {
+ _u.SetName(*v)
+ }
+ return _u
+}
+
+// SetDescription sets the "description" field.
+func (_u *SubscriptionPlanUpdateOne) SetDescription(v string) *SubscriptionPlanUpdateOne {
+ _u.mutation.SetDescription(v)
+ return _u
+}
+
+// SetNillableDescription sets the "description" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdateOne) SetNillableDescription(v *string) *SubscriptionPlanUpdateOne {
+ if v != nil {
+ _u.SetDescription(*v)
+ }
+ return _u
+}
+
+// SetPrice sets the "price" field.
+func (_u *SubscriptionPlanUpdateOne) SetPrice(v float64) *SubscriptionPlanUpdateOne {
+ _u.mutation.ResetPrice()
+ _u.mutation.SetPrice(v)
+ return _u
+}
+
+// SetNillablePrice sets the "price" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdateOne) SetNillablePrice(v *float64) *SubscriptionPlanUpdateOne {
+ if v != nil {
+ _u.SetPrice(*v)
+ }
+ return _u
+}
+
+// AddPrice adds value to the "price" field.
+func (_u *SubscriptionPlanUpdateOne) AddPrice(v float64) *SubscriptionPlanUpdateOne {
+ _u.mutation.AddPrice(v)
+ return _u
+}
+
+// SetOriginalPrice sets the "original_price" field.
+func (_u *SubscriptionPlanUpdateOne) SetOriginalPrice(v float64) *SubscriptionPlanUpdateOne {
+ _u.mutation.ResetOriginalPrice()
+ _u.mutation.SetOriginalPrice(v)
+ return _u
+}
+
+// SetNillableOriginalPrice sets the "original_price" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdateOne) SetNillableOriginalPrice(v *float64) *SubscriptionPlanUpdateOne {
+ if v != nil {
+ _u.SetOriginalPrice(*v)
+ }
+ return _u
+}
+
+// AddOriginalPrice adds value to the "original_price" field.
+func (_u *SubscriptionPlanUpdateOne) AddOriginalPrice(v float64) *SubscriptionPlanUpdateOne {
+ _u.mutation.AddOriginalPrice(v)
+ return _u
+}
+
+// ClearOriginalPrice clears the value of the "original_price" field.
+func (_u *SubscriptionPlanUpdateOne) ClearOriginalPrice() *SubscriptionPlanUpdateOne {
+ _u.mutation.ClearOriginalPrice()
+ return _u
+}
+
+// SetValidityDays sets the "validity_days" field.
+func (_u *SubscriptionPlanUpdateOne) SetValidityDays(v int) *SubscriptionPlanUpdateOne {
+ _u.mutation.ResetValidityDays()
+ _u.mutation.SetValidityDays(v)
+ return _u
+}
+
+// SetNillableValidityDays sets the "validity_days" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdateOne) SetNillableValidityDays(v *int) *SubscriptionPlanUpdateOne {
+ if v != nil {
+ _u.SetValidityDays(*v)
+ }
+ return _u
+}
+
+// AddValidityDays adds value to the "validity_days" field.
+func (_u *SubscriptionPlanUpdateOne) AddValidityDays(v int) *SubscriptionPlanUpdateOne {
+ _u.mutation.AddValidityDays(v)
+ return _u
+}
+
+// SetValidityUnit sets the "validity_unit" field.
+func (_u *SubscriptionPlanUpdateOne) SetValidityUnit(v string) *SubscriptionPlanUpdateOne {
+ _u.mutation.SetValidityUnit(v)
+ return _u
+}
+
+// SetNillableValidityUnit sets the "validity_unit" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdateOne) SetNillableValidityUnit(v *string) *SubscriptionPlanUpdateOne {
+ if v != nil {
+ _u.SetValidityUnit(*v)
+ }
+ return _u
+}
+
+// SetFeatures sets the "features" field.
+func (_u *SubscriptionPlanUpdateOne) SetFeatures(v string) *SubscriptionPlanUpdateOne {
+ _u.mutation.SetFeatures(v)
+ return _u
+}
+
+// SetNillableFeatures sets the "features" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdateOne) SetNillableFeatures(v *string) *SubscriptionPlanUpdateOne {
+ if v != nil {
+ _u.SetFeatures(*v)
+ }
+ return _u
+}
+
+// SetProductName sets the "product_name" field.
+func (_u *SubscriptionPlanUpdateOne) SetProductName(v string) *SubscriptionPlanUpdateOne {
+ _u.mutation.SetProductName(v)
+ return _u
+}
+
+// SetNillableProductName sets the "product_name" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdateOne) SetNillableProductName(v *string) *SubscriptionPlanUpdateOne {
+ if v != nil {
+ _u.SetProductName(*v)
+ }
+ return _u
+}
+
+// SetForSale sets the "for_sale" field.
+func (_u *SubscriptionPlanUpdateOne) SetForSale(v bool) *SubscriptionPlanUpdateOne {
+ _u.mutation.SetForSale(v)
+ return _u
+}
+
+// SetNillableForSale sets the "for_sale" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdateOne) SetNillableForSale(v *bool) *SubscriptionPlanUpdateOne {
+ if v != nil {
+ _u.SetForSale(*v)
+ }
+ return _u
+}
+
+// SetSortOrder sets the "sort_order" field.
+func (_u *SubscriptionPlanUpdateOne) SetSortOrder(v int) *SubscriptionPlanUpdateOne {
+ _u.mutation.ResetSortOrder()
+ _u.mutation.SetSortOrder(v)
+ return _u
+}
+
+// SetNillableSortOrder sets the "sort_order" field if the given value is not nil.
+func (_u *SubscriptionPlanUpdateOne) SetNillableSortOrder(v *int) *SubscriptionPlanUpdateOne {
+ if v != nil {
+ _u.SetSortOrder(*v)
+ }
+ return _u
+}
+
+// AddSortOrder adds value to the "sort_order" field.
+func (_u *SubscriptionPlanUpdateOne) AddSortOrder(v int) *SubscriptionPlanUpdateOne {
+ _u.mutation.AddSortOrder(v)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *SubscriptionPlanUpdateOne) SetUpdatedAt(v time.Time) *SubscriptionPlanUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// Mutation returns the SubscriptionPlanMutation object of the builder.
+func (_u *SubscriptionPlanUpdateOne) Mutation() *SubscriptionPlanMutation {
+ return _u.mutation
+}
+
+// Where appends a list predicates to the SubscriptionPlanUpdate builder.
+func (_u *SubscriptionPlanUpdateOne) Where(ps ...predicate.SubscriptionPlan) *SubscriptionPlanUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *SubscriptionPlanUpdateOne) Select(field string, fields ...string) *SubscriptionPlanUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated SubscriptionPlan entity.
+func (_u *SubscriptionPlanUpdateOne) Save(ctx context.Context) (*SubscriptionPlan, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *SubscriptionPlanUpdateOne) SaveX(ctx context.Context) *SubscriptionPlan {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *SubscriptionPlanUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *SubscriptionPlanUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *SubscriptionPlanUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := subscriptionplan.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *SubscriptionPlanUpdateOne) check() error {
+ if v, ok := _u.mutation.Name(); ok {
+ if err := subscriptionplan.NameValidator(v); err != nil {
+ return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "SubscriptionPlan.name": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ValidityUnit(); ok {
+ if err := subscriptionplan.ValidityUnitValidator(v); err != nil {
+ return &ValidationError{Name: "validity_unit", err: fmt.Errorf(`ent: validator failed for field "SubscriptionPlan.validity_unit": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ProductName(); ok {
+ if err := subscriptionplan.ProductNameValidator(v); err != nil {
+ return &ValidationError{Name: "product_name", err: fmt.Errorf(`ent: validator failed for field "SubscriptionPlan.product_name": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *SubscriptionPlanUpdateOne) sqlSave(ctx context.Context) (_node *SubscriptionPlan, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(subscriptionplan.Table, subscriptionplan.Columns, sqlgraph.NewFieldSpec(subscriptionplan.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "SubscriptionPlan.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, subscriptionplan.FieldID)
+ for _, f := range fields {
+ if !subscriptionplan.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != subscriptionplan.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.GroupID(); ok {
+ _spec.SetField(subscriptionplan.FieldGroupID, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedGroupID(); ok {
+ _spec.AddField(subscriptionplan.FieldGroupID, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.Name(); ok {
+ _spec.SetField(subscriptionplan.FieldName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Description(); ok {
+ _spec.SetField(subscriptionplan.FieldDescription, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Price(); ok {
+ _spec.SetField(subscriptionplan.FieldPrice, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedPrice(); ok {
+ _spec.AddField(subscriptionplan.FieldPrice, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.OriginalPrice(); ok {
+ _spec.SetField(subscriptionplan.FieldOriginalPrice, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedOriginalPrice(); ok {
+ _spec.AddField(subscriptionplan.FieldOriginalPrice, field.TypeFloat64, value)
+ }
+ if _u.mutation.OriginalPriceCleared() {
+ _spec.ClearField(subscriptionplan.FieldOriginalPrice, field.TypeFloat64)
+ }
+ if value, ok := _u.mutation.ValidityDays(); ok {
+ _spec.SetField(subscriptionplan.FieldValidityDays, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedValidityDays(); ok {
+ _spec.AddField(subscriptionplan.FieldValidityDays, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.ValidityUnit(); ok {
+ _spec.SetField(subscriptionplan.FieldValidityUnit, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Features(); ok {
+ _spec.SetField(subscriptionplan.FieldFeatures, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ProductName(); ok {
+ _spec.SetField(subscriptionplan.FieldProductName, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.ForSale(); ok {
+ _spec.SetField(subscriptionplan.FieldForSale, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.SortOrder(); ok {
+ _spec.SetField(subscriptionplan.FieldSortOrder, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedSortOrder(); ok {
+ _spec.AddField(subscriptionplan.FieldSortOrder, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(subscriptionplan.FieldUpdatedAt, field.TypeTime, value)
+ }
+ _node = &SubscriptionPlan{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{subscriptionplan.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/ent/tx.go b/backend/ent/tx.go
index b5aea447..611028e9 100644
--- a/backend/ent/tx.go
+++ b/backend/ent/tx.go
@@ -24,12 +24,34 @@ type Tx struct {
Announcement *AnnouncementClient
// AnnouncementRead is the client for interacting with the AnnouncementRead builders.
AnnouncementRead *AnnouncementReadClient
+ // AuthIdentity is the client for interacting with the AuthIdentity builders.
+ AuthIdentity *AuthIdentityClient
+ // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders.
+ AuthIdentityChannel *AuthIdentityChannelClient
+ // ChannelMonitor is the client for interacting with the ChannelMonitor builders.
+ ChannelMonitor *ChannelMonitorClient
+ // ChannelMonitorDailyRollup is the client for interacting with the ChannelMonitorDailyRollup builders.
+ ChannelMonitorDailyRollup *ChannelMonitorDailyRollupClient
+ // ChannelMonitorHistory is the client for interacting with the ChannelMonitorHistory builders.
+ ChannelMonitorHistory *ChannelMonitorHistoryClient
+ // ChannelMonitorRequestTemplate is the client for interacting with the ChannelMonitorRequestTemplate builders.
+ ChannelMonitorRequestTemplate *ChannelMonitorRequestTemplateClient
// ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders.
ErrorPassthroughRule *ErrorPassthroughRuleClient
// Group is the client for interacting with the Group builders.
Group *GroupClient
// IdempotencyRecord is the client for interacting with the IdempotencyRecord builders.
IdempotencyRecord *IdempotencyRecordClient
+ // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders.
+ IdentityAdoptionDecision *IdentityAdoptionDecisionClient
+ // PaymentAuditLog is the client for interacting with the PaymentAuditLog builders.
+ PaymentAuditLog *PaymentAuditLogClient
+ // PaymentOrder is the client for interacting with the PaymentOrder builders.
+ PaymentOrder *PaymentOrderClient
+ // PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders.
+ PaymentProviderInstance *PaymentProviderInstanceClient
+ // PendingAuthSession is the client for interacting with the PendingAuthSession builders.
+ PendingAuthSession *PendingAuthSessionClient
// PromoCode is the client for interacting with the PromoCode builders.
PromoCode *PromoCodeClient
// PromoCodeUsage is the client for interacting with the PromoCodeUsage builders.
@@ -42,6 +64,8 @@ type Tx struct {
SecuritySecret *SecuritySecretClient
// Setting is the client for interacting with the Setting builders.
Setting *SettingClient
+ // SubscriptionPlan is the client for interacting with the SubscriptionPlan builders.
+ SubscriptionPlan *SubscriptionPlanClient
// TLSFingerprintProfile is the client for interacting with the TLSFingerprintProfile builders.
TLSFingerprintProfile *TLSFingerprintProfileClient
// UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
@@ -194,15 +218,27 @@ func (tx *Tx) init() {
tx.AccountGroup = NewAccountGroupClient(tx.config)
tx.Announcement = NewAnnouncementClient(tx.config)
tx.AnnouncementRead = NewAnnouncementReadClient(tx.config)
+ tx.AuthIdentity = NewAuthIdentityClient(tx.config)
+ tx.AuthIdentityChannel = NewAuthIdentityChannelClient(tx.config)
+ tx.ChannelMonitor = NewChannelMonitorClient(tx.config)
+ tx.ChannelMonitorDailyRollup = NewChannelMonitorDailyRollupClient(tx.config)
+ tx.ChannelMonitorHistory = NewChannelMonitorHistoryClient(tx.config)
+ tx.ChannelMonitorRequestTemplate = NewChannelMonitorRequestTemplateClient(tx.config)
tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config)
tx.Group = NewGroupClient(tx.config)
tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config)
+ tx.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(tx.config)
+ tx.PaymentAuditLog = NewPaymentAuditLogClient(tx.config)
+ tx.PaymentOrder = NewPaymentOrderClient(tx.config)
+ tx.PaymentProviderInstance = NewPaymentProviderInstanceClient(tx.config)
+ tx.PendingAuthSession = NewPendingAuthSessionClient(tx.config)
tx.PromoCode = NewPromoCodeClient(tx.config)
tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config)
tx.Proxy = NewProxyClient(tx.config)
tx.RedeemCode = NewRedeemCodeClient(tx.config)
tx.SecuritySecret = NewSecuritySecretClient(tx.config)
tx.Setting = NewSettingClient(tx.config)
+ tx.SubscriptionPlan = NewSubscriptionPlanClient(tx.config)
tx.TLSFingerprintProfile = NewTLSFingerprintProfileClient(tx.config)
tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config)
tx.UsageLog = NewUsageLogClient(tx.config)
diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go
index fb4ee1c5..a8e0cc6c 100644
--- a/backend/ent/usagelog.go
+++ b/backend/ent/usagelog.go
@@ -36,6 +36,14 @@ type UsageLog struct {
RequestedModel *string `json:"requested_model,omitempty"`
// UpstreamModel holds the value of the "upstream_model" field.
UpstreamModel *string `json:"upstream_model,omitempty"`
+ // 渠道 ID
+ ChannelID *int64 `json:"channel_id,omitempty"`
+ // 模型映射链
+ ModelMappingChain *string `json:"model_mapping_chain,omitempty"`
+ // 计费层级标签
+ BillingTier *string `json:"billing_tier,omitempty"`
+ // 计费模式:token/per_request/image
+ BillingMode *string `json:"billing_mode,omitempty"`
// GroupID holds the value of the "group_id" field.
GroupID *int64 `json:"group_id,omitempty"`
// SubscriptionID holds the value of the "subscription_id" field.
@@ -84,8 +92,6 @@ type UsageLog struct {
ImageCount int `json:"image_count,omitempty"`
// ImageSize holds the value of the "image_size" field.
ImageSize *string `json:"image_size,omitempty"`
- // MediaType holds the value of the "media_type" field.
- MediaType *string `json:"media_type,omitempty"`
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
// CreatedAt holds the value of the "created_at" field.
@@ -177,9 +183,9 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
values[i] = new(sql.NullFloat64)
- case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
+ case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
- case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType:
+ case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime)
@@ -248,6 +254,34 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.UpstreamModel = new(string)
*_m.UpstreamModel = value.String
}
+ case usagelog.FieldChannelID:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field channel_id", values[i])
+ } else if value.Valid {
+ _m.ChannelID = new(int64)
+ *_m.ChannelID = value.Int64
+ }
+ case usagelog.FieldModelMappingChain:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field model_mapping_chain", values[i])
+ } else if value.Valid {
+ _m.ModelMappingChain = new(string)
+ *_m.ModelMappingChain = value.String
+ }
+ case usagelog.FieldBillingTier:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field billing_tier", values[i])
+ } else if value.Valid {
+ _m.BillingTier = new(string)
+ *_m.BillingTier = value.String
+ }
+ case usagelog.FieldBillingMode:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field billing_mode", values[i])
+ } else if value.Valid {
+ _m.BillingMode = new(string)
+ *_m.BillingMode = value.String
+ }
case usagelog.FieldGroupID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field group_id", values[i])
@@ -400,13 +434,6 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.ImageSize = new(string)
*_m.ImageSize = value.String
}
- case usagelog.FieldMediaType:
- if value, ok := values[i].(*sql.NullString); !ok {
- return fmt.Errorf("unexpected type %T for field media_type", values[i])
- } else if value.Valid {
- _m.MediaType = new(string)
- *_m.MediaType = value.String
- }
case usagelog.FieldCacheTTLOverridden:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
@@ -505,6 +532,26 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
+ if v := _m.ChannelID; v != nil {
+ builder.WriteString("channel_id=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ if v := _m.ModelMappingChain; v != nil {
+ builder.WriteString("model_mapping_chain=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ if v := _m.BillingTier; v != nil {
+ builder.WriteString("billing_tier=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ if v := _m.BillingMode; v != nil {
+ builder.WriteString("billing_mode=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
if v := _m.GroupID; v != nil {
builder.WriteString("group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
@@ -593,11 +640,6 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
- if v := _m.MediaType; v != nil {
- builder.WriteString("media_type=")
- builder.WriteString(*v)
- }
- builder.WriteString(", ")
builder.WriteString("cache_ttl_overridden=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
builder.WriteString(", ")
diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go
index b534f193..a7438e60 100644
--- a/backend/ent/usagelog/usagelog.go
+++ b/backend/ent/usagelog/usagelog.go
@@ -28,6 +28,14 @@ const (
FieldRequestedModel = "requested_model"
// FieldUpstreamModel holds the string denoting the upstream_model field in the database.
FieldUpstreamModel = "upstream_model"
+ // FieldChannelID holds the string denoting the channel_id field in the database.
+ FieldChannelID = "channel_id"
+ // FieldModelMappingChain holds the string denoting the model_mapping_chain field in the database.
+ FieldModelMappingChain = "model_mapping_chain"
+ // FieldBillingTier holds the string denoting the billing_tier field in the database.
+ FieldBillingTier = "billing_tier"
+ // FieldBillingMode holds the string denoting the billing_mode field in the database.
+ FieldBillingMode = "billing_mode"
// FieldGroupID holds the string denoting the group_id field in the database.
FieldGroupID = "group_id"
// FieldSubscriptionID holds the string denoting the subscription_id field in the database.
@@ -76,8 +84,6 @@ const (
FieldImageCount = "image_count"
// FieldImageSize holds the string denoting the image_size field in the database.
FieldImageSize = "image_size"
- // FieldMediaType holds the string denoting the media_type field in the database.
- FieldMediaType = "media_type"
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
FieldCacheTTLOverridden = "cache_ttl_overridden"
// FieldCreatedAt holds the string denoting the created_at field in the database.
@@ -141,6 +147,10 @@ var Columns = []string{
FieldModel,
FieldRequestedModel,
FieldUpstreamModel,
+ FieldChannelID,
+ FieldModelMappingChain,
+ FieldBillingTier,
+ FieldBillingMode,
FieldGroupID,
FieldSubscriptionID,
FieldInputTokens,
@@ -165,7 +175,6 @@ var Columns = []string{
FieldIPAddress,
FieldImageCount,
FieldImageSize,
- FieldMediaType,
FieldCacheTTLOverridden,
FieldCreatedAt,
}
@@ -189,6 +198,12 @@ var (
RequestedModelValidator func(string) error
// UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save.
UpstreamModelValidator func(string) error
+ // ModelMappingChainValidator is a validator for the "model_mapping_chain" field. It is called by the builders before save.
+ ModelMappingChainValidator func(string) error
+ // BillingTierValidator is a validator for the "billing_tier" field. It is called by the builders before save.
+ BillingTierValidator func(string) error
+ // BillingModeValidator is a validator for the "billing_mode" field. It is called by the builders before save.
+ BillingModeValidator func(string) error
// DefaultInputTokens holds the default value on creation for the "input_tokens" field.
DefaultInputTokens int
// DefaultOutputTokens holds the default value on creation for the "output_tokens" field.
@@ -227,8 +242,6 @@ var (
DefaultImageCount int
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
ImageSizeValidator func(string) error
- // MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save.
- MediaTypeValidator func(string) error
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
DefaultCacheTTLOverridden bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
@@ -278,6 +291,26 @@ func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc()
}
+// ByChannelID orders the results by the channel_id field.
+func ByChannelID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldChannelID, opts...).ToFunc()
+}
+
+// ByModelMappingChain orders the results by the model_mapping_chain field.
+func ByModelMappingChain(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldModelMappingChain, opts...).ToFunc()
+}
+
+// ByBillingTier orders the results by the billing_tier field.
+func ByBillingTier(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBillingTier, opts...).ToFunc()
+}
+
+// ByBillingMode orders the results by the billing_mode field.
+func ByBillingMode(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBillingMode, opts...).ToFunc()
+}
+
// ByGroupID orders the results by the group_id field.
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
@@ -398,11 +431,6 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageSize, opts...).ToFunc()
}
-// ByMediaType orders the results by the media_type field.
-func ByMediaType(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldMediaType, opts...).ToFunc()
-}
-
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()
diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go
index f95bceb7..b8439a03 100644
--- a/backend/ent/usagelog/where.go
+++ b/backend/ent/usagelog/where.go
@@ -90,6 +90,26 @@ func UpstreamModel(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v))
}
+// ChannelID applies equality check predicate on the "channel_id" field. It's identical to ChannelIDEQ.
+func ChannelID(v int64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldChannelID, v))
+}
+
+// ModelMappingChain applies equality check predicate on the "model_mapping_chain" field. It's identical to ModelMappingChainEQ.
+func ModelMappingChain(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldModelMappingChain, v))
+}
+
+// BillingTier applies equality check predicate on the "billing_tier" field. It's identical to BillingTierEQ.
+func BillingTier(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldBillingTier, v))
+}
+
+// BillingMode applies equality check predicate on the "billing_mode" field. It's identical to BillingModeEQ.
+func BillingMode(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldBillingMode, v))
+}
+
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
func GroupID(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
@@ -210,11 +230,6 @@ func ImageSize(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
}
-// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ.
-func MediaType(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
-}
-
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
func CacheTTLOverridden(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
@@ -565,6 +580,281 @@ func UpstreamModelContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v))
}
+// ChannelIDEQ applies the EQ predicate on the "channel_id" field.
+func ChannelIDEQ(v int64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldChannelID, v))
+}
+
+// ChannelIDNEQ applies the NEQ predicate on the "channel_id" field.
+func ChannelIDNEQ(v int64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNEQ(FieldChannelID, v))
+}
+
+// ChannelIDIn applies the In predicate on the "channel_id" field.
+func ChannelIDIn(vs ...int64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIn(FieldChannelID, vs...))
+}
+
+// ChannelIDNotIn applies the NotIn predicate on the "channel_id" field.
+func ChannelIDNotIn(vs ...int64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotIn(FieldChannelID, vs...))
+}
+
+// ChannelIDGT applies the GT predicate on the "channel_id" field.
+func ChannelIDGT(v int64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGT(FieldChannelID, v))
+}
+
+// ChannelIDGTE applies the GTE predicate on the "channel_id" field.
+func ChannelIDGTE(v int64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGTE(FieldChannelID, v))
+}
+
+// ChannelIDLT applies the LT predicate on the "channel_id" field.
+func ChannelIDLT(v int64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLT(FieldChannelID, v))
+}
+
+// ChannelIDLTE applies the LTE predicate on the "channel_id" field.
+func ChannelIDLTE(v int64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLTE(FieldChannelID, v))
+}
+
+// ChannelIDIsNil applies the IsNil predicate on the "channel_id" field.
+func ChannelIDIsNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIsNull(FieldChannelID))
+}
+
+// ChannelIDNotNil applies the NotNil predicate on the "channel_id" field.
+func ChannelIDNotNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotNull(FieldChannelID))
+}
+
+// ModelMappingChainEQ applies the EQ predicate on the "model_mapping_chain" field.
+func ModelMappingChainEQ(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldModelMappingChain, v))
+}
+
+// ModelMappingChainNEQ applies the NEQ predicate on the "model_mapping_chain" field.
+func ModelMappingChainNEQ(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNEQ(FieldModelMappingChain, v))
+}
+
+// ModelMappingChainIn applies the In predicate on the "model_mapping_chain" field.
+func ModelMappingChainIn(vs ...string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIn(FieldModelMappingChain, vs...))
+}
+
+// ModelMappingChainNotIn applies the NotIn predicate on the "model_mapping_chain" field.
+func ModelMappingChainNotIn(vs ...string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotIn(FieldModelMappingChain, vs...))
+}
+
+// ModelMappingChainGT applies the GT predicate on the "model_mapping_chain" field.
+func ModelMappingChainGT(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGT(FieldModelMappingChain, v))
+}
+
+// ModelMappingChainGTE applies the GTE predicate on the "model_mapping_chain" field.
+func ModelMappingChainGTE(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGTE(FieldModelMappingChain, v))
+}
+
+// ModelMappingChainLT applies the LT predicate on the "model_mapping_chain" field.
+func ModelMappingChainLT(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLT(FieldModelMappingChain, v))
+}
+
+// ModelMappingChainLTE applies the LTE predicate on the "model_mapping_chain" field.
+func ModelMappingChainLTE(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLTE(FieldModelMappingChain, v))
+}
+
+// ModelMappingChainContains applies the Contains predicate on the "model_mapping_chain" field.
+func ModelMappingChainContains(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldContains(FieldModelMappingChain, v))
+}
+
+// ModelMappingChainHasPrefix applies the HasPrefix predicate on the "model_mapping_chain" field.
+func ModelMappingChainHasPrefix(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldHasPrefix(FieldModelMappingChain, v))
+}
+
+// ModelMappingChainHasSuffix applies the HasSuffix predicate on the "model_mapping_chain" field.
+func ModelMappingChainHasSuffix(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldHasSuffix(FieldModelMappingChain, v))
+}
+
+// ModelMappingChainIsNil applies the IsNil predicate on the "model_mapping_chain" field.
+func ModelMappingChainIsNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIsNull(FieldModelMappingChain))
+}
+
+// ModelMappingChainNotNil applies the NotNil predicate on the "model_mapping_chain" field.
+func ModelMappingChainNotNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotNull(FieldModelMappingChain))
+}
+
+// ModelMappingChainEqualFold applies the EqualFold predicate on the "model_mapping_chain" field.
+func ModelMappingChainEqualFold(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEqualFold(FieldModelMappingChain, v))
+}
+
+// ModelMappingChainContainsFold applies the ContainsFold predicate on the "model_mapping_chain" field.
+func ModelMappingChainContainsFold(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldContainsFold(FieldModelMappingChain, v))
+}
+
+// BillingTierEQ applies the EQ predicate on the "billing_tier" field.
+func BillingTierEQ(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldBillingTier, v))
+}
+
+// BillingTierNEQ applies the NEQ predicate on the "billing_tier" field.
+func BillingTierNEQ(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNEQ(FieldBillingTier, v))
+}
+
+// BillingTierIn applies the In predicate on the "billing_tier" field.
+func BillingTierIn(vs ...string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIn(FieldBillingTier, vs...))
+}
+
+// BillingTierNotIn applies the NotIn predicate on the "billing_tier" field.
+func BillingTierNotIn(vs ...string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotIn(FieldBillingTier, vs...))
+}
+
+// BillingTierGT applies the GT predicate on the "billing_tier" field.
+func BillingTierGT(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGT(FieldBillingTier, v))
+}
+
+// BillingTierGTE applies the GTE predicate on the "billing_tier" field.
+func BillingTierGTE(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGTE(FieldBillingTier, v))
+}
+
+// BillingTierLT applies the LT predicate on the "billing_tier" field.
+func BillingTierLT(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLT(FieldBillingTier, v))
+}
+
+// BillingTierLTE applies the LTE predicate on the "billing_tier" field.
+func BillingTierLTE(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLTE(FieldBillingTier, v))
+}
+
+// BillingTierContains applies the Contains predicate on the "billing_tier" field.
+func BillingTierContains(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldContains(FieldBillingTier, v))
+}
+
+// BillingTierHasPrefix applies the HasPrefix predicate on the "billing_tier" field.
+func BillingTierHasPrefix(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldHasPrefix(FieldBillingTier, v))
+}
+
+// BillingTierHasSuffix applies the HasSuffix predicate on the "billing_tier" field.
+func BillingTierHasSuffix(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldHasSuffix(FieldBillingTier, v))
+}
+
+// BillingTierIsNil applies the IsNil predicate on the "billing_tier" field.
+func BillingTierIsNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIsNull(FieldBillingTier))
+}
+
+// BillingTierNotNil applies the NotNil predicate on the "billing_tier" field.
+func BillingTierNotNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotNull(FieldBillingTier))
+}
+
+// BillingTierEqualFold applies the EqualFold predicate on the "billing_tier" field.
+func BillingTierEqualFold(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEqualFold(FieldBillingTier, v))
+}
+
+// BillingTierContainsFold applies the ContainsFold predicate on the "billing_tier" field.
+func BillingTierContainsFold(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldContainsFold(FieldBillingTier, v))
+}
+
+// BillingModeEQ applies the EQ predicate on the "billing_mode" field.
+func BillingModeEQ(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldBillingMode, v))
+}
+
+// BillingModeNEQ applies the NEQ predicate on the "billing_mode" field.
+func BillingModeNEQ(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNEQ(FieldBillingMode, v))
+}
+
+// BillingModeIn applies the In predicate on the "billing_mode" field.
+func BillingModeIn(vs ...string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIn(FieldBillingMode, vs...))
+}
+
+// BillingModeNotIn applies the NotIn predicate on the "billing_mode" field.
+func BillingModeNotIn(vs ...string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotIn(FieldBillingMode, vs...))
+}
+
+// BillingModeGT applies the GT predicate on the "billing_mode" field.
+func BillingModeGT(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGT(FieldBillingMode, v))
+}
+
+// BillingModeGTE applies the GTE predicate on the "billing_mode" field.
+func BillingModeGTE(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGTE(FieldBillingMode, v))
+}
+
+// BillingModeLT applies the LT predicate on the "billing_mode" field.
+func BillingModeLT(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLT(FieldBillingMode, v))
+}
+
+// BillingModeLTE applies the LTE predicate on the "billing_mode" field.
+func BillingModeLTE(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLTE(FieldBillingMode, v))
+}
+
+// BillingModeContains applies the Contains predicate on the "billing_mode" field.
+func BillingModeContains(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldContains(FieldBillingMode, v))
+}
+
+// BillingModeHasPrefix applies the HasPrefix predicate on the "billing_mode" field.
+func BillingModeHasPrefix(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldHasPrefix(FieldBillingMode, v))
+}
+
+// BillingModeHasSuffix applies the HasSuffix predicate on the "billing_mode" field.
+func BillingModeHasSuffix(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldHasSuffix(FieldBillingMode, v))
+}
+
+// BillingModeIsNil applies the IsNil predicate on the "billing_mode" field.
+func BillingModeIsNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIsNull(FieldBillingMode))
+}
+
+// BillingModeNotNil applies the NotNil predicate on the "billing_mode" field.
+func BillingModeNotNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotNull(FieldBillingMode))
+}
+
+// BillingModeEqualFold applies the EqualFold predicate on the "billing_mode" field.
+func BillingModeEqualFold(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEqualFold(FieldBillingMode, v))
+}
+
+// BillingModeContainsFold applies the ContainsFold predicate on the "billing_mode" field.
+func BillingModeContainsFold(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldContainsFold(FieldBillingMode, v))
+}
+
// GroupIDEQ applies the EQ predicate on the "group_id" field.
func GroupIDEQ(v int64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v))
@@ -1610,81 +1900,6 @@ func ImageSizeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
}
-// MediaTypeEQ applies the EQ predicate on the "media_type" field.
-func MediaTypeEQ(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v))
-}
-
-// MediaTypeNEQ applies the NEQ predicate on the "media_type" field.
-func MediaTypeNEQ(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v))
-}
-
-// MediaTypeIn applies the In predicate on the "media_type" field.
-func MediaTypeIn(vs ...string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...))
-}
-
-// MediaTypeNotIn applies the NotIn predicate on the "media_type" field.
-func MediaTypeNotIn(vs ...string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...))
-}
-
-// MediaTypeGT applies the GT predicate on the "media_type" field.
-func MediaTypeGT(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldGT(FieldMediaType, v))
-}
-
-// MediaTypeGTE applies the GTE predicate on the "media_type" field.
-func MediaTypeGTE(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v))
-}
-
-// MediaTypeLT applies the LT predicate on the "media_type" field.
-func MediaTypeLT(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldLT(FieldMediaType, v))
-}
-
-// MediaTypeLTE applies the LTE predicate on the "media_type" field.
-func MediaTypeLTE(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v))
-}
-
-// MediaTypeContains applies the Contains predicate on the "media_type" field.
-func MediaTypeContains(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldContains(FieldMediaType, v))
-}
-
-// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field.
-func MediaTypeHasPrefix(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v))
-}
-
-// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field.
-func MediaTypeHasSuffix(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v))
-}
-
-// MediaTypeIsNil applies the IsNil predicate on the "media_type" field.
-func MediaTypeIsNil() predicate.UsageLog {
- return predicate.UsageLog(sql.FieldIsNull(FieldMediaType))
-}
-
-// MediaTypeNotNil applies the NotNil predicate on the "media_type" field.
-func MediaTypeNotNil() predicate.UsageLog {
- return predicate.UsageLog(sql.FieldNotNull(FieldMediaType))
-}
-
-// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field.
-func MediaTypeEqualFold(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v))
-}
-
-// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field.
-func MediaTypeContainsFold(v string) predicate.UsageLog {
- return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v))
-}
-
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go
index 6ae0bf59..fded364e 100644
--- a/backend/ent/usagelog_create.go
+++ b/backend/ent/usagelog_create.go
@@ -85,6 +85,62 @@ func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate {
return _c
}
+// SetChannelID sets the "channel_id" field.
+func (_c *UsageLogCreate) SetChannelID(v int64) *UsageLogCreate {
+ _c.mutation.SetChannelID(v)
+ return _c
+}
+
+// SetNillableChannelID sets the "channel_id" field if the given value is not nil.
+func (_c *UsageLogCreate) SetNillableChannelID(v *int64) *UsageLogCreate {
+ if v != nil {
+ _c.SetChannelID(*v)
+ }
+ return _c
+}
+
+// SetModelMappingChain sets the "model_mapping_chain" field.
+func (_c *UsageLogCreate) SetModelMappingChain(v string) *UsageLogCreate {
+ _c.mutation.SetModelMappingChain(v)
+ return _c
+}
+
+// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil.
+func (_c *UsageLogCreate) SetNillableModelMappingChain(v *string) *UsageLogCreate {
+ if v != nil {
+ _c.SetModelMappingChain(*v)
+ }
+ return _c
+}
+
+// SetBillingTier sets the "billing_tier" field.
+func (_c *UsageLogCreate) SetBillingTier(v string) *UsageLogCreate {
+ _c.mutation.SetBillingTier(v)
+ return _c
+}
+
+// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil.
+func (_c *UsageLogCreate) SetNillableBillingTier(v *string) *UsageLogCreate {
+ if v != nil {
+ _c.SetBillingTier(*v)
+ }
+ return _c
+}
+
+// SetBillingMode sets the "billing_mode" field.
+func (_c *UsageLogCreate) SetBillingMode(v string) *UsageLogCreate {
+ _c.mutation.SetBillingMode(v)
+ return _c
+}
+
+// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil.
+func (_c *UsageLogCreate) SetNillableBillingMode(v *string) *UsageLogCreate {
+ if v != nil {
+ _c.SetBillingMode(*v)
+ }
+ return _c
+}
+
// SetGroupID sets the "group_id" field.
func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate {
_c.mutation.SetGroupID(v)
@@ -421,20 +477,6 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate {
return _c
}
-// SetMediaType sets the "media_type" field.
-func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate {
- _c.mutation.SetMediaType(v)
- return _c
-}
-
-// SetNillableMediaType sets the "media_type" field if the given value is not nil.
-func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate {
- if v != nil {
- _c.SetMediaType(*v)
- }
- return _c
-}
-
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
_c.mutation.SetCacheTTLOverridden(v)
@@ -634,6 +676,21 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
+ if v, ok := _c.mutation.ModelMappingChain(); ok {
+ if err := usagelog.ModelMappingChainValidator(v); err != nil {
+ return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)}
+ }
+ }
+ if v, ok := _c.mutation.BillingTier(); ok {
+ if err := usagelog.BillingTierValidator(v); err != nil {
+ return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)}
+ }
+ }
+ if v, ok := _c.mutation.BillingMode(); ok {
+ if err := usagelog.BillingModeValidator(v); err != nil {
+ return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
+ }
+ }
if _, ok := _c.mutation.InputTokens(); !ok {
return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)}
}
@@ -697,11 +754,6 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
- if v, ok := _c.mutation.MediaType(); ok {
- if err := usagelog.MediaTypeValidator(v); err != nil {
- return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
- }
- }
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
}
@@ -760,6 +812,22 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value)
_node.UpstreamModel = &value
}
+ if value, ok := _c.mutation.ChannelID(); ok {
+ _spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value)
+ _node.ChannelID = &value
+ }
+ if value, ok := _c.mutation.ModelMappingChain(); ok {
+ _spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value)
+ _node.ModelMappingChain = &value
+ }
+ if value, ok := _c.mutation.BillingTier(); ok {
+ _spec.SetField(usagelog.FieldBillingTier, field.TypeString, value)
+ _node.BillingTier = &value
+ }
+ if value, ok := _c.mutation.BillingMode(); ok {
+ _spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
+ _node.BillingMode = &value
+ }
if value, ok := _c.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
_node.InputTokens = value
@@ -848,10 +916,6 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
_node.ImageSize = &value
}
- if value, ok := _c.mutation.MediaType(); ok {
- _spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
- _node.MediaType = &value
- }
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
_node.CacheTTLOverridden = value
@@ -1093,6 +1157,84 @@ func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert {
return u
}
+// SetChannelID sets the "channel_id" field.
+func (u *UsageLogUpsert) SetChannelID(v int64) *UsageLogUpsert {
+ u.Set(usagelog.FieldChannelID, v)
+ return u
+}
+
+// UpdateChannelID sets the "channel_id" field to the value that was provided on create.
+func (u *UsageLogUpsert) UpdateChannelID() *UsageLogUpsert {
+ u.SetExcluded(usagelog.FieldChannelID)
+ return u
+}
+
+// AddChannelID adds v to the "channel_id" field.
+func (u *UsageLogUpsert) AddChannelID(v int64) *UsageLogUpsert {
+ u.Add(usagelog.FieldChannelID, v)
+ return u
+}
+
+// ClearChannelID clears the value of the "channel_id" field.
+func (u *UsageLogUpsert) ClearChannelID() *UsageLogUpsert {
+ u.SetNull(usagelog.FieldChannelID)
+ return u
+}
+
+// SetModelMappingChain sets the "model_mapping_chain" field.
+func (u *UsageLogUpsert) SetModelMappingChain(v string) *UsageLogUpsert {
+ u.Set(usagelog.FieldModelMappingChain, v)
+ return u
+}
+
+// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create.
+func (u *UsageLogUpsert) UpdateModelMappingChain() *UsageLogUpsert {
+ u.SetExcluded(usagelog.FieldModelMappingChain)
+ return u
+}
+
+// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
+func (u *UsageLogUpsert) ClearModelMappingChain() *UsageLogUpsert {
+ u.SetNull(usagelog.FieldModelMappingChain)
+ return u
+}
+
+// SetBillingTier sets the "billing_tier" field.
+func (u *UsageLogUpsert) SetBillingTier(v string) *UsageLogUpsert {
+ u.Set(usagelog.FieldBillingTier, v)
+ return u
+}
+
+// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create.
+func (u *UsageLogUpsert) UpdateBillingTier() *UsageLogUpsert {
+ u.SetExcluded(usagelog.FieldBillingTier)
+ return u
+}
+
+// ClearBillingTier clears the value of the "billing_tier" field.
+func (u *UsageLogUpsert) ClearBillingTier() *UsageLogUpsert {
+ u.SetNull(usagelog.FieldBillingTier)
+ return u
+}
+
+// SetBillingMode sets the "billing_mode" field.
+func (u *UsageLogUpsert) SetBillingMode(v string) *UsageLogUpsert {
+ u.Set(usagelog.FieldBillingMode, v)
+ return u
+}
+
+// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create.
+func (u *UsageLogUpsert) UpdateBillingMode() *UsageLogUpsert {
+ u.SetExcluded(usagelog.FieldBillingMode)
+ return u
+}
+
+// ClearBillingMode clears the value of the "billing_mode" field.
+func (u *UsageLogUpsert) ClearBillingMode() *UsageLogUpsert {
+ u.SetNull(usagelog.FieldBillingMode)
+ return u
+}
+
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert {
u.Set(usagelog.FieldGroupID, v)
@@ -1537,24 +1679,6 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert {
return u
}
-// SetMediaType sets the "media_type" field.
-func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert {
- u.Set(usagelog.FieldMediaType, v)
- return u
-}
-
-// UpdateMediaType sets the "media_type" field to the value that was provided on create.
-func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert {
- u.SetExcluded(usagelog.FieldMediaType)
- return u
-}
-
-// ClearMediaType clears the value of the "media_type" field.
-func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert {
- u.SetNull(usagelog.FieldMediaType)
- return u
-}
-
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
u.Set(usagelog.FieldCacheTTLOverridden, v)
@@ -1724,6 +1848,97 @@ func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne {
})
}
+// SetChannelID sets the "channel_id" field.
+func (u *UsageLogUpsertOne) SetChannelID(v int64) *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetChannelID(v)
+ })
+}
+
+// AddChannelID adds v to the "channel_id" field.
+func (u *UsageLogUpsertOne) AddChannelID(v int64) *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.AddChannelID(v)
+ })
+}
+
+// UpdateChannelID sets the "channel_id" field to the value that was provided on create.
+func (u *UsageLogUpsertOne) UpdateChannelID() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateChannelID()
+ })
+}
+
+// ClearChannelID clears the value of the "channel_id" field.
+func (u *UsageLogUpsertOne) ClearChannelID() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearChannelID()
+ })
+}
+
+// SetModelMappingChain sets the "model_mapping_chain" field.
+func (u *UsageLogUpsertOne) SetModelMappingChain(v string) *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetModelMappingChain(v)
+ })
+}
+
+// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create.
+func (u *UsageLogUpsertOne) UpdateModelMappingChain() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateModelMappingChain()
+ })
+}
+
+// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
+func (u *UsageLogUpsertOne) ClearModelMappingChain() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearModelMappingChain()
+ })
+}
+
+// SetBillingTier sets the "billing_tier" field.
+func (u *UsageLogUpsertOne) SetBillingTier(v string) *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetBillingTier(v)
+ })
+}
+
+// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create.
+func (u *UsageLogUpsertOne) UpdateBillingTier() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateBillingTier()
+ })
+}
+
+// ClearBillingTier clears the value of the "billing_tier" field.
+func (u *UsageLogUpsertOne) ClearBillingTier() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearBillingTier()
+ })
+}
+
+// SetBillingMode sets the "billing_mode" field.
+func (u *UsageLogUpsertOne) SetBillingMode(v string) *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetBillingMode(v)
+ })
+}
+
+// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create.
+func (u *UsageLogUpsertOne) UpdateBillingMode() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateBillingMode()
+ })
+}
+
+// ClearBillingMode clears the value of the "billing_mode" field.
+func (u *UsageLogUpsertOne) ClearBillingMode() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearBillingMode()
+ })
+}
+
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@@ -2242,27 +2457,6 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne {
})
}
-// SetMediaType sets the "media_type" field.
-func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne {
- return u.Update(func(s *UsageLogUpsert) {
- s.SetMediaType(v)
- })
-}
-
-// UpdateMediaType sets the "media_type" field to the value that was provided on create.
-func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne {
- return u.Update(func(s *UsageLogUpsert) {
- s.UpdateMediaType()
- })
-}
-
-// ClearMediaType clears the value of the "media_type" field.
-func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne {
- return u.Update(func(s *UsageLogUpsert) {
- s.ClearMediaType()
- })
-}
-
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@@ -2600,6 +2794,97 @@ func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk {
})
}
+// SetChannelID sets the "channel_id" field.
+func (u *UsageLogUpsertBulk) SetChannelID(v int64) *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetChannelID(v)
+ })
+}
+
+// AddChannelID adds v to the "channel_id" field.
+func (u *UsageLogUpsertBulk) AddChannelID(v int64) *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.AddChannelID(v)
+ })
+}
+
+// UpdateChannelID sets the "channel_id" field to the value that was provided on create.
+func (u *UsageLogUpsertBulk) UpdateChannelID() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateChannelID()
+ })
+}
+
+// ClearChannelID clears the value of the "channel_id" field.
+func (u *UsageLogUpsertBulk) ClearChannelID() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearChannelID()
+ })
+}
+
+// SetModelMappingChain sets the "model_mapping_chain" field.
+func (u *UsageLogUpsertBulk) SetModelMappingChain(v string) *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetModelMappingChain(v)
+ })
+}
+
+// UpdateModelMappingChain sets the "model_mapping_chain" field to the value that was provided on create.
+func (u *UsageLogUpsertBulk) UpdateModelMappingChain() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateModelMappingChain()
+ })
+}
+
+// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
+func (u *UsageLogUpsertBulk) ClearModelMappingChain() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearModelMappingChain()
+ })
+}
+
+// SetBillingTier sets the "billing_tier" field.
+func (u *UsageLogUpsertBulk) SetBillingTier(v string) *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetBillingTier(v)
+ })
+}
+
+// UpdateBillingTier sets the "billing_tier" field to the value that was provided on create.
+func (u *UsageLogUpsertBulk) UpdateBillingTier() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateBillingTier()
+ })
+}
+
+// ClearBillingTier clears the value of the "billing_tier" field.
+func (u *UsageLogUpsertBulk) ClearBillingTier() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearBillingTier()
+ })
+}
+
+// SetBillingMode sets the "billing_mode" field.
+func (u *UsageLogUpsertBulk) SetBillingMode(v string) *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetBillingMode(v)
+ })
+}
+
+// UpdateBillingMode sets the "billing_mode" field to the value that was provided on create.
+func (u *UsageLogUpsertBulk) UpdateBillingMode() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateBillingMode()
+ })
+}
+
+// ClearBillingMode clears the value of the "billing_mode" field.
+func (u *UsageLogUpsertBulk) ClearBillingMode() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearBillingMode()
+ })
+}
+
// SetGroupID sets the "group_id" field.
func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
@@ -3118,27 +3403,6 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk {
})
}
-// SetMediaType sets the "media_type" field.
-func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk {
- return u.Update(func(s *UsageLogUpsert) {
- s.SetMediaType(v)
- })
-}
-
-// UpdateMediaType sets the "media_type" field to the value that was provided on create.
-func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk {
- return u.Update(func(s *UsageLogUpsert) {
- s.UpdateMediaType()
- })
-}
-
-// ClearMediaType clears the value of the "media_type" field.
-func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk {
- return u.Update(func(s *UsageLogUpsert) {
- s.ClearMediaType()
- })
-}
-
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go
index 516407b9..bb5ac86c 100644
--- a/backend/ent/usagelog_update.go
+++ b/backend/ent/usagelog_update.go
@@ -142,6 +142,93 @@ func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate {
return _u
}
+// SetChannelID sets the "channel_id" field.
+func (_u *UsageLogUpdate) SetChannelID(v int64) *UsageLogUpdate {
+ _u.mutation.ResetChannelID()
+ _u.mutation.SetChannelID(v)
+ return _u
+}
+
+// SetNillableChannelID sets the "channel_id" field if the given value is not nil.
+func (_u *UsageLogUpdate) SetNillableChannelID(v *int64) *UsageLogUpdate {
+ if v != nil {
+ _u.SetChannelID(*v)
+ }
+ return _u
+}
+
+// AddChannelID adds value to the "channel_id" field.
+func (_u *UsageLogUpdate) AddChannelID(v int64) *UsageLogUpdate {
+ _u.mutation.AddChannelID(v)
+ return _u
+}
+
+// ClearChannelID clears the value of the "channel_id" field.
+func (_u *UsageLogUpdate) ClearChannelID() *UsageLogUpdate {
+ _u.mutation.ClearChannelID()
+ return _u
+}
+
+// SetModelMappingChain sets the "model_mapping_chain" field.
+func (_u *UsageLogUpdate) SetModelMappingChain(v string) *UsageLogUpdate {
+ _u.mutation.SetModelMappingChain(v)
+ return _u
+}
+
+// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil.
+func (_u *UsageLogUpdate) SetNillableModelMappingChain(v *string) *UsageLogUpdate {
+ if v != nil {
+ _u.SetModelMappingChain(*v)
+ }
+ return _u
+}
+
+// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
+func (_u *UsageLogUpdate) ClearModelMappingChain() *UsageLogUpdate {
+ _u.mutation.ClearModelMappingChain()
+ return _u
+}
+
+// SetBillingTier sets the "billing_tier" field.
+func (_u *UsageLogUpdate) SetBillingTier(v string) *UsageLogUpdate {
+ _u.mutation.SetBillingTier(v)
+ return _u
+}
+
+// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil.
+func (_u *UsageLogUpdate) SetNillableBillingTier(v *string) *UsageLogUpdate {
+ if v != nil {
+ _u.SetBillingTier(*v)
+ }
+ return _u
+}
+
+// ClearBillingTier clears the value of the "billing_tier" field.
+func (_u *UsageLogUpdate) ClearBillingTier() *UsageLogUpdate {
+ _u.mutation.ClearBillingTier()
+ return _u
+}
+
+// SetBillingMode sets the "billing_mode" field.
+func (_u *UsageLogUpdate) SetBillingMode(v string) *UsageLogUpdate {
+ _u.mutation.SetBillingMode(v)
+ return _u
+}
+
+// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil.
+func (_u *UsageLogUpdate) SetNillableBillingMode(v *string) *UsageLogUpdate {
+ if v != nil {
+ _u.SetBillingMode(*v)
+ }
+ return _u
+}
+
+// ClearBillingMode clears the value of the "billing_mode" field.
+func (_u *UsageLogUpdate) ClearBillingMode() *UsageLogUpdate {
+ _u.mutation.ClearBillingMode()
+ return _u
+}
+
// SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate {
_u.mutation.SetGroupID(v)
@@ -652,26 +739,6 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate {
return _u
}
-// SetMediaType sets the "media_type" field.
-func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate {
- _u.mutation.SetMediaType(v)
- return _u
-}
-
-// SetNillableMediaType sets the "media_type" field if the given value is not nil.
-func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate {
- if v != nil {
- _u.SetMediaType(*v)
- }
- return _u
-}
-
-// ClearMediaType clears the value of the "media_type" field.
-func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate {
- _u.mutation.ClearMediaType()
- return _u
-}
-
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
_u.mutation.SetCacheTTLOverridden(v)
@@ -795,6 +862,21 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
+ if v, ok := _u.mutation.ModelMappingChain(); ok {
+ if err := usagelog.ModelMappingChainValidator(v); err != nil {
+ return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.BillingTier(); ok {
+ if err := usagelog.BillingTierValidator(v); err != nil {
+ return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.BillingMode(); ok {
+ if err := usagelog.BillingModeValidator(v); err != nil {
+ return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
+ }
+ }
if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@@ -810,11 +892,6 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
- if v, ok := _u.mutation.MediaType(); ok {
- if err := usagelog.MediaTypeValidator(v); err != nil {
- return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
- }
- }
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
}
@@ -857,6 +934,33 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.UpstreamModelCleared() {
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
}
+ if value, ok := _u.mutation.ChannelID(); ok {
+ _spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedChannelID(); ok {
+ _spec.AddField(usagelog.FieldChannelID, field.TypeInt64, value)
+ }
+ if _u.mutation.ChannelIDCleared() {
+ _spec.ClearField(usagelog.FieldChannelID, field.TypeInt64)
+ }
+ if value, ok := _u.mutation.ModelMappingChain(); ok {
+ _spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value)
+ }
+ if _u.mutation.ModelMappingChainCleared() {
+ _spec.ClearField(usagelog.FieldModelMappingChain, field.TypeString)
+ }
+ if value, ok := _u.mutation.BillingTier(); ok {
+ _spec.SetField(usagelog.FieldBillingTier, field.TypeString, value)
+ }
+ if _u.mutation.BillingTierCleared() {
+ _spec.ClearField(usagelog.FieldBillingTier, field.TypeString)
+ }
+ if value, ok := _u.mutation.BillingMode(); ok {
+ _spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
+ }
+ if _u.mutation.BillingModeCleared() {
+ _spec.ClearField(usagelog.FieldBillingMode, field.TypeString)
+ }
if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
}
@@ -995,12 +1099,6 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
- if value, ok := _u.mutation.MediaType(); ok {
- _spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
- }
- if _u.mutation.MediaTypeCleared() {
- _spec.ClearField(usagelog.FieldMediaType, field.TypeString)
- }
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}
@@ -1279,6 +1377,93 @@ func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne {
return _u
}
+// SetChannelID sets the "channel_id" field.
+func (_u *UsageLogUpdateOne) SetChannelID(v int64) *UsageLogUpdateOne {
+ _u.mutation.ResetChannelID()
+ _u.mutation.SetChannelID(v)
+ return _u
+}
+
+// SetNillableChannelID sets the "channel_id" field if the given value is not nil.
+func (_u *UsageLogUpdateOne) SetNillableChannelID(v *int64) *UsageLogUpdateOne {
+ if v != nil {
+ _u.SetChannelID(*v)
+ }
+ return _u
+}
+
+// AddChannelID adds value to the "channel_id" field.
+func (_u *UsageLogUpdateOne) AddChannelID(v int64) *UsageLogUpdateOne {
+ _u.mutation.AddChannelID(v)
+ return _u
+}
+
+// ClearChannelID clears the value of the "channel_id" field.
+func (_u *UsageLogUpdateOne) ClearChannelID() *UsageLogUpdateOne {
+ _u.mutation.ClearChannelID()
+ return _u
+}
+
+// SetModelMappingChain sets the "model_mapping_chain" field.
+func (_u *UsageLogUpdateOne) SetModelMappingChain(v string) *UsageLogUpdateOne {
+ _u.mutation.SetModelMappingChain(v)
+ return _u
+}
+
+// SetNillableModelMappingChain sets the "model_mapping_chain" field if the given value is not nil.
+func (_u *UsageLogUpdateOne) SetNillableModelMappingChain(v *string) *UsageLogUpdateOne {
+ if v != nil {
+ _u.SetModelMappingChain(*v)
+ }
+ return _u
+}
+
+// ClearModelMappingChain clears the value of the "model_mapping_chain" field.
+func (_u *UsageLogUpdateOne) ClearModelMappingChain() *UsageLogUpdateOne {
+ _u.mutation.ClearModelMappingChain()
+ return _u
+}
+
+// SetBillingTier sets the "billing_tier" field.
+func (_u *UsageLogUpdateOne) SetBillingTier(v string) *UsageLogUpdateOne {
+ _u.mutation.SetBillingTier(v)
+ return _u
+}
+
+// SetNillableBillingTier sets the "billing_tier" field if the given value is not nil.
+func (_u *UsageLogUpdateOne) SetNillableBillingTier(v *string) *UsageLogUpdateOne {
+ if v != nil {
+ _u.SetBillingTier(*v)
+ }
+ return _u
+}
+
+// ClearBillingTier clears the value of the "billing_tier" field.
+func (_u *UsageLogUpdateOne) ClearBillingTier() *UsageLogUpdateOne {
+ _u.mutation.ClearBillingTier()
+ return _u
+}
+
+// SetBillingMode sets the "billing_mode" field.
+func (_u *UsageLogUpdateOne) SetBillingMode(v string) *UsageLogUpdateOne {
+ _u.mutation.SetBillingMode(v)
+ return _u
+}
+
+// SetNillableBillingMode sets the "billing_mode" field if the given value is not nil.
+func (_u *UsageLogUpdateOne) SetNillableBillingMode(v *string) *UsageLogUpdateOne {
+ if v != nil {
+ _u.SetBillingMode(*v)
+ }
+ return _u
+}
+
+// ClearBillingMode clears the value of the "billing_mode" field.
+func (_u *UsageLogUpdateOne) ClearBillingMode() *UsageLogUpdateOne {
+ _u.mutation.ClearBillingMode()
+ return _u
+}
+
// SetGroupID sets the "group_id" field.
func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne {
_u.mutation.SetGroupID(v)
@@ -1789,26 +1974,6 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne {
return _u
}
-// SetMediaType sets the "media_type" field.
-func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne {
- _u.mutation.SetMediaType(v)
- return _u
-}
-
-// SetNillableMediaType sets the "media_type" field if the given value is not nil.
-func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne {
- if v != nil {
- _u.SetMediaType(*v)
- }
- return _u
-}
-
-// ClearMediaType clears the value of the "media_type" field.
-func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne {
- _u.mutation.ClearMediaType()
- return _u
-}
-
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
_u.mutation.SetCacheTTLOverridden(v)
@@ -1945,6 +2110,21 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)}
}
}
+ if v, ok := _u.mutation.ModelMappingChain(); ok {
+ if err := usagelog.ModelMappingChainValidator(v); err != nil {
+ return &ValidationError{Name: "model_mapping_chain", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model_mapping_chain": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.BillingTier(); ok {
+ if err := usagelog.BillingTierValidator(v); err != nil {
+ return &ValidationError{Name: "billing_tier", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_tier": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.BillingMode(); ok {
+ if err := usagelog.BillingModeValidator(v); err != nil {
+ return &ValidationError{Name: "billing_mode", err: fmt.Errorf(`ent: validator failed for field "UsageLog.billing_mode": %w`, err)}
+ }
+ }
if v, ok := _u.mutation.UserAgent(); ok {
if err := usagelog.UserAgentValidator(v); err != nil {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
@@ -1960,11 +2140,6 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
- if v, ok := _u.mutation.MediaType(); ok {
- if err := usagelog.MediaTypeValidator(v); err != nil {
- return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)}
- }
- }
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
}
@@ -2024,6 +2199,33 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.UpstreamModelCleared() {
_spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString)
}
+ if value, ok := _u.mutation.ChannelID(); ok {
+ _spec.SetField(usagelog.FieldChannelID, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedChannelID(); ok {
+ _spec.AddField(usagelog.FieldChannelID, field.TypeInt64, value)
+ }
+ if _u.mutation.ChannelIDCleared() {
+ _spec.ClearField(usagelog.FieldChannelID, field.TypeInt64)
+ }
+ if value, ok := _u.mutation.ModelMappingChain(); ok {
+ _spec.SetField(usagelog.FieldModelMappingChain, field.TypeString, value)
+ }
+ if _u.mutation.ModelMappingChainCleared() {
+ _spec.ClearField(usagelog.FieldModelMappingChain, field.TypeString)
+ }
+ if value, ok := _u.mutation.BillingTier(); ok {
+ _spec.SetField(usagelog.FieldBillingTier, field.TypeString, value)
+ }
+ if _u.mutation.BillingTierCleared() {
+ _spec.ClearField(usagelog.FieldBillingTier, field.TypeString)
+ }
+ if value, ok := _u.mutation.BillingMode(); ok {
+ _spec.SetField(usagelog.FieldBillingMode, field.TypeString, value)
+ }
+ if _u.mutation.BillingModeCleared() {
+ _spec.ClearField(usagelog.FieldBillingMode, field.TypeString)
+ }
if value, ok := _u.mutation.InputTokens(); ok {
_spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value)
}
@@ -2162,12 +2364,6 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
- if value, ok := _u.mutation.MediaType(); ok {
- _spec.SetField(usagelog.FieldMediaType, field.TypeString, value)
- }
- if _u.mutation.MediaTypeCleared() {
- _spec.ClearField(usagelog.FieldMediaType, field.TypeString)
- }
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}
diff --git a/backend/ent/user.go b/backend/ent/user.go
index b3f933f6..06670444 100644
--- a/backend/ent/user.go
+++ b/backend/ent/user.go
@@ -45,10 +45,24 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
- // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field.
- SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"`
- // SoraStorageUsedBytes holds the value of the "sora_storage_used_bytes" field.
- SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes,omitempty"`
+ // SignupSource holds the value of the "signup_source" field.
+ SignupSource string `json:"signup_source,omitempty"`
+ // LastLoginAt holds the value of the "last_login_at" field.
+ LastLoginAt *time.Time `json:"last_login_at,omitempty"`
+ // LastActiveAt holds the value of the "last_active_at" field.
+ LastActiveAt *time.Time `json:"last_active_at,omitempty"`
+ // BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field.
+ BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"`
+ // BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field.
+ BalanceNotifyThresholdType string `json:"balance_notify_threshold_type,omitempty"`
+ // BalanceNotifyThreshold holds the value of the "balance_notify_threshold" field.
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
+ // BalanceNotifyExtraEmails holds the value of the "balance_notify_extra_emails" field.
+ BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"`
+ // TotalRecharged holds the value of the "total_recharged" field.
+ TotalRecharged float64 `json:"total_recharged,omitempty"`
+ // RpmLimit holds the value of the "rpm_limit" field.
+ RpmLimit int `json:"rpm_limit,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
@@ -75,11 +89,17 @@ type UserEdges struct {
AttributeValues []*UserAttributeValue `json:"attribute_values,omitempty"`
// PromoCodeUsages holds the value of the promo_code_usages edge.
PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"`
+ // PaymentOrders holds the value of the payment_orders edge.
+ PaymentOrders []*PaymentOrder `json:"payment_orders,omitempty"`
+ // AuthIdentities holds the value of the auth_identities edge.
+ AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"`
+ // PendingAuthSessions holds the value of the pending_auth_sessions edge.
+ PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"`
// UserAllowedGroups holds the value of the user_allowed_groups edge.
UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"`
// loadedTypes holds the information for reporting if a
// type was loaded (or requested) in eager-loading or not.
- loadedTypes [10]bool
+ loadedTypes [13]bool
}
// APIKeysOrErr returns the APIKeys value or an error if the edge
@@ -163,10 +183,37 @@ func (e UserEdges) PromoCodeUsagesOrErr() ([]*PromoCodeUsage, error) {
return nil, &NotLoadedError{edge: "promo_code_usages"}
}
+// PaymentOrdersOrErr returns the PaymentOrders value or an error if the edge
+// was not loaded in eager-loading.
+func (e UserEdges) PaymentOrdersOrErr() ([]*PaymentOrder, error) {
+ if e.loadedTypes[9] {
+ return e.PaymentOrders, nil
+ }
+ return nil, &NotLoadedError{edge: "payment_orders"}
+}
+
+// AuthIdentitiesOrErr returns the AuthIdentities value or an error if the edge
+// was not loaded in eager-loading.
+func (e UserEdges) AuthIdentitiesOrErr() ([]*AuthIdentity, error) {
+ if e.loadedTypes[10] {
+ return e.AuthIdentities, nil
+ }
+ return nil, &NotLoadedError{edge: "auth_identities"}
+}
+
+// PendingAuthSessionsOrErr returns the PendingAuthSessions value or an error if the edge
+// was not loaded in eager-loading.
+func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) {
+ if e.loadedTypes[11] {
+ return e.PendingAuthSessions, nil
+ }
+ return nil, &NotLoadedError{edge: "pending_auth_sessions"}
+}
+
// UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge
// was not loaded in eager-loading.
func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) {
- if e.loadedTypes[9] {
+ if e.loadedTypes[12] {
return e.UserAllowedGroups, nil
}
return nil, &NotLoadedError{edge: "user_allowed_groups"}
@@ -177,15 +224,15 @@ func (*User) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
- case user.FieldTotpEnabled:
+ case user.FieldTotpEnabled, user.FieldBalanceNotifyEnabled:
values[i] = new(sql.NullBool)
- case user.FieldBalance:
+ case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged:
values[i] = new(sql.NullFloat64)
- case user.FieldID, user.FieldConcurrency, user.FieldSoraStorageQuotaBytes, user.FieldSoraStorageUsedBytes:
+ case user.FieldID, user.FieldConcurrency, user.FieldRpmLimit:
values[i] = new(sql.NullInt64)
- case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
+ case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
values[i] = new(sql.NullString)
- case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
+ case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt, user.FieldLastLoginAt, user.FieldLastActiveAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -295,17 +342,62 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time
}
- case user.FieldSoraStorageQuotaBytes:
- if value, ok := values[i].(*sql.NullInt64); !ok {
- return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i])
+ case user.FieldSignupSource:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field signup_source", values[i])
} else if value.Valid {
- _m.SoraStorageQuotaBytes = value.Int64
+ _m.SignupSource = value.String
}
- case user.FieldSoraStorageUsedBytes:
- if value, ok := values[i].(*sql.NullInt64); !ok {
- return fmt.Errorf("unexpected type %T for field sora_storage_used_bytes", values[i])
+ case user.FieldLastLoginAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field last_login_at", values[i])
} else if value.Valid {
- _m.SoraStorageUsedBytes = value.Int64
+ _m.LastLoginAt = new(time.Time)
+ *_m.LastLoginAt = value.Time
+ }
+ case user.FieldLastActiveAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field last_active_at", values[i])
+ } else if value.Valid {
+ _m.LastActiveAt = new(time.Time)
+ *_m.LastActiveAt = value.Time
+ }
+ case user.FieldBalanceNotifyEnabled:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i])
+ } else if value.Valid {
+ _m.BalanceNotifyEnabled = value.Bool
+ }
+ case user.FieldBalanceNotifyThresholdType:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field balance_notify_threshold_type", values[i])
+ } else if value.Valid {
+ _m.BalanceNotifyThresholdType = value.String
+ }
+ case user.FieldBalanceNotifyThreshold:
+ if value, ok := values[i].(*sql.NullFloat64); !ok {
+ return fmt.Errorf("unexpected type %T for field balance_notify_threshold", values[i])
+ } else if value.Valid {
+ _m.BalanceNotifyThreshold = new(float64)
+ *_m.BalanceNotifyThreshold = value.Float64
+ }
+ case user.FieldBalanceNotifyExtraEmails:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field balance_notify_extra_emails", values[i])
+ } else if value.Valid {
+ _m.BalanceNotifyExtraEmails = value.String
+ }
+ case user.FieldTotalRecharged:
+ if value, ok := values[i].(*sql.NullFloat64); !ok {
+ return fmt.Errorf("unexpected type %T for field total_recharged", values[i])
+ } else if value.Valid {
+ _m.TotalRecharged = value.Float64
+ }
+ case user.FieldRpmLimit:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field rpm_limit", values[i])
+ } else if value.Valid {
+ _m.RpmLimit = int(value.Int64)
}
default:
_m.selectValues.Set(columns[i], values[i])
@@ -365,6 +457,21 @@ func (_m *User) QueryPromoCodeUsages() *PromoCodeUsageQuery {
return NewUserClient(_m.config).QueryPromoCodeUsages(_m)
}
+// QueryPaymentOrders queries the "payment_orders" edge of the User entity.
+func (_m *User) QueryPaymentOrders() *PaymentOrderQuery {
+ return NewUserClient(_m.config).QueryPaymentOrders(_m)
+}
+
+// QueryAuthIdentities queries the "auth_identities" edge of the User entity.
+func (_m *User) QueryAuthIdentities() *AuthIdentityQuery {
+ return NewUserClient(_m.config).QueryAuthIdentities(_m)
+}
+
+// QueryPendingAuthSessions queries the "pending_auth_sessions" edge of the User entity.
+func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery {
+ return NewUserClient(_m.config).QueryPendingAuthSessions(_m)
+}
+
// QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity.
func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery {
return NewUserClient(_m.config).QueryUserAllowedGroups(_m)
@@ -441,11 +548,38 @@ func (_m *User) String() string {
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
- builder.WriteString("sora_storage_quota_bytes=")
- builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes))
+ builder.WriteString("signup_source=")
+ builder.WriteString(_m.SignupSource)
builder.WriteString(", ")
- builder.WriteString("sora_storage_used_bytes=")
- builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageUsedBytes))
+ if v := _m.LastLoginAt; v != nil {
+ builder.WriteString("last_login_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.LastActiveAt; v != nil {
+ builder.WriteString("last_active_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("balance_notify_enabled=")
+ builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled))
+ builder.WriteString(", ")
+ builder.WriteString("balance_notify_threshold_type=")
+ builder.WriteString(_m.BalanceNotifyThresholdType)
+ builder.WriteString(", ")
+ if v := _m.BalanceNotifyThreshold; v != nil {
+ builder.WriteString("balance_notify_threshold=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ builder.WriteString("balance_notify_extra_emails=")
+ builder.WriteString(_m.BalanceNotifyExtraEmails)
+ builder.WriteString(", ")
+ builder.WriteString("total_recharged=")
+ builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged))
+ builder.WriteString(", ")
+ builder.WriteString("rpm_limit=")
+ builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit))
builder.WriteByte(')')
return builder.String()
}
diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go
index 155b9160..e11a8a32 100644
--- a/backend/ent/user/user.go
+++ b/backend/ent/user/user.go
@@ -43,10 +43,24 @@ const (
FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at"
- // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database.
- FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes"
- // FieldSoraStorageUsedBytes holds the string denoting the sora_storage_used_bytes field in the database.
- FieldSoraStorageUsedBytes = "sora_storage_used_bytes"
+ // FieldSignupSource holds the string denoting the signup_source field in the database.
+ FieldSignupSource = "signup_source"
+ // FieldLastLoginAt holds the string denoting the last_login_at field in the database.
+ FieldLastLoginAt = "last_login_at"
+ // FieldLastActiveAt holds the string denoting the last_active_at field in the database.
+ FieldLastActiveAt = "last_active_at"
+ // FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database.
+ FieldBalanceNotifyEnabled = "balance_notify_enabled"
+ // FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database.
+ FieldBalanceNotifyThresholdType = "balance_notify_threshold_type"
+ // FieldBalanceNotifyThreshold holds the string denoting the balance_notify_threshold field in the database.
+ FieldBalanceNotifyThreshold = "balance_notify_threshold"
+ // FieldBalanceNotifyExtraEmails holds the string denoting the balance_notify_extra_emails field in the database.
+ FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails"
+ // FieldTotalRecharged holds the string denoting the total_recharged field in the database.
+ FieldTotalRecharged = "total_recharged"
+ // FieldRpmLimit holds the string denoting the rpm_limit field in the database.
+ FieldRpmLimit = "rpm_limit"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -65,6 +79,12 @@ const (
EdgeAttributeValues = "attribute_values"
// EdgePromoCodeUsages holds the string denoting the promo_code_usages edge name in mutations.
EdgePromoCodeUsages = "promo_code_usages"
+ // EdgePaymentOrders holds the string denoting the payment_orders edge name in mutations.
+ EdgePaymentOrders = "payment_orders"
+ // EdgeAuthIdentities holds the string denoting the auth_identities edge name in mutations.
+ EdgeAuthIdentities = "auth_identities"
+ // EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations.
+ EdgePendingAuthSessions = "pending_auth_sessions"
// EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations.
EdgeUserAllowedGroups = "user_allowed_groups"
// Table holds the table name of the user in the database.
@@ -130,6 +150,27 @@ const (
PromoCodeUsagesInverseTable = "promo_code_usages"
// PromoCodeUsagesColumn is the table column denoting the promo_code_usages relation/edge.
PromoCodeUsagesColumn = "user_id"
+ // PaymentOrdersTable is the table that holds the payment_orders relation/edge.
+ PaymentOrdersTable = "payment_orders"
+ // PaymentOrdersInverseTable is the table name for the PaymentOrder entity.
+ // It exists in this package in order to avoid circular dependency with the "paymentorder" package.
+ PaymentOrdersInverseTable = "payment_orders"
+ // PaymentOrdersColumn is the table column denoting the payment_orders relation/edge.
+ PaymentOrdersColumn = "user_id"
+ // AuthIdentitiesTable is the table that holds the auth_identities relation/edge.
+ AuthIdentitiesTable = "auth_identities"
+ // AuthIdentitiesInverseTable is the table name for the AuthIdentity entity.
+ // It exists in this package in order to avoid circular dependency with the "authidentity" package.
+ AuthIdentitiesInverseTable = "auth_identities"
+ // AuthIdentitiesColumn is the table column denoting the auth_identities relation/edge.
+ AuthIdentitiesColumn = "user_id"
+ // PendingAuthSessionsTable is the table that holds the pending_auth_sessions relation/edge.
+ PendingAuthSessionsTable = "pending_auth_sessions"
+ // PendingAuthSessionsInverseTable is the table name for the PendingAuthSession entity.
+ // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package.
+ PendingAuthSessionsInverseTable = "pending_auth_sessions"
+ // PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge.
+ PendingAuthSessionsColumn = "target_user_id"
// UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge.
UserAllowedGroupsTable = "user_allowed_groups"
// UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity.
@@ -156,8 +197,15 @@ var Columns = []string{
FieldTotpSecretEncrypted,
FieldTotpEnabled,
FieldTotpEnabledAt,
- FieldSoraStorageQuotaBytes,
- FieldSoraStorageUsedBytes,
+ FieldSignupSource,
+ FieldLastLoginAt,
+ FieldLastActiveAt,
+ FieldBalanceNotifyEnabled,
+ FieldBalanceNotifyThresholdType,
+ FieldBalanceNotifyThreshold,
+ FieldBalanceNotifyExtraEmails,
+ FieldTotalRecharged,
+ FieldRpmLimit,
}
var (
@@ -214,10 +262,20 @@ var (
DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool
- // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field.
- DefaultSoraStorageQuotaBytes int64
- // DefaultSoraStorageUsedBytes holds the default value on creation for the "sora_storage_used_bytes" field.
- DefaultSoraStorageUsedBytes int64
+ // DefaultSignupSource holds the default value on creation for the "signup_source" field.
+ DefaultSignupSource string
+ // SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save.
+ SignupSourceValidator func(string) error
+ // DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field.
+ DefaultBalanceNotifyEnabled bool
+ // DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field.
+ DefaultBalanceNotifyThresholdType string
+ // DefaultBalanceNotifyExtraEmails holds the default value on creation for the "balance_notify_extra_emails" field.
+ DefaultBalanceNotifyExtraEmails string
+ // DefaultTotalRecharged holds the default value on creation for the "total_recharged" field.
+ DefaultTotalRecharged float64
+ // DefaultRpmLimit holds the default value on creation for the "rpm_limit" field.
+ DefaultRpmLimit int
)
// OrderOption defines the ordering options for the User queries.
@@ -298,14 +356,49 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
}
-// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field.
-func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc()
+// BySignupSource orders the results by the signup_source field.
+func BySignupSource(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldSignupSource, opts...).ToFunc()
}
-// BySoraStorageUsedBytes orders the results by the sora_storage_used_bytes field.
-func BySoraStorageUsedBytes(opts ...sql.OrderTermOption) OrderOption {
- return sql.OrderByField(FieldSoraStorageUsedBytes, opts...).ToFunc()
+// ByLastLoginAt orders the results by the last_login_at field.
+func ByLastLoginAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLastLoginAt, opts...).ToFunc()
+}
+
+// ByLastActiveAt orders the results by the last_active_at field.
+func ByLastActiveAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldLastActiveAt, opts...).ToFunc()
+}
+
+// ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field.
+func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc()
+}
+
+// ByBalanceNotifyThresholdType orders the results by the balance_notify_threshold_type field.
+func ByBalanceNotifyThresholdType(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBalanceNotifyThresholdType, opts...).ToFunc()
+}
+
+// ByBalanceNotifyThreshold orders the results by the balance_notify_threshold field.
+func ByBalanceNotifyThreshold(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBalanceNotifyThreshold, opts...).ToFunc()
+}
+
+// ByBalanceNotifyExtraEmails orders the results by the balance_notify_extra_emails field.
+func ByBalanceNotifyExtraEmails(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldBalanceNotifyExtraEmails, opts...).ToFunc()
+}
+
+// ByTotalRecharged orders the results by the total_recharged field.
+func ByTotalRecharged(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc()
+}
+
+// ByRpmLimit orders the results by the rpm_limit field.
+func ByRpmLimit(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRpmLimit, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count.
@@ -434,6 +527,48 @@ func ByPromoCodeUsages(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
}
}
+// ByPaymentOrdersCount orders the results by payment_orders count.
+func ByPaymentOrdersCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newPaymentOrdersStep(), opts...)
+ }
+}
+
+// ByPaymentOrders orders the results by payment_orders terms.
+func ByPaymentOrders(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newPaymentOrdersStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
+// ByAuthIdentitiesCount orders the results by auth_identities count.
+func ByAuthIdentitiesCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newAuthIdentitiesStep(), opts...)
+ }
+}
+
+// ByAuthIdentities orders the results by auth_identities terms.
+func ByAuthIdentities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newAuthIdentitiesStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
+// ByPendingAuthSessionsCount orders the results by pending_auth_sessions count.
+func ByPendingAuthSessionsCount(opts ...sql.OrderTermOption) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborsCount(s, newPendingAuthSessionsStep(), opts...)
+ }
+}
+
+// ByPendingAuthSessions orders the results by pending_auth_sessions terms.
+func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption {
+ return func(s *sql.Selector) {
+ sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionsStep(), append([]sql.OrderTerm{term}, terms...)...)
+ }
+}
+
// ByUserAllowedGroupsCount orders the results by user_allowed_groups count.
func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
@@ -510,6 +645,27 @@ func newPromoCodeUsagesStep() *sqlgraph.Step {
sqlgraph.Edge(sqlgraph.O2M, false, PromoCodeUsagesTable, PromoCodeUsagesColumn),
)
}
+func newPaymentOrdersStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(PaymentOrdersInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, PaymentOrdersTable, PaymentOrdersColumn),
+ )
+}
+func newAuthIdentitiesStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(AuthIdentitiesInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
+ )
+}
+func newPendingAuthSessionsStep() *sqlgraph.Step {
+ return sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.To(PendingAuthSessionsInverseTable, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
+ )
+}
func newUserAllowedGroupsStep() *sqlgraph.Step {
return sqlgraph.NewStep(
sqlgraph.From(Table, FieldID),
diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go
index e26afcf3..05d3b35b 100644
--- a/backend/ent/user/where.go
+++ b/backend/ent/user/where.go
@@ -125,14 +125,49 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
}
-// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ.
-func SoraStorageQuotaBytes(v int64) predicate.User {
- return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
+// SignupSource applies equality check predicate on the "signup_source" field. It's identical to SignupSourceEQ.
+func SignupSource(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSignupSource, v))
}
-// SoraStorageUsedBytes applies equality check predicate on the "sora_storage_used_bytes" field. It's identical to SoraStorageUsedBytesEQ.
-func SoraStorageUsedBytes(v int64) predicate.User {
- return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
+// LastLoginAt applies equality check predicate on the "last_login_at" field. It's identical to LastLoginAtEQ.
+func LastLoginAt(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
+}
+
+// LastActiveAt applies equality check predicate on the "last_active_at" field. It's identical to LastActiveAtEQ.
+func LastActiveAt(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
+}
+
+// BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ.
+func BalanceNotifyEnabled(v bool) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
+}
+
+// BalanceNotifyThresholdType applies equality check predicate on the "balance_notify_threshold_type" field. It's identical to BalanceNotifyThresholdTypeEQ.
+func BalanceNotifyThresholdType(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThreshold applies equality check predicate on the "balance_notify_threshold" field. It's identical to BalanceNotifyThresholdEQ.
+func BalanceNotifyThreshold(v float64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyExtraEmails applies equality check predicate on the "balance_notify_extra_emails" field. It's identical to BalanceNotifyExtraEmailsEQ.
+func BalanceNotifyExtraEmails(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v))
+}
+
+// TotalRecharged applies equality check predicate on the "total_recharged" field. It's identical to TotalRechargedEQ.
+func TotalRecharged(v float64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
+}
+
+// RpmLimit applies equality check predicate on the "rpm_limit" field. It's identical to RpmLimitEQ.
+func RpmLimit(v int) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldRpmLimit, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
@@ -870,84 +905,439 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
}
-// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesEQ(v int64) predicate.User {
- return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v))
+// SignupSourceEQ applies the EQ predicate on the "signup_source" field.
+func SignupSourceEQ(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldSignupSource, v))
}
-// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesNEQ(v int64) predicate.User {
- return predicate.User(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v))
+// SignupSourceNEQ applies the NEQ predicate on the "signup_source" field.
+func SignupSourceNEQ(v string) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldSignupSource, v))
}
-// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesIn(vs ...int64) predicate.User {
- return predicate.User(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...))
+// SignupSourceIn applies the In predicate on the "signup_source" field.
+func SignupSourceIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldIn(FieldSignupSource, vs...))
}
-// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.User {
- return predicate.User(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...))
+// SignupSourceNotIn applies the NotIn predicate on the "signup_source" field.
+func SignupSourceNotIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldSignupSource, vs...))
}
-// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesGT(v int64) predicate.User {
- return predicate.User(sql.FieldGT(FieldSoraStorageQuotaBytes, v))
+// SignupSourceGT applies the GT predicate on the "signup_source" field.
+func SignupSourceGT(v string) predicate.User {
+ return predicate.User(sql.FieldGT(FieldSignupSource, v))
}
-// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesGTE(v int64) predicate.User {
- return predicate.User(sql.FieldGTE(FieldSoraStorageQuotaBytes, v))
+// SignupSourceGTE applies the GTE predicate on the "signup_source" field.
+func SignupSourceGTE(v string) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldSignupSource, v))
}
-// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesLT(v int64) predicate.User {
- return predicate.User(sql.FieldLT(FieldSoraStorageQuotaBytes, v))
+// SignupSourceLT applies the LT predicate on the "signup_source" field.
+func SignupSourceLT(v string) predicate.User {
+ return predicate.User(sql.FieldLT(FieldSignupSource, v))
}
-// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field.
-func SoraStorageQuotaBytesLTE(v int64) predicate.User {
- return predicate.User(sql.FieldLTE(FieldSoraStorageQuotaBytes, v))
+// SignupSourceLTE applies the LTE predicate on the "signup_source" field.
+func SignupSourceLTE(v string) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldSignupSource, v))
}
-// SoraStorageUsedBytesEQ applies the EQ predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesEQ(v int64) predicate.User {
- return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v))
+// SignupSourceContains applies the Contains predicate on the "signup_source" field.
+func SignupSourceContains(v string) predicate.User {
+ return predicate.User(sql.FieldContains(FieldSignupSource, v))
}
-// SoraStorageUsedBytesNEQ applies the NEQ predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesNEQ(v int64) predicate.User {
- return predicate.User(sql.FieldNEQ(FieldSoraStorageUsedBytes, v))
+// SignupSourceHasPrefix applies the HasPrefix predicate on the "signup_source" field.
+func SignupSourceHasPrefix(v string) predicate.User {
+ return predicate.User(sql.FieldHasPrefix(FieldSignupSource, v))
}
-// SoraStorageUsedBytesIn applies the In predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesIn(vs ...int64) predicate.User {
- return predicate.User(sql.FieldIn(FieldSoraStorageUsedBytes, vs...))
+// SignupSourceHasSuffix applies the HasSuffix predicate on the "signup_source" field.
+func SignupSourceHasSuffix(v string) predicate.User {
+ return predicate.User(sql.FieldHasSuffix(FieldSignupSource, v))
}
-// SoraStorageUsedBytesNotIn applies the NotIn predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesNotIn(vs ...int64) predicate.User {
- return predicate.User(sql.FieldNotIn(FieldSoraStorageUsedBytes, vs...))
+// SignupSourceEqualFold applies the EqualFold predicate on the "signup_source" field.
+func SignupSourceEqualFold(v string) predicate.User {
+ return predicate.User(sql.FieldEqualFold(FieldSignupSource, v))
}
-// SoraStorageUsedBytesGT applies the GT predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesGT(v int64) predicate.User {
- return predicate.User(sql.FieldGT(FieldSoraStorageUsedBytes, v))
+// SignupSourceContainsFold applies the ContainsFold predicate on the "signup_source" field.
+func SignupSourceContainsFold(v string) predicate.User {
+ return predicate.User(sql.FieldContainsFold(FieldSignupSource, v))
}
-// SoraStorageUsedBytesGTE applies the GTE predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesGTE(v int64) predicate.User {
- return predicate.User(sql.FieldGTE(FieldSoraStorageUsedBytes, v))
+// LastLoginAtEQ applies the EQ predicate on the "last_login_at" field.
+func LastLoginAtEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastLoginAt, v))
}
-// SoraStorageUsedBytesLT applies the LT predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesLT(v int64) predicate.User {
- return predicate.User(sql.FieldLT(FieldSoraStorageUsedBytes, v))
+// LastLoginAtNEQ applies the NEQ predicate on the "last_login_at" field.
+func LastLoginAtNEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldLastLoginAt, v))
}
-// SoraStorageUsedBytesLTE applies the LTE predicate on the "sora_storage_used_bytes" field.
-func SoraStorageUsedBytesLTE(v int64) predicate.User {
- return predicate.User(sql.FieldLTE(FieldSoraStorageUsedBytes, v))
+// LastLoginAtIn applies the In predicate on the "last_login_at" field.
+func LastLoginAtIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldIn(FieldLastLoginAt, vs...))
+}
+
+// LastLoginAtNotIn applies the NotIn predicate on the "last_login_at" field.
+func LastLoginAtNotIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldLastLoginAt, vs...))
+}
+
+// LastLoginAtGT applies the GT predicate on the "last_login_at" field.
+func LastLoginAtGT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGT(FieldLastLoginAt, v))
+}
+
+// LastLoginAtGTE applies the GTE predicate on the "last_login_at" field.
+func LastLoginAtGTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldLastLoginAt, v))
+}
+
+// LastLoginAtLT applies the LT predicate on the "last_login_at" field.
+func LastLoginAtLT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLT(FieldLastLoginAt, v))
+}
+
+// LastLoginAtLTE applies the LTE predicate on the "last_login_at" field.
+func LastLoginAtLTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldLastLoginAt, v))
+}
+
+// LastLoginAtIsNil applies the IsNil predicate on the "last_login_at" field.
+func LastLoginAtIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldLastLoginAt))
+}
+
+// LastLoginAtNotNil applies the NotNil predicate on the "last_login_at" field.
+func LastLoginAtNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldLastLoginAt))
+}
+
+// LastActiveAtEQ applies the EQ predicate on the "last_active_at" field.
+func LastActiveAtEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldLastActiveAt, v))
+}
+
+// LastActiveAtNEQ applies the NEQ predicate on the "last_active_at" field.
+func LastActiveAtNEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldLastActiveAt, v))
+}
+
+// LastActiveAtIn applies the In predicate on the "last_active_at" field.
+func LastActiveAtIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldIn(FieldLastActiveAt, vs...))
+}
+
+// LastActiveAtNotIn applies the NotIn predicate on the "last_active_at" field.
+func LastActiveAtNotIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldLastActiveAt, vs...))
+}
+
+// LastActiveAtGT applies the GT predicate on the "last_active_at" field.
+func LastActiveAtGT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGT(FieldLastActiveAt, v))
+}
+
+// LastActiveAtGTE applies the GTE predicate on the "last_active_at" field.
+func LastActiveAtGTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldLastActiveAt, v))
+}
+
+// LastActiveAtLT applies the LT predicate on the "last_active_at" field.
+func LastActiveAtLT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLT(FieldLastActiveAt, v))
+}
+
+// LastActiveAtLTE applies the LTE predicate on the "last_active_at" field.
+func LastActiveAtLTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldLastActiveAt, v))
+}
+
+// LastActiveAtIsNil applies the IsNil predicate on the "last_active_at" field.
+func LastActiveAtIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldLastActiveAt))
+}
+
+// LastActiveAtNotNil applies the NotNil predicate on the "last_active_at" field.
+func LastActiveAtNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldLastActiveAt))
+}
+
+// BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field.
+func BalanceNotifyEnabledEQ(v bool) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
+}
+
+// BalanceNotifyEnabledNEQ applies the NEQ predicate on the "balance_notify_enabled" field.
+func BalanceNotifyEnabledNEQ(v bool) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldBalanceNotifyEnabled, v))
+}
+
+// BalanceNotifyThresholdTypeEQ applies the EQ predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeEQ(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeNEQ applies the NEQ predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeNEQ(v string) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeIn applies the In predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldIn(FieldBalanceNotifyThresholdType, vs...))
+}
+
+// BalanceNotifyThresholdTypeNotIn applies the NotIn predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeNotIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldBalanceNotifyThresholdType, vs...))
+}
+
+// BalanceNotifyThresholdTypeGT applies the GT predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeGT(v string) predicate.User {
+ return predicate.User(sql.FieldGT(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeGTE applies the GTE predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeGTE(v string) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeLT applies the LT predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeLT(v string) predicate.User {
+ return predicate.User(sql.FieldLT(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeLTE applies the LTE predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeLTE(v string) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeContains applies the Contains predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeContains(v string) predicate.User {
+ return predicate.User(sql.FieldContains(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeHasPrefix applies the HasPrefix predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeHasPrefix(v string) predicate.User {
+ return predicate.User(sql.FieldHasPrefix(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeHasSuffix applies the HasSuffix predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeHasSuffix(v string) predicate.User {
+ return predicate.User(sql.FieldHasSuffix(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeEqualFold applies the EqualFold predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeEqualFold(v string) predicate.User {
+ return predicate.User(sql.FieldEqualFold(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdTypeContainsFold applies the ContainsFold predicate on the "balance_notify_threshold_type" field.
+func BalanceNotifyThresholdTypeContainsFold(v string) predicate.User {
+ return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyThresholdType, v))
+}
+
+// BalanceNotifyThresholdEQ applies the EQ predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdEQ(v float64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdNEQ applies the NEQ predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdNEQ(v float64) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdIn applies the In predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdIn(vs ...float64) predicate.User {
+ return predicate.User(sql.FieldIn(FieldBalanceNotifyThreshold, vs...))
+}
+
+// BalanceNotifyThresholdNotIn applies the NotIn predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdNotIn(vs ...float64) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldBalanceNotifyThreshold, vs...))
+}
+
+// BalanceNotifyThresholdGT applies the GT predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdGT(v float64) predicate.User {
+ return predicate.User(sql.FieldGT(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdGTE applies the GTE predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdGTE(v float64) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdLT applies the LT predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdLT(v float64) predicate.User {
+ return predicate.User(sql.FieldLT(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdLTE applies the LTE predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdLTE(v float64) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldBalanceNotifyThreshold, v))
+}
+
+// BalanceNotifyThresholdIsNil applies the IsNil predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldBalanceNotifyThreshold))
+}
+
+// BalanceNotifyThresholdNotNil applies the NotNil predicate on the "balance_notify_threshold" field.
+func BalanceNotifyThresholdNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldBalanceNotifyThreshold))
+}
+
+// BalanceNotifyExtraEmailsEQ applies the EQ predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsEQ(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsNEQ applies the NEQ predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsNEQ(v string) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsIn applies the In predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldIn(FieldBalanceNotifyExtraEmails, vs...))
+}
+
+// BalanceNotifyExtraEmailsNotIn applies the NotIn predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsNotIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldBalanceNotifyExtraEmails, vs...))
+}
+
+// BalanceNotifyExtraEmailsGT applies the GT predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsGT(v string) predicate.User {
+ return predicate.User(sql.FieldGT(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsGTE applies the GTE predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsGTE(v string) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsLT applies the LT predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsLT(v string) predicate.User {
+ return predicate.User(sql.FieldLT(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsLTE applies the LTE predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsLTE(v string) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsContains applies the Contains predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsContains(v string) predicate.User {
+ return predicate.User(sql.FieldContains(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsHasPrefix applies the HasPrefix predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsHasPrefix(v string) predicate.User {
+ return predicate.User(sql.FieldHasPrefix(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsHasSuffix applies the HasSuffix predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsHasSuffix(v string) predicate.User {
+ return predicate.User(sql.FieldHasSuffix(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsEqualFold applies the EqualFold predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsEqualFold(v string) predicate.User {
+ return predicate.User(sql.FieldEqualFold(FieldBalanceNotifyExtraEmails, v))
+}
+
+// BalanceNotifyExtraEmailsContainsFold applies the ContainsFold predicate on the "balance_notify_extra_emails" field.
+func BalanceNotifyExtraEmailsContainsFold(v string) predicate.User {
+ return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyExtraEmails, v))
+}
+
+// TotalRechargedEQ applies the EQ predicate on the "total_recharged" field.
+func TotalRechargedEQ(v float64) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
+}
+
+// TotalRechargedNEQ applies the NEQ predicate on the "total_recharged" field.
+func TotalRechargedNEQ(v float64) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldTotalRecharged, v))
+}
+
+// TotalRechargedIn applies the In predicate on the "total_recharged" field.
+func TotalRechargedIn(vs ...float64) predicate.User {
+ return predicate.User(sql.FieldIn(FieldTotalRecharged, vs...))
+}
+
+// TotalRechargedNotIn applies the NotIn predicate on the "total_recharged" field.
+func TotalRechargedNotIn(vs ...float64) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldTotalRecharged, vs...))
+}
+
+// TotalRechargedGT applies the GT predicate on the "total_recharged" field.
+func TotalRechargedGT(v float64) predicate.User {
+ return predicate.User(sql.FieldGT(FieldTotalRecharged, v))
+}
+
+// TotalRechargedGTE applies the GTE predicate on the "total_recharged" field.
+func TotalRechargedGTE(v float64) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldTotalRecharged, v))
+}
+
+// TotalRechargedLT applies the LT predicate on the "total_recharged" field.
+func TotalRechargedLT(v float64) predicate.User {
+ return predicate.User(sql.FieldLT(FieldTotalRecharged, v))
+}
+
+// TotalRechargedLTE applies the LTE predicate on the "total_recharged" field.
+func TotalRechargedLTE(v float64) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldTotalRecharged, v))
+}
+
+// RpmLimitEQ applies the EQ predicate on the "rpm_limit" field.
+func RpmLimitEQ(v int) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldRpmLimit, v))
+}
+
+// RpmLimitNEQ applies the NEQ predicate on the "rpm_limit" field.
+func RpmLimitNEQ(v int) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldRpmLimit, v))
+}
+
+// RpmLimitIn applies the In predicate on the "rpm_limit" field.
+func RpmLimitIn(vs ...int) predicate.User {
+ return predicate.User(sql.FieldIn(FieldRpmLimit, vs...))
+}
+
+// RpmLimitNotIn applies the NotIn predicate on the "rpm_limit" field.
+func RpmLimitNotIn(vs ...int) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldRpmLimit, vs...))
+}
+
+// RpmLimitGT applies the GT predicate on the "rpm_limit" field.
+func RpmLimitGT(v int) predicate.User {
+ return predicate.User(sql.FieldGT(FieldRpmLimit, v))
+}
+
+// RpmLimitGTE applies the GTE predicate on the "rpm_limit" field.
+func RpmLimitGTE(v int) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldRpmLimit, v))
+}
+
+// RpmLimitLT applies the LT predicate on the "rpm_limit" field.
+func RpmLimitLT(v int) predicate.User {
+ return predicate.User(sql.FieldLT(FieldRpmLimit, v))
+}
+
+// RpmLimitLTE applies the LTE predicate on the "rpm_limit" field.
+func RpmLimitLTE(v int) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldRpmLimit, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
@@ -1157,6 +1547,75 @@ func HasPromoCodeUsagesWith(preds ...predicate.PromoCodeUsage) predicate.User {
})
}
+// HasPaymentOrders applies the HasEdge predicate on the "payment_orders" edge.
+func HasPaymentOrders() predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, PaymentOrdersTable, PaymentOrdersColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasPaymentOrdersWith applies the HasEdge predicate on the "payment_orders" edge with a given conditions (other predicates).
+func HasPaymentOrdersWith(preds ...predicate.PaymentOrder) predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := newPaymentOrdersStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasAuthIdentities applies the HasEdge predicate on the "auth_identities" edge.
+func HasAuthIdentities() predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasAuthIdentitiesWith applies the HasEdge predicate on the "auth_identities" edge with a given conditions (other predicates).
+func HasAuthIdentitiesWith(preds ...predicate.AuthIdentity) predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := newAuthIdentitiesStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
+// HasPendingAuthSessions applies the HasEdge predicate on the "pending_auth_sessions" edge.
+func HasPendingAuthSessions() predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := sqlgraph.NewStep(
+ sqlgraph.From(Table, FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn),
+ )
+ sqlgraph.HasNeighbors(s, step)
+ })
+}
+
+// HasPendingAuthSessionsWith applies the HasEdge predicate on the "pending_auth_sessions" edge with a given conditions (other predicates).
+func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate.User {
+ return predicate.User(func(s *sql.Selector) {
+ step := newPendingAuthSessionsStep()
+ sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) {
+ for _, p := range preds {
+ p(s)
+ }
+ })
+ })
+}
+
// HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge.
func HasUserAllowedGroups() predicate.User {
return predicate.User(func(s *sql.Selector) {
diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go
index df0c6bcc..b4161128 100644
--- a/backend/ent/user_create.go
+++ b/backend/ent/user_create.go
@@ -13,7 +13,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
@@ -210,30 +213,128 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c
}
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (_c *UserCreate) SetSoraStorageQuotaBytes(v int64) *UserCreate {
- _c.mutation.SetSoraStorageQuotaBytes(v)
+// SetSignupSource sets the "signup_source" field.
+func (_c *UserCreate) SetSignupSource(v string) *UserCreate {
+ _c.mutation.SetSignupSource(v)
return _c
}
-// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
-func (_c *UserCreate) SetNillableSoraStorageQuotaBytes(v *int64) *UserCreate {
+// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
+func (_c *UserCreate) SetNillableSignupSource(v *string) *UserCreate {
if v != nil {
- _c.SetSoraStorageQuotaBytes(*v)
+ _c.SetSignupSource(*v)
}
return _c
}
-// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
-func (_c *UserCreate) SetSoraStorageUsedBytes(v int64) *UserCreate {
- _c.mutation.SetSoraStorageUsedBytes(v)
+// SetLastLoginAt sets the "last_login_at" field.
+func (_c *UserCreate) SetLastLoginAt(v time.Time) *UserCreate {
+ _c.mutation.SetLastLoginAt(v)
return _c
}
-// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
-func (_c *UserCreate) SetNillableSoraStorageUsedBytes(v *int64) *UserCreate {
+// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
+func (_c *UserCreate) SetNillableLastLoginAt(v *time.Time) *UserCreate {
if v != nil {
- _c.SetSoraStorageUsedBytes(*v)
+ _c.SetLastLoginAt(*v)
+ }
+ return _c
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (_c *UserCreate) SetLastActiveAt(v time.Time) *UserCreate {
+ _c.mutation.SetLastActiveAt(v)
+ return _c
+}
+
+// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
+func (_c *UserCreate) SetNillableLastActiveAt(v *time.Time) *UserCreate {
+ if v != nil {
+ _c.SetLastActiveAt(*v)
+ }
+ return _c
+}
+
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate {
+ _c.mutation.SetBalanceNotifyEnabled(v)
+ return _c
+}
+
+// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
+func (_c *UserCreate) SetNillableBalanceNotifyEnabled(v *bool) *UserCreate {
+ if v != nil {
+ _c.SetBalanceNotifyEnabled(*v)
+ }
+ return _c
+}
+
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (_c *UserCreate) SetBalanceNotifyThresholdType(v string) *UserCreate {
+ _c.mutation.SetBalanceNotifyThresholdType(v)
+ return _c
+}
+
+// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
+func (_c *UserCreate) SetNillableBalanceNotifyThresholdType(v *string) *UserCreate {
+ if v != nil {
+ _c.SetBalanceNotifyThresholdType(*v)
+ }
+ return _c
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (_c *UserCreate) SetBalanceNotifyThreshold(v float64) *UserCreate {
+ _c.mutation.SetBalanceNotifyThreshold(v)
+ return _c
+}
+
+// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
+func (_c *UserCreate) SetNillableBalanceNotifyThreshold(v *float64) *UserCreate {
+ if v != nil {
+ _c.SetBalanceNotifyThreshold(*v)
+ }
+ return _c
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (_c *UserCreate) SetBalanceNotifyExtraEmails(v string) *UserCreate {
+ _c.mutation.SetBalanceNotifyExtraEmails(v)
+ return _c
+}
+
+// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
+func (_c *UserCreate) SetNillableBalanceNotifyExtraEmails(v *string) *UserCreate {
+ if v != nil {
+ _c.SetBalanceNotifyExtraEmails(*v)
+ }
+ return _c
+}
+
+// SetTotalRecharged sets the "total_recharged" field.
+func (_c *UserCreate) SetTotalRecharged(v float64) *UserCreate {
+ _c.mutation.SetTotalRecharged(v)
+ return _c
+}
+
+// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
+func (_c *UserCreate) SetNillableTotalRecharged(v *float64) *UserCreate {
+ if v != nil {
+ _c.SetTotalRecharged(*v)
+ }
+ return _c
+}
+
+// SetRpmLimit sets the "rpm_limit" field.
+func (_c *UserCreate) SetRpmLimit(v int) *UserCreate {
+ _c.mutation.SetRpmLimit(v)
+ return _c
+}
+
+// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
+func (_c *UserCreate) SetNillableRpmLimit(v *int) *UserCreate {
+ if v != nil {
+ _c.SetRpmLimit(*v)
}
return _c
}
@@ -373,6 +474,51 @@ func (_c *UserCreate) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserCreate {
return _c.AddPromoCodeUsageIDs(ids...)
}
+// AddPaymentOrderIDs adds the "payment_orders" edge to the PaymentOrder entity by IDs.
+func (_c *UserCreate) AddPaymentOrderIDs(ids ...int64) *UserCreate {
+ _c.mutation.AddPaymentOrderIDs(ids...)
+ return _c
+}
+
+// AddPaymentOrders adds the "payment_orders" edges to the PaymentOrder entity.
+func (_c *UserCreate) AddPaymentOrders(v ...*PaymentOrder) *UserCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddPaymentOrderIDs(ids...)
+}
+
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (_c *UserCreate) AddAuthIdentityIDs(ids ...int64) *UserCreate {
+ _c.mutation.AddAuthIdentityIDs(ids...)
+ return _c
+}
+
+// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
+func (_c *UserCreate) AddAuthIdentities(v ...*AuthIdentity) *UserCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddAuthIdentityIDs(ids...)
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (_c *UserCreate) AddPendingAuthSessionIDs(ids ...int64) *UserCreate {
+ _c.mutation.AddPendingAuthSessionIDs(ids...)
+ return _c
+}
+
+// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCreate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _c.AddPendingAuthSessionIDs(ids...)
+}
+
// Mutation returns the UserMutation object of the builder.
func (_c *UserCreate) Mutation() *UserMutation {
return _c.mutation
@@ -452,13 +598,29 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v)
}
- if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
- v := user.DefaultSoraStorageQuotaBytes
- _c.mutation.SetSoraStorageQuotaBytes(v)
+ if _, ok := _c.mutation.SignupSource(); !ok {
+ v := user.DefaultSignupSource
+ _c.mutation.SetSignupSource(v)
}
- if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
- v := user.DefaultSoraStorageUsedBytes
- _c.mutation.SetSoraStorageUsedBytes(v)
+ if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
+ v := user.DefaultBalanceNotifyEnabled
+ _c.mutation.SetBalanceNotifyEnabled(v)
+ }
+ if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok {
+ v := user.DefaultBalanceNotifyThresholdType
+ _c.mutation.SetBalanceNotifyThresholdType(v)
+ }
+ if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok {
+ v := user.DefaultBalanceNotifyExtraEmails
+ _c.mutation.SetBalanceNotifyExtraEmails(v)
+ }
+ if _, ok := _c.mutation.TotalRecharged(); !ok {
+ v := user.DefaultTotalRecharged
+ _c.mutation.SetTotalRecharged(v)
+ }
+ if _, ok := _c.mutation.RpmLimit(); !ok {
+ v := user.DefaultRpmLimit
+ _c.mutation.SetRpmLimit(v)
}
return nil
}
@@ -523,11 +685,28 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
}
- if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok {
- return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "User.sora_storage_quota_bytes"`)}
+ if _, ok := _c.mutation.SignupSource(); !ok {
+ return &ValidationError{Name: "signup_source", err: errors.New(`ent: missing required field "User.signup_source"`)}
}
- if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok {
- return &ValidationError{Name: "sora_storage_used_bytes", err: errors.New(`ent: missing required field "User.sora_storage_used_bytes"`)}
+ if v, ok := _c.mutation.SignupSource(); ok {
+ if err := user.SignupSourceValidator(v); err != nil {
+ return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
+ return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)}
+ }
+ if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok {
+ return &ValidationError{Name: "balance_notify_threshold_type", err: errors.New(`ent: missing required field "User.balance_notify_threshold_type"`)}
+ }
+ if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok {
+ return &ValidationError{Name: "balance_notify_extra_emails", err: errors.New(`ent: missing required field "User.balance_notify_extra_emails"`)}
+ }
+ if _, ok := _c.mutation.TotalRecharged(); !ok {
+ return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)}
+ }
+ if _, ok := _c.mutation.RpmLimit(); !ok {
+ return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "User.rpm_limit"`)}
}
return nil
}
@@ -612,13 +791,41 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value
}
- if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok {
- _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
- _node.SoraStorageQuotaBytes = value
+ if value, ok := _c.mutation.SignupSource(); ok {
+ _spec.SetField(user.FieldSignupSource, field.TypeString, value)
+ _node.SignupSource = value
}
- if value, ok := _c.mutation.SoraStorageUsedBytes(); ok {
- _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
- _node.SoraStorageUsedBytes = value
+ if value, ok := _c.mutation.LastLoginAt(); ok {
+ _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
+ _node.LastLoginAt = &value
+ }
+ if value, ok := _c.mutation.LastActiveAt(); ok {
+ _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
+ _node.LastActiveAt = &value
+ }
+ if value, ok := _c.mutation.BalanceNotifyEnabled(); ok {
+ _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
+ _node.BalanceNotifyEnabled = value
+ }
+ if value, ok := _c.mutation.BalanceNotifyThresholdType(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
+ _node.BalanceNotifyThresholdType = value
+ }
+ if value, ok := _c.mutation.BalanceNotifyThreshold(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
+ _node.BalanceNotifyThreshold = &value
+ }
+ if value, ok := _c.mutation.BalanceNotifyExtraEmails(); ok {
+ _spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
+ _node.BalanceNotifyExtraEmails = value
+ }
+ if value, ok := _c.mutation.TotalRecharged(); ok {
+ _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
+ _node.TotalRecharged = value
+ }
+ if value, ok := _c.mutation.RpmLimit(); ok {
+ _spec.SetField(user.FieldRpmLimit, field.TypeInt, value)
+ _node.RpmLimit = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
@@ -768,6 +975,54 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
}
_spec.Edges = append(_spec.Edges, edge)
}
+ if nodes := _c.mutation.PaymentOrdersIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PaymentOrdersTable,
+ Columns: []string{user.PaymentOrdersColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(paymentorder.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
+ if nodes := _c.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges = append(_spec.Edges, edge)
+ }
return _node, _spec
}
@@ -1006,39 +1261,147 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u
}
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (u *UserUpsert) SetSoraStorageQuotaBytes(v int64) *UserUpsert {
- u.Set(user.FieldSoraStorageQuotaBytes, v)
+// SetSignupSource sets the "signup_source" field.
+func (u *UserUpsert) SetSignupSource(v string) *UserUpsert {
+ u.Set(user.FieldSignupSource, v)
return u
}
-// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
-func (u *UserUpsert) UpdateSoraStorageQuotaBytes() *UserUpsert {
- u.SetExcluded(user.FieldSoraStorageQuotaBytes)
+// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
+func (u *UserUpsert) UpdateSignupSource() *UserUpsert {
+ u.SetExcluded(user.FieldSignupSource)
return u
}
-// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
-func (u *UserUpsert) AddSoraStorageQuotaBytes(v int64) *UserUpsert {
- u.Add(user.FieldSoraStorageQuotaBytes, v)
+// SetLastLoginAt sets the "last_login_at" field.
+func (u *UserUpsert) SetLastLoginAt(v time.Time) *UserUpsert {
+ u.Set(user.FieldLastLoginAt, v)
return u
}
-// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
-func (u *UserUpsert) SetSoraStorageUsedBytes(v int64) *UserUpsert {
- u.Set(user.FieldSoraStorageUsedBytes, v)
+// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
+func (u *UserUpsert) UpdateLastLoginAt() *UserUpsert {
+ u.SetExcluded(user.FieldLastLoginAt)
return u
}
-// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
-func (u *UserUpsert) UpdateSoraStorageUsedBytes() *UserUpsert {
- u.SetExcluded(user.FieldSoraStorageUsedBytes)
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (u *UserUpsert) ClearLastLoginAt() *UserUpsert {
+ u.SetNull(user.FieldLastLoginAt)
return u
}
-// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
-func (u *UserUpsert) AddSoraStorageUsedBytes(v int64) *UserUpsert {
- u.Add(user.FieldSoraStorageUsedBytes, v)
+// SetLastActiveAt sets the "last_active_at" field.
+func (u *UserUpsert) SetLastActiveAt(v time.Time) *UserUpsert {
+ u.Set(user.FieldLastActiveAt, v)
+ return u
+}
+
+// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
+func (u *UserUpsert) UpdateLastActiveAt() *UserUpsert {
+ u.SetExcluded(user.FieldLastActiveAt)
+ return u
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (u *UserUpsert) ClearLastActiveAt() *UserUpsert {
+ u.SetNull(user.FieldLastActiveAt)
+ return u
+}
+
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert {
+ u.Set(user.FieldBalanceNotifyEnabled, v)
+ return u
+}
+
+// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
+func (u *UserUpsert) UpdateBalanceNotifyEnabled() *UserUpsert {
+ u.SetExcluded(user.FieldBalanceNotifyEnabled)
+ return u
+}
+
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (u *UserUpsert) SetBalanceNotifyThresholdType(v string) *UserUpsert {
+ u.Set(user.FieldBalanceNotifyThresholdType, v)
+ return u
+}
+
+// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
+func (u *UserUpsert) UpdateBalanceNotifyThresholdType() *UserUpsert {
+ u.SetExcluded(user.FieldBalanceNotifyThresholdType)
+ return u
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (u *UserUpsert) SetBalanceNotifyThreshold(v float64) *UserUpsert {
+ u.Set(user.FieldBalanceNotifyThreshold, v)
+ return u
+}
+
+// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
+func (u *UserUpsert) UpdateBalanceNotifyThreshold() *UserUpsert {
+ u.SetExcluded(user.FieldBalanceNotifyThreshold)
+ return u
+}
+
+// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
+func (u *UserUpsert) AddBalanceNotifyThreshold(v float64) *UserUpsert {
+ u.Add(user.FieldBalanceNotifyThreshold, v)
+ return u
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (u *UserUpsert) ClearBalanceNotifyThreshold() *UserUpsert {
+ u.SetNull(user.FieldBalanceNotifyThreshold)
+ return u
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (u *UserUpsert) SetBalanceNotifyExtraEmails(v string) *UserUpsert {
+ u.Set(user.FieldBalanceNotifyExtraEmails, v)
+ return u
+}
+
+// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
+func (u *UserUpsert) UpdateBalanceNotifyExtraEmails() *UserUpsert {
+ u.SetExcluded(user.FieldBalanceNotifyExtraEmails)
+ return u
+}
+
+// SetTotalRecharged sets the "total_recharged" field.
+func (u *UserUpsert) SetTotalRecharged(v float64) *UserUpsert {
+ u.Set(user.FieldTotalRecharged, v)
+ return u
+}
+
+// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
+func (u *UserUpsert) UpdateTotalRecharged() *UserUpsert {
+ u.SetExcluded(user.FieldTotalRecharged)
+ return u
+}
+
+// AddTotalRecharged adds v to the "total_recharged" field.
+func (u *UserUpsert) AddTotalRecharged(v float64) *UserUpsert {
+ u.Add(user.FieldTotalRecharged, v)
+ return u
+}
+
+// SetRpmLimit sets the "rpm_limit" field.
+func (u *UserUpsert) SetRpmLimit(v int) *UserUpsert {
+ u.Set(user.FieldRpmLimit, v)
+ return u
+}
+
+// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
+func (u *UserUpsert) UpdateRpmLimit() *UserUpsert {
+ u.SetExcluded(user.FieldRpmLimit)
+ return u
+}
+
+// AddRpmLimit adds v to the "rpm_limit" field.
+func (u *UserUpsert) AddRpmLimit(v int) *UserUpsert {
+ u.Add(user.FieldRpmLimit, v)
return u
}
@@ -1304,45 +1667,171 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
})
}
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (u *UserUpsertOne) SetSoraStorageQuotaBytes(v int64) *UserUpsertOne {
+// SetSignupSource sets the "signup_source" field.
+func (u *UserUpsertOne) SetSignupSource(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
- s.SetSoraStorageQuotaBytes(v)
+ s.SetSignupSource(v)
})
}
-// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
-func (u *UserUpsertOne) AddSoraStorageQuotaBytes(v int64) *UserUpsertOne {
+// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateSignupSource() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
- s.AddSoraStorageQuotaBytes(v)
+ s.UpdateSignupSource()
})
}
-// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
-func (u *UserUpsertOne) UpdateSoraStorageQuotaBytes() *UserUpsertOne {
+// SetLastLoginAt sets the "last_login_at" field.
+func (u *UserUpsertOne) SetLastLoginAt(v time.Time) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
- s.UpdateSoraStorageQuotaBytes()
+ s.SetLastLoginAt(v)
})
}
-// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
-func (u *UserUpsertOne) SetSoraStorageUsedBytes(v int64) *UserUpsertOne {
+// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateLastLoginAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
- s.SetSoraStorageUsedBytes(v)
+ s.UpdateLastLoginAt()
})
}
-// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
-func (u *UserUpsertOne) AddSoraStorageUsedBytes(v int64) *UserUpsertOne {
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (u *UserUpsertOne) ClearLastLoginAt() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
- s.AddSoraStorageUsedBytes(v)
+ s.ClearLastLoginAt()
})
}
-// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
-func (u *UserUpsertOne) UpdateSoraStorageUsedBytes() *UserUpsertOne {
+// SetLastActiveAt sets the "last_active_at" field.
+func (u *UserUpsertOne) SetLastActiveAt(v time.Time) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
- s.UpdateSoraStorageUsedBytes()
+ s.SetLastActiveAt(v)
+ })
+}
+
+// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateLastActiveAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastActiveAt()
+ })
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (u *UserUpsertOne) ClearLastActiveAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastActiveAt()
+ })
+}
+
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyEnabled(v)
+ })
+}
+
+// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateBalanceNotifyEnabled() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyEnabled()
+ })
+}
+
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (u *UserUpsertOne) SetBalanceNotifyThresholdType(v string) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyThresholdType(v)
+ })
+}
+
+// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateBalanceNotifyThresholdType() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyThresholdType()
+ })
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (u *UserUpsertOne) SetBalanceNotifyThreshold(v float64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyThreshold(v)
+ })
+}
+
+// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
+func (u *UserUpsertOne) AddBalanceNotifyThreshold(v float64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.AddBalanceNotifyThreshold(v)
+ })
+}
+
+// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateBalanceNotifyThreshold() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyThreshold()
+ })
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (u *UserUpsertOne) ClearBalanceNotifyThreshold() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearBalanceNotifyThreshold()
+ })
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (u *UserUpsertOne) SetBalanceNotifyExtraEmails(v string) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyExtraEmails(v)
+ })
+}
+
+// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateBalanceNotifyExtraEmails() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyExtraEmails()
+ })
+}
+
+// SetTotalRecharged sets the "total_recharged" field.
+func (u *UserUpsertOne) SetTotalRecharged(v float64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotalRecharged(v)
+ })
+}
+
+// AddTotalRecharged adds v to the "total_recharged" field.
+func (u *UserUpsertOne) AddTotalRecharged(v float64) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.AddTotalRecharged(v)
+ })
+}
+
+// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateTotalRecharged() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotalRecharged()
+ })
+}
+
+// SetRpmLimit sets the "rpm_limit" field.
+func (u *UserUpsertOne) SetRpmLimit(v int) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetRpmLimit(v)
+ })
+}
+
+// AddRpmLimit adds v to the "rpm_limit" field.
+func (u *UserUpsertOne) AddRpmLimit(v int) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.AddRpmLimit(v)
+ })
+}
+
+// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateRpmLimit() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateRpmLimit()
})
}
@@ -1774,45 +2263,171 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
})
}
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (u *UserUpsertBulk) SetSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
+// SetSignupSource sets the "signup_source" field.
+func (u *UserUpsertBulk) SetSignupSource(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
- s.SetSoraStorageQuotaBytes(v)
+ s.SetSignupSource(v)
})
}
-// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field.
-func (u *UserUpsertBulk) AddSoraStorageQuotaBytes(v int64) *UserUpsertBulk {
+// UpdateSignupSource sets the "signup_source" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateSignupSource() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
- s.AddSoraStorageQuotaBytes(v)
+ s.UpdateSignupSource()
})
}
-// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create.
-func (u *UserUpsertBulk) UpdateSoraStorageQuotaBytes() *UserUpsertBulk {
+// SetLastLoginAt sets the "last_login_at" field.
+func (u *UserUpsertBulk) SetLastLoginAt(v time.Time) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
- s.UpdateSoraStorageQuotaBytes()
+ s.SetLastLoginAt(v)
})
}
-// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
-func (u *UserUpsertBulk) SetSoraStorageUsedBytes(v int64) *UserUpsertBulk {
+// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateLastLoginAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
- s.SetSoraStorageUsedBytes(v)
+ s.UpdateLastLoginAt()
})
}
-// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field.
-func (u *UserUpsertBulk) AddSoraStorageUsedBytes(v int64) *UserUpsertBulk {
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (u *UserUpsertBulk) ClearLastLoginAt() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
- s.AddSoraStorageUsedBytes(v)
+ s.ClearLastLoginAt()
})
}
-// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create.
-func (u *UserUpsertBulk) UpdateSoraStorageUsedBytes() *UserUpsertBulk {
+// SetLastActiveAt sets the "last_active_at" field.
+func (u *UserUpsertBulk) SetLastActiveAt(v time.Time) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
- s.UpdateSoraStorageUsedBytes()
+ s.SetLastActiveAt(v)
+ })
+}
+
+// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateLastActiveAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateLastActiveAt()
+ })
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (u *UserUpsertBulk) ClearLastActiveAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearLastActiveAt()
+ })
+}
+
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyEnabled(v)
+ })
+}
+
+// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateBalanceNotifyEnabled() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyEnabled()
+ })
+}
+
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (u *UserUpsertBulk) SetBalanceNotifyThresholdType(v string) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyThresholdType(v)
+ })
+}
+
+// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateBalanceNotifyThresholdType() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyThresholdType()
+ })
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (u *UserUpsertBulk) SetBalanceNotifyThreshold(v float64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyThreshold(v)
+ })
+}
+
+// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
+func (u *UserUpsertBulk) AddBalanceNotifyThreshold(v float64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.AddBalanceNotifyThreshold(v)
+ })
+}
+
+// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateBalanceNotifyThreshold() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyThreshold()
+ })
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (u *UserUpsertBulk) ClearBalanceNotifyThreshold() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearBalanceNotifyThreshold()
+ })
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (u *UserUpsertBulk) SetBalanceNotifyExtraEmails(v string) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetBalanceNotifyExtraEmails(v)
+ })
+}
+
+// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateBalanceNotifyExtraEmails() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateBalanceNotifyExtraEmails()
+ })
+}
+
+// SetTotalRecharged sets the "total_recharged" field.
+func (u *UserUpsertBulk) SetTotalRecharged(v float64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotalRecharged(v)
+ })
+}
+
+// AddTotalRecharged adds v to the "total_recharged" field.
+func (u *UserUpsertBulk) AddTotalRecharged(v float64) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.AddTotalRecharged(v)
+ })
+}
+
+// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateTotalRecharged() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotalRecharged()
+ })
+}
+
+// SetRpmLimit sets the "rpm_limit" field.
+func (u *UserUpsertBulk) SetRpmLimit(v int) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetRpmLimit(v)
+ })
+}
+
+// AddRpmLimit adds v to the "rpm_limit" field.
+func (u *UserUpsertBulk) AddRpmLimit(v int) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.AddRpmLimit(v)
+ })
+}
+
+// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateRpmLimit() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateRpmLimit()
})
}
diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go
index 4b56e16f..f1ee5cfe 100644
--- a/backend/ent/user_query.go
+++ b/backend/ent/user_query.go
@@ -15,7 +15,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
@@ -42,6 +45,9 @@ type UserQuery struct {
withUsageLogs *UsageLogQuery
withAttributeValues *UserAttributeValueQuery
withPromoCodeUsages *PromoCodeUsageQuery
+ withPaymentOrders *PaymentOrderQuery
+ withAuthIdentities *AuthIdentityQuery
+ withPendingAuthSessions *PendingAuthSessionQuery
withUserAllowedGroups *UserAllowedGroupQuery
modifiers []func(*sql.Selector)
// intermediate query (i.e. traversal path).
@@ -278,6 +284,72 @@ func (_q *UserQuery) QueryPromoCodeUsages() *PromoCodeUsageQuery {
return query
}
+// QueryPaymentOrders chains the current query on the "payment_orders" edge.
+func (_q *UserQuery) QueryPaymentOrders() *PaymentOrderQuery {
+ query := (&PaymentOrderClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, selector),
+ sqlgraph.To(paymentorder.Table, paymentorder.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.PaymentOrdersTable, user.PaymentOrdersColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryAuthIdentities chains the current query on the "auth_identities" edge.
+func (_q *UserQuery) QueryAuthIdentities() *AuthIdentityQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, selector),
+ sqlgraph.To(authidentity.Table, authidentity.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
+// QueryPendingAuthSessions chains the current query on the "pending_auth_sessions" edge.
+func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ query.path = func(ctx context.Context) (fromU *sql.Selector, err error) {
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ selector := _q.sqlQuery(ctx)
+ if err := selector.Err(); err != nil {
+ return nil, err
+ }
+ step := sqlgraph.NewStep(
+ sqlgraph.From(user.Table, user.FieldID, selector),
+ sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID),
+ sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn),
+ )
+ fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step)
+ return fromU, nil
+ }
+ return query
+}
+
// QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge.
func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery {
query := (&UserAllowedGroupClient{config: _q.config}).Query()
@@ -501,6 +573,9 @@ func (_q *UserQuery) Clone() *UserQuery {
withUsageLogs: _q.withUsageLogs.Clone(),
withAttributeValues: _q.withAttributeValues.Clone(),
withPromoCodeUsages: _q.withPromoCodeUsages.Clone(),
+ withPaymentOrders: _q.withPaymentOrders.Clone(),
+ withAuthIdentities: _q.withAuthIdentities.Clone(),
+ withPendingAuthSessions: _q.withPendingAuthSessions.Clone(),
withUserAllowedGroups: _q.withUserAllowedGroups.Clone(),
// clone intermediate query.
sql: _q.sql.Clone(),
@@ -607,6 +682,39 @@ func (_q *UserQuery) WithPromoCodeUsages(opts ...func(*PromoCodeUsageQuery)) *Us
return _q
}
+// WithPaymentOrders tells the query-builder to eager-load the nodes that are connected to
+// the "payment_orders" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *UserQuery) WithPaymentOrders(opts ...func(*PaymentOrderQuery)) *UserQuery {
+ query := (&PaymentOrderClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withPaymentOrders = query
+ return _q
+}
+
+// WithAuthIdentities tells the query-builder to eager-load the nodes that are connected to
+// the "auth_identities" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *UserQuery) WithAuthIdentities(opts ...func(*AuthIdentityQuery)) *UserQuery {
+ query := (&AuthIdentityClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withAuthIdentities = query
+ return _q
+}
+
+// WithPendingAuthSessions tells the query-builder to eager-load the nodes that are connected to
+// the "pending_auth_sessions" edge. The optional arguments are used to configure the query builder of the edge.
+func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQuery)) *UserQuery {
+ query := (&PendingAuthSessionClient{config: _q.config}).Query()
+ for _, opt := range opts {
+ opt(query)
+ }
+ _q.withPendingAuthSessions = query
+ return _q
+}
+
// WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to
// the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge.
func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery {
@@ -696,7 +804,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
var (
nodes = []*User{}
_spec = _q.querySpec()
- loadedTypes = [10]bool{
+ loadedTypes = [13]bool{
_q.withAPIKeys != nil,
_q.withRedeemCodes != nil,
_q.withSubscriptions != nil,
@@ -706,6 +814,9 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
_q.withUsageLogs != nil,
_q.withAttributeValues != nil,
_q.withPromoCodeUsages != nil,
+ _q.withPaymentOrders != nil,
+ _q.withAuthIdentities != nil,
+ _q.withPendingAuthSessions != nil,
_q.withUserAllowedGroups != nil,
}
)
@@ -795,6 +906,29 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e
return nil, err
}
}
+ if query := _q.withPaymentOrders; query != nil {
+ if err := _q.loadPaymentOrders(ctx, query, nodes,
+ func(n *User) { n.Edges.PaymentOrders = []*PaymentOrder{} },
+ func(n *User, e *PaymentOrder) { n.Edges.PaymentOrders = append(n.Edges.PaymentOrders, e) }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withAuthIdentities; query != nil {
+ if err := _q.loadAuthIdentities(ctx, query, nodes,
+ func(n *User) { n.Edges.AuthIdentities = []*AuthIdentity{} },
+ func(n *User, e *AuthIdentity) { n.Edges.AuthIdentities = append(n.Edges.AuthIdentities, e) }); err != nil {
+ return nil, err
+ }
+ }
+ if query := _q.withPendingAuthSessions; query != nil {
+ if err := _q.loadPendingAuthSessions(ctx, query, nodes,
+ func(n *User) { n.Edges.PendingAuthSessions = []*PendingAuthSession{} },
+ func(n *User, e *PendingAuthSession) {
+ n.Edges.PendingAuthSessions = append(n.Edges.PendingAuthSessions, e)
+ }); err != nil {
+ return nil, err
+ }
+ }
if query := _q.withUserAllowedGroups; query != nil {
if err := _q.loadUserAllowedGroups(ctx, query, nodes,
func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} },
@@ -1112,6 +1246,99 @@ func (_q *UserQuery) loadPromoCodeUsages(ctx context.Context, query *PromoCodeUs
}
return nil
}
+func (_q *UserQuery) loadPaymentOrders(ctx context.Context, query *PaymentOrderQuery, nodes []*User, init func(*User), assign func(*User, *PaymentOrder)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*User)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(paymentorder.FieldUserID)
+ }
+ query.Where(predicate.PaymentOrder(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(user.PaymentOrdersColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.UserID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+func (_q *UserQuery) loadAuthIdentities(ctx context.Context, query *AuthIdentityQuery, nodes []*User, init func(*User), assign func(*User, *AuthIdentity)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*User)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(authidentity.FieldUserID)
+ }
+ query.Where(predicate.AuthIdentity(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(user.AuthIdentitiesColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.UserID
+ node, ok := nodeids[fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
+func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *PendingAuthSessionQuery, nodes []*User, init func(*User), assign func(*User, *PendingAuthSession)) error {
+ fks := make([]driver.Value, 0, len(nodes))
+ nodeids := make(map[int64]*User)
+ for i := range nodes {
+ fks = append(fks, nodes[i].ID)
+ nodeids[nodes[i].ID] = nodes[i]
+ if init != nil {
+ init(nodes[i])
+ }
+ }
+ if len(query.ctx.Fields) > 0 {
+ query.ctx.AppendFieldOnce(pendingauthsession.FieldTargetUserID)
+ }
+ query.Where(predicate.PendingAuthSession(func(s *sql.Selector) {
+ s.Where(sql.InValues(s.C(user.PendingAuthSessionsColumn), fks...))
+ }))
+ neighbors, err := query.All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, n := range neighbors {
+ fk := n.TargetUserID
+ if fk == nil {
+ return fmt.Errorf(`foreign-key "target_user_id" is nil for node %v`, n.ID)
+ }
+ node, ok := nodeids[*fk]
+ if !ok {
+ return fmt.Errorf(`unexpected referenced foreign-key "target_user_id" returned %v for node %v`, *fk, n.ID)
+ }
+ assign(node, n)
+ }
+ return nil
+}
func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error {
fks := make([]driver.Value, 0, len(nodes))
nodeids := make(map[int64]*User)
diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go
index f71f0cad..f1d759ce 100644
--- a/backend/ent/user_update.go
+++ b/backend/ent/user_update.go
@@ -13,7 +13,10 @@ import (
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/announcementread"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/group"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/predicate"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
@@ -242,45 +245,168 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u
}
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (_u *UserUpdate) SetSoraStorageQuotaBytes(v int64) *UserUpdate {
- _u.mutation.ResetSoraStorageQuotaBytes()
- _u.mutation.SetSoraStorageQuotaBytes(v)
+// SetSignupSource sets the "signup_source" field.
+func (_u *UserUpdate) SetSignupSource(v string) *UserUpdate {
+ _u.mutation.SetSignupSource(v)
return _u
}
-// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
-func (_u *UserUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdate {
+// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableSignupSource(v *string) *UserUpdate {
if v != nil {
- _u.SetSoraStorageQuotaBytes(*v)
+ _u.SetSignupSource(*v)
}
return _u
}
-// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
-func (_u *UserUpdate) AddSoraStorageQuotaBytes(v int64) *UserUpdate {
- _u.mutation.AddSoraStorageQuotaBytes(v)
+// SetLastLoginAt sets the "last_login_at" field.
+func (_u *UserUpdate) SetLastLoginAt(v time.Time) *UserUpdate {
+ _u.mutation.SetLastLoginAt(v)
return _u
}
-// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
-func (_u *UserUpdate) SetSoraStorageUsedBytes(v int64) *UserUpdate {
- _u.mutation.ResetSoraStorageUsedBytes()
- _u.mutation.SetSoraStorageUsedBytes(v)
- return _u
-}
-
-// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
-func (_u *UserUpdate) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdate {
+// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableLastLoginAt(v *time.Time) *UserUpdate {
if v != nil {
- _u.SetSoraStorageUsedBytes(*v)
+ _u.SetLastLoginAt(*v)
}
return _u
}
-// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
-func (_u *UserUpdate) AddSoraStorageUsedBytes(v int64) *UserUpdate {
- _u.mutation.AddSoraStorageUsedBytes(v)
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (_u *UserUpdate) ClearLastLoginAt() *UserUpdate {
+ _u.mutation.ClearLastLoginAt()
+ return _u
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (_u *UserUpdate) SetLastActiveAt(v time.Time) *UserUpdate {
+ _u.mutation.SetLastActiveAt(v)
+ return _u
+}
+
+// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableLastActiveAt(v *time.Time) *UserUpdate {
+ if v != nil {
+ _u.SetLastActiveAt(*v)
+ }
+ return _u
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (_u *UserUpdate) ClearLastActiveAt() *UserUpdate {
+ _u.mutation.ClearLastActiveAt()
+ return _u
+}
+
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate {
+ _u.mutation.SetBalanceNotifyEnabled(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdate {
+ if v != nil {
+ _u.SetBalanceNotifyEnabled(*v)
+ }
+ return _u
+}
+
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (_u *UserUpdate) SetBalanceNotifyThresholdType(v string) *UserUpdate {
+ _u.mutation.SetBalanceNotifyThresholdType(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdate {
+ if v != nil {
+ _u.SetBalanceNotifyThresholdType(*v)
+ }
+ return _u
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (_u *UserUpdate) SetBalanceNotifyThreshold(v float64) *UserUpdate {
+ _u.mutation.ResetBalanceNotifyThreshold()
+ _u.mutation.SetBalanceNotifyThreshold(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableBalanceNotifyThreshold(v *float64) *UserUpdate {
+ if v != nil {
+ _u.SetBalanceNotifyThreshold(*v)
+ }
+ return _u
+}
+
+// AddBalanceNotifyThreshold adds value to the "balance_notify_threshold" field.
+func (_u *UserUpdate) AddBalanceNotifyThreshold(v float64) *UserUpdate {
+ _u.mutation.AddBalanceNotifyThreshold(v)
+ return _u
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (_u *UserUpdate) ClearBalanceNotifyThreshold() *UserUpdate {
+ _u.mutation.ClearBalanceNotifyThreshold()
+ return _u
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (_u *UserUpdate) SetBalanceNotifyExtraEmails(v string) *UserUpdate {
+ _u.mutation.SetBalanceNotifyExtraEmails(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdate {
+ if v != nil {
+ _u.SetBalanceNotifyExtraEmails(*v)
+ }
+ return _u
+}
+
+// SetTotalRecharged sets the "total_recharged" field.
+func (_u *UserUpdate) SetTotalRecharged(v float64) *UserUpdate {
+ _u.mutation.ResetTotalRecharged()
+ _u.mutation.SetTotalRecharged(v)
+ return _u
+}
+
+// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableTotalRecharged(v *float64) *UserUpdate {
+ if v != nil {
+ _u.SetTotalRecharged(*v)
+ }
+ return _u
+}
+
+// AddTotalRecharged adds value to the "total_recharged" field.
+func (_u *UserUpdate) AddTotalRecharged(v float64) *UserUpdate {
+ _u.mutation.AddTotalRecharged(v)
+ return _u
+}
+
+// SetRpmLimit sets the "rpm_limit" field.
+func (_u *UserUpdate) SetRpmLimit(v int) *UserUpdate {
+ _u.mutation.ResetRpmLimit()
+ _u.mutation.SetRpmLimit(v)
+ return _u
+}
+
+// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableRpmLimit(v *int) *UserUpdate {
+ if v != nil {
+ _u.SetRpmLimit(*v)
+ }
+ return _u
+}
+
+// AddRpmLimit adds value to the "rpm_limit" field.
+func (_u *UserUpdate) AddRpmLimit(v int) *UserUpdate {
+ _u.mutation.AddRpmLimit(v)
return _u
}
@@ -419,6 +545,51 @@ func (_u *UserUpdate) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserUpdate {
return _u.AddPromoCodeUsageIDs(ids...)
}
+// AddPaymentOrderIDs adds the "payment_orders" edge to the PaymentOrder entity by IDs.
+func (_u *UserUpdate) AddPaymentOrderIDs(ids ...int64) *UserUpdate {
+ _u.mutation.AddPaymentOrderIDs(ids...)
+ return _u
+}
+
+// AddPaymentOrders adds the "payment_orders" edges to the PaymentOrder entity.
+func (_u *UserUpdate) AddPaymentOrders(v ...*PaymentOrder) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddPaymentOrderIDs(ids...)
+}
+
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (_u *UserUpdate) AddAuthIdentityIDs(ids ...int64) *UserUpdate {
+ _u.mutation.AddAuthIdentityIDs(ids...)
+ return _u
+}
+
+// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdate) AddAuthIdentities(v ...*AuthIdentity) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAuthIdentityIDs(ids...)
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (_u *UserUpdate) AddPendingAuthSessionIDs(ids ...int64) *UserUpdate {
+ _u.mutation.AddPendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddPendingAuthSessionIDs(ids...)
+}
+
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdate) Mutation() *UserMutation {
return _u.mutation
@@ -613,6 +784,69 @@ func (_u *UserUpdate) RemovePromoCodeUsages(v ...*PromoCodeUsage) *UserUpdate {
return _u.RemovePromoCodeUsageIDs(ids...)
}
+// ClearPaymentOrders clears all "payment_orders" edges to the PaymentOrder entity.
+func (_u *UserUpdate) ClearPaymentOrders() *UserUpdate {
+ _u.mutation.ClearPaymentOrders()
+ return _u
+}
+
+// RemovePaymentOrderIDs removes the "payment_orders" edge to PaymentOrder entities by IDs.
+func (_u *UserUpdate) RemovePaymentOrderIDs(ids ...int64) *UserUpdate {
+ _u.mutation.RemovePaymentOrderIDs(ids...)
+ return _u
+}
+
+// RemovePaymentOrders removes "payment_orders" edges to PaymentOrder entities.
+func (_u *UserUpdate) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemovePaymentOrderIDs(ids...)
+}
+
+// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdate) ClearAuthIdentities() *UserUpdate {
+ _u.mutation.ClearAuthIdentities()
+ return _u
+}
+
+// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
+func (_u *UserUpdate) RemoveAuthIdentityIDs(ids ...int64) *UserUpdate {
+ _u.mutation.RemoveAuthIdentityIDs(ids...)
+ return _u
+}
+
+// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
+func (_u *UserUpdate) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAuthIdentityIDs(ids...)
+}
+
+// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdate) ClearPendingAuthSessions() *UserUpdate {
+ _u.mutation.ClearPendingAuthSessions()
+ return _u
+}
+
+// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
+func (_u *UserUpdate) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdate {
+ _u.mutation.RemovePendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
+func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdate {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemovePendingAuthSessionIDs(ids...)
+}
+
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *UserUpdate) Save(ctx context.Context) (int, error) {
if err := _u.defaults(); err != nil {
@@ -682,6 +916,11 @@ func (_u *UserUpdate) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
+ if v, ok := _u.mutation.SignupSource(); ok {
+ if err := user.SignupSourceValidator(v); err != nil {
+ return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
+ }
+ }
return nil
}
@@ -751,17 +990,50 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
- if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
- _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ if value, ok := _u.mutation.SignupSource(); ok {
+ _spec.SetField(user.FieldSignupSource, field.TypeString, value)
}
- if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
- _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ if value, ok := _u.mutation.LastLoginAt(); ok {
+ _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
}
- if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
- _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
+ if _u.mutation.LastLoginAtCleared() {
+ _spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
}
- if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
- _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
+ if value, ok := _u.mutation.LastActiveAt(); ok {
+ _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastActiveAtCleared() {
+ _spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
+ _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BalanceNotifyThreshold(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedBalanceNotifyThreshold(); ok {
+ _spec.AddField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
+ }
+ if _u.mutation.BalanceNotifyThresholdCleared() {
+ _spec.ClearField(user.FieldBalanceNotifyThreshold, field.TypeFloat64)
+ }
+ if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok {
+ _spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.TotalRecharged(); ok {
+ _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedTotalRecharged(); ok {
+ _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.RpmLimit(); ok {
+ _spec.SetField(user.FieldRpmLimit, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedRpmLimit(); ok {
+ _spec.AddField(user.FieldRpmLimit, field.TypeInt, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
@@ -1180,6 +1452,141 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.PaymentOrdersCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PaymentOrdersTable,
+ Columns: []string{user.PaymentOrdersColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(paymentorder.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedPaymentOrdersIDs(); len(nodes) > 0 && !_u.mutation.PaymentOrdersCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PaymentOrdersTable,
+ Columns: []string{user.PaymentOrdersColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(paymentorder.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PaymentOrdersIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PaymentOrdersTable,
+ Columns: []string{user.PaymentOrdersColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(paymentorder.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{user.Label}
@@ -1406,45 +1813,168 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u
}
-// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field.
-func (_u *UserUpdateOne) SetSoraStorageQuotaBytes(v int64) *UserUpdateOne {
- _u.mutation.ResetSoraStorageQuotaBytes()
- _u.mutation.SetSoraStorageQuotaBytes(v)
+// SetSignupSource sets the "signup_source" field.
+func (_u *UserUpdateOne) SetSignupSource(v string) *UserUpdateOne {
+ _u.mutation.SetSignupSource(v)
return _u
}
-// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil.
-func (_u *UserUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdateOne {
+// SetNillableSignupSource sets the "signup_source" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableSignupSource(v *string) *UserUpdateOne {
if v != nil {
- _u.SetSoraStorageQuotaBytes(*v)
+ _u.SetSignupSource(*v)
}
return _u
}
-// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field.
-func (_u *UserUpdateOne) AddSoraStorageQuotaBytes(v int64) *UserUpdateOne {
- _u.mutation.AddSoraStorageQuotaBytes(v)
+// SetLastLoginAt sets the "last_login_at" field.
+func (_u *UserUpdateOne) SetLastLoginAt(v time.Time) *UserUpdateOne {
+ _u.mutation.SetLastLoginAt(v)
return _u
}
-// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field.
-func (_u *UserUpdateOne) SetSoraStorageUsedBytes(v int64) *UserUpdateOne {
- _u.mutation.ResetSoraStorageUsedBytes()
- _u.mutation.SetSoraStorageUsedBytes(v)
- return _u
-}
-
-// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil.
-func (_u *UserUpdateOne) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdateOne {
+// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableLastLoginAt(v *time.Time) *UserUpdateOne {
if v != nil {
- _u.SetSoraStorageUsedBytes(*v)
+ _u.SetLastLoginAt(*v)
}
return _u
}
-// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field.
-func (_u *UserUpdateOne) AddSoraStorageUsedBytes(v int64) *UserUpdateOne {
- _u.mutation.AddSoraStorageUsedBytes(v)
+// ClearLastLoginAt clears the value of the "last_login_at" field.
+func (_u *UserUpdateOne) ClearLastLoginAt() *UserUpdateOne {
+ _u.mutation.ClearLastLoginAt()
+ return _u
+}
+
+// SetLastActiveAt sets the "last_active_at" field.
+func (_u *UserUpdateOne) SetLastActiveAt(v time.Time) *UserUpdateOne {
+ _u.mutation.SetLastActiveAt(v)
+ return _u
+}
+
+// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableLastActiveAt(v *time.Time) *UserUpdateOne {
+ if v != nil {
+ _u.SetLastActiveAt(*v)
+ }
+ return _u
+}
+
+// ClearLastActiveAt clears the value of the "last_active_at" field.
+func (_u *UserUpdateOne) ClearLastActiveAt() *UserUpdateOne {
+ _u.mutation.ClearLastActiveAt()
+ return _u
+}
+
+// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
+func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne {
+ _u.mutation.SetBalanceNotifyEnabled(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdateOne {
+ if v != nil {
+ _u.SetBalanceNotifyEnabled(*v)
+ }
+ return _u
+}
+
+// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
+func (_u *UserUpdateOne) SetBalanceNotifyThresholdType(v string) *UserUpdateOne {
+ _u.mutation.SetBalanceNotifyThresholdType(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdateOne {
+ if v != nil {
+ _u.SetBalanceNotifyThresholdType(*v)
+ }
+ return _u
+}
+
+// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
+func (_u *UserUpdateOne) SetBalanceNotifyThreshold(v float64) *UserUpdateOne {
+ _u.mutation.ResetBalanceNotifyThreshold()
+ _u.mutation.SetBalanceNotifyThreshold(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableBalanceNotifyThreshold(v *float64) *UserUpdateOne {
+ if v != nil {
+ _u.SetBalanceNotifyThreshold(*v)
+ }
+ return _u
+}
+
+// AddBalanceNotifyThreshold adds value to the "balance_notify_threshold" field.
+func (_u *UserUpdateOne) AddBalanceNotifyThreshold(v float64) *UserUpdateOne {
+ _u.mutation.AddBalanceNotifyThreshold(v)
+ return _u
+}
+
+// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
+func (_u *UserUpdateOne) ClearBalanceNotifyThreshold() *UserUpdateOne {
+ _u.mutation.ClearBalanceNotifyThreshold()
+ return _u
+}
+
+// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
+func (_u *UserUpdateOne) SetBalanceNotifyExtraEmails(v string) *UserUpdateOne {
+ _u.mutation.SetBalanceNotifyExtraEmails(v)
+ return _u
+}
+
+// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdateOne {
+ if v != nil {
+ _u.SetBalanceNotifyExtraEmails(*v)
+ }
+ return _u
+}
+
+// SetTotalRecharged sets the "total_recharged" field.
+func (_u *UserUpdateOne) SetTotalRecharged(v float64) *UserUpdateOne {
+ _u.mutation.ResetTotalRecharged()
+ _u.mutation.SetTotalRecharged(v)
+ return _u
+}
+
+// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableTotalRecharged(v *float64) *UserUpdateOne {
+ if v != nil {
+ _u.SetTotalRecharged(*v)
+ }
+ return _u
+}
+
+// AddTotalRecharged adds value to the "total_recharged" field.
+func (_u *UserUpdateOne) AddTotalRecharged(v float64) *UserUpdateOne {
+ _u.mutation.AddTotalRecharged(v)
+ return _u
+}
+
+// SetRpmLimit sets the "rpm_limit" field.
+func (_u *UserUpdateOne) SetRpmLimit(v int) *UserUpdateOne {
+ _u.mutation.ResetRpmLimit()
+ _u.mutation.SetRpmLimit(v)
+ return _u
+}
+
+// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableRpmLimit(v *int) *UserUpdateOne {
+ if v != nil {
+ _u.SetRpmLimit(*v)
+ }
+ return _u
+}
+
+// AddRpmLimit adds value to the "rpm_limit" field.
+func (_u *UserUpdateOne) AddRpmLimit(v int) *UserUpdateOne {
+ _u.mutation.AddRpmLimit(v)
return _u
}
@@ -1583,6 +2113,51 @@ func (_u *UserUpdateOne) AddPromoCodeUsages(v ...*PromoCodeUsage) *UserUpdateOne
return _u.AddPromoCodeUsageIDs(ids...)
}
+// AddPaymentOrderIDs adds the "payment_orders" edge to the PaymentOrder entity by IDs.
+func (_u *UserUpdateOne) AddPaymentOrderIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.AddPaymentOrderIDs(ids...)
+ return _u
+}
+
+// AddPaymentOrders adds the "payment_orders" edges to the PaymentOrder entity.
+func (_u *UserUpdateOne) AddPaymentOrders(v ...*PaymentOrder) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddPaymentOrderIDs(ids...)
+}
+
+// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs.
+func (_u *UserUpdateOne) AddAuthIdentityIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.AddAuthIdentityIDs(ids...)
+ return _u
+}
+
+// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdateOne) AddAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddAuthIdentityIDs(ids...)
+}
+
+// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs.
+func (_u *UserUpdateOne) AddPendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.AddPendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.AddPendingAuthSessionIDs(ids...)
+}
+
// Mutation returns the UserMutation object of the builder.
func (_u *UserUpdateOne) Mutation() *UserMutation {
return _u.mutation
@@ -1777,6 +2352,69 @@ func (_u *UserUpdateOne) RemovePromoCodeUsages(v ...*PromoCodeUsage) *UserUpdate
return _u.RemovePromoCodeUsageIDs(ids...)
}
+// ClearPaymentOrders clears all "payment_orders" edges to the PaymentOrder entity.
+func (_u *UserUpdateOne) ClearPaymentOrders() *UserUpdateOne {
+ _u.mutation.ClearPaymentOrders()
+ return _u
+}
+
+// RemovePaymentOrderIDs removes the "payment_orders" edge to PaymentOrder entities by IDs.
+func (_u *UserUpdateOne) RemovePaymentOrderIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.RemovePaymentOrderIDs(ids...)
+ return _u
+}
+
+// RemovePaymentOrders removes "payment_orders" edges to PaymentOrder entities.
+func (_u *UserUpdateOne) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemovePaymentOrderIDs(ids...)
+}
+
+// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity.
+func (_u *UserUpdateOne) ClearAuthIdentities() *UserUpdateOne {
+ _u.mutation.ClearAuthIdentities()
+ return _u
+}
+
+// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs.
+func (_u *UserUpdateOne) RemoveAuthIdentityIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.RemoveAuthIdentityIDs(ids...)
+ return _u
+}
+
+// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities.
+func (_u *UserUpdateOne) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemoveAuthIdentityIDs(ids...)
+}
+
+// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity.
+func (_u *UserUpdateOne) ClearPendingAuthSessions() *UserUpdateOne {
+ _u.mutation.ClearPendingAuthSessions()
+ return _u
+}
+
+// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs.
+func (_u *UserUpdateOne) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdateOne {
+ _u.mutation.RemovePendingAuthSessionIDs(ids...)
+ return _u
+}
+
+// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities.
+func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne {
+ ids := make([]int64, len(v))
+ for i := range v {
+ ids[i] = v[i].ID
+ }
+ return _u.RemovePendingAuthSessionIDs(ids...)
+}
+
// Where appends a list predicates to the UserUpdate builder.
func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne {
_u.mutation.Where(ps...)
@@ -1859,6 +2497,11 @@ func (_u *UserUpdateOne) check() error {
return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)}
}
}
+ if v, ok := _u.mutation.SignupSource(); ok {
+ if err := user.SignupSourceValidator(v); err != nil {
+ return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)}
+ }
+ }
return nil
}
@@ -1945,17 +2588,50 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
}
- if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok {
- _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ if value, ok := _u.mutation.SignupSource(); ok {
+ _spec.SetField(user.FieldSignupSource, field.TypeString, value)
}
- if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok {
- _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value)
+ if value, ok := _u.mutation.LastLoginAt(); ok {
+ _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value)
}
- if value, ok := _u.mutation.SoraStorageUsedBytes(); ok {
- _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
+ if _u.mutation.LastLoginAtCleared() {
+ _spec.ClearField(user.FieldLastLoginAt, field.TypeTime)
}
- if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok {
- _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value)
+ if value, ok := _u.mutation.LastActiveAt(); ok {
+ _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value)
+ }
+ if _u.mutation.LastActiveAtCleared() {
+ _spec.ClearField(user.FieldLastActiveAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
+ _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.BalanceNotifyThreshold(); ok {
+ _spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedBalanceNotifyThreshold(); ok {
+ _spec.AddField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
+ }
+ if _u.mutation.BalanceNotifyThresholdCleared() {
+ _spec.ClearField(user.FieldBalanceNotifyThreshold, field.TypeFloat64)
+ }
+ if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok {
+ _spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.TotalRecharged(); ok {
+ _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedTotalRecharged(); ok {
+ _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.RpmLimit(); ok {
+ _spec.SetField(user.FieldRpmLimit, field.TypeInt, value)
+ }
+ if value, ok := _u.mutation.AddedRpmLimit(); ok {
+ _spec.AddField(user.FieldRpmLimit, field.TypeInt, value)
}
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
@@ -2374,6 +3050,141 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
}
_spec.Edges.Add = append(_spec.Edges.Add, edge)
}
+ if _u.mutation.PaymentOrdersCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PaymentOrdersTable,
+ Columns: []string{user.PaymentOrdersColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(paymentorder.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedPaymentOrdersIDs(); len(nodes) > 0 && !_u.mutation.PaymentOrdersCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PaymentOrdersTable,
+ Columns: []string{user.PaymentOrdersColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(paymentorder.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PaymentOrdersIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PaymentOrdersTable,
+ Columns: []string{user.PaymentOrdersColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(paymentorder.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.AuthIdentitiesTable,
+ Columns: []string{user.AuthIdentitiesColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
+ if _u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Clear = append(_spec.Edges.Clear, edge)
+ }
+ if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 {
+ edge := &sqlgraph.EdgeSpec{
+ Rel: sqlgraph.O2M,
+ Inverse: false,
+ Table: user.PendingAuthSessionsTable,
+ Columns: []string{user.PendingAuthSessionsColumn},
+ Bidi: false,
+ Target: &sqlgraph.EdgeTarget{
+ IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64),
+ },
+ }
+ for _, k := range nodes {
+ edge.Target.Nodes = append(edge.Target.Nodes, k)
+ }
+ _spec.Edges.Add = append(_spec.Edges.Add, edge)
+ }
_node = &User{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
diff --git a/backend/go.mod b/backend/go.mod
index 135cbd3e..982bf91b 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -1,12 +1,12 @@
module github.com/Wei-Shaw/sub2api
-go 1.26.1
+go 1.26.2
require (
entgo.io/ent v0.14.5
github.com/DATA-DOG/go-sqlmock v1.5.2
- github.com/DouDOU-start/go-sora2api v1.1.0
github.com/alitto/pond/v2 v2.6.2
+ github.com/andybalholm/brotli v1.2.0
github.com/aws/aws-sdk-go-v2 v1.41.3
github.com/aws/aws-sdk-go-v2/config v1.32.10
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
@@ -27,18 +27,23 @@ require (
github.com/refraction-networking/utls v1.8.2
github.com/robfig/cron/v3 v3.0.1
github.com/shirou/gopsutil/v4 v4.25.6
+ github.com/shopspring/decimal v1.4.0
+ github.com/smartwalle/alipay/v3 v3.2.29
github.com/spf13/viper v1.18.2
github.com/stretchr/testify v1.11.1
+ github.com/stripe/stripe-go/v85 v85.0.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0
github.com/testcontainers/testcontainers-go/modules/redis v0.40.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
+ github.com/wechatpay-apiv3/wechatpay-go v0.2.21
github.com/zeromicro/go-zero v1.9.4
go.uber.org/zap v1.24.0
- golang.org/x/crypto v0.48.0
- golang.org/x/net v0.49.0
- golang.org/x/sync v0.19.0
- golang.org/x/term v0.40.0
+ golang.org/x/crypto v0.49.0
+ golang.org/x/image v0.39.0
+ golang.org/x/net v0.52.0
+ golang.org/x/sync v0.20.0
+ golang.org/x/term v0.41.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.44.3
@@ -50,7 +55,6 @@ require (
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/agext/levenshtein v1.2.3 // indirect
- github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect
@@ -67,14 +71,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
github.com/aws/smithy-go v1.24.2 // indirect
- github.com/bdandy/go-errors v1.2.2 // indirect
- github.com/bdandy/go-socks4 v1.2.3 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
- github.com/bogdanfinn/fhttp v0.6.8 // indirect
- github.com/bogdanfinn/quic-go-utls v1.0.9-utls // indirect
- github.com/bogdanfinn/tls-client v1.14.0 // indirect
- github.com/bogdanfinn/utls v1.7.7-barnius // indirect
- github.com/bogdanfinn/websocket v1.5.5-barnius // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
@@ -107,6 +104,7 @@ require (
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
+ github.com/google/subcommands v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
@@ -145,13 +143,15 @@ require (
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
+ github.com/smartwalle/ncrypto v1.0.4 // indirect
+ github.com/smartwalle/ngx v1.1.0 // indirect
+ github.com/smartwalle/nsign v1.0.9 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
- github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 // indirect
github.com/testcontainers/testcontainers-go v0.40.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
@@ -173,9 +173,10 @@ require (
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
- golang.org/x/mod v0.32.0 // indirect
- golang.org/x/sys v0.41.0 // indirect
- golang.org/x/text v0.34.0 // indirect
+ golang.org/x/mod v0.34.0 // indirect
+ golang.org/x/sys v0.42.0 // indirect
+ golang.org/x/text v0.36.0 // indirect
+ golang.org/x/tools v0.43.0 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
diff --git a/backend/go.sum b/backend/go.sum
index f5b7968f..0f366ee1 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -10,12 +10,12 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOEl
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
-github.com/DouDOU-start/go-sora2api v1.1.0 h1:PxWiukK77StiHxEngOFwT1rKUn9oTAJJTl07wQUXwiU=
-github.com/DouDOU-start/go-sora2api v1.1.0/go.mod h1:dcwpethoKfAsMWskDD9iGgc/3yox2tkthPLSMVGnhkE=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo=
github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558=
+github.com/agiledragon/gomonkey v2.0.2+incompatible h1:eXKi9/piiC3cjJD1658mEE2o3NjkJ5vDLgYjCQu0Xlw=
+github.com/agiledragon/gomonkey v2.0.2+incompatible/go.mod h1:2NGfXu1a80LLr2cmWXGBDaHEjb1idR6+FVlX5T3D9hw=
github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw=
github.com/alitto/pond/v2 v2.6.2/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
@@ -60,24 +60,10 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
-github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
-github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
-github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
-github.com/bdandy/go-socks4 v1.2.3/go.mod h1:98kiVFgpdogR8aIGLWLvjDVZ8XcKPsSI/ypGrO+bqHI=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
-github.com/bogdanfinn/fhttp v0.6.8 h1:LiQyHOY3i0QoxxNB7nq27/nGNNbtPj0fuBPozhR7Ws4=
-github.com/bogdanfinn/fhttp v0.6.8/go.mod h1:A+EKDzMx2hb4IUbMx4TlkoHnaJEiLl8r/1Ss1Y+5e5M=
-github.com/bogdanfinn/quic-go-utls v1.0.9-utls h1:tV6eDEiRbRCcepALSzxR94JUVD3N3ACIiRLgyc2Ep8s=
-github.com/bogdanfinn/quic-go-utls v1.0.9-utls/go.mod h1:aHph9B9H9yPOt5xnhWKSOum27DJAqpiHzwX+gjvaXcg=
-github.com/bogdanfinn/tls-client v1.14.0 h1:vyk7Cn4BIvLAGVuMfb0tP22OqogfO1lYamquQNEZU1A=
-github.com/bogdanfinn/tls-client v1.14.0/go.mod h1:LsU6mXVn8MOFDwTkyRfI7V1BZM1p0wf2ZfZsICW/1fM=
-github.com/bogdanfinn/utls v1.7.7-barnius h1:OuJ497cc7F3yKNVHRsYPQdGggmk5x6+V5ZlrCR7fOLU=
-github.com/bogdanfinn/utls v1.7.7-barnius/go.mod h1:aAK1VZQlpKZClF1WEQeq6kyclbkPq4hz6xTbB5xSlmg=
-github.com/bogdanfinn/websocket v1.5.5-barnius h1:bY+qnxpai1qe7Jmjx+Sds/cmOSpuuLoR8x61rWltjOI=
-github.com/bogdanfinn/websocket v1.5.5-barnius/go.mod h1:gvvEw6pTKHb7yOiFvIfAFTStQWyrm25BMVCTj5wRSsI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
@@ -94,10 +80,6 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
-github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
-github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
-github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
-github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
@@ -180,6 +162,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
+github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
+github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
@@ -236,8 +220,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
-github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
-github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
+github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
+github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -302,6 +286,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
+github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@@ -314,8 +300,18 @@ github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
+github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
+github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
+github.com/smartwalle/alipay/v3 v3.2.29 h1:roGFqlml8hDa//0TpFmlyxZhndTYs7rbYLu/HlNFNJo=
+github.com/smartwalle/alipay/v3 v3.2.29/go.mod h1:XarBLuAkwK3ah7mYjVtghRu+ysxzlex9sRkgqNMzMRU=
+github.com/smartwalle/ncrypto v1.0.4 h1:P2rqQxDepJwgeO5ShoC+wGcK2wNJDmcdBOWAksuIgx8=
+github.com/smartwalle/ncrypto v1.0.4/go.mod h1:Dwlp6sfeNaPMnOxMNayMTacvC5JGEVln3CVdiVDgbBk=
+github.com/smartwalle/ngx v1.1.0 h1:q8nANgWSPRGeI/u+ixBoA4mf68DrUq6vZ+n9L5UKv9I=
+github.com/smartwalle/ngx v1.1.0/go.mod h1:mx/nz2Pk5j+RBs7t6u6k22MPiBG/8CtOMpCnALIG8Y0=
+github.com/smartwalle/nsign v1.0.9 h1:8poAgG7zBd8HkZy9RQDwasC6XZvJpDGQWSjzL2FZL6E=
+github.com/smartwalle/nsign v1.0.9/go.mod h1:eY6I4CJlyNdVMP+t6z1H6Jpd4m5/V+8xi44ufSTxXgc=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
@@ -345,10 +341,10 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
+github.com/stripe/stripe-go/v85 v85.0.0 h1:HMlFJXW6I/9WvkeSAtj8V7dI5pzeDu4gS1TaqR1ccI4=
+github.com/stripe/stripe-go/v85 v85.0.0/go.mod h1:5P+HGFenpWgak27T5Is6JMsmDfUC1yJnjhhmquz7kXw=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
-github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
-github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5/go.mod h1:2JjD2zLQYH5HO74y5+aE3remJQvl6q4Sn6aWA2wD1Ng=
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk=
@@ -372,6 +368,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
+github.com/wechatpay-apiv3/wechatpay-go v0.2.21 h1:uIyMpzvcaHA33W/QPtHstccw+X52HO1gFdvVL9O6Lfs=
+github.com/wechatpay-apiv3/wechatpay-go v0.2.21/go.mod h1:A254AUBVB6R+EqQFo3yTgeh7HtyqRRtN2w9hQSOrd4Q=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
@@ -415,21 +413,20 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
-golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
-golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
+golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
+golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
-golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
-golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
-golang.org/x/net v0.0.0-20211104170005-ce137452f963/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
-golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
-golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
-golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
-golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
+golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww=
+golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA=
+golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
+golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
+golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
+golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
+golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
+golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -437,19 +434,16 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
-golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
-golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
-golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
-golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
-golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
+golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
+golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
+golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
+golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
+golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
+golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
-golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
-golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
+golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
+golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index 3ee5d6cd..87263db0 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -28,7 +28,7 @@ const (
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
-const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
+const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com https://*.stripe.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// UMQ(用户消息队列)模式常量
const (
@@ -52,6 +52,11 @@ const (
ConnectionPoolIsolationAccountProxy = "account_proxy"
)
+// DefaultUpstreamResponseReadMaxBytes 上游非流式响应体的默认读取上限。
+// 128 MB 可容纳 2-3 张 4K PNG(base64 膨胀 33%,单张 4K PNG 最坏约 67MB base64)。
+// 可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。
+const DefaultUpstreamResponseReadMaxBytes int64 = 128 * 1024 * 1024
+
type Config struct {
Server ServerConfig `mapstructure:"server"`
Log LogConfig `mapstructure:"log"`
@@ -65,6 +70,8 @@ type Config struct {
JWT JWTConfig `mapstructure:"jwt"`
Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
+ WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
+ OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Pricing PricingConfig `mapstructure:"pricing"`
@@ -77,7 +84,6 @@ type Config struct {
UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
- Sora SoraConfig `mapstructure:"sora"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
Gemini GeminiConfig `mapstructure:"gemini"`
@@ -185,6 +191,274 @@ type LinuxDoConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
+type WeChatConnectConfig struct {
+ Enabled bool `mapstructure:"enabled"`
+ AppID string `mapstructure:"app_id"`
+ AppSecret string `mapstructure:"app_secret"`
+ OpenAppID string `mapstructure:"open_app_id"`
+ OpenAppSecret string `mapstructure:"open_app_secret"`
+ MPAppID string `mapstructure:"mp_app_id"`
+ MPAppSecret string `mapstructure:"mp_app_secret"`
+ MobileAppID string `mapstructure:"mobile_app_id"`
+ MobileAppSecret string `mapstructure:"mobile_app_secret"`
+ OpenEnabled bool `mapstructure:"open_enabled"`
+ MPEnabled bool `mapstructure:"mp_enabled"`
+ MobileEnabled bool `mapstructure:"mobile_enabled"`
+ Mode string `mapstructure:"mode"`
+ Scopes string `mapstructure:"scopes"`
+ RedirectURL string `mapstructure:"redirect_url"`
+ FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
+}
+
+type OIDCConnectConfig struct {
+ Enabled bool `mapstructure:"enabled"`
+ ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等
+ ClientID string `mapstructure:"client_id"`
+ ClientSecret string `mapstructure:"client_secret"`
+ IssuerURL string `mapstructure:"issuer_url"`
+ DiscoveryURL string `mapstructure:"discovery_url"`
+ AuthorizeURL string `mapstructure:"authorize_url"`
+ TokenURL string `mapstructure:"token_url"`
+ UserInfoURL string `mapstructure:"userinfo_url"`
+ JWKSURL string `mapstructure:"jwks_url"`
+ Scopes string `mapstructure:"scopes"` // 默认 "openid email profile"
+ RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
+ FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback)
+ TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
+ UsePKCE bool `mapstructure:"use_pkce"`
+ ValidateIDToken bool `mapstructure:"validate_id_token"`
+ UsePKCEExplicit bool `mapstructure:"-" yaml:"-"`
+ ValidateIDTokenExplicit bool `mapstructure:"-" yaml:"-"`
+ AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256"
+ ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120
+ RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false
+
+ // 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
+ // 为空时,服务端会尝试一组常见字段名。
+ UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
+ UserInfoIDPath string `mapstructure:"userinfo_id_path"`
+ UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
+}
+
+const (
+ defaultWeChatConnectMode = "open"
+ defaultWeChatConnectScopes = "snsapi_login"
+ defaultWeChatConnectFrontendRedirect = "/auth/wechat/callback"
+)
+
+func firstNonEmptyString(values ...string) string {
+ for _, value := range values {
+ if trimmed := strings.TrimSpace(value); trimmed != "" {
+ return trimmed
+ }
+ }
+ return ""
+}
+
+func normalizeWeChatConnectMode(raw string) string {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "mp":
+ return "mp"
+ case "mobile":
+ return "mobile"
+ default:
+ return defaultWeChatConnectMode
+ }
+}
+
+func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string {
+ mode = normalizeWeChatConnectMode(mode)
+ switch mode {
+ case "open":
+ if openEnabled {
+ return "open"
+ }
+ case "mp":
+ if mpEnabled {
+ return "mp"
+ }
+ case "mobile":
+ if mobileEnabled {
+ return "mobile"
+ }
+ }
+ switch {
+ case openEnabled:
+ return "open"
+ case mpEnabled:
+ return "mp"
+ case mobileEnabled:
+ return "mobile"
+ default:
+ return mode
+ }
+}
+
+func defaultWeChatConnectScopesForMode(mode string) string {
+ switch normalizeWeChatConnectMode(mode) {
+ case "mp":
+ return "snsapi_userinfo"
+ case "mobile":
+ return ""
+ default:
+ return defaultWeChatConnectScopes
+ }
+}
+
+func normalizeWeChatConnectScopes(raw, mode string) string {
+ switch normalizeWeChatConnectMode(mode) {
+ case "mp":
+ switch strings.TrimSpace(raw) {
+ case "snsapi_base":
+ return "snsapi_base"
+ case "snsapi_userinfo":
+ return "snsapi_userinfo"
+ default:
+ return defaultWeChatConnectScopesForMode(mode)
+ }
+ case "mobile":
+ return ""
+ default:
+ return defaultWeChatConnectScopes
+ }
+}
+
+func shouldApplyLegacyWeChatEnv(configKey, envKey string) bool {
+ if viper.InConfig(configKey) {
+ return false
+ }
+ _, hasNewEnv := os.LookupEnv(envKey)
+ return !hasNewEnv
+}
+
+func hasExplicitConfigOrEnv(configKey, envKey string) bool {
+ if viper.InConfig(configKey) {
+ return true
+ }
+ _, ok := os.LookupEnv(envKey)
+ return ok
+}
+
+func applyLegacyWeChatConnectEnvCompatibility(cfg *WeChatConnectConfig) {
+ if cfg == nil {
+ return
+ }
+
+ legacyOpenAppID := ""
+ if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_id", "WECHAT_CONNECT_OPEN_APP_ID") &&
+ shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") {
+ legacyOpenAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID"))
+ if legacyOpenAppID != "" {
+ cfg.OpenAppID = legacyOpenAppID
+ }
+ }
+
+ legacyOpenAppSecret := ""
+ if shouldApplyLegacyWeChatEnv("wechat_connect.open_app_secret", "WECHAT_CONNECT_OPEN_APP_SECRET") &&
+ shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") {
+ legacyOpenAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET"))
+ if legacyOpenAppSecret != "" {
+ cfg.OpenAppSecret = legacyOpenAppSecret
+ }
+ }
+
+ legacyMPAppID := ""
+ if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_id", "WECHAT_CONNECT_MP_APP_ID") &&
+ shouldApplyLegacyWeChatEnv("wechat_connect.app_id", "WECHAT_CONNECT_APP_ID") {
+ legacyMPAppID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID"))
+ if legacyMPAppID != "" {
+ cfg.MPAppID = legacyMPAppID
+ }
+ }
+
+ legacyMPAppSecret := ""
+ if shouldApplyLegacyWeChatEnv("wechat_connect.mp_app_secret", "WECHAT_CONNECT_MP_APP_SECRET") &&
+ shouldApplyLegacyWeChatEnv("wechat_connect.app_secret", "WECHAT_CONNECT_APP_SECRET") {
+ legacyMPAppSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET"))
+ if legacyMPAppSecret != "" {
+ cfg.MPAppSecret = legacyMPAppSecret
+ }
+ }
+
+ if shouldApplyLegacyWeChatEnv("wechat_connect.frontend_redirect_url", "WECHAT_CONNECT_FRONTEND_REDIRECT_URL") {
+ if legacyFrontend := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")); legacyFrontend != "" {
+ cfg.FrontendRedirectURL = legacyFrontend
+ }
+ }
+
+ hasLegacyOpen := legacyOpenAppID != "" && legacyOpenAppSecret != ""
+ hasLegacyMP := legacyMPAppID != "" && legacyMPAppSecret != ""
+
+ if shouldApplyLegacyWeChatEnv("wechat_connect.enabled", "WECHAT_CONNECT_ENABLED") && (hasLegacyOpen || hasLegacyMP) {
+ cfg.Enabled = true
+ }
+ if shouldApplyLegacyWeChatEnv("wechat_connect.open_enabled", "WECHAT_CONNECT_OPEN_ENABLED") && hasLegacyOpen {
+ cfg.OpenEnabled = true
+ }
+ if shouldApplyLegacyWeChatEnv("wechat_connect.mp_enabled", "WECHAT_CONNECT_MP_ENABLED") && hasLegacyMP {
+ cfg.MPEnabled = true
+ }
+ if shouldApplyLegacyWeChatEnv("wechat_connect.mode", "WECHAT_CONNECT_MODE") {
+ switch {
+ case hasLegacyMP && !hasLegacyOpen:
+ cfg.Mode = "mp"
+ case hasLegacyOpen:
+ cfg.Mode = "open"
+ }
+ }
+ if shouldApplyLegacyWeChatEnv("wechat_connect.scopes", "WECHAT_CONNECT_SCOPES") {
+ switch {
+ case hasLegacyMP && !hasLegacyOpen:
+ cfg.Scopes = defaultWeChatConnectScopesForMode("mp")
+ case hasLegacyOpen:
+ cfg.Scopes = defaultWeChatConnectScopesForMode("open")
+ }
+ }
+}
+
+func normalizeWeChatConnectConfig(cfg *WeChatConnectConfig) {
+ if cfg == nil {
+ return
+ }
+
+ cfg.AppID = strings.TrimSpace(cfg.AppID)
+ cfg.AppSecret = strings.TrimSpace(cfg.AppSecret)
+ cfg.OpenAppID = strings.TrimSpace(cfg.OpenAppID)
+ cfg.OpenAppSecret = strings.TrimSpace(cfg.OpenAppSecret)
+ cfg.MPAppID = strings.TrimSpace(cfg.MPAppID)
+ cfg.MPAppSecret = strings.TrimSpace(cfg.MPAppSecret)
+ cfg.MobileAppID = strings.TrimSpace(cfg.MobileAppID)
+ cfg.MobileAppSecret = strings.TrimSpace(cfg.MobileAppSecret)
+ cfg.Mode = normalizeWeChatConnectMode(cfg.Mode)
+ cfg.RedirectURL = strings.TrimSpace(cfg.RedirectURL)
+ cfg.FrontendRedirectURL = strings.TrimSpace(cfg.FrontendRedirectURL)
+
+ cfg.AppID = firstNonEmptyString(cfg.AppID, cfg.OpenAppID, cfg.MPAppID, cfg.MobileAppID)
+ cfg.AppSecret = firstNonEmptyString(cfg.AppSecret, cfg.OpenAppSecret, cfg.MPAppSecret, cfg.MobileAppSecret)
+ cfg.OpenAppID = firstNonEmptyString(cfg.OpenAppID, cfg.AppID)
+ cfg.OpenAppSecret = firstNonEmptyString(cfg.OpenAppSecret, cfg.AppSecret)
+ cfg.MPAppID = firstNonEmptyString(cfg.MPAppID, cfg.AppID)
+ cfg.MPAppSecret = firstNonEmptyString(cfg.MPAppSecret, cfg.AppSecret)
+ cfg.MobileAppID = firstNonEmptyString(cfg.MobileAppID, cfg.AppID)
+ cfg.MobileAppSecret = firstNonEmptyString(cfg.MobileAppSecret, cfg.AppSecret)
+
+ if !cfg.OpenEnabled && !cfg.MPEnabled && !cfg.MobileEnabled && cfg.Enabled {
+ switch cfg.Mode {
+ case "mp":
+ cfg.MPEnabled = true
+ case "mobile":
+ cfg.MobileEnabled = true
+ default:
+ cfg.OpenEnabled = true
+ }
+ }
+ cfg.Mode = normalizeWeChatConnectStoredMode(cfg.OpenEnabled, cfg.MPEnabled, cfg.MobileEnabled, cfg.Mode)
+ cfg.Scopes = normalizeWeChatConnectScopes(cfg.Scopes, cfg.Mode)
+ if cfg.FrontendRedirectURL == "" {
+ cfg.FrontendRedirectURL = defaultWeChatConnectFrontendRedirect
+ }
+}
+
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
@@ -197,8 +471,6 @@ type TokenRefreshConfig struct {
MaxRetries int `mapstructure:"max_retries"`
// 重试退避基础时间(秒)
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
- // 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭)
- SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"`
}
type PricingConfig struct {
@@ -303,59 +575,6 @@ type ConcurrencyConfig struct {
PingInterval int `mapstructure:"ping_interval"`
}
-// SoraConfig 直连 Sora 配置
-type SoraConfig struct {
- Client SoraClientConfig `mapstructure:"client"`
- Storage SoraStorageConfig `mapstructure:"storage"`
-}
-
-// SoraClientConfig 直连 Sora 客户端配置
-type SoraClientConfig struct {
- BaseURL string `mapstructure:"base_url"`
- TimeoutSeconds int `mapstructure:"timeout_seconds"`
- MaxRetries int `mapstructure:"max_retries"`
- CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"`
- PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
- MaxPollAttempts int `mapstructure:"max_poll_attempts"`
- RecentTaskLimit int `mapstructure:"recent_task_limit"`
- RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"`
- Debug bool `mapstructure:"debug"`
- UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
- Headers map[string]string `mapstructure:"headers"`
- UserAgent string `mapstructure:"user_agent"`
- DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
- CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"`
-}
-
-// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置
-type SoraCurlCFFISidecarConfig struct {
- Enabled bool `mapstructure:"enabled"`
- BaseURL string `mapstructure:"base_url"`
- Impersonate string `mapstructure:"impersonate"`
- TimeoutSeconds int `mapstructure:"timeout_seconds"`
- SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"`
- SessionTTLSeconds int `mapstructure:"session_ttl_seconds"`
-}
-
-// SoraStorageConfig 媒体存储配置
-type SoraStorageConfig struct {
- Type string `mapstructure:"type"`
- LocalPath string `mapstructure:"local_path"`
- FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
- MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
- DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"`
- MaxDownloadBytes int64 `mapstructure:"max_download_bytes"`
- Debug bool `mapstructure:"debug"`
- Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
-}
-
-// SoraStorageCleanupConfig 媒体清理配置
-type SoraStorageCleanupConfig struct {
- Enabled bool `mapstructure:"enabled"`
- Schedule string `mapstructure:"schedule"`
- RetentionDays int `mapstructure:"retention_days"`
-}
-
// GatewayConfig API网关相关配置
type GatewayConfig struct {
// 等待上游响应头的超时时间(秒),0表示无超时
@@ -374,6 +593,12 @@ type GatewayConfig struct {
// ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。
// 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。
ForceCodexCLI bool `mapstructure:"force_codex_cli"`
+ // ForcedCodexInstructionsTemplateFile: 服务端强制附加到 Codex 顶层 instructions 的模板文件路径。
+ // 模板渲染后会直接覆盖最终 instructions;若需要保留客户端 system 转换结果,请在模板中显式引用 {{ .ExistingInstructions }}。
+ ForcedCodexInstructionsTemplateFile string `mapstructure:"forced_codex_instructions_template_file"`
+ // ForcedCodexInstructionsTemplate: 启动时从模板文件读取并缓存的模板内容。
+ // 该字段不直接参与配置反序列化,仅用于请求热路径避免重复读盘。
+ ForcedCodexInstructionsTemplate string `mapstructure:"-"`
// OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头
// 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
@@ -424,24 +649,6 @@ type GatewayConfig struct {
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
FailoverOn400 bool `mapstructure:"failover_on_400"`
- // Sora 专用配置
- // SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size)
- SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"`
- // SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制)
- SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"`
- // SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制)
- SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"`
- // SoraStreamMode: stream 强制策略(force/error)
- SoraStreamMode string `mapstructure:"sora_stream_mode"`
- // SoraModelFilters: 模型列表过滤配置
- SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"`
- // SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key
- SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"`
- // SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名)
- SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"`
- // SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用)
- SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"`
-
// 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限)
MaxAccountSwitches int `mapstructure:"max_account_switches"`
// Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格)
@@ -639,12 +846,6 @@ type GatewayUsageRecordConfig struct {
AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"`
}
-// SoraModelFiltersConfig Sora 模型过滤配置
-type SoraModelFiltersConfig struct {
- // HidePromptEnhance 是否隐藏 prompt-enhance 模型
- HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"`
-}
-
// TLSFingerprintConfig TLS指纹伪装配置
// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端
type TLSFingerprintConfig struct {
@@ -700,6 +901,10 @@ type GatewaySchedulingConfig struct {
// 负载计算
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
+ // 快照桶读取时的 MGET 分块大小
+ SnapshotMGetChunkSize int `mapstructure:"snapshot_mget_chunk_size"`
+ // 快照重建时的缓存写入分块大小
+ SnapshotWriteChunkSize int `mapstructure:"snapshot_write_chunk_size"`
// 过期槽位清理周期(0 表示禁用)
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
@@ -1048,6 +1253,27 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath)
cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath)
cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath)
+ applyLegacyWeChatConnectEnvCompatibility(&cfg.WeChat)
+ normalizeWeChatConnectConfig(&cfg.WeChat)
+ cfg.OIDC.ProviderName = strings.TrimSpace(cfg.OIDC.ProviderName)
+ cfg.OIDC.ClientID = strings.TrimSpace(cfg.OIDC.ClientID)
+ cfg.OIDC.ClientSecret = strings.TrimSpace(cfg.OIDC.ClientSecret)
+ cfg.OIDC.IssuerURL = strings.TrimSpace(cfg.OIDC.IssuerURL)
+ cfg.OIDC.DiscoveryURL = strings.TrimSpace(cfg.OIDC.DiscoveryURL)
+ cfg.OIDC.AuthorizeURL = strings.TrimSpace(cfg.OIDC.AuthorizeURL)
+ cfg.OIDC.TokenURL = strings.TrimSpace(cfg.OIDC.TokenURL)
+ cfg.OIDC.UserInfoURL = strings.TrimSpace(cfg.OIDC.UserInfoURL)
+ cfg.OIDC.JWKSURL = strings.TrimSpace(cfg.OIDC.JWKSURL)
+ cfg.OIDC.Scopes = strings.TrimSpace(cfg.OIDC.Scopes)
+ cfg.OIDC.RedirectURL = strings.TrimSpace(cfg.OIDC.RedirectURL)
+ cfg.OIDC.FrontendRedirectURL = strings.TrimSpace(cfg.OIDC.FrontendRedirectURL)
+ cfg.OIDC.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.OIDC.TokenAuthMethod))
+ cfg.OIDC.AllowedSigningAlgs = strings.TrimSpace(cfg.OIDC.AllowedSigningAlgs)
+ cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath)
+ cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath)
+ cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath)
+ cfg.OIDC.UsePKCEExplicit = hasExplicitConfigOrEnv("oidc_connect.use_pkce", "OIDC_CONNECT_USE_PKCE")
+ cfg.OIDC.ValidateIDTokenExplicit = hasExplicitConfigOrEnv("oidc_connect.validate_id_token", "OIDC_CONNECT_VALIDATE_ID_TOKEN")
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
@@ -1059,6 +1285,14 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment)
cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel))
cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath)
+ cfg.Gateway.ForcedCodexInstructionsTemplateFile = strings.TrimSpace(cfg.Gateway.ForcedCodexInstructionsTemplateFile)
+ if cfg.Gateway.ForcedCodexInstructionsTemplateFile != "" {
+ content, err := os.ReadFile(cfg.Gateway.ForcedCodexInstructionsTemplateFile)
+ if err != nil {
+ return nil, fmt.Errorf("read forced codex instructions template %q: %w", cfg.Gateway.ForcedCodexInstructionsTemplateFile, err)
+ }
+ cfg.Gateway.ForcedCodexInstructionsTemplate = string(content)
+ }
// 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。
// 新键未配置(<=0)时回退旧键;新键优先。
@@ -1218,6 +1452,48 @@ func setDefaults() {
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
+ // WeChat Connect OAuth 登录
+ viper.SetDefault("wechat_connect.enabled", false)
+ viper.SetDefault("wechat_connect.app_id", "")
+ viper.SetDefault("wechat_connect.app_secret", "")
+ viper.SetDefault("wechat_connect.open_app_id", "")
+ viper.SetDefault("wechat_connect.open_app_secret", "")
+ viper.SetDefault("wechat_connect.mp_app_id", "")
+ viper.SetDefault("wechat_connect.mp_app_secret", "")
+ viper.SetDefault("wechat_connect.mobile_app_id", "")
+ viper.SetDefault("wechat_connect.mobile_app_secret", "")
+ viper.SetDefault("wechat_connect.open_enabled", false)
+ viper.SetDefault("wechat_connect.mp_enabled", false)
+ viper.SetDefault("wechat_connect.mobile_enabled", false)
+ viper.SetDefault("wechat_connect.mode", defaultWeChatConnectMode)
+ viper.SetDefault("wechat_connect.scopes", defaultWeChatConnectScopes)
+ viper.SetDefault("wechat_connect.redirect_url", "")
+ viper.SetDefault("wechat_connect.frontend_redirect_url", defaultWeChatConnectFrontendRedirect)
+
+ // Generic OIDC OAuth 登录
+ viper.SetDefault("oidc_connect.enabled", false)
+ viper.SetDefault("oidc_connect.provider_name", "OIDC")
+ viper.SetDefault("oidc_connect.client_id", "")
+ viper.SetDefault("oidc_connect.client_secret", "")
+ viper.SetDefault("oidc_connect.issuer_url", "")
+ viper.SetDefault("oidc_connect.discovery_url", "")
+ viper.SetDefault("oidc_connect.authorize_url", "")
+ viper.SetDefault("oidc_connect.token_url", "")
+ viper.SetDefault("oidc_connect.userinfo_url", "")
+ viper.SetDefault("oidc_connect.jwks_url", "")
+ viper.SetDefault("oidc_connect.scopes", "openid email profile")
+ viper.SetDefault("oidc_connect.redirect_url", "")
+ viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback")
+ viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post")
+ viper.SetDefault("oidc_connect.use_pkce", true)
+ viper.SetDefault("oidc_connect.validate_id_token", true)
+ viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256")
+ viper.SetDefault("oidc_connect.clock_skew_seconds", 120)
+ viper.SetDefault("oidc_connect.require_email_verified", false)
+ viper.SetDefault("oidc_connect.userinfo_email_path", "")
+ viper.SetDefault("oidc_connect.userinfo_id_path", "")
+ viper.SetDefault("oidc_connect.userinfo_username_path", "")
+
// Database
viper.SetDefault("database.host", "localhost")
viper.SetDefault("database.port", 5432)
@@ -1399,16 +1675,9 @@ func setDefaults() {
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
- viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
+ viper.SetDefault("gateway.upstream_response_read_max_bytes", DefaultUpstreamResponseReadMaxBytes)
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
- viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024))
- viper.SetDefault("gateway.sora_stream_timeout_seconds", 900)
- viper.SetDefault("gateway.sora_request_timeout_seconds", 180)
- viper.SetDefault("gateway.sora_stream_mode", "force")
- viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true)
- viper.SetDefault("gateway.sora_media_require_api_key", true)
- viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900)
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化)
viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大)
@@ -1427,6 +1696,8 @@ func setDefaults() {
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
+ viper.SetDefault("gateway.scheduling.snapshot_mget_chunk_size", 128)
+ viper.SetDefault("gateway.scheduling.snapshot_write_chunk_size", 256)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)
viper.SetDefault("gateway.scheduling.db_fallback_timeout_seconds", 0)
@@ -1465,45 +1736,12 @@ func setDefaults() {
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
viper.SetDefault("concurrency.ping_interval", 10)
- // Sora 直连配置
- viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend")
- viper.SetDefault("sora.client.timeout_seconds", 120)
- viper.SetDefault("sora.client.max_retries", 3)
- viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900)
- viper.SetDefault("sora.client.poll_interval_seconds", 2)
- viper.SetDefault("sora.client.max_poll_attempts", 600)
- viper.SetDefault("sora.client.recent_task_limit", 50)
- viper.SetDefault("sora.client.recent_task_limit_max", 200)
- viper.SetDefault("sora.client.debug", false)
- viper.SetDefault("sora.client.use_openai_token_provider", false)
- viper.SetDefault("sora.client.headers", map[string]string{})
- viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
- viper.SetDefault("sora.client.disable_tls_fingerprint", false)
- viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true)
- viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080")
- viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131")
- viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60)
- viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true)
- viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600)
-
- viper.SetDefault("sora.storage.type", "local")
- viper.SetDefault("sora.storage.local_path", "")
- viper.SetDefault("sora.storage.fallback_to_upstream", true)
- viper.SetDefault("sora.storage.max_concurrent_downloads", 4)
- viper.SetDefault("sora.storage.download_timeout_seconds", 120)
- viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20))
- viper.SetDefault("sora.storage.debug", false)
- viper.SetDefault("sora.storage.cleanup.enabled", true)
- viper.SetDefault("sora.storage.cleanup.retention_days", 7)
- viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *")
-
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token)
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
- viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token
// Gemini OAuth - configure via environment variables or config file
// GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET
@@ -1659,9 +1897,6 @@ func (c *Config) Validate() error {
default:
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
- if method == "none" && !c.LinuxDo.UsePKCE {
- return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
- }
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
@@ -1692,6 +1927,123 @@ func (c *Config) Validate() error {
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
}
+ if c.WeChat.Enabled {
+ weChat := c.WeChat
+ normalizeWeChatConnectConfig(&weChat)
+
+ if weChat.OpenEnabled {
+ if strings.TrimSpace(weChat.OpenAppID) == "" {
+ return fmt.Errorf("wechat_connect.open_app_id is required when wechat_connect.open_enabled=true")
+ }
+ if strings.TrimSpace(weChat.OpenAppSecret) == "" {
+ return fmt.Errorf("wechat_connect.open_app_secret is required when wechat_connect.open_enabled=true")
+ }
+ }
+ if weChat.MPEnabled {
+ if strings.TrimSpace(weChat.MPAppID) == "" {
+ return fmt.Errorf("wechat_connect.mp_app_id is required when wechat_connect.mp_enabled=true")
+ }
+ if strings.TrimSpace(weChat.MPAppSecret) == "" {
+ return fmt.Errorf("wechat_connect.mp_app_secret is required when wechat_connect.mp_enabled=true")
+ }
+ }
+ if weChat.MobileEnabled {
+ if strings.TrimSpace(weChat.MobileAppID) == "" {
+ return fmt.Errorf("wechat_connect.mobile_app_id is required when wechat_connect.mobile_enabled=true")
+ }
+ if strings.TrimSpace(weChat.MobileAppSecret) == "" {
+ return fmt.Errorf("wechat_connect.mobile_app_secret is required when wechat_connect.mobile_enabled=true")
+ }
+ }
+ if v := strings.TrimSpace(weChat.RedirectURL); v != "" {
+ if err := ValidateAbsoluteHTTPURL(v); err != nil {
+ return fmt.Errorf("wechat_connect.redirect_url invalid: %w", err)
+ }
+ warnIfInsecureURL("wechat_connect.redirect_url", v)
+ }
+ if err := ValidateFrontendRedirectURL(weChat.FrontendRedirectURL); err != nil {
+ return fmt.Errorf("wechat_connect.frontend_redirect_url invalid: %w", err)
+ }
+ warnIfInsecureURL("wechat_connect.frontend_redirect_url", weChat.FrontendRedirectURL)
+ }
+ if c.OIDC.Enabled {
+ if strings.TrimSpace(c.OIDC.ClientID) == "" {
+ return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true")
+ }
+ if strings.TrimSpace(c.OIDC.IssuerURL) == "" {
+ return fmt.Errorf("oidc_connect.issuer_url is required when oidc_connect.enabled=true")
+ }
+ if strings.TrimSpace(c.OIDC.RedirectURL) == "" {
+ return fmt.Errorf("oidc_connect.redirect_url is required when oidc_connect.enabled=true")
+ }
+ if strings.TrimSpace(c.OIDC.FrontendRedirectURL) == "" {
+ return fmt.Errorf("oidc_connect.frontend_redirect_url is required when oidc_connect.enabled=true")
+ }
+ if !scopeContainsOpenID(c.OIDC.Scopes) {
+ return fmt.Errorf("oidc_connect.scopes must contain openid")
+ }
+
+ method := strings.ToLower(strings.TrimSpace(c.OIDC.TokenAuthMethod))
+ switch method {
+ case "", "client_secret_post", "client_secret_basic", "none":
+ default:
+ return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
+ }
+ if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
+ strings.TrimSpace(c.OIDC.ClientSecret) == "" {
+ return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
+ }
+ if c.OIDC.ClockSkewSeconds < 0 || c.OIDC.ClockSkewSeconds > 600 {
+ return fmt.Errorf("oidc_connect.clock_skew_seconds must be between 0 and 600")
+ }
+ if c.OIDC.ValidateIDToken && strings.TrimSpace(c.OIDC.AllowedSigningAlgs) == "" {
+ return fmt.Errorf("oidc_connect.allowed_signing_algs is required when oidc_connect.validate_id_token=true")
+ }
+
+ if err := ValidateAbsoluteHTTPURL(c.OIDC.IssuerURL); err != nil {
+ return fmt.Errorf("oidc_connect.issuer_url invalid: %w", err)
+ }
+ if v := strings.TrimSpace(c.OIDC.DiscoveryURL); v != "" {
+ if err := ValidateAbsoluteHTTPURL(v); err != nil {
+ return fmt.Errorf("oidc_connect.discovery_url invalid: %w", err)
+ }
+ }
+ if v := strings.TrimSpace(c.OIDC.AuthorizeURL); v != "" {
+ if err := ValidateAbsoluteHTTPURL(v); err != nil {
+ return fmt.Errorf("oidc_connect.authorize_url invalid: %w", err)
+ }
+ }
+ if v := strings.TrimSpace(c.OIDC.TokenURL); v != "" {
+ if err := ValidateAbsoluteHTTPURL(v); err != nil {
+ return fmt.Errorf("oidc_connect.token_url invalid: %w", err)
+ }
+ }
+ if v := strings.TrimSpace(c.OIDC.UserInfoURL); v != "" {
+ if err := ValidateAbsoluteHTTPURL(v); err != nil {
+ return fmt.Errorf("oidc_connect.userinfo_url invalid: %w", err)
+ }
+ }
+ if v := strings.TrimSpace(c.OIDC.JWKSURL); v != "" {
+ if err := ValidateAbsoluteHTTPURL(v); err != nil {
+ return fmt.Errorf("oidc_connect.jwks_url invalid: %w", err)
+ }
+ }
+ if err := ValidateAbsoluteHTTPURL(c.OIDC.RedirectURL); err != nil {
+ return fmt.Errorf("oidc_connect.redirect_url invalid: %w", err)
+ }
+ if err := ValidateFrontendRedirectURL(c.OIDC.FrontendRedirectURL); err != nil {
+ return fmt.Errorf("oidc_connect.frontend_redirect_url invalid: %w", err)
+ }
+
+ warnIfInsecureURL("oidc_connect.issuer_url", c.OIDC.IssuerURL)
+ warnIfInsecureURL("oidc_connect.discovery_url", c.OIDC.DiscoveryURL)
+ warnIfInsecureURL("oidc_connect.authorize_url", c.OIDC.AuthorizeURL)
+ warnIfInsecureURL("oidc_connect.token_url", c.OIDC.TokenURL)
+ warnIfInsecureURL("oidc_connect.userinfo_url", c.OIDC.UserInfoURL)
+ warnIfInsecureURL("oidc_connect.jwks_url", c.OIDC.JWKSURL)
+ warnIfInsecureURL("oidc_connect.redirect_url", c.OIDC.RedirectURL)
+ warnIfInsecureURL("oidc_connect.frontend_redirect_url", c.OIDC.FrontendRedirectURL)
+ }
if c.Billing.CircuitBreaker.Enabled {
if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
@@ -1879,86 +2231,6 @@ func (c *Config) Validate() error {
if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 {
return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive")
}
- if c.Gateway.SoraMaxBodySize < 0 {
- return fmt.Errorf("gateway.sora_max_body_size must be non-negative")
- }
- if c.Gateway.SoraStreamTimeoutSeconds < 0 {
- return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative")
- }
- if c.Gateway.SoraRequestTimeoutSeconds < 0 {
- return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative")
- }
- if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 {
- return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative")
- }
- if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" {
- switch mode {
- case "force", "error":
- default:
- return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error")
- }
- }
- if c.Sora.Client.TimeoutSeconds < 0 {
- return fmt.Errorf("sora.client.timeout_seconds must be non-negative")
- }
- if c.Sora.Client.MaxRetries < 0 {
- return fmt.Errorf("sora.client.max_retries must be non-negative")
- }
- if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 {
- return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative")
- }
- if c.Sora.Client.PollIntervalSeconds < 0 {
- return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative")
- }
- if c.Sora.Client.MaxPollAttempts < 0 {
- return fmt.Errorf("sora.client.max_poll_attempts must be non-negative")
- }
- if c.Sora.Client.RecentTaskLimit < 0 {
- return fmt.Errorf("sora.client.recent_task_limit must be non-negative")
- }
- if c.Sora.Client.RecentTaskLimitMax < 0 {
- return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative")
- }
- if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 &&
- c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit {
- c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit
- }
- if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 {
- return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative")
- }
- if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 {
- return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative")
- }
- if !c.Sora.Client.CurlCFFISidecar.Enabled {
- return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true")
- }
- if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" {
- return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required")
- }
- if c.Sora.Storage.MaxConcurrentDownloads < 0 {
- return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative")
- }
- if c.Sora.Storage.DownloadTimeoutSeconds < 0 {
- return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative")
- }
- if c.Sora.Storage.MaxDownloadBytes < 0 {
- return fmt.Errorf("sora.storage.max_download_bytes must be non-negative")
- }
- if c.Sora.Storage.Cleanup.Enabled {
- if c.Sora.Storage.Cleanup.RetentionDays <= 0 {
- return fmt.Errorf("sora.storage.cleanup.retention_days must be positive")
- }
- if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" {
- return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled")
- }
- } else {
- if c.Sora.Storage.Cleanup.RetentionDays < 0 {
- return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative")
- }
- }
- if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" {
- return fmt.Errorf("sora.storage.type must be 'local'")
- }
if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" {
switch c.Gateway.ConnectionPoolIsolation {
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
@@ -2201,6 +2473,12 @@ func (c *Config) Validate() error {
if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive")
}
+ if c.Gateway.Scheduling.SnapshotMGetChunkSize <= 0 {
+ return fmt.Errorf("gateway.scheduling.snapshot_mget_chunk_size must be positive")
+ }
+ if c.Gateway.Scheduling.SnapshotWriteChunkSize <= 0 {
+ return fmt.Errorf("gateway.scheduling.snapshot_write_chunk_size must be positive")
+ }
if c.Gateway.Scheduling.SlotCleanupInterval < 0 {
return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative")
}
@@ -2384,6 +2662,15 @@ func ValidateFrontendRedirectURL(raw string) error {
return nil
}
+func scopeContainsOpenID(scopes string) bool {
+ for _, scope := range strings.Fields(strings.ToLower(strings.TrimSpace(scopes))) {
+ if scope == "openid" {
+ return true
+ }
+ }
+ return false
+}
+
// isHTTPScheme 检查是否为 HTTP 或 HTTPS 协议
func isHTTPScheme(scheme string) bool {
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go
index abb76549..6ba86aa1 100644
--- a/backend/internal/config/config_test.go
+++ b/backend/internal/config/config_test.go
@@ -1,6 +1,8 @@
package config
import (
+ "os"
+ "path/filepath"
"strings"
"testing"
"time"
@@ -223,6 +225,70 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) {
}
}
+func TestLoadWeChatConnectConfigFromLegacyEnv(t *testing.T) {
+ resetViperWithJWTSecret(t)
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
+ t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
+ t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app")
+ t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wx-mp-secret")
+ t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/legacy-callback")
+
+ cfg, err := Load()
+ require.NoError(t, err)
+ require.True(t, cfg.WeChat.Enabled)
+ require.True(t, cfg.WeChat.OpenEnabled)
+ require.True(t, cfg.WeChat.MPEnabled)
+ require.False(t, cfg.WeChat.MobileEnabled)
+ require.Equal(t, "open", cfg.WeChat.Mode)
+ require.Equal(t, "wx-open-app", cfg.WeChat.OpenAppID)
+ require.Equal(t, "wx-open-secret", cfg.WeChat.OpenAppSecret)
+ require.Equal(t, "wx-mp-app", cfg.WeChat.MPAppID)
+ require.Equal(t, "wx-mp-secret", cfg.WeChat.MPAppSecret)
+ require.Equal(t, "/auth/wechat/legacy-callback", cfg.WeChat.FrontendRedirectURL)
+}
+
+func TestLoadDefaultOIDCSecurityDefaults(t *testing.T) {
+ resetViperWithJWTSecret(t)
+
+ cfg, err := Load()
+ require.NoError(t, err)
+ require.True(t, cfg.OIDC.UsePKCE)
+ require.True(t, cfg.OIDC.ValidateIDToken)
+ require.False(t, cfg.OIDC.UsePKCEExplicit)
+ require.False(t, cfg.OIDC.ValidateIDTokenExplicit)
+}
+
+func TestLoadExplicitOIDCSecurityDefaultsFromEnvMarksFlagsExplicit(t *testing.T) {
+ resetViperWithJWTSecret(t)
+ t.Setenv("OIDC_CONNECT_USE_PKCE", "false")
+ t.Setenv("OIDC_CONNECT_VALIDATE_ID_TOKEN", "false")
+
+ cfg, err := Load()
+ require.NoError(t, err)
+ require.False(t, cfg.OIDC.UsePKCE)
+ require.False(t, cfg.OIDC.ValidateIDToken)
+ require.True(t, cfg.OIDC.UsePKCEExplicit)
+ require.True(t, cfg.OIDC.ValidateIDTokenExplicit)
+}
+
+func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
+ resetViperWithJWTSecret(t)
+
+ tempDir := t.TempDir()
+ templatePath := filepath.Join(tempDir, "codex-instructions.md.tmpl")
+ configPath := filepath.Join(tempDir, "config.yaml")
+
+ require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644))
+ yamlSafePath := filepath.ToSlash(templatePath)
+ require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+yamlSafePath+"\"\n"), 0o644))
+ t.Setenv("DATA_DIR", tempDir)
+
+ cfg, err := Load()
+ require.NoError(t, err)
+ require.Equal(t, yamlSafePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile)
+ require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate)
+}
+
func TestLoadDefaultSecurityToggles(t *testing.T) {
resetViperWithJWTSecret(t)
@@ -314,7 +380,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
cfg.LinuxDo.ClientSecret = "test-secret"
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
- cfg.LinuxDo.UsePKCE = false
+ cfg.LinuxDo.UsePKCE = true
cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)"
err = cfg.Validate()
@@ -326,7 +392,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) {
}
}
-func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
+func TestValidateLinuxDoAllowsDisablingPKCEForCompatibility(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
@@ -343,11 +409,93 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) {
cfg.LinuxDo.UsePKCE = false
err = cfg.Validate()
- if err == nil {
- t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil")
+ if err != nil {
+ t.Fatalf("Validate() expected LinuxDo config without PKCE to pass for compatibility, got: %v", err)
}
- if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") {
- t.Fatalf("Validate() expected use_pkce error, got: %v", err)
+}
+
+func TestValidateOIDCScopesMustContainOpenID(t *testing.T) {
+ resetViperWithJWTSecret(t)
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ cfg.OIDC.Enabled = true
+ cfg.OIDC.ClientID = "oidc-client"
+ cfg.OIDC.ClientSecret = "oidc-secret"
+ cfg.OIDC.IssuerURL = "https://issuer.example.com"
+ cfg.OIDC.AuthorizeURL = "https://issuer.example.com/auth"
+ cfg.OIDC.TokenURL = "https://issuer.example.com/token"
+ cfg.OIDC.JWKSURL = "https://issuer.example.com/jwks"
+ cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback"
+ cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
+ cfg.OIDC.Scopes = "profile email"
+ cfg.OIDC.UsePKCE = true
+
+ err = cfg.Validate()
+ if err == nil {
+ t.Fatalf("Validate() expected error when scopes do not include openid, got nil")
+ }
+ if !strings.Contains(err.Error(), "oidc_connect.scopes") {
+ t.Fatalf("Validate() expected oidc_connect.scopes error, got: %v", err)
+ }
+}
+
+func TestValidateOIDCAllowsIssuerOnlyEndpointsWithDiscoveryFallback(t *testing.T) {
+ resetViperWithJWTSecret(t)
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ cfg.OIDC.Enabled = true
+ cfg.OIDC.ClientID = "oidc-client"
+ cfg.OIDC.ClientSecret = "oidc-secret"
+ cfg.OIDC.IssuerURL = "https://issuer.example.com"
+ cfg.OIDC.AuthorizeURL = ""
+ cfg.OIDC.TokenURL = ""
+ cfg.OIDC.JWKSURL = ""
+ cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback"
+ cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
+ cfg.OIDC.Scopes = "openid email profile"
+ cfg.OIDC.ValidateIDToken = true
+ cfg.OIDC.UsePKCE = true
+
+ err = cfg.Validate()
+ if err != nil {
+ t.Fatalf("Validate() expected issuer-only OIDC config to pass with discovery fallback, got: %v", err)
+ }
+}
+
+func TestValidateOIDCAllowsExplicitCompatibilityOverridesForPKCEAndIDTokenValidation(t *testing.T) {
+ resetViperWithJWTSecret(t)
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ cfg.OIDC.Enabled = true
+ cfg.OIDC.ClientID = "oidc-client"
+ cfg.OIDC.ClientSecret = "oidc-secret"
+ cfg.OIDC.IssuerURL = "https://issuer.example.com"
+ cfg.OIDC.AuthorizeURL = "https://issuer.example.com/auth"
+ cfg.OIDC.TokenURL = "https://issuer.example.com/token"
+ cfg.OIDC.UserInfoURL = "https://issuer.example.com/userinfo"
+ cfg.OIDC.RedirectURL = "https://example.com/api/v1/auth/oauth/oidc/callback"
+ cfg.OIDC.FrontendRedirectURL = "/auth/oidc/callback"
+ cfg.OIDC.Scopes = "openid email profile"
+ cfg.OIDC.UsePKCE = false
+ cfg.OIDC.ValidateIDToken = false
+ cfg.OIDC.JWKSURL = ""
+ cfg.OIDC.AllowedSigningAlgs = ""
+
+ err = cfg.Validate()
+ if err != nil {
+ t.Fatalf("Validate() expected OIDC config without PKCE/id_token validation to pass for compatibility, got: %v", err)
}
}
@@ -766,6 +914,7 @@ func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
+ cfg.LinuxDo.UsePKCE = true
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate() unexpected error: %v", err)
@@ -916,6 +1065,7 @@ func TestValidateConfigErrors(t *testing.T) {
name: "linuxdo client id required",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
+ c.LinuxDo.UsePKCE = true
c.LinuxDo.ClientID = ""
},
wantErr: "linuxdo_connect.client_id",
@@ -924,6 +1074,7 @@ func TestValidateConfigErrors(t *testing.T) {
name: "linuxdo token auth method",
mutate: func(c *Config) {
c.LinuxDo.Enabled = true
+ c.LinuxDo.UsePKCE = true
c.LinuxDo.ClientID = "client"
c.LinuxDo.ClientSecret = "secret"
c.LinuxDo.AuthorizeURL = "https://example.com/authorize"
@@ -1554,94 +1705,6 @@ func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) {
}
}
-func TestSoraCurlCFFISidecarDefaults(t *testing.T) {
- resetViperWithJWTSecret(t)
-
- cfg, err := Load()
- if err != nil {
- t.Fatalf("Load() error: %v", err)
- }
-
- if !cfg.Sora.Client.CurlCFFISidecar.Enabled {
- t.Fatalf("Sora curl_cffi sidecar should be enabled by default")
- }
- if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 {
- t.Fatalf("Sora cloudflare challenge cooldown should be positive by default")
- }
- if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" {
- t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default")
- }
- if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" {
- t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default")
- }
- if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled {
- t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default")
- }
- if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 {
- t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default")
- }
-}
-
-func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) {
- resetViperWithJWTSecret(t)
-
- cfg, err := Load()
- if err != nil {
- t.Fatalf("Load() error: %v", err)
- }
-
- cfg.Sora.Client.CurlCFFISidecar.Enabled = false
- err = cfg.Validate()
- if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") {
- t.Fatalf("Validate() error = %v, want sidecar enabled error", err)
- }
-}
-
-func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) {
- resetViperWithJWTSecret(t)
-
- cfg, err := Load()
- if err != nil {
- t.Fatalf("Load() error: %v", err)
- }
-
- cfg.Sora.Client.CurlCFFISidecar.BaseURL = " "
- err = cfg.Validate()
- if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") {
- t.Fatalf("Validate() error = %v, want sidecar base_url required error", err)
- }
-}
-
-func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) {
- resetViperWithJWTSecret(t)
-
- cfg, err := Load()
- if err != nil {
- t.Fatalf("Load() error: %v", err)
- }
-
- cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1
- err = cfg.Validate()
- if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") {
- t.Fatalf("Validate() error = %v, want sidecar session ttl error", err)
- }
-}
-
-func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) {
- resetViperWithJWTSecret(t)
-
- cfg, err := Load()
- if err != nil {
- t.Fatalf("Load() error: %v", err)
- }
-
- cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1
- err = cfg.Validate()
- if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") {
- t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err)
- }
-}
-
func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
resetViperWithJWTSecret(t)
cfg, err := Load()
diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go
index 4e69ca02..a57f7067 100644
--- a/backend/internal/domain/constants.go
+++ b/backend/internal/domain/constants.go
@@ -22,7 +22,6 @@ const (
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
- PlatformSora = "sora"
)
// Account type constants
@@ -72,6 +71,7 @@ const (
// 与前端 useModelWhitelist.ts 中的 antigravityDefaultMappings 保持一致
var DefaultAntigravityModelMapping = map[string]string{
// Claude 白名单
+ "claude-opus-4-7": "claude-opus-4-7", // 官方模型
"claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型
"claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射
"claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型
@@ -121,6 +121,7 @@ var DefaultAntigravityModelMapping = map[string]string{
// aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等)
var DefaultBedrockModelMapping = map[string]string{
// Claude Opus
+ "claude-opus-4-7": "us.anthropic.claude-opus-4-7-v1",
"claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1",
"claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1",
"claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0",
diff --git a/backend/internal/domain/openai_messages_dispatch.go b/backend/internal/domain/openai_messages_dispatch.go
new file mode 100644
index 00000000..6b018f1c
--- /dev/null
+++ b/backend/internal/domain/openai_messages_dispatch.go
@@ -0,0 +1,10 @@
+package domain
+
+// OpenAIMessagesDispatchModelConfig controls how Anthropic /v1/messages
+// requests are mapped onto OpenAI/Codex models.
+type OpenAIMessagesDispatchModelConfig struct {
+ OpusMappedModel string `json:"opus_mapped_model,omitempty"`
+ SonnetMappedModel string `json:"sonnet_mapped_model,omitempty"`
+ HaikuMappedModel string `json:"haiku_mapped_model,omitempty"`
+ ExactModelMappings map[string]string `json:"exact_model_mappings,omitempty"`
+}
diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go
index 12139b51..00da4821 100644
--- a/backend/internal/handler/admin/account_data.go
+++ b/backend/internal/handler/admin/account_data.go
@@ -10,6 +10,7 @@ import (
"log/slog"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -359,7 +360,7 @@ func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, e
pageSize := dataPageCap
var out []service.Proxy
for {
- items, total, err := h.adminService.ListProxies(ctx, page, pageSize, "", "", "")
+ items, total, err := h.adminService.ListProxies(ctx, page, pageSize, "", "", "", "created_at", "desc")
if err != nil {
return nil, err
}
@@ -372,12 +373,12 @@ func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, e
return out, nil
}
-func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, accountType, status, search string) ([]service.Account, error) {
+func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, accountType, status, search string, groupID int64, privacyMode, sortBy, sortOrder string) ([]service.Account, error) {
page := 1
pageSize := dataPageCap
var out []service.Account
for {
- items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0, "")
+ items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder)
if err != nil {
return nil, err
}
@@ -409,11 +410,28 @@ func (h *AccountHandler) resolveExportAccounts(ctx context.Context, ids []int64,
platform := c.Query("platform")
accountType := c.Query("type")
status := c.Query("status")
+ privacyMode := strings.TrimSpace(c.Query("privacy_mode"))
search := strings.TrimSpace(c.Query("search"))
+ sortBy := c.DefaultQuery("sort_by", "name")
+ sortOrder := c.DefaultQuery("sort_order", "asc")
if len(search) > 100 {
search = search[:100]
}
- return h.listAccountsFiltered(ctx, platform, accountType, status, search)
+
+ groupID := int64(0)
+ if groupIDStr := c.Query("group"); groupIDStr != "" {
+ if groupIDStr == accountListGroupUngroupedQueryValue {
+ groupID = service.AccountListGroupUngrouped
+ } else {
+ parsedGroupID, parseErr := strconv.ParseInt(groupIDStr, 10, 64)
+ if parseErr != nil || parsedGroupID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter")
+ }
+ groupID = parsedGroupID
+ }
+ }
+
+ return h.listAccountsFiltered(ctx, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder)
}
func (h *AccountHandler) resolveExportProxies(ctx context.Context, accounts []service.Account) ([]service.Proxy, error) {
@@ -567,15 +585,15 @@ func defaultProxyName(name string) string {
// enrichCredentialsFromIDToken performs best-effort extraction of user info fields
// (email, plan_type, chatgpt_account_id, etc.) from id_token in credentials.
-// Only applies to OpenAI/Sora OAuth accounts. Skips expired token errors silently.
+// Only applies to OpenAI OAuth accounts. Skips expired token errors silently.
// Existing credential values are never overwritten — only missing fields are filled.
func enrichCredentialsFromIDToken(item *DataAccount) {
if item.Credentials == nil {
return
}
- // Only enrich OpenAI/Sora OAuth accounts
+ // Only enrich OpenAI OAuth accounts
platform := strings.ToLower(strings.TrimSpace(item.Platform))
- if platform != service.PlatformOpenAI && platform != service.PlatformSora {
+ if platform != service.PlatformOpenAI {
return
}
if strings.ToLower(strings.TrimSpace(item.Type)) != service.AccountTypeOAuth {
diff --git a/backend/internal/handler/admin/account_data_handler_test.go b/backend/internal/handler/admin/account_data_handler_test.go
index 285033a1..5793983c 100644
--- a/backend/internal/handler/admin/account_data_handler_test.go
+++ b/backend/internal/handler/admin/account_data_handler_test.go
@@ -172,6 +172,51 @@ func TestExportDataWithoutProxies(t *testing.T) {
require.Nil(t, resp.Data.Accounts[0].ProxyKey)
}
+func TestExportDataPassesAccountFiltersAndSort(t *testing.T) {
+ router, adminSvc := setupAccountDataRouter()
+ adminSvc.accounts = []service.Account{
+ {ID: 1, Name: "acc-1", Status: service.StatusActive},
+ }
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/accounts/data?platform=openai&type=oauth&status=active&group=12&privacy_mode=blocked&search=keyword&sort_by=priority&sort_order=desc",
+ nil,
+ )
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ require.Equal(t, 1, adminSvc.lastListAccounts.calls)
+ require.Equal(t, "openai", adminSvc.lastListAccounts.platform)
+ require.Equal(t, "oauth", adminSvc.lastListAccounts.accountType)
+ require.Equal(t, "active", adminSvc.lastListAccounts.status)
+ require.Equal(t, int64(12), adminSvc.lastListAccounts.groupID)
+ require.Equal(t, "blocked", adminSvc.lastListAccounts.privacyMode)
+ require.Equal(t, "keyword", adminSvc.lastListAccounts.search)
+ require.Equal(t, "priority", adminSvc.lastListAccounts.sortBy)
+ require.Equal(t, "desc", adminSvc.lastListAccounts.sortOrder)
+}
+
+func TestExportDataSelectedIDsOverrideFilters(t *testing.T) {
+ router, adminSvc := setupAccountDataRouter()
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/accounts/data?ids=1,2&platform=openai&search=keyword&sort_by=priority&sort_order=desc",
+ nil,
+ )
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ var resp dataResponse
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Len(t, resp.Data.Accounts, 2)
+ require.Equal(t, 0, adminSvc.lastListAccounts.calls)
+}
+
func TestImportDataReusesProxyAndSkipsDefaultGroup(t *testing.T) {
router, adminSvc := setupAccountDataRouter()
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index ce5cffe4..7454451a 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -221,6 +221,8 @@ func (h *AccountHandler) List(c *gin.Context) {
status := c.Query("status")
search := c.Query("search")
privacyMode := strings.TrimSpace(c.Query("privacy_mode"))
+ sortBy := c.DefaultQuery("sort_by", "name")
+ sortOrder := c.DefaultQuery("sort_order", "asc")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
@@ -246,7 +248,7 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
- accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode)
+ accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -650,6 +652,7 @@ func (h *AccountHandler) Delete(c *gin.Context) {
type TestAccountRequest struct {
ModelID string `json:"model_id"`
Prompt string `json:"prompt"`
+ Mode string `json:"mode"`
}
type SyncFromCRSRequest struct {
@@ -680,7 +683,7 @@ func (h *AccountHandler) Test(c *gin.Context) {
_ = c.ShouldBindJSON(&req)
// Use AccountTestService to test the account with SSE streaming
- if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt); err != nil {
+ if err := h.accountTestService.TestAccountConnection(c, accountID, req.ModelID, req.Prompt, req.Mode); err != nil {
// Error already sent via SSE, just log
return
}
@@ -839,6 +842,7 @@ func (h *AccountHandler) refreshSingleAccount(ctx context.Context, account *serv
if updateErr != nil {
return nil, "", fmt.Errorf("failed to update credentials: %w", updateErr)
}
+ h.adminService.EnsureAntigravityPrivacy(ctx, updatedAccount)
return updatedAccount, "missing_project_id_temporary", nil
}
@@ -1409,6 +1413,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
+ "details": gin.H{
+ "group_id": mixedErr.GroupID,
+ "group_name": mixedErr.GroupName,
+ "current_platform": mixedErr.CurrentPlatform,
+ "other_platform": mixedErr.OtherPlatform,
+ },
})
return
}
@@ -1874,12 +1884,6 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
return
}
- // Handle Sora accounts
- if account.Platform == service.PlatformSora {
- response.Success(c, service.DefaultSoraModels(nil))
- return
- }
-
// Handle Claude/Anthropic accounts
// For OAuth and Setup-Token accounts: return default models
if account.IsOAuth() {
@@ -2034,7 +2038,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 {
- allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "")
+ allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "", "name", "asc")
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go
index cba3ae21..ddeaab02 100644
--- a/backend/internal/handler/admin/admin_basic_handlers_test.go
+++ b/backend/internal/handler/admin/admin_basic_handlers_test.go
@@ -23,6 +23,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router.GET("/api/v1/admin/users", userHandler.List)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
+ router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity)
router.POST("/api/v1/admin/users", userHandler.Create)
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
@@ -75,8 +76,26 @@ func TestUserHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
+ bindBody := map[string]any{
+ "provider_type": "wechat",
+ "provider_key": "wechat-main",
+ "provider_subject": "union-123",
+ "metadata": map[string]any{"source": "admin-repair"},
+ "channel": map[string]any{
+ "channel": "open",
+ "channel_app_id": "wx-open",
+ "channel_subject": "openid-123",
+ },
+ }
+ body, _ := json.Marshal(bindBody)
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
- body, _ := json.Marshal(createBody)
+ body, _ = json.Marshal(createBody)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
@@ -113,6 +132,33 @@ func TestUserHandlerEndpoints(t *testing.T) {
require.Equal(t, http.StatusOK, rec.Code)
}
+func TestUserHandlerBindAuthIdentityMapsRequest(t *testing.T) {
+ router, adminSvc := setupAdminRouter()
+
+ body, err := json.Marshal(map[string]any{
+ "provider_type": "oidc",
+ "provider_key": "https://issuer.example",
+ "provider_subject": "subject-123",
+ "issuer": "https://issuer.example",
+ "metadata": map[string]any{"report_id": 12},
+ })
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/9/auth-identities", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, int64(9), adminSvc.boundAuthIdentityFor)
+ require.NotNil(t, adminSvc.boundAuthIdentity)
+ require.Equal(t, "oidc", adminSvc.boundAuthIdentity.ProviderType)
+ require.Equal(t, "https://issuer.example", adminSvc.boundAuthIdentity.ProviderKey)
+ require.Equal(t, "subject-123", adminSvc.boundAuthIdentity.ProviderSubject)
+ require.Nil(t, adminSvc.boundAuthIdentity.Channel)
+ require.Equal(t, float64(12), adminSvc.boundAuthIdentity.Metadata["report_id"])
+}
+
func TestGroupHandlerEndpoints(t *testing.T) {
router, _ := setupAdminRouter()
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index 9759cef5..2fe29fa3 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -17,6 +17,8 @@ type stubAdminService struct {
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
+ boundAuthIdentity *service.AdminBindAuthIdentityInput
+ boundAuthIdentityFor int64
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
@@ -31,6 +33,41 @@ type stubAdminService struct {
platform string
groupIDs []int64
}
+ lastListAccounts struct {
+ platform string
+ accountType string
+ status string
+ search string
+ groupID int64
+ privacyMode string
+ sortBy string
+ sortOrder string
+ calls int
+ }
+ lastListUsers struct {
+ page int
+ pageSize int
+ filters service.UserListFilters
+ sortBy string
+ sortOrder string
+ calls int
+ }
+ lastListProxies struct {
+ protocol string
+ status string
+ search string
+ sortBy string
+ sortOrder string
+ calls int
+ }
+ lastListRedeemCodes struct {
+ codeType string
+ status string
+ search string
+ sortBy string
+ sortOrder string
+ calls int
+ }
mu sync.Mutex
}
@@ -99,7 +136,13 @@ func newStubAdminService() *stubAdminService {
}
}
-func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) {
+func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters, sortBy, sortOrder string) ([]service.User, int64, error) {
+ s.lastListUsers.page = page
+ s.lastListUsers.pageSize = pageSize
+ s.lastListUsers.filters = filters
+ s.lastListUsers.sortBy = sortBy
+ s.lastListUsers.sortOrder = sortOrder
+ s.lastListUsers.calls++
return s.users, int64(len(s.users)), nil
}
@@ -132,7 +175,7 @@ func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64,
return &user, nil
}
-func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) {
+func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]service.APIKey, int64, error) {
return s.apiKeys, int64(len(s.apiKeys)), nil
}
@@ -140,7 +183,64 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64,
return map[string]any{"user_id": userID}, nil
}
-func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) {
+func (s *stubAdminService) GetUserRPMStatus(ctx context.Context, userID int64) (*service.UserRPMStatus, error) {
+ user, err := s.GetUser(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ return &service.UserRPMStatus{
+ UserRPMUsed: 0,
+ UserRPMLimit: user.RPMLimit,
+ }, nil
+}
+
+func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) {
+ s.boundAuthIdentityFor = userID
+ copied := input
+ if input.Metadata != nil {
+ copied.Metadata = map[string]any{}
+ for key, value := range input.Metadata {
+ copied.Metadata[key] = value
+ }
+ }
+ if input.Channel != nil {
+ channel := *input.Channel
+ if input.Channel.Metadata != nil {
+ channel.Metadata = map[string]any{}
+ for key, value := range input.Channel.Metadata {
+ channel.Metadata[key] = value
+ }
+ }
+ copied.Channel = &channel
+ }
+ s.boundAuthIdentity = &copied
+
+ now := time.Now().UTC()
+ result := &service.AdminBoundAuthIdentity{
+ UserID: userID,
+ ProviderType: input.ProviderType,
+ ProviderKey: input.ProviderKey,
+ ProviderSubject: input.ProviderSubject,
+ VerifiedAt: &now,
+ Issuer: input.Issuer,
+ Metadata: input.Metadata,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ if input.Channel != nil {
+ result.Channel = &service.AdminBoundAuthIdentityChannel{
+ Channel: input.Channel.Channel,
+ ChannelAppID: input.Channel.ChannelAppID,
+ ChannelSubject: input.Channel.ChannelSubject,
+ Metadata: input.Channel.Metadata,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ }
+ return result, nil
+}
+
+func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) {
return s.groups, int64(len(s.groups)), nil
}
@@ -187,7 +287,24 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int
return nil
}
-func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, int64, error) {
+func (s *stubAdminService) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
+ return nil
+}
+
+func (s *stubAdminService) BatchSetGroupRPMOverrides(_ context.Context, _ int64, _ []service.GroupRPMOverrideInput) error {
+ return nil
+}
+
+func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) {
+ s.lastListAccounts.platform = platform
+ s.lastListAccounts.accountType = accountType
+ s.lastListAccounts.status = status
+ s.lastListAccounts.search = search
+ s.lastListAccounts.groupID = groupID
+ s.lastListAccounts.privacyMode = privacyMode
+ s.lastListAccounts.sortBy = sortBy
+ s.lastListAccounts.sortOrder = sortOrder
+ s.lastListAccounts.calls++
return s.accounts, int64(len(s.accounts)), nil
}
@@ -261,7 +378,13 @@ func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAcc
return s.checkMixedErr
}
-func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
+func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]service.Proxy, int64, error) {
+ s.lastListProxies.protocol = protocol
+ s.lastListProxies.status = status
+ s.lastListProxies.search = search
+ s.lastListProxies.sortBy = sortBy
+ s.lastListProxies.sortOrder = sortOrder
+ s.lastListProxies.calls++
search = strings.TrimSpace(strings.ToLower(search))
filtered := make([]service.Proxy, 0, len(s.proxies))
for _, proxy := range s.proxies {
@@ -283,7 +406,7 @@ func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int,
return filtered, int64(len(filtered)), nil
}
-func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) {
+func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]service.ProxyWithAccountCount, int64, error) {
return s.proxyCounts, int64(len(s.proxyCounts)), nil
}
@@ -380,12 +503,17 @@ func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*se
{Target: "openai", Status: "pass", HTTPStatus: 401},
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
{Target: "gemini", Status: "pass", HTTPStatus: 200},
- {Target: "sora", Status: "pass", HTTPStatus: 401},
},
}, nil
}
-func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
+func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]service.RedeemCode, int64, error) {
+ s.lastListRedeemCodes.codeType = codeType
+ s.lastListRedeemCodes.status = status
+ s.lastListRedeemCodes.search = search
+ s.lastListRedeemCodes.sortBy = sortBy
+ s.lastListRedeemCodes.sortOrder = sortOrder
+ s.lastListRedeemCodes.calls++
return s.redeems, int64(len(s.redeems)), nil
}
diff --git a/backend/internal/handler/admin/affiliate_handler.go b/backend/internal/handler/admin/affiliate_handler.go
new file mode 100644
index 00000000..97e649ec
--- /dev/null
+++ b/backend/internal/handler/admin/affiliate_handler.go
@@ -0,0 +1,183 @@
+package admin
+
+import (
+ "strconv"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// AffiliateHandler handles admin affiliate (邀请返利) management:
+// listing users with custom settings, updating per-user invite codes
+// and exclusive rebate rates, and batch operations.
+type AffiliateHandler struct {
+ affiliateService *service.AffiliateService
+ adminService service.AdminService
+}
+
+// NewAffiliateHandler creates a new admin affiliate handler.
+func NewAffiliateHandler(affiliateService *service.AffiliateService, adminService service.AdminService) *AffiliateHandler {
+ return &AffiliateHandler{
+ affiliateService: affiliateService,
+ adminService: adminService,
+ }
+}
+
+// ListUsers returns paginated users with custom affiliate settings.
+// GET /api/v1/admin/affiliates/users
+func (h *AffiliateHandler) ListUsers(c *gin.Context) {
+ page, pageSize := response.ParsePagination(c)
+ search := c.Query("search")
+
+ entries, total, err := h.affiliateService.AdminListCustomUsers(c.Request.Context(), service.AffiliateAdminFilter{
+ Search: search,
+ Page: page,
+ PageSize: pageSize,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Paginated(c, entries, total, page, pageSize)
+}
+
+// UpdateUserSettings updates a user's affiliate settings.
+// PUT /api/v1/admin/affiliates/users/:user_id
+//
+// Both fields are optional and applied independently.
+type UpdateAffiliateUserRequest struct {
+ AffCode *string `json:"aff_code"`
+ AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent"`
+ // ClearRebateRate explicitly clears the per-user rate (sets it to NULL).
+ // Used to disambiguate from "field not provided".
+ ClearRebateRate bool `json:"clear_rebate_rate"`
+}
+
+func (h *AffiliateHandler) UpdateUserSettings(c *gin.Context) {
+ userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
+ if err != nil || userID <= 0 {
+ response.BadRequest(c, "Invalid user_id")
+ return
+ }
+
+ var req UpdateAffiliateUserRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if req.AffCode != nil {
+ if err := h.affiliateService.AdminUpdateUserAffCode(c.Request.Context(), userID, *req.AffCode); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+
+ if req.ClearRebateRate {
+ if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, nil); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ } else if req.AffRebateRatePercent != nil {
+ if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, req.AffRebateRatePercent); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+
+ response.Success(c, gin.H{"user_id": userID})
+}
+
+// ClearUserSettings removes ALL of a user's custom affiliate settings — clears
+// the exclusive rebate rate AND regenerates the invite code as a new system
+// random one. Conceptually this "removes the user from the custom list".
+//
+// Both writes happen in this handler; failure of one leaves the other applied,
+// but the operation is idempotent so the admin can re-run it safely.
+// DELETE /api/v1/admin/affiliates/users/:user_id
+func (h *AffiliateHandler) ClearUserSettings(c *gin.Context) {
+ userID, err := strconv.ParseInt(c.Param("user_id"), 10, 64)
+ if err != nil || userID <= 0 {
+ response.BadRequest(c, "Invalid user_id")
+ return
+ }
+ if err := h.affiliateService.AdminSetUserRebateRate(c.Request.Context(), userID, nil); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if _, err := h.affiliateService.AdminResetUserAffCode(c.Request.Context(), userID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"user_id": userID})
+}
+
+// BatchSetRate applies the same rebate rate (or clears it) to multiple users.
+//
+// Protocol: pass `clear: true` to clear rates (aff_rebate_rate_percent is
+// ignored). Otherwise aff_rebate_rate_percent is required and applied to
+// every user_id. The explicit `clear` flag exists because Go's JSON unmarshal
+// can't distinguish a missing field from `null`, and a silent clear from a
+// frontend that forgot to include the rate would be a footgun.
+//
+// POST /api/v1/admin/affiliates/users/batch-rate
+type BatchSetRateRequest struct {
+ UserIDs []int64 `json:"user_ids" binding:"required"`
+ AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent"`
+ Clear bool `json:"clear"`
+}
+
+func (h *AffiliateHandler) BatchSetRate(c *gin.Context) {
+ var req BatchSetRateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ if len(req.UserIDs) == 0 {
+ response.BadRequest(c, "user_ids cannot be empty")
+ return
+ }
+ if !req.Clear && req.AffRebateRatePercent == nil {
+ response.BadRequest(c, "aff_rebate_rate_percent is required unless clear=true")
+ return
+ }
+ rate := req.AffRebateRatePercent
+ if req.Clear {
+ rate = nil
+ }
+ if err := h.affiliateService.AdminBatchSetUserRebateRate(c.Request.Context(), req.UserIDs, rate); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"affected": len(req.UserIDs)})
+}
+
+// AffiliateUserSummary is the minimal user shape returned by LookupUsers,
+// shared with the frontend's add-custom-user picker.
+type AffiliateUserSummary struct {
+ ID int64 `json:"id"`
+ Email string `json:"email"`
+ Username string `json:"username"`
+}
+
+// LookupUsers searches users by email/username for the "add custom user" modal.
+// GET /api/v1/admin/affiliates/users/lookup?q=
+func (h *AffiliateHandler) LookupUsers(c *gin.Context) {
+ keyword := c.Query("q")
+ if keyword == "" {
+ response.Success(c, []AffiliateUserSummary{})
+ return
+ }
+ users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 20, service.UserListFilters{Search: keyword}, "email", "asc")
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ result := make([]AffiliateUserSummary, len(users))
+ for i, u := range users {
+ result[i] = AffiliateUserSummary{ID: u.ID, Email: u.Email, Username: u.Username}
+ }
+ response.Success(c, result)
+}
diff --git a/backend/internal/handler/admin/announcement_handler.go b/backend/internal/handler/admin/announcement_handler.go
index d1312bc0..d3b9d173 100644
--- a/backend/internal/handler/admin/announcement_handler.go
+++ b/backend/internal/handler/admin/announcement_handler.go
@@ -52,13 +52,17 @@ func (h *AnnouncementHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := strings.TrimSpace(c.Query("status"))
search := strings.TrimSpace(c.Query("search"))
+ sortBy := c.DefaultQuery("sort_by", "created_at")
+ sortOrder := c.DefaultQuery("sort_order", "desc")
if len(search) > 200 {
search = search[:200]
}
params := pagination.PaginationParams{
- Page: page,
- PageSize: pageSize,
+ Page: page,
+ PageSize: pageSize,
+ SortBy: sortBy,
+ SortOrder: sortOrder,
}
items, paginationResult, err := h.announcementService.List(
@@ -227,8 +231,10 @@ func (h *AnnouncementHandler) ListReadStatus(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
params := pagination.PaginationParams{
- Page: page,
- PageSize: pageSize,
+ Page: page,
+ PageSize: pageSize,
+ SortBy: c.DefaultQuery("sort_by", "email"),
+ SortOrder: c.DefaultQuery("sort_order", "asc"),
}
search := strings.TrimSpace(c.Query("search"))
if len(search) > 200 {
diff --git a/backend/internal/handler/admin/announcement_handler_sort_test.go b/backend/internal/handler/admin/announcement_handler_sort_test.go
new file mode 100644
index 00000000..545e619e
--- /dev/null
+++ b/backend/internal/handler/admin/announcement_handler_sort_test.go
@@ -0,0 +1,138 @@
+package admin
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type announcementRepoCapture struct {
+ service.AnnouncementRepository
+ listParams pagination.PaginationParams
+}
+
+func (r *announcementRepoCapture) List(ctx context.Context, params pagination.PaginationParams, filters service.AnnouncementListFilters) ([]service.Announcement, *pagination.PaginationResult, error) {
+ r.listParams = params
+ return []service.Announcement{}, &pagination.PaginationResult{
+ Total: 0,
+ Page: params.Page,
+ PageSize: params.PageSize,
+ Pages: 0,
+ }, nil
+}
+
+func (r *announcementRepoCapture) GetByID(ctx context.Context, id int64) (*service.Announcement, error) {
+ return &service.Announcement{
+ ID: id,
+ Title: "announcement",
+ Content: "content",
+ Status: service.AnnouncementStatusActive,
+ CreatedAt: time.Now(),
+ UpdatedAt: time.Now(),
+ }, nil
+}
+
+type announcementUserRepoCapture struct {
+ service.UserRepository
+ listParams pagination.PaginationParams
+}
+
+func (r *announcementUserRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ r.listParams = params
+ return []service.User{}, &pagination.PaginationResult{
+ Total: 0,
+ Page: params.Page,
+ PageSize: params.PageSize,
+ Pages: 0,
+ }, nil
+}
+
+type announcementReadRepoCapture struct {
+ service.AnnouncementReadRepository
+}
+
+func (r *announcementReadRepoCapture) GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error) {
+ return map[int64]time.Time{}, nil
+}
+
+type announcementUserSubRepoCapture struct {
+ service.UserSubscriptionRepository
+}
+
+func newAnnouncementSortTestRouter(announcementRepo *announcementRepoCapture, userRepo *announcementUserRepoCapture) *gin.Engine {
+ gin.SetMode(gin.TestMode)
+ svc := service.NewAnnouncementService(
+ announcementRepo,
+ &announcementReadRepoCapture{},
+ userRepo,
+ &announcementUserSubRepoCapture{},
+ )
+ handler := NewAnnouncementHandler(svc)
+ router := gin.New()
+ router.GET("/admin/announcements", handler.List)
+ router.GET("/admin/announcements/:id/read-status", handler.ListReadStatus)
+ return router
+}
+
+func TestAdminAnnouncementListSortParams(t *testing.T) {
+ announcementRepo := &announcementRepoCapture{}
+ userRepo := &announcementUserRepoCapture{}
+ router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/announcements?sort_by=title&sort_order=ASC", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "title", announcementRepo.listParams.SortBy)
+ require.Equal(t, "ASC", announcementRepo.listParams.SortOrder)
+}
+
+func TestAdminAnnouncementListSortDefaults(t *testing.T) {
+ announcementRepo := &announcementRepoCapture{}
+ userRepo := &announcementUserRepoCapture{}
+ router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/announcements", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "created_at", announcementRepo.listParams.SortBy)
+ require.Equal(t, "desc", announcementRepo.listParams.SortOrder)
+}
+
+func TestAdminAnnouncementReadStatusSortParams(t *testing.T) {
+ announcementRepo := &announcementRepoCapture{}
+ userRepo := &announcementUserRepoCapture{}
+ router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/announcements/1/read-status?sort_by=balance&sort_order=DESC", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "balance", userRepo.listParams.SortBy)
+ require.Equal(t, "DESC", userRepo.listParams.SortOrder)
+}
+
+func TestAdminAnnouncementReadStatusSortDefaults(t *testing.T) {
+ announcementRepo := &announcementRepoCapture{}
+ userRepo := &announcementUserRepoCapture{}
+ router := newAnnouncementSortTestRouter(announcementRepo, userRepo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/announcements/1/read-status", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "email", userRepo.listParams.SortBy)
+ require.Equal(t, "asc", userRepo.listParams.SortOrder)
+}
diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go
new file mode 100644
index 00000000..950e6e72
--- /dev/null
+++ b/backend/internal/handler/admin/channel_handler.go
@@ -0,0 +1,502 @@
+package admin
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// ChannelHandler handles admin channel management
+type ChannelHandler struct {
+ channelService *service.ChannelService
+ billingService *service.BillingService
+}
+
+// NewChannelHandler creates a new admin channel handler
+func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService) *ChannelHandler {
+ return &ChannelHandler{channelService: channelService, billingService: billingService}
+}
+
+// --- Request / Response types ---
+
+type createChannelRequest struct {
+ Name string `json:"name" binding:"required,max=100"`
+ Description string `json:"description"`
+ GroupIDs []int64 `json:"group_ids"`
+ ModelPricing []channelModelPricingRequest `json:"model_pricing"`
+ ModelMapping map[string]map[string]string `json:"model_mapping"`
+ BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
+ RestrictModels bool `json:"restrict_models"`
+ Features string `json:"features"`
+ FeaturesConfig map[string]any `json:"features_config"`
+ ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
+ AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
+}
+
+type updateChannelRequest struct {
+ Name string `json:"name" binding:"omitempty,max=100"`
+ Description *string `json:"description"`
+ Status string `json:"status" binding:"omitempty,oneof=active disabled"`
+ GroupIDs *[]int64 `json:"group_ids"`
+ ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
+ ModelMapping map[string]map[string]string `json:"model_mapping"`
+ BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
+ RestrictModels *bool `json:"restrict_models"`
+ Features *string `json:"features"`
+ FeaturesConfig map[string]any `json:"features_config"`
+ ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"`
+ AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
+}
+
+type channelModelPricingRequest struct {
+ Platform string `json:"platform" binding:"omitempty,max=50"`
+ Models []string `json:"models" binding:"required,min=1,max=100"`
+ BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"`
+ InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"`
+ OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"`
+ CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"`
+ CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"`
+ ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"`
+ PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"`
+ Intervals []pricingIntervalRequest `json:"intervals"`
+}
+
+type pricingIntervalRequest struct {
+ MinTokens int `json:"min_tokens"`
+ MaxTokens *int `json:"max_tokens"`
+ TierLabel string `json:"tier_label"`
+ InputPrice *float64 `json:"input_price"`
+ OutputPrice *float64 `json:"output_price"`
+ CacheWritePrice *float64 `json:"cache_write_price"`
+ CacheReadPrice *float64 `json:"cache_read_price"`
+ PerRequestPrice *float64 `json:"per_request_price"`
+ SortOrder int `json:"sort_order"`
+}
+
+type accountStatsPricingRuleRequest struct {
+ Name string `json:"name"`
+ GroupIDs []int64 `json:"group_ids"`
+ AccountIDs []int64 `json:"account_ids"`
+ Pricing []channelModelPricingRequest `json:"pricing"`
+}
+
+type channelResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Status string `json:"status"`
+ BillingModelSource string `json:"billing_model_source"`
+ RestrictModels bool `json:"restrict_models"`
+ Features string `json:"features"`
+ FeaturesConfig map[string]any `json:"features_config"`
+ GroupIDs []int64 `json:"group_ids"`
+ ModelPricing []channelModelPricingResponse `json:"model_pricing"`
+ ModelMapping map[string]map[string]string `json:"model_mapping"`
+ ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
+ AccountStatsPricingRules []accountStatsPricingRuleResponse `json:"account_stats_pricing_rules"`
+ CreatedAt string `json:"created_at"`
+ UpdatedAt string `json:"updated_at"`
+}
+
+type channelModelPricingResponse struct {
+ ID int64 `json:"id"`
+ Platform string `json:"platform"`
+ Models []string `json:"models"`
+ BillingMode string `json:"billing_mode"`
+ InputPrice *float64 `json:"input_price"`
+ OutputPrice *float64 `json:"output_price"`
+ CacheWritePrice *float64 `json:"cache_write_price"`
+ CacheReadPrice *float64 `json:"cache_read_price"`
+ ImageOutputPrice *float64 `json:"image_output_price"`
+ PerRequestPrice *float64 `json:"per_request_price"`
+ Intervals []pricingIntervalResponse `json:"intervals"`
+}
+
+type pricingIntervalResponse struct {
+ ID int64 `json:"id"`
+ MinTokens int `json:"min_tokens"`
+ MaxTokens *int `json:"max_tokens"`
+ TierLabel string `json:"tier_label,omitempty"`
+ InputPrice *float64 `json:"input_price"`
+ OutputPrice *float64 `json:"output_price"`
+ CacheWritePrice *float64 `json:"cache_write_price"`
+ CacheReadPrice *float64 `json:"cache_read_price"`
+ PerRequestPrice *float64 `json:"per_request_price"`
+ SortOrder int `json:"sort_order"`
+}
+
+type accountStatsPricingRuleResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ GroupIDs []int64 `json:"group_ids"`
+ AccountIDs []int64 `json:"account_ids"`
+ Pricing []channelModelPricingResponse `json:"pricing"`
+}
+
+func channelToResponse(ch *service.Channel) *channelResponse {
+ if ch == nil {
+ return nil
+ }
+ resp := &channelResponse{
+ ID: ch.ID,
+ Name: ch.Name,
+ Description: ch.Description,
+ Status: ch.Status,
+ RestrictModels: ch.RestrictModels,
+ Features: ch.Features,
+ FeaturesConfig: ch.FeaturesConfig,
+ GroupIDs: ch.GroupIDs,
+ ModelMapping: ch.ModelMapping,
+ CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
+ UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
+ }
+ resp.BillingModelSource = ch.BillingModelSource
+ if resp.GroupIDs == nil {
+ resp.GroupIDs = []int64{}
+ }
+ if resp.ModelMapping == nil {
+ resp.ModelMapping = map[string]map[string]string{}
+ }
+
+ resp.ModelPricing = make([]channelModelPricingResponse, 0, len(ch.ModelPricing))
+ for _, p := range ch.ModelPricing {
+ resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
+ }
+
+ resp.ApplyPricingToAccountStats = ch.ApplyPricingToAccountStats
+ resp.AccountStatsPricingRules = make([]accountStatsPricingRuleResponse, 0, len(ch.AccountStatsPricingRules))
+ for _, rule := range ch.AccountStatsPricingRules {
+ ruleResp := accountStatsPricingRuleResponse{
+ ID: rule.ID,
+ Name: rule.Name,
+ GroupIDs: rule.GroupIDs,
+ AccountIDs: rule.AccountIDs,
+ Pricing: make([]channelModelPricingResponse, 0, len(rule.Pricing)),
+ }
+ if ruleResp.GroupIDs == nil {
+ ruleResp.GroupIDs = []int64{}
+ }
+ if ruleResp.AccountIDs == nil {
+ ruleResp.AccountIDs = []int64{}
+ }
+ for i := range rule.Pricing {
+ ruleResp.Pricing = append(ruleResp.Pricing, pricingToResponse(&rule.Pricing[i]))
+ }
+ resp.AccountStatsPricingRules = append(resp.AccountStatsPricingRules, ruleResp)
+ }
+
+ return resp
+}
+
+func pricingToResponse(p *service.ChannelModelPricing) channelModelPricingResponse {
+ models := p.Models
+ if models == nil {
+ models = []string{}
+ }
+ billingMode := string(p.BillingMode)
+ if billingMode == "" {
+ billingMode = string(service.BillingModeToken)
+ }
+ platform := p.Platform
+ if platform == "" {
+ platform = service.PlatformAnthropic
+ }
+ intervals := make([]pricingIntervalResponse, 0, len(p.Intervals))
+ for _, iv := range p.Intervals {
+ intervals = append(intervals, intervalToResponse(iv))
+ }
+ return channelModelPricingResponse{
+ ID: p.ID,
+ Platform: platform,
+ Models: models,
+ BillingMode: billingMode,
+ InputPrice: p.InputPrice,
+ OutputPrice: p.OutputPrice,
+ CacheWritePrice: p.CacheWritePrice,
+ CacheReadPrice: p.CacheReadPrice,
+ ImageOutputPrice: p.ImageOutputPrice,
+ PerRequestPrice: p.PerRequestPrice,
+ Intervals: intervals,
+ }
+}
+
+func intervalToResponse(iv service.PricingInterval) pricingIntervalResponse {
+ return pricingIntervalResponse{
+ ID: iv.ID,
+ MinTokens: iv.MinTokens,
+ MaxTokens: iv.MaxTokens,
+ TierLabel: iv.TierLabel,
+ InputPrice: iv.InputPrice,
+ OutputPrice: iv.OutputPrice,
+ CacheWritePrice: iv.CacheWritePrice,
+ CacheReadPrice: iv.CacheReadPrice,
+ PerRequestPrice: iv.PerRequestPrice,
+ SortOrder: iv.SortOrder,
+ }
+}
+
+func pricingRequestToService(reqs []channelModelPricingRequest) []service.ChannelModelPricing {
+ result := make([]service.ChannelModelPricing, 0, len(reqs))
+ for _, r := range reqs {
+ billingMode := service.BillingMode(r.BillingMode)
+ if billingMode == "" {
+ billingMode = service.BillingModeToken
+ }
+ platform := r.Platform
+ intervals := make([]service.PricingInterval, 0, len(r.Intervals))
+ for _, iv := range r.Intervals {
+ intervals = append(intervals, service.PricingInterval{
+ MinTokens: iv.MinTokens,
+ MaxTokens: iv.MaxTokens,
+ TierLabel: iv.TierLabel,
+ InputPrice: iv.InputPrice,
+ OutputPrice: iv.OutputPrice,
+ CacheWritePrice: iv.CacheWritePrice,
+ CacheReadPrice: iv.CacheReadPrice,
+ PerRequestPrice: iv.PerRequestPrice,
+ SortOrder: iv.SortOrder,
+ })
+ }
+ result = append(result, service.ChannelModelPricing{
+ Platform: platform,
+ Models: r.Models,
+ BillingMode: billingMode,
+ InputPrice: r.InputPrice,
+ OutputPrice: r.OutputPrice,
+ CacheWritePrice: r.CacheWritePrice,
+ CacheReadPrice: r.CacheReadPrice,
+ ImageOutputPrice: r.ImageOutputPrice,
+ PerRequestPrice: r.PerRequestPrice,
+ Intervals: intervals,
+ })
+ }
+ return result
+}
+
+func accountStatsPricingRuleRequestToService(r accountStatsPricingRuleRequest) service.AccountStatsPricingRule {
+ return service.AccountStatsPricingRule{
+ Name: r.Name,
+ GroupIDs: r.GroupIDs,
+ AccountIDs: r.AccountIDs,
+ Pricing: pricingRequestToService(r.Pricing),
+ }
+}
+
+// --- Handlers ---
+
+// List handles listing channels with pagination
+// GET /api/v1/admin/channels
+func (h *ChannelHandler) List(c *gin.Context) {
+ page, pageSize := response.ParsePagination(c)
+ status := c.Query("status")
+ search := strings.TrimSpace(c.Query("search"))
+ if len(search) > 100 {
+ search = search[:100]
+ }
+
+ channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{
+ Page: page,
+ PageSize: pageSize,
+ SortBy: c.DefaultQuery("sort_by", "created_at"),
+ SortOrder: c.DefaultQuery("sort_order", "desc"),
+ }, status, search)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ out := make([]*channelResponse, 0, len(channels))
+ for i := range channels {
+ out = append(out, channelToResponse(&channels[i]))
+ }
+ response.Paginated(c, out, pag.Total, page, pageSize)
+}
+
+// GetByID handles getting a channel by ID
+// GET /api/v1/admin/channels/:id
+func (h *ChannelHandler) GetByID(c *gin.Context) {
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
+ return
+ }
+
+ channel, err := h.channelService.GetByID(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, channelToResponse(channel))
+}
+
+// Create handles creating a new channel
+// POST /api/v1/admin/channels
+func (h *ChannelHandler) Create(c *gin.Context) {
+ var req createChannelRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
+ return
+ }
+
+ pricing := pricingRequestToService(req.ModelPricing)
+ // Main model_pricing requires a platform; default to anthropic for backward compatibility.
+ for i := range pricing {
+ if pricing[i].Platform == "" {
+ pricing[i].Platform = service.PlatformAnthropic
+ }
+ }
+
+ var statsRules []service.AccountStatsPricingRule
+ for i, r := range req.AccountStatsPricingRules {
+ if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
+ response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
+ fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
+ return
+ }
+ if len(r.Pricing) == 0 {
+ response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
+ fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
+ return
+ }
+ rule := accountStatsPricingRuleRequestToService(r)
+ rule.SortOrder = i
+ statsRules = append(statsRules, rule)
+ }
+
+ channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
+ Name: req.Name,
+ Description: req.Description,
+ GroupIDs: req.GroupIDs,
+ ModelPricing: pricing,
+ ModelMapping: req.ModelMapping,
+ BillingModelSource: req.BillingModelSource,
+ RestrictModels: req.RestrictModels,
+ Features: req.Features,
+ FeaturesConfig: req.FeaturesConfig,
+ ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
+ AccountStatsPricingRules: statsRules,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, channelToResponse(channel))
+}
+
+// Update handles updating a channel
+// PUT /api/v1/admin/channels/:id
+func (h *ChannelHandler) Update(c *gin.Context) {
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
+ return
+ }
+
+ var req updateChannelRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
+ return
+ }
+
+ input := &service.UpdateChannelInput{
+ Name: req.Name,
+ Description: req.Description,
+ Status: req.Status,
+ GroupIDs: req.GroupIDs,
+ ModelMapping: req.ModelMapping,
+ BillingModelSource: req.BillingModelSource,
+ RestrictModels: req.RestrictModels,
+ Features: req.Features,
+ FeaturesConfig: req.FeaturesConfig,
+ ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
+ }
+ if req.ModelPricing != nil {
+ pricing := pricingRequestToService(*req.ModelPricing)
+ for i := range pricing {
+ if pricing[i].Platform == "" {
+ pricing[i].Platform = service.PlatformAnthropic
+ }
+ }
+ input.ModelPricing = &pricing
+ }
+ if req.AccountStatsPricingRules != nil {
+ statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules))
+ for i, r := range *req.AccountStatsPricingRules {
+ if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
+ response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
+ fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
+ return
+ }
+ if len(r.Pricing) == 0 {
+ response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
+ fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
+ return
+ }
+ rule := accountStatsPricingRuleRequestToService(r)
+ rule.SortOrder = i
+ statsRules = append(statsRules, rule)
+ }
+ input.AccountStatsPricingRules = &statsRules
+ }
+
+ channel, err := h.channelService.Update(c.Request.Context(), id, input)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, channelToResponse(channel))
+}
+
+// Delete handles deleting a channel
+// DELETE /api/v1/admin/channels/:id
+func (h *ChannelHandler) Delete(c *gin.Context) {
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
+ return
+ }
+
+ if err := h.channelService.Delete(c.Request.Context(), id); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"message": "Channel deleted successfully"})
+}
+
+// GetModelDefaultPricing 获取模型的默认定价(用于前端自动填充)
+// GET /api/v1/admin/channels/model-pricing?model=claude-sonnet-4
+func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) {
+ model := strings.TrimSpace(c.Query("model"))
+ if model == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("MISSING_PARAMETER", "model parameter is required").
+ WithMetadata(map[string]string{"param": "model"}))
+ return
+ }
+
+ pricing, err := h.billingService.GetModelPricing(model)
+ if err != nil {
+ // 模型不在定价列表中
+ response.Success(c, gin.H{"found": false})
+ return
+ }
+
+ response.Success(c, gin.H{
+ "found": true,
+ "input_price": pricing.InputPricePerToken,
+ "output_price": pricing.OutputPricePerToken,
+ "cache_write_price": pricing.CacheCreationPricePerToken,
+ "cache_read_price": pricing.CacheReadPricePerToken,
+ "image_output_price": pricing.ImageOutputPricePerToken,
+ })
+}
diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go
new file mode 100644
index 00000000..12cd4bdd
--- /dev/null
+++ b/backend/internal/handler/admin/channel_handler_test.go
@@ -0,0 +1,418 @@
+//go:build unit
+
+package admin
+
+import (
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// helpers
+// ---------------------------------------------------------------------------
+
+func float64Ptr(v float64) *float64 { return &v }
+func intPtr(v int) *int { return &v }
+
+// ---------------------------------------------------------------------------
+// 1. channelToResponse
+// ---------------------------------------------------------------------------
+
+func TestChannelToResponse_NilInput(t *testing.T) {
+ require.Nil(t, channelToResponse(nil))
+}
+
+func TestChannelToResponse_FullChannel(t *testing.T) {
+ now := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC)
+ ch := &service.Channel{
+ ID: 42,
+ Name: "test-channel",
+ Description: "desc",
+ Status: "active",
+ BillingModelSource: "upstream",
+ RestrictModels: true,
+ CreatedAt: now,
+ UpdatedAt: now.Add(time.Hour),
+ GroupIDs: []int64{1, 2, 3},
+ ModelPricing: []service.ChannelModelPricing{
+ {
+ ID: 10,
+ Platform: "openai",
+ Models: []string{"gpt-4"},
+ BillingMode: service.BillingModeToken,
+ InputPrice: float64Ptr(0.01),
+ OutputPrice: float64Ptr(0.03),
+ CacheWritePrice: float64Ptr(0.005),
+ CacheReadPrice: float64Ptr(0.002),
+ PerRequestPrice: float64Ptr(0.5),
+ },
+ },
+ ModelMapping: map[string]map[string]string{
+ "anthropic": {"claude-3-haiku": "claude-haiku-3"},
+ },
+ }
+
+ resp := channelToResponse(ch)
+ require.NotNil(t, resp)
+ require.Equal(t, int64(42), resp.ID)
+ require.Equal(t, "test-channel", resp.Name)
+ require.Equal(t, "desc", resp.Description)
+ require.Equal(t, "active", resp.Status)
+ require.Equal(t, "upstream", resp.BillingModelSource)
+ require.True(t, resp.RestrictModels)
+ require.Equal(t, []int64{1, 2, 3}, resp.GroupIDs)
+ require.Equal(t, "2025-06-01T12:00:00Z", resp.CreatedAt)
+ require.Equal(t, "2025-06-01T13:00:00Z", resp.UpdatedAt)
+
+ // model mapping
+ require.Len(t, resp.ModelMapping, 1)
+ require.Equal(t, "claude-haiku-3", resp.ModelMapping["anthropic"]["claude-3-haiku"])
+
+ // pricing
+ require.Len(t, resp.ModelPricing, 1)
+ p := resp.ModelPricing[0]
+ require.Equal(t, int64(10), p.ID)
+ require.Equal(t, "openai", p.Platform)
+ require.Equal(t, []string{"gpt-4"}, p.Models)
+ require.Equal(t, "token", p.BillingMode)
+ require.Equal(t, float64Ptr(0.01), p.InputPrice)
+ require.Equal(t, float64Ptr(0.03), p.OutputPrice)
+ require.Equal(t, float64Ptr(0.005), p.CacheWritePrice)
+ require.Equal(t, float64Ptr(0.002), p.CacheReadPrice)
+ require.Equal(t, float64Ptr(0.5), p.PerRequestPrice)
+ require.Empty(t, p.Intervals)
+}
+
+func TestChannelToResponse_EmptyDefaults(t *testing.T) {
+ now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
+ ch := &service.Channel{
+ ID: 1,
+ Name: "ch",
+ BillingModelSource: service.BillingModelSourceChannelMapped,
+ CreatedAt: now,
+ UpdatedAt: now,
+ GroupIDs: nil,
+ ModelMapping: nil,
+ ModelPricing: []service.ChannelModelPricing{
+ {
+ Platform: "",
+ BillingMode: "",
+ Models: []string{"m1"},
+ },
+ },
+ }
+
+ // handler 层 channelToResponse 现在是纯透传:BillingModelSource 的空值兜底
+ // 已下放到 service 层(Create/GetByID/List/Update/ListAvailable 出口统一处理),
+ // 因此这里构造 fixture 时直接传入归一化后的值。
+ resp := channelToResponse(ch)
+ require.Equal(t, "channel_mapped", resp.BillingModelSource)
+ require.NotNil(t, resp.GroupIDs)
+ require.Empty(t, resp.GroupIDs)
+ require.NotNil(t, resp.ModelMapping)
+ require.Empty(t, resp.ModelMapping)
+
+ require.Len(t, resp.ModelPricing, 1)
+ require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
+ require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
+}
+
+func TestChannelToResponse_BillingModelSourcePassthrough(t *testing.T) {
+ // handler 不再兜底 BillingModelSource:空值应原样透传(由 service 层负责默认回填)。
+ ch := &service.Channel{
+ ID: 1,
+ Name: "ch",
+ BillingModelSource: "",
+ CreatedAt: time.Now(),
+ UpdatedAt: time.Now(),
+ }
+ resp := channelToResponse(ch)
+ require.Equal(t, "", resp.BillingModelSource, "handler 应纯透传,默认值由 service.normalizeBillingModelSource 负责")
+}
+
+func TestChannelToResponse_NilModels(t *testing.T) {
+ now := time.Now()
+ ch := &service.Channel{
+ ID: 1,
+ Name: "ch",
+ CreatedAt: now,
+ UpdatedAt: now,
+ ModelPricing: []service.ChannelModelPricing{
+ {
+ Models: nil,
+ },
+ },
+ }
+
+ resp := channelToResponse(ch)
+ require.Len(t, resp.ModelPricing, 1)
+ require.NotNil(t, resp.ModelPricing[0].Models)
+ require.Empty(t, resp.ModelPricing[0].Models)
+}
+
+func TestChannelToResponse_WithIntervals(t *testing.T) {
+ now := time.Now()
+ ch := &service.Channel{
+ ID: 1,
+ Name: "ch",
+ CreatedAt: now,
+ UpdatedAt: now,
+ ModelPricing: []service.ChannelModelPricing{
+ {
+ Models: []string{"m1"},
+ BillingMode: service.BillingModePerRequest,
+ Intervals: []service.PricingInterval{
+ {
+ ID: 100,
+ MinTokens: 0,
+ MaxTokens: intPtr(1000),
+ TierLabel: "1K",
+ InputPrice: float64Ptr(0.01),
+ OutputPrice: float64Ptr(0.02),
+ CacheWritePrice: float64Ptr(0.003),
+ CacheReadPrice: float64Ptr(0.001),
+ PerRequestPrice: float64Ptr(0.1),
+ SortOrder: 1,
+ },
+ {
+ ID: 101,
+ MinTokens: 1000,
+ MaxTokens: nil,
+ TierLabel: "unlimited",
+ SortOrder: 2,
+ },
+ },
+ },
+ },
+ }
+
+ resp := channelToResponse(ch)
+ require.Len(t, resp.ModelPricing, 1)
+ intervals := resp.ModelPricing[0].Intervals
+ require.Len(t, intervals, 2)
+
+ iv0 := intervals[0]
+ require.Equal(t, int64(100), iv0.ID)
+ require.Equal(t, 0, iv0.MinTokens)
+ require.Equal(t, intPtr(1000), iv0.MaxTokens)
+ require.Equal(t, "1K", iv0.TierLabel)
+ require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
+ require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
+ require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
+ require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
+ require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
+ require.Equal(t, 1, iv0.SortOrder)
+
+ iv1 := intervals[1]
+ require.Equal(t, int64(101), iv1.ID)
+ require.Equal(t, 1000, iv1.MinTokens)
+ require.Nil(t, iv1.MaxTokens)
+ require.Equal(t, "unlimited", iv1.TierLabel)
+ require.Equal(t, 2, iv1.SortOrder)
+}
+
+func TestChannelToResponse_MultipleEntries(t *testing.T) {
+ now := time.Now()
+ ch := &service.Channel{
+ ID: 1,
+ Name: "multi",
+ CreatedAt: now,
+ UpdatedAt: now,
+ ModelPricing: []service.ChannelModelPricing{
+ {
+ ID: 1,
+ Platform: "anthropic",
+ Models: []string{"claude-sonnet-4"},
+ BillingMode: service.BillingModeToken,
+ InputPrice: float64Ptr(0.003),
+ OutputPrice: float64Ptr(0.015),
+ },
+ {
+ ID: 2,
+ Platform: "openai",
+ Models: []string{"gpt-4", "gpt-4o"},
+ BillingMode: service.BillingModePerRequest,
+ PerRequestPrice: float64Ptr(1.0),
+ },
+ {
+ ID: 3,
+ Platform: "gemini",
+ Models: []string{"gemini-2.5-pro"},
+ BillingMode: service.BillingModeImage,
+ ImageOutputPrice: float64Ptr(0.05),
+ PerRequestPrice: float64Ptr(0.2),
+ },
+ },
+ }
+
+ resp := channelToResponse(ch)
+ require.Len(t, resp.ModelPricing, 3)
+
+ require.Equal(t, int64(1), resp.ModelPricing[0].ID)
+ require.Equal(t, "anthropic", resp.ModelPricing[0].Platform)
+ require.Equal(t, []string{"claude-sonnet-4"}, resp.ModelPricing[0].Models)
+ require.Equal(t, "token", resp.ModelPricing[0].BillingMode)
+
+ require.Equal(t, int64(2), resp.ModelPricing[1].ID)
+ require.Equal(t, "openai", resp.ModelPricing[1].Platform)
+ require.Equal(t, []string{"gpt-4", "gpt-4o"}, resp.ModelPricing[1].Models)
+ require.Equal(t, "per_request", resp.ModelPricing[1].BillingMode)
+
+ require.Equal(t, int64(3), resp.ModelPricing[2].ID)
+ require.Equal(t, "gemini", resp.ModelPricing[2].Platform)
+ require.Equal(t, []string{"gemini-2.5-pro"}, resp.ModelPricing[2].Models)
+ require.Equal(t, "image", resp.ModelPricing[2].BillingMode)
+ require.Equal(t, float64Ptr(0.05), resp.ModelPricing[2].ImageOutputPrice)
+}
+
+// ---------------------------------------------------------------------------
+// 2. pricingRequestToService
+// ---------------------------------------------------------------------------
+
+func TestPricingRequestToService_Defaults(t *testing.T) {
+ tests := []struct {
+ name string
+ req channelModelPricingRequest
+ wantField string // which default field to check
+ wantValue string
+ }{
+ {
+ name: "empty billing mode defaults to token",
+ req: channelModelPricingRequest{
+ Models: []string{"m1"},
+ BillingMode: "",
+ },
+ wantField: "BillingMode",
+ wantValue: string(service.BillingModeToken),
+ },
+ {
+ name: "empty platform stays empty",
+ req: channelModelPricingRequest{
+ Models: []string{"m1"},
+ Platform: "",
+ },
+ wantField: "Platform",
+ wantValue: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := pricingRequestToService([]channelModelPricingRequest{tt.req})
+ require.Len(t, result, 1)
+ switch tt.wantField {
+ case "BillingMode":
+ require.Equal(t, service.BillingMode(tt.wantValue), result[0].BillingMode)
+ case "Platform":
+ require.Equal(t, tt.wantValue, result[0].Platform)
+ }
+ })
+ }
+}
+
+func TestPricingRequestToService_WithAllFields(t *testing.T) {
+ reqs := []channelModelPricingRequest{
+ {
+ Platform: "openai",
+ Models: []string{"gpt-4", "gpt-4o"},
+ BillingMode: "per_request",
+ InputPrice: float64Ptr(0.01),
+ OutputPrice: float64Ptr(0.03),
+ CacheWritePrice: float64Ptr(0.005),
+ CacheReadPrice: float64Ptr(0.002),
+ ImageOutputPrice: float64Ptr(0.04),
+ PerRequestPrice: float64Ptr(0.5),
+ },
+ }
+
+ result := pricingRequestToService(reqs)
+ require.Len(t, result, 1)
+ r := result[0]
+ require.Equal(t, "openai", r.Platform)
+ require.Equal(t, []string{"gpt-4", "gpt-4o"}, r.Models)
+ require.Equal(t, service.BillingModePerRequest, r.BillingMode)
+ require.Equal(t, float64Ptr(0.01), r.InputPrice)
+ require.Equal(t, float64Ptr(0.03), r.OutputPrice)
+ require.Equal(t, float64Ptr(0.005), r.CacheWritePrice)
+ require.Equal(t, float64Ptr(0.002), r.CacheReadPrice)
+ require.Equal(t, float64Ptr(0.04), r.ImageOutputPrice)
+ require.Equal(t, float64Ptr(0.5), r.PerRequestPrice)
+}
+
+func TestPricingRequestToService_WithIntervals(t *testing.T) {
+ reqs := []channelModelPricingRequest{
+ {
+ Models: []string{"m1"},
+ BillingMode: "per_request",
+ Intervals: []pricingIntervalRequest{
+ {
+ MinTokens: 0,
+ MaxTokens: intPtr(2000),
+ TierLabel: "small",
+ InputPrice: float64Ptr(0.01),
+ OutputPrice: float64Ptr(0.02),
+ CacheWritePrice: float64Ptr(0.003),
+ CacheReadPrice: float64Ptr(0.001),
+ PerRequestPrice: float64Ptr(0.1),
+ SortOrder: 1,
+ },
+ {
+ MinTokens: 2000,
+ MaxTokens: nil,
+ TierLabel: "large",
+ SortOrder: 2,
+ },
+ },
+ },
+ }
+
+ result := pricingRequestToService(reqs)
+ require.Len(t, result, 1)
+ require.Len(t, result[0].Intervals, 2)
+
+ iv0 := result[0].Intervals[0]
+ require.Equal(t, 0, iv0.MinTokens)
+ require.Equal(t, intPtr(2000), iv0.MaxTokens)
+ require.Equal(t, "small", iv0.TierLabel)
+ require.Equal(t, float64Ptr(0.01), iv0.InputPrice)
+ require.Equal(t, float64Ptr(0.02), iv0.OutputPrice)
+ require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice)
+ require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice)
+ require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice)
+ require.Equal(t, 1, iv0.SortOrder)
+
+ iv1 := result[0].Intervals[1]
+ require.Equal(t, 2000, iv1.MinTokens)
+ require.Nil(t, iv1.MaxTokens)
+ require.Equal(t, "large", iv1.TierLabel)
+ require.Equal(t, 2, iv1.SortOrder)
+}
+
+func TestPricingRequestToService_EmptySlice(t *testing.T) {
+ result := pricingRequestToService([]channelModelPricingRequest{})
+ require.NotNil(t, result)
+ require.Empty(t, result)
+}
+
+func TestPricingRequestToService_NilPriceFields(t *testing.T) {
+ reqs := []channelModelPricingRequest{
+ {
+ Models: []string{"m1"},
+ BillingMode: "token",
+ // all price fields are nil by default
+ },
+ }
+
+ result := pricingRequestToService(reqs)
+ require.Len(t, result, 1)
+ r := result[0]
+ require.Nil(t, r.InputPrice)
+ require.Nil(t, r.OutputPrice)
+ require.Nil(t, r.CacheWritePrice)
+ require.Nil(t, r.CacheReadPrice)
+ require.Nil(t, r.ImageOutputPrice)
+ require.Nil(t, r.PerRequestPrice)
+}
diff --git a/backend/internal/handler/admin/channel_monitor_handler.go b/backend/internal/handler/admin/channel_monitor_handler.go
new file mode 100644
index 00000000..e92c81fe
--- /dev/null
+++ b/backend/internal/handler/admin/channel_monitor_handler.go
@@ -0,0 +1,427 @@
+package admin
+
+import (
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ // monitorMaxPageSize 列表分页上限。
+ monitorMaxPageSize = 100
+ // monitorAPIKeyMaskPrefix 脱敏时保留的明文前缀长度。
+ monitorAPIKeyMaskPrefix = 4
+ // monitorAPIKeyMaskSuffix 脱敏后追加的占位字符串。
+ monitorAPIKeyMaskSuffix = "***"
+)
+
+// ChannelMonitorHandler 渠道监控管理后台 handler。
+type ChannelMonitorHandler struct {
+ monitorService *service.ChannelMonitorService
+}
+
+// NewChannelMonitorHandler 创建 handler。
+func NewChannelMonitorHandler(monitorService *service.ChannelMonitorService) *ChannelMonitorHandler {
+ return &ChannelMonitorHandler{monitorService: monitorService}
+}
+
+// --- Request / Response ---
+
+type channelMonitorCreateRequest struct {
+ Name string `json:"name" binding:"required,max=100"`
+ Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
+ Endpoint string `json:"endpoint" binding:"required,max=500"`
+ APIKey string `json:"api_key" binding:"required,max=2000"`
+ PrimaryModel string `json:"primary_model" binding:"required,max=200"`
+ ExtraModels []string `json:"extra_models"`
+ GroupName string `json:"group_name" binding:"max=100"`
+ Enabled *bool `json:"enabled"`
+ IntervalSeconds int `json:"interval_seconds" binding:"required,min=15,max=3600"`
+ TemplateID *int64 `json:"template_id"`
+ ExtraHeaders map[string]string `json:"extra_headers"`
+ BodyOverrideMode string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
+ BodyOverride map[string]any `json:"body_override"`
+}
+
+type channelMonitorUpdateRequest struct {
+ Name *string `json:"name" binding:"omitempty,max=100"`
+ Provider *string `json:"provider" binding:"omitempty,oneof=openai anthropic gemini"`
+ Endpoint *string `json:"endpoint" binding:"omitempty,max=500"`
+ APIKey *string `json:"api_key" binding:"omitempty,max=2000"`
+ PrimaryModel *string `json:"primary_model" binding:"omitempty,max=200"`
+ ExtraModels *[]string `json:"extra_models"`
+ GroupName *string `json:"group_name" binding:"omitempty,max=100"`
+ Enabled *bool `json:"enabled"`
+ IntervalSeconds *int `json:"interval_seconds" binding:"omitempty,min=15,max=3600"`
+ TemplateID *int64 `json:"template_id"`
+ ClearTemplate bool `json:"clear_template"` // true 时把 template_id 置空,忽略 TemplateID
+ ExtraHeaders *map[string]string `json:"extra_headers"`
+ BodyOverrideMode *string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
+ BodyOverride *map[string]any `json:"body_override"`
+}
+
+type channelMonitorResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Provider string `json:"provider"`
+ Endpoint string `json:"endpoint"`
+ APIKeyMasked string `json:"api_key_masked"`
+ APIKeyDecryptFailed bool `json:"api_key_decrypt_failed"`
+ PrimaryModel string `json:"primary_model"`
+ ExtraModels []string `json:"extra_models"`
+ GroupName string `json:"group_name"`
+ Enabled bool `json:"enabled"`
+ IntervalSeconds int `json:"interval_seconds"`
+ LastCheckedAt *string `json:"last_checked_at"`
+ CreatedBy int64 `json:"created_by"`
+ CreatedAt string `json:"created_at"`
+ UpdatedAt string `json:"updated_at"`
+ PrimaryStatus string `json:"primary_status"`
+ PrimaryLatencyMs *int `json:"primary_latency_ms"`
+ Availability7d float64 `json:"availability_7d"`
+ ExtraModelsStatus []dto.ChannelMonitorExtraModelStatus `json:"extra_models_status"`
+ // 请求自定义快照:前端编辑 / 展示「高级设置」用
+ TemplateID *int64 `json:"template_id"`
+ ExtraHeaders map[string]string `json:"extra_headers"`
+ BodyOverrideMode string `json:"body_override_mode"`
+ BodyOverride map[string]any `json:"body_override"`
+}
+
+type channelMonitorCheckResultResponse struct {
+ Model string `json:"model"`
+ Status string `json:"status"`
+ LatencyMs *int `json:"latency_ms"`
+ PingLatencyMs *int `json:"ping_latency_ms"`
+ Message string `json:"message"`
+ CheckedAt string `json:"checked_at"`
+}
+
+type channelMonitorHistoryItemResponse struct {
+ ID int64 `json:"id"`
+ Model string `json:"model"`
+ Status string `json:"status"`
+ LatencyMs *int `json:"latency_ms"`
+ PingLatencyMs *int `json:"ping_latency_ms"`
+ Message string `json:"message"`
+ CheckedAt string `json:"checked_at"`
+}
+
+// maskAPIKey 对 API Key 明文做脱敏:前 4 字符 + "***",长度 ≤ 4 时只显示 "***"。
+func maskAPIKey(plain string) string {
+ if len(plain) <= monitorAPIKeyMaskPrefix {
+ return monitorAPIKeyMaskSuffix
+ }
+ return plain[:monitorAPIKeyMaskPrefix] + monitorAPIKeyMaskSuffix
+}
+
+func channelMonitorToResponse(m *service.ChannelMonitor) *channelMonitorResponse {
+ if m == nil {
+ return nil
+ }
+ extras := m.ExtraModels
+ if extras == nil {
+ extras = []string{}
+ }
+ headers := m.ExtraHeaders
+ if headers == nil {
+ headers = map[string]string{}
+ }
+ resp := &channelMonitorResponse{
+ ID: m.ID,
+ Name: m.Name,
+ Provider: m.Provider,
+ Endpoint: m.Endpoint,
+ APIKeyMasked: maskAPIKey(m.APIKey),
+ APIKeyDecryptFailed: m.APIKeyDecryptFailed,
+ PrimaryModel: m.PrimaryModel,
+ ExtraModels: extras,
+ GroupName: m.GroupName,
+ Enabled: m.Enabled,
+ IntervalSeconds: m.IntervalSeconds,
+ CreatedBy: m.CreatedBy,
+ CreatedAt: m.CreatedAt.UTC().Format(time.RFC3339),
+ UpdatedAt: m.UpdatedAt.UTC().Format(time.RFC3339),
+ TemplateID: m.TemplateID,
+ ExtraHeaders: headers,
+ BodyOverrideMode: m.BodyOverrideMode,
+ BodyOverride: m.BodyOverride,
+ // PrimaryStatus / PrimaryLatencyMs / Availability7d 由 List handler 在批量聚合后填充。
+ }
+ if m.LastCheckedAt != nil {
+ s := m.LastCheckedAt.UTC().Format(time.RFC3339)
+ resp.LastCheckedAt = &s
+ }
+ return resp
+}
+
+func checkResultToResponse(r *service.CheckResult) channelMonitorCheckResultResponse {
+ return channelMonitorCheckResultResponse{
+ Model: r.Model,
+ Status: r.Status,
+ LatencyMs: r.LatencyMs,
+ PingLatencyMs: r.PingLatencyMs,
+ Message: r.Message,
+ CheckedAt: r.CheckedAt.UTC().Format(time.RFC3339),
+ }
+}
+
+func historyEntryToResponse(e *service.ChannelMonitorHistoryEntry) channelMonitorHistoryItemResponse {
+ return channelMonitorHistoryItemResponse{
+ ID: e.ID,
+ Model: e.Model,
+ Status: e.Status,
+ LatencyMs: e.LatencyMs,
+ PingLatencyMs: e.PingLatencyMs,
+ Message: e.Message,
+ CheckedAt: e.CheckedAt.UTC().Format(time.RFC3339),
+ }
+}
+
+// ParseChannelMonitorID 提取并校验路径参数 :id(admin 与 user handler 共享)。
+// 校验失败时已写入 4xx 响应,调用方只需 return。
+func ParseChannelMonitorID(c *gin.Context) (int64, bool) {
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil || id <= 0 {
+ response.ErrorFrom(c, infraerrors.BadRequest("INVALID_MONITOR_ID", "invalid monitor id"))
+ return 0, false
+ }
+ return id, true
+}
+
+// parseListEnabled 解析 enabled query 参数:true/false 转为 *bool,空或非法则返回 nil。
+func parseListEnabled(raw string) *bool {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "true", "1", "yes":
+ v := true
+ return &v
+ case "false", "0", "no":
+ v := false
+ return &v
+ default:
+ return nil
+ }
+}
+
+// --- Handlers ---
+
+// List GET /api/v1/admin/channel-monitors
+func (h *ChannelMonitorHandler) List(c *gin.Context) {
+ page, pageSize := response.ParsePagination(c)
+ if pageSize > monitorMaxPageSize {
+ pageSize = monitorMaxPageSize
+ }
+
+ params := service.ChannelMonitorListParams{
+ Page: page,
+ PageSize: pageSize,
+ Provider: strings.TrimSpace(c.Query("provider")),
+ Enabled: parseListEnabled(c.Query("enabled")),
+ Search: strings.TrimSpace(c.Query("search")),
+ }
+
+ items, total, err := h.monitorService.List(c.Request.Context(), params)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ summaries := h.batchSummaryFor(c, items)
+ out := make([]*channelMonitorResponse, 0, len(items))
+ for _, m := range items {
+ out = append(out, buildListItemResponse(m, summaries[m.ID]))
+ }
+ response.Paginated(c, out, total, page, pageSize)
+}
+
+// batchSummaryFor 批量聚合 latest + 7d 可用率,避免每行 2 次 SQL(消除 N+1)。
+func (h *ChannelMonitorHandler) batchSummaryFor(c *gin.Context, items []*service.ChannelMonitor) map[int64]service.MonitorStatusSummary {
+ ids := make([]int64, 0, len(items))
+ primaryByID := make(map[int64]string, len(items))
+ extrasByID := make(map[int64][]string, len(items))
+ for _, m := range items {
+ ids = append(ids, m.ID)
+ primaryByID[m.ID] = m.PrimaryModel
+ extrasByID[m.ID] = m.ExtraModels
+ }
+ return h.monitorService.BatchMonitorStatusSummary(c.Request.Context(), ids, primaryByID, extrasByID)
+}
+
+// buildListItemResponse 把 monitor + summary 装成 admin list 的响应行。
+func buildListItemResponse(m *service.ChannelMonitor, summary service.MonitorStatusSummary) *channelMonitorResponse {
+ resp := channelMonitorToResponse(m)
+ resp.PrimaryStatus = summary.PrimaryStatus
+ resp.PrimaryLatencyMs = summary.PrimaryLatencyMs
+ resp.Availability7d = summary.Availability7d
+ resp.ExtraModelsStatus = make([]dto.ChannelMonitorExtraModelStatus, 0, len(summary.ExtraModels))
+ for _, e := range summary.ExtraModels {
+ resp.ExtraModelsStatus = append(resp.ExtraModelsStatus, dto.ChannelMonitorExtraModelStatus{
+ Model: e.Model,
+ Status: e.Status,
+ LatencyMs: e.LatencyMs,
+ })
+ }
+ return resp
+}
+
+// Get GET /api/v1/admin/channel-monitors/:id
+func (h *ChannelMonitorHandler) Get(c *gin.Context) {
+ id, ok := ParseChannelMonitorID(c)
+ if !ok {
+ return
+ }
+ m, err := h.monitorService.Get(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, channelMonitorToResponse(m))
+}
+
+// Create POST /api/v1/admin/channel-monitors
+func (h *ChannelMonitorHandler) Create(c *gin.Context) {
+ var req channelMonitorCreateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
+ return
+ }
+
+ subject, _ := middleware2.GetAuthSubjectFromContext(c)
+
+ enabled := true
+ if req.Enabled != nil {
+ enabled = *req.Enabled
+ }
+
+ m, err := h.monitorService.Create(c.Request.Context(), service.ChannelMonitorCreateParams{
+ Name: req.Name,
+ Provider: req.Provider,
+ Endpoint: req.Endpoint,
+ APIKey: req.APIKey,
+ PrimaryModel: req.PrimaryModel,
+ ExtraModels: req.ExtraModels,
+ GroupName: req.GroupName,
+ Enabled: enabled,
+ IntervalSeconds: req.IntervalSeconds,
+ CreatedBy: subject.UserID,
+ TemplateID: req.TemplateID,
+ ExtraHeaders: req.ExtraHeaders,
+ BodyOverrideMode: req.BodyOverrideMode,
+ BodyOverride: req.BodyOverride,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Created(c, channelMonitorToResponse(m))
+}
+
+// Update PUT /api/v1/admin/channel-monitors/:id
+func (h *ChannelMonitorHandler) Update(c *gin.Context) {
+ id, ok := ParseChannelMonitorID(c)
+ if !ok {
+ return
+ }
+ var req channelMonitorUpdateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
+ return
+ }
+
+ m, err := h.monitorService.Update(c.Request.Context(), id, service.ChannelMonitorUpdateParams{
+ Name: req.Name,
+ Provider: req.Provider,
+ Endpoint: req.Endpoint,
+ APIKey: req.APIKey,
+ PrimaryModel: req.PrimaryModel,
+ ExtraModels: req.ExtraModels,
+ GroupName: req.GroupName,
+ Enabled: req.Enabled,
+ IntervalSeconds: req.IntervalSeconds,
+ TemplateID: req.TemplateID,
+ ClearTemplate: req.ClearTemplate,
+ ExtraHeaders: req.ExtraHeaders,
+ BodyOverrideMode: req.BodyOverrideMode,
+ BodyOverride: req.BodyOverride,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, channelMonitorToResponse(m))
+}
+
+// Delete DELETE /api/v1/admin/channel-monitors/:id
+func (h *ChannelMonitorHandler) Delete(c *gin.Context) {
+ id, ok := ParseChannelMonitorID(c)
+ if !ok {
+ return
+ }
+ if err := h.monitorService.Delete(c.Request.Context(), id); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, nil)
+}
+
+// Run POST /api/v1/admin/channel-monitors/:id/run
+func (h *ChannelMonitorHandler) Run(c *gin.Context) {
+ id, ok := ParseChannelMonitorID(c)
+ if !ok {
+ return
+ }
+ results, err := h.monitorService.RunCheck(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ out := make([]channelMonitorCheckResultResponse, 0, len(results))
+ for _, r := range results {
+ out = append(out, checkResultToResponse(r))
+ }
+ response.Success(c, gin.H{"results": out})
+}
+
+// History GET /api/v1/admin/channel-monitors/:id/history
+func (h *ChannelMonitorHandler) History(c *gin.Context) {
+ id, ok := ParseChannelMonitorID(c)
+ if !ok {
+ return
+ }
+ limit := parseHistoryLimit(c.Query("limit"))
+ model := strings.TrimSpace(c.Query("model"))
+
+ entries, err := h.monitorService.ListHistory(c.Request.Context(), id, model, limit)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ out := make([]channelMonitorHistoryItemResponse, 0, len(entries))
+ for _, e := range entries {
+ out = append(out, historyEntryToResponse(e))
+ }
+ response.Success(c, gin.H{"items": out})
+}
+
+// parseHistoryLimit 解析 history 接口的 limit query。
+// 使用 service 包的统一上下限常量,避免在 handler 重复定义同名魔法值。
+func parseHistoryLimit(raw string) int {
+ if strings.TrimSpace(raw) == "" {
+ return service.MonitorHistoryDefaultLimit
+ }
+ v, err := strconv.Atoi(raw)
+ if err != nil || v <= 0 {
+ return service.MonitorHistoryDefaultLimit
+ }
+ if v > service.MonitorHistoryMaxLimit {
+ return service.MonitorHistoryMaxLimit
+ }
+ return v
+}
diff --git a/backend/internal/handler/admin/channel_monitor_template_handler.go b/backend/internal/handler/admin/channel_monitor_template_handler.go
new file mode 100644
index 00000000..bebe0929
--- /dev/null
+++ b/backend/internal/handler/admin/channel_monitor_template_handler.go
@@ -0,0 +1,234 @@
+package admin
+
+import (
+ "strconv"
+ "strings"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// ChannelMonitorRequestTemplateHandler 请求模板管理后台 handler。
+type ChannelMonitorRequestTemplateHandler struct {
+ templateService *service.ChannelMonitorRequestTemplateService
+}
+
+// NewChannelMonitorRequestTemplateHandler 创建 handler。
+func NewChannelMonitorRequestTemplateHandler(templateService *service.ChannelMonitorRequestTemplateService) *ChannelMonitorRequestTemplateHandler {
+ return &ChannelMonitorRequestTemplateHandler{templateService: templateService}
+}
+
+// --- DTO ---
+
+type channelMonitorTemplateCreateRequest struct {
+ Name string `json:"name" binding:"required,max=100"`
+ Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
+ Description string `json:"description" binding:"max=500"`
+ ExtraHeaders map[string]string `json:"extra_headers"`
+ BodyOverrideMode string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
+ BodyOverride map[string]any `json:"body_override"`
+}
+
+type channelMonitorTemplateUpdateRequest struct {
+ Name *string `json:"name" binding:"omitempty,max=100"`
+ Description *string `json:"description" binding:"omitempty,max=500"`
+ ExtraHeaders *map[string]string `json:"extra_headers"`
+ BodyOverrideMode *string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
+ BodyOverride *map[string]any `json:"body_override"`
+}
+
+type channelMonitorTemplateResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Provider string `json:"provider"`
+ Description string `json:"description"`
+ ExtraHeaders map[string]string `json:"extra_headers"`
+ BodyOverrideMode string `json:"body_override_mode"`
+ BodyOverride map[string]any `json:"body_override"`
+ CreatedAt string `json:"created_at"`
+ UpdatedAt string `json:"updated_at"`
+ AssociatedMonitors int64 `json:"associated_monitors"`
+}
+
+func (h *ChannelMonitorRequestTemplateHandler) toResponse(c *gin.Context, t *service.ChannelMonitorRequestTemplate) *channelMonitorTemplateResponse {
+ if t == nil {
+ return nil
+ }
+ headers := t.ExtraHeaders
+ if headers == nil {
+ headers = map[string]string{}
+ }
+ count, _ := h.templateService.CountAssociatedMonitors(c.Request.Context(), t.ID)
+ return &channelMonitorTemplateResponse{
+ ID: t.ID,
+ Name: t.Name,
+ Provider: t.Provider,
+ Description: t.Description,
+ ExtraHeaders: headers,
+ BodyOverrideMode: t.BodyOverrideMode,
+ BodyOverride: t.BodyOverride,
+ CreatedAt: t.CreatedAt.UTC().Format(time.RFC3339),
+ UpdatedAt: t.UpdatedAt.UTC().Format(time.RFC3339),
+ AssociatedMonitors: count,
+ }
+}
+
+// parseTemplateID 提取并校验 :id。
+func parseTemplateID(c *gin.Context) (int64, bool) {
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil || id <= 0 {
+ response.ErrorFrom(c, infraerrors.BadRequest("INVALID_TEMPLATE_ID", "invalid template id"))
+ return 0, false
+ }
+ return id, true
+}
+
+// --- Handlers ---
+
+// List GET /api/v1/admin/channel-monitor-templates?provider=anthropic
+func (h *ChannelMonitorRequestTemplateHandler) List(c *gin.Context) {
+ items, err := h.templateService.List(c.Request.Context(), service.ChannelMonitorRequestTemplateListParams{
+ Provider: strings.TrimSpace(c.Query("provider")),
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ out := make([]*channelMonitorTemplateResponse, 0, len(items))
+ for _, t := range items {
+ out = append(out, h.toResponse(c, t))
+ }
+ response.Success(c, gin.H{"items": out})
+}
+
+// Get GET /api/v1/admin/channel-monitor-templates/:id
+func (h *ChannelMonitorRequestTemplateHandler) Get(c *gin.Context) {
+ id, ok := parseTemplateID(c)
+ if !ok {
+ return
+ }
+ t, err := h.templateService.Get(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, h.toResponse(c, t))
+}
+
+// Create POST /api/v1/admin/channel-monitor-templates
+func (h *ChannelMonitorRequestTemplateHandler) Create(c *gin.Context) {
+ var req channelMonitorTemplateCreateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
+ return
+ }
+ t, err := h.templateService.Create(c.Request.Context(), service.ChannelMonitorRequestTemplateCreateParams{
+ Name: req.Name,
+ Provider: req.Provider,
+ Description: req.Description,
+ ExtraHeaders: req.ExtraHeaders,
+ BodyOverrideMode: req.BodyOverrideMode,
+ BodyOverride: req.BodyOverride,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Created(c, h.toResponse(c, t))
+}
+
+// Update PUT /api/v1/admin/channel-monitor-templates/:id
+func (h *ChannelMonitorRequestTemplateHandler) Update(c *gin.Context) {
+ id, ok := parseTemplateID(c)
+ if !ok {
+ return
+ }
+ var req channelMonitorTemplateUpdateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
+ return
+ }
+ t, err := h.templateService.Update(c.Request.Context(), id, service.ChannelMonitorRequestTemplateUpdateParams{
+ Name: req.Name,
+ Description: req.Description,
+ ExtraHeaders: req.ExtraHeaders,
+ BodyOverrideMode: req.BodyOverrideMode,
+ BodyOverride: req.BodyOverride,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, h.toResponse(c, t))
+}
+
+// Delete DELETE /api/v1/admin/channel-monitor-templates/:id
+func (h *ChannelMonitorRequestTemplateHandler) Delete(c *gin.Context) {
+ id, ok := parseTemplateID(c)
+ if !ok {
+ return
+ }
+ if err := h.templateService.Delete(c.Request.Context(), id); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, nil)
+}
+
+type channelMonitorTemplateApplyRequest struct {
+ // MonitorIDs 必填、非空:用户在 picker 里勾选的要被覆盖的监控 ID 列表。
+ // 仅当对应监控当前 template_id == :id 时才会真的被覆盖。
+ MonitorIDs []int64 `json:"monitor_ids" binding:"required,min=1"`
+}
+
+// Apply POST /api/v1/admin/channel-monitor-templates/:id/apply
+// 把模板当前配置覆盖到 monitor_ids 列表里的关联监控(picker 选中的子集)。
+func (h *ChannelMonitorRequestTemplateHandler) Apply(c *gin.Context) {
+ id, ok := parseTemplateID(c)
+ if !ok {
+ return
+ }
+ var req channelMonitorTemplateApplyRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
+ return
+ }
+ affected, err := h.templateService.ApplyToMonitors(c.Request.Context(), id, req.MonitorIDs)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"affected": affected})
+}
+
+type associatedMonitorBriefResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Provider string `json:"provider"`
+ Enabled bool `json:"enabled"`
+}
+
+// AssociatedMonitors GET /api/v1/admin/channel-monitor-templates/:id/monitors
+// 列出关联监控(picker 弹窗用)。
+func (h *ChannelMonitorRequestTemplateHandler) AssociatedMonitors(c *gin.Context) {
+ id, ok := parseTemplateID(c)
+ if !ok {
+ return
+ }
+ items, err := h.templateService.ListAssociatedMonitors(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ out := make([]associatedMonitorBriefResponse, 0, len(items))
+ for _, m := range items {
+ out = append(out, associatedMonitorBriefResponse{
+ ID: m.ID, Name: m.Name, Provider: m.Provider, Enabled: m.Enabled,
+ })
+ }
+ response.Success(c, gin.H{"items": out})
+}
diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go
index 2a214471..460f6357 100644
--- a/backend/internal/handler/admin/dashboard_handler.go
+++ b/backend/internal/handler/admin/dashboard_handler.go
@@ -636,6 +636,40 @@ func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) {
dim.Endpoint = c.Query("endpoint")
dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound")
+ // Additional filter conditions
+ if v := c.Query("user_id"); v != "" {
+ if id, err := strconv.ParseInt(v, 10, 64); err == nil {
+ dim.UserID = id
+ }
+ }
+ if v := c.Query("api_key_id"); v != "" {
+ if id, err := strconv.ParseInt(v, 10, 64); err == nil {
+ dim.APIKeyID = id
+ }
+ }
+ if v := c.Query("account_id"); v != "" {
+ if id, err := strconv.ParseInt(v, 10, 64); err == nil {
+ dim.AccountID = id
+ }
+ }
+ if v := c.Query("request_type"); v != "" {
+ if rt, err := strconv.ParseInt(v, 10, 16); err == nil {
+ rtVal := int16(rt)
+ dim.RequestType = &rtVal
+ }
+ }
+ if v := c.Query("stream"); v != "" {
+ if s, err := strconv.ParseBool(v); err == nil {
+ dim.Stream = &s
+ }
+ }
+ if v := c.Query("billing_type"); v != "" {
+ if bt, err := strconv.ParseInt(v, 10, 8); err == nil {
+ btVal := int8(bt)
+ dim.BillingType = &btVal
+ }
+ }
+
limit := 50
if v := c.Query("limit"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 {
diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go
index caa27bc3..65e5ec78 100644
--- a/backend/internal/handler/admin/group_handler.go
+++ b/backend/internal/handler/admin/group_handler.go
@@ -84,7 +84,7 @@ func NewGroupHandler(adminService service.AdminService, dashboardService *servic
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
- Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
+ Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
@@ -95,10 +95,6 @@ type CreateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
- SoraImagePrice360 *float64 `json:"sora_image_price_360"`
- SoraImagePrice540 *float64 `json:"sora_image_price_540"`
- SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
- SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
@@ -108,13 +104,14 @@ type CreateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
- // Sora 存储配额
- SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
- AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
- RequireOAuthOnly bool `json:"require_oauth_only"`
- RequirePrivacySet bool `json:"require_privacy_set"`
- DefaultMappedModel string `json:"default_mapped_model"`
+ AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
+ RequireOAuthOnly bool `json:"require_oauth_only"`
+ RequirePrivacySet bool `json:"require_privacy_set"`
+ DefaultMappedModel string `json:"default_mapped_model"`
+ MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
+ // 分组 RPM 上限(0 = 不限制)
+ RPMLimit int `json:"rpm_limit"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -123,7 +120,7 @@ type CreateGroupRequest struct {
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
- Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
+ Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
@@ -135,10 +132,6 @@ type UpdateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
- SoraImagePrice360 *float64 `json:"sora_image_price_360"`
- SoraImagePrice540 *float64 `json:"sora_image_price_540"`
- SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
- SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
@@ -148,13 +141,14 @@ type UpdateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"`
- // Sora 存储配额
- SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
- AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
- RequireOAuthOnly *bool `json:"require_oauth_only"`
- RequirePrivacySet *bool `json:"require_privacy_set"`
- DefaultMappedModel *string `json:"default_mapped_model"`
+ AllowMessagesDispatch *bool `json:"allow_messages_dispatch"`
+ RequireOAuthOnly *bool `json:"require_oauth_only"`
+ RequirePrivacySet *bool `json:"require_privacy_set"`
+ DefaultMappedModel *string `json:"default_mapped_model"`
+ MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
+ // 分组 RPM 上限(0 = 不限制);nil 表示未提供不改动
+ RPMLimit *int `json:"rpm_limit"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -172,6 +166,8 @@ func (h *GroupHandler) List(c *gin.Context) {
search = search[:100]
}
isExclusiveStr := c.Query("is_exclusive")
+ sortBy := c.DefaultQuery("sort_by", "sort_order")
+ sortOrder := c.DefaultQuery("sort_order", "asc")
var isExclusive *bool
if isExclusiveStr != "" {
@@ -179,7 +175,7 @@ func (h *GroupHandler) List(c *gin.Context) {
isExclusive = &val
}
- groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive)
+ groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -258,10 +254,6 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
- SoraImagePrice360: req.SoraImagePrice360,
- SoraImagePrice540: req.SoraImagePrice540,
- SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
- SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
@@ -269,11 +261,12 @@ func (h *GroupHandler) Create(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
- SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
AllowMessagesDispatch: req.AllowMessagesDispatch,
RequireOAuthOnly: req.RequireOAuthOnly,
RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel,
+ MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
+ RPMLimit: req.RPMLimit,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
@@ -313,10 +306,6 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
- SoraImagePrice360: req.SoraImagePrice360,
- SoraImagePrice540: req.SoraImagePrice540,
- SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
- SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
@@ -324,11 +313,12 @@ func (h *GroupHandler) Update(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
- SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
AllowMessagesDispatch: req.AllowMessagesDispatch,
RequireOAuthOnly: req.RequireOAuthOnly,
RequirePrivacySet: req.RequirePrivacySet,
DefaultMappedModel: req.DefaultMappedModel,
+ MessagesDispatchModelConfig: req.MessagesDispatchModelConfig,
+ RPMLimit: req.RPMLimit,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
@@ -493,6 +483,51 @@ func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) {
response.Success(c, gin.H{"message": "Rate multipliers updated successfully"})
}
+// BatchSetGroupRPMOverridesRequest represents batch set rpm_override request
+type BatchSetGroupRPMOverridesRequest struct {
+ Entries []service.GroupRPMOverrideInput `json:"entries" binding:"required"`
+}
+
+// BatchSetGroupRPMOverrides handles batch setting rpm_override for users in a group
+// PUT /api/v1/admin/groups/:id/rpm-overrides
+func (h *GroupHandler) BatchSetGroupRPMOverrides(c *gin.Context) {
+ groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid group ID")
+ return
+ }
+
+ var req BatchSetGroupRPMOverridesRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.adminService.BatchSetGroupRPMOverrides(c.Request.Context(), groupID, req.Entries); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"message": "RPM overrides updated successfully"})
+}
+
+// ClearGroupRPMOverrides handles clearing all rpm_override for a group
+// DELETE /api/v1/admin/groups/:id/rpm-overrides
+func (h *GroupHandler) ClearGroupRPMOverrides(c *gin.Context) {
+ groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid group ID")
+ return
+ }
+
+ if err := h.adminService.ClearGroupRPMOverrides(c.Request.Context(), groupID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"message": "RPM overrides cleared successfully"})
+}
+
// UpdateSortOrderRequest represents the request to update group sort orders
type UpdateSortOrderRequest struct {
Updates []struct {
diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go
index 4e6179db..cc0c9337 100644
--- a/backend/internal/handler/admin/openai_oauth_handler.go
+++ b/backend/internal/handler/admin/openai_oauth_handler.go
@@ -19,9 +19,6 @@ type OpenAIOAuthHandler struct {
}
func oauthPlatformFromPath(c *gin.Context) string {
- if strings.Contains(c.FullPath(), "/admin/sora/") {
- return service.PlatformSora
- }
return service.PlatformOpenAI
}
@@ -105,7 +102,6 @@ type OpenAIRefreshTokenRequest struct {
// RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token
-// POST /api/v1/admin/sora/rt2at
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
@@ -145,39 +141,8 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
response.Success(c, tokenInfo)
}
-// ExchangeSoraSessionToken exchanges Sora session token to access token
-// POST /api/v1/admin/sora/st2at
-func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
- var req struct {
- SessionToken string `json:"session_token"`
- ST string `json:"st"`
- ProxyID *int64 `json:"proxy_id"`
- }
- if err := c.ShouldBindJSON(&req); err != nil {
- response.BadRequest(c, "Invalid request: "+err.Error())
- return
- }
-
- sessionToken := strings.TrimSpace(req.SessionToken)
- if sessionToken == "" {
- sessionToken = strings.TrimSpace(req.ST)
- }
- if sessionToken == "" {
- response.BadRequest(c, "session_token is required")
- return
- }
-
- tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, tokenInfo)
-}
-
-// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
+// RefreshAccountToken refreshes token for a specific OpenAI account
// POST /api/v1/admin/openai/accounts/:id/refresh
-// POST /api/v1/admin/sora/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
@@ -232,9 +197,8 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
response.Success(c, dto.AccountFromService(updatedAccount))
}
-// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
+// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth
-// POST /api/v1/admin/sora/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct {
SessionID string `json:"session_id" binding:"required"`
@@ -276,11 +240,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
name = tokenInfo.Email
}
if name == "" {
- if platform == service.PlatformSora {
- name = "Sora OAuth Account"
- } else {
- name = "OpenAI OAuth Account"
- }
+ name = "OpenAI OAuth Account"
}
// Create account
diff --git a/backend/internal/handler/admin/payment_handler.go b/backend/internal/handler/admin/payment_handler.go
new file mode 100644
index 00000000..84359cd9
--- /dev/null
+++ b/backend/internal/handler/admin/payment_handler.go
@@ -0,0 +1,344 @@
+package admin
+
+import (
+ "strconv"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// PaymentHandler handles admin payment management.
+type PaymentHandler struct {
+ paymentService *service.PaymentService
+ configService *service.PaymentConfigService
+}
+
+// NewPaymentHandler creates a new admin PaymentHandler.
+func NewPaymentHandler(paymentService *service.PaymentService, configService *service.PaymentConfigService) *PaymentHandler {
+ return &PaymentHandler{
+ paymentService: paymentService,
+ configService: configService,
+ }
+}
+
+// --- Dashboard ---
+
+// GetDashboard returns payment dashboard statistics.
+// GET /api/v1/admin/payment/dashboard
+func (h *PaymentHandler) GetDashboard(c *gin.Context) {
+ days := 30
+ if d := c.Query("days"); d != "" {
+ if v, err := strconv.Atoi(d); err == nil && v > 0 {
+ days = v
+ }
+ }
+ stats, err := h.paymentService.GetDashboardStats(c.Request.Context(), days)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, stats)
+}
+
+// --- Orders ---
+
+// ListOrders returns a paginated list of all payment orders.
+// GET /api/v1/admin/payment/orders
+func (h *PaymentHandler) ListOrders(c *gin.Context) {
+ page, pageSize := response.ParsePagination(c)
+ var userID int64
+ if uid := c.Query("user_id"); uid != "" {
+ if v, err := strconv.ParseInt(uid, 10, 64); err == nil {
+ userID = v
+ }
+ }
+ orders, total, err := h.paymentService.AdminListOrders(c.Request.Context(), userID, service.OrderListParams{
+ Page: page,
+ PageSize: pageSize,
+ Status: c.Query("status"),
+ OrderType: c.Query("order_type"),
+ PaymentType: c.Query("payment_type"),
+ Keyword: c.Query("keyword"),
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Paginated(c, sanitizeAdminPaymentOrdersForResponse(orders), int64(total), page, pageSize)
+}
+
+// GetOrderDetail returns detailed information about a single order.
+// GET /api/v1/admin/payment/orders/:id
+func (h *PaymentHandler) GetOrderDetail(c *gin.Context) {
+ orderID, ok := parseIDParam(c, "id")
+ if !ok {
+ return
+ }
+ order, err := h.paymentService.GetOrderByID(c.Request.Context(), orderID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ auditLogs, _ := h.paymentService.GetOrderAuditLogs(c.Request.Context(), orderID)
+ response.Success(c, gin.H{"order": sanitizeAdminPaymentOrderForResponse(order), "auditLogs": auditLogs})
+}
+
+// CancelOrder cancels a pending order (admin).
+// POST /api/v1/admin/payment/orders/:id/cancel
+func (h *PaymentHandler) CancelOrder(c *gin.Context) {
+ orderID, ok := parseIDParam(c, "id")
+ if !ok {
+ return
+ }
+ msg, err := h.paymentService.AdminCancelOrder(c.Request.Context(), orderID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"message": msg})
+}
+
+// RetryFulfillment retries fulfillment for a paid order.
+// POST /api/v1/admin/payment/orders/:id/retry
+func (h *PaymentHandler) RetryFulfillment(c *gin.Context) {
+ orderID, ok := parseIDParam(c, "id")
+ if !ok {
+ return
+ }
+ if err := h.paymentService.RetryFulfillment(c.Request.Context(), orderID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"message": "fulfillment retried"})
+}
+
+func sanitizeAdminPaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder {
+ if len(orders) == 0 {
+ return orders
+ }
+ out := make([]*dbent.PaymentOrder, 0, len(orders))
+ for _, order := range orders {
+ out = append(out, sanitizeAdminPaymentOrderForResponse(order))
+ }
+ return out
+}
+
+func sanitizeAdminPaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder {
+ if order == nil {
+ return nil
+ }
+ cloned := *order
+ cloned.ProviderSnapshot = nil
+ return &cloned
+}
+
+// AdminProcessRefundRequest is the request body for admin refund processing.
+type AdminProcessRefundRequest struct {
+ Amount float64 `json:"amount"`
+ Reason string `json:"reason"`
+ Force bool `json:"force"`
+ DeductBalance bool `json:"deduct_balance"`
+}
+
+// ProcessRefund processes a refund for an order (admin).
+// POST /api/v1/admin/payment/orders/:id/refund
+func (h *PaymentHandler) ProcessRefund(c *gin.Context) {
+ orderID, ok := parseIDParam(c, "id")
+ if !ok {
+ return
+ }
+
+ var req AdminProcessRefundRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ plan, earlyResult, err := h.paymentService.PrepareRefund(c.Request.Context(), orderID, req.Amount, req.Reason, req.Force, req.DeductBalance)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if earlyResult != nil {
+ response.Success(c, earlyResult)
+ return
+ }
+
+ result, err := h.paymentService.ExecuteRefund(c.Request.Context(), plan)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, result)
+}
+
+// --- Subscription Plans ---
+
+// ListPlans returns all subscription plans.
+// GET /api/v1/admin/payment/plans
+func (h *PaymentHandler) ListPlans(c *gin.Context) {
+ plans, err := h.configService.ListPlans(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, plans)
+}
+
+// CreatePlan creates a new subscription plan.
+// POST /api/v1/admin/payment/plans
+func (h *PaymentHandler) CreatePlan(c *gin.Context) {
+ var req service.CreatePlanRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ plan, err := h.configService.CreatePlan(c.Request.Context(), req)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Created(c, plan)
+}
+
+// UpdatePlan updates an existing subscription plan.
+// PUT /api/v1/admin/payment/plans/:id
+func (h *PaymentHandler) UpdatePlan(c *gin.Context) {
+ id, ok := parseIDParam(c, "id")
+ if !ok {
+ return
+ }
+ var req service.UpdatePlanRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ plan, err := h.configService.UpdatePlan(c.Request.Context(), id, req)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, plan)
+}
+
+// DeletePlan deletes a subscription plan.
+// DELETE /api/v1/admin/payment/plans/:id
+func (h *PaymentHandler) DeletePlan(c *gin.Context) {
+ id, ok := parseIDParam(c, "id")
+ if !ok {
+ return
+ }
+ if err := h.configService.DeletePlan(c.Request.Context(), id); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"message": "deleted"})
+}
+
+// --- Provider Instances ---
+
+// ListProviders returns all payment provider instances.
+// GET /api/v1/admin/payment/providers
+func (h *PaymentHandler) ListProviders(c *gin.Context) {
+ providers, err := h.configService.ListProviderInstancesWithConfig(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, providers)
+}
+
+// CreateProvider creates a new payment provider instance.
+// POST /api/v1/admin/payment/providers
+func (h *PaymentHandler) CreateProvider(c *gin.Context) {
+ var req service.CreateProviderInstanceRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ inst, err := h.configService.CreateProviderInstance(c.Request.Context(), req)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ h.paymentService.RefreshProviders(c.Request.Context())
+ response.Created(c, inst)
+}
+
+// UpdateProvider updates an existing payment provider instance.
+// PUT /api/v1/admin/payment/providers/:id
+func (h *PaymentHandler) UpdateProvider(c *gin.Context) {
+ id, ok := parseIDParam(c, "id")
+ if !ok {
+ return
+ }
+ var req service.UpdateProviderInstanceRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ inst, err := h.configService.UpdateProviderInstance(c.Request.Context(), id, req)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ h.paymentService.RefreshProviders(c.Request.Context())
+ response.Success(c, inst)
+}
+
+// DeleteProvider deletes a payment provider instance.
+// DELETE /api/v1/admin/payment/providers/:id
+func (h *PaymentHandler) DeleteProvider(c *gin.Context) {
+ id, ok := parseIDParam(c, "id")
+ if !ok {
+ return
+ }
+ if err := h.configService.DeleteProviderInstance(c.Request.Context(), id); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ h.paymentService.RefreshProviders(c.Request.Context())
+ response.Success(c, gin.H{"message": "deleted"})
+}
+
+// parseIDParam parses an int64 path parameter.
+// Returns the parsed ID and true on success; on failure it writes a BadRequest response and returns false.
+func parseIDParam(c *gin.Context, paramName string) (int64, bool) {
+ id, err := strconv.ParseInt(c.Param(paramName), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid "+paramName)
+ return 0, false
+ }
+ return id, true
+}
+
+// --- Config ---
+
+// GetConfig returns the payment configuration (admin view).
+// GET /api/v1/admin/payment/config
+func (h *PaymentHandler) GetConfig(c *gin.Context) {
+ cfg, err := h.configService.GetPaymentConfig(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, cfg)
+}
+
+// UpdateConfig updates the payment configuration.
+// PUT /api/v1/admin/payment/config
+func (h *PaymentHandler) UpdateConfig(c *gin.Context) {
+ var req service.UpdatePaymentConfigRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ if err := h.configService.UpdatePaymentConfig(c.Request.Context(), req); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"message": "updated"})
+}
diff --git a/backend/internal/handler/admin/promo_handler.go b/backend/internal/handler/admin/promo_handler.go
index 3eafa380..77d5f171 100644
--- a/backend/internal/handler/admin/promo_handler.go
+++ b/backend/internal/handler/admin/promo_handler.go
@@ -55,8 +55,10 @@ func (h *PromoHandler) List(c *gin.Context) {
}
params := pagination.PaginationParams{
- Page: page,
- PageSize: pageSize,
+ Page: page,
+ PageSize: pageSize,
+ SortBy: c.DefaultQuery("sort_by", "created_at"),
+ SortOrder: c.DefaultQuery("sort_order", "desc"),
}
codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search)
diff --git a/backend/internal/handler/admin/proxy_data.go b/backend/internal/handler/admin/proxy_data.go
index 72ecd6c1..8149ce3b 100644
--- a/backend/internal/handler/admin/proxy_data.go
+++ b/backend/internal/handler/admin/proxy_data.go
@@ -33,11 +33,13 @@ func (h *ProxyHandler) ExportData(c *gin.Context) {
protocol := c.Query("protocol")
status := c.Query("status")
search := strings.TrimSpace(c.Query("search"))
+ sortBy := c.DefaultQuery("sort_by", "id")
+ sortOrder := c.DefaultQuery("sort_order", "desc")
if len(search) > 100 {
search = search[:100]
}
- proxies, err = h.listProxiesFiltered(ctx, protocol, status, search)
+ proxies, err = h.listProxiesFiltered(ctx, protocol, status, search, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -89,7 +91,7 @@ func (h *ProxyHandler) ImportData(c *gin.Context) {
ctx := c.Request.Context()
result := DataImportResult{}
- existingProxies, err := h.listProxiesFiltered(ctx, "", "", "")
+ existingProxies, err := h.listProxiesFiltered(ctx, "", "", "", "id", "desc")
if err != nil {
response.ErrorFrom(c, err)
return
@@ -220,18 +222,33 @@ func parseProxyIDs(c *gin.Context) ([]int64, error) {
return ids, nil
}
-func (h *ProxyHandler) listProxiesFiltered(ctx context.Context, protocol, status, search string) ([]service.Proxy, error) {
+func (h *ProxyHandler) listProxiesFiltered(ctx context.Context, protocol, status, search, sortBy, sortOrder string) ([]service.Proxy, error) {
page := 1
pageSize := dataPageCap
var out []service.Proxy
+ sortBy = strings.TrimSpace(sortBy)
+ useAccountCountSort := strings.EqualFold(sortBy, "account_count")
for {
- items, total, err := h.adminService.ListProxies(ctx, page, pageSize, protocol, status, search)
- if err != nil {
- return nil, err
- }
- out = append(out, items...)
- if len(out) >= int(total) || len(items) == 0 {
- break
+ if useAccountCountSort {
+ items, total, err := h.adminService.ListProxiesWithAccountCount(ctx, page, pageSize, protocol, status, search, sortBy, sortOrder)
+ if err != nil {
+ return nil, err
+ }
+ for i := range items {
+ out = append(out, items[i].Proxy)
+ }
+ if len(out) >= int(total) || len(items) == 0 {
+ break
+ }
+ } else {
+ items, total, err := h.adminService.ListProxies(ctx, page, pageSize, protocol, status, search, sortBy, sortOrder)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, items...)
+ if len(out) >= int(total) || len(items) == 0 {
+ break
+ }
}
page++
}
diff --git a/backend/internal/handler/admin/proxy_data_handler_test.go b/backend/internal/handler/admin/proxy_data_handler_test.go
index 803f9b61..8cd035ed 100644
--- a/backend/internal/handler/admin/proxy_data_handler_test.go
+++ b/backend/internal/handler/admin/proxy_data_handler_test.go
@@ -74,6 +74,10 @@ func TestProxyExportDataRespectsFilters(t *testing.T) {
require.Len(t, resp.Data.Proxies, 1)
require.Len(t, resp.Data.Accounts, 0)
require.Equal(t, "https", resp.Data.Proxies[0].Protocol)
+ require.Equal(t, 1, adminSvc.lastListProxies.calls)
+ require.Equal(t, "https", adminSvc.lastListProxies.protocol)
+ require.Equal(t, "id", adminSvc.lastListProxies.sortBy)
+ require.Equal(t, "desc", adminSvc.lastListProxies.sortOrder)
}
func TestProxyExportDataWithSelectedIDs(t *testing.T) {
@@ -113,6 +117,96 @@ func TestProxyExportDataWithSelectedIDs(t *testing.T) {
require.Len(t, resp.Data.Proxies, 1)
require.Equal(t, "https", resp.Data.Proxies[0].Protocol)
require.Equal(t, "10.0.0.2", resp.Data.Proxies[0].Host)
+ require.Equal(t, 0, adminSvc.lastListProxies.calls)
+}
+
+func TestProxyExportDataPassesSortParams(t *testing.T) {
+ router, adminSvc := setupProxyDataRouter()
+
+ adminSvc.proxies = []service.Proxy{
+ {
+ ID: 1,
+ Name: "proxy-a",
+ Protocol: "http",
+ Host: "127.0.0.1",
+ Port: 8080,
+ Username: "user",
+ Password: "pass",
+ Status: service.StatusActive,
+ },
+ }
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?protocol=http&status=active&search=proxy&sort_by=name&sort_order=asc", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ require.Equal(t, 1, adminSvc.lastListProxies.calls)
+ require.Equal(t, "http", adminSvc.lastListProxies.protocol)
+ require.Equal(t, "active", adminSvc.lastListProxies.status)
+ require.Equal(t, "proxy", adminSvc.lastListProxies.search)
+ require.Equal(t, "name", adminSvc.lastListProxies.sortBy)
+ require.Equal(t, "asc", adminSvc.lastListProxies.sortOrder)
+}
+
+func TestProxyExportDataSortByAccountCountUsesAccountCountListing(t *testing.T) {
+ router, adminSvc := setupProxyDataRouter()
+
+ adminSvc.proxies = []service.Proxy{
+ {
+ ID: 1,
+ Name: "proxy-id-1",
+ Protocol: "http",
+ Host: "127.0.0.1",
+ Port: 8080,
+ Status: service.StatusActive,
+ },
+ {
+ ID: 2,
+ Name: "proxy-id-2",
+ Protocol: "http",
+ Host: "127.0.0.2",
+ Port: 8081,
+ Status: service.StatusActive,
+ },
+ }
+ adminSvc.proxyCounts = []service.ProxyWithAccountCount{
+ {
+ Proxy: service.Proxy{
+ ID: 2,
+ Name: "proxy-count-high",
+ Protocol: "http",
+ Host: "127.0.0.2",
+ Port: 8081,
+ Status: service.StatusActive,
+ },
+ AccountCount: 9,
+ },
+ {
+ Proxy: service.Proxy{
+ ID: 1,
+ Name: "proxy-count-low",
+ Protocol: "http",
+ Host: "127.0.0.1",
+ Port: 8080,
+ Status: service.StatusActive,
+ },
+ AccountCount: 1,
+ },
+ }
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?sort_by=account_count&sort_order=desc", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ var resp proxyDataResponse
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Len(t, resp.Data.Proxies, 2)
+ require.Equal(t, "proxy-count-high", resp.Data.Proxies[0].Name)
+ require.Equal(t, "proxy-count-low", resp.Data.Proxies[1].Name)
+ require.Equal(t, 0, adminSvc.lastListProxies.calls)
}
func TestProxyImportDataReusesAndTriggersLatencyProbe(t *testing.T) {
diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go
index e8ae0ce2..f97fcb0a 100644
--- a/backend/internal/handler/admin/proxy_handler.go
+++ b/backend/internal/handler/admin/proxy_handler.go
@@ -52,13 +52,15 @@ func (h *ProxyHandler) List(c *gin.Context) {
protocol := c.Query("protocol")
status := c.Query("status")
search := c.Query("search")
+ sortBy := c.DefaultQuery("sort_by", "id")
+ sortOrder := c.DefaultQuery("sort_order", "desc")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
- proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search)
+ proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/admin/redeem_export_handler_test.go b/backend/internal/handler/admin/redeem_export_handler_test.go
new file mode 100644
index 00000000..9983fe31
--- /dev/null
+++ b/backend/internal/handler/admin/redeem_export_handler_test.go
@@ -0,0 +1,49 @@
+package admin
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func setupRedeemExportRouter() (*gin.Engine, *stubAdminService) {
+ gin.SetMode(gin.TestMode)
+ router := gin.New()
+ adminSvc := newStubAdminService()
+
+ h := NewRedeemHandler(adminSvc, nil)
+ router.GET("/api/v1/admin/redeem-codes/export", h.Export)
+ return router, adminSvc
+}
+
+func TestRedeemExportPassesSearchAndSort(t *testing.T) {
+ router, adminSvc := setupRedeemExportRouter()
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/export?type=balance&status=unused&search=ABC&sort_by=value&sort_order=asc", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ require.Equal(t, 1, adminSvc.lastListRedeemCodes.calls)
+ require.Equal(t, "balance", adminSvc.lastListRedeemCodes.codeType)
+ require.Equal(t, "unused", adminSvc.lastListRedeemCodes.status)
+ require.Equal(t, "ABC", adminSvc.lastListRedeemCodes.search)
+ require.Equal(t, "value", adminSvc.lastListRedeemCodes.sortBy)
+ require.Equal(t, "asc", adminSvc.lastListRedeemCodes.sortOrder)
+}
+
+func TestRedeemExportSortDefaults(t *testing.T) {
+ router, adminSvc := setupRedeemExportRouter()
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/export", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ require.Equal(t, 1, adminSvc.lastListRedeemCodes.calls)
+ require.Equal(t, "id", adminSvc.lastListRedeemCodes.sortBy)
+ require.Equal(t, "desc", adminSvc.lastListRedeemCodes.sortOrder)
+}
diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go
index 13ea88d9..24365f3d 100644
--- a/backend/internal/handler/admin/redeem_handler.go
+++ b/backend/internal/handler/admin/redeem_handler.go
@@ -35,9 +35,9 @@ func NewRedeemHandler(adminService service.AdminService, redeemService *service.
type GenerateRedeemCodesRequest struct {
Count int `json:"count" binding:"required,min=1,max=100"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
- Value float64 `json:"value" binding:"min=0"`
- GroupID *int64 `json:"group_id"` // 订阅类型必填
- ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年
+ Value float64 `json:"value"`
+ GroupID *int64 `json:"group_id"` // 订阅类型必填
+ ValidityDays int `json:"validity_days"` // 订阅类型使用,正数增加/负数退款扣减
}
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
@@ -45,10 +45,10 @@ type GenerateRedeemCodesRequest struct {
type CreateAndRedeemCodeRequest struct {
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
- Value float64 `json:"value" binding:"required,gt=0"`
+ Value float64 `json:"value" binding:"required"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
- GroupID *int64 `json:"group_id"` // subscription 类型必填
- ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // subscription 类型必填,>0
+ GroupID *int64 `json:"group_id"` // subscription 类型必填
+ ValidityDays int `json:"validity_days"` // subscription 类型:正数增加,负数退款扣减
Notes string `json:"notes"`
}
@@ -59,13 +59,15 @@ func (h *RedeemHandler) List(c *gin.Context) {
codeType := c.Query("type")
status := c.Query("status")
search := c.Query("search")
+ sortBy := c.DefaultQuery("sort_by", "id")
+ sortOrder := c.DefaultQuery("sort_order", "desc")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
}
- codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
+ codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -150,8 +152,8 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
response.BadRequest(c, "group_id is required for subscription type")
return
}
- if req.ValidityDays <= 0 {
- response.BadRequest(c, "validity_days must be greater than 0 for subscription type")
+ if req.ValidityDays == 0 {
+ response.BadRequest(c, "validity_days must not be zero for subscription type")
return
}
}
@@ -300,9 +302,15 @@ func (h *RedeemHandler) GetStats(c *gin.Context) {
func (h *RedeemHandler) Export(c *gin.Context) {
codeType := c.Query("type")
status := c.Query("status")
+ search := strings.TrimSpace(c.Query("search"))
+ sortBy := c.DefaultQuery("sort_by", "id")
+ sortOrder := c.DefaultQuery("sort_order", "desc")
+ if len(search) > 100 {
+ search = search[:100]
+ }
// Get all codes without pagination (use large page size)
- codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "")
+ codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, search, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/admin/redeem_handler_test.go b/backend/internal/handler/admin/redeem_handler_test.go
index 0d42f64f..f1f7778f 100644
--- a/backend/internal/handler/admin/redeem_handler_test.go
+++ b/backend/internal/handler/admin/redeem_handler_test.go
@@ -76,32 +76,38 @@ func TestCreateAndRedeem_SubscriptionRequiresGroupID(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, code)
}
-func TestCreateAndRedeem_SubscriptionRequiresPositiveValidityDays(t *testing.T) {
+func TestCreateAndRedeem_SubscriptionRequiresNonZeroValidityDays(t *testing.T) {
groupID := int64(5)
h := newCreateAndRedeemHandler()
- cases := []struct {
- name string
- validityDays int
- }{
- {"zero", 0},
- {"negative", -1},
- }
-
- for _, tc := range cases {
- t.Run(tc.name, func(t *testing.T) {
- code := postCreateAndRedeemValidation(t, h, map[string]any{
- "code": "test-sub-bad-days-" + tc.name,
- "type": "subscription",
- "value": 29.9,
- "user_id": 1,
- "group_id": groupID,
- "validity_days": tc.validityDays,
- })
-
- assert.Equal(t, http.StatusBadRequest, code)
+ // zero should be rejected
+ t.Run("zero", func(t *testing.T) {
+ code := postCreateAndRedeemValidation(t, h, map[string]any{
+ "code": "test-sub-bad-days-zero",
+ "type": "subscription",
+ "value": 29.9,
+ "user_id": 1,
+ "group_id": groupID,
+ "validity_days": 0,
})
- }
+
+ assert.Equal(t, http.StatusBadRequest, code)
+ })
+
+ // negative should pass validation (used for refund/reduction)
+ t.Run("negative_passes_validation", func(t *testing.T) {
+ code := postCreateAndRedeemValidation(t, h, map[string]any{
+ "code": "test-sub-negative-days",
+ "type": "subscription",
+ "value": 29.9,
+ "user_id": 1,
+ "group_id": groupID,
+ "validity_days": -7,
+ })
+
+ assert.NotEqual(t, http.StatusBadRequest, code,
+ "negative validity_days should pass validation for refund")
+ })
}
func TestCreateAndRedeem_SubscriptionValidParamsPassValidation(t *testing.T) {
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index 397526a7..320dbd6b 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -5,11 +5,10 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
- "log"
+ "log/slog"
"net/http"
"regexp"
"strings"
- "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -35,23 +34,43 @@ func generateMenuItemID() (string, error) {
return hex.EncodeToString(b), nil
}
+func scopesContainOpenID(scopes string) bool {
+ for _, scope := range strings.Fields(strings.ToLower(strings.TrimSpace(scopes))) {
+ if scope == "openid" {
+ return true
+ }
+ }
+ return false
+}
+
+func firstNonEmpty(values ...string) string {
+ for _, value := range values {
+ if trimmed := strings.TrimSpace(value); trimmed != "" {
+ return trimmed
+ }
+ }
+ return ""
+}
+
// SettingHandler 系统设置处理器
type SettingHandler struct {
- settingService *service.SettingService
- emailService *service.EmailService
- turnstileService *service.TurnstileService
- opsService *service.OpsService
- soraS3Storage *service.SoraS3Storage
+ settingService *service.SettingService
+ emailService *service.EmailService
+ turnstileService *service.TurnstileService
+ opsService *service.OpsService
+ paymentConfigService *service.PaymentConfigService
+ paymentService *service.PaymentService
}
// NewSettingHandler 创建系统设置处理器
-func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler {
+func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService) *SettingHandler {
return &SettingHandler{
- settingService: settingService,
- emailService: emailService,
- turnstileService: turnstileService,
- opsService: opsService,
- soraS3Storage: soraS3Storage,
+ settingService: settingService,
+ emailService: emailService,
+ turnstileService: turnstileService,
+ opsService: opsService,
+ paymentConfigService: paymentConfigService,
+ paymentService: paymentService,
}
}
@@ -63,6 +82,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ authSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
// Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
@@ -74,64 +98,157 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
})
}
- response.Success(c, dto.SystemSettings{
- RegistrationEnabled: settings.RegistrationEnabled,
- EmailVerifyEnabled: settings.EmailVerifyEnabled,
- RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
- PromoCodeEnabled: settings.PromoCodeEnabled,
- PasswordResetEnabled: settings.PasswordResetEnabled,
- FrontendURL: settings.FrontendURL,
- InvitationCodeEnabled: settings.InvitationCodeEnabled,
- TotpEnabled: settings.TotpEnabled,
- TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
- SMTPHost: settings.SMTPHost,
- SMTPPort: settings.SMTPPort,
- SMTPUsername: settings.SMTPUsername,
- SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
- SMTPFrom: settings.SMTPFrom,
- SMTPFromName: settings.SMTPFromName,
- SMTPUseTLS: settings.SMTPUseTLS,
- TurnstileEnabled: settings.TurnstileEnabled,
- TurnstileSiteKey: settings.TurnstileSiteKey,
- TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
- LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
- LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
- LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
- LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
- SiteName: settings.SiteName,
- SiteLogo: settings.SiteLogo,
- SiteSubtitle: settings.SiteSubtitle,
- APIBaseURL: settings.APIBaseURL,
- ContactInfo: settings.ContactInfo,
- DocURL: settings.DocURL,
- HomeContent: settings.HomeContent,
- HideCcsImportButton: settings.HideCcsImportButton,
- PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
- PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
- SoraClientEnabled: settings.SoraClientEnabled,
- CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
- CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
- DefaultConcurrency: settings.DefaultConcurrency,
- DefaultBalance: settings.DefaultBalance,
- DefaultSubscriptions: defaultSubscriptions,
- EnableModelFallback: settings.EnableModelFallback,
- FallbackModelAnthropic: settings.FallbackModelAnthropic,
- FallbackModelOpenAI: settings.FallbackModelOpenAI,
- FallbackModelGemini: settings.FallbackModelGemini,
- FallbackModelAntigravity: settings.FallbackModelAntigravity,
- EnableIdentityPatch: settings.EnableIdentityPatch,
- IdentityPatchPrompt: settings.IdentityPatchPrompt,
- OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled,
- OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
- OpsQueryModeDefault: settings.OpsQueryModeDefault,
- OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
- MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
- MaxClaudeCodeVersion: settings.MaxClaudeCodeVersion,
- AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
- BackendModeEnabled: settings.BackendModeEnabled,
- EnableFingerprintUnification: settings.EnableFingerprintUnification,
- EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
- })
+ // Load payment config
+ var paymentCfg *service.PaymentConfig
+ if h.paymentConfigService != nil {
+ paymentCfg, _ = h.paymentConfigService.GetPaymentConfig(c.Request.Context())
+ }
+ if paymentCfg == nil {
+ paymentCfg = &service.PaymentConfig{}
+ }
+
+ payload := dto.SystemSettings{
+ RegistrationEnabled: settings.RegistrationEnabled,
+ EmailVerifyEnabled: settings.EmailVerifyEnabled,
+ RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
+ PromoCodeEnabled: settings.PromoCodeEnabled,
+ PasswordResetEnabled: settings.PasswordResetEnabled,
+ FrontendURL: settings.FrontendURL,
+ InvitationCodeEnabled: settings.InvitationCodeEnabled,
+ TotpEnabled: settings.TotpEnabled,
+ TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
+ SMTPHost: settings.SMTPHost,
+ SMTPPort: settings.SMTPPort,
+ SMTPUsername: settings.SMTPUsername,
+ SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
+ SMTPFrom: settings.SMTPFrom,
+ SMTPFromName: settings.SMTPFromName,
+ SMTPUseTLS: settings.SMTPUseTLS,
+ TurnstileEnabled: settings.TurnstileEnabled,
+ TurnstileSiteKey: settings.TurnstileSiteKey,
+ TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
+ LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
+ LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
+ LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
+ LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
+ WeChatConnectEnabled: settings.WeChatConnectEnabled,
+ WeChatConnectAppID: settings.WeChatConnectAppID,
+ WeChatConnectAppSecretConfigured: settings.WeChatConnectAppSecretConfigured,
+ WeChatConnectOpenAppID: settings.WeChatConnectOpenAppID,
+ WeChatConnectOpenAppSecretConfigured: settings.WeChatConnectOpenAppSecretConfigured,
+ WeChatConnectMPAppID: settings.WeChatConnectMPAppID,
+ WeChatConnectMPAppSecretConfigured: settings.WeChatConnectMPAppSecretConfigured,
+ WeChatConnectMobileAppID: settings.WeChatConnectMobileAppID,
+ WeChatConnectMobileAppSecretConfigured: settings.WeChatConnectMobileAppSecretConfigured,
+ WeChatConnectOpenEnabled: settings.WeChatConnectOpenEnabled,
+ WeChatConnectMPEnabled: settings.WeChatConnectMPEnabled,
+ WeChatConnectMobileEnabled: settings.WeChatConnectMobileEnabled,
+ WeChatConnectMode: settings.WeChatConnectMode,
+ WeChatConnectScopes: settings.WeChatConnectScopes,
+ WeChatConnectRedirectURL: settings.WeChatConnectRedirectURL,
+ WeChatConnectFrontendRedirectURL: settings.WeChatConnectFrontendRedirectURL,
+ OIDCConnectEnabled: settings.OIDCConnectEnabled,
+ OIDCConnectProviderName: settings.OIDCConnectProviderName,
+ OIDCConnectClientID: settings.OIDCConnectClientID,
+ OIDCConnectClientSecretConfigured: settings.OIDCConnectClientSecretConfigured,
+ OIDCConnectIssuerURL: settings.OIDCConnectIssuerURL,
+ OIDCConnectDiscoveryURL: settings.OIDCConnectDiscoveryURL,
+ OIDCConnectAuthorizeURL: settings.OIDCConnectAuthorizeURL,
+ OIDCConnectTokenURL: settings.OIDCConnectTokenURL,
+ OIDCConnectUserInfoURL: settings.OIDCConnectUserInfoURL,
+ OIDCConnectJWKSURL: settings.OIDCConnectJWKSURL,
+ OIDCConnectScopes: settings.OIDCConnectScopes,
+ OIDCConnectRedirectURL: settings.OIDCConnectRedirectURL,
+ OIDCConnectFrontendRedirectURL: settings.OIDCConnectFrontendRedirectURL,
+ OIDCConnectTokenAuthMethod: settings.OIDCConnectTokenAuthMethod,
+ OIDCConnectUsePKCE: settings.OIDCConnectUsePKCE,
+ OIDCConnectValidateIDToken: settings.OIDCConnectValidateIDToken,
+ OIDCConnectAllowedSigningAlgs: settings.OIDCConnectAllowedSigningAlgs,
+ OIDCConnectClockSkewSeconds: settings.OIDCConnectClockSkewSeconds,
+ OIDCConnectRequireEmailVerified: settings.OIDCConnectRequireEmailVerified,
+ OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath,
+ OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath,
+ OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath,
+ SiteName: settings.SiteName,
+ SiteLogo: settings.SiteLogo,
+ SiteSubtitle: settings.SiteSubtitle,
+ APIBaseURL: settings.APIBaseURL,
+ ContactInfo: settings.ContactInfo,
+ DocURL: settings.DocURL,
+ HomeContent: settings.HomeContent,
+ HideCcsImportButton: settings.HideCcsImportButton,
+ PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
+ PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
+ TableDefaultPageSize: settings.TableDefaultPageSize,
+ TablePageSizeOptions: settings.TablePageSizeOptions,
+ CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
+ CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
+ DefaultConcurrency: settings.DefaultConcurrency,
+ DefaultBalance: settings.DefaultBalance,
+ AffiliateRebateRate: settings.AffiliateRebateRate,
+ AffiliateRebateFreezeHours: settings.AffiliateRebateFreezeHours,
+ AffiliateRebateDurationDays: settings.AffiliateRebateDurationDays,
+ AffiliateRebatePerInviteeCap: settings.AffiliateRebatePerInviteeCap,
+ DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
+ DefaultSubscriptions: defaultSubscriptions,
+ EnableModelFallback: settings.EnableModelFallback,
+ FallbackModelAnthropic: settings.FallbackModelAnthropic,
+ FallbackModelOpenAI: settings.FallbackModelOpenAI,
+ FallbackModelGemini: settings.FallbackModelGemini,
+ FallbackModelAntigravity: settings.FallbackModelAntigravity,
+ EnableIdentityPatch: settings.EnableIdentityPatch,
+ IdentityPatchPrompt: settings.IdentityPatchPrompt,
+ OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled,
+ OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
+ OpsQueryModeDefault: settings.OpsQueryModeDefault,
+ OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
+ MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
+ MaxClaudeCodeVersion: settings.MaxClaudeCodeVersion,
+ AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
+ BackendModeEnabled: settings.BackendModeEnabled,
+ EnableFingerprintUnification: settings.EnableFingerprintUnification,
+ EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
+ EnableCCHSigning: settings.EnableCCHSigning,
+ WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
+ PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
+ PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
+ PaymentVisibleMethodAlipayEnabled: settings.PaymentVisibleMethodAlipayEnabled,
+ PaymentVisibleMethodWxpayEnabled: settings.PaymentVisibleMethodWxpayEnabled,
+ OpenAIAdvancedSchedulerEnabled: settings.OpenAIAdvancedSchedulerEnabled,
+ BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
+ BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
+ BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
+ AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
+ AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
+ PaymentEnabled: paymentCfg.Enabled,
+ PaymentMinAmount: paymentCfg.MinAmount,
+ PaymentMaxAmount: paymentCfg.MaxAmount,
+ PaymentDailyLimit: paymentCfg.DailyLimit,
+ PaymentOrderTimeoutMin: paymentCfg.OrderTimeoutMin,
+ PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders,
+ PaymentEnabledTypes: paymentCfg.EnabledTypes,
+ PaymentBalanceDisabled: paymentCfg.BalanceDisabled,
+ PaymentBalanceRechargeMultiplier: paymentCfg.BalanceRechargeMultiplier,
+ PaymentRechargeFeeRate: paymentCfg.RechargeFeeRate,
+ PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy,
+ PaymentProductNamePrefix: paymentCfg.ProductNamePrefix,
+ PaymentProductNameSuffix: paymentCfg.ProductNameSuffix,
+ PaymentHelpImageURL: paymentCfg.HelpImageURL,
+ PaymentHelpText: paymentCfg.HelpText,
+ PaymentCancelRateLimitEnabled: paymentCfg.CancelRateLimitEnabled,
+ PaymentCancelRateLimitMax: paymentCfg.CancelRateLimitMax,
+ PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
+ PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
+ PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
+
+ ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
+ ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
+
+ AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
+
+ AffiliateEnabled: settings.AffiliateEnabled,
+ }
+ response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
// UpdateSettingsRequest 更新设置请求
@@ -166,6 +283,48 @@ type UpdateSettingsRequest struct {
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
+ // WeChat Connect OAuth 登录
+ WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
+ WeChatConnectAppID string `json:"wechat_connect_app_id"`
+ WeChatConnectAppSecret string `json:"wechat_connect_app_secret"`
+ WeChatConnectOpenAppID string `json:"wechat_connect_open_app_id"`
+ WeChatConnectOpenAppSecret string `json:"wechat_connect_open_app_secret"`
+ WeChatConnectMPAppID string `json:"wechat_connect_mp_app_id"`
+ WeChatConnectMPAppSecret string `json:"wechat_connect_mp_app_secret"`
+ WeChatConnectMobileAppID string `json:"wechat_connect_mobile_app_id"`
+ WeChatConnectMobileAppSecret string `json:"wechat_connect_mobile_app_secret"`
+ WeChatConnectOpenEnabled bool `json:"wechat_connect_open_enabled"`
+ WeChatConnectMPEnabled bool `json:"wechat_connect_mp_enabled"`
+ WeChatConnectMobileEnabled bool `json:"wechat_connect_mobile_enabled"`
+ WeChatConnectMode string `json:"wechat_connect_mode"`
+ WeChatConnectScopes string `json:"wechat_connect_scopes"`
+ WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"`
+ WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"`
+
+ // Generic OIDC OAuth 登录
+ OIDCConnectEnabled bool `json:"oidc_connect_enabled"`
+ OIDCConnectProviderName string `json:"oidc_connect_provider_name"`
+ OIDCConnectClientID string `json:"oidc_connect_client_id"`
+ OIDCConnectClientSecret string `json:"oidc_connect_client_secret"`
+ OIDCConnectIssuerURL string `json:"oidc_connect_issuer_url"`
+ OIDCConnectDiscoveryURL string `json:"oidc_connect_discovery_url"`
+ OIDCConnectAuthorizeURL string `json:"oidc_connect_authorize_url"`
+ OIDCConnectTokenURL string `json:"oidc_connect_token_url"`
+ OIDCConnectUserInfoURL string `json:"oidc_connect_userinfo_url"`
+ OIDCConnectJWKSURL string `json:"oidc_connect_jwks_url"`
+ OIDCConnectScopes string `json:"oidc_connect_scopes"`
+ OIDCConnectRedirectURL string `json:"oidc_connect_redirect_url"`
+ OIDCConnectFrontendRedirectURL string `json:"oidc_connect_frontend_redirect_url"`
+ OIDCConnectTokenAuthMethod string `json:"oidc_connect_token_auth_method"`
+ OIDCConnectUsePKCE *bool `json:"oidc_connect_use_pkce"`
+ OIDCConnectValidateIDToken *bool `json:"oidc_connect_validate_id_token"`
+ OIDCConnectAllowedSigningAlgs string `json:"oidc_connect_allowed_signing_algs"`
+ OIDCConnectClockSkewSeconds int `json:"oidc_connect_clock_skew_seconds"`
+ OIDCConnectRequireEmailVerified bool `json:"oidc_connect_require_email_verified"`
+ OIDCConnectUserInfoEmailPath string `json:"oidc_connect_userinfo_email_path"`
+ OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
+ OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
+
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
@@ -177,14 +336,41 @@ type UpdateSettingsRequest struct {
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
- SoraClientEnabled bool `json:"sora_client_enabled"`
+ TableDefaultPageSize int `json:"table_default_page_size"`
+ TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
// 默认配置
- DefaultConcurrency int `json:"default_concurrency"`
- DefaultBalance float64 `json:"default_balance"`
- DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
+ DefaultConcurrency int `json:"default_concurrency"`
+ DefaultBalance float64 `json:"default_balance"`
+ AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
+ AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
+ AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
+ AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
+ DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
+ DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
+ AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
+ AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
+ AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
+ AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
+ AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
+ AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
+ AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
+ AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
+ AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
+ AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
+ AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
+ AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
+ AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
+ AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
+ AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
+ AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
+ AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
+ AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
+ AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
+ AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
+ ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -215,6 +401,57 @@ type UpdateSettingsRequest struct {
// Gateway forwarding behavior
EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
+ EnableCCHSigning *bool `json:"enable_cch_signing"`
+
+ // Payment visible method routing
+ PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
+ PaymentVisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"`
+ PaymentVisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"`
+ PaymentVisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"`
+
+ // OpenAI account scheduling
+ OpenAIAdvancedSchedulerEnabled *bool `json:"openai_advanced_scheduler_enabled"`
+
+ // Balance low notification
+ BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
+ BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
+ BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"`
+ AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"`
+ AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"`
+
+ // Payment configuration (integrated into settings, full replace)
+ PaymentEnabled *bool `json:"payment_enabled"`
+ PaymentMinAmount *float64 `json:"payment_min_amount"`
+ PaymentMaxAmount *float64 `json:"payment_max_amount"`
+ PaymentDailyLimit *float64 `json:"payment_daily_limit"`
+ PaymentOrderTimeoutMin *int `json:"payment_order_timeout_minutes"`
+ PaymentMaxPendingOrders *int `json:"payment_max_pending_orders"`
+ PaymentEnabledTypes []string `json:"payment_enabled_types"`
+ PaymentBalanceDisabled *bool `json:"payment_balance_disabled"`
+ PaymentBalanceRechargeMultiplier *float64 `json:"payment_balance_recharge_multiplier"`
+ PaymentRechargeFeeRate *float64 `json:"payment_recharge_fee_rate"`
+ PaymentLoadBalanceStrat *string `json:"payment_load_balance_strategy"`
+ PaymentProductNamePrefix *string `json:"payment_product_name_prefix"`
+ PaymentProductNameSuffix *string `json:"payment_product_name_suffix"`
+ PaymentHelpImageURL *string `json:"payment_help_image_url"`
+ PaymentHelpText *string `json:"payment_help_text"`
+
+ // Cancel rate limit
+ PaymentCancelRateLimitEnabled *bool `json:"payment_cancel_rate_limit_enabled"`
+ PaymentCancelRateLimitMax *int `json:"payment_cancel_rate_limit_max"`
+ PaymentCancelRateLimitWindow *int `json:"payment_cancel_rate_limit_window"`
+ PaymentCancelRateLimitUnit *string `json:"payment_cancel_rate_limit_unit"`
+ PaymentCancelRateLimitMode *string `json:"payment_cancel_rate_limit_window_mode"`
+
+ // Channel Monitor feature switch
+ ChannelMonitorEnabled *bool `json:"channel_monitor_enabled"`
+ ChannelMonitorDefaultIntervalSeconds *int `json:"channel_monitor_default_interval_seconds"`
+
+ // Available Channels feature switch (user-facing)
+ AvailableChannelsEnabled *bool `json:"available_channels_enabled"`
+
+ // Affiliate (邀请返利) feature switch
+ AffiliateEnabled *bool `json:"affiliate_enabled"`
}
// UpdateSettings 更新系统设置
@@ -231,6 +468,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ previousAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
// 验证参数
if req.DefaultConcurrency < 1 {
@@ -239,6 +481,50 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.DefaultBalance < 0 {
req.DefaultBalance = 0
}
+ affiliateRebateRate := previousSettings.AffiliateRebateRate
+ if req.AffiliateRebateRate != nil {
+ affiliateRebateRate = *req.AffiliateRebateRate
+ }
+ if affiliateRebateRate < service.AffiliateRebateRateMin {
+ affiliateRebateRate = service.AffiliateRebateRateMin
+ }
+ if affiliateRebateRate > service.AffiliateRebateRateMax {
+ affiliateRebateRate = service.AffiliateRebateRateMax
+ }
+ affiliateRebateFreezeHours := previousSettings.AffiliateRebateFreezeHours
+ if req.AffiliateRebateFreezeHours != nil {
+ affiliateRebateFreezeHours = *req.AffiliateRebateFreezeHours
+ }
+ if affiliateRebateFreezeHours < 0 {
+ affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursDefault
+ }
+ if affiliateRebateFreezeHours > service.AffiliateRebateFreezeHoursMax {
+ affiliateRebateFreezeHours = service.AffiliateRebateFreezeHoursMax
+ }
+ affiliateRebateDurationDays := previousSettings.AffiliateRebateDurationDays
+ if req.AffiliateRebateDurationDays != nil {
+ affiliateRebateDurationDays = *req.AffiliateRebateDurationDays
+ }
+ if affiliateRebateDurationDays < 0 {
+ affiliateRebateDurationDays = service.AffiliateRebateDurationDaysDefault
+ }
+ if affiliateRebateDurationDays > service.AffiliateRebateDurationDaysMax {
+ affiliateRebateDurationDays = service.AffiliateRebateDurationDaysMax
+ }
+ affiliateRebatePerInviteeCap := previousSettings.AffiliateRebatePerInviteeCap
+ if req.AffiliateRebatePerInviteeCap != nil {
+ affiliateRebatePerInviteeCap = *req.AffiliateRebatePerInviteeCap
+ }
+ if affiliateRebatePerInviteeCap < 0 {
+ affiliateRebatePerInviteeCap = service.AffiliateRebatePerInviteeCapDefault
+ }
+ // 通用表格配置:兼容旧客户端未传字段时保留当前值。
+ if req.TableDefaultPageSize <= 0 {
+ req.TableDefaultPageSize = previousSettings.TableDefaultPageSize
+ }
+ if req.TablePageSizeOptions == nil {
+ req.TablePageSizeOptions = previousSettings.TablePageSizeOptions
+ }
req.SMTPHost = strings.TrimSpace(req.SMTPHost)
req.SMTPUsername = strings.TrimSpace(req.SMTPUsername)
req.SMTPPassword = strings.TrimSpace(req.SMTPPassword)
@@ -248,6 +534,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SMTPPort = 587
}
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
+ req.AuthSourceDefaultEmailSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultEmailSubscriptions)
+ req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions)
+ req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions)
+ req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions)
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
@@ -326,6 +616,275 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
+ if req.WeChatConnectEnabled {
+ req.WeChatConnectAppID = strings.TrimSpace(req.WeChatConnectAppID)
+ req.WeChatConnectAppSecret = strings.TrimSpace(req.WeChatConnectAppSecret)
+ req.WeChatConnectOpenAppID = strings.TrimSpace(req.WeChatConnectOpenAppID)
+ req.WeChatConnectOpenAppSecret = strings.TrimSpace(req.WeChatConnectOpenAppSecret)
+ req.WeChatConnectMPAppID = strings.TrimSpace(req.WeChatConnectMPAppID)
+ req.WeChatConnectMPAppSecret = strings.TrimSpace(req.WeChatConnectMPAppSecret)
+ req.WeChatConnectMobileAppID = strings.TrimSpace(req.WeChatConnectMobileAppID)
+ req.WeChatConnectMobileAppSecret = strings.TrimSpace(req.WeChatConnectMobileAppSecret)
+ req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(req.WeChatConnectMode))
+ req.WeChatConnectScopes = strings.TrimSpace(req.WeChatConnectScopes)
+ req.WeChatConnectRedirectURL = strings.TrimSpace(req.WeChatConnectRedirectURL)
+ req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(req.WeChatConnectFrontendRedirectURL)
+ req.WeChatConnectAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectAppID, previousSettings.WeChatConnectAppID))
+ req.WeChatConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectRedirectURL, previousSettings.WeChatConnectRedirectURL))
+ req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectFrontendRedirectURL, previousSettings.WeChatConnectFrontendRedirectURL))
+ if req.WeChatConnectMode == "" {
+ req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(previousSettings.WeChatConnectMode))
+ }
+ if req.WeChatConnectScopes == "" {
+ req.WeChatConnectScopes = strings.TrimSpace(previousSettings.WeChatConnectScopes)
+ }
+
+ if req.WeChatConnectMPEnabled && req.WeChatConnectMobileEnabled {
+ response.BadRequest(c, "WeChat Official Account and Mobile App cannot be enabled at the same time")
+ return
+ }
+ if req.WeChatConnectMode != "" {
+ switch req.WeChatConnectMode {
+ case "open", "mp", "mobile":
+ default:
+ response.BadRequest(c, "WeChat mode must be open, mp, or mobile")
+ return
+ }
+ }
+ if !req.WeChatConnectOpenEnabled && !req.WeChatConnectMPEnabled && !req.WeChatConnectMobileEnabled {
+ switch req.WeChatConnectMode {
+ case "mp":
+ req.WeChatConnectMPEnabled = true
+ case "mobile":
+ req.WeChatConnectMobileEnabled = true
+ default:
+ req.WeChatConnectOpenEnabled = true
+ }
+ }
+ if req.WeChatConnectMode == "" {
+ if req.WeChatConnectMPEnabled {
+ req.WeChatConnectMode = "mp"
+ } else if req.WeChatConnectMobileEnabled {
+ req.WeChatConnectMode = "mobile"
+ } else {
+ req.WeChatConnectMode = "open"
+ }
+ }
+
+ req.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectOpenAppID, previousSettings.WeChatConnectAppID))
+ req.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMPAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMPAppID, previousSettings.WeChatConnectAppID))
+ req.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMobileAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMobileAppID, previousSettings.WeChatConnectAppID))
+
+ if req.WeChatConnectOpenAppSecret == "" {
+ req.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectOpenAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret))
+ }
+ if req.WeChatConnectMPAppSecret == "" {
+ req.WeChatConnectMPAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectMPAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret))
+ }
+ if req.WeChatConnectMobileAppSecret == "" {
+ req.WeChatConnectMobileAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectMobileAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret))
+ }
+ if req.WeChatConnectAppSecret == "" {
+ req.WeChatConnectAppSecret = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppSecret, req.WeChatConnectMPAppSecret, req.WeChatConnectMobileAppSecret, previousSettings.WeChatConnectAppSecret))
+ }
+
+ if req.WeChatConnectOpenEnabled {
+ if req.WeChatConnectOpenAppID == "" {
+ response.BadRequest(c, "WeChat PC App ID is required when enabled")
+ return
+ }
+ if req.WeChatConnectOpenAppSecret == "" {
+ response.BadRequest(c, "WeChat PC App Secret is required when enabled")
+ return
+ }
+ }
+ if req.WeChatConnectMPEnabled {
+ if req.WeChatConnectMPAppID == "" {
+ response.BadRequest(c, "WeChat Official Account App ID is required when enabled")
+ return
+ }
+ if req.WeChatConnectMPAppSecret == "" {
+ response.BadRequest(c, "WeChat Official Account App Secret is required when enabled")
+ return
+ }
+ }
+ if req.WeChatConnectMobileEnabled {
+ if req.WeChatConnectMobileAppID == "" {
+ response.BadRequest(c, "WeChat Mobile App ID is required when enabled")
+ return
+ }
+ if req.WeChatConnectMobileAppSecret == "" {
+ response.BadRequest(c, "WeChat Mobile App Secret is required when enabled")
+ return
+ }
+ }
+
+ if req.WeChatConnectScopes == "" {
+ if req.WeChatConnectMPEnabled {
+ req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode("mp")
+ } else {
+ req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode(req.WeChatConnectMode)
+ }
+ }
+ if req.WeChatConnectOpenEnabled || req.WeChatConnectMPEnabled {
+ if req.WeChatConnectRedirectURL == "" {
+ response.BadRequest(c, "WeChat Redirect URL is required when web oauth is enabled")
+ return
+ }
+ if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil {
+ response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL")
+ return
+ }
+ if req.WeChatConnectFrontendRedirectURL == "" {
+ req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback"
+ }
+ if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil {
+ response.BadRequest(c, "WeChat Frontend Redirect URL is invalid")
+ return
+ }
+ }
+ }
+
+ // Generic OIDC 参数验证
+ oidcUsePKCE, oidcValidateIDToken, err := h.settingService.OIDCSecurityWriteDefaults(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if req.OIDCConnectEnabled {
+ req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName)
+ req.OIDCConnectClientID = strings.TrimSpace(req.OIDCConnectClientID)
+ req.OIDCConnectClientSecret = strings.TrimSpace(req.OIDCConnectClientSecret)
+ req.OIDCConnectIssuerURL = strings.TrimSpace(req.OIDCConnectIssuerURL)
+ req.OIDCConnectDiscoveryURL = strings.TrimSpace(req.OIDCConnectDiscoveryURL)
+ req.OIDCConnectAuthorizeURL = strings.TrimSpace(req.OIDCConnectAuthorizeURL)
+ req.OIDCConnectTokenURL = strings.TrimSpace(req.OIDCConnectTokenURL)
+ req.OIDCConnectUserInfoURL = strings.TrimSpace(req.OIDCConnectUserInfoURL)
+ req.OIDCConnectJWKSURL = strings.TrimSpace(req.OIDCConnectJWKSURL)
+ req.OIDCConnectScopes = strings.TrimSpace(req.OIDCConnectScopes)
+ req.OIDCConnectRedirectURL = strings.TrimSpace(req.OIDCConnectRedirectURL)
+ req.OIDCConnectFrontendRedirectURL = strings.TrimSpace(req.OIDCConnectFrontendRedirectURL)
+ req.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(req.OIDCConnectTokenAuthMethod))
+ req.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(req.OIDCConnectAllowedSigningAlgs)
+ req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(req.OIDCConnectUserInfoEmailPath)
+ req.OIDCConnectUserInfoIDPath = strings.TrimSpace(req.OIDCConnectUserInfoIDPath)
+ req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(req.OIDCConnectUserInfoUsernamePath)
+ req.OIDCConnectProviderName = strings.TrimSpace(firstNonEmpty(req.OIDCConnectProviderName, previousSettings.OIDCConnectProviderName, "OIDC"))
+ req.OIDCConnectClientID = strings.TrimSpace(firstNonEmpty(req.OIDCConnectClientID, previousSettings.OIDCConnectClientID))
+ req.OIDCConnectIssuerURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectIssuerURL, previousSettings.OIDCConnectIssuerURL))
+ req.OIDCConnectDiscoveryURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectDiscoveryURL, previousSettings.OIDCConnectDiscoveryURL))
+ req.OIDCConnectAuthorizeURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAuthorizeURL, previousSettings.OIDCConnectAuthorizeURL))
+ req.OIDCConnectTokenURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenURL, previousSettings.OIDCConnectTokenURL))
+ req.OIDCConnectUserInfoURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoURL, previousSettings.OIDCConnectUserInfoURL))
+ req.OIDCConnectJWKSURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectJWKSURL, previousSettings.OIDCConnectJWKSURL))
+ req.OIDCConnectScopes = strings.TrimSpace(firstNonEmpty(req.OIDCConnectScopes, previousSettings.OIDCConnectScopes, "openid email profile"))
+ req.OIDCConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectRedirectURL, previousSettings.OIDCConnectRedirectURL))
+ req.OIDCConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectFrontendRedirectURL, previousSettings.OIDCConnectFrontendRedirectURL, "/auth/oidc/callback"))
+ req.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenAuthMethod, previousSettings.OIDCConnectTokenAuthMethod, "client_secret_post")))
+ req.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAllowedSigningAlgs, previousSettings.OIDCConnectAllowedSigningAlgs, "RS256,ES256,PS256"))
+ req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoEmailPath, previousSettings.OIDCConnectUserInfoEmailPath))
+ req.OIDCConnectUserInfoIDPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoIDPath, previousSettings.OIDCConnectUserInfoIDPath))
+ req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoUsernamePath, previousSettings.OIDCConnectUserInfoUsernamePath))
+ if req.OIDCConnectUsePKCE != nil {
+ oidcUsePKCE = *req.OIDCConnectUsePKCE
+ }
+ if req.OIDCConnectValidateIDToken != nil {
+ oidcValidateIDToken = *req.OIDCConnectValidateIDToken
+ }
+ if req.OIDCConnectClockSkewSeconds == 0 {
+ req.OIDCConnectClockSkewSeconds = previousSettings.OIDCConnectClockSkewSeconds
+ if req.OIDCConnectClockSkewSeconds == 0 {
+ req.OIDCConnectClockSkewSeconds = 120
+ }
+ }
+
+ if req.OIDCConnectClientID == "" {
+ response.BadRequest(c, "OIDC Client ID is required when enabled")
+ return
+ }
+ if req.OIDCConnectIssuerURL == "" {
+ response.BadRequest(c, "OIDC Issuer URL is required when enabled")
+ return
+ }
+ if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectIssuerURL); err != nil {
+ response.BadRequest(c, "OIDC Issuer URL must be an absolute http(s) URL")
+ return
+ }
+ if req.OIDCConnectDiscoveryURL != "" {
+ if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectDiscoveryURL); err != nil {
+ response.BadRequest(c, "OIDC Discovery URL must be an absolute http(s) URL")
+ return
+ }
+ }
+ if req.OIDCConnectAuthorizeURL != "" {
+ if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectAuthorizeURL); err != nil {
+ response.BadRequest(c, "OIDC Authorize URL must be an absolute http(s) URL")
+ return
+ }
+ }
+ if req.OIDCConnectTokenURL != "" {
+ if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectTokenURL); err != nil {
+ response.BadRequest(c, "OIDC Token URL must be an absolute http(s) URL")
+ return
+ }
+ }
+ if req.OIDCConnectUserInfoURL != "" {
+ if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectUserInfoURL); err != nil {
+ response.BadRequest(c, "OIDC UserInfo URL must be an absolute http(s) URL")
+ return
+ }
+ }
+ if req.OIDCConnectRedirectURL == "" {
+ response.BadRequest(c, "OIDC Redirect URL is required when enabled")
+ return
+ }
+ if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectRedirectURL); err != nil {
+ response.BadRequest(c, "OIDC Redirect URL must be an absolute http(s) URL")
+ return
+ }
+ if req.OIDCConnectFrontendRedirectURL == "" {
+ response.BadRequest(c, "OIDC Frontend Redirect URL is required when enabled")
+ return
+ }
+ if err := config.ValidateFrontendRedirectURL(req.OIDCConnectFrontendRedirectURL); err != nil {
+ response.BadRequest(c, "OIDC Frontend Redirect URL is invalid")
+ return
+ }
+ if !scopesContainOpenID(req.OIDCConnectScopes) {
+ response.BadRequest(c, "OIDC scopes must contain openid")
+ return
+ }
+ switch req.OIDCConnectTokenAuthMethod {
+ case "", "client_secret_post", "client_secret_basic", "none":
+ default:
+ response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none")
+ return
+ }
+ if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 {
+ response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600")
+ return
+ }
+ if oidcValidateIDToken && req.OIDCConnectAllowedSigningAlgs == "" {
+ response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
+ return
+ }
+ if req.OIDCConnectJWKSURL != "" {
+ if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil {
+ response.BadRequest(c, "OIDC JWKS URL must be an absolute http(s) URL")
+ return
+ }
+ }
+ if req.OIDCConnectTokenAuthMethod == "" || req.OIDCConnectTokenAuthMethod == "client_secret_post" || req.OIDCConnectTokenAuthMethod == "client_secret_basic" {
+ if req.OIDCConnectClientSecret == "" {
+ if previousSettings.OIDCConnectClientSecret == "" {
+ response.BadRequest(c, "OIDC Client Secret is required when enabled")
+ return
+ }
+ req.OIDCConnectClientSecret = previousSettings.OIDCConnectClientSecret
+ }
+ }
+ }
+
// “购买订阅”页面配置验证
purchaseEnabled := previousSettings.PurchaseSubscriptionEnabled
if req.PurchaseSubscriptionEnabled != nil {
@@ -556,6 +1115,44 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
+ WeChatConnectEnabled: req.WeChatConnectEnabled,
+ WeChatConnectAppID: req.WeChatConnectAppID,
+ WeChatConnectAppSecret: req.WeChatConnectAppSecret,
+ WeChatConnectOpenAppID: req.WeChatConnectOpenAppID,
+ WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret,
+ WeChatConnectMPAppID: req.WeChatConnectMPAppID,
+ WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret,
+ WeChatConnectMobileAppID: req.WeChatConnectMobileAppID,
+ WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret,
+ WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled,
+ WeChatConnectMPEnabled: req.WeChatConnectMPEnabled,
+ WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled,
+ WeChatConnectMode: req.WeChatConnectMode,
+ WeChatConnectScopes: req.WeChatConnectScopes,
+ WeChatConnectRedirectURL: req.WeChatConnectRedirectURL,
+ WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL,
+ OIDCConnectEnabled: req.OIDCConnectEnabled,
+ OIDCConnectProviderName: req.OIDCConnectProviderName,
+ OIDCConnectClientID: req.OIDCConnectClientID,
+ OIDCConnectClientSecret: req.OIDCConnectClientSecret,
+ OIDCConnectIssuerURL: req.OIDCConnectIssuerURL,
+ OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL,
+ OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL,
+ OIDCConnectTokenURL: req.OIDCConnectTokenURL,
+ OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL,
+ OIDCConnectJWKSURL: req.OIDCConnectJWKSURL,
+ OIDCConnectScopes: req.OIDCConnectScopes,
+ OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
+ OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
+ OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
+ OIDCConnectUsePKCE: oidcUsePKCE,
+ OIDCConnectValidateIDToken: oidcValidateIDToken,
+ OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
+ OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
+ OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
+ OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
+ OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
+ OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
@@ -566,11 +1163,17 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
- SoraClientEnabled: req.SoraClientEnabled,
+ TableDefaultPageSize: req.TableDefaultPageSize,
+ TablePageSizeOptions: req.TablePageSizeOptions,
CustomMenuItems: customMenuJSON,
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
+ AffiliateRebateRate: affiliateRebateRate,
+ AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
+ AffiliateRebateDurationDays: affiliateRebateDurationDays,
+ AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
+ DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
@@ -619,14 +1222,170 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.EnableMetadataPassthrough
}(),
+ EnableCCHSigning: func() bool {
+ if req.EnableCCHSigning != nil {
+ return *req.EnableCCHSigning
+ }
+ return previousSettings.EnableCCHSigning
+ }(),
+ PaymentVisibleMethodAlipaySource: func() string {
+ if req.PaymentVisibleMethodAlipaySource != nil {
+ return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
+ }
+ return previousSettings.PaymentVisibleMethodAlipaySource
+ }(),
+ PaymentVisibleMethodWxpaySource: func() string {
+ if req.PaymentVisibleMethodWxpaySource != nil {
+ return strings.TrimSpace(*req.PaymentVisibleMethodWxpaySource)
+ }
+ return previousSettings.PaymentVisibleMethodWxpaySource
+ }(),
+ PaymentVisibleMethodAlipayEnabled: func() bool {
+ if req.PaymentVisibleMethodAlipayEnabled != nil {
+ return *req.PaymentVisibleMethodAlipayEnabled
+ }
+ return previousSettings.PaymentVisibleMethodAlipayEnabled
+ }(),
+ PaymentVisibleMethodWxpayEnabled: func() bool {
+ if req.PaymentVisibleMethodWxpayEnabled != nil {
+ return *req.PaymentVisibleMethodWxpayEnabled
+ }
+ return previousSettings.PaymentVisibleMethodWxpayEnabled
+ }(),
+ OpenAIAdvancedSchedulerEnabled: func() bool {
+ if req.OpenAIAdvancedSchedulerEnabled != nil {
+ return *req.OpenAIAdvancedSchedulerEnabled
+ }
+ return previousSettings.OpenAIAdvancedSchedulerEnabled
+ }(),
+ BalanceLowNotifyEnabled: func() bool {
+ if req.BalanceLowNotifyEnabled != nil {
+ return *req.BalanceLowNotifyEnabled
+ }
+ return previousSettings.BalanceLowNotifyEnabled
+ }(),
+ BalanceLowNotifyThreshold: func() float64 {
+ if req.BalanceLowNotifyThreshold != nil {
+ return *req.BalanceLowNotifyThreshold
+ }
+ return previousSettings.BalanceLowNotifyThreshold
+ }(),
+ BalanceLowNotifyRechargeURL: func() string {
+ if req.BalanceLowNotifyRechargeURL != nil {
+ return *req.BalanceLowNotifyRechargeURL
+ }
+ return previousSettings.BalanceLowNotifyRechargeURL
+ }(),
+ AccountQuotaNotifyEnabled: func() bool {
+ if req.AccountQuotaNotifyEnabled != nil {
+ return *req.AccountQuotaNotifyEnabled
+ }
+ return previousSettings.AccountQuotaNotifyEnabled
+ }(),
+ AccountQuotaNotifyEmails: func() []service.NotifyEmailEntry {
+ if req.AccountQuotaNotifyEmails != nil {
+ return dto.NotifyEmailEntriesToService(*req.AccountQuotaNotifyEmails)
+ }
+ return previousSettings.AccountQuotaNotifyEmails
+ }(),
+ ChannelMonitorEnabled: func() bool {
+ if req.ChannelMonitorEnabled != nil {
+ return *req.ChannelMonitorEnabled
+ }
+ return previousSettings.ChannelMonitorEnabled
+ }(),
+ ChannelMonitorDefaultIntervalSeconds: func() int {
+ if req.ChannelMonitorDefaultIntervalSeconds != nil {
+ return *req.ChannelMonitorDefaultIntervalSeconds
+ }
+ return previousSettings.ChannelMonitorDefaultIntervalSeconds
+ }(),
+ AvailableChannelsEnabled: func() bool {
+ if req.AvailableChannelsEnabled != nil {
+ return *req.AvailableChannelsEnabled
+ }
+ return previousSettings.AvailableChannelsEnabled
+ }(),
+ AffiliateEnabled: func() bool {
+ if req.AffiliateEnabled != nil {
+ return *req.AffiliateEnabled
+ }
+ return previousSettings.AffiliateEnabled
+ }(),
}
- if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
+ authSourceDefaults := &service.AuthSourceDefaultSettings{
+ Email: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultEmailConcurrency, previousAuthSourceDefaults.Email.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind),
+ },
+ LinuxDo: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultLinuxDoConcurrency, previousAuthSourceDefaults.LinuxDo.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind),
+ },
+ OIDC: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultOIDCConcurrency, previousAuthSourceDefaults.OIDC.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind),
+ },
+ WeChat: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultWeChatConcurrency, previousAuthSourceDefaults.WeChat.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
+ },
+ ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
+ }
+ if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
response.ErrorFrom(c, err)
return
}
- h.auditSettingsUpdate(c, previousSettings, settings, req)
+ // Update payment configuration (integrated into system settings).
+ // Skip if no payment fields were provided (prevents accidental wipe).
+ if h.paymentConfigService != nil && hasPaymentFields(req) {
+ paymentReq := service.UpdatePaymentConfigRequest{
+ Enabled: req.PaymentEnabled,
+ MinAmount: req.PaymentMinAmount,
+ MaxAmount: req.PaymentMaxAmount,
+ DailyLimit: req.PaymentDailyLimit,
+ OrderTimeoutMin: req.PaymentOrderTimeoutMin,
+ MaxPendingOrders: req.PaymentMaxPendingOrders,
+ EnabledTypes: req.PaymentEnabledTypes,
+ BalanceDisabled: req.PaymentBalanceDisabled,
+ BalanceRechargeMultiplier: req.PaymentBalanceRechargeMultiplier,
+ RechargeFeeRate: req.PaymentRechargeFeeRate,
+ LoadBalanceStrategy: req.PaymentLoadBalanceStrat,
+ ProductNamePrefix: req.PaymentProductNamePrefix,
+ ProductNameSuffix: req.PaymentProductNameSuffix,
+ HelpImageURL: req.PaymentHelpImageURL,
+ HelpText: req.PaymentHelpText,
+ CancelRateLimitEnabled: req.PaymentCancelRateLimitEnabled,
+ CancelRateLimitMax: req.PaymentCancelRateLimitMax,
+ CancelRateLimitWindow: req.PaymentCancelRateLimitWindow,
+ CancelRateLimitUnit: req.PaymentCancelRateLimitUnit,
+ CancelRateLimitMode: req.PaymentCancelRateLimitMode,
+ }
+ if err := h.paymentConfigService.UpdatePaymentConfig(c.Request.Context(), paymentReq); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ // Refresh in-memory provider registry so config changes take effect immediately
+ if h.paymentService != nil {
+ h.paymentService.RefreshProviders(c.Request.Context())
+ }
+ }
+
+ h.auditSettingsUpdate(c, previousSettings, settings, previousAuthSourceDefaults, authSourceDefaults, req)
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
@@ -634,6 +1393,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
for _, sub := range updatedSettings.DefaultSubscriptions {
updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
@@ -642,87 +1406,193 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
})
}
- response.Success(c, dto.SystemSettings{
- RegistrationEnabled: updatedSettings.RegistrationEnabled,
- EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
- RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
- PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
- PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
- FrontendURL: updatedSettings.FrontendURL,
- InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
- TotpEnabled: updatedSettings.TotpEnabled,
- TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
- SMTPHost: updatedSettings.SMTPHost,
- SMTPPort: updatedSettings.SMTPPort,
- SMTPUsername: updatedSettings.SMTPUsername,
- SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
- SMTPFrom: updatedSettings.SMTPFrom,
- SMTPFromName: updatedSettings.SMTPFromName,
- SMTPUseTLS: updatedSettings.SMTPUseTLS,
- TurnstileEnabled: updatedSettings.TurnstileEnabled,
- TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
- TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
- LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
- LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
- LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
- LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
- SiteName: updatedSettings.SiteName,
- SiteLogo: updatedSettings.SiteLogo,
- SiteSubtitle: updatedSettings.SiteSubtitle,
- APIBaseURL: updatedSettings.APIBaseURL,
- ContactInfo: updatedSettings.ContactInfo,
- DocURL: updatedSettings.DocURL,
- HomeContent: updatedSettings.HomeContent,
- HideCcsImportButton: updatedSettings.HideCcsImportButton,
- PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
- PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
- SoraClientEnabled: updatedSettings.SoraClientEnabled,
- CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
- CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
- DefaultConcurrency: updatedSettings.DefaultConcurrency,
- DefaultBalance: updatedSettings.DefaultBalance,
- DefaultSubscriptions: updatedDefaultSubscriptions,
- EnableModelFallback: updatedSettings.EnableModelFallback,
- FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
- FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
- FallbackModelGemini: updatedSettings.FallbackModelGemini,
- FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
- EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
- IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
- OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
- OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
- OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
- OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
- MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
- MaxClaudeCodeVersion: updatedSettings.MaxClaudeCodeVersion,
- AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
- BackendModeEnabled: updatedSettings.BackendModeEnabled,
- EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
- EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
- })
+ // Reload payment config for response
+ var updatedPaymentCfg *service.PaymentConfig
+ if h.paymentConfigService != nil {
+ updatedPaymentCfg, _ = h.paymentConfigService.GetPaymentConfig(c.Request.Context())
+ }
+ if updatedPaymentCfg == nil {
+ updatedPaymentCfg = &service.PaymentConfig{}
+ }
+
+ payload := dto.SystemSettings{
+ RegistrationEnabled: updatedSettings.RegistrationEnabled,
+ EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
+ RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
+ PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
+ PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
+ FrontendURL: updatedSettings.FrontendURL,
+ InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
+ TotpEnabled: updatedSettings.TotpEnabled,
+ TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
+ SMTPHost: updatedSettings.SMTPHost,
+ SMTPPort: updatedSettings.SMTPPort,
+ SMTPUsername: updatedSettings.SMTPUsername,
+ SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
+ SMTPFrom: updatedSettings.SMTPFrom,
+ SMTPFromName: updatedSettings.SMTPFromName,
+ SMTPUseTLS: updatedSettings.SMTPUseTLS,
+ TurnstileEnabled: updatedSettings.TurnstileEnabled,
+ TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
+ TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
+ LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
+ LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
+ LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
+ LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
+ WeChatConnectEnabled: updatedSettings.WeChatConnectEnabled,
+ WeChatConnectAppID: updatedSettings.WeChatConnectAppID,
+ WeChatConnectAppSecretConfigured: updatedSettings.WeChatConnectAppSecretConfigured,
+ WeChatConnectOpenAppID: updatedSettings.WeChatConnectOpenAppID,
+ WeChatConnectOpenAppSecretConfigured: updatedSettings.WeChatConnectOpenAppSecretConfigured,
+ WeChatConnectMPAppID: updatedSettings.WeChatConnectMPAppID,
+ WeChatConnectMPAppSecretConfigured: updatedSettings.WeChatConnectMPAppSecretConfigured,
+ WeChatConnectMobileAppID: updatedSettings.WeChatConnectMobileAppID,
+ WeChatConnectMobileAppSecretConfigured: updatedSettings.WeChatConnectMobileAppSecretConfigured,
+ WeChatConnectOpenEnabled: updatedSettings.WeChatConnectOpenEnabled,
+ WeChatConnectMPEnabled: updatedSettings.WeChatConnectMPEnabled,
+ WeChatConnectMobileEnabled: updatedSettings.WeChatConnectMobileEnabled,
+ WeChatConnectMode: updatedSettings.WeChatConnectMode,
+ WeChatConnectScopes: updatedSettings.WeChatConnectScopes,
+ WeChatConnectRedirectURL: updatedSettings.WeChatConnectRedirectURL,
+ WeChatConnectFrontendRedirectURL: updatedSettings.WeChatConnectFrontendRedirectURL,
+ OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled,
+ OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName,
+ OIDCConnectClientID: updatedSettings.OIDCConnectClientID,
+ OIDCConnectClientSecretConfigured: updatedSettings.OIDCConnectClientSecretConfigured,
+ OIDCConnectIssuerURL: updatedSettings.OIDCConnectIssuerURL,
+ OIDCConnectDiscoveryURL: updatedSettings.OIDCConnectDiscoveryURL,
+ OIDCConnectAuthorizeURL: updatedSettings.OIDCConnectAuthorizeURL,
+ OIDCConnectTokenURL: updatedSettings.OIDCConnectTokenURL,
+ OIDCConnectUserInfoURL: updatedSettings.OIDCConnectUserInfoURL,
+ OIDCConnectJWKSURL: updatedSettings.OIDCConnectJWKSURL,
+ OIDCConnectScopes: updatedSettings.OIDCConnectScopes,
+ OIDCConnectRedirectURL: updatedSettings.OIDCConnectRedirectURL,
+ OIDCConnectFrontendRedirectURL: updatedSettings.OIDCConnectFrontendRedirectURL,
+ OIDCConnectTokenAuthMethod: updatedSettings.OIDCConnectTokenAuthMethod,
+ OIDCConnectUsePKCE: updatedSettings.OIDCConnectUsePKCE,
+ OIDCConnectValidateIDToken: updatedSettings.OIDCConnectValidateIDToken,
+ OIDCConnectAllowedSigningAlgs: updatedSettings.OIDCConnectAllowedSigningAlgs,
+ OIDCConnectClockSkewSeconds: updatedSettings.OIDCConnectClockSkewSeconds,
+ OIDCConnectRequireEmailVerified: updatedSettings.OIDCConnectRequireEmailVerified,
+ OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath,
+ OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath,
+ OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath,
+ SiteName: updatedSettings.SiteName,
+ SiteLogo: updatedSettings.SiteLogo,
+ SiteSubtitle: updatedSettings.SiteSubtitle,
+ APIBaseURL: updatedSettings.APIBaseURL,
+ ContactInfo: updatedSettings.ContactInfo,
+ DocURL: updatedSettings.DocURL,
+ HomeContent: updatedSettings.HomeContent,
+ HideCcsImportButton: updatedSettings.HideCcsImportButton,
+ PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
+ PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
+ TableDefaultPageSize: updatedSettings.TableDefaultPageSize,
+ TablePageSizeOptions: updatedSettings.TablePageSizeOptions,
+ CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
+ CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
+ DefaultConcurrency: updatedSettings.DefaultConcurrency,
+ DefaultBalance: updatedSettings.DefaultBalance,
+ AffiliateRebateRate: updatedSettings.AffiliateRebateRate,
+ AffiliateRebateFreezeHours: updatedSettings.AffiliateRebateFreezeHours,
+ AffiliateRebateDurationDays: updatedSettings.AffiliateRebateDurationDays,
+ AffiliateRebatePerInviteeCap: updatedSettings.AffiliateRebatePerInviteeCap,
+ DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
+ DefaultSubscriptions: updatedDefaultSubscriptions,
+ EnableModelFallback: updatedSettings.EnableModelFallback,
+ FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
+ FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
+ FallbackModelGemini: updatedSettings.FallbackModelGemini,
+ FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
+ EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
+ IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
+ OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
+ OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
+ OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
+ OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
+ MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
+ MaxClaudeCodeVersion: updatedSettings.MaxClaudeCodeVersion,
+ AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
+ BackendModeEnabled: updatedSettings.BackendModeEnabled,
+ EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
+ EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
+ EnableCCHSigning: updatedSettings.EnableCCHSigning,
+ PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
+ PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
+ PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
+ PaymentVisibleMethodWxpayEnabled: updatedSettings.PaymentVisibleMethodWxpayEnabled,
+ OpenAIAdvancedSchedulerEnabled: updatedSettings.OpenAIAdvancedSchedulerEnabled,
+ BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
+ BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
+ BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
+ AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled,
+ AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
+ PaymentEnabled: updatedPaymentCfg.Enabled,
+ PaymentMinAmount: updatedPaymentCfg.MinAmount,
+ PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
+ PaymentDailyLimit: updatedPaymentCfg.DailyLimit,
+ PaymentOrderTimeoutMin: updatedPaymentCfg.OrderTimeoutMin,
+ PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders,
+ PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes,
+ PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled,
+ PaymentBalanceRechargeMultiplier: updatedPaymentCfg.BalanceRechargeMultiplier,
+ PaymentRechargeFeeRate: updatedPaymentCfg.RechargeFeeRate,
+ PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy,
+ PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix,
+ PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix,
+ PaymentHelpImageURL: updatedPaymentCfg.HelpImageURL,
+ PaymentHelpText: updatedPaymentCfg.HelpText,
+ PaymentCancelRateLimitEnabled: updatedPaymentCfg.CancelRateLimitEnabled,
+ PaymentCancelRateLimitMax: updatedPaymentCfg.CancelRateLimitMax,
+ PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
+ PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
+ PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
+
+ ChannelMonitorEnabled: updatedSettings.ChannelMonitorEnabled,
+ ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds,
+
+ AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
+
+ AffiliateEnabled: updatedSettings.AffiliateEnabled,
+ }
+ response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
-func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
+// hasPaymentFields returns true if any payment-related field was explicitly provided.
+func hasPaymentFields(req UpdateSettingsRequest) bool {
+ return req.PaymentEnabled != nil || req.PaymentMinAmount != nil ||
+ req.PaymentMaxAmount != nil || req.PaymentDailyLimit != nil ||
+ req.PaymentOrderTimeoutMin != nil || req.PaymentMaxPendingOrders != nil ||
+ req.PaymentEnabledTypes != nil || req.PaymentBalanceDisabled != nil ||
+ req.PaymentBalanceRechargeMultiplier != nil || req.PaymentRechargeFeeRate != nil ||
+ req.PaymentLoadBalanceStrat != nil || req.PaymentProductNamePrefix != nil ||
+ req.PaymentProductNameSuffix != nil || req.PaymentHelpImageURL != nil ||
+ req.PaymentHelpText != nil || req.PaymentCancelRateLimitEnabled != nil ||
+ req.PaymentCancelRateLimitMax != nil || req.PaymentCancelRateLimitWindow != nil ||
+ req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil
+}
+
+func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) {
if before == nil || after == nil {
return
}
- changed := diffSettings(before, after, req)
+ changed := diffSettings(before, after, beforeAuthSourceDefaults, afterAuthSourceDefaults, req)
if len(changed) == 0 {
return
}
subject, _ := middleware.GetAuthSubjectFromContext(c)
role, _ := middleware.GetUserRoleFromContext(c)
- log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v",
- time.Now().UTC().Format(time.RFC3339),
- subject.UserID,
- role,
- changed,
+ slog.Info("settings updated",
+ "audit", true,
+ "user_id", subject.UserID,
+ "role", role,
+ "changed", changed,
)
}
-func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string {
+func diffSettings(before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) []string {
changed := make([]string, 0, 20)
if before.RegistrationEnabled != after.RegistrationEnabled {
changed = append(changed, "registration_enabled")
@@ -733,6 +1603,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
changed = append(changed, "registration_email_suffix_whitelist")
}
+ if before.PromoCodeEnabled != after.PromoCodeEnabled {
+ changed = append(changed, "promo_code_enabled")
+ }
+ if before.InvitationCodeEnabled != after.InvitationCodeEnabled {
+ changed = append(changed, "invitation_code_enabled")
+ }
if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled")
}
@@ -784,6 +1660,120 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
changed = append(changed, "linuxdo_connect_redirect_url")
}
+ if before.WeChatConnectEnabled != after.WeChatConnectEnabled {
+ changed = append(changed, "wechat_connect_enabled")
+ }
+ if before.WeChatConnectAppID != after.WeChatConnectAppID {
+ changed = append(changed, "wechat_connect_app_id")
+ }
+ if req.WeChatConnectAppSecret != "" {
+ changed = append(changed, "wechat_connect_app_secret")
+ }
+ if before.WeChatConnectOpenAppID != after.WeChatConnectOpenAppID {
+ changed = append(changed, "wechat_connect_open_app_id")
+ }
+ if req.WeChatConnectOpenAppSecret != "" {
+ changed = append(changed, "wechat_connect_open_app_secret")
+ }
+ if before.WeChatConnectMPAppID != after.WeChatConnectMPAppID {
+ changed = append(changed, "wechat_connect_mp_app_id")
+ }
+ if req.WeChatConnectMPAppSecret != "" {
+ changed = append(changed, "wechat_connect_mp_app_secret")
+ }
+ if before.WeChatConnectMobileAppID != after.WeChatConnectMobileAppID {
+ changed = append(changed, "wechat_connect_mobile_app_id")
+ }
+ if req.WeChatConnectMobileAppSecret != "" {
+ changed = append(changed, "wechat_connect_mobile_app_secret")
+ }
+ if before.WeChatConnectOpenEnabled != after.WeChatConnectOpenEnabled {
+ changed = append(changed, "wechat_connect_open_enabled")
+ }
+ if before.WeChatConnectMPEnabled != after.WeChatConnectMPEnabled {
+ changed = append(changed, "wechat_connect_mp_enabled")
+ }
+ if before.WeChatConnectMobileEnabled != after.WeChatConnectMobileEnabled {
+ changed = append(changed, "wechat_connect_mobile_enabled")
+ }
+ if before.WeChatConnectMode != after.WeChatConnectMode {
+ changed = append(changed, "wechat_connect_mode")
+ }
+ if before.WeChatConnectScopes != after.WeChatConnectScopes {
+ changed = append(changed, "wechat_connect_scopes")
+ }
+ if before.WeChatConnectRedirectURL != after.WeChatConnectRedirectURL {
+ changed = append(changed, "wechat_connect_redirect_url")
+ }
+ if before.WeChatConnectFrontendRedirectURL != after.WeChatConnectFrontendRedirectURL {
+ changed = append(changed, "wechat_connect_frontend_redirect_url")
+ }
+ if before.OIDCConnectEnabled != after.OIDCConnectEnabled {
+ changed = append(changed, "oidc_connect_enabled")
+ }
+ if before.OIDCConnectProviderName != after.OIDCConnectProviderName {
+ changed = append(changed, "oidc_connect_provider_name")
+ }
+ if before.OIDCConnectClientID != after.OIDCConnectClientID {
+ changed = append(changed, "oidc_connect_client_id")
+ }
+ if req.OIDCConnectClientSecret != "" {
+ changed = append(changed, "oidc_connect_client_secret")
+ }
+ if before.OIDCConnectIssuerURL != after.OIDCConnectIssuerURL {
+ changed = append(changed, "oidc_connect_issuer_url")
+ }
+ if before.OIDCConnectDiscoveryURL != after.OIDCConnectDiscoveryURL {
+ changed = append(changed, "oidc_connect_discovery_url")
+ }
+ if before.OIDCConnectAuthorizeURL != after.OIDCConnectAuthorizeURL {
+ changed = append(changed, "oidc_connect_authorize_url")
+ }
+ if before.OIDCConnectTokenURL != after.OIDCConnectTokenURL {
+ changed = append(changed, "oidc_connect_token_url")
+ }
+ if before.OIDCConnectUserInfoURL != after.OIDCConnectUserInfoURL {
+ changed = append(changed, "oidc_connect_userinfo_url")
+ }
+ if before.OIDCConnectJWKSURL != after.OIDCConnectJWKSURL {
+ changed = append(changed, "oidc_connect_jwks_url")
+ }
+ if before.OIDCConnectScopes != after.OIDCConnectScopes {
+ changed = append(changed, "oidc_connect_scopes")
+ }
+ if before.OIDCConnectRedirectURL != after.OIDCConnectRedirectURL {
+ changed = append(changed, "oidc_connect_redirect_url")
+ }
+ if before.OIDCConnectFrontendRedirectURL != after.OIDCConnectFrontendRedirectURL {
+ changed = append(changed, "oidc_connect_frontend_redirect_url")
+ }
+ if before.OIDCConnectTokenAuthMethod != after.OIDCConnectTokenAuthMethod {
+ changed = append(changed, "oidc_connect_token_auth_method")
+ }
+ if before.OIDCConnectUsePKCE != after.OIDCConnectUsePKCE {
+ changed = append(changed, "oidc_connect_use_pkce")
+ }
+ if before.OIDCConnectValidateIDToken != after.OIDCConnectValidateIDToken {
+ changed = append(changed, "oidc_connect_validate_id_token")
+ }
+ if before.OIDCConnectAllowedSigningAlgs != after.OIDCConnectAllowedSigningAlgs {
+ changed = append(changed, "oidc_connect_allowed_signing_algs")
+ }
+ if before.OIDCConnectClockSkewSeconds != after.OIDCConnectClockSkewSeconds {
+ changed = append(changed, "oidc_connect_clock_skew_seconds")
+ }
+ if before.OIDCConnectRequireEmailVerified != after.OIDCConnectRequireEmailVerified {
+ changed = append(changed, "oidc_connect_require_email_verified")
+ }
+ if before.OIDCConnectUserInfoEmailPath != after.OIDCConnectUserInfoEmailPath {
+ changed = append(changed, "oidc_connect_userinfo_email_path")
+ }
+ if before.OIDCConnectUserInfoIDPath != after.OIDCConnectUserInfoIDPath {
+ changed = append(changed, "oidc_connect_userinfo_id_path")
+ }
+ if before.OIDCConnectUserInfoUsernamePath != after.OIDCConnectUserInfoUsernamePath {
+ changed = append(changed, "oidc_connect_userinfo_username_path")
+ }
if before.SiteName != after.SiteName {
changed = append(changed, "site_name")
}
@@ -814,6 +1804,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance")
}
+ if before.AffiliateRebateRate != after.AffiliateRebateRate {
+ changed = append(changed, "affiliate_rebate_rate")
+ }
+ if before.AffiliateRebateFreezeHours != after.AffiliateRebateFreezeHours {
+ changed = append(changed, "affiliate_rebate_freeze_hours")
+ }
+ if before.AffiliateRebateDurationDays != after.AffiliateRebateDurationDays {
+ changed = append(changed, "affiliate_rebate_duration_days")
+ }
+ if before.AffiliateRebatePerInviteeCap != after.AffiliateRebatePerInviteeCap {
+ changed = append(changed, "affiliate_rebate_per_invitee_cap")
+ }
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
changed = append(changed, "default_subscriptions")
}
@@ -868,15 +1870,114 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.PurchaseSubscriptionURL != after.PurchaseSubscriptionURL {
changed = append(changed, "purchase_subscription_url")
}
+ if before.TableDefaultPageSize != after.TableDefaultPageSize {
+ changed = append(changed, "table_default_page_size")
+ }
+ if !equalIntSlice(before.TablePageSizeOptions, after.TablePageSizeOptions) {
+ changed = append(changed, "table_page_size_options")
+ }
if before.CustomMenuItems != after.CustomMenuItems {
changed = append(changed, "custom_menu_items")
}
+ if before.CustomEndpoints != after.CustomEndpoints {
+ changed = append(changed, "custom_endpoints")
+ }
if before.EnableFingerprintUnification != after.EnableFingerprintUnification {
changed = append(changed, "enable_fingerprint_unification")
}
if before.EnableMetadataPassthrough != after.EnableMetadataPassthrough {
changed = append(changed, "enable_metadata_passthrough")
}
+ if before.EnableCCHSigning != after.EnableCCHSigning {
+ changed = append(changed, "enable_cch_signing")
+ }
+ if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
+ changed = append(changed, "payment_visible_method_alipay_source")
+ }
+ if before.PaymentVisibleMethodWxpaySource != after.PaymentVisibleMethodWxpaySource {
+ changed = append(changed, "payment_visible_method_wxpay_source")
+ }
+ if before.PaymentVisibleMethodAlipayEnabled != after.PaymentVisibleMethodAlipayEnabled {
+ changed = append(changed, "payment_visible_method_alipay_enabled")
+ }
+ if before.PaymentVisibleMethodWxpayEnabled != after.PaymentVisibleMethodWxpayEnabled {
+ changed = append(changed, "payment_visible_method_wxpay_enabled")
+ }
+ if before.OpenAIAdvancedSchedulerEnabled != after.OpenAIAdvancedSchedulerEnabled {
+ changed = append(changed, "openai_advanced_scheduler_enabled")
+ }
+ // Balance & quota notification
+ if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled {
+ changed = append(changed, "balance_low_notify_enabled")
+ }
+ if before.BalanceLowNotifyThreshold != after.BalanceLowNotifyThreshold {
+ changed = append(changed, "balance_low_notify_threshold")
+ }
+ if before.BalanceLowNotifyRechargeURL != after.BalanceLowNotifyRechargeURL {
+ changed = append(changed, "balance_low_notify_recharge_url")
+ }
+ if before.AccountQuotaNotifyEnabled != after.AccountQuotaNotifyEnabled {
+ changed = append(changed, "account_quota_notify_enabled")
+ }
+ if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) {
+ changed = append(changed, "account_quota_notify_emails")
+ }
+ if before.ChannelMonitorEnabled != after.ChannelMonitorEnabled {
+ changed = append(changed, "channel_monitor_enabled")
+ }
+ if before.ChannelMonitorDefaultIntervalSeconds != after.ChannelMonitorDefaultIntervalSeconds {
+ changed = append(changed, "channel_monitor_default_interval_seconds")
+ }
+ if before.AvailableChannelsEnabled != after.AvailableChannelsEnabled {
+ changed = append(changed, "available_channels_enabled")
+ }
+ if before.AffiliateEnabled != after.AffiliateEnabled {
+ changed = append(changed, "affiliate_enabled")
+ }
+ changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
+ return changed
+}
+
+func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSourceDefaultSettings, after *service.AuthSourceDefaultSettings) []string {
+ if before == nil {
+ before = &service.AuthSourceDefaultSettings{}
+ }
+ if after == nil {
+ after = &service.AuthSourceDefaultSettings{}
+ }
+
+ type providerDefaultGrantField struct {
+ name string
+ before service.ProviderDefaultGrantSettings
+ after service.ProviderDefaultGrantSettings
+ }
+
+ fields := []providerDefaultGrantField{
+ {name: "email", before: before.Email, after: after.Email},
+ {name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo},
+ {name: "oidc", before: before.OIDC, after: after.OIDC},
+ {name: "wechat", before: before.WeChat, after: after.WeChat},
+ }
+ for _, field := range fields {
+ if field.before.Balance != field.after.Balance {
+ changed = append(changed, "auth_source_default_"+field.name+"_balance")
+ }
+ if field.before.Concurrency != field.after.Concurrency {
+ changed = append(changed, "auth_source_default_"+field.name+"_concurrency")
+ }
+ if !equalDefaultSubscriptions(field.before.Subscriptions, field.after.Subscriptions) {
+ changed = append(changed, "auth_source_default_"+field.name+"_subscriptions")
+ }
+ if field.before.GrantOnSignup != field.after.GrantOnSignup {
+ changed = append(changed, "auth_source_default_"+field.name+"_grant_on_signup")
+ }
+ if field.before.GrantOnFirstBind != field.after.GrantOnFirstBind {
+ changed = append(changed, "auth_source_default_"+field.name+"_grant_on_first_bind")
+ }
+ }
+ if before.ForceEmailOnThirdPartySignup != after.ForceEmailOnThirdPartySignup {
+ changed = append(changed, "force_email_on_third_party_signup")
+ }
return changed
}
@@ -897,6 +1998,84 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto
return normalized
}
+func normalizeOptionalDefaultSubscriptions(input *[]dto.DefaultSubscriptionSetting) *[]dto.DefaultSubscriptionSetting {
+ if input == nil {
+ return nil
+ }
+ normalized := normalizeDefaultSubscriptions(*input)
+ return &normalized
+}
+
+func float64ValueOrDefault(value *float64, fallback float64) float64 {
+ if value == nil {
+ return fallback
+ }
+ return *value
+}
+
+func intValueOrDefault(value *int, fallback int) int {
+ if value == nil {
+ return fallback
+ }
+ return *value
+}
+
+func boolValueOrDefault(value *bool, fallback bool) bool {
+ if value == nil {
+ return fallback
+ }
+ return *value
+}
+
+func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, fallback []service.DefaultSubscriptionSetting) []service.DefaultSubscriptionSetting {
+ if input == nil {
+ return fallback
+ }
+ result := make([]service.DefaultSubscriptionSetting, 0, len(*input))
+ for _, item := range *input {
+ result = append(result, service.DefaultSubscriptionSetting{
+ GroupID: item.GroupID,
+ ValidityDays: item.ValidityDays,
+ })
+ }
+ return result
+}
+
+func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any {
+ data := make(map[string]any)
+ raw, err := json.Marshal(settings)
+ if err == nil {
+ _ = json.Unmarshal(raw, &data)
+ }
+ if authSourceDefaults == nil {
+ authSourceDefaults = &service.AuthSourceDefaultSettings{}
+ }
+
+ data["auth_source_default_email_balance"] = authSourceDefaults.Email.Balance
+ data["auth_source_default_email_concurrency"] = authSourceDefaults.Email.Concurrency
+ data["auth_source_default_email_subscriptions"] = authSourceDefaults.Email.Subscriptions
+ data["auth_source_default_email_grant_on_signup"] = authSourceDefaults.Email.GrantOnSignup
+ data["auth_source_default_email_grant_on_first_bind"] = authSourceDefaults.Email.GrantOnFirstBind
+ data["auth_source_default_linuxdo_balance"] = authSourceDefaults.LinuxDo.Balance
+ data["auth_source_default_linuxdo_concurrency"] = authSourceDefaults.LinuxDo.Concurrency
+ data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions
+ data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup
+ data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind
+ data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance
+ data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency
+ data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions
+ data["auth_source_default_oidc_grant_on_signup"] = authSourceDefaults.OIDC.GrantOnSignup
+ data["auth_source_default_oidc_grant_on_first_bind"] = authSourceDefaults.OIDC.GrantOnFirstBind
+ data["auth_source_default_wechat_balance"] = authSourceDefaults.WeChat.Balance
+ data["auth_source_default_wechat_concurrency"] = authSourceDefaults.WeChat.Concurrency
+ data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions
+ data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup
+ data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind
+ data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
+
+ return data
+}
+
func equalStringSlice(a, b []string) bool {
if len(a) != len(b) {
return false
@@ -921,6 +2100,30 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
return true
}
+func equalIntSlice(a, b []int) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i := range a {
+ if a[i] != b[i] {
+ return false
+ }
+ }
+ return true
+}
+
+func equalNotifyEmailEntries(a, b []service.NotifyEmailEntry) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i := range a {
+ if a[i].Email != b[i].Email || a[i].Verified != b[i].Verified || a[i].Disabled != b[i].Disabled {
+ return false
+ }
+ }
+ return true
+}
+
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host"`
@@ -1207,384 +2410,6 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
})
}
-func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings {
- if settings == nil {
- return dto.SoraS3Settings{}
- }
- return dto.SoraS3Settings{
- Enabled: settings.Enabled,
- Endpoint: settings.Endpoint,
- Region: settings.Region,
- Bucket: settings.Bucket,
- AccessKeyID: settings.AccessKeyID,
- SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured,
- Prefix: settings.Prefix,
- ForcePathStyle: settings.ForcePathStyle,
- CDNURL: settings.CDNURL,
- DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes,
- }
-}
-
-func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile {
- return dto.SoraS3Profile{
- ProfileID: profile.ProfileID,
- Name: profile.Name,
- IsActive: profile.IsActive,
- Enabled: profile.Enabled,
- Endpoint: profile.Endpoint,
- Region: profile.Region,
- Bucket: profile.Bucket,
- AccessKeyID: profile.AccessKeyID,
- SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured,
- Prefix: profile.Prefix,
- ForcePathStyle: profile.ForcePathStyle,
- CDNURL: profile.CDNURL,
- DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes,
- UpdatedAt: profile.UpdatedAt,
- }
-}
-
-func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error {
- if !enabled {
- return nil
- }
- if strings.TrimSpace(endpoint) == "" {
- return fmt.Errorf("S3 Endpoint is required when enabled")
- }
- if strings.TrimSpace(bucket) == "" {
- return fmt.Errorf("S3 Bucket is required when enabled")
- }
- if strings.TrimSpace(accessKeyID) == "" {
- return fmt.Errorf("S3 Access Key ID is required when enabled")
- }
- if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret {
- return nil
- }
- return fmt.Errorf("S3 Secret Access Key is required when enabled")
-}
-
-func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile {
- for idx := range items {
- if items[idx].ProfileID == profileID {
- return &items[idx]
- }
- }
- return nil
-}
-
-// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口)
-// GET /api/v1/admin/settings/sora-s3
-func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) {
- settings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, toSoraS3SettingsDTO(settings))
-}
-
-// ListSoraS3Profiles 获取 Sora S3 多配置
-// GET /api/v1/admin/settings/sora-s3/profiles
-func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) {
- result, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- items := make([]dto.SoraS3Profile, 0, len(result.Items))
- for idx := range result.Items {
- items = append(items, toSoraS3ProfileDTO(result.Items[idx]))
- }
- response.Success(c, dto.ListSoraS3ProfilesResponse{
- ActiveProfileID: result.ActiveProfileID,
- Items: items,
- })
-}
-
-// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口)
-type UpdateSoraS3SettingsRequest struct {
- ProfileID string `json:"profile_id"`
- Enabled bool `json:"enabled"`
- Endpoint string `json:"endpoint"`
- Region string `json:"region"`
- Bucket string `json:"bucket"`
- AccessKeyID string `json:"access_key_id"`
- SecretAccessKey string `json:"secret_access_key"`
- Prefix string `json:"prefix"`
- ForcePathStyle bool `json:"force_path_style"`
- CDNURL string `json:"cdn_url"`
- DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
-}
-
-type CreateSoraS3ProfileRequest struct {
- ProfileID string `json:"profile_id"`
- Name string `json:"name"`
- SetActive bool `json:"set_active"`
- Enabled bool `json:"enabled"`
- Endpoint string `json:"endpoint"`
- Region string `json:"region"`
- Bucket string `json:"bucket"`
- AccessKeyID string `json:"access_key_id"`
- SecretAccessKey string `json:"secret_access_key"`
- Prefix string `json:"prefix"`
- ForcePathStyle bool `json:"force_path_style"`
- CDNURL string `json:"cdn_url"`
- DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
-}
-
-type UpdateSoraS3ProfileRequest struct {
- Name string `json:"name"`
- Enabled bool `json:"enabled"`
- Endpoint string `json:"endpoint"`
- Region string `json:"region"`
- Bucket string `json:"bucket"`
- AccessKeyID string `json:"access_key_id"`
- SecretAccessKey string `json:"secret_access_key"`
- Prefix string `json:"prefix"`
- ForcePathStyle bool `json:"force_path_style"`
- CDNURL string `json:"cdn_url"`
- DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
-}
-
-// CreateSoraS3Profile 创建 Sora S3 配置
-// POST /api/v1/admin/settings/sora-s3/profiles
-func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) {
- var req CreateSoraS3ProfileRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- response.BadRequest(c, "Invalid request: "+err.Error())
- return
- }
-
- if req.DefaultStorageQuotaBytes < 0 {
- req.DefaultStorageQuotaBytes = 0
- }
- if strings.TrimSpace(req.Name) == "" {
- response.BadRequest(c, "Name is required")
- return
- }
- if strings.TrimSpace(req.ProfileID) == "" {
- response.BadRequest(c, "Profile ID is required")
- return
- }
- if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil {
- response.BadRequest(c, err.Error())
- return
- }
-
- created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{
- ProfileID: req.ProfileID,
- Name: req.Name,
- Enabled: req.Enabled,
- Endpoint: req.Endpoint,
- Region: req.Region,
- Bucket: req.Bucket,
- AccessKeyID: req.AccessKeyID,
- SecretAccessKey: req.SecretAccessKey,
- Prefix: req.Prefix,
- ForcePathStyle: req.ForcePathStyle,
- CDNURL: req.CDNURL,
- DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
- }, req.SetActive)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- response.Success(c, toSoraS3ProfileDTO(*created))
-}
-
-// UpdateSoraS3Profile 更新 Sora S3 配置
-// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id
-func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) {
- profileID := strings.TrimSpace(c.Param("profile_id"))
- if profileID == "" {
- response.BadRequest(c, "Profile ID is required")
- return
- }
-
- var req UpdateSoraS3ProfileRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- response.BadRequest(c, "Invalid request: "+err.Error())
- return
- }
-
- if req.DefaultStorageQuotaBytes < 0 {
- req.DefaultStorageQuotaBytes = 0
- }
- if strings.TrimSpace(req.Name) == "" {
- response.BadRequest(c, "Name is required")
- return
- }
-
- existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- existing := findSoraS3ProfileByID(existingList.Items, profileID)
- if existing == nil {
- response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound)
- return
- }
- if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
- response.BadRequest(c, err.Error())
- return
- }
-
- updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{
- Name: req.Name,
- Enabled: req.Enabled,
- Endpoint: req.Endpoint,
- Region: req.Region,
- Bucket: req.Bucket,
- AccessKeyID: req.AccessKeyID,
- SecretAccessKey: req.SecretAccessKey,
- Prefix: req.Prefix,
- ForcePathStyle: req.ForcePathStyle,
- CDNURL: req.CDNURL,
- DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
- })
- if updateErr != nil {
- response.ErrorFrom(c, updateErr)
- return
- }
-
- response.Success(c, toSoraS3ProfileDTO(*updated))
-}
-
-// DeleteSoraS3Profile 删除 Sora S3 配置
-// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id
-func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) {
- profileID := strings.TrimSpace(c.Param("profile_id"))
- if profileID == "" {
- response.BadRequest(c, "Profile ID is required")
- return
- }
- if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, gin.H{"deleted": true})
-}
-
-// SetActiveSoraS3Profile 切换激活 Sora S3 配置
-// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate
-func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) {
- profileID := strings.TrimSpace(c.Param("profile_id"))
- if profileID == "" {
- response.BadRequest(c, "Profile ID is required")
- return
- }
- active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, toSoraS3ProfileDTO(*active))
-}
-
-// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口)
-// PUT /api/v1/admin/settings/sora-s3
-func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) {
- var req UpdateSoraS3SettingsRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- response.BadRequest(c, "Invalid request: "+err.Error())
- return
- }
-
- existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- if req.DefaultStorageQuotaBytes < 0 {
- req.DefaultStorageQuotaBytes = 0
- }
- if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
- response.BadRequest(c, err.Error())
- return
- }
-
- settings := &service.SoraS3Settings{
- Enabled: req.Enabled,
- Endpoint: req.Endpoint,
- Region: req.Region,
- Bucket: req.Bucket,
- AccessKeyID: req.AccessKeyID,
- SecretAccessKey: req.SecretAccessKey,
- Prefix: req.Prefix,
- ForcePathStyle: req.ForcePathStyle,
- CDNURL: req.CDNURL,
- DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
- }
- if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, toSoraS3SettingsDTO(updatedSettings))
-}
-
-// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket)
-// POST /api/v1/admin/settings/sora-s3/test
-func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
- if h.soraS3Storage == nil {
- response.Error(c, 500, "S3 存储服务未初始化")
- return
- }
-
- var req UpdateSoraS3SettingsRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- response.BadRequest(c, "Invalid request: "+err.Error())
- return
- }
- if !req.Enabled {
- response.BadRequest(c, "S3 未启用,无法测试连接")
- return
- }
-
- if req.SecretAccessKey == "" {
- if req.ProfileID != "" {
- profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
- if err == nil {
- profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID)
- if profile != nil {
- req.SecretAccessKey = profile.SecretAccessKey
- }
- }
- }
- if req.SecretAccessKey == "" {
- existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
- if err == nil {
- req.SecretAccessKey = existing.SecretAccessKey
- }
- }
- }
-
- testCfg := &service.SoraS3Settings{
- Enabled: true,
- Endpoint: req.Endpoint,
- Region: req.Region,
- Bucket: req.Bucket,
- AccessKeyID: req.AccessKeyID,
- SecretAccessKey: req.SecretAccessKey,
- Prefix: req.Prefix,
- ForcePathStyle: req.ForcePathStyle,
- CDNURL: req.CDNURL,
- }
- if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil {
- response.Error(c, 400, "S3 连接测试失败: "+err.Error())
- return
- }
- response.Success(c, gin.H{"message": "S3 连接成功"})
-}
-
// GetRectifierSettings 获取请求整流器配置
// GET /api/v1/admin/settings/rectifier
func (h *SettingHandler) GetRectifierSettings(c *gin.Context) {
@@ -1779,3 +2604,80 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
})
}
+
+// GetWebSearchEmulationConfig 获取 Web Search 模拟配置
+// GET /api/v1/admin/settings/web-search-emulation
+func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) {
+ cfg, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), cfg))
+}
+
+// UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置
+// PUT /api/v1/admin/settings/web-search-emulation
+func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) {
+ var cfg service.WebSearchEmulationConfig
+ if err := c.ShouldBindJSON(&cfg); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.settingService.SaveWebSearchEmulationConfig(c.Request.Context(), &cfg); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Re-read (with sanitized api keys) to return current state
+ updated, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), updated))
+}
+
+// ResetWebSearchUsage 重置指定 provider 的配额用量
+// POST /api/v1/admin/settings/web-search-emulation/reset-usage
+func (h *SettingHandler) ResetWebSearchUsage(c *gin.Context) {
+ var req struct {
+ ProviderType string `json:"provider_type"`
+ }
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ if req.ProviderType == "" {
+ response.BadRequest(c, "provider_type is required")
+ return
+ }
+ if err := service.ResetWebSearchUsage(c.Request.Context(), req.ProviderType); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, nil)
+}
+
+// TestWebSearchEmulation 测试 Web Search 搜索
+// POST /api/v1/admin/settings/web-search-emulation/test
+func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) {
+ var req struct {
+ Query string `json:"query"`
+ }
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ if strings.TrimSpace(req.Query) == "" {
+ req.Query = "搜索今年世界大事件"
+ }
+
+ result, err := service.TestWebSearch(c.Request.Context(), req.Query)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, result)
+}
diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
new file mode 100644
index 00000000..9a33a93a
--- /dev/null
+++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
@@ -0,0 +1,503 @@
+package admin
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type settingHandlerRepoStub struct {
+ values map[string]string
+ lastUpdates map[string]string
+}
+
+func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *settingHandlerRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *settingHandlerRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ s.lastUpdates = make(map[string]string, len(settings))
+ for key, value := range settings {
+ s.lastUpdates[key] = value
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *settingHandlerRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ out := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ out[key] = value
+ }
+ return out, nil
+}
+
+func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+type failingAuthSourceSettingsRepoStub struct {
+ values map[string]string
+ err error
+}
+
+func (s *failingAuthSourceSettingsRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *failingAuthSourceSettingsRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *failingAuthSourceSettingsRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *failingAuthSourceSettingsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *failingAuthSourceSettingsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ if _, ok := settings[service.SettingKeyAuthSourceDefaultEmailBalance]; ok {
+ return s.err
+ }
+ for key, value := range settings {
+ if s.values == nil {
+ s.values = map[string]string{}
+ }
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *failingAuthSourceSettingsRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ out := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ out[key] = value
+ }
+ return out, nil
+}
+
+func (s *failingAuthSourceSettingsRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
+ service.SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil)
+
+ handler.GetSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, 9.5, data["auth_source_default_email_balance"])
+ require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
+ require.Equal(t, true, data["force_email_on_third_party_signup"])
+
+ subscriptions, ok := data["auth_source_default_email_subscriptions"].([]any)
+ require.True(t, ok)
+ require.Len(t, subscriptions, 1)
+}
+
+func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "false",
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
+ service.SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "registration_enabled": true,
+ "promo_code_enabled": true,
+ "auth_source_default_email_balance": 12.75,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "12.75000000", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance])
+ require.Equal(t, "8", repo.values[service.SettingKeyAuthSourceDefaultEmailConcurrency])
+ require.Equal(t, `[{"group_id":31,"validity_days":15}]`, repo.values[service.SettingKeyAuthSourceDefaultEmailSubscriptions])
+ require.Equal(t, "true", repo.values[service.SettingKeyForceEmailOnThirdPartySignup])
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, 12.75, data["auth_source_default_email_balance"])
+ require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
+ require.Equal(t, true, data["force_email_on_third_party_signup"])
+}
+
+func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "payment_visible_method_alipay_source": "easypay",
+ "payment_visible_method_wxpay_source": "wxpay",
+ "payment_visible_method_alipay_enabled": true,
+ "payment_visible_method_wxpay_enabled": false,
+ "openai_advanced_scheduler_enabled": true,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, repo.values[service.SettingPaymentVisibleMethodAlipaySource])
+ require.Equal(t, service.VisibleMethodSourceOfficialWechat, repo.values[service.SettingPaymentVisibleMethodWxpaySource])
+ require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled])
+ require.Equal(t, "false", repo.values[service.SettingPaymentVisibleMethodWxpayEnabled])
+ require.Equal(t, "true", repo.values["openai_advanced_scheduler_enabled"])
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, data["payment_visible_method_alipay_source"])
+ require.Equal(t, service.VisibleMethodSourceOfficialWechat, data["payment_visible_method_wxpay_source"])
+ require.Equal(t, true, data["payment_visible_method_alipay_enabled"])
+ require.Equal(t, false, data["payment_visible_method_wxpay_enabled"])
+ require.Equal(t, true, data["openai_advanced_scheduler_enabled"])
+}
+
+func TestSettingHandler_UpdateSettings_PreservesLegacyBlankPaymentVisibleMethodSource(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingPaymentVisibleMethodAlipayEnabled: "true",
+ service.SettingPaymentVisibleMethodAlipaySource: "",
+ service.SettingPaymentVisibleMethodWxpayEnabled: "false",
+ service.SettingPaymentVisibleMethodWxpaySource: "",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": false,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "", repo.values[service.SettingPaymentVisibleMethodAlipaySource])
+ require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled])
+}
+
+func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFlags(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyOIDCConnectEnabled: "true",
+ service.SettingKeyOIDCConnectProviderName: "OIDC",
+ service.SettingKeyOIDCConnectClientID: "oidc-client",
+ service.SettingKeyOIDCConnectClientSecret: "oidc-secret",
+ service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com",
+ service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth",
+ service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token",
+ service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo",
+ service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks",
+ service.SettingKeyOIDCConnectScopes: "openid email profile",
+ service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
+ service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
+ service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
+ service.SettingKeyOIDCConnectUsePKCE: "true",
+ service.SettingKeyOIDCConnectValidateIDToken: "true",
+ service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256",
+ service.SettingKeyOIDCConnectClockSkewSeconds: "120",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "oidc_connect_enabled": true,
+ "oidc_connect_use_pkce": false,
+ "oidc_connect_validate_id_token": false,
+ "oidc_connect_allowed_signing_algs": "",
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE])
+ require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken])
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, false, data["oidc_connect_use_pkce"])
+ require.Equal(t, false, data["oidc_connect_validate_id_token"])
+}
+
+func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaultsOnLegacyUpgrade(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyOIDCConnectEnabled: "true",
+ service.SettingKeyOIDCConnectProviderName: "OIDC",
+ service.SettingKeyOIDCConnectClientID: "oidc-client",
+ service.SettingKeyOIDCConnectClientSecret: "oidc-secret",
+ service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com",
+ service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth",
+ service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token",
+ service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo",
+ service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks",
+ service.SettingKeyOIDCConnectScopes: "openid email profile",
+ service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
+ service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
+ service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
+ service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256",
+ service.SettingKeyOIDCConnectClockSkewSeconds: "120",
+ service.SettingKeyOIDCConnectRequireEmailVerified: "false",
+ service.SettingKeyOIDCConnectUserInfoEmailPath: "",
+ service.SettingKeyOIDCConnectUserInfoIDPath: "",
+ service.SettingKeyOIDCConnectUserInfoUsernamePath: "",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{
+ Default: config.DefaultConfig{UserConcurrency: 5},
+ OIDC: config.OIDCConnectConfig{
+ Enabled: true,
+ ProviderName: "OIDC",
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: "https://issuer.example.com/auth",
+ TokenURL: "https://issuer.example.com/token",
+ UserInfoURL: "https://issuer.example.com/userinfo",
+ JWKSURL: "https://issuer.example.com/jwks",
+ Scopes: "openid email profile",
+ RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ ValidateIDToken: true,
+ AllowedSigningAlgs: "RS256",
+ ClockSkewSeconds: 120,
+ },
+ })
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "oidc_connect_enabled": true,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE])
+ require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken])
+}
+
+func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &settingHandlerRepoStub{
+ values: map[string]string{
+ service.SettingKeyPromoCodeEnabled: "true",
+ },
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "promo_code_enabled": true,
+ "payment_visible_method_alipay_source": "bogus",
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.NotContains(t, repo.values, service.SettingPaymentVisibleMethodAlipaySource)
+}
+
+func TestSettingHandler_UpdateSettings_DoesNotPersistPartialSystemSettingsWhenAuthSourceDefaultsFail(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := &failingAuthSourceSettingsRepoStub{
+ values: map[string]string{
+ service.SettingKeyRegistrationEnabled: "false",
+ service.SettingKeyPromoCodeEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
+ },
+ err: errors.New("write auth source defaults failed"),
+ }
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+
+ body := map[string]any{
+ "registration_enabled": true,
+ "promo_code_enabled": true,
+ "auth_source_default_email_balance": 12.75,
+ }
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+ require.Equal(t, "false", repo.values[service.SettingKeyRegistrationEnabled])
+ require.Equal(t, "9.5", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance])
+}
+
+func TestDiffSettings_IncludesAuthSourceDefaultsAndForceEmail(t *testing.T) {
+ changed := diffSettings(
+ &service.SystemSettings{},
+ &service.SystemSettings{},
+ &service.AuthSourceDefaultSettings{
+ Email: service.ProviderDefaultGrantSettings{
+ Balance: 0,
+ Concurrency: 5,
+ Subscriptions: nil,
+ GrantOnSignup: true,
+ GrantOnFirstBind: false,
+ },
+ ForceEmailOnThirdPartySignup: false,
+ },
+ &service.AuthSourceDefaultSettings{
+ Email: service.ProviderDefaultGrantSettings{
+ Balance: 12.5,
+ Concurrency: 7,
+ Subscriptions: []service.DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 30}},
+ GrantOnSignup: false,
+ GrantOnFirstBind: true,
+ },
+ ForceEmailOnThirdPartySignup: true,
+ },
+ UpdateSettingsRequest{},
+ )
+
+ require.Contains(t, changed, "auth_source_default_email_balance")
+ require.Contains(t, changed, "auth_source_default_email_concurrency")
+ require.Contains(t, changed, "auth_source_default_email_subscriptions")
+ require.Contains(t, changed, "auth_source_default_email_grant_on_signup")
+ require.Contains(t, changed, "auth_source_default_email_grant_on_first_bind")
+ require.Contains(t, changed, "force_email_on_third_party_signup")
+}
diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go
index 7a3135b8..0857a138 100644
--- a/backend/internal/handler/admin/usage_handler.go
+++ b/backend/internal/handler/admin/usage_handler.go
@@ -110,6 +110,7 @@ func (h *UsageHandler) List(c *gin.Context) {
}
model := c.Query("model")
+ billingMode := strings.TrimSpace(c.Query("billing_mode"))
var requestType *int16
var stream *bool
@@ -164,7 +165,12 @@ func (h *UsageHandler) List(c *gin.Context) {
endTime = &t
}
- params := pagination.PaginationParams{Page: page, PageSize: pageSize}
+ params := pagination.PaginationParams{
+ Page: page,
+ PageSize: pageSize,
+ SortBy: c.DefaultQuery("sort_by", "created_at"),
+ SortOrder: c.DefaultQuery("sort_order", "desc"),
+ }
filters := usagestats.UsageLogFilters{
UserID: userID,
APIKeyID: apiKeyID,
@@ -174,6 +180,7 @@ func (h *UsageHandler) List(c *gin.Context) {
RequestType: requestType,
Stream: stream,
BillingType: billingType,
+ BillingMode: billingMode,
StartTime: startTime,
EndTime: endTime,
ExactTotal: exactTotal,
@@ -234,6 +241,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
}
model := c.Query("model")
+ billingMode := strings.TrimSpace(c.Query("billing_mode"))
var requestType *int16
var stream *bool
@@ -312,6 +320,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
RequestType: requestType,
Stream: stream,
BillingType: billingType,
+ BillingMode: billingMode,
StartTime: &startTime,
EndTime: &endTime,
}
@@ -335,7 +344,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
}
// Limit to 30 results
- users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword})
+ users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword}, "email", "asc")
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/admin/usage_handler_request_type_test.go b/backend/internal/handler/admin/usage_handler_request_type_test.go
index 3f158316..882cbe93 100644
--- a/backend/internal/handler/admin/usage_handler_request_type_test.go
+++ b/backend/internal/handler/admin/usage_handler_request_type_test.go
@@ -15,11 +15,13 @@ import (
type adminUsageRepoCapture struct {
service.UsageLogRepository
+ listParams pagination.PaginationParams
listFilters usagestats.UsageLogFilters
statsFilters usagestats.UsageLogFilters
}
func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
+ s.listParams = params
s.listFilters = filters
return []service.UsageLog{}, &pagination.PaginationResult{
Total: 0,
diff --git a/backend/internal/handler/admin/usage_handler_sort_test.go b/backend/internal/handler/admin/usage_handler_sort_test.go
new file mode 100644
index 00000000..dac82676
--- /dev/null
+++ b/backend/internal/handler/admin/usage_handler_sort_test.go
@@ -0,0 +1,35 @@
+package admin
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAdminUsageListSortParams(t *testing.T) {
+ repo := &adminUsageRepoCapture{}
+ router := newAdminUsageRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/usage?sort_by=model&sort_order=ASC", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "model", repo.listParams.SortBy)
+ require.Equal(t, "ASC", repo.listParams.SortOrder)
+}
+
+func TestAdminUsageListSortDefaults(t *testing.T) {
+ repo := &adminUsageRepoCapture{}
+ router := newAdminUsageRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/admin/usage", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "created_at", repo.listParams.SortBy)
+ require.Equal(t, "desc", repo.listParams.SortOrder)
+}
diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go
index 998308dd..3d80107f 100644
--- a/backend/internal/handler/admin/user_handler.go
+++ b/backend/internal/handler/admin/user_handler.go
@@ -34,14 +34,14 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi
// CreateUserRequest represents admin create user request
type CreateUserRequest struct {
- Email string `json:"email" binding:"required,email"`
- Password string `json:"password" binding:"required,min=6"`
- Username string `json:"username"`
- Notes string `json:"notes"`
- Balance float64 `json:"balance"`
- Concurrency int `json:"concurrency"`
- AllowedGroups []int64 `json:"allowed_groups"`
- SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
+ Email string `json:"email" binding:"required,email"`
+ Password string `json:"password" binding:"required,min=6"`
+ Username string `json:"username"`
+ Notes string `json:"notes"`
+ Balance float64 `json:"balance"`
+ Concurrency int `json:"concurrency"`
+ RPMLimit int `json:"rpm_limit"`
+ AllowedGroups []int64 `json:"allowed_groups"`
}
// UpdateUserRequest represents admin update user request
@@ -53,12 +53,12 @@ type UpdateUserRequest struct {
Notes *string `json:"notes"`
Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"`
+ RPMLimit *int `json:"rpm_limit"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups *[]int64 `json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率
- GroupRates map[int64]*float64 `json:"group_rates"`
- SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
+ GroupRates map[int64]*float64 `json:"group_rates"`
}
// UpdateBalanceRequest represents balance update request
@@ -68,6 +68,22 @@ type UpdateBalanceRequest struct {
Notes string `json:"notes"`
}
+type BindUserAuthIdentityRequest struct {
+ ProviderType string `json:"provider_type"`
+ ProviderKey string `json:"provider_key"`
+ ProviderSubject string `json:"provider_subject"`
+ Issuer *string `json:"issuer"`
+ Metadata map[string]any `json:"metadata"`
+ Channel *BindUserAuthIdentityChannelRequest `json:"channel"`
+}
+
+type BindUserAuthIdentityChannelRequest struct {
+ Channel string `json:"channel"`
+ ChannelAppID string `json:"channel_app_id"`
+ ChannelSubject string `json:"channel_subject"`
+ Metadata map[string]any `json:"metadata"`
+}
+
// List handles listing all users with pagination
// GET /api/v1/admin/users
// Query params:
@@ -93,12 +109,14 @@ func (h *UserHandler) List(c *gin.Context) {
GroupName: strings.TrimSpace(c.Query("group_name")),
Attributes: parseAttributeFilters(c),
}
+ sortBy := c.DefaultQuery("sort_by", "created_at")
+ sortOrder := c.DefaultQuery("sort_order", "desc")
if raw, ok := c.GetQuery("include_subscriptions"); ok {
includeSubscriptions := parseBoolQueryWithDefault(raw, true)
filters.IncludeSubscriptions = &includeSubscriptions
}
- users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters)
+ users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -172,6 +190,45 @@ func (h *UserHandler) GetByID(c *gin.Context) {
response.Success(c, dto.UserFromServiceAdmin(user))
}
+// BindAuthIdentity manually binds a canonical auth identity to a user.
+// POST /api/v1/admin/users/:id/auth-identities
+func (h *UserHandler) BindAuthIdentity(c *gin.Context) {
+ userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid user ID")
+ return
+ }
+
+ var req BindUserAuthIdentityRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ input := service.AdminBindAuthIdentityInput{
+ ProviderType: req.ProviderType,
+ ProviderKey: req.ProviderKey,
+ ProviderSubject: req.ProviderSubject,
+ Issuer: req.Issuer,
+ Metadata: req.Metadata,
+ }
+ if req.Channel != nil {
+ input.Channel = &service.AdminBindAuthIdentityChannelInput{
+ Channel: req.Channel.Channel,
+ ChannelAppID: req.Channel.ChannelAppID,
+ ChannelSubject: req.Channel.ChannelSubject,
+ Metadata: req.Channel.Metadata,
+ }
+ }
+
+ result, err := h.adminService.BindUserAuthIdentity(c.Request.Context(), userID, input)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, result)
+}
+
// Create handles creating a new user
// POST /api/v1/admin/users
func (h *UserHandler) Create(c *gin.Context) {
@@ -182,14 +239,14 @@ func (h *UserHandler) Create(c *gin.Context) {
}
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
- Email: req.Email,
- Password: req.Password,
- Username: req.Username,
- Notes: req.Notes,
- Balance: req.Balance,
- Concurrency: req.Concurrency,
- AllowedGroups: req.AllowedGroups,
- SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
+ Email: req.Email,
+ Password: req.Password,
+ Username: req.Username,
+ Notes: req.Notes,
+ Balance: req.Balance,
+ Concurrency: req.Concurrency,
+ RPMLimit: req.RPMLimit,
+ AllowedGroups: req.AllowedGroups,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -216,16 +273,16 @@ func (h *UserHandler) Update(c *gin.Context) {
// 使用指针类型直接传递,nil 表示未提供该字段
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
- Email: req.Email,
- Password: req.Password,
- Username: req.Username,
- Notes: req.Notes,
- Balance: req.Balance,
- Concurrency: req.Concurrency,
- Status: req.Status,
- AllowedGroups: req.AllowedGroups,
- GroupRates: req.GroupRates,
- SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
+ Email: req.Email,
+ Password: req.Password,
+ Username: req.Username,
+ Notes: req.Notes,
+ Balance: req.Balance,
+ Concurrency: req.Concurrency,
+ RPMLimit: req.RPMLimit,
+ Status: req.Status,
+ AllowedGroups: req.AllowedGroups,
+ GroupRates: req.GroupRates,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -294,8 +351,10 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
}
page, pageSize := response.ParsePagination(c)
+ sortBy := c.DefaultQuery("sort_by", "created_at")
+ sortOrder := c.DefaultQuery("sort_order", "desc")
- keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
+ keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -400,3 +459,21 @@ func (h *UserHandler) ReplaceGroup(c *gin.Context) {
"migrated_keys": result.MigratedKeys,
})
}
+
+// GetUserRPMStatus 返回指定用户当前分钟的 RPM 用量
+// GET /api/v1/admin/users/:id/rpm-status
+func (h *UserHandler) GetUserRPMStatus(c *gin.Context) {
+ userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid user ID")
+ return
+ }
+
+ status, err := h.adminService.GetUserRPMStatus(c.Request.Context(), userID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, status)
+}
diff --git a/backend/internal/handler/admin/user_handler_activity_test.go b/backend/internal/handler/admin/user_handler_activity_test.go
new file mode 100644
index 00000000..bfba2408
--- /dev/null
+++ b/backend/internal/handler/admin/user_handler_activity_test.go
@@ -0,0 +1,114 @@
+//go:build unit
+
+package admin
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ lastLoginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
+ lastActiveAt := lastLoginAt.Add(30 * time.Minute)
+ lastUsedAt := lastLoginAt.Add(90 * time.Minute)
+
+ adminSvc := newStubAdminService()
+ adminSvc.users = []service.User{
+ {
+ ID: 7,
+ Email: "activity@example.com",
+ Username: "activity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ LastActiveAt: &lastActiveAt,
+ LastUsedAt: &lastUsedAt,
+ CreatedAt: lastLoginAt.Add(-24 * time.Hour),
+ UpdatedAt: lastLoginAt,
+ },
+ }
+ handler := NewUserHandler(adminSvc, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/users?sort_by=last_used_at&sort_order=asc&search=activity",
+ nil,
+ )
+
+ handler.List(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, "last_used_at", adminSvc.lastListUsers.sortBy)
+ require.Equal(t, "asc", adminSvc.lastListUsers.sortOrder)
+ require.Equal(t, "activity", adminSvc.lastListUsers.filters.Search)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Items []struct {
+ LastActiveAt *time.Time `json:"last_active_at"`
+ LastUsedAt *time.Time `json:"last_used_at"`
+ } `json:"items"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Len(t, resp.Data.Items, 1)
+ require.WithinDuration(t, lastActiveAt, *resp.Data.Items[0].LastActiveAt, time.Second)
+ require.WithinDuration(t, lastUsedAt, *resp.Data.Items[0].LastUsedAt, time.Second)
+}
+
+func TestUserHandlerGetByIDIncludesActivityFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ lastLoginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
+ lastActiveAt := lastLoginAt.Add(30 * time.Minute)
+ lastUsedAt := lastLoginAt.Add(90 * time.Minute)
+
+ adminSvc := newStubAdminService()
+ adminSvc.users = []service.User{
+ {
+ ID: 8,
+ Email: "detail@example.com",
+ Username: "detail-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ LastActiveAt: &lastActiveAt,
+ LastUsedAt: &lastUsedAt,
+ CreatedAt: lastLoginAt.Add(-24 * time.Hour),
+ UpdatedAt: lastLoginAt,
+ },
+ }
+ handler := NewUserHandler(adminSvc, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Params = gin.Params{{Key: "id", Value: "8"}}
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/8", nil)
+
+ handler.GetByID(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ LastActiveAt *time.Time `json:"last_active_at"`
+ LastUsedAt *time.Time `json:"last_used_at"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.WithinDuration(t, lastActiveAt, *resp.Data.LastActiveAt, time.Second)
+ require.WithinDuration(t, lastUsedAt, *resp.Data.LastUsedAt, time.Second)
+}
diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go
index 951aed08..9d6c6c15 100644
--- a/backend/internal/handler/api_key_handler.go
+++ b/backend/internal/handler/api_key_handler.go
@@ -72,7 +72,12 @@ func (h *APIKeyHandler) List(c *gin.Context) {
}
page, pageSize := response.ParsePagination(c)
- params := pagination.PaginationParams{Page: page, PageSize: pageSize}
+ params := pagination.PaginationParams{
+ Page: page,
+ PageSize: pageSize,
+ SortBy: c.DefaultQuery("sort_by", "created_at"),
+ SortOrder: c.DefaultQuery("sort_order", "desc"),
+ }
// Parse filter parameters
var filters service.APIKeyListFilters
diff --git a/backend/internal/handler/auth_current_user_test.go b/backend/internal/handler/auth_current_user_test.go
new file mode 100644
index 00000000..cb3e4ba5
--- /dev/null
+++ b/backend/internal/handler/auth_current_user_test.go
@@ -0,0 +1,86 @@
+//go:build unit
+
+package handler
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthHandlerGetCurrentUserReturnsProfileCompatibilityFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 31,
+ Email: "me@example.com",
+ Username: "linuxdo-handle",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ AvatarURL: "https://cdn.example.com/linuxdo.png",
+ AvatarSource: "remote_url",
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-31",
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ "avatar_url": "https://cdn.example.com/linuxdo.png",
+ },
+ },
+ },
+ }
+
+ handler := &AuthHandler{
+ userService: service.NewUserService(repo, nil, nil, nil),
+ }
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 31})
+
+ handler.GetCurrentUser(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, true, resp.Data["email_bound"])
+ require.Equal(t, true, resp.Data["linuxdo_bound"])
+ require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"])
+
+ authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
+ require.True(t, ok)
+ linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, linuxdoBinding["bound"])
+
+ avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", avatarSource["provider"])
+ require.Equal(t, "linuxdo", avatarSource["source"])
+
+ profileSources, ok := resp.Data["profile_sources"].(map[string]any)
+ require.True(t, ok)
+ usernameSource, ok := profileSources["username"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", usernameSource["provider"])
+ require.Equal(t, "linuxdo", usernameSource["source"])
+}
diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go
index f4ddf890..1f9a66ff 100644
--- a/backend/internal/handler/auth_handler.go
+++ b/backend/internal/handler/auth_handler.go
@@ -1,11 +1,13 @@
package handler
import (
+ "context"
"log/slog"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -46,6 +48,7 @@ type RegisterRequest struct {
TurnstileToken string `json:"turnstile_token"`
PromoCode string `json:"promo_code"` // 注册优惠码
InvitationCode string `json:"invitation_code"` // 邀请码
+ AffCode string `json:"aff_code"` // 邀请返利码
}
// SendVerifyCodeRequest 发送验证码请求
@@ -76,9 +79,24 @@ type AuthResponse struct {
User *dto.User `json:"user"`
}
+func ensureLoginUserActive(user *service.User) error {
+ if user == nil {
+ return infraerrors.Unauthorized("INVALID_USER", "user not found")
+ }
+ if !user.IsActive() {
+ return service.ErrUserNotActive
+ }
+ return nil
+}
+
// respondWithTokenPair 生成 Token 对并返回认证响应
// 如果 Token 对生成失败,回退到只返回 Access Token(向后兼容)
func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
+ if err := ensureLoginUserActive(user); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
if err != nil {
slog.Error("failed to generate token pair", "error", err, "user_id", user.ID)
@@ -104,6 +122,34 @@ func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
})
}
+func (h *AuthHandler) ensureBackendModeAllowsUser(ctx context.Context, user *service.User) error {
+ if user == nil {
+ return infraerrors.Unauthorized("INVALID_USER", "user not found")
+ }
+ if h == nil || !h.isBackendModeEnabled(ctx) || user.IsAdmin() {
+ return nil
+ }
+ return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.")
+}
+
+func (h *AuthHandler) ensureBackendModeAllowsNewUserLogin(ctx context.Context) error {
+ if h == nil || !h.isBackendModeEnabled(ctx) {
+ return nil
+ }
+ return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.")
+}
+
+func (h *AuthHandler) isBackendModeEnabled(ctx context.Context) bool {
+ if h == nil || h.settingSvc == nil {
+ return false
+ }
+ settings, err := h.settingSvc.GetPublicSettings(ctx)
+ if err == nil && settings != nil {
+ return settings.BackendModeEnabled
+ }
+ return h.settingSvc.IsBackendModeEnabled(ctx)
+}
+
// Register handles user registration
// POST /api/v1/auth/register
func (h *AuthHandler) Register(c *gin.Context) {
@@ -119,7 +165,15 @@ func (h *AuthHandler) Register(c *gin.Context) {
return
}
- _, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode)
+ _, user, err := h.authService.RegisterWithVerification(
+ c.Request.Context(),
+ req.Email,
+ req.Password,
+ req.VerifyCode,
+ req.PromoCode,
+ req.InvitationCode,
+ req.AffCode,
+ )
if err != nil {
response.ErrorFrom(c, err)
return
@@ -177,6 +231,11 @@ func (h *AuthHandler) Login(c *gin.Context) {
}
_ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
// Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
// Create a temporary login session for 2FA
@@ -194,11 +253,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
- // Backend mode: only admin can login
- if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
- response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
- return
- }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
h.respondWithTokenPair(c, user)
}
@@ -262,16 +317,80 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
-
- // Backend mode: only admin can login (check BEFORE deleting session)
- if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
- response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
+ if err := ensureLoginUserActive(user); err != nil {
+ response.ErrorFrom(c, err)
return
}
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ if session.PendingOAuthBind != nil {
+ pendingSvc, err := h.pendingIdentityService()
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ pendingSession, err := pendingSvc.GetBrowserSession(
+ c.Request.Context(),
+ session.PendingOAuthBind.PendingSessionToken,
+ session.PendingOAuthBind.BrowserSessionKey,
+ )
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, pendingSession.ID, oauthAdoptionDecisionRequest{})
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthBinding(
+ c.Request.Context(),
+ h.entClient(),
+ h.authService,
+ h.userService,
+ pendingSession,
+ decision,
+ &user.ID,
+ true,
+ true,
+ ); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+ if _, err := pendingSvc.ConsumeBrowserSession(
+ c.Request.Context(),
+ pendingSession.SessionToken,
+ pendingSession.BrowserSessionKey,
+ ); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+
+ user, err = h.userService.GetByID(c.Request.Context(), session.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+
// Delete the login session (only after all checks pass)
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
+ if session.PendingOAuthBind == nil {
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ }
+
h.respondWithTokenPair(c, user)
}
@@ -290,8 +409,14 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
return
}
+ identities, err := h.userService.GetProfileIdentitySummaries(c.Request.Context(), subject.UserID, user)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
type UserResponse struct {
- *dto.User
+ userProfileResponse
RunMode string `json:"run_mode"`
}
@@ -300,7 +425,10 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
runMode = h.cfg.RunMode
}
- response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
+ response.Success(c, UserResponse{
+ userProfileResponse: userProfileResponseFromService(user, identities),
+ RunMode: runMode,
+ })
}
// ValidatePromoCodeRequest 验证优惠码请求
@@ -578,6 +706,8 @@ func (h *AuthHandler) Logout(c *gin.Context) {
// 不影响登出流程
}
}
+ h.consumePendingOAuthSessionOnLogout(c)
+ clearOAuthLogoutCookies(c)
response.Success(c, LogoutResponse{
Message: "Logged out successfully",
@@ -598,7 +728,7 @@ func (h *AuthHandler) RevokeAllSessions(c *gin.Context) {
return
}
- if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil {
+ if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil {
slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err)
response.InternalError(c, "Failed to revoke sessions")
return
diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go
index 0c7c2da7..7df4abfd 100644
--- a/backend/internal/handler/auth_linuxdo_oauth.go
+++ b/backend/internal/handler/auth_linuxdo_oauth.go
@@ -2,6 +2,8 @@ package handler
import (
"context"
+ "crypto/hmac"
+ "crypto/sha256"
"encoding/base64"
"errors"
"fmt"
@@ -13,10 +15,13 @@ import (
"time"
"unicode/utf8"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -25,17 +30,24 @@ import (
)
const (
- linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
- linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
- linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
- linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
- linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
- linuxDoOAuthDefaultRedirectTo = "/dashboard"
- linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
+ linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
+ oauthBindAccessTokenCookiePath = "/api/v1/auth/oauth"
+ linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
+ linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
+ linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
+ linuxDoOAuthIntentCookieName = "linuxdo_oauth_intent"
+ linuxDoOAuthBindUserCookieName = "linuxdo_oauth_bind_user"
+ oauthBindAccessTokenCookieName = "oauth_bind_access_token"
+ linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
+ linuxDoOAuthDefaultRedirectTo = "/dashboard"
+ linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
linuxDoOAuthMaxRedirectLen = 2048
linuxDoOAuthMaxFragmentValueLen = 512
linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-")
+
+ oauthIntentLogin = "login"
+ oauthIntentBindCurrentUser = "bind_current_user"
)
type linuxDoTokenResponse struct {
@@ -87,9 +99,29 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
secureCookie := isRequestHTTPS(c)
setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie)
setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie)
+ intent := normalizeOAuthIntent(c.Query("intent"))
+ setCookie(c, linuxDoOAuthIntentCookieName, encodeCookieValue(intent), linuxDoOAuthCookieMaxAgeSec, secureCookie)
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ if intent == oauthIntentBindCurrentUser {
+ bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ setCookie(c, linuxDoOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), linuxDoOAuthCookieMaxAgeSec, secureCookie)
+ } else {
+ clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
+ }
codeChallenge := ""
if cfg.UsePKCE {
@@ -148,6 +180,8 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
+ clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie)
+ clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
}()
expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName)
@@ -161,6 +195,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
+ return
+ }
+ intent, _ := readCookieDecoded(c, linuxDoOAuthIntentCookieName)
+ intent = normalizeOAuthIntent(intent)
codeVerifier := ""
if cfg.UsePKCE {
@@ -198,52 +239,205 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
- email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
+ email, username, subject, displayName, avatarURL, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
if err != nil {
log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err)
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
return
}
+ compatEmail := strings.TrimSpace(email)
// 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。
// 统一使用基于 subject 的稳定合成邮箱来做账号绑定。
if subject != "" {
email = linuxDoSyntheticEmail(subject)
}
-
- // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
- if err != nil {
- if errors.Is(err, service.ErrOAuthInvitationRequired) {
- pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
- if tokenErr != nil {
- redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
- return
- }
- fragment := url.Values{}
- fragment.Set("error", "invitation_required")
- fragment.Set("pending_oauth_token", pendingToken)
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ identityKey := service.PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: subject,
+ }
+ upstreamClaims := map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "suggested_display_name": displayName,
+ "suggested_avatar_url": avatarURL,
+ }
+ if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
+ upstreamClaims["compat_email"] = compatEmail
+ }
+ if intent == oauthIntentBindCurrentUser {
+ targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth bind target", "")
return
}
- // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
- redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentBindCurrentUser,
+ Identity: identityKey,
+ TargetUserID: &targetUserID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth bind", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
return
}
- fragment := url.Values{}
- fragment.Set("access_token", tokenPair.AccessToken)
- fragment.Set("refresh_token", tokenPair.RefreshToken)
- fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
- fragment.Set("token_type", "Bearer")
- fragment.Set("redirect", redirectTo)
- redirectWithFragment(c, frontendCallback, fragment)
+ existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityKey)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if existingIdentityUser != nil {
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identityKey,
+ TargetUserID: &existingIdentityUser.ID,
+ ResolvedEmail: existingIdentityUser.Email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ compatEmailUser, err := h.findLinuxDoCompatEmailUser(c.Request.Context(), compatEmail)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if err := h.createLinuxDoOAuthChoicePendingSession(
+ c,
+ identityKey,
+ email,
+ email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ compatEmail,
+ compatEmailUser,
+ h.isForceEmailOnThirdPartySignup(c.Request.Context()),
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+}
+
+func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ email = strings.TrimSpace(strings.ToLower(email))
+ if email == "" ||
+ strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
+ return nil, nil
+ }
+
+ userEntity, err := client.User.Query().
+ Where(userNormalizedEmailPredicate(email)).
+ Order(dbent.Asc(dbuser.FieldID)).
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
+ }
+ switch len(userEntity) {
+ case 0:
+ return nil, nil
+ case 1:
+ return userEntity[0], nil
+ default:
+ return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users")
+ }
+}
+
+func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
+ c *gin.Context,
+ identity service.PendingAuthIdentityKey,
+ suggestedEmail string,
+ resolvedEmail string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+ compatEmail string,
+ compatEmailUser *dbent.User,
+ forceEmailOnSignup bool,
+) error {
+ suggestionEmail := strings.TrimSpace(suggestedEmail)
+ canonicalEmail := strings.TrimSpace(resolvedEmail)
+ if suggestionEmail == "" {
+ suggestionEmail = canonicalEmail
+ }
+
+ completionResponse := map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "redirect": strings.TrimSpace(redirectTo),
+ "email": suggestionEmail,
+ "resolved_email": canonicalEmail,
+ "existing_account_email": "",
+ "existing_account_bindable": false,
+ "create_account_allowed": true,
+ "force_email_on_signup": forceEmailOnSignup,
+ "choice_reason": "third_party_signup",
+ }
+ if strings.TrimSpace(compatEmail) != "" {
+ completionResponse["compat_email"] = strings.TrimSpace(compatEmail)
+ }
+ resolvedChoiceEmail := suggestionEmail
+ if compatEmailUser != nil {
+ completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_bindable"] = true
+ completionResponse["choice_reason"] = "compat_email_match"
+ resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
+ }
+ if forceEmailOnSignup && compatEmailUser == nil {
+ completionResponse["choice_reason"] = "force_email_on_signup"
+ }
+
+ var targetUserID *int64
+ if compatEmailUser != nil && compatEmailUser.ID > 0 {
+ targetUserID = &compatEmailUser.ID
+ }
+
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identity,
+ TargetUserID: targetUserID,
+ ResolvedEmail: resolvedChoiceEmail,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ })
}
type completeLinuxDoOAuthRequest struct {
- PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
- InvitationCode string `json:"invitation_code" binding:"required"`
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AffCode string `json:"aff_code,omitempty"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
@@ -256,17 +450,87 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
return
}
- email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
return
}
-
- tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
if err != nil {
response.ErrorFrom(c, err)
return
}
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ } else if handled {
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
+ return
+ } else {
+ session = updatedSession
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
+ return
+ }
+
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+ if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
@@ -303,7 +567,7 @@ func linuxDoExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
- if cfg.UsePKCE {
+ if strings.TrimSpace(codeVerifier) != "" {
form.Set("code_verifier", codeVerifier)
}
@@ -353,11 +617,11 @@ func linuxDoFetchUserInfo(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
token *linuxDoTokenResponse,
-) (email string, username string, subject string, err error) {
+) (email string, username string, subject string, displayName string, avatarURL string, err error) {
client := req.C().SetTimeout(30 * time.Second)
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
if err != nil {
- return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
+ return "", "", "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
}
resp, err := client.R().
@@ -366,16 +630,16 @@ func linuxDoFetchUserInfo(
SetHeader("Authorization", authorization).
Get(cfg.UserInfoURL)
if err != nil {
- return "", "", "", fmt.Errorf("request userinfo: %w", err)
+ return "", "", "", "", "", fmt.Errorf("request userinfo: %w", err)
}
if !resp.IsSuccessState() {
- return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
+ return "", "", "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
}
return linuxDoParseUserInfo(resp.String(), cfg)
}
-func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) {
+func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, displayName string, avatarURL string, err error) {
email = firstNonEmpty(
getGJSON(body, cfg.UserInfoEmailPath),
getGJSON(body, "email"),
@@ -400,12 +664,29 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s
getGJSON(body, "user.id"),
)
+ displayName = firstNonEmpty(
+ getGJSON(body, "name"),
+ getGJSON(body, "nickname"),
+ getGJSON(body, "display_name"),
+ getGJSON(body, "user.name"),
+ getGJSON(body, "user.username"),
+ username,
+ )
+ avatarURL = firstNonEmpty(
+ getGJSON(body, "avatar_url"),
+ getGJSON(body, "avatar"),
+ getGJSON(body, "picture"),
+ getGJSON(body, "profile_image_url"),
+ getGJSON(body, "user.avatar"),
+ getGJSON(body, "user.avatar_url"),
+ )
+
subject = strings.TrimSpace(subject)
if subject == "" {
- return "", "", "", errors.New("userinfo missing id field")
+ return "", "", "", "", "", errors.New("userinfo missing id field")
}
if !isSafeLinuxDoSubject(subject) {
- return "", "", "", errors.New("userinfo returned invalid id field")
+ return "", "", "", "", "", errors.New("userinfo returned invalid id field")
}
email = strings.TrimSpace(email)
@@ -418,8 +699,13 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s
if username == "" {
username = "linuxdo_" + subject
}
+ displayName = strings.TrimSpace(displayName)
+ if displayName == "" {
+ displayName = username
+ }
+ avatarURL = strings.TrimSpace(avatarURL)
- return email, username, subject, nil
+ return email, username, subject, displayName, avatarURL, nil
}
func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) {
@@ -436,7 +722,7 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
- if cfg.UsePKCE {
+ if strings.TrimSpace(codeChallenge) != "" {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
@@ -670,6 +956,30 @@ func clearCookie(c *gin.Context, name string, secure bool) {
})
}
+func clearOAuthBindAccessTokenCookie(c *gin.Context, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthBindAccessTokenCookieName,
+ Value: "",
+ Path: oauthBindAccessTokenCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func setOAuthBindAccessTokenCookie(c *gin.Context, token string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthBindAccessTokenCookieName,
+ Value: url.QueryEscape(strings.TrimSpace(token)),
+ Path: oauthBindAccessTokenCookiePath,
+ MaxAge: linuxDoOAuthCookieMaxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
func truncateFragmentValue(value string) string {
value = strings.TrimSpace(value)
if value == "" {
@@ -728,3 +1038,127 @@ func linuxDoSyntheticEmail(subject string) string {
}
return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain
}
+
+func normalizeOAuthIntent(raw string) string {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "", oauthIntentLogin:
+ return oauthIntentLogin
+ case "bind", oauthIntentBindCurrentUser:
+ return oauthIntentBindCurrentUser
+ default:
+ return oauthIntentLogin
+ }
+}
+
+func (h *AuthHandler) buildOAuthBindUserCookieFromContext(c *gin.Context) (string, error) {
+ userID, err := h.resolveOAuthBindTargetUserID(c)
+ if err != nil || userID == nil || *userID <= 0 {
+ return "", infraerrors.Unauthorized("UNAUTHORIZED", "authentication required")
+ }
+ return buildOAuthBindUserCookieValue(*userID, h.oauthBindCookieSecret())
+}
+
+func (h *AuthHandler) PrepareOAuthBindAccessTokenCookie(c *gin.Context) {
+ const bearerPrefix = "Bearer "
+
+ authHeader := strings.TrimSpace(c.GetHeader("Authorization"))
+ if !strings.HasPrefix(strings.ToLower(authHeader), strings.ToLower(bearerPrefix)) {
+ response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required"))
+ return
+ }
+
+ token := strings.TrimSpace(authHeader[len(bearerPrefix):])
+ if token == "" {
+ response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required"))
+ return
+ }
+
+ setOAuthBindAccessTokenCookie(c, token, isRequestHTTPS(c))
+ c.Status(http.StatusNoContent)
+ c.Writer.WriteHeaderNow()
+}
+
+func (h *AuthHandler) resolveOAuthBindTargetUserID(c *gin.Context) (*int64, error) {
+ if subject, ok := servermiddleware.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
+ return &subject.UserID, nil
+ }
+ if h == nil || h.authService == nil || h.userService == nil {
+ return nil, service.ErrInvalidToken
+ }
+
+ ck, err := c.Request.Cookie(oauthBindAccessTokenCookieName)
+ clearOAuthBindAccessTokenCookie(c, isRequestHTTPS(c))
+ if err != nil {
+ return nil, err
+ }
+
+ tokenString, err := url.QueryUnescape(strings.TrimSpace(ck.Value))
+ if err != nil {
+ return nil, err
+ }
+ if tokenString == "" {
+ return nil, service.ErrInvalidToken
+ }
+
+ claims, err := h.authService.ValidateToken(tokenString)
+ if err != nil {
+ return nil, err
+ }
+ user, err := h.userService.GetByID(c.Request.Context(), claims.UserID)
+ if err != nil {
+ return nil, err
+ }
+ if user == nil || !user.IsActive() || claims.TokenVersion != user.TokenVersion {
+ return nil, service.ErrInvalidToken
+ }
+ return &user.ID, nil
+}
+
+func (h *AuthHandler) readOAuthBindUserIDFromCookie(c *gin.Context, cookieName string) (int64, error) {
+ value, err := readCookieDecoded(c, cookieName)
+ if err != nil {
+ return 0, err
+ }
+ return parseOAuthBindUserCookieValue(value, h.oauthBindCookieSecret())
+}
+
+func (h *AuthHandler) oauthBindCookieSecret() string {
+ if h == nil || h.cfg == nil {
+ return ""
+ }
+ return strings.TrimSpace(h.cfg.JWT.Secret)
+}
+
+func buildOAuthBindUserCookieValue(userID int64, secret string) (string, error) {
+ secret = strings.TrimSpace(secret)
+ if userID <= 0 || secret == "" {
+ return "", errors.New("invalid oauth bind cookie input")
+ }
+ payload := strconv.FormatInt(userID, 10)
+ mac := hmac.New(sha256.New, []byte(secret))
+ _, _ = mac.Write([]byte(payload))
+ signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
+ return payload + "." + signature, nil
+}
+
+func parseOAuthBindUserCookieValue(value string, secret string) (int64, error) {
+ secret = strings.TrimSpace(secret)
+ if secret == "" {
+ return 0, errors.New("missing oauth bind cookie secret")
+ }
+ payload, signature, ok := strings.Cut(strings.TrimSpace(value), ".")
+ if !ok || payload == "" || signature == "" {
+ return 0, errors.New("invalid oauth bind cookie")
+ }
+ mac := hmac.New(sha256.New, []byte(secret))
+ _, _ = mac.Write([]byte(payload))
+ expectedSignature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
+ if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
+ return 0, errors.New("invalid oauth bind cookie signature")
+ }
+ userID, err := strconv.ParseInt(payload, 10, 64)
+ if err != nil || userID <= 0 {
+ return 0, errors.New("invalid oauth bind cookie user")
+ }
+ return userID, nil
+}
diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go
index ff169c52..8b01ab41 100644
--- a/backend/internal/handler/auth_linuxdo_oauth_test.go
+++ b/backend/internal/handler/auth_linuxdo_oauth_test.go
@@ -1,10 +1,24 @@
package handler
import (
+ "bytes"
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
"strings"
"testing"
+ "time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
+ servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -41,11 +55,13 @@ func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
- email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg)
+ email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":123,"username":"alice","name":"Alice","avatar_url":"https://cdn.example/avatar.png"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "alice", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
+ require.Equal(t, "Alice", displayName)
+ require.Equal(t, "https://cdn.example/avatar.png", avatarURL)
}
func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
@@ -53,11 +69,13 @@ func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
- email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
+ email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "linuxdo_123", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
+ require.Equal(t, "linuxdo_123", displayName)
+ require.Equal(t, "", avatarURL)
}
func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
@@ -65,11 +83,11 @@ func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
- _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
+ _, _, _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
require.Error(t, err)
tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1)
- _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
+ _, _, _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
require.Error(t, err)
}
@@ -106,3 +124,906 @@ func TestSingleLineStripsWhitespace(t *testing.T) {
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
require.Equal(t, "", singleLine("\n\t\r"))
}
+
+func TestLinuxDoOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
+ handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: "https://connect.linux.do/oauth/authorize",
+ TokenURL: "https://connect.linux.do/oauth/token",
+ UserInfoURL: "https://connect.linux.do/api/user",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=/settings/connections", nil)
+ c.Request = req
+ c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 42})
+
+ handler.LinuxDoOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.Contains(t, location, "connect.linux.do/oauth/authorize")
+ require.Contains(t, location, "client_id=linuxdo-client")
+ require.Contains(t, location, "code_challenge=")
+
+ cookies := recorder.Result().Cookies()
+ require.NotNil(t, findCookie(cookies, linuxDoOAuthStateCookieName))
+ require.NotNil(t, findCookie(cookies, linuxDoOAuthRedirectCookie))
+ require.NotNil(t, findCookie(cookies, linuxDoOAuthVerifierCookie))
+ require.NotNil(t, findCookie(cookies, oauthPendingBrowserCookieName))
+
+ intentCookie := findCookie(cookies, linuxDoOAuthIntentCookieName)
+ require.NotNil(t, intentCookie)
+ require.Equal(t, oauthIntentBindCurrentUser, decodeCookieValueForTest(t, intentCookie.Value))
+
+ bindCookie := findCookie(cookies, linuxDoOAuthBindUserCookieName)
+ require.NotNil(t, bindCookie)
+ userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret")
+ require.NoError(t, err)
+ require.Equal(t, int64(42), userID)
+}
+
+func TestLinuxDoOAuthStartOmitsPKCEWhenDisabled(t *testing.T) {
+ handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: "https://connect.linux.do/oauth/authorize",
+ TokenURL: "https://connect.linux.do/oauth/token",
+ UserInfoURL: "https://connect.linux.do/api/user",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: false,
+ })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?redirect=/dashboard", nil)
+
+ handler.LinuxDoOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.NotContains(t, recorder.Header().Get("Location"), "code_challenge=")
+ require.Nil(t, findCookie(recorder.Result().Cookies(), linuxDoOAuthVerifierCookie))
+}
+
+func TestLinuxDoOAuthCallbackAllowsMissingVerifierWhenPKCEDisabled(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ require.NoError(t, r.ParseForm())
+ require.Empty(t, r.PostForm.Get("code_verifier"))
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"compat-subject","username":"linuxdo_user","name":"LinuxDo Display"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: false,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=linuxdo-code&state=state-123", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+ require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+}
+
+func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) {
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: "https://connect.linux.do/oauth/authorize",
+ TokenURL: "https://connect.linux.do/oauth/token",
+ UserInfoURL: "https://connect.linux.do/api/user",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ user, err := client.User.Create().
+ SetEmail("bind-cookie@example.com").
+ SetUsername("bind-cookie-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(context.Background())
+ require.NoError(t, err)
+
+ token, err := handler.authService.GenerateToken(&service.User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ PasswordHash: user.PasswordHash,
+ Role: user.Role,
+ Status: user.Status,
+ })
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=/settings/connections", nil)
+ req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: token, Path: oauthBindAccessTokenCookiePath})
+ c.Request = req
+
+ handler.LinuxDoOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+
+ bindCookie := findCookie(recorder.Result().Cookies(), linuxDoOAuthBindUserCookieName)
+ require.NotNil(t, bindCookie)
+ userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret")
+ require.NoError(t, err)
+ require.Equal(t, user.ID, userID)
+
+ accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName)
+ require.NotNil(t, accessTokenCookie)
+ require.Equal(t, -1, accessTokenCookie.MaxAge)
+}
+
+func TestPrepareOAuthBindAccessTokenCookieSetsHttpOnlyCookie(t *testing.T) {
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{})
+ t.Cleanup(func() { _ = client.Close() })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/bind-token", nil)
+ req.Header.Set("Authorization", "Bearer access-token-value")
+ c.Request = req
+
+ handler.PrepareOAuthBindAccessTokenCookie(c)
+
+ require.Equal(t, http.StatusNoContent, recorder.Code)
+ accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName)
+ require.NotNil(t, accessTokenCookie)
+ require.Equal(t, oauthBindAccessTokenCookiePath, accessTokenCookie.Path)
+ require.Equal(t, linuxDoOAuthCookieMaxAgeSec, accessTokenCookie.MaxAge)
+ require.True(t, accessTokenCookie.HttpOnly)
+ require.Equal(t, url.QueryEscape("access-token-value"), accessTokenCookie.Value)
+}
+
+func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"321","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(linuxDoSyntheticEmail("321")).
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("321").
+ SetMetadata(map[string]any{"username": "legacy-user"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-123&state=state-123", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-123"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, linuxDoSyntheticEmail("321"), session.ResolvedEmail)
+ require.Equal(t, "LinuxDo Display", session.UpstreamIdentityClaims["suggested_display_name"])
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+ _, hasRefreshToken := completion["refresh_token"]
+ require.False(t, hasRefreshToken)
+ require.Nil(t, completion["error"])
+}
+
+func TestLinuxDoOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_disabled","name":"LinuxDo Disabled"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(linuxDoSyntheticEmail("654")).
+ SetUsername("disabled-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusDisabled).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("654").
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-disabled&state=state-disabled", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-disabled"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-disabled"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"321","email":"legacy@example.com","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(" Legacy@Example.com ").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-compat&state=state-compat", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-compat"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-compat"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, strings.TrimSpace(existingUser.Email), session.ResolvedEmail)
+ require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ require.Equal(t, oauthPendingChoiceStep, completion["step"])
+ require.Equal(t, strings.TrimSpace(existingUser.Email), completion["email"])
+ require.Equal(t, strings.TrimSpace(existingUser.Email), completion["existing_account_email"])
+ require.Equal(t, true, completion["existing_account_bindable"])
+ require.Equal(t, "compat_email_match", completion["choice_reason"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+}
+
+func TestLinuxDoOAuthCallbackCreatesChoicePendingSessionWhenSignupRequiresInvite(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_invite","name":"Need Invite","avatar_url":"https://cdn.example/invite.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, true, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-456&state=state-456", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-456"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-456"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-456"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ ctx := context.Background()
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.Nil(t, session.TargetUserID)
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, oauthPendingChoiceStep, completion["step"])
+ require.Equal(t, "/dashboard", completion["redirect"])
+ require.Equal(t, "third_party_signup", completion["choice_reason"])
+}
+
+func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"999","username":"bind_user","name":"Bind Display","avatar_url":"https://cdn.example/bind.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
+ Enabled: true,
+ ClientID: "linuxdo-client",
+ ClientSecret: "linuxdo-secret",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "read",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
+ FrontendRedirectURL: "/auth/linuxdo/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-bind&state=state-bind", nil)
+ req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-bind"))
+ req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/settings/connections"))
+ req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-bind"))
+ req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentBindCurrentUser))
+ req.AddCookie(encodedCookie(linuxDoOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-bind"))
+ c.Request = req
+
+ handler.LinuxDoOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentBindCurrentUser, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, currentUser.ID, *session.TargetUserID)
+ require.Equal(t, linuxDoSyntheticEmail("999"), session.ResolvedEmail)
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/settings/connections", completion["redirect"])
+ require.Empty(t, completion["access_token"])
+ require.Equal(t, "Bind Display", session.UpstreamIdentityClaims["suggested_display_name"])
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, userCount)
+}
+
+func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-session").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-subject-1").
+ SetResolvedEmail("linuxdo-subject-1@linuxdo-connect.invalid").
+ SetBrowserSessionKey("linuxdo-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "LinuxDo Display",
+ "suggested_avatar_url": "https://cdn.example/linuxdo.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "LinuxDo Display", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("linuxdo-subject-1"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "LinuxDo Display", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/linuxdo.png", identity.Metadata["avatar_url"])
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-invalid-session").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-invalid-subject-1").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("linuxdo-invalid-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "bind_login_required",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-invalid-browser")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCompleteLinuxDoOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-choice-session").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-choice-subject-1").
+ SetResolvedEmail("linuxdo-choice-subject-1@linuxdo-connect.invalid").
+ SetBrowserSessionKey("linuxdo-choice-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": oauthPendingChoiceStep,
+ "redirect": "/dashboard",
+ "email": "fresh@example.com",
+ "resolved_email": "fresh@example.com",
+ "force_email_on_signup": true,
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-choice-browser")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.Equal(t, "pending_session", responseData["auth_result"])
+ require.Equal(t, oauthPendingChoiceStep, responseData["step"])
+ require.Equal(t, true, responseData["force_email_on_signup"])
+ require.Empty(t, responseData["access_token"])
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-no-adoption-session").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-subject-no-adoption").
+ SetResolvedEmail("linuxdo-subject-no-adoption@linuxdo-connect.invalid").
+ SetBrowserSessionKey("linuxdo-browser-no-adoption").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "LinuxDo Legacy",
+ "suggested_avatar_url": "https://cdn.example/linuxdo-legacy.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser-no-adoption")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+ require.NotEmpty(t, responseData["refresh_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "linuxdo_user", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("linuxdo-subject-no-adoption"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+}
+
+func TestCompleteLinuxDoOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ existingOwner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingOwner.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-conflict-subject").
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("linuxdo-complete-conflict-session").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("linuxdo-conflict-subject").
+ SetResolvedEmail("linuxdo-conflict-subject@linuxdo-connect.invalid").
+ SetBrowserSessionKey("linuxdo-conflict-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-conflict-browser")})
+ c.Request = req
+
+ handler.CompleteLinuxDoOAuthRegistration(c)
+
+ require.Equal(t, http.StatusConflict, recorder.Code)
+ payload := decodeJSONBody(t, recorder)
+ require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"])
+
+ userCount, err := client.User.Query().
+ Where(dbuser.EmailEQ("linuxdo-conflict-subject@linuxdo-connect.invalid")).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
+ t.Helper()
+ handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
+ return handler
+}
+
+func newLinuxDoOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+ handler, client := newOAuthPendingFlowTestHandler(t, invitationEnabled)
+ handler.settingSvc = nil
+ handler.cfg = &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ LinuxDo: oauthCfg,
+ }
+ return handler, client
+}
diff --git a/backend/internal/handler/auth_oauth_logout_test.go b/backend/internal/handler/auth_oauth_logout_test.go
new file mode 100644
index 00000000..0d4f94b1
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_logout_test.go
@@ -0,0 +1,68 @@
+package handler
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestLogoutClearsOAuthStateCookiesAndConsumesPendingSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("logout-pending-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("logout-subject-123").
+ SetBrowserSessionKey("logout-browser-session-key").
+ SetResolvedEmail("logout@example.com").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("logout-browser-session-key")})
+ req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-access-token"})
+ req.AddCookie(&http.Cookie{Name: linuxDoOAuthStateCookieName, Value: encodeCookieValue("linuxdo-state")})
+ req.AddCookie(&http.Cookie{Name: oidcOAuthStateCookieName, Value: encodeCookieValue("oidc-state")})
+ req.AddCookie(&http.Cookie{Name: wechatOAuthStateCookieName, Value: encodeCookieValue("wechat-state")})
+ req.AddCookie(&http.Cookie{Name: wechatPaymentOAuthStateName, Value: encodeCookieValue("wechat-payment-state")})
+ ginCtx.Request = req
+
+ handler.Logout(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ cookies := recorder.Result().Cookies()
+ for _, name := range []string{
+ oauthPendingSessionCookieName,
+ oauthPendingBrowserCookieName,
+ oauthBindAccessTokenCookieName,
+ linuxDoOAuthStateCookieName,
+ oidcOAuthStateCookieName,
+ wechatOAuthStateCookieName,
+ wechatPaymentOAuthStateName,
+ } {
+ cookie := findCookie(cookies, name)
+ require.NotNil(t, cookie, name)
+ require.Equal(t, -1, cookie.MaxAge, name)
+ require.True(t, cookie.HttpOnly, name)
+ }
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
new file mode 100644
index 00000000..490afd0f
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -0,0 +1,1946 @@
+package handler
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ entsql "entgo.io/ent/dialect/sql"
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ oauthPendingBrowserCookiePath = "/api/v1/auth/oauth"
+ oauthPendingBrowserCookieName = "oauth_pending_browser_session"
+ oauthPendingSessionCookiePath = "/api/v1/auth/oauth"
+ oauthPendingSessionCookieName = "oauth_pending_session"
+ oauthPendingCookieMaxAgeSec = 10 * 60
+ oauthPendingChoiceStep = "choose_account_action_required"
+
+ oauthCompletionResponseKey = "completion_response"
+)
+
+var pendingOAuthCreateAccountPreCommitHook func(context.Context, *dbent.PendingAuthSession) error
+
+type oauthPendingSessionPayload struct {
+ Intent string
+ Identity service.PendingAuthIdentityKey
+ TargetUserID *int64
+ ResolvedEmail string
+ RedirectTo string
+ BrowserSessionKey string
+ UpstreamIdentityClaims map[string]any
+ CompletionResponse map[string]any
+}
+
+type oauthAdoptionDecisionRequest struct {
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+type bindPendingOAuthLoginRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ Password string `json:"password" binding:"required"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+type createPendingOAuthAccountRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ VerifyCode string `json:"verify_code,omitempty"`
+ Password string `json:"password" binding:"required,min=6"`
+ InvitationCode string `json:"invitation_code,omitempty"`
+ AffCode string `json:"aff_code,omitempty"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+type sendPendingOAuthVerifyCodeRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ TurnstileToken string `json:"turnstile_token,omitempty"`
+ PendingAuthToken string `json:"pending_auth_token,omitempty"`
+ PendingOAuthToken string `json:"pending_oauth_token,omitempty"`
+}
+
+func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest {
+ return oauthAdoptionDecisionRequest{
+ AdoptDisplayName: r.AdoptDisplayName,
+ AdoptAvatar: r.AdoptAvatar,
+ }
+}
+
+func (r createPendingOAuthAccountRequest) adoptionDecision() oauthAdoptionDecisionRequest {
+ return oauthAdoptionDecisionRequest{
+ AdoptDisplayName: r.AdoptDisplayName,
+ AdoptAvatar: r.AdoptAvatar,
+ }
+}
+
+func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
+ if h == nil || h.authService == nil || h.authService.EntClient() == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+ return service.NewAuthPendingIdentityService(h.authService.EntClient()), nil
+}
+
+func generateOAuthPendingBrowserSession() (string, error) {
+ return oauth.GenerateState()
+}
+
+func setOAuthPendingBrowserCookie(c *gin.Context, sessionKey string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingBrowserCookieName,
+ Value: encodeCookieValue(sessionKey),
+ Path: oauthPendingBrowserCookiePath,
+ MaxAge: oauthPendingCookieMaxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func clearOAuthPendingBrowserCookie(c *gin.Context, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingBrowserCookieName,
+ Value: "",
+ Path: oauthPendingBrowserCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func readOAuthPendingBrowserCookie(c *gin.Context) (string, error) {
+ return readCookieDecoded(c, oauthPendingBrowserCookieName)
+}
+
+func setOAuthPendingSessionCookie(c *gin.Context, sessionToken string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingSessionCookieName,
+ Value: encodeCookieValue(sessionToken),
+ Path: oauthPendingSessionCookiePath,
+ MaxAge: oauthPendingCookieMaxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func clearOAuthPendingSessionCookie(c *gin.Context, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: oauthPendingSessionCookieName,
+ Value: "",
+ Path: oauthPendingSessionCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func readOAuthPendingSessionCookie(c *gin.Context) (string, error) {
+ return readCookieDecoded(c, oauthPendingSessionCookieName)
+}
+
+func redirectToFrontendCallback(c *gin.Context, frontendCallback string) {
+ u, err := url.Parse(frontendCallback)
+ if err != nil {
+ c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
+ return
+ }
+ if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
+ c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
+ return
+ }
+ u.Fragment = ""
+ c.Header("Cache-Control", "no-store")
+ c.Header("Pragma", "no-cache")
+ c.Redirect(http.StatusFound, u.String())
+}
+
+func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPendingSessionPayload) error {
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return err
+ }
+
+ session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{
+ Intent: strings.TrimSpace(payload.Intent),
+ Identity: payload.Identity,
+ TargetUserID: payload.TargetUserID,
+ ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail),
+ RedirectTo: strings.TrimSpace(payload.RedirectTo),
+ BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey),
+ UpstreamIdentityClaims: payload.UpstreamIdentityClaims,
+ LocalFlowState: map[string]any{
+ oauthCompletionResponseKey: payload.CompletionResponse,
+ },
+ })
+ if err != nil {
+ return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
+ }
+
+ setOAuthPendingSessionCookie(c, session.SessionToken, isRequestHTTPS(c))
+ return nil
+}
+
+func readCompletionResponse(session map[string]any) (map[string]any, bool) {
+ if len(session) == 0 {
+ return nil, false
+ }
+ value, ok := session[oauthCompletionResponseKey]
+ if !ok {
+ return nil, false
+ }
+ result, ok := value.(map[string]any)
+ if !ok {
+ return nil, false
+ }
+ return result, true
+}
+
+func clonePendingMap(values map[string]any) map[string]any {
+ if len(values) == 0 {
+ return map[string]any{}
+ }
+ cloned := make(map[string]any, len(values))
+ for key, value := range values {
+ cloned[key] = value
+ }
+ return cloned
+}
+
+func mergePendingCompletionResponse(session *dbent.PendingAuthSession, overrides map[string]any) map[string]any {
+ payload, _ := readCompletionResponse(session.LocalFlowState)
+ merged := clonePendingMap(payload)
+ if strings.TrimSpace(session.RedirectTo) != "" {
+ if _, exists := merged["redirect"]; !exists {
+ merged["redirect"] = session.RedirectTo
+ }
+ }
+ for key, value := range overrides {
+ if value == nil {
+ delete(merged, key)
+ continue
+ }
+ merged[key] = value
+ }
+ applySuggestedProfileToCompletionResponse(merged, session.UpstreamIdentityClaims)
+ return merged
+}
+
+func pendingSessionStringValue(values map[string]any, key string) string {
+ if len(values) == 0 {
+ return ""
+ }
+ raw, ok := values[key]
+ if !ok {
+ return ""
+ }
+ value, ok := raw.(string)
+ if !ok {
+ return ""
+ }
+ return strings.TrimSpace(value)
+}
+
+func pendingSessionWantsInvitation(payload map[string]any) bool {
+ return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
+}
+
+func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool {
+ if session == nil {
+ return false
+ }
+ if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) {
+ return false
+ }
+ if session.TargetUserID == nil || *session.TargetUserID <= 0 {
+ return false
+ }
+ if pendingSessionWantsInvitation(payload) {
+ return false
+ }
+ return strings.TrimSpace(pendingSessionStringValue(payload, "step")) == ""
+}
+
+func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error {
+ if session == nil {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ if strings.TrimSpace(session.Intent) != oauthIntentLogin {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ if session.TargetUserID != nil && *session.TargetUserID > 0 {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ payload, _ := readCompletionResponse(session.LocalFlowState)
+ if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required") {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+ return nil
+}
+
+func buildLegacyCompleteRegistrationPendingResponse(
+ session *dbent.PendingAuthSession,
+ forceEmailOnSignup bool,
+ emailVerificationRequired bool,
+) map[string]any {
+ completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "create_account_allowed": true,
+ "force_email_on_signup": forceEmailOnSignup,
+ }))
+
+ if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
+ if _, exists := completionResponse["email"]; !exists {
+ completionResponse["email"] = email
+ }
+ if _, exists := completionResponse["resolved_email"]; !exists {
+ completionResponse["resolved_email"] = email
+ }
+ }
+ if _, exists := completionResponse["choice_reason"]; !exists {
+ switch {
+ case forceEmailOnSignup:
+ completionResponse["choice_reason"] = "force_email_on_signup"
+ case emailVerificationRequired:
+ completionResponse["choice_reason"] = "email_verification_required"
+ default:
+ completionResponse["choice_reason"] = "third_party_signup"
+ }
+ }
+ return completionResponse
+}
+
+func (h *AuthHandler) legacyCompleteRegistrationSessionStatus(
+ c *gin.Context,
+ session *dbent.PendingAuthSession,
+) (*dbent.PendingAuthSession, bool, error) {
+ if session == nil {
+ return nil, false, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+
+ payload := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
+ if step := pendingSessionStringValue(payload, "step"); step != "" {
+ return session, true, nil
+ }
+
+ emailVerificationRequired := h != nil && h.authService != nil && h.authService.IsEmailVerifyEnabled(c.Request.Context())
+ forceEmailOnSignup := h.isForceEmailOnThirdPartySignup(c.Request.Context())
+ if !emailVerificationRequired && !forceEmailOnSignup {
+ return session, false, nil
+ }
+
+ client := h.entClient()
+ if client == nil {
+ return nil, false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ updatedSession, err := updatePendingOAuthSessionProgress(
+ c.Request.Context(),
+ client,
+ session,
+ strings.TrimSpace(session.Intent),
+ strings.TrimSpace(session.ResolvedEmail),
+ nil,
+ buildLegacyCompleteRegistrationPendingResponse(session, forceEmailOnSignup, emailVerificationRequired),
+ )
+ if err != nil {
+ return nil, false, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
+ }
+ return updatedSession, true, nil
+}
+
+func (r oauthAdoptionDecisionRequest) hasDecision() bool {
+ return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
+}
+
+func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionRequest, error) {
+ var req oauthAdoptionDecisionRequest
+ if c == nil || c.Request == nil || c.Request.Body == nil {
+ return req, nil
+ }
+ if err := c.ShouldBindJSON(&req); err != nil {
+ if errors.Is(err, io.EOF) {
+ return req, nil
+ }
+ return req, err
+ }
+ return req, nil
+}
+
+func cloneOAuthMetadata(values map[string]any) map[string]any {
+ if len(values) == 0 {
+ return map[string]any{}
+ }
+ cloned := make(map[string]any, len(values))
+ for key, value := range values {
+ cloned[key] = value
+ }
+ return cloned
+}
+
+func mergeOAuthMetadata(base map[string]any, overlay map[string]any) map[string]any {
+ merged := cloneOAuthMetadata(base)
+ for key, value := range overlay {
+ merged[key] = value
+ }
+ return merged
+}
+
+func normalizeAdoptedOAuthDisplayName(value string) string {
+ value = strings.TrimSpace(value)
+ if len([]rune(value)) > 100 {
+ value = string([]rune(value)[:100])
+ }
+ return value
+}
+
+func (h *AuthHandler) entClient() *dbent.Client {
+ if h == nil || h.authService == nil {
+ return nil
+ }
+ return h.authService.EntClient()
+}
+
+func (h *AuthHandler) isForceEmailOnThirdPartySignup(ctx context.Context) bool {
+ if h == nil || h.settingSvc == nil {
+ return false
+ }
+ defaults, err := h.settingSvc.GetAuthSourceDefaultSettings(ctx)
+ if err != nil || defaults == nil {
+ return false
+ }
+ return defaults.ForceEmailOnThirdPartySignup
+}
+
+func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity service.PendingAuthIdentityKey) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ record, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(identity.ProviderSubject)),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ return findActiveUserByID(ctx, client, record.UserID)
+}
+
+func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") }
+func (h *AuthHandler) BindOIDCOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "oidc") }
+func (h *AuthHandler) BindWeChatOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "wechat") }
+func (h *AuthHandler) BindPendingOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "") }
+
+func (h *AuthHandler) CreateLinuxDoOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "linuxdo")
+}
+
+func (h *AuthHandler) CreateOIDCOAuthAccount(c *gin.Context) { h.createPendingOAuthAccount(c, "oidc") }
+
+func (h *AuthHandler) CreateWeChatOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "wechat")
+}
+
+func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "")
+}
+
+// SendPendingOAuthVerifyCode sends a verification code for a browser-bound
+// pending OAuth account-creation flow.
+// POST /api/v1/auth/oauth/pending/send-verify-code
+func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
+ var req sendPendingOAuthVerifyCodeRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ _, session, _, err := readPendingOAuthBrowserSession(c, h)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+
+ email := strings.TrimSpace(strings.ToLower(req.Email))
+ if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil {
+ session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
+ return
+ } else if err != nil && !errors.Is(err, service.ErrUserNotFound) {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, SendVerifyCodeResponse{
+ Message: "Verification code sent successfully",
+ Countdown: result.Countdown,
+ })
+}
+
+func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
+ c *gin.Context,
+ sessionID int64,
+ req oauthAdoptionDecisionRequest,
+) (*dbent.IdentityAdoptionDecision, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ existing, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(sessionID)).
+ Only(c.Request.Context())
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_LOAD_FAILED", "failed to load oauth profile adoption decision").WithCause(err)
+ }
+ if existing != nil && !req.hasDecision() {
+ return existing, nil
+ }
+ if existing == nil && !req.hasDecision() {
+ return nil, nil
+ }
+
+ input := service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: sessionID,
+ }
+ if existing != nil {
+ input.AdoptDisplayName = existing.AdoptDisplayName
+ input.AdoptAvatar = existing.AdoptAvatar
+ input.IdentityID = existing.IdentityID
+ }
+ if req.AdoptDisplayName != nil {
+ input.AdoptDisplayName = *req.AdoptDisplayName
+ }
+ if req.AdoptAvatar != nil {
+ input.AdoptAvatar = *req.AdoptAvatar
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return nil, err
+ }
+ decision, err := svc.UpsertAdoptionDecision(c.Request.Context(), input)
+ if err != nil {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
+ }
+ return decision, nil
+}
+
+func (h *AuthHandler) ensurePendingOAuthAdoptionDecision(
+ c *gin.Context,
+ sessionID int64,
+ req oauthAdoptionDecisionRequest,
+) (*dbent.IdentityAdoptionDecision, error) {
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, sessionID, req)
+ if err != nil {
+ return nil, err
+ }
+ if decision != nil {
+ return decision, nil
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return nil, err
+ }
+ decision, err = svc.UpsertAdoptionDecision(c.Request.Context(), service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: sessionID,
+ })
+ if err != nil {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
+ }
+ return decision, nil
+}
+
+func updatePendingOAuthSessionProgress(
+ ctx context.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ intent string,
+ resolvedEmail string,
+ targetUserID *int64,
+ completionResponse map[string]any,
+) (*dbent.PendingAuthSession, error) {
+ if client == nil || session == nil {
+ return nil, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
+ }
+
+ localFlowState := clonePendingMap(session.LocalFlowState)
+ localFlowState[oauthCompletionResponseKey] = clonePendingMap(completionResponse)
+
+ update := client.PendingAuthSession.UpdateOneID(session.ID).
+ SetIntent(strings.TrimSpace(intent)).
+ SetResolvedEmail(strings.TrimSpace(resolvedEmail)).
+ SetLocalFlowState(localFlowState)
+ if targetUserID != nil && *targetUserID > 0 {
+ update = update.SetTargetUserID(*targetUserID)
+ } else {
+ update = update.ClearTargetUserID()
+ }
+ return update.Save(ctx)
+}
+
+func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) {
+ if session == nil {
+ return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
+ }
+ if session.TargetUserID != nil && *session.TargetUserID > 0 {
+ return *session.TargetUserID, nil
+ }
+ email := strings.TrimSpace(session.ResolvedEmail)
+ if email == "" {
+ return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing")
+ }
+
+ userEntity, err := findUserByNormalizedEmail(ctx, client, email)
+ if err != nil {
+ if errors.Is(err, service.ErrUserNotFound) {
+ return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found")
+ }
+ return 0, err
+ }
+ return userEntity.ID, nil
+}
+
+func userNormalizedEmailPredicate(email string) predicate.User {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" {
+ return dbuser.EmailEQ(email)
+ }
+ return predicate.User(func(s *entsql.Selector) {
+ s.Where(entsql.P(func(b *entsql.Builder) {
+ b.WriteString("LOWER(TRIM(").
+ Ident(s.C(dbuser.FieldEmail)).
+ WriteString(")) = ").
+ Arg(normalized)
+ }))
+ })
+}
+
+func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email string) (*dbent.User, error) {
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ matches, err := client.User.Query().
+ Where(userNormalizedEmailPredicate(email)).
+ Order(dbent.Asc(dbuser.FieldID)).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if len(matches) == 0 {
+ return nil, service.ErrUserNotFound
+ }
+ if len(matches) > 1 {
+ return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users")
+ }
+ return matches[0], nil
+}
+
+func ensurePendingOAuthRegistrationIdentityAvailable(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) error {
+ if client == nil || session == nil {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil
+ }
+ return err
+ }
+ if identity == nil || identity.UserID <= 0 {
+ return nil
+ }
+
+ activeOwner, err := findActiveUserByID(ctx, client, identity.UserID)
+ if err != nil {
+ return err
+ }
+ if activeOwner != nil {
+ return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ return nil
+}
+
+func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
+ if session == nil {
+ return nil
+ }
+ switch strings.TrimSpace(session.ProviderType) {
+ case "oidc":
+ issuer := strings.TrimSpace(session.ProviderKey)
+ if issuer == "" {
+ issuer = pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
+ }
+ if issuer == "" {
+ return nil
+ }
+ return &issuer
+ default:
+ issuer := pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
+ if issuer == "" {
+ return nil
+ }
+ return &issuer
+ }
+}
+
+func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
+ if session != nil && strings.EqualFold(strings.TrimSpace(session.ProviderType), "wechat") {
+ return ensurePendingWeChatOAuthIdentityForUser(ctx, tx, session, userID)
+ }
+
+ client := tx.Client()
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
+ ).
+ Only(ctx)
+ if err != nil && !dbent.IsNotFound(err) {
+ return nil, err
+ }
+ if identity != nil {
+ if identity.UserID != userID {
+ activeOwner, err := findActiveUserByID(ctx, client, identity.UserID)
+ if err != nil {
+ return nil, err
+ }
+ if activeOwner != nil {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ return client.AuthIdentity.UpdateOneID(identity.ID).
+ SetUserID(userID).
+ Save(ctx)
+ }
+ return identity, nil
+ }
+
+ create := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType(strings.TrimSpace(session.ProviderType)).
+ SetProviderKey(strings.TrimSpace(session.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(session.ProviderSubject)).
+ SetMetadata(cloneOAuthMetadata(session.UpstreamIdentityClaims))
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ create = create.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ return create.Save(ctx)
+}
+
+func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
+ client := tx.Client()
+ providerType := strings.TrimSpace(session.ProviderType)
+ providerKey := strings.TrimSpace(session.ProviderKey)
+ providerSubject := strings.TrimSpace(session.ProviderSubject)
+ providerKeys := wechatCompatibleProviderKeys(providerKey)
+ channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel"))
+ channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id"))
+ channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject"))
+ metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims)
+
+ identityRecords, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(providerKeys...),
+ authidentity.ProviderSubjectEQ(providerSubject),
+ ).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ identity, hasCanonicalKey, err := chooseWeChatIdentityForUser(ctx, client, identityRecords, userID, providerKey)
+ if err != nil {
+ return nil, err
+ }
+
+ var legacyOpenIDIdentity *dbent.AuthIdentity
+ if channelSubject != "" && channelSubject != providerSubject {
+ legacyOpenIDRecords, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(providerKeys...),
+ authidentity.ProviderSubjectEQ(channelSubject),
+ ).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ legacyOpenIDIdentity, _, err = chooseWeChatIdentityForUser(ctx, client, legacyOpenIDRecords, userID, providerKey)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ switch {
+ case identity != nil:
+ update := client.AuthIdentity.UpdateOneID(identity.ID).
+ SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata))
+ if identity.UserID != userID {
+ update = update.SetUserID(userID)
+ }
+ if !strings.EqualFold(strings.TrimSpace(identity.ProviderKey), providerKey) && !hasCanonicalKey {
+ update = update.SetProviderKey(providerKey)
+ }
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ update = update.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ identity, err = update.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ case legacyOpenIDIdentity != nil:
+ update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID).
+ SetProviderKey(providerKey).
+ SetProviderSubject(providerSubject).
+ SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata))
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ update = update.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ identity, err = update.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ default:
+ create := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType(providerType).
+ SetProviderKey(providerKey).
+ SetProviderSubject(providerSubject).
+ SetMetadata(metadata)
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ create = create.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ identity, err = create.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if channel == "" || channelAppID == "" || channelSubject == "" {
+ return identity, nil
+ }
+
+ channelRecords, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(providerType),
+ authidentitychannel.ProviderKeyIn(providerKeys...),
+ authidentitychannel.ChannelEQ(channel),
+ authidentitychannel.ChannelAppIDEQ(channelAppID),
+ authidentitychannel.ChannelSubjectEQ(channelSubject),
+ ).
+ WithIdentity().
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+ channelRecord, hasCanonicalChannelKey, err := chooseWeChatChannelForUser(ctx, client, channelRecords, userID, providerKey)
+ if err != nil {
+ return nil, err
+ }
+
+ channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata)
+ if channelRecord == nil {
+ if _, err := client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(providerType).
+ SetProviderKey(providerKey).
+ SetChannel(channel).
+ SetChannelAppID(channelAppID).
+ SetChannelSubject(channelSubject).
+ SetMetadata(channelMetadata).
+ Save(ctx); err != nil {
+ return nil, err
+ }
+ return identity, nil
+ }
+
+ updateChannel := client.AuthIdentityChannel.UpdateOneID(channelRecord.ID).
+ SetIdentityID(identity.ID).
+ SetMetadata(channelMetadata)
+ if !strings.EqualFold(strings.TrimSpace(channelRecord.ProviderKey), providerKey) && !hasCanonicalChannelKey {
+ updateChannel = updateChannel.SetProviderKey(providerKey)
+ }
+ _, err = updateChannel.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return identity, nil
+}
+
+func chooseWeChatIdentityForUser(ctx context.Context, client *dbent.Client, records []*dbent.AuthIdentity, userID int64, preferredProviderKey string) (*dbent.AuthIdentity, bool, error) {
+ var preferred *dbent.AuthIdentity
+ var fallback *dbent.AuthIdentity
+ hasCanonicalKey := false
+ for _, record := range records {
+ if record == nil {
+ continue
+ }
+ if record.UserID != userID {
+ activeOwner, err := findActiveUserByID(ctx, client, record.UserID)
+ if err != nil {
+ return nil, false, err
+ }
+ if activeOwner != nil {
+ return nil, false, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ }
+ if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) {
+ hasCanonicalKey = true
+ if preferred == nil {
+ preferred = record
+ }
+ continue
+ }
+ if fallback == nil {
+ fallback = record
+ }
+ }
+ if preferred != nil {
+ return preferred, hasCanonicalKey, nil
+ }
+ return fallback, hasCanonicalKey, nil
+}
+
+func chooseWeChatChannelForUser(ctx context.Context, client *dbent.Client, records []*dbent.AuthIdentityChannel, userID int64, preferredProviderKey string) (*dbent.AuthIdentityChannel, bool, error) {
+ var preferred *dbent.AuthIdentityChannel
+ var fallback *dbent.AuthIdentityChannel
+ hasCanonicalKey := false
+ for _, record := range records {
+ if record == nil {
+ continue
+ }
+ if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
+ activeOwner, err := findActiveUserByID(ctx, client, record.Edges.Identity.UserID)
+ if err != nil {
+ return nil, false, err
+ }
+ if activeOwner != nil {
+ return nil, false, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+ }
+ if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) {
+ hasCanonicalKey = true
+ if preferred == nil {
+ preferred = record
+ }
+ continue
+ }
+ if fallback == nil {
+ fallback = record
+ }
+ }
+ if preferred != nil {
+ return preferred, hasCanonicalKey, nil
+ }
+ return fallback, hasCanonicalKey, nil
+}
+
+func findActiveUserByID(ctx context.Context, client *dbent.Client, userID int64) (*dbent.User, error) {
+ if client == nil || userID <= 0 {
+ return nil, nil
+ }
+ userEntity, err := client.User.Get(ctx, userID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
+ }
+ if !strings.EqualFold(strings.TrimSpace(userEntity.Status), service.StatusActive) {
+ return nil, service.ErrUserNotActive
+ }
+ return userEntity, nil
+}
+
+func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any {
+ if channel == nil {
+ return map[string]any{}
+ }
+ return cloneOAuthMetadata(channel.Metadata)
+}
+
+func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision) bool {
+ if session == nil || decision == nil {
+ return false
+ }
+ switch strings.ToLower(strings.TrimSpace(session.Intent)) {
+ case "bind_current_user", "login", "adopt_existing_user_by_email":
+ return true
+ default:
+ return decision.AdoptDisplayName || decision.AdoptAvatar
+ }
+}
+
+func shouldSkipAvatarAdoption(err error) bool {
+ return errors.Is(err, service.ErrAvatarInvalid) ||
+ errors.Is(err, service.ErrAvatarTooLarge) ||
+ errors.Is(err, service.ErrAvatarNotImage)
+}
+
+func applyPendingOAuthBinding(
+ ctx context.Context,
+ client *dbent.Client,
+ authService *service.AuthService,
+ userService *service.UserService,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ overrideUserID *int64,
+ forceBind bool,
+ applyFirstBindDefaults bool,
+) error {
+ if client == nil || session == nil {
+ return nil
+ }
+ if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
+ return nil
+ }
+
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return applyPendingOAuthBindingTx(ctx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults)
+ }
+
+ tx, err := client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := applyPendingOAuthBindingTx(txCtx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func applyPendingOAuthBindingTx(
+ ctx context.Context,
+ tx *dbent.Tx,
+ authService *service.AuthService,
+ userService *service.UserService,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ overrideUserID *int64,
+ forceBind bool,
+ applyFirstBindDefaults bool,
+) error {
+ if tx == nil || session == nil {
+ return nil
+ }
+ if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
+ return nil
+ }
+
+ targetUserID := int64(0)
+ if overrideUserID != nil && *overrideUserID > 0 {
+ targetUserID = *overrideUserID
+ } else {
+ resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, tx.Client(), session)
+ if err != nil {
+ return err
+ }
+ targetUserID = resolvedUserID
+ }
+
+ adoptedDisplayName := ""
+ if decision != nil && decision.AdoptDisplayName {
+ adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"))
+ }
+ adoptedAvatarURL := ""
+ if decision != nil && decision.AdoptAvatar {
+ adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
+ }
+ shouldAdoptAvatar := false
+ if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" {
+ if err := service.ValidateUserAvatar(adoptedAvatarURL); err == nil {
+ shouldAdoptAvatar = true
+ } else if !shouldSkipAvatarAdoption(err) {
+ return err
+ }
+ }
+
+ if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
+ if err := tx.Client().User.UpdateOneID(targetUserID).
+ SetUsername(adoptedDisplayName).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID)
+ if err != nil {
+ return err
+ }
+
+ metadata := cloneOAuthMetadata(identity.Metadata)
+ for key, value := range session.UpstreamIdentityClaims {
+ metadata[key] = value
+ }
+ if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
+ metadata["display_name"] = adoptedDisplayName
+ }
+ if shouldAdoptAvatar {
+ metadata["avatar_url"] = adoptedAvatarURL
+ }
+
+ updateIdentity := tx.Client().AuthIdentity.UpdateOneID(identity.ID).SetMetadata(metadata)
+ if issuer := oauthIdentityIssuer(session); issuer != nil {
+ updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
+ }
+ if _, err := updateIdentity.Save(ctx); err != nil {
+ return err
+ }
+
+ if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
+ if _, err := tx.Client().IdentityAdoptionDecision.Update().
+ Where(
+ identityadoptiondecision.IdentityIDEQ(identity.ID),
+ identityadoptiondecision.IDNEQ(decision.ID),
+ ).
+ ClearIdentityID().
+ Save(ctx); err != nil {
+ return err
+ }
+ if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
+ SetIdentityID(identity.ID).
+ Save(ctx); err != nil {
+ return err
+ }
+ }
+
+ if applyFirstBindDefaults && authService != nil {
+ if err := authService.ApplyProviderDefaultSettingsOnFirstBind(ctx, targetUserID, session.ProviderType); err != nil {
+ return err
+ }
+ }
+
+ if shouldAdoptAvatar && userService != nil {
+ if _, err := userService.SetAvatar(ctx, targetUserID, adoptedAvatarURL); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func consumePendingOAuthBrowserSessionTx(
+ ctx context.Context,
+ tx *dbent.Tx,
+ session *dbent.PendingAuthSession,
+) error {
+ if tx == nil || session == nil {
+ return service.ErrPendingAuthSessionNotFound
+ }
+
+ storedSession, err := tx.Client().PendingAuthSession.Get(ctx, session.ID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return service.ErrPendingAuthSessionNotFound
+ }
+ return err
+ }
+
+ now := time.Now().UTC()
+ if storedSession.ConsumedAt != nil {
+ return service.ErrPendingAuthSessionConsumed
+ }
+ if !storedSession.ExpiresAt.IsZero() && now.After(storedSession.ExpiresAt) {
+ return service.ErrPendingAuthSessionExpired
+ }
+ if strings.TrimSpace(storedSession.BrowserSessionKey) != "" &&
+ strings.TrimSpace(storedSession.BrowserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
+ return service.ErrPendingAuthBrowserMismatch
+ }
+
+ if _, err := tx.Client().PendingAuthSession.UpdateOneID(storedSession.ID).
+ SetConsumedAt(now).
+ SetCompletionCodeHash("").
+ ClearCompletionCodeExpiresAt().
+ Save(ctx); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func applyPendingOAuthAdoptionAndConsumeSession(
+ ctx context.Context,
+ client *dbent.Client,
+ authService *service.AuthService,
+ userService *service.UserService,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ userID int64,
+) error {
+ if client == nil {
+ return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+ if session == nil || userID <= 0 {
+ return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
+ }
+
+ tx, err := client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := applyPendingOAuthAdoption(txCtx, client, authService, userService, session, decision, &userID); err != nil {
+ return err
+ }
+ if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func applyPendingOAuthAdoption(
+ ctx context.Context,
+ client *dbent.Client,
+ authService *service.AuthService,
+ userService *service.UserService,
+ session *dbent.PendingAuthSession,
+ decision *dbent.IdentityAdoptionDecision,
+ overrideUserID *int64,
+) error {
+ return applyPendingOAuthBinding(
+ ctx,
+ client,
+ authService,
+ userService,
+ session,
+ decision,
+ overrideUserID,
+ false,
+ strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user"),
+ )
+}
+
+func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
+ if len(payload) == 0 || len(upstream) == 0 {
+ return
+ }
+
+ displayName := pendingSessionStringValue(upstream, "suggested_display_name")
+ avatarURL := pendingSessionStringValue(upstream, "suggested_avatar_url")
+
+ if displayName != "" {
+ if _, exists := payload["suggested_display_name"]; !exists {
+ payload["suggested_display_name"] = displayName
+ }
+ }
+ if avatarURL != "" {
+ if _, exists := payload["suggested_avatar_url"]; !exists {
+ payload["suggested_avatar_url"] = avatarURL
+ }
+ }
+ if displayName != "" || avatarURL != "" {
+ payload["adoption_required"] = true
+ }
+}
+
+func pendingOAuthIdentityExistsForUser(
+ ctx context.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ userID int64,
+) (bool, error) {
+ if client == nil || session == nil || userID <= 0 {
+ return false, nil
+ }
+
+ providerType := strings.TrimSpace(session.ProviderType)
+ providerKey := strings.TrimSpace(session.ProviderKey)
+ providerSubject := strings.TrimSpace(session.ProviderSubject)
+ if providerType == "" || providerSubject == "" {
+ return false, nil
+ }
+
+ query := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderSubjectEQ(providerSubject),
+ authidentity.UserIDEQ(userID),
+ )
+ if strings.EqualFold(providerType, "wechat") {
+ query = query.Where(authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(providerKey)...))
+ } else if providerKey != "" {
+ query = query.Where(authidentity.ProviderKeyEQ(providerKey))
+ }
+
+ count, err := query.Count(ctx)
+ if err != nil {
+ return false, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ return count > 0, nil
+}
+
+func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt(
+ ctx context.Context,
+ session *dbent.PendingAuthSession,
+ payload map[string]any,
+) (bool, error) {
+ if session == nil || len(payload) == 0 {
+ return false, nil
+ }
+ if !pendingOAuthCompletionCanIssueTokenPair(session, payload) {
+ return false, nil
+ }
+ if pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name") == "" &&
+ pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") == "" {
+ return false, nil
+ }
+
+ return pendingOAuthIdentityExistsForUser(ctx, h.entClient(), session, *session.TargetUserID)
+}
+
+func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) {
+ secureCookie := isRequestHTTPS(c)
+ clearCookies := func() {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ }
+
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil || strings.TrimSpace(sessionToken) == "" {
+ clearCookies()
+ return nil, nil, clearCookies, service.ErrPendingAuthSessionNotFound
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil || strings.TrimSpace(browserSessionKey) == "" {
+ clearCookies()
+ return nil, nil, clearCookies, service.ErrPendingAuthBrowserMismatch
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ clearCookies()
+ return nil, nil, clearCookies, err
+ }
+
+ session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearCookies()
+ return nil, nil, clearCookies, err
+ }
+
+ return svc, session, clearCookies, nil
+}
+
+func (h *AuthHandler) consumePendingOAuthSessionOnLogout(c *gin.Context) {
+ if c == nil || c.Request == nil {
+ return
+ }
+
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil || strings.TrimSpace(sessionToken) == "" {
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil || strings.TrimSpace(browserSessionKey) == "" {
+ return
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ return
+ }
+ _, _ = svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+}
+
+func clearOAuthLogoutCookies(c *gin.Context) {
+ secureCookie := isRequestHTTPS(c)
+
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ clearOAuthBindAccessTokenCookie(c, secureCookie)
+
+ clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
+ clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
+ clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
+ clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie)
+ clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
+
+ oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie)
+ oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie)
+ oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie)
+ oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie)
+ oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie)
+ oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie)
+
+ wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie)
+
+ wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie)
+}
+
+func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H {
+ completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
+ payload := gin.H{
+ "auth_result": "pending_session",
+ "provider": strings.TrimSpace(session.ProviderType),
+ "intent": strings.TrimSpace(session.Intent),
+ }
+ for key, value := range completionResponse {
+ payload[key] = value
+ }
+ if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
+ payload["email"] = email
+ }
+ return payload
+}
+
+func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]any {
+ normalized := clonePendingMap(payload)
+ for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
+ delete(normalized, key)
+ }
+ step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step")))
+ switch step {
+ case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required":
+ normalized["step"] = oauthPendingChoiceStep
+ }
+ if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(normalized, "step")), oauthPendingChoiceStep) {
+ normalized["adoption_required"] = true
+ }
+ if _, exists := normalized["adoption_required"]; !exists {
+ if _, hasChoiceFields := normalized["email_binding_required"]; hasChoiceFields {
+ normalized["adoption_required"] = true
+ }
+ }
+ return normalized
+}
+
+func pendingOAuthChoiceCompletionResponse(session *dbent.PendingAuthSession, email string) map[string]any {
+ response := mergePendingCompletionResponse(session, map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "force_email_on_signup": true,
+ "email_binding_required": true,
+ "existing_account_bindable": true,
+ })
+ if email = strings.TrimSpace(email); email != "" {
+ response["email"] = email
+ response["resolved_email"] = email
+ }
+ return response
+}
+
+func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState(
+ c *gin.Context,
+ client *dbent.Client,
+ session *dbent.PendingAuthSession,
+ targetUser *dbent.User,
+ email string,
+) (*dbent.PendingAuthSession, error) {
+ completionResponse := pendingOAuthChoiceCompletionResponse(session, email)
+ var targetUserID *int64
+ if targetUser != nil && targetUser.ID > 0 {
+ targetUserID = &targetUser.ID
+ }
+ session, err := updatePendingOAuthSessionProgress(
+ c.Request.Context(),
+ client,
+ session,
+ strings.TrimSpace(session.Intent),
+ email,
+ targetUserID,
+ completionResponse,
+ )
+ if err != nil {
+ return nil, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
+ }
+ return session, nil
+}
+
+func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) {
+ c.JSON(http.StatusOK, gin.H{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ })
+}
+
+func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
+ var req bindPendingOAuthLoginRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
+ response.BadRequest(c, "Pending oauth session provider mismatch")
+ return
+ }
+
+ user, err := h.authService.ValidatePasswordCredentials(c.Request.Context(), strings.TrimSpace(req.Email), req.Password)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if session.TargetUserID != nil && *session.TargetUserID > 0 && user.ID != *session.TargetUserID {
+ response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user"))
+ return
+ }
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
+ tempToken, err := h.totpService.CreatePendingOAuthBindLoginSession(
+ c.Request.Context(),
+ user.ID,
+ user.Email,
+ session.SessionToken,
+ session.BrowserSessionKey,
+ )
+ if err != nil {
+ response.InternalError(c, "Failed to create 2FA session")
+ return
+ }
+ response.Success(c, TotpLoginResponse{
+ Requires2FA: true,
+ TempToken: tempToken,
+ UserEmailMasked: service.MaskEmail(user.Email),
+ })
+ return
+ }
+ if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID, true, true); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
+ if err != nil {
+ response.InternalError(c, "Failed to generate token pair")
+ return
+ }
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ clearCookies()
+ writeOAuthTokenPairResponse(c, tokenPair)
+}
+
+func respondPendingOAuthBindingApplyError(c *gin.Context, err error) {
+ if code := infraerrors.Code(err); code >= http.StatusBadRequest && code < http.StatusInternalServerError {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+}
+
+func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) {
+ var req createPendingOAuthAccountRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ _, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
+ response.BadRequest(c, "Pending oauth session provider mismatch")
+ return
+ }
+
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+
+ email := strings.TrimSpace(strings.ToLower(req.Email))
+ existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email)
+ if err != nil {
+ switch {
+ case errors.Is(err, service.ErrUserNotFound):
+ existingUser = nil
+ case infraerrors.Code(err) >= http.StatusBadRequest && infraerrors.Code(err) < http.StatusInternalServerError:
+ response.ErrorFrom(c, err)
+ return
+ default:
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable"))
+ return
+ }
+ }
+ if existingUser != nil {
+ session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
+ return
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ tokenPair, user, err := h.authService.RegisterOAuthEmailAccount(
+ c.Request.Context(),
+ email,
+ req.Password,
+ strings.TrimSpace(req.VerifyCode),
+ strings.TrimSpace(req.InvitationCode),
+ strings.TrimSpace(session.ProviderType),
+ )
+ if err != nil {
+ if errors.Is(err, service.ErrEmailExists) {
+ existingUser, lookupErr := findUserByNormalizedEmail(c.Request.Context(), client, email)
+ if lookupErr != nil {
+ response.ErrorFrom(c, lookupErr)
+ return
+ }
+ session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
+ return
+ }
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ rollbackCreatedUser := func(originalErr error) bool {
+ if user == nil || user.ID <= 0 {
+ return false
+ }
+ if rollbackErr := h.authService.RollbackOAuthEmailAccountCreation(
+ c.Request.Context(),
+ user.ID,
+ strings.TrimSpace(req.InvitationCode),
+ ); rollbackErr != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer(
+ "PENDING_AUTH_ACCOUNT_ROLLBACK_FAILED",
+ "failed to rollback pending oauth account creation",
+ ).WithCause(fmt.Errorf("original error: %w; rollback error: %v", originalErr, rollbackErr)))
+ return true
+ }
+ user = nil
+ return false
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
+ if err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ tx, err := client.Tx(c.Request.Context())
+ if err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+ defer func() { _ = tx.Rollback() }()
+ txCtx := dbent.NewTxContext(c.Request.Context(), tx)
+
+ if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
+ _ = tx.Rollback()
+ if rollbackCreatedUser(err) {
+ return
+ }
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+
+ if err := h.authService.FinalizeOAuthEmailAccount(
+ txCtx,
+ user,
+ strings.TrimSpace(req.InvitationCode),
+ strings.TrimSpace(session.ProviderType),
+ strings.TrimSpace(req.AffCode),
+ ); err != nil {
+ _ = tx.Rollback()
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
+ _ = tx.Rollback()
+ if rollbackCreatedUser(err) {
+ return
+ }
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ if pendingOAuthCreateAccountPreCommitHook != nil {
+ if err := pendingOAuthCreateAccountPreCommitHook(txCtx, session); err != nil {
+ _ = tx.Rollback()
+ if rollbackCreatedUser(err) {
+ return
+ }
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ }
+
+ if err := tx.Commit(); err != nil {
+ if rollbackCreatedUser(err) {
+ return
+ }
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
+ return
+ }
+
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ clearCookies()
+ writeOAuthTokenPairResponse(c, tokenPair)
+}
+
+// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload.
+// POST /api/v1/auth/oauth/pending/exchange
+func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
+ secureCookie := isRequestHTTPS(c)
+ clearCookies := func() {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ }
+ adoptionDecision, err := bindOptionalOAuthAdoptionDecision(c)
+ if err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil || strings.TrimSpace(sessionToken) == "" {
+ clearCookies()
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil || strings.TrimSpace(browserSessionKey) == "" {
+ clearCookies()
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+
+ svc, err := h.pendingIdentityService()
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ payload, ok := readCompletionResponse(session.LocalFlowState)
+ if !ok {
+ clearCookies()
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid"))
+ return
+ }
+ payload = normalizePendingOAuthCompletionResponse(payload)
+ if strings.TrimSpace(session.RedirectTo) != "" {
+ if _, exists := payload["redirect"]; !exists {
+ payload["redirect"] = session.RedirectTo
+ }
+ }
+ applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
+
+ canIssueTokenPair := pendingOAuthCompletionCanIssueTokenPair(session, payload)
+ var loginUser *service.User
+ if canIssueTokenPair {
+ loginUser, err = h.userService.GetByID(c.Request.Context(), *session.TargetUserID)
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensureLoginUserActive(loginUser); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := h.ensureBackendModeAllowsUser(c.Request.Context(), loginUser); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+ skipAdoptionPrompt, err := h.shouldSkipPendingOAuthAdoptionPrompt(c.Request.Context(), session, payload)
+ if err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+ if skipAdoptionPrompt {
+ delete(payload, "adoption_required")
+ }
+
+ if pendingSessionWantsInvitation(payload) {
+ if adoptionDecision.hasDecision() {
+ decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ _ = decision
+ }
+ response.Success(c, payload)
+ return
+ }
+ if !adoptionDecision.hasDecision() {
+ adoptionRequired, _ := payload["adoption_required"].(bool)
+ if adoptionRequired {
+ response.Success(c, payload)
+ return
+ }
+ }
+
+ decisionReq := adoptionDecision
+ if !decisionReq.hasDecision() {
+ adoptDisplayName := false
+ adoptAvatar := false
+ decisionReq = oauthAdoptionDecisionRequest{
+ AdoptDisplayName: &adoptDisplayName,
+ AdoptAvatar: &adoptAvatar,
+ }
+ }
+
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, decisionReq)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, session.TargetUserID); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
+ return
+ }
+
+ if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
+ clearCookies()
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ if canIssueTokenPair {
+ tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), loginUser, "")
+ if err != nil {
+ clearCookies()
+ response.InternalError(c, "Failed to generate token pair")
+ return
+ }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), loginUser.ID)
+ payload["access_token"] = tokenPair.AccessToken
+ payload["refresh_token"] = tokenPair.RefreshToken
+ payload["expires_in"] = tokenPair.ExpiresIn
+ payload["token_type"] = "Bearer"
+ }
+
+ clearCookies()
+ response.Success(c, payload)
+}
diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go
new file mode 100644
index 00000000..ffe9ff5f
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_pending_flow_test.go
@@ -0,0 +1,2996 @@
+package handler
+
+import (
+ "bytes"
+ "context"
+ "database/sql"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ "github.com/Wei-Shaw/sub2api/ent/redeemcode"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/pquerna/otp/totp"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func TestApplySuggestedProfileToCompletionResponse(t *testing.T) {
+ payload := map[string]any{
+ "access_token": "token",
+ }
+ upstream := map[string]any{
+ "suggested_display_name": "Alice",
+ "suggested_avatar_url": "https://cdn.example/avatar.png",
+ }
+
+ applySuggestedProfileToCompletionResponse(payload, upstream)
+
+ require.Equal(t, "Alice", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"])
+ require.Equal(t, true, payload["adoption_required"])
+}
+
+func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *testing.T) {
+ payload := map[string]any{
+ "suggested_display_name": "Existing",
+ "adoption_required": false,
+ }
+ upstream := map[string]any{
+ "suggested_display_name": "Alice",
+ "suggested_avatar_url": "https://cdn.example/avatar.png",
+ }
+
+ applySuggestedProfileToCompletionResponse(payload, upstream)
+
+ require.Equal(t, "Existing", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"])
+ require.Equal(t, true, payload["adoption_required"])
+}
+
+func TestSetOAuthPendingSessionCookieUsesProviderCompletionPathPrefix(t *testing.T) {
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ ginCtx.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil)
+
+ setOAuthPendingSessionCookie(ginCtx, "pending-session-token", false)
+
+ cookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, cookie)
+ require.Equal(t, "/api/v1/auth/oauth", cookie.Path)
+}
+
+func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("linuxdo-123@linuxdo-connect.invalid").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("pending-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "Alice Example",
+ "suggested_avatar_url": "https://cdn.example/alice.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ previewRecorder := httptest.NewRecorder()
+ previewCtx, _ := gin.CreateTestContext(previewRecorder)
+ previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
+ previewCtx.Request = previewReq
+
+ handler.ExchangePendingOAuthCompletion(previewCtx)
+
+ require.Equal(t, http.StatusOK, previewRecorder.Code)
+ previewData := decodeJSONResponseData(t, previewRecorder)
+ require.Equal(t, "Alice Example", previewData["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/alice.png", previewData["suggested_avatar_url"])
+ require.Equal(t, true, previewData["adoption_required"])
+
+ storedUser, err := client.User.Get(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "legacy-name", storedUser.Username)
+
+ previewSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, previewSession.ConsumedAt)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`)
+ finalizeRecorder := httptest.NewRecorder()
+ finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder)
+ finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ finalizeReq.Header.Set("Content-Type", "application/json")
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")})
+ finalizeCtx.Request = finalizeReq
+
+ handler.ExchangePendingOAuthCompletion(finalizeCtx)
+
+ require.Equal(t, http.StatusOK, finalizeRecorder.Code)
+
+ storedUser, err = client.User.Get(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "Alice Example", storedUser.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "Alice Example", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"])
+
+ avatar := loadUserAvatarRecord(t, client, userEntity.ID)
+ require.NotNil(t, avatar)
+ require.Equal(t, "remote_url", avatar.StorageProvider)
+ require.Equal(t, "https://cdn.example/alice.png", avatar.URL)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionSkipsInvalidAvatarAdoptionWithoutBlockingCompletion(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("invalid-avatar@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("pending-invalid-avatar-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("invalid-avatar-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("browser-invalid-avatar-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "Alice Example",
+ "suggested_avatar_url": "/avatars/alice.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-invalid-avatar-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("invalid-avatar-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "Alice Example", identity.Metadata["display_name"])
+ _, hasAdoptedAvatar := identity.Metadata["avatar_url"]
+ require.False(t, hasAdoptedAvatar)
+
+ avatar := loadUserAvatarRecord(t, client, userEntity.ID)
+ require.Nil(t, avatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionBindCurrentUserPreviewThenFinalizeBindsIdentityWithoutAdoption(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("bind-target@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-pending-session-token").
+ SetIntent("bind_current_user").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("bind-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("bind-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "linuxdo_user",
+ "suggested_display_name": "Bound Example",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "redirect": "/settings/profile",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ previewRecorder := httptest.NewRecorder()
+ previewCtx, _ := gin.CreateTestContext(previewRecorder)
+ previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-browser-session-key")})
+ previewCtx.Request = previewReq
+
+ handler.ExchangePendingOAuthCompletion(previewCtx)
+
+ require.Equal(t, http.StatusOK, previewRecorder.Code)
+ previewData := decodeJSONResponseData(t, previewRecorder)
+ require.Equal(t, "Bound Example", previewData["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/bound.png", previewData["suggested_avatar_url"])
+ require.Equal(t, true, previewData["adoption_required"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("bind-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ previewSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, previewSession.ConsumedAt)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`)
+ finalizeRecorder := httptest.NewRecorder()
+ finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder)
+ finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ finalizeReq.Header.Set("Content-Type", "application/json")
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-browser-session-key")})
+ finalizeCtx.Request = finalizeReq
+
+ handler.ExchangePendingOAuthCompletion(finalizeCtx)
+
+ require.Equal(t, http.StatusOK, finalizeRecorder.Code)
+
+ storedUser, err := client.User.Get(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "legacy-name", storedUser.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("bind-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "Bound Example", identity.Metadata["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/bound.png", identity.Metadata["suggested_avatar_url"])
+ _, hasDisplayName := identity.Metadata["display_name"]
+ require.False(t, hasDisplayName)
+ _, hasAvatarURL := identity.Metadata["avatar_url"]
+ require.False(t, hasAvatarURL)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionBindCurrentUserOwnershipConflict(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ targetUser, err := client.User.Create().
+ SetEmail("bind-conflict-target@example.com").
+ SetUsername("target-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ ownerUser, err := client.User.Create().
+ SetEmail("bind-conflict-owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ existingIdentity, err := client.AuthIdentity.Create().
+ SetUserID(ownerUser.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("conflict-123").
+ SetMetadata(map[string]any{"username": "owner-user"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-conflict-session-token").
+ SetIntent("bind_current_user").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("conflict-123").
+ SetTargetUserID(targetUser.ID).
+ SetResolvedEmail(targetUser.Email).
+ SetBrowserSessionKey("bind-conflict-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Conflict Example",
+ "suggested_avatar_url": "https://cdn.example/conflict.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-conflict-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusInternalServerError, recorder.Code)
+ payload := decodeJSONBody(t, recorder)
+ require.Equal(t, "PENDING_AUTH_ADOPTION_APPLY_FAILED", payload["reason"])
+
+ identity, err := client.AuthIdentity.Get(ctx, existingIdentity.ID)
+ require.NoError(t, err)
+ require.Equal(t, ownerUser.ID, identity.UserID)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdoption(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("login-false@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-false-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-false-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("login-false-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Login Example",
+ "suggested_avatar_url": "https://cdn.example/login.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-false-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("login-false-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionLoginReassignsExistingDecisionIdentityReference(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("login-reassign@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ existingIdentity, err := client.AuthIdentity.Create().
+ SetUserID(userEntity.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-reassign-123").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ previousSession, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-reassign-previous-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-reassign-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("login-reassign-previous-browser-session-key").
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "previous-access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ previousDecision, err := client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(previousSession.ID).
+ SetIdentityID(existingIdentity.ID).
+ SetAdoptDisplayName(true).
+ SetAdoptAvatar(true).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-reassign-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-reassign-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("login-reassign-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Login Reassign",
+ "suggested_avatar_url": "https://cdn.example/login-reassign.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(session.ID).
+ SetAdoptDisplayName(false).
+ SetAdoptAvatar(false).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-reassign-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ reloadedPrevious, err := client.IdentityAdoptionDecision.Get(ctx, previousDecision.ID)
+ require.NoError(t, err)
+ require.Nil(t, reloadedPrevious.IdentityID)
+
+ currentDecision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, currentDecision.IdentityID)
+ require.Equal(t, existingIdentity.ID, *currentDecision.IdentityID)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionLoginWithoutDecisionStillBindsIdentity(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("login-nodecision@example.com").
+ SetUsername("legacy-name").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-nodecision-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("login-nodecision-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("login-nodecision-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "login-nodecision-user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("login-nodecision-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("login-nodecision-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdoptionPrompt(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("existing-login@example.com").
+ SetUsername("existing-login-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(userEntity.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("existing-login-123").
+ SetMetadata(map[string]any{
+ "username": "existing-login-user",
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-login-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("existing-login-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("existing-login-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Existing Login Example",
+ "suggested_avatar_url": "https://cdn.example/existing-login.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "legacy-access-token",
+ "refresh_token": "legacy-refresh-token",
+ "expires_in": float64(3600),
+ "token_type": "Bearer",
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-login-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ payload := decodeJSONResponseData(t, recorder)
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ require.NotEqual(t, "legacy-access-token", payload["access_token"])
+ require.NotEqual(t, "legacy-refresh-token", payload["refresh_token"])
+ require.Equal(t, "/dashboard", payload["redirect"])
+ require.Equal(t, "Existing Login Example", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/existing-login.png", payload["suggested_avatar_url"])
+ require.NotContains(t, payload, "adoption_required")
+
+ accessToken, ok := payload["access_token"].(string)
+ require.True(t, ok)
+ claims, err := handler.authService.ValidateToken(accessToken)
+ require.NoError(t, err)
+ reloadedUser, err := handler.userService.GetByID(ctx, userEntity.ID)
+ require.NoError(t, err)
+ require.Equal(t, reloadedUser.TokenVersion, claims.TokenVersion)
+
+ decisionCount, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, decisionCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+
+ completion, ok := storedSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.NotContains(t, completion, "access_token")
+ require.NotContains(t, completion, "refresh_token")
+ require.NotContains(t, completion, "expires_in")
+ require.NotContains(t, completion, "token_type")
+ require.Equal(t, "/dashboard", completion["redirect"])
+}
+
+func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyBackendModeEnabled: "true",
+ },
+ })
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("blocked@example.com").
+ SetUsername("blocked-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("blocked-backend-mode-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("blocked-subject-123").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("blocked-backend-mode-browser-session-key").
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "access_token": "access-token",
+ "refresh_token": "refresh-token",
+ "expires_in": float64(3600),
+ "token_type": "Bearer",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("blocked-backend-mode-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestExchangePendingOAuthCompletionRejectsDisabledTargetUser(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ userEntity, err := client.User.Create().
+ SetEmail("disabled-linked@example.com").
+ SetUsername("disabled-linked-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusDisabled).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("disabled-linked-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("disabled-linked-subject").
+ SetTargetUserID(userEntity.ID).
+ SetResolvedEmail(userEntity.Email).
+ SetBrowserSessionKey("disabled-linked-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Disabled Linked User",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "redirect": "/dashboard",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("disabled-linked-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload(t *testing.T) {
+ payload := normalizePendingOAuthCompletionResponse(map[string]any{
+ "access_token": "legacy-access-token",
+ "refresh_token": "legacy-refresh-token",
+ "expires_in": float64(3600),
+ "token_type": "Bearer",
+ "redirect": "/dashboard",
+ })
+
+ require.NotContains(t, payload, "access_token")
+ require.NotContains(t, payload, "refresh_token")
+ require.NotContains(t, payload, "expires_in")
+ require.NotContains(t, payload, "token_type")
+ require.Equal(t, "/dashboard", payload["redirect"])
+}
+
+func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, true)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("invitation-required-session-token").
+ SetIntent("login").
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("invitation-123").
+ SetBrowserSessionKey("invitation-required-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Invite Example",
+ "suggested_avatar_url": "https://cdn.example/invite.png",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "error": "invitation_required",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("invitation-required-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.ExchangePendingOAuthCompletion(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ data := decodeJSONResponseData(t, recorder)
+ require.Equal(t, "invitation_required", data["error"])
+ require.Equal(t, true, data["adoption_required"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("linuxdo"),
+ authidentity.ProviderKeyEQ("linuxdo"),
+ authidentity.ProviderSubjectEQ("invitation-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+
+ storedSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCreateOIDCOAuthAccountCreatesUserBindsIdentityAndConsumesSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "fresh@example.com", "246810")
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-create-123").
+ SetBrowserSessionKey("create-account-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Fresh OIDC User",
+ "suggested_avatar_url": "https://cdn.example/fresh.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ require.Equal(t, "Bearer", payload["token_type"])
+
+ createdUser, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, service.StatusActive, createdUser.Status)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-create-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, createdUser.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-email-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-existing-123").
+ SetBrowserSessionKey("existing-email-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Existing OIDC User",
+ "suggested_avatar_url": "https://cdn.example/existing.png",
+ }).
+ SetRedirectTo("/dashboard").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.Equal(t, "pending_session", payload["auth_result"])
+ require.Equal(t, oauthIntentLogin, payload["intent"])
+ require.Equal(t, "oidc", payload["provider"])
+ require.Equal(t, "/dashboard", payload["redirect"])
+ require.Equal(t, true, payload["adoption_required"])
+ require.Equal(t, oauthPendingChoiceStep, payload["step"])
+ require.Equal(t, "owner@example.com", payload["email"])
+ require.Equal(t, "Existing OIDC User", payload["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/existing.png", payload["suggested_avatar_url"])
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, storedSession.Intent)
+ require.NotNil(t, storedSession.TargetUserID)
+ require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
+ require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
+ require.Nil(t, storedSession.ConsumedAt)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-existing-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+}
+
+func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail(" Owner@Example.com ").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-email-normalized-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-existing-normalized-123").
+ SetBrowserSessionKey("existing-email-normalized-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Existing OIDC User",
+ "suggested_avatar_url": "https://cdn.example/existing.png",
+ }).
+ SetRedirectTo("/dashboard").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-normalized-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.Equal(t, oauthIntentLogin, payload["intent"])
+ require.Equal(t, oauthPendingChoiceStep, payload["step"])
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.TargetUserID)
+ require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
+ require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
+}
+
+func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("existing-email-send-code-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-existing-send-code-123").
+ SetBrowserSessionKey("existing-email-send-code-browser-session-key").
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "email_required",
+ },
+ }).
+ SetRedirectTo("/dashboard").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/send-verify-code", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-send-code-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.SendPendingOAuthVerifyCode(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.Equal(t, "pending_session", payload["auth_result"])
+ require.Equal(t, oauthPendingChoiceStep, payload["step"])
+ require.Equal(t, "owner@example.com", payload["email"])
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, storedSession.Intent)
+ require.NotNil(t, storedSession.TargetUserID)
+ require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
+ require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
+}
+
+func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ emailVerifyEnabled: true,
+ emailCache: &oauthPendingFlowEmailCacheStub{
+ verificationCodes: map[string]*service.VerificationCodeData{
+ "fresh@example.com": {
+ Code: "246810",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ },
+ },
+ settingValues: map[string]string{
+ service.SettingKeyBackendModeEnabled: "true",
+ },
+ })
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-backend-mode-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-create-backend-mode-123").
+ SetBrowserSessionKey("create-account-backend-mode-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-backend-mode-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+
+ userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestLogoutClearsPendingOAuthAndBindCookies(t *testing.T) {
+ handler, _ := newOAuthPendingFlowTestHandler(t, false)
+
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", bytes.NewBufferString(`{}`))
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue("pending-session-token")})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("pending-browser-key")})
+ req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-token"})
+ ginCtx.Request = req
+
+ handler.Logout(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName).MaxAge)
+ require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingBrowserCookieName).MaxAge)
+ require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName).MaxAge)
+}
+
+func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810")
+ ctx := context.Background()
+
+ conflictOwner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(conflictOwner.ID).
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-conflict-123").
+ SetMetadata(map[string]any{
+ "username": "owner-user",
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ invitation, err := client.RedeemCode.Create().
+ SetCode("INVITE123").
+ SetType(service.RedeemTypeInvitation).
+ SetStatus(service.StatusUnused).
+ SetValue(0).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-conflict-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-conflict-123").
+ SetBrowserSessionKey("create-account-conflict-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","invitation_code":"INVITE123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-conflict-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusConflict, recorder.Code)
+
+ userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedInvitation, err := client.RedeemCode.Get(ctx, invitation.ID)
+ require.NoError(t, err)
+ require.Equal(t, service.StatusUnused, storedInvitation.Status)
+ require.Nil(t, storedInvitation.UsedBy)
+ require.Nil(t, storedInvitation.UsedAt)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCreateOIDCOAuthAccountRollsBackPostBindFailureBeforeIdentityCanCommit(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ emailVerifyEnabled: true,
+ emailCache: &oauthPendingFlowEmailCacheStub{
+ verificationCodes: map[string]*service.VerificationCodeData{
+ "fresh@example.com": {
+ Code: "246810",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ },
+ },
+ userRepoOptions: oauthPendingFlowUserRepoOptions{
+ rejectDeleteWhileAuthIdentityExists: true,
+ },
+ })
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("create-account-finalize-failure-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-finalize-failure-123").
+ SetBrowserSessionKey("create-account-finalize-failure-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ pendingOAuthCreateAccountPreCommitHook = func(context.Context, *dbent.PendingAuthSession) error {
+ return errors.New("forced post-bind failure")
+ }
+ t.Cleanup(func() {
+ pendingOAuthCreateAccountPreCommitHook = nil
+ })
+
+ body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-finalize-failure-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.CreateOIDCOAuthAccount(ginCtx)
+
+ require.Equal(t, http.StatusInternalServerError, recorder.Code)
+
+ userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-finalize-failure-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ require.Equal(t, "Bearer", payload["token_type"])
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, existingUser.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+}
+
+func TestBindOIDCOAuthLoginBlocksBackendModeBeforeTokenIssue(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyBackendModeEnabled: "true",
+ },
+ })
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-backend-mode-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-backend-mode-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-backend-mode-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-backend-mode-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusForbidden, recorder.Code)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-backend-mode-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-invalid-password-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-invalid-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-invalid-password-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"wrong-password"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-invalid-password-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusUnauthorized, recorder.Code)
+ payload := decodeJSONBody(t, recorder)
+ require.Equal(t, "INVALID_CREDENTIALS", payload["reason"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-invalid-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestBindOIDCOAuthLoginReclaimsIdentityOwnedBySoftDeletedUser(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ oldOwnerHash, err := handler.authService.HashPassword("old-secret")
+ require.NoError(t, err)
+ oldOwner, err := client.User.Create().
+ SetEmail("old-owner@example.com").
+ SetUsername("old-owner").
+ SetPasswordHash(oldOwnerHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(oldOwner.ID).
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-soft-deleted-123").
+ SetMetadata(map[string]any{"username": "old-owner"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.User.Delete().Where(dbuser.IDEQ(oldOwner.ID)).Exec(ctx)
+ require.NoError(t, err)
+
+ newOwnerHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+ newOwner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(newOwnerHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-soft-deleted-owner-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-soft-deleted-123").
+ SetTargetUserID(newOwner.ID).
+ SetResolvedEmail(newOwner.Email).
+ SetBrowserSessionKey("bind-login-soft-deleted-owner-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "suggested_display_name": "Recovered OIDC User",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-soft-deleted-owner-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ identity, err = client.AuthIdentity.Get(ctx, identity.ID)
+ require.NoError(t, err)
+ require.Equal(t, newOwner.ID, identity.UserID)
+}
+
+func TestBindOIDCOAuthLoginAppliesFirstBindGrantOnce(t *testing.T) {
+ defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{}
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyAuthSourceDefaultOIDCBalance: "12.5",
+ service.SettingKeyAuthSourceDefaultOIDCConcurrency: "3",
+ service.SettingKeyAuthSourceDefaultOIDCSubscriptions: `[{"group_id":101,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true",
+ },
+ defaultSubAssigner: defaultSubAssigner,
+ })
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ firstSession, err := client.PendingAuthSession.Create().
+ SetSessionToken("first-bind-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-first-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("first-bind-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ firstBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ firstRecorder := httptest.NewRecorder()
+ firstGinCtx, _ := gin.CreateTestContext(firstRecorder)
+ firstReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", firstBody)
+ firstReq.Header.Set("Content-Type", "application/json")
+ firstReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(firstSession.SessionToken)})
+ firstReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("first-bind-browser-session-key")})
+ firstGinCtx.Request = firstReq
+
+ handler.BindOIDCOAuthLogin(firstGinCtx)
+
+ require.Equal(t, http.StatusOK, firstRecorder.Code)
+
+ storedUser, err := client.User.Get(ctx, existingUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, 17.5, storedUser.Balance)
+ require.Equal(t, 5, storedUser.Concurrency)
+ require.Zero(t, storedUser.TotalRecharged)
+ require.Len(t, defaultSubAssigner.calls, 1)
+ require.Equal(t, int64(existingUser.ID), defaultSubAssigner.calls[0].UserID)
+ require.Equal(t, int64(101), defaultSubAssigner.calls[0].GroupID)
+ require.Equal(t, 30, defaultSubAssigner.calls[0].ValidityDays)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind"))
+
+ secondSession, err := client.PendingAuthSession.Create().
+ SetSessionToken("second-bind-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-second-456").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("second-bind-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Second OIDC User",
+ "suggested_avatar_url": "https://cdn.example/second.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ secondBody := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ secondRecorder := httptest.NewRecorder()
+ secondGinCtx, _ := gin.CreateTestContext(secondRecorder)
+ secondReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", secondBody)
+ secondReq.Header.Set("Content-Type", "application/json")
+ secondReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(secondSession.SessionToken)})
+ secondReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("second-bind-browser-session-key")})
+ secondGinCtx.Request = secondReq
+
+ handler.BindOIDCOAuthLogin(secondGinCtx)
+
+ require.Equal(t, http.StatusOK, secondRecorder.Code)
+
+ storedUser, err = client.User.Get(ctx, existingUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, 17.5, storedUser.Balance)
+ require.Equal(t, 5, storedUser.Concurrency)
+ require.Zero(t, storedUser.TotalRecharged)
+ require.Len(t, defaultSubAssigner.calls, 1)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind"))
+}
+
+func TestResolvePendingOAuthTargetUserIDNormalizesLegacySpacingAndCase(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ _ = handler
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail(" Owner@Example.com ").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("resolve-target-session-token").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-target-123").
+ SetResolvedEmail("owner@example.com").
+ SetBrowserSessionKey("resolve-target-browser-session-key").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session)
+ require.NoError(t, err)
+ require.Equal(t, existingUser.ID, resolvedUserID)
+}
+
+func TestBindOIDCOAuthLoginReturns2FAChallengeWhenUserHasTotp(t *testing.T) {
+ totpCache := &oauthPendingFlowTotpCacheStub{}
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyTotpEnabled: "true",
+ },
+ totpCache: totpCache,
+ totpEncryptor: oauthPendingFlowTotpEncryptorStub{},
+ })
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+ totpEnabledAt := time.Now().UTC().Add(-time.Hour)
+ secret := "JBSWY3DPEHPK3PXP"
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ SetTotpEnabled(true).
+ SetTotpSecretEncrypted(secret).
+ SetTotpEnabledAt(totpEnabledAt).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("bind-login-2fa-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-bind-2fa-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("bind-login-2fa-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-2fa-browser-session-key")})
+ ginCtx.Request = req
+
+ handler.BindOIDCOAuthLogin(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ data := decodeJSONResponseData(t, recorder)
+ require.Equal(t, true, data["requires_2fa"])
+ require.Equal(t, "o***r@example.com", data["user_email_masked"])
+ tempToken, ok := data["temp_token"].(string)
+ require.True(t, ok)
+ require.NotEmpty(t, tempToken)
+
+ loginSession, err := totpCache.GetLoginSession(ctx, tempToken)
+ require.NoError(t, err)
+ require.NotNil(t, loginSession)
+ require.NotNil(t, loginSession.PendingOAuthBind)
+ require.Equal(t, session.SessionToken, loginSession.PendingOAuthBind.PendingSessionToken)
+ require.Equal(t, session.BrowserSessionKey, loginSession.PendingOAuthBind.BrowserSessionKey)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-bind-2fa-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestLogin2FACompletesPendingOAuthBindAndConsumesSession(t *testing.T) {
+ totpCache := &oauthPendingFlowTotpCacheStub{}
+ defaultSubAssigner := &oauthPendingFlowDefaultSubAssignerStub{}
+ handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ settingValues: map[string]string{
+ service.SettingKeyTotpEnabled: "true",
+ service.SettingKeyAuthSourceDefaultOIDCBalance: "8",
+ service.SettingKeyAuthSourceDefaultOIDCConcurrency: "2",
+ service.SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "true",
+ },
+ defaultSubAssigner: defaultSubAssigner,
+ totpCache: totpCache,
+ totpEncryptor: oauthPendingFlowTotpEncryptorStub{},
+ })
+ ctx := context.Background()
+
+ passwordHash, err := handler.authService.HashPassword("secret-123")
+ require.NoError(t, err)
+ totpEnabledAt := time.Now().UTC().Add(-time.Hour)
+ secret := "JBSWY3DPEHPK3PXP"
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(4).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ SetTotpEnabled(true).
+ SetTotpSecretEncrypted(secret).
+ SetTotpEnabledAt(totpEnabledAt).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("login-2fa-pending-session-token").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("oidc-login-2fa-123").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("login-2fa-browser-session-key").
+ SetUpstreamIdentityClaims(map[string]any{
+ "suggested_display_name": "Bound OIDC User",
+ "suggested_avatar_url": "https://cdn.example/bound.png",
+ }).
+ SetRedirectTo("/profile").
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(session.ID).
+ SetAdoptDisplayName(false).
+ SetAdoptAvatar(false).
+ Save(ctx)
+ require.NoError(t, err)
+
+ tempToken, err := handler.totpService.CreatePendingOAuthBindLoginSession(
+ ctx,
+ existingUser.ID,
+ existingUser.Email,
+ session.SessionToken,
+ session.BrowserSessionKey,
+ )
+ require.NoError(t, err)
+
+ code, err := totp.GenerateCode(secret, time.Now().UTC())
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"temp_token":"` + tempToken + `","totp_code":"` + code + `"}`)
+ recorder := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/login/2fa", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue(session.BrowserSessionKey)})
+ ginCtx.Request = req
+
+ handler.Login2FA(ginCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ payload := decodeJSONResponseData(t, recorder)
+ require.NotEmpty(t, payload["access_token"])
+ require.NotEmpty(t, payload["refresh_token"])
+ accessToken, ok := payload["access_token"].(string)
+ require.True(t, ok)
+ claims, err := handler.authService.ValidateToken(accessToken)
+ require.NoError(t, err)
+ reloadedUser, err := handler.userService.GetByID(ctx, existingUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, reloadedUser.TokenVersion, claims.TokenVersion)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("oidc-login-2fa-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, existingUser.ID, identity.UserID)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedSession.ConsumedAt)
+
+ loginSession, err := totpCache.GetLoginSession(ctx, tempToken)
+ require.NoError(t, err)
+ require.Nil(t, loginSession)
+
+ storedUser, err := client.User.Get(ctx, existingUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, 9.5, storedUser.Balance)
+ require.Equal(t, 6, storedUser.Concurrency)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind"))
+ require.Empty(t, defaultSubAssigner.calls)
+}
+
+func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, false, nil)
+}
+
+func newOAuthPendingFlowTestHandlerWithEmailVerification(
+ t *testing.T,
+ invitationEnabled bool,
+ email string,
+ code string,
+) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ cache := &oauthPendingFlowEmailCacheStub{
+ verificationCodes: map[string]*service.VerificationCodeData{
+ email: {
+ Code: code,
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ },
+ }
+ return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, true, cache)
+}
+
+func newOAuthPendingFlowTestHandlerWithOptions(
+ t *testing.T,
+ invitationEnabled bool,
+ emailVerifyEnabled bool,
+ emailCache service.EmailCache,
+) (*AuthHandler, *dbent.Client) {
+ return newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
+ invitationEnabled: invitationEnabled,
+ emailVerifyEnabled: emailVerifyEnabled,
+ emailCache: emailCache,
+ })
+}
+
+type oauthPendingFlowTestHandlerOptions struct {
+ invitationEnabled bool
+ emailVerifyEnabled bool
+ emailCache service.EmailCache
+ settingValues map[string]string
+ defaultSubAssigner service.DefaultSubscriptionAssigner
+ totpCache service.TotpCache
+ totpEncryptor service.SecretEncryptor
+ userRepoOptions oauthPendingFlowUserRepoOptions
+}
+
+func newOAuthPendingFlowTestHandlerWithDependencies(
+ t *testing.T,
+ options oauthPendingFlowTestHandlerOptions,
+) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ provider_type TEXT NOT NULL,
+ grant_reason TEXT NOT NULL DEFAULT 'first_bind',
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(user_id, provider_type, grant_reason)
+)`)
+ require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_avatars (
+ user_id INTEGER PRIMARY KEY,
+ storage_provider TEXT NOT NULL,
+ storage_key TEXT NOT NULL DEFAULT '',
+ url TEXT NOT NULL,
+ content_type TEXT NOT NULL DEFAULT '',
+ byte_size INTEGER NOT NULL DEFAULT 0,
+ sha256 TEXT NOT NULL DEFAULT '',
+ updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
+)`)
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 0,
+ UserConcurrency: 1,
+ },
+ }
+ settingValues := map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyInvitationCodeEnabled: boolSettingValue(options.invitationEnabled),
+ service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
+ }
+ for key, value := range options.settingValues {
+ settingValues[key] = value
+ }
+ settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{values: settingValues}, cfg)
+ userRepo := &oauthPendingFlowUserRepo{
+ client: client,
+ options: options.userRepoOptions,
+ }
+ redeemRepo := &oauthPendingFlowRedeemCodeRepo{client: client}
+ var emailService *service.EmailService
+ if options.emailCache != nil {
+ emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{
+ values: map[string]string{
+ service.SettingKeyEmailVerifyEnabled: boolSettingValue(options.emailVerifyEnabled),
+ },
+ }, options.emailCache)
+ }
+ authSvc := service.NewAuthService(
+ client,
+ userRepo,
+ redeemRepo,
+ &oauthPendingFlowRefreshTokenCacheStub{},
+ cfg,
+ settingSvc,
+ emailService,
+ nil,
+ nil,
+ nil,
+ options.defaultSubAssigner,
+ nil,
+ )
+ userSvc := service.NewUserService(userRepo, nil, nil, nil)
+ var totpSvc *service.TotpService
+ if options.totpCache != nil || options.totpEncryptor != nil {
+ totpCache := options.totpCache
+ if totpCache == nil {
+ totpCache = &oauthPendingFlowTotpCacheStub{}
+ }
+ totpEncryptor := options.totpEncryptor
+ if totpEncryptor == nil {
+ totpEncryptor = oauthPendingFlowTotpEncryptorStub{}
+ }
+ totpSvc = service.NewTotpService(userRepo, totpEncryptor, totpCache, settingSvc, nil, nil)
+ }
+
+ return &AuthHandler{
+ authService: authSvc,
+ userService: userSvc,
+ settingSvc: settingSvc,
+ totpService: totpSvc,
+ }, client
+}
+
+func boolSettingValue(v bool) string {
+ if v {
+ return "true"
+ }
+ return "false"
+}
+
+func boolPtr(v bool) *bool {
+ return &v
+}
+
+type oauthPendingFlowSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *oauthPendingFlowSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ return nil, service.ErrSettingNotFound
+}
+
+func (s *oauthPendingFlowSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ value, ok := s.values[key]
+ if !ok {
+ return "", service.ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) Set(context.Context, string, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ result := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ result[key] = value
+ }
+ }
+ return result, nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ result := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ result[key] = value
+ }
+ return result, nil
+}
+
+func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error {
+ return nil
+}
+
+type oauthPendingFlowRefreshTokenCacheStub struct{}
+
+type oauthPendingFlowEmailCacheStub struct {
+ verificationCodes map[string]*service.VerificationCodeData
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetVerificationCode(_ context.Context, email string) (*service.VerificationCodeData, error) {
+ if s == nil || s.verificationCodes == nil {
+ return nil, nil
+ }
+ return s.verificationCodes[email], nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetVerificationCode(_ context.Context, email string, data *service.VerificationCodeData, _ time.Duration) error {
+ if s.verificationCodes == nil {
+ s.verificationCodes = map[string]*service.VerificationCodeData{}
+ }
+ s.verificationCodes[email] = data
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) DeleteVerificationCode(_ context.Context, email string) error {
+ delete(s.verificationCodes, email)
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) DeletePasswordResetToken(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
+ return false
+}
+
+func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
+ return 0, nil
+}
+
+func (s *oauthPendingFlowEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
+ return nil, service.ErrRefreshTokenNotFound
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
+
+type oauthPendingFlowRedeemCodeRepo struct {
+ client *dbent.Client
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Create(context.Context, *service.RedeemCode) error {
+ panic("unexpected Create call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) CreateBatch(context.Context, []service.RedeemCode) error {
+ panic("unexpected CreateBatch call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) GetByID(context.Context, int64) (*service.RedeemCode, error) {
+ panic("unexpected GetByID call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
+ entity, err := r.client.RedeemCode.Query().Where(redeemcode.CodeEQ(code)).Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, service.ErrRedeemCodeNotFound
+ }
+ return nil, err
+ }
+ notes := ""
+ if entity.Notes != nil {
+ notes = *entity.Notes
+ }
+ return &service.RedeemCode{
+ ID: entity.ID,
+ Code: entity.Code,
+ Type: entity.Type,
+ Value: entity.Value,
+ Status: entity.Status,
+ UsedBy: entity.UsedBy,
+ UsedAt: entity.UsedAt,
+ Notes: notes,
+ CreatedAt: entity.CreatedAt,
+ GroupID: entity.GroupID,
+ ValidityDays: entity.ValidityDays,
+ }, nil
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) error {
+ if code == nil {
+ return nil
+ }
+ update := r.client.RedeemCode.UpdateOneID(code.ID).
+ SetCode(code.Code).
+ SetType(code.Type).
+ SetValue(code.Value).
+ SetStatus(code.Status).
+ SetNotes(code.Notes).
+ SetValidityDays(code.ValidityDays)
+ if code.UsedBy != nil {
+ update = update.SetUsedBy(*code.UsedBy)
+ } else {
+ update = update.ClearUsedBy()
+ }
+ if code.UsedAt != nil {
+ update = update.SetUsedAt(*code.UsedAt)
+ } else {
+ update = update.ClearUsedAt()
+ }
+ if code.GroupID != nil {
+ update = update.SetGroupID(*code.GroupID)
+ } else {
+ update = update.ClearGroupID()
+ }
+ _, err := update.Save(ctx)
+ return err
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Delete(context.Context, int64) error {
+ panic("unexpected Delete call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) Use(ctx context.Context, id, userID int64) error {
+ affected, err := r.client.RedeemCode.Update().
+ Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
+ SetStatus(service.StatusUsed).
+ SetUsedBy(userID).
+ SetUsedAt(time.Now().UTC()).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return service.ErrRedeemCodeUsed
+ }
+ return nil
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) List(context.Context, pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) ListByUser(context.Context, int64, int) ([]service.RedeemCode, error) {
+ panic("unexpected ListByUser call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListByUserPaginated call")
+}
+
+func (r *oauthPendingFlowRedeemCodeRepo) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
+ panic("unexpected SumPositiveBalanceByUser call")
+}
+
+func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
+ t.Helper()
+
+ var envelope struct {
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &envelope))
+ return envelope.Data
+}
+
+func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any {
+ t.Helper()
+
+ var payload map[string]any
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload))
+ return payload
+}
+
+type oauthPendingFlowAvatarRecord struct {
+ StorageProvider string
+ URL string
+}
+
+func loadUserAvatarRecord(t *testing.T, client *dbent.Client, userID int64) *oauthPendingFlowAvatarRecord {
+ t.Helper()
+
+ var rows entsql.Rows
+ err := client.Driver().Query(
+ context.Background(),
+ `SELECT storage_provider, url FROM user_avatars WHERE user_id = ?`,
+ []any{userID},
+ &rows,
+ )
+ require.NoError(t, err)
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ require.NoError(t, rows.Err())
+ return nil
+ }
+
+ var record oauthPendingFlowAvatarRecord
+ require.NoError(t, rows.Scan(&record.StorageProvider, &record.URL))
+ require.NoError(t, rows.Err())
+ return &record
+}
+
+func countProviderGrantRecords(
+ t *testing.T,
+ client *dbent.Client,
+ userID int64,
+ providerType string,
+ grantReason string,
+) int {
+ t.Helper()
+
+ var rows entsql.Rows
+ err := client.Driver().Query(
+ context.Background(),
+ `SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`,
+ []any{userID, providerType, grantReason},
+ &rows,
+ )
+ require.NoError(t, err)
+ defer func() { _ = rows.Close() }()
+
+ require.True(t, rows.Next())
+ var count int
+ require.NoError(t, rows.Scan(&count))
+ require.False(t, rows.Next())
+ return count
+}
+
+type oauthPendingFlowUserRepo struct {
+ client *dbent.Client
+ options oauthPendingFlowUserRepoOptions
+}
+
+type oauthPendingFlowUserRepoOptions struct {
+ rejectDeleteWhileAuthIdentityExists bool
+}
+
+func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error {
+ entity, err := r.client.User.Create().
+ SetEmail(user.Email).
+ SetUsername(user.Username).
+ SetNotes(user.Notes).
+ SetPasswordHash(user.PasswordHash).
+ SetRole(user.Role).
+ SetBalance(user.Balance).
+ SetConcurrency(user.Concurrency).
+ SetStatus(user.Status).
+ SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted).
+ SetTotpEnabled(user.TotpEnabled).
+ SetNillableTotpEnabledAt(user.TotpEnabledAt).
+ SetTotalRecharged(user.TotalRecharged).
+ SetSignupSource(user.SignupSource).
+ SetNillableLastLoginAt(user.LastLoginAt).
+ SetNillableLastActiveAt(user.LastActiveAt).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ user.ID = entity.ID
+ user.CreatedAt = entity.CreatedAt
+ user.UpdatedAt = entity.UpdatedAt
+ return nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
+ entity, err := r.client.User.Get(ctx, id)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, service.ErrUserNotFound
+ }
+ return nil, err
+ }
+ return oauthPendingFlowServiceUser(entity), nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
+ entity, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, service.ErrUserNotFound
+ }
+ return nil, err
+ }
+ return oauthPendingFlowServiceUser(entity), nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetFirstAdmin(context.Context) (*service.User, error) {
+ panic("unexpected GetFirstAdmin call")
+}
+
+func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.User) error {
+ entity, err := r.client.User.UpdateOneID(user.ID).
+ SetEmail(user.Email).
+ SetUsername(user.Username).
+ SetNotes(user.Notes).
+ SetPasswordHash(user.PasswordHash).
+ SetRole(user.Role).
+ SetBalance(user.Balance).
+ SetConcurrency(user.Concurrency).
+ SetStatus(user.Status).
+ SetNillableTotpSecretEncrypted(user.TotpSecretEncrypted).
+ SetTotpEnabled(user.TotpEnabled).
+ SetNillableTotpEnabledAt(user.TotpEnabledAt).
+ SetTotalRecharged(user.TotalRecharged).
+ SetSignupSource(user.SignupSource).
+ SetNillableLastLoginAt(user.LastLoginAt).
+ SetNillableLastActiveAt(user.LastActiveAt).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ user.UpdatedAt = entity.UpdatedAt
+ return nil
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ return r.client.User.UpdateOneID(userID).SetLastActiveAt(activeAt).Exec(ctx)
+}
+
+func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error {
+ if r.options.rejectDeleteWhileAuthIdentityExists {
+ count, err := r.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(id)).Count(ctx)
+ if err != nil {
+ return err
+ }
+ if count > 0 {
+ return errors.New("cannot delete user while auth identities still exist")
+ }
+ }
+ return r.client.User.DeleteOneID(id).Exec(ctx)
+}
+
+func (r *oauthPendingFlowUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ driver := r.client.Driver()
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ driver = tx.Client().Driver()
+ }
+
+ var rows entsql.Rows
+ if err := driver.Query(
+ ctx,
+ `SELECT storage_provider, storage_key, url, content_type, byte_size, sha256 FROM user_avatars WHERE user_id = ?`,
+ []any{userID},
+ &rows,
+ ); err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ return nil, rows.Err()
+ }
+
+ var avatar service.UserAvatar
+ if err := rows.Scan(
+ &avatar.StorageProvider,
+ &avatar.StorageKey,
+ &avatar.URL,
+ &avatar.ContentType,
+ &avatar.ByteSize,
+ &avatar.SHA256,
+ ); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return &avatar, nil
+}
+
+func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ driver := r.client.Driver()
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ driver = tx.Client().Driver()
+ }
+
+ var result entsql.Result
+ if err := driver.Exec(
+ ctx,
+ `INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at)
+VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
+ON CONFLICT(user_id) DO UPDATE SET
+ storage_provider = excluded.storage_provider,
+ storage_key = excluded.storage_key,
+ url = excluded.url,
+ content_type = excluded.content_type,
+ byte_size = excluded.byte_size,
+ sha256 = excluded.sha256,
+ updated_at = CURRENT_TIMESTAMP`,
+ []any{
+ userID,
+ input.StorageProvider,
+ input.StorageKey,
+ input.URL,
+ input.ContentType,
+ input.ByteSize,
+ input.SHA256,
+ },
+ &result,
+ ); err != nil {
+ return nil, err
+ }
+
+ return &service.UserAvatar{
+ StorageProvider: input.StorageProvider,
+ StorageKey: input.StorageKey,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
+}
+
+func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ driver := r.client.Driver()
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ driver = tx.Client().Driver()
+ }
+
+ var result entsql.Result
+ return driver.Exec(ctx, `DELETE FROM user_avatars WHERE user_id = ?`, []any{userID}, &result)
+}
+
+func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (r *oauthPendingFlowUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateBalance(context.Context, int64, float64) error {
+ panic("unexpected UpdateBalance call")
+}
+
+func (r *oauthPendingFlowUserRepo) DeductBalance(context.Context, int64, float64) error {
+ panic("unexpected DeductBalance call")
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int) error {
+ panic("unexpected UpdateConcurrency call")
+}
+
+func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
+ count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx)
+ return count > 0, err
+}
+
+func (r *oauthPendingFlowUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ panic("unexpected RemoveGroupFromAllowedGroups call")
+}
+
+func (r *oauthPendingFlowUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ panic("unexpected AddGroupToAllowedGroups call")
+}
+
+func (r *oauthPendingFlowUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ panic("unexpected RemoveGroupFromUserAllowedGroups call")
+}
+
+func (r *oauthPendingFlowUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ identities, err := r.client.AuthIdentity.Query().
+ Where(authidentity.UserIDEQ(userID)).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ records := make([]service.UserAuthIdentityRecord, 0, len(identities))
+ for _, identity := range identities {
+ if identity == nil {
+ continue
+ }
+ records = append(records, service.UserAuthIdentityRecord{
+ ProviderType: identity.ProviderType,
+ ProviderKey: identity.ProviderKey,
+ ProviderSubject: identity.ProviderSubject,
+ VerifiedAt: identity.VerifiedAt,
+ Issuer: identity.Issuer,
+ Metadata: identity.Metadata,
+ CreatedAt: identity.CreatedAt,
+ UpdatedAt: identity.UpdatedAt,
+ })
+ }
+ return records, nil
+}
+
+func (r *oauthPendingFlowUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error {
+ panic("unexpected UnbindUserAuthProvider call")
+}
+
+func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
+ update := r.client.User.UpdateOneID(userID)
+ if encryptedSecret == nil {
+ update = update.ClearTotpSecretEncrypted()
+ } else {
+ update = update.SetTotpSecretEncrypted(*encryptedSecret)
+ }
+ return update.Exec(ctx)
+}
+
+func (r *oauthPendingFlowUserRepo) EnableTotp(ctx context.Context, userID int64) error {
+ return r.client.User.UpdateOneID(userID).
+ SetTotpEnabled(true).
+ SetTotpEnabledAt(time.Now().UTC()).
+ Exec(ctx)
+}
+
+func (r *oauthPendingFlowUserRepo) DisableTotp(ctx context.Context, userID int64) error {
+ return r.client.User.UpdateOneID(userID).
+ SetTotpEnabled(false).
+ ClearTotpSecretEncrypted().
+ ClearTotpEnabledAt().
+ Exec(ctx)
+}
+
+func oauthPendingFlowServiceUser(entity *dbent.User) *service.User {
+ if entity == nil {
+ return nil
+ }
+ return &service.User{
+ ID: entity.ID,
+ Email: entity.Email,
+ Username: entity.Username,
+ Notes: entity.Notes,
+ PasswordHash: entity.PasswordHash,
+ Role: entity.Role,
+ Balance: entity.Balance,
+ Concurrency: entity.Concurrency,
+ Status: entity.Status,
+ SignupSource: entity.SignupSource,
+ LastLoginAt: entity.LastLoginAt,
+ LastActiveAt: entity.LastActiveAt,
+ TotpSecretEncrypted: entity.TotpSecretEncrypted,
+ TotpEnabled: entity.TotpEnabled,
+ TotpEnabledAt: entity.TotpEnabledAt,
+ TotalRecharged: entity.TotalRecharged,
+ CreatedAt: entity.CreatedAt,
+ UpdatedAt: entity.UpdatedAt,
+ }
+}
+
+type oauthPendingFlowDefaultSubAssignerStub struct {
+ calls []service.AssignSubscriptionInput
+}
+
+func (s *oauthPendingFlowDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ if input != nil {
+ s.calls = append(s.calls, *input)
+ }
+ return nil, false, nil
+}
+
+type oauthPendingFlowTotpCacheStub struct {
+ setupSessions map[int64]*service.TotpSetupSession
+ loginSessions map[string]*service.TotpLoginSession
+ verifyAttempts map[int64]int
+}
+
+func (s *oauthPendingFlowTotpCacheStub) GetSetupSession(_ context.Context, userID int64) (*service.TotpSetupSession, error) {
+ if s == nil || s.setupSessions == nil {
+ return nil, nil
+ }
+ return s.setupSessions[userID], nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) SetSetupSession(_ context.Context, userID int64, session *service.TotpSetupSession, _ time.Duration) error {
+ if s.setupSessions == nil {
+ s.setupSessions = map[int64]*service.TotpSetupSession{}
+ }
+ s.setupSessions[userID] = session
+ return nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) DeleteSetupSession(_ context.Context, userID int64) error {
+ delete(s.setupSessions, userID)
+ return nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) GetLoginSession(_ context.Context, tempToken string) (*service.TotpLoginSession, error) {
+ if s == nil || s.loginSessions == nil {
+ return nil, nil
+ }
+ return s.loginSessions[tempToken], nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) SetLoginSession(_ context.Context, tempToken string, session *service.TotpLoginSession, _ time.Duration) error {
+ if s.loginSessions == nil {
+ s.loginSessions = map[string]*service.TotpLoginSession{}
+ }
+ s.loginSessions[tempToken] = session
+ return nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) DeleteLoginSession(_ context.Context, tempToken string) error {
+ delete(s.loginSessions, tempToken)
+ return nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) IncrementVerifyAttempts(_ context.Context, userID int64) (int, error) {
+ if s.verifyAttempts == nil {
+ s.verifyAttempts = map[int64]int{}
+ }
+ s.verifyAttempts[userID]++
+ return s.verifyAttempts[userID], nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) GetVerifyAttempts(_ context.Context, userID int64) (int, error) {
+ if s == nil || s.verifyAttempts == nil {
+ return 0, nil
+ }
+ return s.verifyAttempts[userID], nil
+}
+
+func (s *oauthPendingFlowTotpCacheStub) ClearVerifyAttempts(_ context.Context, userID int64) error {
+ delete(s.verifyAttempts, userID)
+ return nil
+}
+
+type oauthPendingFlowTotpEncryptorStub struct{}
+
+func (oauthPendingFlowTotpEncryptorStub) Encrypt(plaintext string) (string, error) {
+ return plaintext, nil
+}
+
+func (oauthPendingFlowTotpEncryptorStub) Decrypt(ciphertext string) (string, error) {
+ return ciphertext, nil
+}
diff --git a/backend/internal/handler/auth_oauth_test_helpers_test.go b/backend/internal/handler/auth_oauth_test_helpers_test.go
new file mode 100644
index 00000000..47bad942
--- /dev/null
+++ b/backend/internal/handler/auth_oauth_test_helpers_test.go
@@ -0,0 +1,57 @@
+package handler
+
+import (
+ "net/http"
+ "net/url"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func buildEncodedOAuthBindUserCookie(t *testing.T, userID int64, secret string) string {
+ t.Helper()
+ value, err := buildOAuthBindUserCookieValue(userID, secret)
+ require.NoError(t, err)
+ return value
+}
+
+func encodedCookie(name, value string) *http.Cookie {
+ return &http.Cookie{
+ Name: name,
+ Value: encodeCookieValue(value),
+ Path: "/",
+ }
+}
+
+func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
+ for _, cookie := range cookies {
+ if cookie.Name == name {
+ return cookie
+ }
+ }
+ return nil
+}
+
+func decodeCookieValueForTest(t *testing.T, value string) string {
+ t.Helper()
+ decoded, err := decodeCookieValue(value)
+ require.NoError(t, err)
+ return decoded
+}
+
+func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) {
+ t.Helper()
+ require.NotEmpty(t, location)
+
+ parsed, err := url.Parse(location)
+ require.NoError(t, err)
+
+ rawValues := parsed.RawQuery
+ if rawValues == "" {
+ rawValues = parsed.Fragment
+ }
+ values, err := url.ParseQuery(rawValues)
+ require.NoError(t, err)
+ require.Equal(t, errorCode, values.Get("error"))
+ require.Equal(t, errorMessage, values.Get("error_message"))
+}
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
new file mode 100644
index 00000000..4264002d
--- /dev/null
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -0,0 +1,1191 @@
+package handler
+
+import (
+ "context"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rsa"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log"
+ "math/big"
+ "net/http"
+ "net/url"
+ "strconv"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+ "github.com/golang-jwt/jwt/v5"
+ "github.com/imroc/req/v3"
+ "github.com/tidwall/gjson"
+)
+
+const (
+ oidcOAuthCookiePath = "/api/v1/auth/oauth/oidc"
+ oidcOAuthStateCookieName = "oidc_oauth_state"
+ oidcOAuthVerifierCookie = "oidc_oauth_verifier"
+ oidcOAuthRedirectCookie = "oidc_oauth_redirect"
+ oidcOAuthNonceCookie = "oidc_oauth_nonce"
+ oidcOAuthIntentCookieName = "oidc_oauth_intent"
+ oidcOAuthBindUserCookieName = "oidc_oauth_bind_user"
+ oidcOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
+ oidcOAuthDefaultRedirectTo = "/dashboard"
+ oidcOAuthDefaultFrontendCB = "/auth/oidc/callback"
+)
+
+type oidcTokenResponse struct {
+ AccessToken string `json:"access_token"`
+ TokenType string `json:"token_type"`
+ ExpiresIn int64 `json:"expires_in"`
+ RefreshToken string `json:"refresh_token,omitempty"`
+ Scope string `json:"scope,omitempty"`
+ IDToken string `json:"id_token,omitempty"`
+}
+
+type oidcTokenExchangeError struct {
+ StatusCode int
+ ProviderError string
+ ProviderDescription string
+ Body string
+}
+
+func (e *oidcTokenExchangeError) Error() string {
+ if e == nil {
+ return ""
+ }
+ parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)}
+ if strings.TrimSpace(e.ProviderError) != "" {
+ parts = append(parts, "error="+strings.TrimSpace(e.ProviderError))
+ }
+ if strings.TrimSpace(e.ProviderDescription) != "" {
+ parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription))
+ }
+ return strings.Join(parts, " ")
+}
+
+type oidcIDTokenClaims struct {
+ Email string `json:"email,omitempty"`
+ EmailVerified *bool `json:"email_verified,omitempty"`
+ PreferredUsername string `json:"preferred_username,omitempty"`
+ Name string `json:"name,omitempty"`
+ Nonce string `json:"nonce,omitempty"`
+ Azp string `json:"azp,omitempty"`
+ jwt.RegisteredClaims
+}
+
+type oidcUserInfoClaims struct {
+ Email string
+ Username string
+ Subject string
+ EmailVerified *bool
+ DisplayName string
+ AvatarURL string
+}
+
+type oidcJWKSet struct {
+ Keys []oidcJWK `json:"keys"`
+}
+
+type oidcJWK struct {
+ Kty string `json:"kty"`
+ Kid string `json:"kid"`
+ Use string `json:"use"`
+ Alg string `json:"alg"`
+
+ N string `json:"n"`
+ E string `json:"e"`
+
+ Crv string `json:"crv"`
+ X string `json:"x"`
+ Y string `json:"y"`
+}
+
+// OIDCOAuthStart 启动通用 OIDC OAuth 登录流程。
+// GET /api/v1/auth/oauth/oidc/start?redirect=/dashboard
+func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) {
+ cfg, err := h.getOIDCOAuthConfig(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ state, err := oauth.GenerateState()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
+ return
+ }
+
+ redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
+ if redirectTo == "" {
+ redirectTo = oidcOAuthDefaultRedirectTo
+ }
+
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ oidcSetCookie(c, oidcOAuthStateCookieName, encodeCookieValue(state), oidcOAuthCookieMaxAgeSec, secureCookie)
+ oidcSetCookie(c, oidcOAuthRedirectCookie, encodeCookieValue(redirectTo), oidcOAuthCookieMaxAgeSec, secureCookie)
+ intent := normalizeOAuthIntent(c.Query("intent"))
+ oidcSetCookie(c, oidcOAuthIntentCookieName, encodeCookieValue(intent), oidcOAuthCookieMaxAgeSec, secureCookie)
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ if intent == oauthIntentBindCurrentUser {
+ bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ oidcSetCookie(c, oidcOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), oidcOAuthCookieMaxAgeSec, secureCookie)
+ } else {
+ oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie)
+ }
+
+ codeChallenge := ""
+ if cfg.UsePKCE {
+ verifier, genErr := oauth.GenerateCodeVerifier()
+ if genErr != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr))
+ return
+ }
+ codeChallenge = oauth.GenerateCodeChallenge(verifier)
+ oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie)
+ }
+
+ nonce := ""
+ if cfg.ValidateIDToken {
+ nonce, err = oauth.GenerateState()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err))
+ return
+ }
+ oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie)
+ }
+
+ redirectURI := strings.TrimSpace(cfg.RedirectURL)
+ if redirectURI == "" {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured"))
+ return
+ }
+
+ authURL, err := buildOIDCAuthorizeURL(cfg, state, nonce, codeChallenge, redirectURI)
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
+ return
+ }
+
+ c.Redirect(http.StatusFound, authURL)
+}
+
+// OIDCOAuthCallback 处理 OIDC 回调:校验 id_token、创建/登录用户并重定向到前端。
+// GET /api/v1/auth/oauth/oidc/callback?code=...&state=...
+func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
+ cfg, cfgErr := h.getOIDCOAuthConfig(c.Request.Context())
+ if cfgErr != nil {
+ response.ErrorFrom(c, cfgErr)
+ return
+ }
+
+ frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
+ if frontendCallback == "" {
+ frontendCallback = oidcOAuthDefaultFrontendCB
+ }
+
+ if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
+ return
+ }
+
+ code := strings.TrimSpace(c.Query("code"))
+ state := strings.TrimSpace(c.Query("state"))
+ if code == "" || state == "" {
+ redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ defer func() {
+ oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie)
+ oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie)
+ oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie)
+ oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie)
+ oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie)
+ oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie)
+ }()
+
+ expectedState, err := readCookieDecoded(c, oidcOAuthStateCookieName)
+ if err != nil || expectedState == "" || state != expectedState {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
+ return
+ }
+
+ redirectTo, _ := readCookieDecoded(c, oidcOAuthRedirectCookie)
+ redirectTo = sanitizeFrontendRedirectPath(redirectTo)
+ if redirectTo == "" {
+ redirectTo = oidcOAuthDefaultRedirectTo
+ }
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
+ return
+ }
+ intent, _ := readCookieDecoded(c, oidcOAuthIntentCookieName)
+ intent = normalizeOAuthIntent(intent)
+
+ codeVerifier := ""
+ if cfg.UsePKCE {
+ codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie)
+ if codeVerifier == "" {
+ redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
+ return
+ }
+ }
+
+ expectedNonce := ""
+ if cfg.ValidateIDToken {
+ expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie)
+ if expectedNonce == "" {
+ redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "")
+ return
+ }
+ }
+
+ redirectURI := strings.TrimSpace(cfg.RedirectURL)
+ if redirectURI == "" {
+ redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "")
+ return
+ }
+
+ tokenResp, err := oidcExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier)
+ if err != nil {
+ description := ""
+ var exchangeErr *oidcTokenExchangeError
+ if errors.As(err, &exchangeErr) && exchangeErr != nil {
+ log.Printf(
+ "[OIDC OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s",
+ exchangeErr.StatusCode,
+ exchangeErr.ProviderError,
+ exchangeErr.ProviderDescription,
+ truncateLogValue(exchangeErr.Body, 2048),
+ )
+ description = exchangeErr.Error()
+ } else {
+ log.Printf("[OIDC OAuth] token exchange failed: %v", err)
+ description = err.Error()
+ }
+ redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description))
+ return
+ }
+
+ var idClaims *oidcIDTokenClaims
+ if cfg.ValidateIDToken {
+ if strings.TrimSpace(tokenResp.IDToken) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
+ return
+ }
+
+ idClaims, err = oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce)
+ if err != nil {
+ log.Printf("[OIDC OAuth] id_token validation failed: %v", err)
+ redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "")
+ return
+ }
+ }
+
+ userInfoClaims, err := oidcFetchUserInfo(c.Request.Context(), cfg, tokenResp)
+ if err != nil {
+ log.Printf("[OIDC OAuth] userinfo fetch failed: %v", err)
+ redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
+ return
+ }
+
+ subject := ""
+ if idClaims != nil {
+ subject = strings.TrimSpace(idClaims.Subject)
+ }
+ if subject == "" {
+ subject = strings.TrimSpace(userInfoClaims.Subject)
+ }
+ if subject == "" {
+ redirectOAuthError(c, frontendCallback, "missing_subject", "missing subject claim", "")
+ return
+ }
+ issuer := ""
+ if idClaims != nil {
+ issuer = strings.TrimSpace(idClaims.Issuer)
+ }
+ if issuer == "" {
+ issuer = strings.TrimSpace(cfg.IssuerURL)
+ }
+ if issuer == "" {
+ redirectOAuthError(c, frontendCallback, "missing_issuer", "missing issuer claim", "")
+ return
+ }
+
+ emailVerified := userInfoClaims.EmailVerified
+ if emailVerified == nil && idClaims != nil {
+ emailVerified = idClaims.EmailVerified
+ }
+ if idClaims != nil && userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
+ redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "")
+ return
+ }
+
+ identityKey := oidcIdentityKey(issuer, subject)
+ compatEmail := strings.TrimSpace(userInfoClaims.Email)
+ if compatEmail == "" && idClaims != nil {
+ compatEmail = strings.TrimSpace(idClaims.Email)
+ }
+ email := oidcSyntheticEmailFromIdentityKey(identityKey)
+ username := firstNonEmpty(
+ userInfoClaims.Username,
+ func() string {
+ if idClaims != nil {
+ return idClaims.PreferredUsername
+ }
+ return ""
+ }(),
+ func() string {
+ if idClaims != nil {
+ return idClaims.Name
+ }
+ return ""
+ }(),
+ oidcFallbackUsername(subject),
+ )
+ identityRef := service.PendingAuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: issuer,
+ ProviderSubject: subject,
+ }
+ upstreamClaims := map[string]any{
+ "email": email,
+ "username": username,
+ "subject": subject,
+ "issuer": issuer,
+ "email_verified": emailVerified != nil && *emailVerified,
+ "provider_fallback": strings.TrimSpace(cfg.ProviderName),
+ "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, func() string {
+ if idClaims != nil {
+ return idClaims.Name
+ }
+ return ""
+ }(), username),
+ "suggested_avatar_url": userInfoClaims.AvatarURL,
+ }
+ if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
+ upstreamClaims["compat_email"] = compatEmail
+ }
+ if intent == oauthIntentBindCurrentUser {
+ targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth bind target", "")
+ return
+ }
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentBindCurrentUser,
+ Identity: identityRef,
+ TargetUserID: &targetUserID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth bind", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if existingIdentityUser != nil {
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identityRef,
+ TargetUserID: &existingIdentityUser.ID,
+ ResolvedEmail: existingIdentityUser.Email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{
+ "redirect": redirectTo,
+ },
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ compatEmailUser, err := h.findOIDCCompatEmailUser(c.Request.Context(), compatEmail)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+
+ if cfg.RequireEmailVerified {
+ if emailVerified == nil || !*emailVerified {
+ redirectOAuthError(c, frontendCallback, "email_not_verified", "email is not verified", "")
+ return
+ }
+ }
+
+ if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
+ if err := h.createOIDCOAuthChoicePendingSession(
+ c,
+ identityRef,
+ email,
+ email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ compatEmail,
+ compatEmailUser,
+ true,
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if err := h.createOIDCOAuthChoicePendingSession(
+ c,
+ identityRef,
+ email,
+ email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ compatEmail,
+ compatEmailUser,
+ h.isForceEmailOnThirdPartySignup(c.Request.Context()),
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+}
+
+func (h *AuthHandler) findOIDCCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ email = strings.TrimSpace(strings.ToLower(email))
+ if email == "" ||
+ strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
+ return nil, nil
+ }
+
+ userEntity, err := findUserByNormalizedEmail(ctx, client, email)
+ if err != nil {
+ if errors.Is(err, service.ErrUserNotFound) {
+ return nil, nil
+ }
+ return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
+ }
+ return userEntity, nil
+}
+
+func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
+ c *gin.Context,
+ identity service.PendingAuthIdentityKey,
+ suggestedEmail string,
+ resolvedEmail string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+ compatEmail string,
+ compatEmailUser *dbent.User,
+ forceEmailOnSignup bool,
+) error {
+ suggestionEmail := strings.TrimSpace(suggestedEmail)
+ canonicalEmail := strings.TrimSpace(resolvedEmail)
+ if suggestionEmail == "" {
+ suggestionEmail = canonicalEmail
+ }
+
+ completionResponse := map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "redirect": strings.TrimSpace(redirectTo),
+ "email": suggestionEmail,
+ "resolved_email": canonicalEmail,
+ "existing_account_email": "",
+ "existing_account_bindable": false,
+ "create_account_allowed": true,
+ "force_email_on_signup": forceEmailOnSignup,
+ "choice_reason": "third_party_signup",
+ }
+ if strings.TrimSpace(compatEmail) != "" {
+ completionResponse["compat_email"] = strings.TrimSpace(compatEmail)
+ }
+ if compatEmailUser != nil {
+ completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_bindable"] = true
+ completionResponse["choice_reason"] = "compat_email_match"
+ }
+ if forceEmailOnSignup && compatEmailUser == nil {
+ completionResponse["choice_reason"] = "force_email_on_signup"
+ }
+
+ resolvedChoiceEmail := suggestionEmail
+ if compatEmailUser != nil {
+ resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
+ }
+ var targetUserID *int64
+ if compatEmailUser != nil && compatEmailUser.ID > 0 {
+ targetUserID = &compatEmailUser.ID
+ }
+
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identity,
+ TargetUserID: targetUserID,
+ ResolvedEmail: resolvedChoiceEmail,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ })
+}
+
+type completeOIDCOAuthRequest struct {
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AffCode string `json:"aff_code,omitempty"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+// CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating
+// the invitation code and creating the user account.
+// POST /api/v1/auth/oauth/oidc/complete-registration
+func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
+ var req completeOIDCOAuthRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ } else if handled {
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
+ return
+ } else {
+ session = updatedSession
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
+ return
+ }
+
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+ if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+
+ c.JSON(http.StatusOK, gin.H{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ })
+}
+
+func (h *AuthHandler) getOIDCOAuthConfig(ctx context.Context) (config.OIDCConnectConfig, error) {
+ if h != nil && h.settingSvc != nil {
+ return h.settingSvc.GetOIDCConnectOAuthConfig(ctx)
+ }
+ if h == nil || h.cfg == nil {
+ return config.OIDCConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
+ }
+ if !h.cfg.OIDC.Enabled {
+ return config.OIDCConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
+ }
+ return h.cfg.OIDC, nil
+}
+
+func oidcExchangeCode(
+ ctx context.Context,
+ cfg config.OIDCConnectConfig,
+ code string,
+ redirectURI string,
+ codeVerifier string,
+) (*oidcTokenResponse, error) {
+ client := req.C().SetTimeout(30 * time.Second)
+
+ form := url.Values{}
+ form.Set("grant_type", "authorization_code")
+ form.Set("client_id", cfg.ClientID)
+ form.Set("code", code)
+ form.Set("redirect_uri", redirectURI)
+ if strings.TrimSpace(codeVerifier) != "" {
+ form.Set("code_verifier", codeVerifier)
+ }
+
+ r := client.R().
+ SetContext(ctx).
+ SetHeader("Accept", "application/json")
+
+ switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) {
+ case "", "client_secret_post":
+ form.Set("client_secret", cfg.ClientSecret)
+ case "client_secret_basic":
+ r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret)
+ case "none":
+ default:
+ return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod)
+ }
+
+ resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL)
+ if err != nil {
+ return nil, fmt.Errorf("request token: %w", err)
+ }
+ body := strings.TrimSpace(resp.String())
+ if !resp.IsSuccessState() {
+ providerErr, providerDesc := parseOAuthProviderError(body)
+ return nil, &oidcTokenExchangeError{
+ StatusCode: resp.StatusCode,
+ ProviderError: providerErr,
+ ProviderDescription: providerDesc,
+ Body: body,
+ }
+ }
+
+ tokenResp, ok := oidcParseTokenResponse(body)
+ if !ok {
+ return nil, &oidcTokenExchangeError{StatusCode: resp.StatusCode, Body: body}
+ }
+ if strings.TrimSpace(tokenResp.TokenType) == "" {
+ tokenResp.TokenType = "Bearer"
+ }
+ if strings.TrimSpace(tokenResp.AccessToken) == "" && strings.TrimSpace(tokenResp.IDToken) == "" {
+ return nil, &oidcTokenExchangeError{StatusCode: resp.StatusCode, Body: body}
+ }
+ return tokenResp, nil
+}
+
+func oidcParseTokenResponse(body string) (*oidcTokenResponse, bool) {
+ body = strings.TrimSpace(body)
+ if body == "" {
+ return nil, false
+ }
+
+ accessToken := strings.TrimSpace(getGJSON(body, "access_token"))
+ idToken := strings.TrimSpace(getGJSON(body, "id_token"))
+ if accessToken != "" || idToken != "" {
+ tokenType := strings.TrimSpace(getGJSON(body, "token_type"))
+ refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token"))
+ scope := strings.TrimSpace(getGJSON(body, "scope"))
+ expiresIn := gjson.Get(body, "expires_in").Int()
+ return &oidcTokenResponse{
+ AccessToken: accessToken,
+ TokenType: tokenType,
+ ExpiresIn: expiresIn,
+ RefreshToken: refreshToken,
+ Scope: scope,
+ IDToken: idToken,
+ }, true
+ }
+
+ values, err := url.ParseQuery(body)
+ if err != nil {
+ return nil, false
+ }
+ accessToken = strings.TrimSpace(values.Get("access_token"))
+ idToken = strings.TrimSpace(values.Get("id_token"))
+ if accessToken == "" && idToken == "" {
+ return nil, false
+ }
+ expiresIn := int64(0)
+ if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" {
+ if v, parseErr := strconv.ParseInt(raw, 10, 64); parseErr == nil {
+ expiresIn = v
+ }
+ }
+ return &oidcTokenResponse{
+ AccessToken: accessToken,
+ TokenType: strings.TrimSpace(values.Get("token_type")),
+ ExpiresIn: expiresIn,
+ RefreshToken: strings.TrimSpace(values.Get("refresh_token")),
+ Scope: strings.TrimSpace(values.Get("scope")),
+ IDToken: idToken,
+ }, true
+}
+
+func oidcFetchUserInfo(
+ ctx context.Context,
+ cfg config.OIDCConnectConfig,
+ token *oidcTokenResponse,
+) (*oidcUserInfoClaims, error) {
+ if strings.TrimSpace(cfg.UserInfoURL) == "" {
+ return &oidcUserInfoClaims{}, nil
+ }
+ if token == nil || strings.TrimSpace(token.AccessToken) == "" {
+ return nil, errors.New("missing access_token for userinfo request")
+ }
+
+ client := req.C().SetTimeout(30 * time.Second)
+ authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
+ if err != nil {
+ return nil, fmt.Errorf("invalid token for userinfo request: %w", err)
+ }
+
+ resp, err := client.R().
+ SetContext(ctx).
+ SetHeader("Accept", "application/json").
+ SetHeader("Authorization", authorization).
+ Get(cfg.UserInfoURL)
+ if err != nil {
+ return nil, fmt.Errorf("request userinfo: %w", err)
+ }
+ if !resp.IsSuccessState() {
+ return nil, fmt.Errorf("userinfo status=%d", resp.StatusCode)
+ }
+
+ return oidcParseUserInfo(resp.String(), cfg), nil
+}
+
+func oidcParseUserInfo(body string, cfg config.OIDCConnectConfig) *oidcUserInfoClaims {
+ claims := &oidcUserInfoClaims{}
+ claims.Email = firstNonEmpty(
+ getGJSON(body, cfg.UserInfoEmailPath),
+ getGJSON(body, "email"),
+ getGJSON(body, "user.email"),
+ getGJSON(body, "data.email"),
+ getGJSON(body, "attributes.email"),
+ )
+ claims.Username = firstNonEmpty(
+ getGJSON(body, cfg.UserInfoUsernamePath),
+ getGJSON(body, "preferred_username"),
+ getGJSON(body, "username"),
+ getGJSON(body, "name"),
+ getGJSON(body, "user.username"),
+ getGJSON(body, "user.name"),
+ )
+ claims.Subject = firstNonEmpty(
+ getGJSON(body, cfg.UserInfoIDPath),
+ getGJSON(body, "sub"),
+ getGJSON(body, "id"),
+ getGJSON(body, "user_id"),
+ getGJSON(body, "uid"),
+ getGJSON(body, "user.id"),
+ )
+ if verified, ok := getGJSONBool(body, "email_verified"); ok {
+ claims.EmailVerified = &verified
+ }
+ claims.DisplayName = firstNonEmpty(
+ getGJSON(body, "name"),
+ getGJSON(body, "nickname"),
+ getGJSON(body, "display_name"),
+ getGJSON(body, "preferred_username"),
+ getGJSON(body, "username"),
+ )
+ claims.AvatarURL = firstNonEmpty(
+ getGJSON(body, "picture"),
+ getGJSON(body, "avatar_url"),
+ getGJSON(body, "avatar"),
+ getGJSON(body, "profile_image_url"),
+ getGJSON(body, "user.avatar"),
+ getGJSON(body, "user.avatar_url"),
+ )
+ claims.Email = strings.TrimSpace(claims.Email)
+ claims.Username = strings.TrimSpace(claims.Username)
+ claims.Subject = strings.TrimSpace(claims.Subject)
+ claims.DisplayName = strings.TrimSpace(claims.DisplayName)
+ claims.AvatarURL = strings.TrimSpace(claims.AvatarURL)
+ return claims
+}
+
+func getGJSONBool(body string, path string) (bool, bool) {
+ path = strings.TrimSpace(path)
+ if path == "" {
+ return false, false
+ }
+ res := gjson.Get(body, path)
+ if !res.Exists() {
+ return false, false
+ }
+ return res.Bool(), true
+}
+
+func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChallenge, redirectURI string) (string, error) {
+ u, err := url.Parse(cfg.AuthorizeURL)
+ if err != nil {
+ return "", fmt.Errorf("parse authorize_url: %w", err)
+ }
+
+ q := u.Query()
+ q.Set("response_type", "code")
+ q.Set("client_id", cfg.ClientID)
+ q.Set("redirect_uri", redirectURI)
+ if strings.TrimSpace(cfg.Scopes) != "" {
+ q.Set("scope", cfg.Scopes)
+ }
+ q.Set("state", state)
+ if strings.TrimSpace(nonce) != "" {
+ q.Set("nonce", nonce)
+ }
+ if strings.TrimSpace(codeChallenge) != "" {
+ q.Set("code_challenge", codeChallenge)
+ q.Set("code_challenge_method", "S256")
+ }
+
+ u.RawQuery = q.Encode()
+ return u.String(), nil
+}
+
+func oidcParseAndValidateIDToken(ctx context.Context, cfg config.OIDCConnectConfig, idToken string, expectedNonce string) (*oidcIDTokenClaims, error) {
+ idToken = strings.TrimSpace(idToken)
+ if idToken == "" {
+ return nil, errors.New("missing id_token")
+ }
+ allowed := oidcAllowedSigningAlgs(cfg.AllowedSigningAlgs)
+ if len(allowed) == 0 {
+ return nil, errors.New("empty allowed signing algorithms")
+ }
+
+ jwks, err := oidcFetchJWKSet(ctx, cfg.JWKSURL)
+ if err != nil {
+ return nil, err
+ }
+ leeway := time.Duration(cfg.ClockSkewSeconds) * time.Second
+ claims := &oidcIDTokenClaims{}
+
+ parsed, err := jwt.ParseWithClaims(
+ idToken,
+ claims,
+ func(token *jwt.Token) (any, error) {
+ alg := strings.TrimSpace(token.Method.Alg())
+ if !containsString(allowed, alg) {
+ return nil, fmt.Errorf("unexpected signing algorithm: %s", alg)
+ }
+ kid, _ := token.Header["kid"].(string)
+ return oidcFindPublicKey(jwks, strings.TrimSpace(kid), alg)
+ },
+ jwt.WithValidMethods(allowed),
+ jwt.WithAudience(cfg.ClientID),
+ jwt.WithIssuer(cfg.IssuerURL),
+ jwt.WithLeeway(leeway),
+ )
+ if err != nil {
+ return nil, err
+ }
+ if !parsed.Valid {
+ return nil, errors.New("id_token invalid")
+ }
+ if strings.TrimSpace(claims.Subject) == "" {
+ return nil, errors.New("id_token missing sub")
+ }
+ if expectedNonce != "" && strings.TrimSpace(claims.Nonce) != strings.TrimSpace(expectedNonce) {
+ return nil, errors.New("id_token nonce mismatch")
+ }
+ if len(claims.Audience) > 1 {
+ if strings.TrimSpace(claims.Azp) == "" || strings.TrimSpace(claims.Azp) != strings.TrimSpace(cfg.ClientID) {
+ return nil, errors.New("id_token azp mismatch")
+ }
+ }
+ return claims, nil
+}
+
+func oidcAllowedSigningAlgs(raw string) []string {
+ if strings.TrimSpace(raw) == "" {
+ return []string{"RS256", "ES256", "PS256"}
+ }
+ seen := make(map[string]struct{})
+ out := make([]string, 0, 4)
+ for _, part := range strings.Split(raw, ",") {
+ alg := strings.ToUpper(strings.TrimSpace(part))
+ if alg == "" {
+ continue
+ }
+ if _, ok := seen[alg]; ok {
+ continue
+ }
+ seen[alg] = struct{}{}
+ out = append(out, alg)
+ }
+ return out
+}
+
+func oidcFetchJWKSet(ctx context.Context, jwksURL string) (*oidcJWKSet, error) {
+ jwksURL = strings.TrimSpace(jwksURL)
+ if jwksURL == "" {
+ return nil, errors.New("missing jwks_url")
+ }
+ resp, err := req.C().
+ SetTimeout(30*time.Second).
+ R().
+ SetContext(ctx).
+ SetHeader("Accept", "application/json").
+ Get(jwksURL)
+ if err != nil {
+ return nil, fmt.Errorf("request jwks: %w", err)
+ }
+ if !resp.IsSuccessState() {
+ return nil, fmt.Errorf("jwks status=%d", resp.StatusCode)
+ }
+ set := &oidcJWKSet{}
+ if err := json.Unmarshal(resp.Bytes(), set); err != nil {
+ return nil, fmt.Errorf("parse jwks: %w", err)
+ }
+ if len(set.Keys) == 0 {
+ return nil, errors.New("jwks empty keys")
+ }
+ return set, nil
+}
+
+func oidcFindPublicKey(set *oidcJWKSet, kid, alg string) (any, error) {
+ if set == nil {
+ return nil, errors.New("jwks not loaded")
+ }
+ alg = strings.ToUpper(strings.TrimSpace(alg))
+ kid = strings.TrimSpace(kid)
+
+ var lastErr error
+ for i := range set.Keys {
+ k := set.Keys[i]
+ if strings.TrimSpace(k.Use) != "" && !strings.EqualFold(strings.TrimSpace(k.Use), "sig") {
+ continue
+ }
+ if kid != "" && strings.TrimSpace(k.Kid) != kid {
+ continue
+ }
+ if strings.TrimSpace(k.Alg) != "" && !strings.EqualFold(strings.TrimSpace(k.Alg), alg) {
+ continue
+ }
+ pk, err := k.publicKey()
+ if err != nil {
+ lastErr = err
+ continue
+ }
+ if pk != nil {
+ return pk, nil
+ }
+ }
+ if lastErr != nil {
+ return nil, lastErr
+ }
+ if kid != "" {
+ return nil, fmt.Errorf("jwk not found for kid=%s", kid)
+ }
+ return nil, errors.New("jwk not found")
+}
+
+func (k oidcJWK) publicKey() (any, error) {
+ switch strings.ToUpper(strings.TrimSpace(k.Kty)) {
+ case "RSA":
+ n, err := decodeBase64URLBigInt(k.N)
+ if err != nil {
+ return nil, fmt.Errorf("decode rsa n: %w", err)
+ }
+ eBytes, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(k.E))
+ if err != nil {
+ return nil, fmt.Errorf("decode rsa e: %w", err)
+ }
+ if len(eBytes) == 0 {
+ return nil, errors.New("empty rsa e")
+ }
+ e := 0
+ for _, b := range eBytes {
+ e = (e << 8) | int(b)
+ }
+ if e <= 0 {
+ return nil, errors.New("invalid rsa exponent")
+ }
+ if n.Sign() <= 0 {
+ return nil, errors.New("invalid rsa modulus")
+ }
+ return &rsa.PublicKey{N: n, E: e}, nil
+ case "EC":
+ var curve elliptic.Curve
+ switch strings.TrimSpace(k.Crv) {
+ case "P-256":
+ curve = elliptic.P256()
+ case "P-384":
+ curve = elliptic.P384()
+ case "P-521":
+ curve = elliptic.P521()
+ default:
+ return nil, fmt.Errorf("unsupported ec curve: %s", k.Crv)
+ }
+ x, err := decodeBase64URLBigInt(k.X)
+ if err != nil {
+ return nil, fmt.Errorf("decode ec x: %w", err)
+ }
+ y, err := decodeBase64URLBigInt(k.Y)
+ if err != nil {
+ return nil, fmt.Errorf("decode ec y: %w", err)
+ }
+ if !curve.IsOnCurve(x, y) {
+ return nil, errors.New("ec point is not on curve")
+ }
+ return &ecdsa.PublicKey{Curve: curve, X: x, Y: y}, nil
+ default:
+ return nil, fmt.Errorf("unsupported jwk kty: %s", k.Kty)
+ }
+}
+
+func decodeBase64URLBigInt(raw string) (*big.Int, error) {
+ buf, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(raw))
+ if err != nil {
+ return nil, err
+ }
+ if len(buf) == 0 {
+ return nil, errors.New("empty value")
+ }
+ return new(big.Int).SetBytes(buf), nil
+}
+
+func containsString(values []string, target string) bool {
+ target = strings.TrimSpace(target)
+ for _, v := range values {
+ if strings.EqualFold(strings.TrimSpace(v), target) {
+ return true
+ }
+ }
+ return false
+}
+
+func oidcIdentityKey(issuer, subject string) string {
+ issuer = strings.TrimSpace(strings.ToLower(issuer))
+ subject = strings.TrimSpace(subject)
+ return issuer + "\x1f" + subject
+}
+
+func oidcSyntheticEmailFromIdentityKey(identityKey string) string {
+ identityKey = strings.TrimSpace(identityKey)
+ if identityKey == "" {
+ return ""
+ }
+ sum := sha256.Sum256([]byte(identityKey))
+ return "oidc-" + hex.EncodeToString(sum[:16]) + service.OIDCConnectSyntheticEmailDomain
+}
+
+func oidcFallbackUsername(subject string) string {
+ subject = strings.TrimSpace(subject)
+ if subject == "" {
+ return "oidc_user"
+ }
+ sum := sha256.Sum256([]byte(subject))
+ return "oidc_" + hex.EncodeToString(sum[:])[:12]
+}
+
+func oidcSetCookie(c *gin.Context, name, value string, maxAgeSec int, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: value,
+ Path: oidcOAuthCookiePath,
+ MaxAge: maxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func oidcClearCookie(c *gin.Context, name string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: "",
+ Path: oidcOAuthCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go
new file mode 100644
index 00000000..3216d51e
--- /dev/null
+++ b/backend/internal/handler/auth_oidc_oauth_test.go
@@ -0,0 +1,1040 @@
+package handler
+
+import (
+ "bytes"
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "encoding/base64"
+ "encoding/json"
+ "math/big"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/golang-jwt/jwt/v5"
+ "github.com/stretchr/testify/require"
+)
+
+func TestOIDCSyntheticEmailStableAndDistinct(t *testing.T) {
+ k1 := oidcIdentityKey("https://issuer.example.com", "subject-a")
+ k2 := oidcIdentityKey("https://issuer.example.com", "subject-b")
+
+ e1 := oidcSyntheticEmailFromIdentityKey(k1)
+ e1Again := oidcSyntheticEmailFromIdentityKey(k1)
+ e2 := oidcSyntheticEmailFromIdentityKey(k2)
+
+ require.Equal(t, e1, e1Again)
+ require.NotEqual(t, e1, e2)
+ require.Contains(t, e1, "@oidc-connect.invalid")
+}
+
+func TestBuildOIDCAuthorizeURLIncludesNonceAndPKCE(t *testing.T) {
+ cfg := config.OIDCConnectConfig{
+ AuthorizeURL: "https://issuer.example.com/auth",
+ ClientID: "cid",
+ Scopes: "openid email profile",
+ }
+
+ u, err := buildOIDCAuthorizeURL(cfg, "state123", "nonce123", "challenge123", "https://app.example.com/callback")
+ require.NoError(t, err)
+ require.Contains(t, u, "nonce=nonce123")
+ require.Contains(t, u, "code_challenge=challenge123")
+ require.Contains(t, u, "code_challenge_method=S256")
+ require.Contains(t, u, "scope=openid+email+profile")
+}
+
+func TestOIDCParseAndValidateIDToken(t *testing.T) {
+ priv, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ kid := "kid-1"
+ jwks := oidcJWKSet{Keys: []oidcJWK{buildRSAJWK(kid, &priv.PublicKey)}}
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.NoError(t, json.NewEncoder(w).Encode(jwks))
+ }))
+ defer srv.Close()
+
+ now := time.Now()
+ claims := oidcIDTokenClaims{
+ Nonce: "nonce-ok",
+ Azp: "client-1",
+ RegisteredClaims: jwt.RegisteredClaims{
+ Issuer: "https://issuer.example.com",
+ Subject: "subject-1",
+ Audience: jwt.ClaimStrings{"client-1", "another-aud"},
+ IssuedAt: jwt.NewNumericDate(now),
+ NotBefore: jwt.NewNumericDate(now.Add(-30 * time.Second)),
+ ExpiresAt: jwt.NewNumericDate(now.Add(5 * time.Minute)),
+ },
+ }
+ tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
+ tok.Header["kid"] = kid
+ signed, err := tok.SignedString(priv)
+ require.NoError(t, err)
+
+ cfg := config.OIDCConnectConfig{
+ ClientID: "client-1",
+ IssuerURL: "https://issuer.example.com",
+ JWKSURL: srv.URL,
+ AllowedSigningAlgs: "RS256",
+ ClockSkewSeconds: 120,
+ }
+
+ parsed, err := oidcParseAndValidateIDToken(context.Background(), cfg, signed, "nonce-ok")
+ require.NoError(t, err)
+ require.Equal(t, "subject-1", parsed.Subject)
+ require.Equal(t, "https://issuer.example.com", parsed.Issuer)
+
+ _, err = oidcParseAndValidateIDToken(context.Background(), cfg, signed, "bad-nonce")
+ require.Error(t, err)
+}
+
+func TestOIDCParseUserInfoIncludesSuggestedProfile(t *testing.T) {
+ cfg := config.OIDCConnectConfig{}
+
+ claims := oidcParseUserInfo(`{
+ "sub":"subject-1",
+ "preferred_username":"alice",
+ "name":"Alice Example",
+ "picture":"https://cdn.example/avatar.png",
+ "email":"alice@example.com",
+ "email_verified":true
+ }`, cfg)
+
+ require.Equal(t, "subject-1", claims.Subject)
+ require.Equal(t, "alice", claims.Username)
+ require.Equal(t, "Alice Example", claims.DisplayName)
+ require.Equal(t, "https://cdn.example/avatar.png", claims.AvatarURL)
+ require.NotNil(t, claims.EmailVerified)
+ require.True(t, *claims.EmailVerified)
+}
+
+func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK {
+ n := base64.RawURLEncoding.EncodeToString(pub.N.Bytes())
+ e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes())
+ return oidcJWK{
+ Kty: "RSA",
+ Kid: kid,
+ Use: "sig",
+ Alg: "RS256",
+ N: n,
+ E: e,
+ }
+}
+
+func TestOIDCOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
+ handler := newOIDCOAuthTestHandler(t, false, config.OIDCConnectConfig{
+ Enabled: true,
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: "https://issuer.example.com/oauth/authorize",
+ TokenURL: "https://issuer.example.com/oauth/token",
+ UserInfoURL: "https://issuer.example.com/oauth/userinfo",
+ JWKSURL: "https://issuer.example.com/oauth/jwks",
+ Scopes: "openid profile email",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ ValidateIDToken: true,
+ AllowedSigningAlgs: "RS256",
+ ClockSkewSeconds: 120,
+ RequireEmailVerified: false,
+ })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=/settings/connections", nil)
+ c.Request = req
+ c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 84})
+
+ handler.OIDCOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.Contains(t, location, "issuer.example.com/oauth/authorize")
+ require.Contains(t, location, "client_id=oidc-client")
+ require.Contains(t, location, "nonce=")
+
+ cookies := recorder.Result().Cookies()
+ require.NotNil(t, findCookie(cookies, oidcOAuthStateCookieName))
+ require.NotNil(t, findCookie(cookies, oidcOAuthRedirectCookie))
+ require.NotNil(t, findCookie(cookies, oidcOAuthVerifierCookie))
+ require.NotNil(t, findCookie(cookies, oidcOAuthNonceCookie))
+ require.NotNil(t, findCookie(cookies, oauthPendingBrowserCookieName))
+
+ intentCookie := findCookie(cookies, oidcOAuthIntentCookieName)
+ require.NotNil(t, intentCookie)
+ require.Equal(t, oauthIntentBindCurrentUser, decodeCookieValueForTest(t, intentCookie.Value))
+
+ bindCookie := findCookie(cookies, oidcOAuthBindUserCookieName)
+ require.NotNil(t, bindCookie)
+ userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret")
+ require.NoError(t, err)
+ require.Equal(t, int64(84), userID)
+}
+
+func TestOIDCOAuthStartOmitsPKCEAndNonceWhenDisabled(t *testing.T) {
+ handler := newOIDCOAuthTestHandler(t, false, config.OIDCConnectConfig{
+ Enabled: true,
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: "https://issuer.example.com/oauth/authorize",
+ TokenURL: "https://issuer.example.com/oauth/token",
+ UserInfoURL: "https://issuer.example.com/oauth/userinfo",
+ Scopes: "openid profile email",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: false,
+ ValidateIDToken: false,
+ RequireEmailVerified: false,
+ })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/start?redirect=/dashboard", nil)
+
+ handler.OIDCOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.NotContains(t, location, "code_challenge=")
+ require.NotContains(t, location, "nonce=")
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthVerifierCookie))
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthNonceCookie))
+}
+
+func TestOIDCOAuthCallbackAllowsOptionalPKCEAndIDTokenValidation(t *testing.T) {
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ require.NoError(t, r.ParseForm())
+ require.Empty(t, r.PostForm.Get("code_verifier"))
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"oidc-access","token_type":"Bearer","expires_in":3600}`))
+ case "/userinfo":
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"sub":"oidc-subject-compat","preferred_username":"oidc_user","name":"OIDC Display","email":"oidc@example.com"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, config.OIDCConnectConfig{
+ Enabled: true,
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: "https://issuer.example.com",
+ AuthorizeURL: upstream.URL + "/authorize",
+ TokenURL: upstream.URL + "/token",
+ UserInfoURL: upstream.URL + "/userinfo",
+ Scopes: "openid profile email",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: false,
+ ValidateIDToken: false,
+ RequireEmailVerified: false,
+ })
+ t.Cleanup(func() { _ = client.Close() })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+ require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+}
+
+func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-login",
+ PreferredUsername: "oidc_login",
+ DisplayName: "OIDC Login Display",
+ AvatarURL: "https://cdn.example/oidc-login.png",
+ Email: "oidc-login@example.com",
+ EmailVerified: true,
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(oidcSyntheticEmailFromIdentityKey(oidcIdentityKey(cfg.IssuerURL, "oidc-subject-login"))).
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("oidc").
+ SetProviderKey(cfg.IssuerURL).
+ SetProviderSubject("oidc-subject-login").
+ SetMetadata(map[string]any{"username": "legacy-user"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-123"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-login"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, cfg.IssuerURL, session.ProviderKey)
+ require.Equal(t, "OIDC Login Display", session.UpstreamIdentityClaims["suggested_display_name"])
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+ _, hasRefreshToken := completion["refresh_token"]
+ require.False(t, hasRefreshToken)
+ require.Nil(t, completion["error"])
+}
+
+func TestOIDCOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-disabled-subject",
+ PreferredUsername: "oidc_disabled",
+ DisplayName: "OIDC Disabled",
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(oidcSyntheticEmailFromIdentityKey(oidcIdentityKey(cfg.IssuerURL, "oidc-disabled-subject"))).
+ SetUsername("disabled-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusDisabled).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("oidc").
+ SetProviderKey(cfg.IssuerURL).
+ SetProviderSubject("oidc-disabled-subject").
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-disabled", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-disabled"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-disabled"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-disabled-subject"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-compat",
+ PreferredUsername: "oidc_compat",
+ DisplayName: "OIDC Compat Display",
+ AvatarURL: "https://cdn.example/oidc-compat.png",
+ Email: "legacy@example.com",
+ EmailVerified: true,
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-compat", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, existingUser.Email, session.ResolvedEmail)
+ require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ require.Equal(t, oauthPendingChoiceStep, completion["step"])
+ require.Equal(t, existingUser.Email, completion["email"])
+ require.Equal(t, existingUser.Email, completion["existing_account_email"])
+ require.Equal(t, true, completion["existing_account_bindable"])
+ require.Equal(t, "compat_email_match", completion["choice_reason"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+}
+
+func TestOIDCOAuthCallbackAllowsCompatEmailBindWhenUpstreamEmailIsUnverified(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-unverified-compat",
+ PreferredUsername: "oidc_unverified",
+ DisplayName: "OIDC Unverified Compat Display",
+ AvatarURL: "https://cdn.example/oidc-unverified.png",
+ Email: "owner@example.com",
+ EmailVerified: false,
+ })
+ defer cleanup()
+ cfg.RequireEmailVerified = true
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ _, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-unverified-compat", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-unverified-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/settings/connections"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-unverified-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-unverified-compat"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-unverified-compat"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback#error=email_not_verified&error_message=email+is+not+verified", recorder.Header().Get("Location"))
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestOIDCOAuthCallbackCreatesChoicePendingSessionWhenSignupRequiresInvite(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-invite",
+ PreferredUsername: "oidc_invite",
+ DisplayName: "OIDC Invite Display",
+ AvatarURL: "https://cdn.example/oidc-invite.png",
+ Email: "oidc-invite@example.com",
+ EmailVerified: true,
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, true, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-456", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-456"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-456"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-invite"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-456"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ ctx := context.Background()
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.Nil(t, session.TargetUserID)
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, oauthPendingChoiceStep, completion["step"])
+ require.Equal(t, "/dashboard", completion["redirect"])
+ require.Equal(t, "third_party_signup", completion["choice_reason"])
+}
+
+func TestOIDCOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) {
+ cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
+ Subject: "oidc-subject-bind",
+ PreferredUsername: "oidc_bind",
+ DisplayName: "OIDC Bind Display",
+ AvatarURL: "https://cdn.example/oidc-bind.png",
+ Email: "oidc-bind@example.com",
+ EmailVerified: true,
+ })
+ defer cleanup()
+
+ handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
+ t.Cleanup(func() { _ = client.Close() })
+
+ ctx := context.Background()
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-bind", nil)
+ req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-bind"))
+ req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/settings/connections"))
+ req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-bind"))
+ req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-subject-bind"))
+ req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentBindCurrentUser))
+ req.AddCookie(encodedCookie(oidcOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-bind"))
+ c.Request = req
+
+ handler.OIDCOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentBindCurrentUser, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, currentUser.ID, *session.TargetUserID)
+ require.Equal(t, cfg.IssuerURL, session.ProviderKey)
+ require.Equal(t, "OIDC Bind Display", session.UpstreamIdentityClaims["suggested_display_name"])
+
+ completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "/settings/connections", completion["redirect"])
+ require.Empty(t, completion["access_token"])
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, userCount)
+}
+
+func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-session").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-subject-1").
+ SetResolvedEmail("93a310f4c1944c5bbd2e246df1f76485@oidc-connect.invalid").
+ SetBrowserSessionKey("oidc-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "issuer": "https://issuer.example.com",
+ "suggested_display_name": "OIDC Display",
+ "suggested_avatar_url": "https://cdn.example/oidc.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "OIDC Display", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example.com"),
+ authidentity.ProviderSubjectEQ("oidc-subject-1"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+ require.Equal(t, "OIDC Display", identity.Metadata["display_name"])
+ require.Equal(t, "https://cdn.example/oidc.png", identity.Metadata["avatar_url"])
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.True(t, decision.AdoptDisplayName)
+ require.True(t, decision.AdoptAvatar)
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+}
+
+func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-invalid-session").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-invalid-subject-1").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("oidc-invalid-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "bind_login_required",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-invalid-browser")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-choice-session").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-choice-subject-1").
+ SetResolvedEmail("oidc-choice-subject-1@oidc-connect.invalid").
+ SetBrowserSessionKey("oidc-choice-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "issuer": "https://issuer.example.com",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": oauthPendingChoiceStep,
+ "redirect": "/dashboard",
+ "email": "fresh@example.com",
+ "resolved_email": "fresh@example.com",
+ "force_email_on_signup": true,
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-choice-browser")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.Equal(t, "pending_session", responseData["auth_result"])
+ require.Equal(t, oauthPendingChoiceStep, responseData["step"])
+ require.Equal(t, true, responseData["force_email_on_signup"])
+ require.Empty(t, responseData["access_token"])
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-no-adoption-session").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-subject-no-adoption").
+ SetResolvedEmail("8c9f12b2a2e14b1db9efc08b27e0ef5c@oidc-connect.invalid").
+ SetBrowserSessionKey("oidc-browser-no-adoption").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "issuer": "https://issuer.example.com",
+ "suggested_display_name": "OIDC Legacy",
+ "suggested_avatar_url": "https://cdn.example/oidc-legacy.png",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser-no-adoption")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+ require.NotEmpty(t, responseData["refresh_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "oidc_user", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example.com"),
+ authidentity.ProviderSubjectEQ("oidc-subject-no-adoption"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+}
+
+func TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ existingOwner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingOwner.ID).
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-conflict-subject").
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("oidc-complete-conflict-session").
+ SetIntent("login").
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example.com").
+ SetProviderSubject("oidc-conflict-subject").
+ SetResolvedEmail("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid").
+ SetBrowserSessionKey("oidc-conflict-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "oidc_user",
+ "issuer": "https://issuer.example.com",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-conflict-browser")})
+ c.Request = req
+
+ handler.CompleteOIDCOAuthRegistration(c)
+
+ require.Equal(t, http.StatusConflict, recorder.Code)
+ payload := decodeJSONBody(t, recorder)
+ require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"])
+
+ userCount, err := client.User.Query().
+ Where(dbuser.EmailEQ("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid")).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+type oidcProviderFixture struct {
+ Subject string
+ PreferredUsername string
+ DisplayName string
+ AvatarURL string
+ Email string
+ EmailVerified bool
+}
+
+func newOIDCOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.OIDCConnectConfig) *AuthHandler {
+ t.Helper()
+ handler, _ := newOIDCOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
+ return handler
+}
+
+func newOIDCOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg config.OIDCConnectConfig) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+ handler, client := newOAuthPendingFlowTestHandler(t, invitationEnabled)
+ handler.settingSvc = nil
+ handler.cfg = &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ OIDC: oauthCfg,
+ }
+ return handler, client
+}
+
+func newOIDCTestProvider(t *testing.T, fixture oidcProviderFixture) (config.OIDCConnectConfig, func()) {
+ t.Helper()
+
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ kid := "test-kid"
+ jwks := oidcJWKSet{Keys: []oidcJWK{buildRSAJWK(kid, &privateKey.PublicKey)}}
+ tokenResponse := oidcTokenResponse{
+ AccessToken: "oidc-access-token",
+ TokenType: "Bearer",
+ ExpiresIn: 3600,
+ }
+
+ userInfoPayload := map[string]any{
+ "sub": fixture.Subject,
+ "preferred_username": fixture.PreferredUsername,
+ "name": fixture.DisplayName,
+ "picture": fixture.AvatarURL,
+ "email": fixture.Email,
+ "email_verified": fixture.EmailVerified,
+ }
+
+ var issuer string
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/token":
+ require.NoError(t, json.NewEncoder(w).Encode(tokenResponse))
+ case "/userinfo":
+ require.NoError(t, json.NewEncoder(w).Encode(userInfoPayload))
+ case "/jwks":
+ require.NoError(t, json.NewEncoder(w).Encode(jwks))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+
+ issuer = server.URL
+ now := time.Now()
+ claims := oidcIDTokenClaims{
+ Email: fixture.Email,
+ EmailVerified: boolPtr(fixture.EmailVerified),
+ PreferredUsername: fixture.PreferredUsername,
+ Name: fixture.DisplayName,
+ Nonce: "nonce-" + fixture.Subject,
+ RegisteredClaims: jwt.RegisteredClaims{
+ Issuer: issuer,
+ Subject: fixture.Subject,
+ Audience: jwt.ClaimStrings{"oidc-client"},
+ IssuedAt: jwt.NewNumericDate(now),
+ NotBefore: jwt.NewNumericDate(now.Add(-30 * time.Second)),
+ ExpiresAt: jwt.NewNumericDate(now.Add(5 * time.Minute)),
+ },
+ }
+ token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
+ token.Header["kid"] = kid
+ tokenResponse.IDToken, err = token.SignedString(privateKey)
+ require.NoError(t, err)
+
+ cfg := config.OIDCConnectConfig{
+ Enabled: true,
+ ProviderName: "Test OIDC",
+ ClientID: "oidc-client",
+ ClientSecret: "oidc-secret",
+ IssuerURL: issuer,
+ AuthorizeURL: issuer + "/authorize",
+ TokenURL: issuer + "/token",
+ UserInfoURL: issuer + "/userinfo",
+ JWKSURL: issuer + "/jwks",
+ Scopes: "openid profile email",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ ValidateIDToken: true,
+ AllowedSigningAlgs: "RS256",
+ ClockSkewSeconds: 120,
+ RequireEmailVerified: false,
+ }
+ return cfg, server.Close
+}
diff --git a/backend/internal/handler/auth_session_revocation_test.go b/backend/internal/handler/auth_session_revocation_test.go
new file mode 100644
index 00000000..f1c6d87d
--- /dev/null
+++ b/backend/internal/handler/auth_session_revocation_test.go
@@ -0,0 +1,61 @@
+//go:build unit
+
+package handler
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 29,
+ Email: "session@example.com",
+ Username: "session-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ TokenVersion: 7,
+ },
+ }
+ refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
+ handler := &AuthHandler{authService: authService}
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/auth/revoke-all-sessions", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 29})
+
+ handler.RevokeAllSessions(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, []int64{29}, refreshTokenCache.revokedUserIDs)
+ require.Equal(t, int64(8), repo.user.TokenVersion)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Message string `json:"message"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, "All sessions have been revoked. Please log in again.", resp.Data.Message)
+}
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
new file mode 100644
index 00000000..34e70ed0
--- /dev/null
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -0,0 +1,1350 @@
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strconv"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ wechatOAuthCookiePath = "/api/v1/auth/oauth/wechat"
+ wechatOAuthCookieMaxAgeSec = 10 * 60
+ wechatOAuthStateCookieName = "wechat_oauth_state"
+ wechatOAuthRedirectCookieName = "wechat_oauth_redirect"
+ wechatOAuthIntentCookieName = "wechat_oauth_intent"
+ wechatOAuthModeCookieName = "wechat_oauth_mode"
+ wechatOAuthBindUserCookieName = "wechat_oauth_bind_user"
+ wechatOAuthDefaultRedirectTo = "/dashboard"
+ wechatOAuthDefaultFrontendCB = "/auth/wechat/callback"
+ wechatOAuthProviderKey = "wechat-main"
+ wechatOAuthLegacyProviderKey = "wechat"
+ wechatPaymentOAuthCookiePath = "/api/v1/auth/oauth/wechat/payment"
+ wechatPaymentOAuthStateName = "wechat_payment_oauth_state"
+ wechatPaymentOAuthRedirect = "wechat_payment_oauth_redirect"
+ wechatPaymentOAuthContextName = "wechat_payment_oauth_context"
+ wechatPaymentOAuthScope = "wechat_payment_oauth_scope"
+ wechatPaymentOAuthDefaultTo = "/purchase"
+ wechatPaymentOAuthFrontendCB = "/auth/wechat/payment/callback"
+
+ wechatOAuthIntentLogin = "login"
+ wechatOAuthIntentBind = "bind_current_user"
+ wechatOAuthIntentAdoptEmail = "adopt_existing_user_by_email"
+)
+
+var (
+ wechatOAuthAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = "https://api.weixin.qq.com/sns/userinfo"
+)
+
+type wechatOAuthConfig struct {
+ mode string
+ appID string
+ appSecret string
+ authorizeURL string
+ scope string
+ redirectURI string
+ frontendCallback string
+ openEnabled bool
+ mpEnabled bool
+}
+
+type wechatOAuthTokenResponse struct {
+ AccessToken string `json:"access_token"`
+ ExpiresIn int64 `json:"expires_in"`
+ RefreshToken string `json:"refresh_token"`
+ OpenID string `json:"openid"`
+ Scope string `json:"scope"`
+ UnionID string `json:"unionid"`
+ ErrCode int64 `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+}
+
+type wechatOAuthUserInfoResponse struct {
+ OpenID string `json:"openid"`
+ Nickname string `json:"nickname"`
+ HeadImgURL string `json:"headimgurl"`
+ UnionID string `json:"unionid"`
+ ErrCode int64 `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+}
+
+type wechatPaymentOAuthContext struct {
+ PaymentType string `json:"payment_type"`
+ Amount string `json:"amount,omitempty"`
+ OrderType string `json:"order_type,omitempty"`
+ PlanID int64 `json:"plan_id,omitempty"`
+}
+
+// WeChatOAuthStart starts the WeChat OAuth login flow and stores the short-lived
+// browser cookies required by the rebuild pending-auth bridge.
+func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) {
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), c.Query("mode"), c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ state, err := oauth.GenerateState()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
+ return
+ }
+
+ redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
+ if redirectTo == "" {
+ redirectTo = wechatOAuthDefaultRedirectTo
+ }
+
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
+ intent := normalizeWeChatOAuthIntent(c.Query("intent"))
+ secureCookie := isRequestHTTPS(c)
+ wechatSetCookie(c, wechatOAuthStateCookieName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatSetCookie(c, wechatOAuthRedirectCookieName, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatSetCookie(c, wechatOAuthIntentCookieName, encodeCookieValue(intent), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatSetCookie(c, wechatOAuthModeCookieName, encodeCookieValue(cfg.mode), wechatOAuthCookieMaxAgeSec, secureCookie)
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ if intent == oauthIntentBindCurrentUser {
+ bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ wechatSetCookie(c, wechatOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), wechatOAuthCookieMaxAgeSec, secureCookie)
+ } else {
+ wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie)
+ }
+
+ authURL, err := buildWeChatAuthorizeURL(cfg, state)
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
+ return
+ }
+
+ c.Redirect(http.StatusFound, authURL)
+}
+
+// WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid,
+// and stores the result in the unified pending-auth flow.
+func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
+ frontendCallback := h.wechatOAuthFrontendCallback(c.Request.Context())
+
+ if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
+ return
+ }
+
+ code := strings.TrimSpace(c.Query("code"))
+ state := strings.TrimSpace(c.Query("state"))
+ if code == "" || state == "" {
+ redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ defer func() {
+ wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
+ wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie)
+ }()
+
+ expectedState, err := readCookieDecoded(c, wechatOAuthStateCookieName)
+ if err != nil || expectedState == "" || state != expectedState {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
+ return
+ }
+
+ redirectTo, _ := readCookieDecoded(c, wechatOAuthRedirectCookieName)
+ redirectTo = sanitizeFrontendRedirectPath(redirectTo)
+ if redirectTo == "" {
+ redirectTo = wechatOAuthDefaultRedirectTo
+ }
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
+ return
+ }
+
+ intent, _ := readCookieDecoded(c, wechatOAuthIntentCookieName)
+ mode, err := readCookieDecoded(c, wechatOAuthModeCookieName)
+ if err != nil || strings.TrimSpace(mode) == "" {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "missing oauth mode", "")
+ return
+ }
+
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), mode, c)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+
+ tokenResp, userInfo, err := fetchWeChatOAuthIdentity(c.Request.Context(), cfg, code)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "provider_error", "wechat_identity_fetch_failed", singleLine(err.Error()))
+ return
+ }
+
+ unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID))
+ openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID))
+ providerSubject := unionid
+ if providerSubject == "" {
+ if cfg.requiresUnionID() {
+ redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_unionid", "")
+ return
+ }
+ providerSubject = openid
+ }
+ if providerSubject == "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_unionid", "")
+ return
+ }
+
+ username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject))
+ email := wechatSyntheticEmail(providerSubject)
+ upstreamClaims := map[string]any{
+ "email": email,
+ "username": username,
+ "subject": providerSubject,
+ "openid": openid,
+ "unionid": unionid,
+ "mode": cfg.mode,
+ "channel": cfg.mode,
+ "channel_app_id": strings.TrimSpace(cfg.appID),
+ "channel_subject": openid,
+ "suggested_display_name": strings.TrimSpace(userInfo.Nickname),
+ "suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL),
+ }
+ identityRef := service.PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: wechatOAuthProviderKey,
+ ProviderSubject: providerSubject,
+ }
+
+ normalizedIntent := normalizeWeChatOAuthIntent(intent)
+ if normalizedIntent == wechatOAuthIntentBind {
+ if err := h.createWeChatBindPendingSession(c, cfg, providerSubject, openid, redirectTo, browserSessionKey, upstreamClaims); err != nil {
+ switch infraerrors.Code(err) {
+ case http.StatusConflict:
+ redirectOAuthError(c, frontendCallback, "ownership_conflict", infraerrors.Reason(err), infraerrors.Message(err))
+ case http.StatusUnauthorized, http.StatusForbidden:
+ redirectOAuthError(c, frontendCallback, "auth_required", infraerrors.Reason(err), infraerrors.Message(err))
+ default:
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ }
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if existingIdentityUser == nil {
+ existingIdentityUser, err = h.findWeChatUserByLegacyOpenID(c.Request.Context(), identityRef, cfg, openid)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ }
+ if existingIdentityUser != nil {
+ if err := h.ensureWeChatRuntimeIdentityBinding(c.Request.Context(), existingIdentityUser.ID, identityRef, upstreamClaims); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, nil, nil, &existingIdentityUser.ID); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
+ if err := h.createWeChatChoicePendingSession(
+ c,
+ identityRef,
+ email,
+ email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ "",
+ nil,
+ true,
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ if err := h.createWeChatChoicePendingSession(
+ c,
+ identityRef,
+ email,
+ email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ "",
+ nil,
+ false,
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+}
+
+// WeChatPaymentOAuthStart starts the WeChat payment OAuth flow.
+// GET /api/v1/auth/oauth/wechat/payment/start?payment_type=wxpay&redirect=/purchase
+func (h *AuthHandler) WeChatPaymentOAuthStart(c *gin.Context) {
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), "mp", c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ paymentType := normalizeWeChatPaymentType(c.Query("payment_type"))
+ if paymentType == "" {
+ response.BadRequest(c, "Invalid payment type")
+ return
+ }
+
+ state, err := oauth.GenerateState()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
+ return
+ }
+
+ redirectTo := normalizeWeChatPaymentRedirectPath(sanitizeFrontendRedirectPath(c.Query("redirect")))
+ if redirectTo == "" {
+ redirectTo = wechatPaymentOAuthDefaultTo
+ }
+ rawContext, err := encodeWeChatPaymentOAuthContext(wechatPaymentOAuthContext{
+ PaymentType: paymentType,
+ Amount: strings.TrimSpace(c.Query("amount")),
+ OrderType: strings.TrimSpace(c.Query("order_type")),
+ PlanID: parseWeChatPaymentPlanID(c.Query("plan_id")),
+ })
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONTEXT_ENCODE_FAILED", "failed to encode oauth context").WithCause(err))
+ return
+ }
+
+ scope := normalizeWeChatPaymentScope(c.Query("scope"))
+ secureCookie := isRequestHTTPS(c)
+ wechatPaymentSetCookie(c, wechatPaymentOAuthStateName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatPaymentSetCookie(c, wechatPaymentOAuthRedirect, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatPaymentSetCookie(c, wechatPaymentOAuthContextName, encodeCookieValue(rawContext), wechatOAuthCookieMaxAgeSec, secureCookie)
+ wechatPaymentSetCookie(c, wechatPaymentOAuthScope, encodeCookieValue(scope), wechatOAuthCookieMaxAgeSec, secureCookie)
+
+ cfg.redirectURI = h.resolveWeChatPaymentOAuthCallbackURL(c.Request.Context(), c)
+ cfg.scope = scope
+ authURL, err := buildWeChatAuthorizeURL(cfg, state)
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err))
+ return
+ }
+
+ c.Redirect(http.StatusFound, authURL)
+}
+
+// WeChatPaymentOAuthCallback exchanges a payment OAuth code for an OpenID and
+// forwards the browser back to the frontend callback route.
+func (h *AuthHandler) WeChatPaymentOAuthCallback(c *gin.Context) {
+ frontendCallback := wechatPaymentOAuthFrontendCB
+
+ if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
+ return
+ }
+
+ code := strings.TrimSpace(c.Query("code"))
+ state := strings.TrimSpace(c.Query("state"))
+ if code == "" || state == "" {
+ redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ defer func() {
+ wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie)
+ wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie)
+ }()
+
+ expectedState, err := readCookieDecoded(c, wechatPaymentOAuthStateName)
+ if err != nil || expectedState == "" || state != expectedState {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "")
+ return
+ }
+
+ redirectTo, _ := readCookieDecoded(c, wechatPaymentOAuthRedirect)
+ redirectTo = normalizeWeChatPaymentRedirectPath(sanitizeFrontendRedirectPath(redirectTo))
+ if redirectTo == "" {
+ redirectTo = wechatPaymentOAuthDefaultTo
+ }
+
+ rawContext, _ := readCookieDecoded(c, wechatPaymentOAuthContextName)
+ paymentContext, err := decodeWeChatPaymentOAuthContext(rawContext)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "invalid_context", "invalid oauth context", "")
+ return
+ }
+ if paymentContext.PaymentType == "" {
+ paymentContext.PaymentType = payment.TypeWxpay
+ }
+
+ scope, _ := readCookieDecoded(c, wechatPaymentOAuthScope)
+ scope = normalizeWeChatPaymentScope(scope)
+
+ cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), "mp", c)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ cfg.redirectURI = h.resolveWeChatPaymentOAuthCallbackURL(c.Request.Context(), c)
+ tokenResp, err := exchangeWeChatOAuthCode(c.Request.Context(), cfg, code)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", err.Error())
+ return
+ }
+
+ openid := strings.TrimSpace(tokenResp.OpenID)
+ if openid == "" {
+ redirectOAuthError(c, frontendCallback, "missing_openid", "missing openid", "")
+ return
+ }
+ if strings.TrimSpace(tokenResp.Scope) != "" {
+ scope = strings.TrimSpace(tokenResp.Scope)
+ }
+
+ resumeToken, err := h.wechatPaymentResumeService().CreateWeChatPaymentResumeToken(service.WeChatPaymentResumeClaims{
+ OpenID: openid,
+ PaymentType: paymentContext.PaymentType,
+ Amount: paymentContext.Amount,
+ OrderType: paymentContext.OrderType,
+ PlanID: paymentContext.PlanID,
+ RedirectTo: redirectTo,
+ Scope: scope,
+ })
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "invalid_context", "failed to encode payment resume context", "")
+ return
+ }
+
+ fragment := url.Values{}
+ fragment.Set("wechat_resume_token", resumeToken)
+ fragment.Set("redirect", redirectTo)
+ redirectWithFragment(c, frontendCallback, fragment)
+}
+
+func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService {
+ var legacyKey []byte
+ key, err := payment.ProvideEncryptionKey(h.cfg)
+ if err == nil {
+ legacyKey = []byte(key)
+ }
+ return service.NewLegacyAwarePaymentResumeService(legacyKey)
+}
+
+type completeWeChatOAuthRequest struct {
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AffCode string `json:"aff_code,omitempty"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+// CompleteWeChatOAuthRegistration completes a pending WeChat OAuth registration by
+// validating the invitation code and consuming the current pending browser session.
+// POST /api/v1/auth/oauth/wechat/complete-registration
+func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
+ var req completeWeChatOAuthRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ } else if handled {
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
+ return
+ } else {
+ session = updatedSession
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
+ return
+ }
+
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
+ return
+ }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+
+ c.JSON(http.StatusOK, gin.H{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ })
+}
+
+func (h *AuthHandler) createWeChatPendingSession(
+ c *gin.Context,
+ intent string,
+ providerSubject string,
+ email string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+ tokenPair *service.TokenPair,
+ authErr error,
+ targetUserID *int64,
+) error {
+ completionResponse := map[string]any{
+ "redirect": redirectTo,
+ }
+ if authErr != nil {
+ if errors.Is(authErr, service.ErrOAuthInvitationRequired) {
+ completionResponse["error"] = "invitation_required"
+ } else {
+ return authErr
+ }
+ } else if tokenPair != nil {
+ completionResponse["access_token"] = tokenPair.AccessToken
+ completionResponse["refresh_token"] = tokenPair.RefreshToken
+ completionResponse["expires_in"] = tokenPair.ExpiresIn
+ completionResponse["token_type"] = "Bearer"
+ }
+
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: intent,
+ Identity: service.PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: wechatOAuthProviderKey,
+ ProviderSubject: providerSubject,
+ },
+ TargetUserID: targetUserID,
+ ResolvedEmail: email,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ })
+}
+
+func (h *AuthHandler) createWeChatChoicePendingSession(
+ c *gin.Context,
+ identity service.PendingAuthIdentityKey,
+ suggestedEmail string,
+ resolvedEmail string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+ compatEmail string,
+ compatEmailUser *dbent.User,
+ forceEmailOnSignup bool,
+) error {
+ suggestionEmail := strings.TrimSpace(suggestedEmail)
+ canonicalEmail := strings.TrimSpace(resolvedEmail)
+ if suggestionEmail == "" {
+ suggestionEmail = canonicalEmail
+ }
+
+ completionResponse := map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "redirect": strings.TrimSpace(redirectTo),
+ "email": suggestionEmail,
+ "resolved_email": canonicalEmail,
+ "existing_account_email": "",
+ "existing_account_bindable": false,
+ "create_account_allowed": true,
+ "force_email_on_signup": forceEmailOnSignup,
+ "choice_reason": "third_party_signup",
+ }
+ if strings.TrimSpace(compatEmail) != "" {
+ completionResponse["compat_email"] = strings.TrimSpace(compatEmail)
+ }
+ if compatEmailUser != nil {
+ completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_bindable"] = true
+ completionResponse["choice_reason"] = "compat_email_match"
+ }
+ if forceEmailOnSignup {
+ completionResponse["choice_reason"] = "force_email_on_signup"
+ }
+
+ resolvedChoiceEmail := suggestionEmail
+ if compatEmailUser != nil {
+ resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
+ }
+
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identity,
+ ResolvedEmail: resolvedChoiceEmail,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ })
+}
+
+func (h *AuthHandler) createWeChatBindPendingSession(
+ c *gin.Context,
+ cfg wechatOAuthConfig,
+ providerSubject string,
+ channelSubject string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+) error {
+ currentUser, err := h.readOAuthBindTargetUser(c, wechatOAuthBindUserCookieName)
+ if err != nil {
+ return err
+ }
+ if err := h.ensureWeChatBindOwnership(c.Request.Context(), currentUser.ID, providerSubject, cfg, channelSubject); err != nil {
+ return err
+ }
+ return h.createWeChatPendingSession(
+ c,
+ wechatOAuthIntentBind,
+ providerSubject,
+ currentUser.Email,
+ redirectTo,
+ browserSessionKey,
+ upstreamClaims,
+ nil,
+ nil,
+ ¤tUser.ID,
+ )
+}
+
+func (h *AuthHandler) readOAuthBindTargetUser(c *gin.Context, cookieName string) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+ userID, err := h.readOAuthBindUserIDFromCookie(c, cookieName)
+ if err != nil {
+ return nil, infraerrors.Unauthorized("AUTH_REQUIRED", "current user is required to bind wechat account")
+ }
+ userEntity, err := client.User.Get(c.Request.Context(), userID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, infraerrors.Unauthorized("AUTH_REQUIRED", "current user is required to bind wechat account")
+ }
+ return nil, infraerrors.InternalServer("WECHAT_BIND_USER_LOOKUP_FAILED", "failed to load current user").WithCause(err)
+ }
+ return userEntity, nil
+}
+
+func (h *AuthHandler) ensureWeChatBindOwnership(
+ ctx context.Context,
+ userID int64,
+ providerSubject string,
+ cfg wechatOAuthConfig,
+ channelSubject string,
+) error {
+ client := h.entClient()
+ if client == nil {
+ return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ identities, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(providerSubject)),
+ ).
+ All(ctx)
+ if err != nil {
+ return infraerrors.InternalServer("WECHAT_BIND_LOOKUP_FAILED", "failed to inspect wechat identity ownership").WithCause(err)
+ }
+ for _, identity := range identities {
+ if identity != nil && identity.UserID != userID {
+ activeOwner, lookupErr := findActiveUserByID(ctx, client, identity.UserID)
+ if lookupErr != nil {
+ return lookupErr
+ }
+ if activeOwner != nil {
+ return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ }
+ }
+
+ channelSubject = strings.TrimSpace(channelSubject)
+ channelAppID := strings.TrimSpace(cfg.appID)
+ if channelSubject == "" || channelAppID == "" {
+ return nil
+ }
+
+ channels, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...),
+ authidentitychannel.ChannelEQ(strings.TrimSpace(cfg.mode)),
+ authidentitychannel.ChannelAppIDEQ(channelAppID),
+ authidentitychannel.ChannelSubjectEQ(channelSubject),
+ ).
+ WithIdentity().
+ All(ctx)
+ if err != nil {
+ return infraerrors.InternalServer("WECHAT_BIND_CHANNEL_LOOKUP_FAILED", "failed to inspect wechat identity channel ownership").WithCause(err)
+ }
+ for _, channel := range channels {
+ if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID {
+ activeOwner, lookupErr := findActiveUserByID(ctx, client, channel.Edges.Identity.UserID)
+ if lookupErr != nil {
+ return lookupErr
+ }
+ if activeOwner != nil {
+ return infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+ }
+ }
+ return nil
+}
+
+func (h *AuthHandler) findWeChatUserByLegacyOpenID(
+ ctx context.Context,
+ identity service.PendingAuthIdentityKey,
+ cfg wechatOAuthConfig,
+ openid string,
+) (*dbent.User, error) {
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ providerType := strings.TrimSpace(identity.ProviderType)
+ providerSubject := strings.TrimSpace(identity.ProviderSubject)
+ providerKeys := wechatCompatibleProviderKeys(identity.ProviderKey)
+ if providerSubject != "" {
+ records, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(providerKeys...),
+ authidentity.ProviderSubjectEQ(providerSubject),
+ ).
+ WithUser().
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ if user, err := singleWeChatIdentityUser(records); err != nil || user != nil {
+ if err != nil || user == nil {
+ return user, err
+ }
+ return findActiveUserByID(ctx, client, user.ID)
+ }
+ }
+
+ openid = strings.TrimSpace(openid)
+ channel := strings.TrimSpace(cfg.mode)
+ channelAppID := strings.TrimSpace(cfg.appID)
+ if openid != "" && channel != "" && channelAppID != "" {
+ records, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(providerType),
+ authidentitychannel.ProviderKeyIn(providerKeys...),
+ authidentitychannel.ChannelEQ(channel),
+ authidentitychannel.ChannelAppIDEQ(channelAppID),
+ authidentitychannel.ChannelSubjectEQ(openid),
+ ).
+ WithIdentity(func(q *dbent.AuthIdentityQuery) {
+ q.WithUser()
+ }).
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
+ }
+ if user, err := singleWeChatChannelUser(records); err != nil || user != nil {
+ if err != nil || user == nil {
+ return user, err
+ }
+ return findActiveUserByID(ctx, client, user.ID)
+ }
+ }
+
+ if openid == "" {
+ return nil, nil
+ }
+
+ records, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(providerKeys...),
+ authidentity.ProviderSubjectEQ(openid),
+ ).
+ WithUser().
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ user, err := singleWeChatIdentityUser(records)
+ if err != nil || user == nil {
+ return user, err
+ }
+ return findActiveUserByID(ctx, client, user.ID)
+}
+
+func wechatCompatibleProviderKeys(providerKey string) []string {
+ preferred := strings.TrimSpace(providerKey)
+ if preferred == "" {
+ preferred = wechatOAuthProviderKey
+ }
+ keys := []string{preferred}
+ if !strings.EqualFold(preferred, wechatOAuthLegacyProviderKey) {
+ keys = append(keys, wechatOAuthLegacyProviderKey)
+ }
+ return keys
+}
+
+func singleWeChatIdentityUser(records []*dbent.AuthIdentity) (*dbent.User, error) {
+ var resolved *dbent.User
+ for _, record := range records {
+ if record == nil || record.Edges.User == nil {
+ continue
+ }
+ if resolved == nil {
+ resolved = record.Edges.User
+ continue
+ }
+ if resolved.ID != record.Edges.User.ID {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ }
+ return resolved, nil
+}
+
+func singleWeChatChannelUser(records []*dbent.AuthIdentityChannel) (*dbent.User, error) {
+ var resolved *dbent.User
+ for _, record := range records {
+ if record == nil || record.Edges.Identity == nil || record.Edges.Identity.Edges.User == nil {
+ continue
+ }
+ if resolved == nil {
+ resolved = record.Edges.Identity.Edges.User
+ continue
+ }
+ if resolved.ID != record.Edges.Identity.Edges.User.ID {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+ }
+ return resolved, nil
+}
+
+func (h *AuthHandler) ensureWeChatRuntimeIdentityBinding(
+ ctx context.Context,
+ userID int64,
+ identity service.PendingAuthIdentityKey,
+ upstreamClaims map[string]any,
+) error {
+ client := h.entClient()
+ if client == nil {
+ return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ tx, err := client.Tx(ctx)
+ if err != nil {
+ return infraerrors.InternalServer("AUTH_IDENTITY_BIND_FAILED", "failed to begin wechat identity repair transaction").WithCause(err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ _, err = ensurePendingOAuthIdentityForUser(dbent.NewTxContext(ctx, tx), tx, &dbent.PendingAuthSession{
+ ProviderType: strings.TrimSpace(identity.ProviderType),
+ ProviderKey: strings.TrimSpace(identity.ProviderKey),
+ ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
+ UpstreamIdentityClaims: cloneOAuthMetadata(upstreamClaims),
+ }, userID)
+ if err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) {
+ mode, err := resolveWeChatOAuthMode(rawMode, c)
+ if err != nil {
+ return wechatOAuthConfig{}, err
+ }
+
+ if h == nil || h.settingSvc == nil {
+ return wechatOAuthConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "wechat oauth settings service not ready")
+ }
+
+ apiBaseURL := ""
+ if h != nil && h.settingSvc != nil {
+ settings, err := h.settingSvc.GetAllSettings(ctx)
+ if err == nil && settings != nil {
+ apiBaseURL = strings.TrimSpace(settings.APIBaseURL)
+ }
+ }
+
+ effective, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx)
+ if err != nil {
+ return wechatOAuthConfig{}, err
+ }
+ if !effective.SupportsMode(mode) {
+ return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled")
+ }
+
+ cfg := wechatOAuthConfig{
+ mode: mode,
+ appID: strings.TrimSpace(effective.AppIDForMode(mode)),
+ appSecret: strings.TrimSpace(effective.AppSecretForMode(mode)),
+ redirectURI: firstNonEmpty(strings.TrimSpace(effective.RedirectURL), resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback")),
+ frontendCallback: firstNonEmpty(strings.TrimSpace(effective.FrontendRedirectURL), wechatOAuthDefaultFrontendCB),
+ scope: effective.ScopeForMode(mode),
+ openEnabled: effective.OpenEnabled,
+ mpEnabled: effective.MPEnabled,
+ }
+
+ switch mode {
+ case "mp":
+ cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize"
+ default:
+ cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect"
+ }
+ if strings.TrimSpace(cfg.redirectURI) == "" {
+ return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured")
+ }
+
+ return cfg, nil
+}
+
+func (cfg wechatOAuthConfig) requiresUnionID() bool {
+ return cfg.openEnabled && cfg.mpEnabled
+}
+
+func (h *AuthHandler) wechatOAuthFrontendCallback(ctx context.Context) string {
+ if h != nil && h.settingSvc != nil {
+ cfg, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx)
+ if err == nil && strings.TrimSpace(cfg.FrontendRedirectURL) != "" {
+ return strings.TrimSpace(cfg.FrontendRedirectURL)
+ }
+ }
+ return wechatOAuthDefaultFrontendCB
+}
+
+func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) {
+ mode := strings.ToLower(strings.TrimSpace(rawMode))
+ if mode == "" {
+ if isWeChatBrowserRequest(c) {
+ return "mp", nil
+ }
+ return "open", nil
+ }
+ if mode != "open" && mode != "mp" {
+ return "", infraerrors.BadRequest("INVALID_MODE", "wechat oauth mode must be open or mp")
+ }
+ return mode, nil
+}
+
+func isWeChatBrowserRequest(c *gin.Context) bool {
+ if c == nil || c.Request == nil {
+ return false
+ }
+ return strings.Contains(strings.ToLower(strings.TrimSpace(c.GetHeader("User-Agent"))), "micromessenger")
+}
+
+func normalizeWeChatOAuthIntent(raw string) string {
+ switch strings.ToLower(strings.TrimSpace(raw)) {
+ case "", "login":
+ return wechatOAuthIntentLogin
+ case "bind", "bind_current_user":
+ return wechatOAuthIntentBind
+ case "adopt", "adopt_existing_user_by_email":
+ return wechatOAuthIntentAdoptEmail
+ default:
+ return wechatOAuthIntentLogin
+ }
+}
+
+func buildWeChatAuthorizeURL(cfg wechatOAuthConfig, state string) (string, error) {
+ u, err := url.Parse(cfg.authorizeURL)
+ if err != nil {
+ return "", fmt.Errorf("parse authorize url: %w", err)
+ }
+ query := u.Query()
+ query.Set("appid", cfg.appID)
+ query.Set("redirect_uri", cfg.redirectURI)
+ query.Set("response_type", "code")
+ query.Set("scope", cfg.scope)
+ query.Set("state", state)
+ u.RawQuery = query.Encode()
+ u.Fragment = "wechat_redirect"
+ return u.String(), nil
+}
+
+func resolveWeChatOAuthAbsoluteURL(apiBaseURL string, c *gin.Context, callbackPath string) string {
+ callbackPath = strings.TrimSpace(callbackPath)
+ if callbackPath == "" {
+ return ""
+ }
+
+ if raw := strings.TrimSpace(apiBaseURL); raw != "" {
+ if parsed, err := url.Parse(raw); err == nil && parsed.Scheme != "" && parsed.Host != "" {
+ basePath := strings.TrimRight(parsed.EscapedPath(), "/")
+ targetPath := callbackPath
+ if basePath != "" && strings.HasSuffix(basePath, "/api/v1") && strings.HasPrefix(callbackPath, "/api/v1") {
+ targetPath = basePath + strings.TrimPrefix(callbackPath, "/api/v1")
+ } else if basePath != "" {
+ targetPath = basePath + callbackPath
+ }
+ return parsed.Scheme + "://" + parsed.Host + targetPath
+ }
+ }
+
+ if c == nil || c.Request == nil {
+ return ""
+ }
+ scheme := "http"
+ if isRequestHTTPS(c) {
+ scheme = "https"
+ }
+ host := strings.TrimSpace(c.Request.Host)
+ if forwardedHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); forwardedHost != "" {
+ host = forwardedHost
+ }
+ if host == "" {
+ return ""
+ }
+ return scheme + "://" + host + callbackPath
+}
+
+func fetchWeChatOAuthIdentity(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, *wechatOAuthUserInfoResponse, error) {
+ tokenResp, err := exchangeWeChatOAuthCode(ctx, cfg, code)
+ if err != nil {
+ return nil, nil, err
+ }
+ userInfo, err := fetchWeChatUserInfo(ctx, tokenResp)
+ if err != nil {
+ return nil, nil, err
+ }
+ return tokenResp, userInfo, nil
+}
+
+func exchangeWeChatOAuthCode(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, error) {
+ endpoint, err := url.Parse(wechatOAuthAccessTokenURL)
+ if err != nil {
+ return nil, fmt.Errorf("parse wechat access token url: %w", err)
+ }
+
+ query := endpoint.Query()
+ query.Set("appid", cfg.appID)
+ query.Set("secret", cfg.appSecret)
+ query.Set("code", strings.TrimSpace(code))
+ query.Set("grant_type", "authorization_code")
+ endpoint.RawQuery = query.Encode()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("build wechat access token request: %w", err)
+ }
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request wechat access token: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("read wechat access token response: %w", err)
+ }
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
+ return nil, fmt.Errorf("wechat access token status=%d", resp.StatusCode)
+ }
+
+ var tokenResp wechatOAuthTokenResponse
+ if err := json.Unmarshal(body, &tokenResp); err != nil {
+ return nil, fmt.Errorf("decode wechat access token response: %w", err)
+ }
+ if tokenResp.ErrCode != 0 {
+ return nil, fmt.Errorf("wechat access token error=%d %s", tokenResp.ErrCode, strings.TrimSpace(tokenResp.ErrMsg))
+ }
+ if strings.TrimSpace(tokenResp.AccessToken) == "" {
+ return nil, fmt.Errorf("wechat access token missing access_token")
+ }
+ return &tokenResp, nil
+}
+
+func fetchWeChatUserInfo(ctx context.Context, tokenResp *wechatOAuthTokenResponse) (*wechatOAuthUserInfoResponse, error) {
+ if tokenResp == nil {
+ return nil, fmt.Errorf("wechat token response is nil")
+ }
+
+ endpoint, err := url.Parse(wechatOAuthUserInfoURL)
+ if err != nil {
+ return nil, fmt.Errorf("parse wechat userinfo url: %w", err)
+ }
+ query := endpoint.Query()
+ query.Set("access_token", strings.TrimSpace(tokenResp.AccessToken))
+ query.Set("openid", strings.TrimSpace(tokenResp.OpenID))
+ query.Set("lang", "zh_CN")
+ endpoint.RawQuery = query.Encode()
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("build wechat userinfo request: %w", err)
+ }
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request wechat userinfo: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("read wechat userinfo response: %w", err)
+ }
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
+ return nil, fmt.Errorf("wechat userinfo status=%d", resp.StatusCode)
+ }
+
+ var userInfo wechatOAuthUserInfoResponse
+ if err := json.Unmarshal(body, &userInfo); err != nil {
+ return nil, fmt.Errorf("decode wechat userinfo response: %w", err)
+ }
+ if userInfo.ErrCode != 0 {
+ return nil, fmt.Errorf("wechat userinfo error=%d %s", userInfo.ErrCode, strings.TrimSpace(userInfo.ErrMsg))
+ }
+ return &userInfo, nil
+}
+
+func wechatSyntheticEmail(subject string) string {
+ subject = strings.TrimSpace(subject)
+ if subject == "" {
+ return ""
+ }
+ return "wechat-" + subject + service.WeChatConnectSyntheticEmailDomain
+}
+
+func wechatFallbackUsername(subject string) string {
+ subject = strings.TrimSpace(subject)
+ if subject == "" {
+ return "wechat_user"
+ }
+ return "wechat_" + truncateFragmentValue(subject)
+}
+
+func wechatSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: value,
+ Path: wechatOAuthCookiePath,
+ MaxAge: maxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func wechatClearCookie(c *gin.Context, name string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: "",
+ Path: wechatOAuthCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func normalizeWeChatPaymentType(raw string) string {
+ switch strings.TrimSpace(raw) {
+ case payment.TypeWxpay, payment.TypeWxpayDirect:
+ return strings.TrimSpace(raw)
+ default:
+ return ""
+ }
+}
+
+func normalizeWeChatPaymentScope(raw string) string {
+ for _, part := range strings.FieldsFunc(strings.TrimSpace(raw), func(r rune) bool {
+ return r == ',' || r == ' ' || r == '\t' || r == '\n' || r == '\r'
+ }) {
+ switch strings.TrimSpace(part) {
+ case "snsapi_userinfo":
+ return "snsapi_userinfo"
+ case "snsapi_base":
+ return "snsapi_base"
+ }
+ }
+ return "snsapi_base"
+}
+
+func normalizeWeChatPaymentRedirectPath(path string) string {
+ path = strings.TrimSpace(path)
+ if path == "" {
+ return wechatPaymentOAuthDefaultTo
+ }
+ if path == "/payment" {
+ return "/purchase"
+ }
+ if strings.HasPrefix(path, "/payment?") {
+ return "/purchase" + strings.TrimPrefix(path, "/payment")
+ }
+ return path
+}
+
+func (h *AuthHandler) resolveWeChatPaymentOAuthCallbackURL(ctx context.Context, c *gin.Context) string {
+ apiBaseURL := ""
+ if h != nil && h.settingSvc != nil {
+ if settings, err := h.settingSvc.GetAllSettings(ctx); err == nil && settings != nil {
+ apiBaseURL = strings.TrimSpace(settings.APIBaseURL)
+ }
+ }
+ return resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/payment/callback")
+}
+
+func encodeWeChatPaymentOAuthContext(ctx wechatPaymentOAuthContext) (string, error) {
+ data, err := json.Marshal(ctx)
+ if err != nil {
+ return "", err
+ }
+ return string(data), nil
+}
+
+func decodeWeChatPaymentOAuthContext(raw string) (wechatPaymentOAuthContext, error) {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return wechatPaymentOAuthContext{}, nil
+ }
+ var ctx wechatPaymentOAuthContext
+ if err := json.Unmarshal([]byte(raw), &ctx); err != nil {
+ return wechatPaymentOAuthContext{}, err
+ }
+ return ctx, nil
+}
+
+func parseWeChatPaymentPlanID(raw string) int64 {
+ id, _ := strconv.ParseInt(strings.TrimSpace(raw), 10, 64)
+ return id
+}
+
+func wechatPaymentSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: value,
+ Path: wechatPaymentOAuthCookiePath,
+ MaxAge: maxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func wechatPaymentClearCookie(c *gin.Context, name string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: "",
+ Path: wechatPaymentOAuthCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go
new file mode 100644
index 00000000..b3c7786d
--- /dev/null
+++ b/backend/internal/handler/auth_wechat_oauth_test.go
@@ -0,0 +1,1498 @@
+//go:build unit
+
+package handler
+
+import (
+ "bytes"
+ "context"
+ "database/sql"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/repository"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{
+ service.SettingKeyWeChatConnectEnabled: "true",
+ service.SettingKeyWeChatConnectAppID: "wx-open-app",
+ service.SettingKeyWeChatConnectAppSecret: "wx-open-secret",
+ service.SettingKeyWeChatConnectMode: "open",
+ service.SettingKeyWeChatConnectScopes: "snsapi_login",
+ service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ })
+ defer client.Close()
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil)
+ c.Request.Host = "api.example.com"
+
+ handler.WeChatOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.NotEmpty(t, location)
+ require.Contains(t, location, "open.weixin.qq.com")
+ require.Contains(t, location, "appid=wx-open-app")
+ require.Contains(t, location, "scope=snsapi_login")
+
+ cookies := recorder.Result().Cookies()
+ require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName))
+ require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName))
+ require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName))
+ require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName))
+}
+
+func TestWeChatOAuthStart_AllowsOpenModeWhenBothCapabilitiesEnabled(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{
+ service.SettingKeyWeChatConnectEnabled: "true",
+ service.SettingKeyWeChatConnectAppID: "wx-shared-app",
+ service.SettingKeyWeChatConnectAppSecret: "wx-shared-secret",
+ service.SettingKeyWeChatConnectMode: "mp",
+ service.SettingKeyWeChatConnectScopes: "snsapi_base",
+ service.SettingKeyWeChatConnectOpenEnabled: "true",
+ service.SettingKeyWeChatConnectMPEnabled: "true",
+ service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ })
+ defer client.Close()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil)
+ c.Request.Host = "api.example.com"
+
+ handler.WeChatOAuthStart(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ require.NotEmpty(t, location)
+ require.Contains(t, location, "open.weixin.qq.com")
+ require.Contains(t, location, "connect/qrconnect")
+ require.Contains(t, location, "scope=snsapi_login")
+}
+
+func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ ctx := context.Background()
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "wechat", session.ProviderType)
+ require.Equal(t, "wechat-main", session.ProviderKey)
+ require.Equal(t, "union-456", session.ProviderSubject)
+ require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail)
+ require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"])
+ require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"])
+ require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"])
+ require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"])
+}
+
+func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMode(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback"))
+ defer client.Close()
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.Equal(t, "openid-123", session.ProviderSubject)
+ require.Equal(t, wechatSyntheticEmail("openid-123"), session.ResolvedEmail)
+
+ completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.Equal(t, oauthPendingChoiceStep, completion["step"])
+ require.Equal(t, "third_party_signup", completion["choice_reason"])
+}
+
+func TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWithoutStoredTokens(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat-login.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback"))
+ defer client.Close()
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(wechatSyntheticEmail("union-456")).
+ SetUsername("wechat-existing-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("union-456").
+ SetMetadata(map[string]any{"username": "wechat-existing-user"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthIntentLogin, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, existingUser.ID, *session.TargetUserID)
+ require.Equal(t, existingUser.Email, session.ResolvedEmail)
+
+ completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.Equal(t, "/dashboard", completion["redirect"])
+ _, hasAccessToken := completion["access_token"]
+ require.False(t, hasAccessToken)
+ _, hasRefreshToken := completion["refresh_token"]
+ require.False(t, hasRefreshToken)
+}
+
+func TestWeChatOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-disabled","unionid":"union-disabled","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-disabled","unionid":"union-disabled","nickname":"Disabled WeChat","headimgurl":"https://cdn.example/disabled.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail(wechatSyntheticEmail("union-disabled")).
+ SetUsername("disabled-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusDisabled).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(existingUser.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("union-disabled").
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-disabled", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-disabled"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_base"}`))
+ return
+ }
+ http.NotFound(w, r)
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback"))
+ defer client.Close()
+ handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
+ handler.cfg.Totp.EncryptionKeyConfigured = true
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-123"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"12.5","order_type":"subscription","plan_id":7}`))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base"))
+ c.Request = req
+
+ handler.WeChatPaymentOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ parsed, err := url.Parse(location)
+ require.NoError(t, err)
+ fragment, err := url.ParseQuery(parsed.Fragment)
+ require.NoError(t, err)
+ require.Equal(t, "/purchase?from=wechat", fragment.Get("redirect"))
+ require.NotEmpty(t, fragment.Get("wechat_resume_token"))
+ require.Empty(t, fragment.Get("openid"))
+ require.Empty(t, fragment.Get("payment_type"))
+ require.Empty(t, fragment.Get("amount"))
+ require.Empty(t, fragment.Get("order_type"))
+ require.Empty(t, fragment.Get("plan_id"))
+
+ claims, err := handler.wechatPaymentResumeService().ParseWeChatPaymentResumeToken(fragment.Get("wechat_resume_token"))
+ require.NoError(t, err)
+ require.Equal(t, "openid-123", claims.OpenID)
+ require.Equal(t, payment.TypeWxpay, claims.PaymentType)
+ require.Equal(t, "12.5", claims.Amount)
+ require.Equal(t, payment.OrderTypeSubscription, claims.OrderType)
+ require.EqualValues(t, 7, claims.PlanID)
+ require.Equal(t, "/purchase?from=wechat", claims.RedirectTo)
+}
+
+func TestWeChatPaymentOAuthCallbackUsesExplicitPaymentResumeSigningKeyWhenMixedKeysConfigured(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-mixed-key","scope":"snsapi_base"}`))
+ return
+ }
+ http.NotFound(w, r)
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback"))
+ defer client.Close()
+
+ legacyKeyHex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
+ explicitSigningKey := "explicit-payment-resume-signing-key"
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", explicitSigningKey)
+ handler.cfg.Totp.EncryptionKey = legacyKeyHex
+ handler.cfg.Totp.EncryptionKeyConfigured = true
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-mixed", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-mixed"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat"))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"18.8","order_type":"subscription","plan_id":9}`))
+ req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base"))
+ c.Request = req
+
+ handler.WeChatPaymentOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ location := recorder.Header().Get("Location")
+ parsed, err := url.Parse(location)
+ require.NoError(t, err)
+ fragment, err := url.ParseQuery(parsed.Fragment)
+ require.NoError(t, err)
+
+ token := fragment.Get("wechat_resume_token")
+ require.NotEmpty(t, token)
+
+ claims, err := service.NewPaymentResumeService([]byte(explicitSigningKey)).ParseWeChatPaymentResumeToken(token)
+ require.NoError(t, err)
+ require.Equal(t, "openid-mixed-key", claims.OpenID)
+ require.Equal(t, payment.TypeWxpay, claims.PaymentType)
+ require.Equal(t, "18.8", claims.Amount)
+ require.Equal(t, payment.OrderTypeSubscription, claims.OrderType)
+ require.EqualValues(t, 9, claims.PlanID)
+ require.Equal(t, "/purchase?from=wechat", claims.RedirectTo)
+
+ _, err = service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")).ParseWeChatPaymentResumeToken(token)
+ require.Error(t, err)
+}
+
+func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) {
+ testCases := []struct {
+ name string
+ mode string
+ appID string
+ appSecret string
+ openID string
+ }{
+ {
+ name: "open",
+ mode: "open",
+ appID: "wx-open-app",
+ appSecret: "wx-open-secret",
+ openID: "openid-open-123",
+ },
+ {
+ name: "mp",
+ mode: "mp",
+ appID: "wx-mp-app",
+ appSecret: "wx-mp-secret",
+ openID: "openid-mp-123",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"` + tc.openID + `","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"` + tc.openID + `","unionid":"union-456","nickname":"Bind Nick","headimgurl":"https://cdn.example/bind.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings(tc.mode, tc.appID, tc.appSecret, "/auth/wechat/callback"))
+ defer client.Close()
+
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(context.Background())
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, tc.mode))
+ req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, wechatOAuthIntentBind, session.Intent)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, currentUser.ID, *session.TargetUserID)
+ require.Equal(t, currentUser.Email, session.ResolvedEmail)
+ require.Equal(t, "union-456", session.ProviderSubject)
+ require.Equal(t, "union-456", session.UpstreamIdentityClaims["subject"])
+ require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"])
+ require.Equal(t, tc.openID, session.UpstreamIdentityClaims["openid"])
+ require.Equal(t, tc.mode, session.UpstreamIdentityClaims["channel"])
+ require.Equal(t, tc.appID, session.UpstreamIdentityClaims["channel_app_id"])
+ require.Equal(t, tc.openID, session.UpstreamIdentityClaims["channel_subject"])
+
+ completionResponse := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
+ require.Equal(t, "/dashboard", completionResponse["redirect"])
+ _, hasAccessToken := completionResponse["access_token"]
+ require.False(t, hasAccessToken)
+ })
+ }
+}
+
+func TestWeChatOAuthCallbackBindRejectsCanonicalOwnershipConflict(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ owner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(owner.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("union-456").
+ SetMetadata(map[string]any{"unionid": "union-456"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ owner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ ownerIdentity, err := client.AuthIdentity.Create().
+ SetUserID(owner.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("union-owner").
+ SetMetadata(map[string]any{"unionid": "union-owner"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentityChannel.Create().
+ SetIdentityID(ownerIdentity.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetChannel("open").
+ SetChannelAppID("wx-open-app").
+ SetChannelSubject("openid-123").
+ SetMetadata(map[string]any{"openid": "openid-123"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ owner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ currentUser, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("current").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(owner.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthLegacyProviderKey).
+ SetProviderSubject("union-456").
+ SetMetadata(map[string]any{"unionid": "union-456"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
+ assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT")
+
+ count, err := client.PendingAuthSession.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, count)
+}
+
+func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPendingSession(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, true)
+ defer client.Close()
+
+ ctx := context.Background()
+ redeemRepo := repository.NewRedeemCodeRepository(client)
+ require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{
+ Code: "invite-1",
+ Type: service.RedeemTypeInvitation,
+ Status: service.StatusUnused,
+ }))
+
+ callbackRecorder := httptest.NewRecorder()
+ callbackCtx, _ := gin.CreateTestContext(callbackRecorder)
+ callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ callbackReq.Host = "api.example.com"
+ callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ callbackCtx.Request = callbackReq
+
+ handler.WeChatOAuthCallback(callbackCtx)
+
+ require.Equal(t, http.StatusFound, callbackRecorder.Code)
+ require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+ sessionToken := decodeCookieValueForTest(t, sessionCookie.Value)
+
+ pendingSession, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(sessionToken)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, oauthPendingChoiceStep, pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["step"])
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`)
+ completeRecorder := httptest.NewRecorder()
+ completeCtx, _ := gin.CreateTestContext(completeRecorder)
+ completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
+ completeReq.Header.Set("Content-Type", "application/json")
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)})
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")})
+ completeCtx.Request = completeReq
+
+ handler.CompleteWeChatOAuthRegistration(completeCtx)
+
+ require.Equal(t, http.StatusOK, completeRecorder.Code)
+ responseData := decodeJSONBody(t, completeRecorder)
+ require.Equal(t, "pending_session", responseData["auth_result"])
+ require.Equal(t, oauthPendingChoiceStep, responseData["step"])
+ require.Equal(t, true, responseData["adoption_required"])
+ require.Empty(t, responseData["access_token"])
+
+ consumed, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.IDEQ(pendingSession.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Nil(t, consumed.ConsumedAt)
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ("wechat-main"),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ channelCount, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ("wechat-main"),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open-app"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, channelCount)
+
+ decisionCount, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, decisionCount)
+}
+
+func TestCompleteWeChatOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
+ handler, client := newOAuthPendingFlowTestHandler(t, false)
+ ctx := context.Background()
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("wechat-complete-no-adoption-session").
+ SetIntent("login").
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("wechat-subject-no-adoption").
+ SetResolvedEmail("wechat-subject-no-adoption@wechat-connect.invalid").
+ SetBrowserSessionKey("wechat-browser-no-adoption").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "wechat_user",
+ "suggested_display_name": "WeChat Legacy",
+ "suggested_avatar_url": "https://cdn.example/wechat-legacy.png",
+ "mode": "open",
+ "channel": "open",
+ "channel_app_id": "wx-open-app",
+ "channel_subject": "openid-legacy",
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ completeCtx, _ := gin.CreateTestContext(recorder)
+ completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
+ completeReq.Header.Set("Content-Type", "application/json")
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-browser-no-adoption")})
+ completeCtx.Request = completeReq
+
+ handler.CompleteWeChatOAuthRegistration(completeCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.NotEmpty(t, responseData["access_token"])
+ require.NotEmpty(t, responseData["refresh_token"])
+
+ userEntity, err := client.User.Query().
+ Where(dbuser.EmailEQ(session.ResolvedEmail)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "wechat_user", userEntity.Username)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderSubjectEQ("wechat-subject-no-adoption"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, userEntity.ID, identity.UserID)
+
+ decision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+ require.False(t, decision.AdoptDisplayName)
+ require.False(t, decision.AdoptAvatar)
+}
+
+func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy WeChat","headimgurl":"https://cdn.example/legacy.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ legacyUser, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyIdentity, err := client.AuthIdentity.Create().
+ SetUserID(legacyUser.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthProviderKey).
+ SetProviderSubject("openid-123").
+ SetMetadata(map[string]any{"openid": "openid-123"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, legacyUser.ID, *session.TargetUserID)
+ require.Equal(t, legacyUser.Email, session.ResolvedEmail)
+
+ repairedIdentity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, legacyIdentity.ID, repairedIdentity.ID)
+ require.Equal(t, legacyUser.ID, repairedIdentity.UserID)
+
+ openIDIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderSubjectEQ("openid-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, openIDIdentityCount)
+
+ channel, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open-app"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, repairedIdentity.ID, channel.IdentityID)
+}
+
+func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ existingUser, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetUsername("owner-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("wechat-complete-invalid-session").
+ SetIntent("adopt_existing_user_by_email").
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("union-invalid-1").
+ SetTargetUserID(existingUser.ID).
+ SetResolvedEmail(existingUser.Email).
+ SetBrowserSessionKey("wechat-invalid-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "wechat_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": "bind_login_required",
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ completeCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-invalid-browser")})
+ completeCtx.Request = req
+
+ handler.CompleteWeChatOAuthRegistration(completeCtx)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestCompleteWeChatOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("wechat-complete-choice-session").
+ SetIntent("login").
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("wechat-choice-subject-1").
+ SetResolvedEmail("wechat-choice-subject-1@wechat-connect.invalid").
+ SetBrowserSessionKey("wechat-choice-browser").
+ SetUpstreamIdentityClaims(map[string]any{
+ "username": "wechat_user",
+ }).
+ SetLocalFlowState(map[string]any{
+ oauthCompletionResponseKey: map[string]any{
+ "step": oauthPendingChoiceStep,
+ "redirect": "/dashboard",
+ "email": "fresh@example.com",
+ "resolved_email": "fresh@example.com",
+ "force_email_on_signup": true,
+ },
+ }).
+ SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
+ recorder := httptest.NewRecorder()
+ completeCtx, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
+ req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-choice-browser")})
+ completeCtx.Request = req
+
+ handler.CompleteWeChatOAuthRegistration(completeCtx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ responseData := decodeJSONBody(t, recorder)
+ require.Equal(t, "pending_session", responseData["auth_result"])
+ require.Equal(t, oauthPendingChoiceStep, responseData["step"])
+ require.Equal(t, true, responseData["force_email_on_signup"])
+ require.Empty(t, responseData["access_token"])
+
+ userCount, err := client.User.Query().Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, userCount)
+
+ storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+ require.Nil(t, storedSession.ConsumedAt)
+}
+
+func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) {
+ originalAccessTokenURL := wechatOAuthAccessTokenURL
+ originalUserInfoURL := wechatOAuthUserInfoURL
+ t.Cleanup(func() {
+ wechatOAuthAccessTokenURL = originalAccessTokenURL
+ wechatOAuthUserInfoURL = originalUserInfoURL
+ })
+
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
+ case strings.Contains(r.URL.Path, "/sns/userinfo"):
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy Canonical","headimgurl":"https://cdn.example/legacy-canonical.png"}`))
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer upstream.Close()
+ wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
+ wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
+
+ handler, client := newWeChatOAuthTestHandler(t, false)
+ defer client.Close()
+
+ ctx := context.Background()
+ legacyUser, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyIdentity, err := client.AuthIdentity.Create().
+ SetUserID(legacyUser.ID).
+ SetProviderType("wechat").
+ SetProviderKey(wechatOAuthLegacyProviderKey).
+ SetProviderSubject("union-456").
+ SetMetadata(map[string]any{"unionid": "union-456"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
+ req.Host = "api.example.com"
+ req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
+ req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
+ req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
+ req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
+ c.Request = req
+
+ handler.WeChatOAuthCallback(c)
+
+ require.Equal(t, http.StatusFound, recorder.Code)
+ require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location"))
+
+ sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
+ require.NotNil(t, sessionCookie)
+
+ session, err := client.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, legacyUser.ID, *session.TargetUserID)
+ require.Equal(t, legacyUser.Email, session.ResolvedEmail)
+
+ repairedIdentity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, legacyIdentity.ID, repairedIdentity.ID)
+ require.Equal(t, legacyUser.ID, repairedIdentity.UserID)
+
+ legacyIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ(wechatOAuthLegacyProviderKey),
+ authidentity.ProviderSubjectEQ("union-456"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, legacyIdentityCount)
+
+ channel, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open-app"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, repairedIdentity.ID, channel.IdentityID)
+}
+
+func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) {
+ return newWeChatOAuthTestHandlerWithSettings(t, invitationEnabled, nil)
+}
+
+func wechatOAuthTestSettings(mode, appID, secret, frontendRedirect string) map[string]string {
+ return map[string]string{
+ service.SettingKeyWeChatConnectEnabled: "true",
+ service.SettingKeyWeChatConnectAppID: appID,
+ service.SettingKeyWeChatConnectAppSecret: secret,
+ service.SettingKeyWeChatConnectMode: mode,
+ service.SettingKeyWeChatConnectScopes: service.DefaultWeChatConnectScopesForMode(mode),
+ service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ service.SettingKeyWeChatConnectFrontendRedirectURL: frontendRedirect,
+ }
+}
+
+func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool, extraSettings map[string]string) (*AuthHandler, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+
+ userRepo := &oauthPendingFlowUserRepo{client: client}
+ redeemRepo := repository.NewRedeemCodeRepository(client)
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 0,
+ UserConcurrency: 1,
+ },
+ }
+ values := map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled),
+ }
+ for key, value := range wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "/auth/wechat/callback") {
+ values[key] = value
+ }
+ for key, value := range extraSettings {
+ values[key] = value
+ }
+ settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{values: values}, cfg)
+
+ authSvc := service.NewAuthService(
+ client,
+ userRepo,
+ redeemRepo,
+ &wechatOAuthRefreshTokenCacheStub{},
+ cfg,
+ settingSvc,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+
+ return &AuthHandler{
+ authService: authSvc,
+ settingSvc: settingSvc,
+ cfg: cfg,
+ }, client
+}
+
+type wechatOAuthSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ return nil, service.ErrSettingNotFound
+}
+
+func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ value, ok := s.values[key]
+ if !ok {
+ return "", service.ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error {
+ return nil
+}
+
+func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ result := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ result[key] = value
+ }
+ }
+ return result, nil
+}
+
+func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ return nil
+}
+
+func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ result := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ result[key] = value
+ }
+ return result, nil
+}
+
+func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error {
+ return nil
+}
+
+type wechatOAuthRefreshTokenCacheStub struct{}
+
+func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
+ return nil, service.ErrRefreshTokenNotFound
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
diff --git a/backend/internal/handler/available_channel_handler.go b/backend/internal/handler/available_channel_handler.go
new file mode 100644
index 00000000..8982b80d
--- /dev/null
+++ b/backend/internal/handler/available_channel_handler.go
@@ -0,0 +1,283 @@
+package handler
+
+import (
+ "sort"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// AvailableChannelHandler 处理用户侧「可用渠道」查询。
+//
+// 用户侧接口委托 ChannelService.ListAvailable,并在返回前做三层过滤:
+// 1. 行过滤:只保留状态为 Active 且与当前用户可访问分组有交集的渠道;
+// 2. 分组过滤:渠道的 Groups 只保留用户可访问的那些;
+// 3. 平台过滤:渠道的 SupportedModels 只保留平台在用户可见 Groups 中出现过的模型,
+// 防止"渠道同时挂在 antigravity / anthropic 两个平台的分组上,用户只访问
+// antigravity,却看到 anthropic 模型"这类跨平台信息泄漏;
+// 4. 字段白名单:仅返回用户需要的字段(省略 BillingModelSource / RestrictModels
+// / 内部 ID / Status 等管理字段)。
+type AvailableChannelHandler struct {
+ channelService *service.ChannelService
+ apiKeyService *service.APIKeyService
+ settingService *service.SettingService
+}
+
+// NewAvailableChannelHandler 创建用户侧可用渠道 handler。
+func NewAvailableChannelHandler(
+ channelService *service.ChannelService,
+ apiKeyService *service.APIKeyService,
+ settingService *service.SettingService,
+) *AvailableChannelHandler {
+ return &AvailableChannelHandler{
+ channelService: channelService,
+ apiKeyService: apiKeyService,
+ settingService: settingService,
+ }
+}
+
+// featureEnabled 返回 available-channels 开关是否启用。默认关闭(opt-in)。
+func (h *AvailableChannelHandler) featureEnabled(c *gin.Context) bool {
+ if h.settingService == nil {
+ return false
+ }
+ return h.settingService.GetAvailableChannelsRuntime(c.Request.Context()).Enabled
+}
+
+// userAvailableGroup 用户可见的分组概要(白名单字段)。
+//
+// 前端据此区分专属 vs 公开分组(IsExclusive)、订阅 vs 标准分组(SubscriptionType,
+// 订阅视觉加深),并用 RateMultiplier 作为默认倍率;用户专属倍率前端走
+// /groups/rates,和 API 密钥页面保持一致。
+type userAvailableGroup struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Platform string `json:"platform"`
+ SubscriptionType string `json:"subscription_type"`
+ RateMultiplier float64 `json:"rate_multiplier"`
+ IsExclusive bool `json:"is_exclusive"`
+}
+
+// userSupportedModelPricing 用户可见的定价字段白名单。
+type userSupportedModelPricing struct {
+ BillingMode string `json:"billing_mode"`
+ InputPrice *float64 `json:"input_price"`
+ OutputPrice *float64 `json:"output_price"`
+ CacheWritePrice *float64 `json:"cache_write_price"`
+ CacheReadPrice *float64 `json:"cache_read_price"`
+ ImageOutputPrice *float64 `json:"image_output_price"`
+ PerRequestPrice *float64 `json:"per_request_price"`
+ Intervals []userPricingIntervalDTO `json:"intervals"`
+}
+
+// userPricingIntervalDTO 定价区间白名单(去掉内部 ID、SortOrder 等前端不渲染的字段)。
+type userPricingIntervalDTO struct {
+ MinTokens int `json:"min_tokens"`
+ MaxTokens *int `json:"max_tokens"`
+ TierLabel string `json:"tier_label,omitempty"`
+ InputPrice *float64 `json:"input_price"`
+ OutputPrice *float64 `json:"output_price"`
+ CacheWritePrice *float64 `json:"cache_write_price"`
+ CacheReadPrice *float64 `json:"cache_read_price"`
+ PerRequestPrice *float64 `json:"per_request_price"`
+}
+
+// userSupportedModel 用户可见的支持模型条目。
+type userSupportedModel struct {
+ Name string `json:"name"`
+ Platform string `json:"platform"`
+ Pricing *userSupportedModelPricing `json:"pricing"`
+}
+
+// userChannelPlatformSection 单渠道内某个平台的子视图:用户可见的分组 + 该平台
+// 支持的模型。按 platform 聚合后让前端可以把渠道名作为 row-group 一次渲染,
+// 后面的平台行按 sections 顺序铺开。
+type userChannelPlatformSection struct {
+ Platform string `json:"platform"`
+ Groups []userAvailableGroup `json:"groups"`
+ SupportedModels []userSupportedModel `json:"supported_models"`
+}
+
+// userAvailableChannel 用户可见的渠道条目(白名单字段)。
+//
+// 每个渠道聚合为一条记录,内嵌 platforms 子数组:每个 section 对应一个平台,
+// 包含该平台的 groups 和 supported_models。
+type userAvailableChannel struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Platforms []userChannelPlatformSection `json:"platforms"`
+}
+
+// List 列出当前用户可见的「可用渠道」。
+// GET /api/v1/channels/available
+func (h *AvailableChannelHandler) List(c *gin.Context) {
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ // Feature 未启用时返回空数组(不暴露渠道信息)。检查放在认证之后,
+ // 保持与未开关前的 401 行为一致:未登录先 401,登录后再按开关决定。
+ if !h.featureEnabled(c) {
+ response.Success(c, []userAvailableChannel{})
+ return
+ }
+
+ userGroups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ allowedGroupIDs := make(map[int64]struct{}, len(userGroups))
+ for i := range userGroups {
+ allowedGroupIDs[userGroups[i].ID] = struct{}{}
+ }
+
+ channels, err := h.channelService.ListAvailable(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ out := make([]userAvailableChannel, 0, len(channels))
+ for _, ch := range channels {
+ if ch.Status != service.StatusActive {
+ continue
+ }
+ visibleGroups := filterUserVisibleGroups(ch.Groups, allowedGroupIDs)
+ if len(visibleGroups) == 0 {
+ continue
+ }
+ sections := buildPlatformSections(ch, visibleGroups)
+ if len(sections) == 0 {
+ continue
+ }
+ out = append(out, userAvailableChannel{
+ Name: ch.Name,
+ Description: ch.Description,
+ Platforms: sections,
+ })
+ }
+
+ response.Success(c, out)
+}
+
+// buildPlatformSections 把一个渠道按 visibleGroups 的平台集合拆成有序的 section 列表:
+// 每个 section 对应一个平台,只包含该平台的 groups 和 supported_models。
+// 输出按 platform 字母序稳定排序,便于前端等效比较与回归测试。
+func buildPlatformSections(
+ ch service.AvailableChannel,
+ visibleGroups []userAvailableGroup,
+) []userChannelPlatformSection {
+ groupsByPlatform := make(map[string][]userAvailableGroup, 4)
+ for _, g := range visibleGroups {
+ if g.Platform == "" {
+ continue
+ }
+ groupsByPlatform[g.Platform] = append(groupsByPlatform[g.Platform], g)
+ }
+ if len(groupsByPlatform) == 0 {
+ return nil
+ }
+
+ platforms := make([]string, 0, len(groupsByPlatform))
+ for p := range groupsByPlatform {
+ platforms = append(platforms, p)
+ }
+ sort.Strings(platforms)
+
+ sections := make([]userChannelPlatformSection, 0, len(platforms))
+ for _, platform := range platforms {
+ platformSet := map[string]struct{}{platform: {}}
+ sections = append(sections, userChannelPlatformSection{
+ Platform: platform,
+ Groups: groupsByPlatform[platform],
+ SupportedModels: toUserSupportedModels(ch.SupportedModels, platformSet),
+ })
+ }
+ return sections
+}
+
+// filterUserVisibleGroups 仅保留用户可访问的分组。
+func filterUserVisibleGroups(
+ groups []service.AvailableGroupRef,
+ allowed map[int64]struct{},
+) []userAvailableGroup {
+ visible := make([]userAvailableGroup, 0, len(groups))
+ for _, g := range groups {
+ if _, ok := allowed[g.ID]; !ok {
+ continue
+ }
+ visible = append(visible, userAvailableGroup{
+ ID: g.ID,
+ Name: g.Name,
+ Platform: g.Platform,
+ SubscriptionType: g.SubscriptionType,
+ RateMultiplier: g.RateMultiplier,
+ IsExclusive: g.IsExclusive,
+ })
+ }
+ return visible
+}
+
+// toUserSupportedModels 将 service 层支持模型转换为用户 DTO(字段白名单)。
+// 仅保留平台在 allowedPlatforms 中的条目,防止跨平台模型信息泄漏。
+// allowedPlatforms 为 nil 时不做平台过滤(保留全部,供测试或明确无过滤场景使用)。
+func toUserSupportedModels(
+ src []service.SupportedModel,
+ allowedPlatforms map[string]struct{},
+) []userSupportedModel {
+ out := make([]userSupportedModel, 0, len(src))
+ for i := range src {
+ m := src[i]
+ if allowedPlatforms != nil {
+ if _, ok := allowedPlatforms[m.Platform]; !ok {
+ continue
+ }
+ }
+ out = append(out, userSupportedModel{
+ Name: m.Name,
+ Platform: m.Platform,
+ Pricing: toUserPricing(m.Pricing),
+ })
+ }
+ return out
+}
+
+// toUserPricing 将 service 层定价转换为用户 DTO;入参为 nil 时返回 nil。
+func toUserPricing(p *service.ChannelModelPricing) *userSupportedModelPricing {
+ if p == nil {
+ return nil
+ }
+ intervals := make([]userPricingIntervalDTO, 0, len(p.Intervals))
+ for _, iv := range p.Intervals {
+ intervals = append(intervals, userPricingIntervalDTO{
+ MinTokens: iv.MinTokens,
+ MaxTokens: iv.MaxTokens,
+ TierLabel: iv.TierLabel,
+ InputPrice: iv.InputPrice,
+ OutputPrice: iv.OutputPrice,
+ CacheWritePrice: iv.CacheWritePrice,
+ CacheReadPrice: iv.CacheReadPrice,
+ PerRequestPrice: iv.PerRequestPrice,
+ })
+ }
+ billingMode := string(p.BillingMode)
+ if billingMode == "" {
+ billingMode = string(service.BillingModeToken)
+ }
+ return &userSupportedModelPricing{
+ BillingMode: billingMode,
+ InputPrice: p.InputPrice,
+ OutputPrice: p.OutputPrice,
+ CacheWritePrice: p.CacheWritePrice,
+ CacheReadPrice: p.CacheReadPrice,
+ ImageOutputPrice: p.ImageOutputPrice,
+ PerRequestPrice: p.PerRequestPrice,
+ Intervals: intervals,
+ }
+}
diff --git a/backend/internal/handler/available_channel_handler_test.go b/backend/internal/handler/available_channel_handler_test.go
new file mode 100644
index 00000000..0a7ce6c4
--- /dev/null
+++ b/backend/internal/handler/available_channel_handler_test.go
@@ -0,0 +1,157 @@
+//go:build unit
+
+package handler
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserAvailableChannel_Unauthenticated401(t *testing.T) {
+ // 没有 AuthSubject 注入时,handler 应返回 401 且不触达 service 依赖。
+ gin.SetMode(gin.TestMode)
+ h := &AvailableChannelHandler{} // nil services — 401 路径不会调用它们
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/channels/available", nil)
+
+ h.List(c)
+
+ require.Equal(t, http.StatusUnauthorized, w.Code)
+}
+
+func TestFilterUserVisibleGroups_IntersectionOnly(t *testing.T) {
+ // 渠道挂在 {g1, g2, g3},用户只允许 {g1, g3} —— 响应必须仅含 g1/g3。
+ groups := []service.AvailableGroupRef{
+ {ID: 1, Name: "g1", Platform: "anthropic"},
+ {ID: 2, Name: "g2", Platform: "anthropic"},
+ {ID: 3, Name: "g3", Platform: "openai"},
+ }
+ allowed := map[int64]struct{}{1: {}, 3: {}}
+
+ visible := filterUserVisibleGroups(groups, allowed)
+ require.Len(t, visible, 2)
+ ids := []int64{visible[0].ID, visible[1].ID}
+ require.ElementsMatch(t, []int64{1, 3}, ids)
+}
+
+func TestToUserSupportedModels_FiltersByAllowedPlatforms(t *testing.T) {
+ // 用户可访问分组只覆盖 anthropic;anthropic 平台的模型保留,openai 模型被剔除。
+ src := []service.SupportedModel{
+ {Name: "claude-sonnet-4-6", Platform: "anthropic", Pricing: nil},
+ {Name: "gpt-4o", Platform: "openai", Pricing: nil},
+ }
+ allowed := map[string]struct{}{"anthropic": {}}
+ out := toUserSupportedModels(src, allowed)
+ require.Len(t, out, 1)
+ require.Equal(t, "claude-sonnet-4-6", out[0].Name)
+}
+
+func TestToUserSupportedModels_NilAllowedPlatformsKeepsAll(t *testing.T) {
+ // 显式传 nil allowedPlatforms 表示不做过滤。
+ src := []service.SupportedModel{
+ {Name: "a", Platform: "anthropic"},
+ {Name: "b", Platform: "openai"},
+ }
+ require.Len(t, toUserSupportedModels(src, nil), 2)
+}
+
+func TestUserAvailableChannel_FieldWhitelist(t *testing.T) {
+ // 通过序列化 userAvailableChannel 结构体验证响应形状:
+ // 只有 name / description / platforms;不含管理端字段。
+ row := userAvailableChannel{
+ Name: "ch",
+ Description: "d",
+ Platforms: []userChannelPlatformSection{
+ {
+ Platform: "anthropic",
+ Groups: []userAvailableGroup{{ID: 1, Name: "g1", Platform: "anthropic"}},
+ SupportedModels: []userSupportedModel{},
+ },
+ },
+ }
+ raw, err := json.Marshal(row)
+ require.NoError(t, err)
+ var decoded map[string]any
+ require.NoError(t, json.Unmarshal(raw, &decoded))
+
+ for _, key := range []string{"id", "status", "billing_model_source", "restrict_models"} {
+ _, exists := decoded[key]
+ require.Falsef(t, exists, "user DTO must not expose %q", key)
+ }
+ for _, key := range []string{"name", "description", "platforms"} {
+ _, exists := decoded[key]
+ require.Truef(t, exists, "user DTO must expose %q", key)
+ }
+
+ // 验证 section 的字段(platform / groups / supported_models)。
+ rawSection, err := json.Marshal(row.Platforms[0])
+ require.NoError(t, err)
+ var sectionDecoded map[string]any
+ require.NoError(t, json.Unmarshal(rawSection, §ionDecoded))
+ for _, key := range []string{"platform", "groups", "supported_models"} {
+ _, exists := sectionDecoded[key]
+ require.Truef(t, exists, "platform section must expose %q", key)
+ }
+
+ // Group DTO 暴露区分专属/公开、订阅类型、默认倍率所需的字段,
+ // 前端据此渲染 GroupBadge 并与 API 密钥页保持一致的视觉。
+ rawGroup, err := json.Marshal(row.Platforms[0].Groups[0])
+ require.NoError(t, err)
+ var groupDecoded map[string]any
+ require.NoError(t, json.Unmarshal(rawGroup, &groupDecoded))
+ for _, key := range []string{"id", "name", "platform", "subscription_type", "rate_multiplier", "is_exclusive"} {
+ _, exists := groupDecoded[key]
+ require.Truef(t, exists, "group DTO must expose %q", key)
+ }
+
+ // pricing interval 白名单:不应暴露 id / sort_order。
+ pricing := toUserPricing(&service.ChannelModelPricing{
+ BillingMode: service.BillingModeToken,
+ Intervals: []service.PricingInterval{
+ {ID: 7, MinTokens: 0, MaxTokens: nil, SortOrder: 3},
+ },
+ })
+ require.NotNil(t, pricing)
+ require.Len(t, pricing.Intervals, 1)
+ rawIv, err := json.Marshal(pricing.Intervals[0])
+ require.NoError(t, err)
+ var ivDecoded map[string]any
+ require.NoError(t, json.Unmarshal(rawIv, &ivDecoded))
+ for _, key := range []string{"id", "pricing_id", "sort_order"} {
+ _, exists := ivDecoded[key]
+ require.Falsef(t, exists, "user pricing interval must not expose %q", key)
+ }
+}
+
+func TestBuildPlatformSections_GroupsByPlatform(t *testing.T) {
+ // 一个渠道横跨 anthropic / openai / 空平台:应该生成 2 个 section,
+ // 按 platform 字母序排序,各自 groups 和 supported_models 只含同平台条目。
+ ch := service.AvailableChannel{
+ Name: "ch",
+ SupportedModels: []service.SupportedModel{
+ {Name: "claude-sonnet-4-6", Platform: "anthropic"},
+ {Name: "gpt-4o", Platform: "openai"},
+ },
+ }
+ visible := []userAvailableGroup{
+ {ID: 1, Name: "g-openai", Platform: "openai"},
+ {ID: 2, Name: "g-ant", Platform: "anthropic"},
+ {ID: 3, Name: "g-empty", Platform: ""},
+ }
+ sections := buildPlatformSections(ch, visible)
+ require.Len(t, sections, 2)
+ require.Equal(t, "anthropic", sections[0].Platform)
+ require.Equal(t, "openai", sections[1].Platform)
+ require.Len(t, sections[0].Groups, 1)
+ require.Equal(t, int64(2), sections[0].Groups[0].ID)
+ require.Len(t, sections[0].SupportedModels, 1)
+ require.Equal(t, "claude-sonnet-4-6", sections[0].SupportedModels[0].Name)
+}
diff --git a/backend/internal/handler/channel_monitor_user_handler.go b/backend/internal/handler/channel_monitor_user_handler.go
new file mode 100644
index 00000000..cc36b334
--- /dev/null
+++ b/backend/internal/handler/channel_monitor_user_handler.go
@@ -0,0 +1,176 @@
+package handler
+
+import (
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/handler/admin"
+ "github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// ChannelMonitorUserHandler 渠道监控用户只读 handler。
+type ChannelMonitorUserHandler struct {
+ monitorService *service.ChannelMonitorService
+ settingService *service.SettingService
+}
+
+// NewChannelMonitorUserHandler 创建 handler。
+// settingService 用于每次请求前读取功能开关;关闭时 List/GetStatus 直接返回空/404。
+func NewChannelMonitorUserHandler(
+ monitorService *service.ChannelMonitorService,
+ settingService *service.SettingService,
+) *ChannelMonitorUserHandler {
+ return &ChannelMonitorUserHandler{
+ monitorService: monitorService,
+ settingService: settingService,
+ }
+}
+
+// featureEnabled 返回当前渠道监控功能是否开启。
+// settingService 为 nil(测试场景)视为启用。
+func (h *ChannelMonitorUserHandler) featureEnabled(c *gin.Context) bool {
+ if h.settingService == nil {
+ return true
+ }
+ return h.settingService.GetChannelMonitorRuntime(c.Request.Context()).Enabled
+}
+
+// --- Response ---
+
+type channelMonitorUserListItem struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Provider string `json:"provider"`
+ GroupName string `json:"group_name"`
+ PrimaryModel string `json:"primary_model"`
+ PrimaryStatus string `json:"primary_status"`
+ PrimaryLatencyMs *int `json:"primary_latency_ms"`
+ PrimaryPingLatencyMs *int `json:"primary_ping_latency_ms"`
+ Availability7d float64 `json:"availability_7d"`
+ ExtraModels []dto.ChannelMonitorExtraModelStatus `json:"extra_models"`
+ Timeline []channelMonitorUserTimelinePoint `json:"timeline"`
+}
+
+// channelMonitorUserTimelinePoint 主模型最近一次检测的 timeline 点。
+// 仅用于用户视图 list 响应,admin 视图不使用。
+type channelMonitorUserTimelinePoint struct {
+ Status string `json:"status"`
+ LatencyMs *int `json:"latency_ms"`
+ PingLatencyMs *int `json:"ping_latency_ms"`
+ CheckedAt string `json:"checked_at"`
+}
+
+type channelMonitorUserDetailResponse struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Provider string `json:"provider"`
+ GroupName string `json:"group_name"`
+ Models []channelMonitorUserModelStat `json:"models"`
+}
+
+type channelMonitorUserModelStat struct {
+ Model string `json:"model"`
+ LatestStatus string `json:"latest_status"`
+ LatestLatencyMs *int `json:"latest_latency_ms"`
+ Availability7d float64 `json:"availability_7d"`
+ Availability15d float64 `json:"availability_15d"`
+ Availability30d float64 `json:"availability_30d"`
+ AvgLatency7dMs *int `json:"avg_latency_7d_ms"`
+}
+
+func userMonitorViewToItem(v *service.UserMonitorView) channelMonitorUserListItem {
+ extras := make([]dto.ChannelMonitorExtraModelStatus, 0, len(v.ExtraModels))
+ for _, e := range v.ExtraModels {
+ extras = append(extras, dto.ChannelMonitorExtraModelStatus{
+ Model: e.Model,
+ Status: e.Status,
+ LatencyMs: e.LatencyMs,
+ })
+ }
+ timeline := make([]channelMonitorUserTimelinePoint, 0, len(v.Timeline))
+ for _, p := range v.Timeline {
+ timeline = append(timeline, channelMonitorUserTimelinePoint{
+ Status: p.Status,
+ LatencyMs: p.LatencyMs,
+ PingLatencyMs: p.PingLatencyMs,
+ CheckedAt: p.CheckedAt.UTC().Format(time.RFC3339),
+ })
+ }
+ return channelMonitorUserListItem{
+ ID: v.ID,
+ Name: v.Name,
+ Provider: v.Provider,
+ GroupName: v.GroupName,
+ PrimaryModel: v.PrimaryModel,
+ PrimaryStatus: v.PrimaryStatus,
+ PrimaryLatencyMs: v.PrimaryLatencyMs,
+ PrimaryPingLatencyMs: v.PrimaryPingLatencyMs,
+ Availability7d: v.Availability7d,
+ ExtraModels: extras,
+ Timeline: timeline,
+ }
+}
+
+func userMonitorDetailToResponse(d *service.UserMonitorDetail) *channelMonitorUserDetailResponse {
+ models := make([]channelMonitorUserModelStat, 0, len(d.Models))
+ for _, m := range d.Models {
+ models = append(models, channelMonitorUserModelStat{
+ Model: m.Model,
+ LatestStatus: m.LatestStatus,
+ LatestLatencyMs: m.LatestLatencyMs,
+ Availability7d: m.Availability7d,
+ Availability15d: m.Availability15d,
+ Availability30d: m.Availability30d,
+ AvgLatency7dMs: m.AvgLatency7dMs,
+ })
+ }
+ return &channelMonitorUserDetailResponse{
+ ID: d.ID,
+ Name: d.Name,
+ Provider: d.Provider,
+ GroupName: d.GroupName,
+ Models: models,
+ }
+}
+
+// --- Handlers ---
+
+// List GET /api/v1/channel-monitors
+func (h *ChannelMonitorUserHandler) List(c *gin.Context) {
+ if !h.featureEnabled(c) {
+ response.Success(c, gin.H{"items": []channelMonitorUserListItem{}})
+ return
+ }
+ views, err := h.monitorService.ListUserView(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ items := make([]channelMonitorUserListItem, 0, len(views))
+ for _, v := range views {
+ items = append(items, userMonitorViewToItem(v))
+ }
+ response.Success(c, gin.H{"items": items})
+}
+
+// GetStatus GET /api/v1/channel-monitors/:id/status
+func (h *ChannelMonitorUserHandler) GetStatus(c *gin.Context) {
+ if !h.featureEnabled(c) {
+ response.ErrorFrom(c, service.ErrChannelMonitorNotFound)
+ return
+ }
+ // 复用 admin.ParseChannelMonitorID 保持错误码与日志一致。
+ id, ok := admin.ParseChannelMonitorID(c)
+ if !ok {
+ return
+ }
+ detail, err := h.monitorService.GetUserDetail(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, userMonitorDetailToResponse(detail))
+}
diff --git a/backend/internal/handler/dto/channel_monitor.go b/backend/internal/handler/dto/channel_monitor.go
new file mode 100644
index 00000000..3c0c5e11
--- /dev/null
+++ b/backend/internal/handler/dto/channel_monitor.go
@@ -0,0 +1,10 @@
+package dto
+
+// ChannelMonitorExtraModelStatus 渠道监控附加模型最近一次状态。
+// 同时被 admin handler(List 响应)与 user handler(List 响应)复用,
+// 字段必须保持一致以保证前端拿到统一结构。
+type ChannelMonitorExtraModelStatus struct {
+ Model string `json:"model"`
+ Status string `json:"status"`
+ LatencyMs *int `json:"latency_ms"`
+}
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index a8da92c0..f7503c2e 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -13,16 +13,23 @@ func UserFromServiceShallow(u *service.User) *User {
return nil
}
return &User{
- ID: u.ID,
- Email: u.Email,
- Username: u.Username,
- Role: u.Role,
- Balance: u.Balance,
- Concurrency: u.Concurrency,
- Status: u.Status,
- AllowedGroups: u.AllowedGroups,
- CreatedAt: u.CreatedAt,
- UpdatedAt: u.UpdatedAt,
+ ID: u.ID,
+ Email: u.Email,
+ Username: u.Username,
+ Role: u.Role,
+ Balance: u.Balance,
+ Concurrency: u.Concurrency,
+ Status: u.Status,
+ AllowedGroups: u.AllowedGroups,
+ LastActiveAt: u.LastActiveAt,
+ CreatedAt: u.CreatedAt,
+ UpdatedAt: u.UpdatedAt,
+ BalanceNotifyEnabled: u.BalanceNotifyEnabled,
+ BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
+ BalanceNotifyThreshold: u.BalanceNotifyThreshold,
+ BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails),
+ TotalRecharged: u.TotalRecharged,
+ RPMLimit: u.RPMLimit,
}
}
@@ -59,11 +66,10 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
return nil
}
return &AdminUser{
- User: *base,
- Notes: u.Notes,
- GroupRates: u.GroupRates,
- SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
- SoraStorageUsedBytes: u.SoraStorageUsedBytes,
+ User: *base,
+ Notes: u.Notes,
+ LastUsedAt: u.LastUsedAt,
+ GroupRates: u.GroupRates,
}
}
@@ -135,16 +141,17 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
return nil
}
out := &AdminGroup{
- Group: groupFromServiceBase(g),
- ModelRouting: g.ModelRouting,
- ModelRoutingEnabled: g.ModelRoutingEnabled,
- MCPXMLInject: g.MCPXMLInject,
- DefaultMappedModel: g.DefaultMappedModel,
- SupportedModelScopes: g.SupportedModelScopes,
- AccountCount: g.AccountCount,
- ActiveAccountCount: g.ActiveAccountCount,
- RateLimitedAccountCount: g.RateLimitedAccountCount,
- SortOrder: g.SortOrder,
+ Group: groupFromServiceBase(g),
+ ModelRouting: g.ModelRouting,
+ ModelRoutingEnabled: g.ModelRoutingEnabled,
+ MCPXMLInject: g.MCPXMLInject,
+ DefaultMappedModel: g.DefaultMappedModel,
+ MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
+ SupportedModelScopes: g.SupportedModelScopes,
+ AccountCount: g.AccountCount,
+ ActiveAccountCount: g.ActiveAccountCount,
+ RateLimitedAccountCount: g.RateLimitedAccountCount,
+ SortOrder: g.SortOrder,
}
if len(g.AccountGroups) > 0 {
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
@@ -172,17 +179,13 @@ func groupFromServiceBase(g *service.Group) Group {
ImagePrice1K: g.ImagePrice1K,
ImagePrice2K: g.ImagePrice2K,
ImagePrice4K: g.ImagePrice4K,
- SoraImagePrice360: g.SoraImagePrice360,
- SoraImagePrice540: g.SoraImagePrice540,
- SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
- SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
- SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
AllowMessagesDispatch: g.AllowMessagesDispatch,
RequireOAuthOnly: g.RequireOAuthOnly,
RequirePrivacySet: g.RequirePrivacySet,
+ RPMLimit: g.RPMLimit,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
@@ -328,6 +331,26 @@ func AccountFromServiceShallow(a *service.Account) *Account {
out.QuotaWeeklyResetAt = &v
}
}
+
+ // 配额通知配置
+ if enabled := a.GetQuotaNotifyDailyEnabled(); enabled {
+ out.QuotaNotifyDailyEnabled = &enabled
+ }
+ if threshold := a.GetQuotaNotifyDailyThreshold(); threshold > 0 {
+ out.QuotaNotifyDailyThreshold = &threshold
+ }
+ if enabled := a.GetQuotaNotifyWeeklyEnabled(); enabled {
+ out.QuotaNotifyWeeklyEnabled = &enabled
+ }
+ if threshold := a.GetQuotaNotifyWeeklyThreshold(); threshold > 0 {
+ out.QuotaNotifyWeeklyThreshold = &threshold
+ }
+ if enabled := a.GetQuotaNotifyTotalEnabled(); enabled {
+ out.QuotaNotifyTotalEnabled = &enabled
+ }
+ if threshold := a.GetQuotaNotifyTotalThreshold(); threshold > 0 {
+ out.QuotaNotifyTotalThreshold = &threshold
+ }
}
return out
@@ -577,6 +600,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
MediaType: l.MediaType,
UserAgent: l.UserAgent,
CacheTTLOverridden: l.CacheTTLOverridden,
+ BillingMode: l.BillingMode,
CreatedAt: l.CreatedAt,
User: UserFromServiceShallow(l.User),
APIKey: APIKeyFromService(l.APIKey),
@@ -604,7 +628,11 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
return &AdminUsageLog{
UsageLog: usageLogFromServiceUser(l),
UpstreamModel: l.UpstreamModel,
+ ChannelID: l.ChannelID,
+ ModelMappingChain: l.ModelMappingChain,
+ BillingTier: l.BillingTier,
AccountRateMultiplier: l.AccountRateMultiplier,
+ AccountStatsCost: l.AccountStatsCost,
IPAddress: l.IPAddress,
Account: AccountSummaryFromService(l.Account),
}
diff --git a/backend/internal/handler/dto/notify_email_entry.go b/backend/internal/handler/dto/notify_email_entry.go
new file mode 100644
index 00000000..78641005
--- /dev/null
+++ b/backend/internal/handler/dto/notify_email_entry.go
@@ -0,0 +1,43 @@
+package dto
+
+import "github.com/Wei-Shaw/sub2api/internal/service"
+
+// NotifyEmailEntry represents a notification email with enable/disable and verification state.
+// All emails are user-managed; maximum 3 entries per user.
+type NotifyEmailEntry struct {
+ Email string `json:"email"`
+ Disabled bool `json:"disabled"`
+ Verified bool `json:"verified"`
+}
+
+// NotifyEmailEntriesFromService converts service entries to DTO entries.
+func NotifyEmailEntriesFromService(entries []service.NotifyEmailEntry) []NotifyEmailEntry {
+ if entries == nil {
+ return nil
+ }
+ result := make([]NotifyEmailEntry, len(entries))
+ for i, e := range entries {
+ result[i] = NotifyEmailEntry{
+ Email: e.Email,
+ Disabled: e.Disabled,
+ Verified: e.Verified,
+ }
+ }
+ return result
+}
+
+// NotifyEmailEntriesToService converts DTO entries to service entries.
+func NotifyEmailEntriesToService(entries []NotifyEmailEntry) []service.NotifyEmailEntry {
+ if entries == nil {
+ return nil
+ }
+ result := make([]service.NotifyEmailEntry, len(entries))
+ for i, e := range entries {
+ result[i] = service.NotifyEmailEntry{
+ Email: e.Email,
+ Disabled: e.Disabled,
+ Verified: e.Verified,
+ }
+ }
+ return result
+}
diff --git a/backend/internal/handler/dto/public_settings_injection_schema_test.go b/backend/internal/handler/dto/public_settings_injection_schema_test.go
new file mode 100644
index 00000000..428fed3d
--- /dev/null
+++ b/backend/internal/handler/dto/public_settings_injection_schema_test.go
@@ -0,0 +1,70 @@
+package dto
+
+import (
+ "reflect"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+// TestPublicSettingsInjectionPayload_SchemaDoesNotDrift guarantees the SSR
+// injection struct exposes every JSON field consumed by the frontend.
+//
+// Why this test exists: before we extracted a named PublicSettingsInjectionPayload
+// type, the inline struct was manually kept in sync with dto.PublicSettings and
+// drifted — ChannelMonitorEnabled / AvailableChannelsEnabled were missing, which
+// made the frontend read `undefined` on refresh and hide the "可用渠道" menu
+// until the async /api/v1/settings/public round-trip finished.
+//
+// This test compares the two JSON-tag sets and fails if injection is missing
+// any field that dto.PublicSettings exposes. Adding a new feature flag with
+// only a DTO entry will fail this test until the injection struct is updated.
+//
+// Intentional exclusions (fields present on dto.PublicSettings that SSR does
+// not need to inject) are listed in `dtoOnlyFields` below with a reason.
+func TestPublicSettingsInjectionPayload_SchemaDoesNotDrift(t *testing.T) {
+ injection := jsonTags(reflect.TypeOf(service.PublicSettingsInjectionPayload{}))
+ dtoKeys := jsonTags(reflect.TypeOf(PublicSettings{}))
+
+ // Fields that legitimately live only on the DTO. Keep tiny; document each.
+ dtoOnlyFields := map[string]string{
+ // sora_client_enabled is an upstream-only field the fork does not surface.
+ "sora_client_enabled": "upstream-only field, not used on this fork",
+ // force_email_on_third_party_signup lives on the DTO but is not injected via SSR.
+ "force_email_on_third_party_signup": "auth-source default, not a feature flag",
+ }
+
+ var missing []string
+ for key := range dtoKeys {
+ if _, ok := injection[key]; ok {
+ continue
+ }
+ if _, allowed := dtoOnlyFields[key]; allowed {
+ continue
+ }
+ missing = append(missing, key)
+ }
+ if len(missing) > 0 {
+ t.Fatalf("service.PublicSettingsInjectionPayload is missing JSON fields present on dto.PublicSettings: %s\n"+
+ "add the field to PublicSettingsInjectionPayload (and GetPublicSettingsForInjection), or "+
+ "document the exclusion in dtoOnlyFields with a reason.", strings.Join(missing, ", "))
+ }
+}
+
+func jsonTags(t reflect.Type) map[string]struct{} {
+ out := make(map[string]struct{})
+ for i := 0; i < t.NumField(); i++ {
+ f := t.Field(i)
+ tag := f.Tag.Get("json")
+ if tag == "" || tag == "-" {
+ continue
+ }
+ name := strings.SplitN(tag, ",", 2)[0]
+ if name == "" {
+ continue
+ }
+ out[name] = struct{}{}
+ }
+ return out
+}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 47bab091..92ae4dc6 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -51,6 +51,46 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
+ WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
+ WeChatConnectAppID string `json:"wechat_connect_app_id"`
+ WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"`
+ WeChatConnectOpenAppID string `json:"wechat_connect_open_app_id"`
+ WeChatConnectOpenAppSecretConfigured bool `json:"wechat_connect_open_app_secret_configured"`
+ WeChatConnectMPAppID string `json:"wechat_connect_mp_app_id"`
+ WeChatConnectMPAppSecretConfigured bool `json:"wechat_connect_mp_app_secret_configured"`
+ WeChatConnectMobileAppID string `json:"wechat_connect_mobile_app_id"`
+ WeChatConnectMobileAppSecretConfigured bool `json:"wechat_connect_mobile_app_secret_configured"`
+ WeChatConnectOpenEnabled bool `json:"wechat_connect_open_enabled"`
+ WeChatConnectMPEnabled bool `json:"wechat_connect_mp_enabled"`
+ WeChatConnectMobileEnabled bool `json:"wechat_connect_mobile_enabled"`
+ WeChatConnectMode string `json:"wechat_connect_mode"`
+ WeChatConnectScopes string `json:"wechat_connect_scopes"`
+ WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"`
+ WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"`
+
+ OIDCConnectEnabled bool `json:"oidc_connect_enabled"`
+ OIDCConnectProviderName string `json:"oidc_connect_provider_name"`
+ OIDCConnectClientID string `json:"oidc_connect_client_id"`
+ OIDCConnectClientSecretConfigured bool `json:"oidc_connect_client_secret_configured"`
+ OIDCConnectIssuerURL string `json:"oidc_connect_issuer_url"`
+ OIDCConnectDiscoveryURL string `json:"oidc_connect_discovery_url"`
+ OIDCConnectAuthorizeURL string `json:"oidc_connect_authorize_url"`
+ OIDCConnectTokenURL string `json:"oidc_connect_token_url"`
+ OIDCConnectUserInfoURL string `json:"oidc_connect_userinfo_url"`
+ OIDCConnectJWKSURL string `json:"oidc_connect_jwks_url"`
+ OIDCConnectScopes string `json:"oidc_connect_scopes"`
+ OIDCConnectRedirectURL string `json:"oidc_connect_redirect_url"`
+ OIDCConnectFrontendRedirectURL string `json:"oidc_connect_frontend_redirect_url"`
+ OIDCConnectTokenAuthMethod string `json:"oidc_connect_token_auth_method"`
+ OIDCConnectUsePKCE bool `json:"oidc_connect_use_pkce"`
+ OIDCConnectValidateIDToken bool `json:"oidc_connect_validate_id_token"`
+ OIDCConnectAllowedSigningAlgs string `json:"oidc_connect_allowed_signing_algs"`
+ OIDCConnectClockSkewSeconds int `json:"oidc_connect_clock_skew_seconds"`
+ OIDCConnectRequireEmailVerified bool `json:"oidc_connect_require_email_verified"`
+ OIDCConnectUserInfoEmailPath string `json:"oidc_connect_userinfo_email_path"`
+ OIDCConnectUserInfoIDPath string `json:"oidc_connect_userinfo_id_path"`
+ OIDCConnectUserInfoUsernamePath string `json:"oidc_connect_userinfo_username_path"`
+
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
@@ -61,13 +101,19 @@ type SystemSettings struct {
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
- SoraClientEnabled bool `json:"sora_client_enabled"`
+ TableDefaultPageSize int `json:"table_default_page_size"`
+ TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
- DefaultConcurrency int `json:"default_concurrency"`
- DefaultBalance float64 `json:"default_balance"`
- DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
+ DefaultConcurrency int `json:"default_concurrency"`
+ DefaultBalance float64 `json:"default_balance"`
+ AffiliateRebateRate float64 `json:"affiliate_rebate_rate"`
+ AffiliateRebateFreezeHours int `json:"affiliate_rebate_freeze_hours"`
+ AffiliateRebateDurationDays int `json:"affiliate_rebate_duration_days"`
+ AffiliateRebatePerInviteeCap float64 `json:"affiliate_rebate_per_invitee_cap"`
+ DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
+ DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -98,6 +144,60 @@ type SystemSettings struct {
// Gateway forwarding behavior
EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
+ EnableCCHSigning bool `json:"enable_cch_signing"`
+
+ // Web Search Emulation
+ WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
+
+ // Payment visible method routing
+ PaymentVisibleMethodAlipaySource string `json:"payment_visible_method_alipay_source"`
+ PaymentVisibleMethodWxpaySource string `json:"payment_visible_method_wxpay_source"`
+ PaymentVisibleMethodAlipayEnabled bool `json:"payment_visible_method_alipay_enabled"`
+ PaymentVisibleMethodWxpayEnabled bool `json:"payment_visible_method_wxpay_enabled"`
+
+ // OpenAI account scheduling
+ OpenAIAdvancedSchedulerEnabled bool `json:"openai_advanced_scheduler_enabled"`
+
+ // Payment configuration
+ PaymentEnabled bool `json:"payment_enabled"`
+ PaymentMinAmount float64 `json:"payment_min_amount"`
+ PaymentMaxAmount float64 `json:"payment_max_amount"`
+ PaymentDailyLimit float64 `json:"payment_daily_limit"`
+ PaymentOrderTimeoutMin int `json:"payment_order_timeout_minutes"`
+ PaymentMaxPendingOrders int `json:"payment_max_pending_orders"`
+ PaymentEnabledTypes []string `json:"payment_enabled_types"`
+ PaymentBalanceDisabled bool `json:"payment_balance_disabled"`
+ PaymentBalanceRechargeMultiplier float64 `json:"payment_balance_recharge_multiplier"`
+ PaymentRechargeFeeRate float64 `json:"payment_recharge_fee_rate"`
+ PaymentLoadBalanceStrat string `json:"payment_load_balance_strategy"`
+ PaymentProductNamePrefix string `json:"payment_product_name_prefix"`
+ PaymentProductNameSuffix string `json:"payment_product_name_suffix"`
+ PaymentHelpImageURL string `json:"payment_help_image_url"`
+ PaymentHelpText string `json:"payment_help_text"`
+
+ // Cancel rate limit
+ PaymentCancelRateLimitEnabled bool `json:"payment_cancel_rate_limit_enabled"`
+ PaymentCancelRateLimitMax int `json:"payment_cancel_rate_limit_max"`
+ PaymentCancelRateLimitWindow int `json:"payment_cancel_rate_limit_window"`
+ PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"`
+ PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"`
+
+ // Balance low notification
+ BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
+ BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
+ BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
+ AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
+ AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"`
+
+ // Channel Monitor feature switch
+ ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
+ ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
+
+ // Available Channels feature switch (user-facing aggregate view)
+ AvailableChannelsEnabled bool `json:"available_channels_enabled"`
+
+ // Affiliate (邀请返利) feature switch
+ AffiliateEnabled bool `json:"affiliate_enabled"`
}
type DefaultSubscriptionSetting struct {
@@ -108,6 +208,7 @@ type DefaultSubscriptionSetting struct {
type PublicSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"`
PromoCodeEnabled bool `json:"promo_code_enabled"`
PasswordResetEnabled bool `json:"password_reset_enabled"`
@@ -125,50 +226,32 @@ type PublicSettings struct {
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL string `json:"purchase_subscription_url"`
+ TableDefaultPageSize int `json:"table_default_page_size"`
+ TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
+ WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
+ WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
+ WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
+ WeChatOAuthMobileEnabled bool `json:"wechat_oauth_mobile_enabled"`
+ OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
+ OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
SoraClientEnabled bool `json:"sora_client_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
+ PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version"`
-}
+ BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
+ AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
+ BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
+ BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
-// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段)
-type SoraS3Settings struct {
- Enabled bool `json:"enabled"`
- Endpoint string `json:"endpoint"`
- Region string `json:"region"`
- Bucket string `json:"bucket"`
- AccessKeyID string `json:"access_key_id"`
- SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
- Prefix string `json:"prefix"`
- ForcePathStyle bool `json:"force_path_style"`
- CDNURL string `json:"cdn_url"`
- DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
-}
+ ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
+ ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"`
-// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段)
-type SoraS3Profile struct {
- ProfileID string `json:"profile_id"`
- Name string `json:"name"`
- IsActive bool `json:"is_active"`
- Enabled bool `json:"enabled"`
- Endpoint string `json:"endpoint"`
- Region string `json:"region"`
- Bucket string `json:"bucket"`
- AccessKeyID string `json:"access_key_id"`
- SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
- Prefix string `json:"prefix"`
- ForcePathStyle bool `json:"force_path_style"`
- CDNURL string `json:"cdn_url"`
- DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
- UpdatedAt string `json:"updated_at"`
-}
+ AvailableChannelsEnabled bool `json:"available_channels_enabled"`
-// ListSoraS3ProfilesResponse Sora S3 配置列表响应
-type ListSoraS3ProfilesResponse struct {
- ActiveProfileID string `json:"active_profile_id"`
- Items []SoraS3Profile `json:"items"`
+ AffiliateEnabled bool `json:"affiliate_enabled"`
}
// OverloadCooldownSettings 529过载冷却配置 DTO
@@ -197,10 +280,13 @@ type RectifierSettings struct {
// BetaPolicyRule Beta 策略规则 DTO
type BetaPolicyRule struct {
- BetaToken string `json:"beta_token"`
- Action string `json:"action"`
- Scope string `json:"scope"`
- ErrorMessage string `json:"error_message,omitempty"`
+ BetaToken string `json:"beta_token"`
+ Action string `json:"action"`
+ Scope string `json:"scope"`
+ ErrorMessage string `json:"error_message,omitempty"`
+ ModelWhitelist []string `json:"model_whitelist,omitempty"`
+ FallbackAction string `json:"fallback_action,omitempty"`
+ FallbackErrorMessage string `json:"fallback_error_message,omitempty"`
}
// BetaPolicySettings Beta 策略配置 DTO
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 46984044..5cc2f8e4 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -1,18 +1,33 @@
package dto
-import "time"
+import (
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/domain"
+)
type User struct {
- ID int64 `json:"id"`
- Email string `json:"email"`
- Username string `json:"username"`
- Role string `json:"role"`
- Balance float64 `json:"balance"`
- Concurrency int `json:"concurrency"`
- Status string `json:"status"`
- AllowedGroups []int64 `json:"allowed_groups"`
- CreatedAt time.Time `json:"created_at"`
- UpdatedAt time.Time `json:"updated_at"`
+ ID int64 `json:"id"`
+ Email string `json:"email"`
+ Username string `json:"username"`
+ Role string `json:"role"`
+ Balance float64 `json:"balance"`
+ Concurrency int `json:"concurrency"`
+ Status string `json:"status"`
+ AllowedGroups []int64 `json:"allowed_groups"`
+ LastActiveAt *time.Time `json:"last_active_at,omitempty"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+
+ // 余额不足通知
+ BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
+ BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
+ BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"`
+ TotalRecharged float64 `json:"total_recharged"`
+
+ // RPMLimit 用户级每分钟请求数上限(0 = 不限制),仅在所用分组未设置 rpm_limit 时作为兜底生效。
+ RPMLimit int `json:"rpm_limit"`
APIKeys []APIKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
@@ -23,12 +38,11 @@ type User struct {
type AdminUser struct {
User
- Notes string `json:"notes"`
+ Notes string `json:"notes"`
+ LastUsedAt *time.Time `json:"last_used_at"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
- GroupRates map[int64]float64 `json:"group_rates,omitempty"`
- SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
- SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"`
+ GroupRates map[int64]float64 `json:"group_rates,omitempty"`
}
type APIKey struct {
@@ -84,21 +98,12 @@ type Group struct {
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
- // Sora 按次计费配置
- SoraImagePrice360 *float64 `json:"sora_image_price_360"`
- SoraImagePrice540 *float64 `json:"sora_image_price_540"`
- SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
- SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
-
// Claude Code 客户端限制
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
// 无效请求兜底分组
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
- // Sora 存储配额
- SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
-
// OpenAI Messages 调度开关(用户侧需要此字段判断是否展示 Claude Code 教程)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
@@ -106,6 +111,9 @@ type Group struct {
RequireOAuthOnly bool `json:"require_oauth_only"`
RequirePrivacySet bool `json:"require_privacy_set"`
+ // RPMLimit 分组级每分钟请求数上限(0 = 不限制),设置后覆盖用户级 rpm_limit。
+ RPMLimit int `json:"rpm_limit"`
+
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
@@ -123,7 +131,8 @@ type AdminGroup struct {
MCPXMLInject bool `json:"mcp_xml_inject"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
- DefaultMappedModel string `json:"default_mapped_model"`
+ DefaultMappedModel string `json:"default_mapped_model"`
+ MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
@@ -224,6 +233,14 @@ type Account struct {
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
+ // 配额通知配置
+ QuotaNotifyDailyEnabled *bool `json:"quota_notify_daily_enabled,omitempty"`
+ QuotaNotifyDailyThreshold *float64 `json:"quota_notify_daily_threshold,omitempty"`
+ QuotaNotifyWeeklyEnabled *bool `json:"quota_notify_weekly_enabled,omitempty"`
+ QuotaNotifyWeeklyThreshold *float64 `json:"quota_notify_weekly_threshold,omitempty"`
+ QuotaNotifyTotalEnabled *bool `json:"quota_notify_total_enabled,omitempty"`
+ QuotaNotifyTotalThreshold *float64 `json:"quota_notify_total_threshold,omitempty"`
+
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
@@ -390,6 +407,9 @@ type UsageLog struct {
// Cache TTL Override 标记
CacheTTLOverridden bool `json:"cache_ttl_overridden"`
+ // BillingMode 计费模式:token/image
+ BillingMode *string `json:"billing_mode,omitempty"`
+
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`
@@ -406,8 +426,17 @@ type AdminUsageLog struct {
// Omitted when no mapping was applied (requested model was used as-is).
UpstreamModel *string `json:"upstream_model,omitempty"`
+ // ChannelID 渠道 ID
+ ChannelID *int64 `json:"channel_id,omitempty"`
+ // ModelMappingChain 模型映射链,如 "a→b→c"
+ ModelMappingChain *string `json:"model_mapping_chain,omitempty"`
+ // BillingTier 计费层级标签(per_request/image 模式)
+ BillingTier *string `json:"billing_tier,omitempty"`
+
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
+ // AccountStatsCost 自定义定价规则计算的账号统计费用(nil 表示使用默认公式)
+ AccountStatsCost *float64 `json:"account_stats_cost,omitempty"`
// IPAddress 用户请求 IP(仅管理员可见)
IPAddress *string `json:"ip_address,omitempty"`
diff --git a/backend/internal/handler/dto/user_mapper_activity_test.go b/backend/internal/handler/dto/user_mapper_activity_test.go
new file mode 100644
index 00000000..a17f0ce4
--- /dev/null
+++ b/backend/internal/handler/dto/user_mapper_activity_test.go
@@ -0,0 +1,33 @@
+package dto
+
+import (
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserFromServiceAdmin_MapsActivityTimestamps(t *testing.T) {
+ t.Parallel()
+
+ lastLoginAt := time.Date(2026, time.April, 20, 10, 0, 0, 0, time.UTC)
+ lastActiveAt := lastLoginAt.Add(15 * time.Minute)
+ lastUsedAt := lastLoginAt.Add(45 * time.Minute)
+
+ out := UserFromServiceAdmin(&service.User{
+ ID: 42,
+ Email: "admin@example.com",
+ Username: "admin",
+ Role: service.RoleAdmin,
+ Status: service.StatusActive,
+ LastActiveAt: &lastActiveAt,
+ LastUsedAt: &lastUsedAt,
+ })
+
+ require.NotNil(t, out)
+ require.NotNil(t, out.LastActiveAt)
+ require.NotNil(t, out.LastUsedAt)
+ require.WithinDuration(t, lastActiveAt, *out.LastActiveAt, time.Second)
+ require.WithinDuration(t, lastUsedAt, *out.LastUsedAt, time.Second)
+}
diff --git a/backend/internal/handler/endpoint.go b/backend/internal/handler/endpoint.go
index b1200988..db29618a 100644
--- a/backend/internal/handler/endpoint.go
+++ b/backend/internal/handler/endpoint.go
@@ -15,10 +15,12 @@ import (
// ──────────────────────────────────────────────────────────
const (
- EndpointMessages = "/v1/messages"
- EndpointChatCompletions = "/v1/chat/completions"
- EndpointResponses = "/v1/responses"
- EndpointGeminiModels = "/v1beta/models"
+ EndpointMessages = "/v1/messages"
+ EndpointChatCompletions = "/v1/chat/completions"
+ EndpointResponses = "/v1/responses"
+ EndpointImagesGenerations = "/v1/images/generations"
+ EndpointImagesEdits = "/v1/images/edits"
+ EndpointGeminiModels = "/v1beta/models"
)
// gin.Context keys used by the middleware and helpers below.
@@ -31,7 +33,7 @@ const (
// ──────────────────────────────────────────────────────────
// NormalizeInboundEndpoint maps a raw request path (which may carry
-// prefixes like /antigravity, /openai, /sora) to its canonical form.
+// prefixes like /antigravity, /openai) to its canonical form.
//
// "/antigravity/v1/messages" → "/v1/messages"
// "/v1/chat/completions" → "/v1/chat/completions"
@@ -44,6 +46,10 @@ func NormalizeInboundEndpoint(path string) string {
return EndpointChatCompletions
case strings.Contains(path, EndpointMessages):
return EndpointMessages
+ case strings.Contains(path, EndpointImagesGenerations) || strings.Contains(path, "/images/generations"):
+ return EndpointImagesGenerations
+ case strings.Contains(path, EndpointImagesEdits) || strings.Contains(path, "/images/edits"):
+ return EndpointImagesEdits
case strings.Contains(path, EndpointResponses):
return EndpointResponses
case strings.Contains(path, EndpointGeminiModels):
@@ -61,7 +67,7 @@ func NormalizeInboundEndpoint(path string) string {
// such as /v1/responses/compact preserved from the raw URL).
// - Anthropic → /v1/messages
// - Gemini → /v1beta/models
-// - Sora → /v1/chat/completions
+// - Antigravity → /v1/messages (Claude) or gemini (Gemini)
// - Antigravity routes may target either Claude or Gemini, so the
// inbound endpoint is used to distinguish.
func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
@@ -69,6 +75,9 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
switch platform {
case service.PlatformOpenAI:
+ if inbound == EndpointImagesGenerations || inbound == EndpointImagesEdits {
+ return inbound
+ }
// OpenAI forwards everything to the Responses API.
// Preserve subresource suffix (e.g. /v1/responses/compact).
if suffix := responsesSubpathSuffix(rawRequestPath); suffix != "" {
@@ -82,9 +91,6 @@ func DeriveUpstreamEndpoint(inbound, rawRequestPath, platform string) string {
case service.PlatformGemini:
return EndpointGeminiModels
- case service.PlatformSora:
- return EndpointChatCompletions
-
case service.PlatformAntigravity:
// Antigravity accounts serve both Claude and Gemini.
if inbound == EndpointGeminiModels {
diff --git a/backend/internal/handler/endpoint_test.go b/backend/internal/handler/endpoint_test.go
index a3767ac4..369c5fa7 100644
--- a/backend/internal/handler/endpoint_test.go
+++ b/backend/internal/handler/endpoint_test.go
@@ -25,13 +25,16 @@ func TestNormalizeInboundEndpoint(t *testing.T) {
{"/v1/messages", EndpointMessages},
{"/v1/chat/completions", EndpointChatCompletions},
{"/v1/responses", EndpointResponses},
+ {"/v1/images/generations", EndpointImagesGenerations},
+ {"/v1/images/edits", EndpointImagesEdits},
{"/v1beta/models", EndpointGeminiModels},
- // Prefixed paths (antigravity, openai, sora).
+ // Prefixed paths (antigravity, openai).
{"/antigravity/v1/messages", EndpointMessages},
{"/openai/v1/responses", EndpointResponses},
{"/openai/v1/responses/compact", EndpointResponses},
- {"/sora/v1/chat/completions", EndpointChatCompletions},
+ {"/openai/v1/images/generations", EndpointImagesGenerations},
+ {"/openai/v1/images/edits", EndpointImagesEdits},
{"/antigravity/v1beta/models/gemini:generateContent", EndpointGeminiModels},
// Gin route patterns with wildcards.
@@ -68,15 +71,14 @@ func TestDeriveUpstreamEndpoint(t *testing.T) {
// Gemini.
{"gemini models", EndpointGeminiModels, "/v1beta/models/gemini:gen", service.PlatformGemini, EndpointGeminiModels},
- // Sora.
- {"sora completions", EndpointChatCompletions, "/sora/v1/chat/completions", service.PlatformSora, EndpointChatCompletions},
-
// OpenAI — always /v1/responses.
{"openai responses root", EndpointResponses, "/v1/responses", service.PlatformOpenAI, EndpointResponses},
{"openai responses compact", EndpointResponses, "/openai/v1/responses/compact", service.PlatformOpenAI, "/v1/responses/compact"},
{"openai responses nested", EndpointResponses, "/openai/v1/responses/compact/detail", service.PlatformOpenAI, "/v1/responses/compact/detail"},
{"openai from messages", EndpointMessages, "/v1/messages", service.PlatformOpenAI, EndpointResponses},
{"openai from completions", EndpointChatCompletions, "/v1/chat/completions", service.PlatformOpenAI, EndpointResponses},
+ {"openai image generations", EndpointImagesGenerations, "/v1/images/generations", service.PlatformOpenAI, EndpointImagesGenerations},
+ {"openai image edits", EndpointImagesEdits, "/openai/v1/images/edits", service.PlatformOpenAI, EndpointImagesEdits},
// Antigravity — uses inbound to pick Claude vs Gemini upstream.
{"antigravity claude", EndpointMessages, "/antigravity/v1/messages", service.PlatformAntigravity, EndpointMessages},
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index a0d8b2e9..ef532559 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -158,6 +158,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqStream := parsedReq.Stream
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
+ // 解析渠道级模型映射
+ channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
+
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
@@ -240,11 +243,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err))
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
+ // 设置请求所属分组 ID(用于渠道级功能判断,如 WebSearch 模拟)
+ parsedReq.GroupID = apiKey.GroupID
+
// 计算粘性会话hash
parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c),
@@ -292,9 +301,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
for {
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
+ selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
+ reqLog.Warn("gateway.select_account_no_available",
+ zap.String("model", reqModel),
+ zap.Int64p("group_id", apiKey.GroupID),
+ zap.String("platform", platform),
+ zap.Error(err),
+ )
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
@@ -338,6 +353,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
+ reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
+ zap.Int64("account_id", account.ID),
+ zap.String("model", reqModel),
+ zap.String("platform", platform),
+ )
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
@@ -467,6 +487,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
+ ParsedRequest: parsedReq,
APIKey: apiKey,
User: apiKey.User,
Account: account,
@@ -478,6 +499,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
+ ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),
@@ -514,9 +536,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
// 选择支持该模型的账号
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID)
+ selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
+ reqLog.Warn("gateway.select_account_no_available",
+ zap.String("model", reqModel),
+ zap.Int64p("group_id", currentAPIKey.GroupID),
+ zap.String("platform", platform),
+ zap.Bool("fallback_used", fallbackUsed),
+ zap.Error(err),
+ )
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
@@ -560,6 +589,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
+ reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
+ zap.Int64("account_id", account.ID),
+ zap.String("model", reqModel),
+ zap.String("platform", platform),
+ )
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
@@ -660,7 +694,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
parsedReq.OnUpstreamAccepted = queueRelease
// ===== 用户消息串行队列 END =====
+ // 应用渠道模型映射到请求
+ if channelMapping.Mapped {
+ parsedReq.Model = channelMapping.MappedModel
+ parsedReq.Body = h.gatewayService.ReplaceModelInBody(parsedReq.Body, channelMapping.MappedModel)
+ body = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
+ }
+
// 转发请求 - 根据账号平台分流
+ c.Set("parsed_request", parsedReq)
var result *service.ForwardResult
requestCtx := c.Request.Context()
if fs.SwitchCount > 0 {
@@ -719,7 +761,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil {
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -799,6 +844,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
+ ParsedRequest: parsedReq,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: account,
@@ -810,6 +856,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
RequestPayloadHash: requestPayloadHash,
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
+ ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),
@@ -847,14 +894,6 @@ func (h *GatewayHandler) Models(c *gin.Context) {
platform = forcedPlatform
}
- if platform == service.PlatformSora {
- c.JSON(http.StatusOK, gin.H{
- "object": "list",
- "data": service.DefaultSoraModels(h.cfg),
- })
- return
- }
-
// Get available models from account configurations (without platform filter)
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
@@ -1431,7 +1470,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility(订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.errorResponse(c, status, code, message)
return
}
@@ -1674,25 +1716,32 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
c.JSON(http.StatusOK, response)
}
-func billingErrorDetails(err error) (status int, code, message string) {
+func billingErrorDetails(err error) (status int, code, message string, retryAfter int) {
if errors.Is(err, service.ErrBillingServiceUnavailable) {
msg := pkgerrors.Message(err)
if msg == "" {
msg = "Billing service temporarily unavailable. Please retry later."
}
- return http.StatusServiceUnavailable, "billing_service_error", msg
+ return http.StatusServiceUnavailable, "billing_service_error", msg, 0
}
if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) {
msg := pkgerrors.Message(err)
- return http.StatusTooManyRequests, "rate_limit_exceeded", msg
+ return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
}
if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) {
msg := pkgerrors.Message(err)
- return http.StatusTooManyRequests, "rate_limit_exceeded", msg
+ return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
}
if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) {
msg := pkgerrors.Message(err)
- return http.StatusTooManyRequests, "rate_limit_exceeded", msg
+ return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0
+ }
+ // 用户/分组 RPM 超限统一映射为 HTTP 429;保留与其它 rate_limit 一致的错误码便于客户端分类。
+ // 返回 Retry-After 秒数(当前分钟剩余秒数),让 SDK 自动退避。
+ if errors.Is(err, service.ErrGroupRPMExceeded) || errors.Is(err, service.ErrUserRPMExceeded) {
+ msg := pkgerrors.Message(err)
+ retrySeconds := 60 - int(time.Now().Unix()%60)
+ return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds
}
msg := pkgerrors.Message(err)
if msg == "" {
@@ -1702,7 +1751,7 @@ func billingErrorDetails(err error) (status int, code, message string) {
).Warn("gateway.billing_error_missing_message")
msg = "Billing error"
}
- return http.StatusForbidden, "billing_error", msg
+ return http.StatusForbidden, "billing_error", msg, 0
}
func (h *GatewayHandler) metadataBridgeEnabled() bool {
diff --git a/backend/internal/handler/gateway_handler_billing_error_test.go b/backend/internal/handler/gateway_handler_billing_error_test.go
new file mode 100644
index 00000000..e8a88802
--- /dev/null
+++ b/backend/internal/handler/gateway_handler_billing_error_test.go
@@ -0,0 +1,54 @@
+package handler
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestBillingErrorDetails_MapsGroupRPMExceededToTooManyRequests(t *testing.T) {
+ status, code, msg, retryAfter := billingErrorDetails(service.ErrGroupRPMExceeded)
+ require.Equal(t, http.StatusTooManyRequests, status)
+ require.Equal(t, "rate_limit_exceeded", code)
+ require.NotEmpty(t, msg)
+ require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
+ require.LessOrEqual(t, retryAfter, 60)
+}
+
+func TestBillingErrorDetails_MapsUserRPMExceededToTooManyRequests(t *testing.T) {
+ status, code, msg, retryAfter := billingErrorDetails(service.ErrUserRPMExceeded)
+ require.Equal(t, http.StatusTooManyRequests, status)
+ require.Equal(t, "rate_limit_exceeded", code)
+ require.NotEmpty(t, msg)
+ require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After")
+ require.LessOrEqual(t, retryAfter, 60)
+}
+
+func TestBillingErrorDetails_APIKeyRateLimitStillMaps(t *testing.T) {
+ // 回归保护:加 RPM 分支后不应影响已有 APIKey rate limit 的映射。
+ for _, err := range []error{
+ service.ErrAPIKeyRateLimit5hExceeded,
+ service.ErrAPIKeyRateLimit1dExceeded,
+ service.ErrAPIKeyRateLimit7dExceeded,
+ } {
+ status, code, _, _ := billingErrorDetails(err)
+ require.Equal(t, http.StatusTooManyRequests, status, "status for %v", err)
+ require.Equal(t, "rate_limit_exceeded", code)
+ }
+}
+
+func TestBillingErrorDetails_BillingServiceUnavailableMapsTo503(t *testing.T) {
+ status, code, _, retryAfter := billingErrorDetails(service.ErrBillingServiceUnavailable)
+ require.Equal(t, http.StatusServiceUnavailable, status)
+ require.Equal(t, "billing_service_error", code)
+ require.Equal(t, 0, retryAfter, "non-RPM errors should not set Retry-After")
+}
+
+func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) {
+ status, code, msg, _ := billingErrorDetails(service.ErrInsufficientBalance)
+ require.Equal(t, http.StatusForbidden, status)
+ require.Equal(t, "billing_error", code)
+ require.NotEmpty(t, msg)
+}
diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go
index da376036..4290e54b 100644
--- a/backend/internal/handler/gateway_handler_chat_completions.go
+++ b/backend/internal/handler/gateway_handler_chat_completions.go
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
+ "strconv"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
@@ -80,6 +81,9 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
+ // 解析渠道级模型映射
+ channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
+
// Claude Code only restriction
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error",
@@ -133,7 +137,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err))
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.chatCompletionsErrorResponse(c, status, code, message)
return
}
@@ -154,7 +161,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
fs := NewFailoverState(h.maxAccountSwitches, false)
for {
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
+ selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
@@ -203,7 +210,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
- result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq)
+ forwardBody := body
+ if channelMapping.Mapped {
+ forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
+ }
+ result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
if accountReleaseFunc != nil {
accountReleaseFunc()
@@ -255,6 +266,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
+ ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
reqLog.Error("gateway.cc.record_usage_failed",
zap.Int64("account_id", account.ID),
diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go
index d146d724..683cf2b7 100644
--- a/backend/internal/handler/gateway_handler_responses.go
+++ b/backend/internal/handler/gateway_handler_responses.go
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
+ "strconv"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
@@ -80,6 +81,9 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
+ // 解析渠道级模型映射
+ channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
+
// Claude Code only restriction:
// /v1/responses is never a Claude Code endpoint.
// When claude_code_only is enabled, this endpoint is rejected.
@@ -138,7 +142,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err))
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.responsesErrorResponse(c, status, code, message)
return
}
@@ -159,7 +166,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
fs := NewFailoverState(h.maxAccountSwitches, false)
for {
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "")
+ selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
@@ -208,7 +215,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
- result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq)
+ forwardBody := body
+ if channelMapping.Mapped {
+ forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
+ }
+ result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, forwardBody, parsedReq)
if accountReleaseFunc != nil {
accountReleaseFunc()
@@ -261,6 +272,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
+ ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
reqLog.Error("gateway.responses.record_usage_failed",
zap.Int64("account_id", account.ID),
diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
index 69c8d1d5..71030140 100644
--- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
+++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
@@ -34,7 +34,12 @@ func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerB
func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error {
return nil
}
-func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) {
+func (f *fakeSchedulerCache) GetAccount(_ context.Context, id int64) (*service.Account, error) {
+ for _, account := range f.accounts {
+ if account != nil && account.ID == id {
+ return account, nil
+ }
+ }
return nil, nil
}
func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil }
@@ -161,11 +166,14 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // digestStore
nil, // settingService
nil, // tlsFPProfileService
+ nil, // channelService
+ nil, // resolver
+ nil, // balanceNotifyService
)
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
cfg := &config.Config{RunMode: config.RunModeSimple}
- billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
+ billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg)
concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{})
concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0)
diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go
index 524c6b6d..2a34e3f0 100644
--- a/backend/internal/handler/gemini_v1beta_handler.go
+++ b/backend/internal/handler/gemini_v1beta_handler.go
@@ -9,6 +9,7 @@ import (
"errors"
"net/http"
"regexp"
+ "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/domain"
@@ -184,6 +185,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
setOpsRequestContext(c, modelName, stream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
+ // 解析渠道级模型映射
+ channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
+ reqModel := modelName // 保存映射前的原始模型名
+ if channelMapping.Mapped {
+ modelName = channelMapping.MappedModel
+ }
+
// Get subscription (may be nil)
subscription, _ := middleware.GetSubscriptionFromContext(c)
@@ -234,7 +242,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err))
- status, _, message := billingErrorDetails(err)
+ status, _, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
googleError(c, status, message)
return
}
@@ -353,7 +364,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
for {
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制
+ selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
@@ -523,6 +534,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
LongContextMultiplier: 2.0, // 超出部分双倍计费
ForceCacheBilling: fs.ForceCacheBilling,
APIKeyService: h.apiKeyService,
+ ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.gemini_v1beta.models"),
diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go
index b2467eac..13e3ac88 100644
--- a/backend/internal/handler/handler.go
+++ b/backend/internal/handler/handler.go
@@ -6,48 +6,55 @@ import (
// AdminHandlers contains all admin-related HTTP handlers
type AdminHandlers struct {
- Dashboard *admin.DashboardHandler
- User *admin.UserHandler
- Group *admin.GroupHandler
- Account *admin.AccountHandler
- Announcement *admin.AnnouncementHandler
- DataManagement *admin.DataManagementHandler
- Backup *admin.BackupHandler
- OAuth *admin.OAuthHandler
- OpenAIOAuth *admin.OpenAIOAuthHandler
- GeminiOAuth *admin.GeminiOAuthHandler
- AntigravityOAuth *admin.AntigravityOAuthHandler
- Proxy *admin.ProxyHandler
- Redeem *admin.RedeemHandler
- Promo *admin.PromoHandler
- Setting *admin.SettingHandler
- Ops *admin.OpsHandler
- System *admin.SystemHandler
- Subscription *admin.SubscriptionHandler
- Usage *admin.UsageHandler
- UserAttribute *admin.UserAttributeHandler
- ErrorPassthrough *admin.ErrorPassthroughHandler
- TLSFingerprintProfile *admin.TLSFingerprintProfileHandler
- APIKey *admin.AdminAPIKeyHandler
- ScheduledTest *admin.ScheduledTestHandler
+ Dashboard *admin.DashboardHandler
+ User *admin.UserHandler
+ Group *admin.GroupHandler
+ Account *admin.AccountHandler
+ Announcement *admin.AnnouncementHandler
+ DataManagement *admin.DataManagementHandler
+ Backup *admin.BackupHandler
+ OAuth *admin.OAuthHandler
+ OpenAIOAuth *admin.OpenAIOAuthHandler
+ GeminiOAuth *admin.GeminiOAuthHandler
+ AntigravityOAuth *admin.AntigravityOAuthHandler
+ Proxy *admin.ProxyHandler
+ Redeem *admin.RedeemHandler
+ Promo *admin.PromoHandler
+ Setting *admin.SettingHandler
+ Ops *admin.OpsHandler
+ System *admin.SystemHandler
+ Subscription *admin.SubscriptionHandler
+ Usage *admin.UsageHandler
+ UserAttribute *admin.UserAttributeHandler
+ ErrorPassthrough *admin.ErrorPassthroughHandler
+ TLSFingerprintProfile *admin.TLSFingerprintProfileHandler
+ APIKey *admin.AdminAPIKeyHandler
+ ScheduledTest *admin.ScheduledTestHandler
+ Channel *admin.ChannelHandler
+ ChannelMonitor *admin.ChannelMonitorHandler
+ ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler
+ Payment *admin.PaymentHandler
+ Affiliate *admin.AffiliateHandler
}
// Handlers contains all HTTP handlers
type Handlers struct {
- Auth *AuthHandler
- User *UserHandler
- APIKey *APIKeyHandler
- Usage *UsageHandler
- Redeem *RedeemHandler
- Subscription *SubscriptionHandler
- Announcement *AnnouncementHandler
- Admin *AdminHandlers
- Gateway *GatewayHandler
- OpenAIGateway *OpenAIGatewayHandler
- SoraGateway *SoraGatewayHandler
- SoraClient *SoraClientHandler
- Setting *SettingHandler
- Totp *TotpHandler
+ Auth *AuthHandler
+ User *UserHandler
+ APIKey *APIKeyHandler
+ Usage *UsageHandler
+ Redeem *RedeemHandler
+ Subscription *SubscriptionHandler
+ Announcement *AnnouncementHandler
+ ChannelMonitor *ChannelMonitorUserHandler
+ Admin *AdminHandlers
+ Gateway *GatewayHandler
+ OpenAIGateway *OpenAIGatewayHandler
+ Setting *SettingHandler
+ Totp *TotpHandler
+ Payment *PaymentHandler
+ PaymentWebhook *PaymentWebhookHandler
+ AvailableChannel *AvailableChannelHandler
}
// BuildInfo contains build-time information
diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go
index 0c94aa21..f395970a 100644
--- a/backend/internal/handler/openai_chat_completions.go
+++ b/backend/internal/handler/openai_chat_completions.go
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net/http"
+ "strconv"
"time"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
@@ -79,6 +80,9 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
+ // 解析渠道级模型映射
+ channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
+
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
@@ -98,7 +102,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -123,6 +130,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
+ false,
)
if err != nil {
reqLog.Warn("openai_chat_completions.account_select_failed",
@@ -146,6 +154,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
defaultModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
+ false,
)
if err == nil && selection != nil {
c.Set("openai_chat_completions_fallback_model", defaultModel)
@@ -183,7 +192,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
forwardStart := time.Now()
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model"))
- result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
+ forwardBody := body
+ if channelMapping.Mapped {
+ forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
+ }
+ result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
@@ -257,16 +270,17 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
- Result: result,
- APIKey: apiKey,
- User: apiKey.User,
- Account: account,
- Subscription: subscription,
- InboundEndpoint: GetInboundEndpoint(c),
- UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
- UserAgent: userAgent,
- IPAddress: clientIP,
- APIKeyService: h.apiKeyService,
+ Result: result,
+ APIKey: apiKey,
+ User: apiKey.User,
+ Account: account,
+ Subscription: subscription,
+ InboundEndpoint: GetInboundEndpoint(c),
+ UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
+ UserAgent: userAgent,
+ IPAddress: clientIP,
+ APIKeyService: h.apiKeyService,
+ ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.chat_completions"),
diff --git a/backend/internal/handler/openai_gateway_compact_log_test.go b/backend/internal/handler/openai_gateway_compact_log_test.go
index 062f318b..e18509b4 100644
--- a/backend/internal/handler/openai_gateway_compact_log_test.go
+++ b/backend/internal/handler/openai_gateway_compact_log_test.go
@@ -116,7 +116,7 @@ func TestLogOpenAIRemoteCompactOutcome_Succeeded(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
+ c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
c.Set(opsModelKey, "gpt-5.3-codex")
c.Set(opsAccountIDKey, int64(123))
c.Header("x-request-id", "rid-compact-ok")
@@ -142,7 +142,7 @@ func TestLogOpenAIRemoteCompactOutcome_Failed(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/responses/compact", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
+ c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
c.Status(http.StatusBadGateway)
h := &OpenAIGatewayHandler{}
@@ -180,7 +180,7 @@ func TestOpenAIResponses_CompactUnauthorizedLogsFailed(t *testing.T) {
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"gpt-5.3-codex"}`))
c.Request.Header.Set("Content-Type", "application/json")
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.104.0")
+ c.Request.Header.Set("User-Agent", "codex_cli_rs/0.125.0")
h := &OpenAIGatewayHandler{}
h.Responses(c)
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index ae70cee4..7676ffa3 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -47,6 +47,13 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode
return strings.TrimSpace(apiKey.Group.DefaultMappedModel)
}
+func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string {
+ if apiKey == nil || apiKey.Group == nil {
+ return ""
+ }
+ return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel))
+}
+
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
@@ -180,11 +187,19 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id")
return
}
+ reqLog.Warn("openai.request_validation_failed",
+ zap.String("reason", "previous_response_id_requires_wsv2"),
+ )
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id is only supported on Responses WebSocket v2")
+ return
}
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
+ // 解析渠道级模型映射
+ channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
+
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
return
@@ -213,13 +228,17 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
// Generate session hash (header first; fallback to prompt_cache_key)
sessionHash := h.gatewayService.GenerateSessionHash(c, sessionHashBody)
+ requireCompact := isOpenAIRemoteCompactPath(c)
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
@@ -238,6 +257,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
reqModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
+ requireCompact,
)
if err != nil {
reqLog.Warn("openai.account_select_failed",
@@ -245,6 +265,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
+ if errors.Is(err, service.ErrNoAvailableCompactAccounts) {
+ h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "compact_not_supported", "No available OpenAI accounts support /responses/compact", streamStarted)
+ return
+ }
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
@@ -284,7 +308,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Forward request
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
- result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
+ // 应用渠道模型映射到请求体
+ forwardBody := body
+ if channelMapping.Mapped {
+ forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
+ }
+ result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
@@ -379,6 +408,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
+ ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
@@ -542,6 +572,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
}
reqModel := modelResult.String()
routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel)
+ preferredMappedModel := resolveOpenAIMessagesDispatchMappedModel(apiKey, reqModel)
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
@@ -549,6 +580,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
+ // 解析渠道级模型映射
+ channelMappingMsg, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
+
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
@@ -569,7 +603,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err))
- status, code, message := billingErrorDetails(err)
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
h.anthropicStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -597,48 +634,30 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
failedAccountIDs := make(map[int64]struct{})
sameAccountRetryCount := make(map[int64]int)
var lastFailoverErr *service.UpstreamFailoverError
+ effectiveMappedModel := preferredMappedModel
for {
- // 清除上一次迭代的降级模型标记,避免残留影响本次迭代
- c.Set("openai_messages_fallback_model", "")
+ currentRoutingModel := routingModel
+ if effectiveMappedModel != "" {
+ currentRoutingModel = effectiveMappedModel
+ }
reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
c.Request.Context(),
apiKey.GroupID,
"", // no previous_response_id
sessionHash,
- routingModel,
+ currentRoutingModel,
failedAccountIDs,
service.OpenAIUpstreamTransportAny,
+ false,
)
if err != nil {
reqLog.Warn("openai_messages.account_select_failed",
zap.Error(err),
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
- // 首次调度失败 + 有默认映射模型 → 用默认模型重试
if len(failedAccountIDs) == 0 {
- defaultModel := ""
- if apiKey.Group != nil {
- defaultModel = apiKey.Group.DefaultMappedModel
- }
- if defaultModel != "" && defaultModel != routingModel {
- reqLog.Info("openai_messages.fallback_to_default_model",
- zap.String("default_mapped_model", defaultModel),
- )
- selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
- c.Request.Context(),
- apiKey.GroupID,
- "",
- sessionHash,
- defaultModel,
- failedAccountIDs,
- service.OpenAIUpstreamTransportAny,
- )
- if err == nil && selection != nil {
- c.Set("openai_messages_fallback_model", defaultModel)
- }
- }
if err != nil {
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
@@ -670,10 +689,13 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
- // Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的
- // Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。
- defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model"))
- result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
+ defaultMappedModel := strings.TrimSpace(effectiveMappedModel)
+ // 应用渠道模型映射到请求体
+ forwardBody := body
+ if channelMappingMsg.Mapped {
+ forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel)
+ }
+ result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
@@ -759,6 +781,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
IPAddress: clientIP,
RequestPayloadHash: requestPayloadHash,
APIKeyService: h.apiKeyService,
+ ChannelUsageFields: channelMappingMsg.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.messages"),
@@ -851,7 +874,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context,
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_call_id"),
)
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2")
return false
}
if validation.HasItemReferenceForAllCallIDs {
@@ -861,7 +884,7 @@ func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context,
reqLog.Warn("openai.request_validation_failed",
zap.String("reason", "function_call_output_missing_item_reference"),
)
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id on HTTP requests; continuation via previous_response_id is only supported on Responses WebSocket v2")
return false
}
@@ -1101,6 +1124,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
+ // 解析渠道级模型映射
+ channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel)
+
var currentUserRelease func()
var currentAccountRelease func()
releaseTurnSlots := func() {
@@ -1148,6 +1174,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
reqModel,
nil,
service.OpenAIUpstreamTransportResponsesWebsocketV2,
+ false,
)
if err != nil {
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
@@ -1259,6 +1286,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
IPAddress: clientIP,
RequestPayloadHash: service.HashUsageRequestPayload(firstMessage),
APIKeyService: h.apiKeyService,
+ ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel),
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),
@@ -1270,7 +1298,13 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
},
}
- if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil {
+ // 应用渠道模型映射到 WebSocket 首条消息
+ wsFirstMessage := firstMessage
+ if channelMappingWS.Mapped {
+ wsFirstMessage = h.gatewayService.ReplaceModelInBody(firstMessage, channelMappingWS.MappedModel)
+ }
+
+ if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, wsFirstMessage, hooks); err != nil {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_proxy_failed",
diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go
index 7bbf94ec..8ecee59a 100644
--- a/backend/internal/handler/openai_gateway_handler_test.go
+++ b/backend/internal/handler/openai_gateway_handler_test.go
@@ -360,7 +360,7 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 "))
})
- t.Run("uses_group_default_on_normal_path", func(t *testing.T) {
+ t.Run("uses_group_default_when_explicit_fallback_absent", func(t *testing.T) {
apiKey := &service.APIKey{
Group: &service.Group{DefaultMappedModel: "gpt-5.4"},
}
@@ -376,6 +376,45 @@ func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) {
})
}
+func TestResolveOpenAIMessagesDispatchMappedModel(t *testing.T) {
+ t.Run("exact_claude_model_override_wins", func(t *testing.T) {
+ apiKey := &service.APIKey{
+ Group: &service.Group{
+ MessagesDispatchModelConfig: service.OpenAIMessagesDispatchModelConfig{
+ SonnetMappedModel: "gpt-5.2",
+ ExactModelMappings: map[string]string{
+ "claude-sonnet-4-5-20250929": "gpt-5.4-mini-high",
+ },
+ },
+ },
+ }
+ require.Equal(t, "gpt-5.4-mini", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-sonnet-4-5-20250929"))
+ })
+
+ t.Run("uses_family_default_when_no_override", func(t *testing.T) {
+ apiKey := &service.APIKey{Group: &service.Group{}}
+ require.Equal(t, "gpt-5.4", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-opus-4-6"))
+ require.Equal(t, "gpt-5.3-codex", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-sonnet-4-5-20250929"))
+ require.Equal(t, "gpt-5.4-mini", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-haiku-4-5-20251001"))
+ })
+
+ t.Run("returns_empty_for_non_claude_or_missing_group", func(t *testing.T) {
+ require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(nil, "claude-sonnet-4-5-20250929"))
+ require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(&service.APIKey{}, "claude-sonnet-4-5-20250929"))
+ require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(&service.APIKey{Group: &service.Group{}}, "gpt-5.4"))
+ })
+
+ t.Run("does_not_fall_back_to_group_default_mapped_model", func(t *testing.T) {
+ apiKey := &service.APIKey{
+ Group: &service.Group{
+ DefaultMappedModel: "gpt-5.4",
+ },
+ }
+ require.Empty(t, resolveOpenAIMessagesDispatchMappedModel(apiKey, "gpt-5.4"))
+ require.Equal(t, "gpt-5.3-codex", resolveOpenAIMessagesDispatchMappedModel(apiKey, "claude-sonnet-4-5-20250929"))
+ })
+}
+
func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -455,6 +494,64 @@ func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
require.Contains(t, w.Body.String(), "previous_response_id must be a response.id")
}
+func TestOpenAIResponses_RejectsHTTPContinuationPreviousResponseID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
+ `{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123456","input":[{"type":"input_text","text":"hello"}]}`,
+ ))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ groupID := int64(2)
+ c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
+ ID: 101,
+ GroupID: &groupID,
+ User: &service.User{ID: 1},
+ })
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
+ UserID: 1,
+ Concurrency: 1,
+ })
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ h.Responses(c)
+
+ require.Equal(t, http.StatusBadRequest, w.Code)
+ require.Contains(t, w.Body.String(), "Responses WebSocket v2")
+ require.Contains(t, w.Body.String(), "previous_response_id")
+}
+
+func TestOpenAIResponses_FunctionCallOutputHTTPGuidanceDoesNotSuggestPreviousResponseReuse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
+ `{"model":"gpt-5.1","stream":false,"input":[{"type":"function_call_output","output":"{}"}]}`,
+ ))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ groupID := int64(2)
+ c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
+ ID: 101,
+ GroupID: &groupID,
+ User: &service.User{ID: 1},
+ })
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
+ UserID: 1,
+ Concurrency: 1,
+ })
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ h.Responses(c)
+
+ require.Equal(t, http.StatusBadRequest, w.Code)
+ require.Contains(t, w.Body.String(), "Responses WebSocket v2")
+ require.NotContains(t, w.Body.String(), "reuse previous_response_id")
+}
+
func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) {
gin.SetMode(gin.TestMode)
diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go
new file mode 100644
index 00000000..4d0078a7
--- /dev/null
+++ b/backend/internal/handler/openai_images.go
@@ -0,0 +1,299 @@
+package handler
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+
+ pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "go.uber.org/zap"
+)
+
+// Images handles OpenAI Images API requests.
+// POST /v1/images/generations
+// POST /v1/images/edits
+func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
+ streamStarted := false
+ defer h.recoverResponsesPanic(c, &streamStarted)
+
+ requestStart := time.Now()
+
+ apiKey, ok := middleware2.GetAPIKeyFromContext(c)
+ if !ok {
+ h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
+ return
+ }
+
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
+ return
+ }
+ reqLog := requestLogger(
+ c,
+ "handler.openai_gateway.images",
+ zap.Int64("user_id", subject.UserID),
+ zap.Int64("api_key_id", apiKey.ID),
+ zap.Any("group_id", apiKey.GroupID),
+ )
+ if !h.ensureResponsesDependencies(c, reqLog) {
+ return
+ }
+
+ body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
+ if err != nil {
+ if maxErr, ok := extractMaxBytesError(err); ok {
+ h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
+ return
+ }
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
+ return
+ }
+ if len(body) == 0 {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
+ return
+ }
+
+ if isMultipartImagesContentType(c.GetHeader("Content-Type")) {
+ setOpsRequestContext(c, "", false, nil)
+ } else {
+ setOpsRequestContext(c, "", false, body)
+ }
+
+ parsed, err := h.gatewayService.ParseOpenAIImagesRequest(c, body)
+ if err != nil {
+ h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error())
+ return
+ }
+
+ reqLog = reqLog.With(
+ zap.String("model", parsed.Model),
+ zap.Bool("stream", parsed.Stream),
+ zap.Bool("multipart", parsed.Multipart),
+ zap.String("capability", string(parsed.RequiredCapability)),
+ )
+
+ if parsed.Multipart {
+ setOpsRequestContext(c, parsed.Model, parsed.Stream, nil)
+ } else {
+ setOpsRequestContext(c, parsed.Model, parsed.Stream, body)
+ }
+ setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsed.Stream, false)))
+
+ channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, parsed.Model)
+
+ if h.errorPassthroughService != nil {
+ service.BindErrorPassthroughService(c, h.errorPassthroughService)
+ }
+
+ subscription, _ := middleware2.GetSubscriptionFromContext(c)
+
+ service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
+ routingStart := time.Now()
+
+ userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, parsed.Stream, &streamStarted, reqLog)
+ if !acquired {
+ return
+ }
+ if userReleaseFunc != nil {
+ defer userReleaseFunc()
+ }
+
+ if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
+ reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err))
+ status, code, message, retryAfter := billingErrorDetails(err)
+ if retryAfter > 0 {
+ c.Header("Retry-After", strconv.Itoa(retryAfter))
+ }
+ h.handleStreamingAwareError(c, status, code, message, streamStarted)
+ return
+ }
+
+ sessionHash := h.gatewayService.GenerateExplicitSessionHash(c, body)
+
+ maxAccountSwitches := h.maxAccountSwitches
+ switchCount := 0
+ failedAccountIDs := make(map[int64]struct{})
+ sameAccountRetryCount := make(map[int64]int)
+ var lastFailoverErr *service.UpstreamFailoverError
+
+ for {
+ reqLog.Debug("openai.images.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
+ selection, scheduleDecision, err := h.gatewayService.SelectAccountWithSchedulerForImages(
+ c.Request.Context(),
+ apiKey.GroupID,
+ sessionHash,
+ parsed.Model,
+ failedAccountIDs,
+ parsed.RequiredCapability,
+ )
+ if err != nil {
+ reqLog.Warn("openai.images.account_select_failed",
+ zap.Error(err),
+ zap.Int("excluded_account_count", len(failedAccountIDs)),
+ )
+ if len(failedAccountIDs) == 0 {
+ h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
+ return
+ }
+ if lastFailoverErr != nil {
+ h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
+ } else {
+ h.handleFailoverExhaustedSimple(c, 502, streamStarted)
+ }
+ return
+ }
+ if selection == nil || selection.Account == nil {
+ h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
+ return
+ }
+
+ reqLog.Debug("openai.images.account_schedule_decision",
+ zap.String("layer", scheduleDecision.Layer),
+ zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit),
+ zap.Int("candidate_count", scheduleDecision.CandidateCount),
+ zap.Int("top_k", scheduleDecision.TopK),
+ zap.Int64("latency_ms", scheduleDecision.LatencyMs),
+ zap.Float64("load_skew", scheduleDecision.LoadSkew),
+ )
+
+ account := selection.Account
+ sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
+ reqLog.Debug("openai.images.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
+ setOpsSelectedAccount(c, account.ID, account.Platform)
+
+ accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, parsed.Stream, &streamStarted, reqLog)
+ if !acquired {
+ return
+ }
+
+ service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
+ forwardStart := time.Now()
+ result, err := h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel)
+ forwardDurationMs := time.Since(forwardStart).Milliseconds()
+ if accountReleaseFunc != nil {
+ accountReleaseFunc()
+ }
+ upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
+ responseLatencyMs := forwardDurationMs
+ if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
+ responseLatencyMs = forwardDurationMs - upstreamLatencyMs
+ }
+ service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
+ if err == nil && result != nil && result.FirstTokenMs != nil {
+ service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
+ }
+ if err != nil {
+ var failoverErr *service.UpstreamFailoverError
+ if errors.As(err, &failoverErr) {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ if failoverErr.RetryableOnSameAccount {
+ retryLimit := account.GetPoolModeRetryCount()
+ if sameAccountRetryCount[account.ID] < retryLimit {
+ sameAccountRetryCount[account.ID]++
+ reqLog.Warn("openai.images.pool_mode_same_account_retry",
+ zap.Int64("account_id", account.ID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("retry_limit", retryLimit),
+ zap.Int("retry_count", sameAccountRetryCount[account.ID]),
+ )
+ select {
+ case <-c.Request.Context().Done():
+ return
+ case <-time.After(sameAccountRetryDelay):
+ }
+ continue
+ }
+ }
+ h.gatewayService.RecordOpenAIAccountSwitch()
+ failedAccountIDs[account.ID] = struct{}{}
+ lastFailoverErr = failoverErr
+ if switchCount >= maxAccountSwitches {
+ h.handleFailoverExhausted(c, failoverErr, streamStarted)
+ return
+ }
+ switchCount++
+ reqLog.Warn("openai.images.upstream_failover_switching",
+ zap.Int64("account_id", account.ID),
+ zap.Int("upstream_status", failoverErr.StatusCode),
+ zap.Int("switch_count", switchCount),
+ zap.Int("max_switches", maxAccountSwitches),
+ )
+ continue
+ }
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
+ fields := []zap.Field{
+ zap.Int64("account_id", account.ID),
+ zap.Bool("fallback_error_response_written", wroteFallback),
+ zap.Error(err),
+ }
+ if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
+ reqLog.Warn("openai.images.forward_failed", fields...)
+ return
+ }
+ reqLog.Error("openai.images.forward_failed", fields...)
+ return
+ }
+
+ if result != nil {
+ if account.Type == service.AccountTypeOAuth {
+ h.gatewayService.UpdateCodexUsageSnapshotFromHeaders(c.Request.Context(), account.ID, result.ResponseHeaders)
+ }
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
+ } else {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
+ }
+
+ userAgent := c.GetHeader("User-Agent")
+ clientIP := ip.GetClientIP(c)
+ requestPayloadHash := service.HashUsageRequestPayload(body)
+ if parsed.Multipart {
+ requestPayloadHash = service.HashUsageRequestPayload([]byte(parsed.StickySessionSeed()))
+ }
+
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
+ Result: result,
+ APIKey: apiKey,
+ User: apiKey.User,
+ Account: account,
+ Subscription: subscription,
+ InboundEndpoint: GetInboundEndpoint(c),
+ UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
+ UserAgent: userAgent,
+ IPAddress: clientIP,
+ RequestPayloadHash: requestPayloadHash,
+ APIKeyService: h.apiKeyService,
+ ChannelUsageFields: channelMapping.ToUsageFields(parsed.Model, result.UpstreamModel),
+ }); err != nil {
+ logger.L().With(
+ zap.String("component", "handler.openai_gateway.images"),
+ zap.Int64("user_id", subject.UserID),
+ zap.Int64("api_key_id", apiKey.ID),
+ zap.Any("group_id", apiKey.GroupID),
+ zap.String("model", parsed.Model),
+ zap.Int64("account_id", account.ID),
+ ).Error("openai.images.record_usage_failed", zap.Error(err))
+ }
+ })
+
+ reqLog.Debug("openai.images.request_completed",
+ zap.Int64("account_id", account.ID),
+ zap.Int("switch_count", switchCount),
+ )
+ return
+ }
+}
+
+func isMultipartImagesContentType(contentType string) bool {
+ return strings.HasPrefix(strings.ToLower(strings.TrimSpace(contentType)), "multipart/form-data")
+}
diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go
index 90e90dd0..93554912 100644
--- a/backend/internal/handler/ops_error_logger.go
+++ b/backend/internal/handler/ops_error_logger.go
@@ -1068,7 +1068,7 @@ func guessPlatformFromPath(path string) string {
return service.PlatformAntigravity
case strings.HasPrefix(p, "/v1beta/"):
return service.PlatformGemini
- case strings.Contains(p, "/responses"):
+ case strings.Contains(p, "/responses"), strings.Contains(p, "/images/"):
return service.PlatformOpenAI
default:
return ""
diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go
new file mode 100644
index 00000000..09580442
--- /dev/null
+++ b/backend/internal/handler/payment_handler.go
@@ -0,0 +1,579 @@
+package handler
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// PaymentHandler handles user-facing payment requests.
+type PaymentHandler struct {
+ channelService *service.ChannelService
+ paymentService *service.PaymentService
+ configService *service.PaymentConfigService
+}
+
+// NewPaymentHandler creates a new PaymentHandler.
+func NewPaymentHandler(paymentService *service.PaymentService, configService *service.PaymentConfigService, channelService *service.ChannelService) *PaymentHandler {
+ return &PaymentHandler{
+ channelService: channelService,
+ paymentService: paymentService,
+ configService: configService,
+ }
+}
+
+// GetPaymentConfig returns the payment system configuration.
+// GET /api/v1/payment/config
+func (h *PaymentHandler) GetPaymentConfig(c *gin.Context) {
+ cfg, err := h.configService.GetPaymentConfig(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, cfg)
+}
+
+// GetPlans returns subscription plans available for sale.
+// GET /api/v1/payment/plans
+func (h *PaymentHandler) GetPlans(c *gin.Context) {
+ plans, err := h.configService.ListPlansForSale(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ // Enrich plans with group platform for frontend color coding
+ type planWithPlatform struct {
+ ID int64 `json:"id"`
+ GroupID int64 `json:"group_id"`
+ GroupPlatform string `json:"group_platform"`
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Price float64 `json:"price"`
+ OriginalPrice *float64 `json:"original_price,omitempty"`
+ ValidityDays int `json:"validity_days"`
+ ValidityUnit string `json:"validity_unit"`
+ Features string `json:"features"`
+ ProductName string `json:"product_name"`
+ ForSale bool `json:"for_sale"`
+ SortOrder int `json:"sort_order"`
+ }
+ platformMap := h.configService.GetGroupPlatformMap(c.Request.Context(), plans)
+ result := make([]planWithPlatform, 0, len(plans))
+ for _, p := range plans {
+ result = append(result, planWithPlatform{
+ ID: int64(p.ID), GroupID: p.GroupID, GroupPlatform: platformMap[p.GroupID],
+ Name: p.Name, Description: p.Description, Price: p.Price, OriginalPrice: p.OriginalPrice,
+ ValidityDays: p.ValidityDays, ValidityUnit: p.ValidityUnit, Features: p.Features,
+ ProductName: p.ProductName, ForSale: p.ForSale, SortOrder: p.SortOrder,
+ })
+ }
+ response.Success(c, result)
+}
+
+// GetChannels returns enabled payment channels.
+// GET /api/v1/payment/channels
+func (h *PaymentHandler) GetChannels(c *gin.Context) {
+ channels, _, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{Page: 1, PageSize: 1000}, "active", "")
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, channels)
+}
+
+// GetCheckoutInfo returns all data the payment page needs in a single call:
+// payment methods with limits, subscription plans, and configuration.
+// GET /api/v1/payment/checkout-info
+func (h *PaymentHandler) GetCheckoutInfo(c *gin.Context) {
+ ctx := c.Request.Context()
+
+ // Fetch limits (methods + global range)
+ limitsResp, err := h.configService.GetAvailableMethodLimits(ctx)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Fetch payment config
+ cfg, err := h.configService.GetPaymentConfig(ctx)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Fetch plans with group info
+ plans, _ := h.configService.ListPlansForSale(ctx)
+ groupInfo := h.configService.GetGroupInfoMap(ctx, plans)
+ planList := make([]checkoutPlan, 0, len(plans))
+ for _, p := range plans {
+ gi := groupInfo[p.GroupID]
+ planList = append(planList, checkoutPlan{
+ ID: int64(p.ID), GroupID: p.GroupID,
+ GroupPlatform: gi.Platform, GroupName: gi.Name,
+ RateMultiplier: gi.RateMultiplier, DailyLimitUSD: gi.DailyLimitUSD,
+ WeeklyLimitUSD: gi.WeeklyLimitUSD, MonthlyLimitUSD: gi.MonthlyLimitUSD,
+ ModelScopes: gi.ModelScopes,
+ Name: p.Name, Description: p.Description, Price: p.Price, OriginalPrice: p.OriginalPrice,
+ ValidityDays: p.ValidityDays, ValidityUnit: p.ValidityUnit, Features: parseFeatures(p.Features),
+ ProductName: p.ProductName,
+ })
+ }
+
+ response.Success(c, checkoutInfoResponse{
+ Methods: limitsResp.Methods,
+ GlobalMin: limitsResp.GlobalMin,
+ GlobalMax: limitsResp.GlobalMax,
+ Plans: planList,
+ BalanceDisabled: cfg.BalanceDisabled,
+ BalanceRechargeMultiplier: cfg.BalanceRechargeMultiplier,
+ RechargeFeeRate: cfg.RechargeFeeRate,
+ HelpText: cfg.HelpText,
+ HelpImageURL: cfg.HelpImageURL,
+ StripePublishableKey: cfg.StripePublishableKey,
+ })
+}
+
+type checkoutInfoResponse struct {
+ Methods map[string]service.MethodLimits `json:"methods"`
+ GlobalMin float64 `json:"global_min"`
+ GlobalMax float64 `json:"global_max"`
+ Plans []checkoutPlan `json:"plans"`
+ BalanceDisabled bool `json:"balance_disabled"`
+ BalanceRechargeMultiplier float64 `json:"balance_recharge_multiplier"`
+ RechargeFeeRate float64 `json:"recharge_fee_rate"`
+ HelpText string `json:"help_text"`
+ HelpImageURL string `json:"help_image_url"`
+ StripePublishableKey string `json:"stripe_publishable_key"`
+}
+
+type checkoutPlan struct {
+ ID int64 `json:"id"`
+ GroupID int64 `json:"group_id"`
+ GroupPlatform string `json:"group_platform"`
+ GroupName string `json:"group_name"`
+ RateMultiplier float64 `json:"rate_multiplier"`
+ DailyLimitUSD *float64 `json:"daily_limit_usd"`
+ WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
+ MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
+ ModelScopes []string `json:"supported_model_scopes"`
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Price float64 `json:"price"`
+ OriginalPrice *float64 `json:"original_price,omitempty"`
+ ValidityDays int `json:"validity_days"`
+ ValidityUnit string `json:"validity_unit"`
+ Features []string `json:"features"`
+ ProductName string `json:"product_name"`
+}
+
+// parseFeatures splits a newline-separated features string into a string slice.
+func parseFeatures(raw string) []string {
+ if raw == "" {
+ return []string{}
+ }
+ var out []string
+ for _, line := range strings.Split(raw, "\n") {
+ if s := strings.TrimSpace(line); s != "" {
+ out = append(out, s)
+ }
+ }
+ if out == nil {
+ return []string{}
+ }
+ return out
+}
+
+// GetLimits returns per-payment-type limits derived from enabled provider instances.
+// GET /api/v1/payment/limits
+func (h *PaymentHandler) GetLimits(c *gin.Context) {
+ resp, err := h.configService.GetAvailableMethodLimits(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, resp)
+}
+
+// CreateOrderRequest is the request body for creating a payment order.
+type CreateOrderRequest struct {
+ Amount float64 `json:"amount"`
+ PaymentType string `json:"payment_type" binding:"required"`
+ OpenID string `json:"openid"`
+ WechatResumeToken string `json:"wechat_resume_token"`
+ ReturnURL string `json:"return_url"`
+ PaymentSource string `json:"payment_source"`
+ OrderType string `json:"order_type"`
+ PlanID int64 `json:"plan_id"`
+ // IsMobile lets the frontend declare its mobile status directly. When
+ // nil we fall back to User-Agent heuristics (which miss iPadOS / some
+ // embedded browsers that strip the "Mobile" keyword).
+ IsMobile *bool `json:"is_mobile,omitempty"`
+}
+
+// CreateOrder creates a new payment order.
+// POST /api/v1/payment/orders
+func (h *PaymentHandler) CreateOrder(c *gin.Context) {
+ subject, ok := requireAuth(c)
+ if !ok {
+ return
+ }
+
+ var req CreateOrderRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ if strings.TrimSpace(req.WechatResumeToken) != "" {
+ claims, err := h.paymentService.ParseWeChatPaymentResumeToken(req.WechatResumeToken)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyWeChatPaymentResumeClaims(&req, claims); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+
+ mobile := isMobile(c)
+ if req.IsMobile != nil {
+ mobile = *req.IsMobile
+ }
+ result, err := h.paymentService.CreateOrder(c.Request.Context(), service.CreateOrderRequest{
+ UserID: subject.UserID,
+ Amount: req.Amount,
+ PaymentType: req.PaymentType,
+ OpenID: req.OpenID,
+ ClientIP: c.ClientIP(),
+ IsMobile: mobile,
+ IsWeChatBrowser: isWeChatBrowser(c),
+ SrcHost: c.Request.Host,
+ SrcURL: c.Request.Referer(),
+ ReturnURL: req.ReturnURL,
+ PaymentSource: req.PaymentSource,
+ OrderType: req.OrderType,
+ PlanID: req.PlanID,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, result)
+}
+
+func applyWeChatPaymentResumeClaims(req *CreateOrderRequest, claims *service.WeChatPaymentResumeClaims) error {
+ if req == nil || claims == nil {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume context is missing")
+ }
+ openid := strings.TrimSpace(claims.OpenID)
+ if openid == "" {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token missing openid")
+ }
+
+ paymentType := service.NormalizeVisibleMethod(claims.PaymentType)
+ if paymentType == "" {
+ paymentType = payment.TypeWxpay
+ }
+ if req.PaymentType != "" {
+ requestPaymentType := service.NormalizeVisibleMethod(req.PaymentType)
+ if requestPaymentType != "" && requestPaymentType != paymentType {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", "wechat payment resume token payment type mismatch")
+ }
+ }
+ req.PaymentType = paymentType
+ req.OpenID = openid
+
+ if strings.TrimSpace(claims.Amount) != "" {
+ amount, err := strconv.ParseFloat(strings.TrimSpace(claims.Amount), 64)
+ if err != nil || amount <= 0 {
+ return infraerrors.BadRequest("INVALID_WECHAT_PAYMENT_RESUME_TOKEN", fmt.Sprintf("invalid resume amount: %s", claims.Amount))
+ }
+ req.Amount = amount
+ }
+ if claims.OrderType != "" {
+ req.OrderType = claims.OrderType
+ }
+ if claims.PlanID > 0 {
+ req.PlanID = claims.PlanID
+ }
+ return nil
+}
+
+// GetMyOrders returns the authenticated user's orders.
+// GET /api/v1/payment/orders/my
+func (h *PaymentHandler) GetMyOrders(c *gin.Context) {
+ subject, ok := requireAuth(c)
+ if !ok {
+ return
+ }
+
+ page, pageSize := response.ParsePagination(c)
+ orders, total, err := h.paymentService.GetUserOrders(c.Request.Context(), subject.UserID, service.OrderListParams{
+ Page: page,
+ PageSize: pageSize,
+ Status: c.Query("status"),
+ OrderType: c.Query("order_type"),
+ PaymentType: c.Query("payment_type"),
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Paginated(c, sanitizePaymentOrdersForResponse(orders), int64(total), page, pageSize)
+}
+
+// GetOrder returns a single order for the authenticated user.
+// GET /api/v1/payment/orders/:id
+func (h *PaymentHandler) GetOrder(c *gin.Context) {
+ subject, ok := requireAuth(c)
+ if !ok {
+ return
+ }
+
+ orderID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid order ID")
+ return
+ }
+
+ order, err := h.paymentService.GetOrder(c.Request.Context(), orderID, subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, sanitizePaymentOrderForResponse(order))
+}
+
+// CancelOrder cancels a pending order for the authenticated user.
+// POST /api/v1/payment/orders/:id/cancel
+func (h *PaymentHandler) CancelOrder(c *gin.Context) {
+ subject, ok := requireAuth(c)
+ if !ok {
+ return
+ }
+
+ orderID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid order ID")
+ return
+ }
+
+ msg, err := h.paymentService.CancelOrder(c.Request.Context(), orderID, subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"message": msg})
+}
+
+// RefundRequestBody is the request body for requesting a refund.
+type RefundRequestBody struct {
+ Reason string `json:"reason"`
+}
+
+// RequestRefund submits a refund request for a completed order.
+// POST /api/v1/payment/orders/:id/refund-request
+func (h *PaymentHandler) RequestRefund(c *gin.Context) {
+ subject, ok := requireAuth(c)
+ if !ok {
+ return
+ }
+
+ orderID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid order ID")
+ return
+ }
+
+ var req RefundRequestBody
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.paymentService.RequestRefund(c.Request.Context(), orderID, subject.UserID, req.Reason); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"message": "refund requested"})
+}
+
+// GetRefundEligibleProviders returns provider instance IDs that allow user refund.
+func (h *PaymentHandler) GetRefundEligibleProviders(c *gin.Context) {
+ ids, err := h.configService.GetUserRefundEligibleInstanceIDs(c.Request.Context())
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"provider_instance_ids": ids})
+}
+
+// VerifyOrderRequest is the request body for verifying a payment order.
+type VerifyOrderRequest struct {
+ OutTradeNo string `json:"out_trade_no" binding:"required"`
+}
+
+type ResolveOrderByResumeTokenRequest struct {
+ ResumeToken string `json:"resume_token" binding:"required"`
+}
+
+// VerifyOrder actively queries the upstream payment provider to check
+// if payment was made, and processes it if so.
+// POST /api/v1/payment/orders/verify
+func (h *PaymentHandler) VerifyOrder(c *gin.Context) {
+ subject, ok := requireAuth(c)
+ if !ok {
+ return
+ }
+
+ var req VerifyOrderRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ order, err := h.paymentService.VerifyOrderByOutTradeNo(c.Request.Context(), req.OutTradeNo, subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, sanitizePaymentOrderForResponse(order))
+}
+
+// PublicOrderResult is the limited order info returned by the public verify endpoint.
+// No user details are exposed — only payment status information.
+type PublicOrderResult struct {
+ ID int64 `json:"id"`
+ OutTradeNo string `json:"out_trade_no"`
+ Amount float64 `json:"amount"`
+ PayAmount float64 `json:"pay_amount"`
+ FeeRate float64 `json:"fee_rate"`
+ PaymentType string `json:"payment_type"`
+ OrderType string `json:"order_type"`
+ Status string `json:"status"`
+ CreatedAt time.Time `json:"created_at"`
+ ExpiresAt time.Time `json:"expires_at"`
+ PaidAt *time.Time `json:"paid_at,omitempty"`
+ CompletedAt *time.Time `json:"completed_at,omitempty"`
+ RefundAmount float64 `json:"refund_amount"`
+ RefundReason *string `json:"refund_reason,omitempty"`
+ RefundRequestedAt *time.Time `json:"refund_requested_at,omitempty"`
+ RefundRequestedBy *string `json:"refund_requested_by,omitempty"`
+ RefundRequestReason *string `json:"refund_request_reason,omitempty"`
+ PlanID *int64 `json:"plan_id,omitempty"`
+}
+
+func buildPublicOrderResult(order *dbent.PaymentOrder) PublicOrderResult {
+ return PublicOrderResult{
+ ID: order.ID,
+ OutTradeNo: order.OutTradeNo,
+ Amount: order.Amount,
+ PayAmount: order.PayAmount,
+ FeeRate: order.FeeRate,
+ PaymentType: order.PaymentType,
+ OrderType: order.OrderType,
+ Status: order.Status,
+ CreatedAt: order.CreatedAt,
+ ExpiresAt: order.ExpiresAt,
+ PaidAt: order.PaidAt,
+ CompletedAt: order.CompletedAt,
+ RefundAmount: order.RefundAmount,
+ RefundReason: order.RefundReason,
+ RefundRequestedAt: order.RefundRequestedAt,
+ RefundRequestedBy: order.RefundRequestedBy,
+ RefundRequestReason: order.RefundRequestReason,
+ PlanID: order.PlanID,
+ }
+}
+
+// VerifyOrderPublic keeps the legacy anonymous out_trade_no lookup available as
+// a compatibility path for older result pages and staggered deploys.
+// POST /api/v1/payment/public/orders/verify
+func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) {
+ var req VerifyOrderRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ order, err := h.paymentService.VerifyOrderPublic(c.Request.Context(), req.OutTradeNo)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, buildPublicOrderResult(order))
+}
+
+// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token.
+// POST /api/v1/payment/public/orders/resolve
+func (h *PaymentHandler) ResolveOrderPublicByResumeToken(c *gin.Context) {
+ var req ResolveOrderByResumeTokenRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ order, err := h.paymentService.GetPublicOrderByResumeToken(c.Request.Context(), req.ResumeToken)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, buildPublicOrderResult(order))
+}
+
+// requireAuth extracts the authenticated subject from the context.
+// Returns the subject and true on success; on failure it writes an Unauthorized response and returns false.
+func requireAuth(c *gin.Context) (middleware2.AuthSubject, bool) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return middleware2.AuthSubject{}, false
+ }
+ return subject, true
+}
+
+// isMobile detects mobile user agents.
+func isMobile(c *gin.Context) bool {
+ ua := strings.ToLower(c.GetHeader("User-Agent"))
+ for _, kw := range []string{"mobile", "android", "iphone", "ipad", "ipod"} {
+ if strings.Contains(ua, kw) {
+ return true
+ }
+ }
+ return false
+}
+
+func sanitizePaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder {
+ if len(orders) == 0 {
+ return orders
+ }
+ out := make([]*dbent.PaymentOrder, 0, len(orders))
+ for _, order := range orders {
+ out = append(out, sanitizePaymentOrderForResponse(order))
+ }
+ return out
+}
+
+func sanitizePaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder {
+ if order == nil {
+ return nil
+ }
+ cloned := *order
+ cloned.ProviderSnapshot = nil
+ return &cloned
+}
+
+func isWeChatBrowser(c *gin.Context) bool {
+ return strings.Contains(strings.ToLower(c.GetHeader("User-Agent")), "micromessenger")
+}
diff --git a/backend/internal/handler/payment_handler_resume_test.go b/backend/internal/handler/payment_handler_resume_test.go
new file mode 100644
index 00000000..377f432e
--- /dev/null
+++ b/backend/internal/handler/payment_handler_resume_test.go
@@ -0,0 +1,368 @@
+//go:build unit
+
+package handler
+
+import (
+ "bytes"
+ "context"
+ "database/sql"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func TestApplyWeChatPaymentResumeClaims(t *testing.T) {
+ t.Parallel()
+
+ req := CreateOrderRequest{
+ Amount: 0,
+ PaymentType: payment.TypeWxpay,
+ OrderType: payment.OrderTypeBalance,
+ }
+
+ err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ Amount: "12.50",
+ OrderType: payment.OrderTypeSubscription,
+ PlanID: 7,
+ })
+ if err != nil {
+ t.Fatalf("applyWeChatPaymentResumeClaims returned error: %v", err)
+ }
+ if req.OpenID != "openid-123" {
+ t.Fatalf("openid = %q, want %q", req.OpenID, "openid-123")
+ }
+ if req.Amount != 12.5 {
+ t.Fatalf("amount = %v, want 12.5", req.Amount)
+ }
+ if req.OrderType != payment.OrderTypeSubscription {
+ t.Fatalf("order_type = %q, want %q", req.OrderType, payment.OrderTypeSubscription)
+ }
+ if req.PlanID != 7 {
+ t.Fatalf("plan_id = %d, want 7", req.PlanID)
+ }
+}
+
+func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T) {
+ t.Parallel()
+
+ req := CreateOrderRequest{
+ PaymentType: payment.TypeAlipay,
+ }
+
+ err := applyWeChatPaymentResumeClaims(&req, &service.WeChatPaymentResumeClaims{
+ OpenID: "openid-123",
+ PaymentType: payment.TypeWxpay,
+ Amount: "12.50",
+ OrderType: payment.OrderTypeBalance,
+ })
+ if err == nil {
+ t.Fatal("applyWeChatPaymentResumeClaims should reject mismatched payment types")
+ }
+}
+
+func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
+ t.Parallel()
+
+ gin.SetMode(gin.TestMode)
+
+ db, err := sql.Open("sqlite", "file:payment_handler_public_verify?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ user, err := client.User.Create().
+ SetEmail("public-verify@example.com").
+ SetPasswordHash("hash").
+ SetUsername("public-verify-user").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(90.64).
+ SetFeeRate(0.03).
+ SetRechargeCode("PUBLIC-VERIFY").
+ SetOutTradeNo("legacy-order-no").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-public-verify").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(service.OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
+ h := NewPaymentHandler(paymentSvc, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(recorder)
+ ctx.Request = httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/payment/public/orders/verify",
+ bytes.NewBufferString(`{"out_trade_no":"legacy-order-no"}`),
+ )
+ ctx.Request.Header.Set("Content-Type", "application/json")
+
+ h.VerifyOrderPublic(ctx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ ID int64 `json:"id"`
+ OutTradeNo string `json:"out_trade_no"`
+ Amount float64 `json:"amount"`
+ PayAmount float64 `json:"pay_amount"`
+ FeeRate float64 `json:"fee_rate"`
+ PaymentType string `json:"payment_type"`
+ OrderType string `json:"order_type"`
+ Status string `json:"status"`
+ RefundAmount float64 `json:"refund_amount"`
+ CreatedAt string `json:"created_at"`
+ ExpiresAt string `json:"expires_at"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, order.ID, resp.Data.ID)
+ require.Equal(t, "legacy-order-no", resp.Data.OutTradeNo)
+ require.Equal(t, 90.64, resp.Data.PayAmount)
+ require.Equal(t, 0.03, resp.Data.FeeRate)
+ require.Equal(t, payment.TypeAlipay, resp.Data.PaymentType)
+ require.Equal(t, payment.OrderTypeBalance, resp.Data.OrderType)
+ require.Equal(t, service.OrderStatusPending, resp.Data.Status)
+ require.Equal(t, 0.0, resp.Data.RefundAmount)
+ require.NotEmpty(t, resp.Data.CreatedAt)
+ require.NotEmpty(t, resp.Data.ExpiresAt)
+}
+
+func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
+
+ db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ user, err := client.User.Create().
+ SetEmail("public-resolve@example.com").
+ SetPasswordHash("hash").
+ SetUsername("public-resolve-user").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(100).
+ SetPayAmount(103).
+ SetFeeRate(0.03).
+ SetRechargeCode("PUBLIC-RESOLVE").
+ SetOutTradeNo("resolve-order-no").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-public-resolve").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(service.OrderStatusPaid).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
+ paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
+ h := NewPaymentHandler(paymentSvc, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(recorder)
+ ctx.Request = httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/payment/public/orders/resolve",
+ bytes.NewBufferString(`{"resume_token":"`+token+`"}`),
+ )
+ ctx.Request.Header.Set("Content-Type", "application/json")
+
+ h.ResolveOrderPublicByResumeToken(ctx)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, float64(order.ID), resp.Data["id"])
+ require.Equal(t, "resolve-order-no", resp.Data["out_trade_no"])
+ require.Equal(t, 100.0, resp.Data["amount"])
+ require.Equal(t, 103.0, resp.Data["pay_amount"])
+ require.Equal(t, 0.03, resp.Data["fee_rate"])
+ require.Equal(t, payment.TypeAlipay, resp.Data["payment_type"])
+ require.Equal(t, payment.OrderTypeBalance, resp.Data["order_type"])
+ require.Equal(t, service.OrderStatusPaid, resp.Data["status"])
+ require.Contains(t, resp.Data, "created_at")
+ require.Contains(t, resp.Data, "expires_at")
+ require.Contains(t, resp.Data, "refund_amount")
+}
+
+func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
+
+ db, err := sql.Open("sqlite", "file:payment_handler_public_resolve_mismatch?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ user, err := client.User.Create().
+ SetEmail("public-resolve-mismatch@example.com").
+ SetPasswordHash("hash").
+ SetUsername("public-resolve-mismatch-user").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(100).
+ SetPayAmount(103).
+ SetFeeRate(0.03).
+ SetRechargeCode("PUBLIC-RESOLVE-MISMATCH").
+ SetOutTradeNo("resolve-order-mismatch-no").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("trade-public-resolve-mismatch").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(service.OrderStatusPaid).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetPaidAt(time.Now()).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(context.Background())
+ require.NoError(t, err)
+
+ resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
+ token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{
+ OrderID: order.ID,
+ UserID: user.ID + 999,
+ PaymentType: payment.TypeAlipay,
+ CanonicalReturnURL: "https://app.example.com/payment/result",
+ })
+ require.NoError(t, err)
+
+ configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
+ paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil, nil)
+ h := NewPaymentHandler(paymentSvc, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(recorder)
+ ctx.Request = httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/payment/public/orders/resolve",
+ bytes.NewBufferString(`{"resume_token":"`+token+`"}`),
+ )
+ ctx.Request.Header.Set("Content-Type", "application/json")
+
+ h.ResolveOrderPublicByResumeToken(ctx)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Reason string `json:"reason"`
+ Message string `json:"message"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusBadRequest, resp.Code)
+ require.Equal(t, "INVALID_RESUME_TOKEN", resp.Reason)
+}
+
+func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ db, err := sql.Open("sqlite", "file:payment_handler_public_verify_blank?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil, nil)
+ h := NewPaymentHandler(paymentSvc, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ ctx, _ := gin.CreateTestContext(recorder)
+ ctx.Request = httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/payment/public/orders/verify",
+ bytes.NewBufferString(`{"out_trade_no":" "}`),
+ )
+ ctx.Request.Header.Set("Content-Type", "application/json")
+
+ h.VerifyOrderPublic(ctx)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Reason string `json:"reason"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusBadRequest, resp.Code)
+ require.Equal(t, "INVALID_OUT_TRADE_NO", resp.Reason)
+}
diff --git a/backend/internal/handler/payment_webhook_handler.go b/backend/internal/handler/payment_webhook_handler.go
new file mode 100644
index 00000000..9ae799fd
--- /dev/null
+++ b/backend/internal/handler/payment_webhook_handler.go
@@ -0,0 +1,198 @@
+package handler
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+ "net/url"
+ "strings"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// PaymentWebhookHandler handles payment provider webhook callbacks.
+type PaymentWebhookHandler struct {
+ paymentService *service.PaymentService
+ registry *payment.Registry
+}
+
+// maxWebhookBodySize is the maximum allowed webhook request body size (1 MB).
+const maxWebhookBodySize = 1 << 20
+
+// webhookLogTruncateLen is the maximum length of raw body logged on verify failure.
+const webhookLogTruncateLen = 200
+
+// NewPaymentWebhookHandler creates a new PaymentWebhookHandler.
+func NewPaymentWebhookHandler(paymentService *service.PaymentService, registry *payment.Registry) *PaymentWebhookHandler {
+ return &PaymentWebhookHandler{
+ paymentService: paymentService,
+ registry: registry,
+ }
+}
+
+// EasyPayNotify handles EasyPay payment notifications.
+// POST /api/v1/payment/webhook/easypay
+func (h *PaymentWebhookHandler) EasyPayNotify(c *gin.Context) {
+ h.handleNotify(c, payment.TypeEasyPay)
+}
+
+// AlipayNotify handles Alipay payment notifications.
+// POST /api/v1/payment/webhook/alipay
+func (h *PaymentWebhookHandler) AlipayNotify(c *gin.Context) {
+ h.handleNotify(c, payment.TypeAlipay)
+}
+
+// WxpayNotify handles WeChat Pay payment notifications.
+// POST /api/v1/payment/webhook/wxpay
+func (h *PaymentWebhookHandler) WxpayNotify(c *gin.Context) {
+ h.handleNotify(c, payment.TypeWxpay)
+}
+
+// StripeWebhook handles Stripe webhook events.
+// POST /api/v1/payment/webhook/stripe
+func (h *PaymentWebhookHandler) StripeWebhook(c *gin.Context) {
+ h.handleNotify(c, payment.TypeStripe)
+}
+
+// handleNotify is the shared logic for all provider webhook handlers.
+func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) {
+ var rawBody string
+ if c.Request.Method == http.MethodGet {
+ // GET callbacks (e.g. EasyPay) pass params as URL query string
+ rawBody = c.Request.URL.RawQuery
+ } else {
+ body, err := io.ReadAll(io.LimitReader(c.Request.Body, maxWebhookBodySize))
+ if err != nil {
+ slog.Error("[Payment Webhook] failed to read body", "provider", providerKey, "error", err)
+ c.String(http.StatusBadRequest, "failed to read body")
+ return
+ }
+ rawBody = string(body)
+ }
+
+ // Extract out_trade_no to look up the order's specific provider instance.
+ // This is needed when multiple instances of the same provider exist (e.g. multiple EasyPay accounts).
+ outTradeNo := extractOutTradeNo(rawBody, providerKey)
+
+ providers, err := h.paymentService.GetWebhookProviders(c.Request.Context(), providerKey, outTradeNo)
+ if err != nil {
+ slog.Warn("[Payment Webhook] provider not found", "provider", providerKey, "outTradeNo", outTradeNo, "error", err)
+ if providerKey == payment.TypeWxpay {
+ c.String(http.StatusBadRequest, "verify failed")
+ return
+ }
+ writeSuccessResponse(c, providerKey)
+ return
+ }
+
+ headers := make(map[string]string)
+ for k := range c.Request.Header {
+ headers[strings.ToLower(k)] = c.GetHeader(k)
+ }
+
+ resolvedProviderKey, notification, err := verifyNotificationWithProviders(c.Request.Context(), providers, rawBody, headers)
+ if err != nil {
+ truncatedBody := rawBody
+ if len(truncatedBody) > webhookLogTruncateLen {
+ truncatedBody = truncatedBody[:webhookLogTruncateLen] + "...(truncated)"
+ }
+ slog.Error("[Payment Webhook] verify failed", "provider", providerKey, "error", err, "method", c.Request.Method, "bodyLen", len(rawBody))
+ slog.Debug("[Payment Webhook] verify failed body", "provider", providerKey, "rawBody", truncatedBody)
+ c.String(http.StatusBadRequest, "verify failed")
+ return
+ }
+
+ // nil notification means irrelevant event (e.g. Stripe non-payment event); return success.
+ if notification == nil {
+ writeSuccessResponse(c, resolvedProviderKey)
+ return
+ }
+
+ if err := h.paymentService.HandlePaymentNotification(c.Request.Context(), notification, resolvedProviderKey); err != nil {
+ // Unknown order: ack with 2xx so the provider stops retrying. This
+ // guards against foreign environments whose webhook endpoints are
+ // (mis)configured to point at us — without a 2xx, the provider will
+ // retry for days and spam our error logs. We still emit a WARN so the
+ // event is discoverable in logs.
+ if errors.Is(err, service.ErrOrderNotFound) {
+ slog.Warn("[Payment Webhook] unknown order, acking to stop retries",
+ "provider", resolvedProviderKey,
+ "outTradeNo", notification.OrderID,
+ "tradeNo", notification.TradeNo,
+ )
+ writeSuccessResponse(c, resolvedProviderKey)
+ return
+ }
+ slog.Error("[Payment Webhook] handle notification failed", "provider", resolvedProviderKey, "error", err)
+ c.String(http.StatusInternalServerError, "handle failed")
+ return
+ }
+
+ writeSuccessResponse(c, resolvedProviderKey)
+}
+
+// extractOutTradeNo parses the webhook body to find the out_trade_no.
+// This allows looking up the correct provider instance before verification.
+func extractOutTradeNo(rawBody, providerKey string) string {
+ switch providerKey {
+ case payment.TypeEasyPay, payment.TypeAlipay:
+ values, err := url.ParseQuery(rawBody)
+ if err == nil {
+ return values.Get("out_trade_no")
+ }
+ }
+ // For other providers (Stripe, Alipay direct, WxPay direct), the registry
+ // typically has only one instance, so no instance lookup is needed.
+ return ""
+}
+
+func verifyNotificationWithProviders(ctx context.Context, providers []payment.Provider, rawBody string, headers map[string]string) (string, *payment.PaymentNotification, error) {
+ var lastErr error
+ for _, provider := range providers {
+ if provider == nil {
+ continue
+ }
+ notification, err := provider.VerifyNotification(ctx, rawBody, headers)
+ if err != nil {
+ lastErr = err
+ continue
+ }
+ return provider.ProviderKey(), notification, nil
+ }
+ if lastErr != nil {
+ return "", nil, lastErr
+ }
+ return "", nil, fmt.Errorf("no webhook provider could verify notification")
+}
+
+// wxpaySuccessResponse is the JSON response expected by WeChat Pay webhook.
+type wxpaySuccessResponse struct {
+ Code string `json:"code"`
+ Message string `json:"message"`
+}
+
+// WeChat Pay webhook success response constants.
+const (
+ wxpaySuccessCode = "SUCCESS"
+ wxpaySuccessMessage = "成功"
+)
+
+// writeSuccessResponse sends the provider-specific success response.
+// WeChat Pay requires JSON {"code":"SUCCESS","message":"成功"};
+// Stripe expects an empty 200; others accept plain text "success".
+func writeSuccessResponse(c *gin.Context, providerKey string) {
+ switch providerKey {
+ case payment.TypeWxpay:
+ c.JSON(http.StatusOK, wxpaySuccessResponse{Code: wxpaySuccessCode, Message: wxpaySuccessMessage})
+ case payment.TypeStripe:
+ c.String(http.StatusOK, "")
+ default:
+ c.String(http.StatusOK, "success")
+ }
+}
diff --git a/backend/internal/handler/payment_webhook_handler_test.go b/backend/internal/handler/payment_webhook_handler_test.go
new file mode 100644
index 00000000..7551fc83
--- /dev/null
+++ b/backend/internal/handler/payment_webhook_handler_test.go
@@ -0,0 +1,242 @@
+//go:build unit
+
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestWriteSuccessResponse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ tests := []struct {
+ name string
+ providerKey string
+ wantCode int
+ wantContentType string
+ wantBody string
+ checkJSON bool
+ wantJSONCode string
+ wantJSONMessage string
+ }{
+ {
+ name: "wxpay returns JSON with code SUCCESS",
+ providerKey: "wxpay",
+ wantCode: http.StatusOK,
+ wantContentType: "application/json",
+ checkJSON: true,
+ wantJSONCode: "SUCCESS",
+ wantJSONMessage: "成功",
+ },
+ {
+ name: "stripe returns empty 200",
+ providerKey: "stripe",
+ wantCode: http.StatusOK,
+ wantContentType: "text/plain",
+ wantBody: "",
+ },
+ {
+ name: "easypay returns plain text success",
+ providerKey: "easypay",
+ wantCode: http.StatusOK,
+ wantContentType: "text/plain",
+ wantBody: "success",
+ },
+ {
+ name: "alipay returns plain text success",
+ providerKey: "alipay",
+ wantCode: http.StatusOK,
+ wantContentType: "text/plain",
+ wantBody: "success",
+ },
+ {
+ name: "unknown provider returns plain text success",
+ providerKey: "unknown_provider",
+ wantCode: http.StatusOK,
+ wantContentType: "text/plain",
+ wantBody: "success",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+
+ writeSuccessResponse(c, tt.providerKey)
+
+ assert.Equal(t, tt.wantCode, w.Code)
+ assert.Contains(t, w.Header().Get("Content-Type"), tt.wantContentType)
+
+ if tt.checkJSON {
+ var resp wxpaySuccessResponse
+ err := json.Unmarshal(w.Body.Bytes(), &resp)
+ require.NoError(t, err, "response body should be valid JSON")
+ assert.Equal(t, tt.wantJSONCode, resp.Code)
+ assert.Equal(t, tt.wantJSONMessage, resp.Message)
+ } else {
+ assert.Equal(t, tt.wantBody, w.Body.String())
+ }
+ })
+ }
+}
+
+// TestUnknownOrderWebhookAcksWithSuccess exercises the response contract that
+// handleNotify relies on when HandlePaymentNotification returns ErrOrderNotFound:
+// we still need to emit the provider-specific 2xx so the provider stops
+// retrying. We can't easily drive handleNotify end-to-end without mocking the
+// concrete *service.PaymentService, so this test locks down the two ingredients
+// the fix depends on:
+// 1. errors.Is recognises the sentinel through fmt.Errorf %w wrapping (which
+// is how service layer wraps it with the out_trade_no context).
+// 2. writeSuccessResponse produces the provider-specific body for Stripe
+// (empty 200) — matching what handleNotify calls on the ack path.
+//
+// If either contract breaks, the Stripe "unknown order → 500 loop" regresses.
+func TestUnknownOrderWebhookAcksWithSuccess(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ // 1) Sentinel recognition through wrapping.
+ wrapped := fmt.Errorf("%w: out_trade_no=sub2_missing_42", service.ErrOrderNotFound)
+ require.True(t, errors.Is(wrapped, service.ErrOrderNotFound),
+ "handleNotify uses errors.Is on the wrapped service error; regression here "+
+ "would mean unknown-order webhooks go back to returning 500 and looping forever")
+
+ // A distinct error must NOT match — otherwise a DB failure would be silently
+ // swallowed as an ack.
+ other := errors.New("lookup order failed: connection refused")
+ require.False(t, errors.Is(other, service.ErrOrderNotFound))
+
+ // 2) Provider-specific success body is what handleNotify emits on the
+ // ack path. Asserted again here because this is the shape Stripe expects
+ // to consider the webhook acknowledged.
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ writeSuccessResponse(c, payment.TypeStripe)
+ require.Equal(t, http.StatusOK, w.Code,
+ "Stripe requires 2xx to stop retrying; anything else restarts the retry loop")
+ require.Empty(t, w.Body.String(), "Stripe expects an empty body on the ack path")
+}
+
+func TestWebhookConstants(t *testing.T) {
+ t.Run("maxWebhookBodySize is 1MB", func(t *testing.T) {
+ assert.Equal(t, int64(1<<20), int64(maxWebhookBodySize))
+ })
+
+ t.Run("webhookLogTruncateLen is 200", func(t *testing.T) {
+ assert.Equal(t, 200, webhookLogTruncateLen)
+ })
+}
+
+func TestExtractOutTradeNo(t *testing.T) {
+ tests := []struct {
+ name string
+ providerKey string
+ rawBody string
+ want string
+ }{
+ {
+ name: "easypay query payload",
+ providerKey: "easypay",
+ rawBody: "out_trade_no=sub2_123&trade_status=TRADE_SUCCESS",
+ want: "sub2_123",
+ },
+ {
+ name: "alipay query payload",
+ providerKey: "alipay",
+ rawBody: "notify_time=2026-04-20+12%3A00%3A00&out_trade_no=sub2_456",
+ want: "sub2_456",
+ },
+ {
+ name: "unknown provider",
+ providerKey: "wxpay",
+ rawBody: "{}",
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.want, extractOutTradeNo(tt.rawBody, tt.providerKey))
+ })
+ }
+}
+
+func TestVerifyNotificationWithProvidersReturnsMatchedProvider(t *testing.T) {
+ firstErr := errors.New("wrong provider")
+ providers := []payment.Provider{
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ verifyErr: firstErr,
+ },
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ notification: &payment.PaymentNotification{
+ OrderID: "sub2_42",
+ TradeNo: "trade-42",
+ Status: payment.NotificationStatusSuccess,
+ },
+ },
+ }
+
+ providerKey, notification, err := verifyNotificationWithProviders(context.Background(), providers, "{}", map[string]string{"wechatpay-signature": "sig"})
+ require.NoError(t, err)
+ require.Equal(t, payment.TypeWxpay, providerKey)
+ require.NotNil(t, notification)
+ require.Equal(t, "sub2_42", notification.OrderID)
+}
+
+func TestVerifyNotificationWithProvidersFailsWhenAllProvidersReject(t *testing.T) {
+ providers := []payment.Provider{
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ verifyErr: errors.New("verify failed a"),
+ },
+ webhookHandlerProviderStub{
+ key: payment.TypeWxpay,
+ verifyErr: errors.New("verify failed b"),
+ },
+ }
+
+ _, _, err := verifyNotificationWithProviders(context.Background(), providers, "{}", nil)
+ require.Error(t, err)
+}
+
+type webhookHandlerProviderStub struct {
+ key string
+ notification *payment.PaymentNotification
+ verifyErr error
+}
+
+func (p webhookHandlerProviderStub) Name() string { return p.key }
+func (p webhookHandlerProviderStub) ProviderKey() string { return p.key }
+func (p webhookHandlerProviderStub) SupportedTypes() []payment.PaymentType {
+ return []payment.PaymentType{payment.PaymentType(p.key)}
+}
+func (p webhookHandlerProviderStub) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ panic("unexpected call")
+}
+func (p webhookHandlerProviderStub) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) {
+ panic("unexpected call")
+}
+func (p webhookHandlerProviderStub) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) {
+ if p.verifyErr != nil {
+ return nil, p.verifyErr
+ }
+ return p.notification, nil
+}
+func (p webhookHandlerProviderStub) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) {
+ panic("unexpected call")
+}
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index 2c999cf1..22f2aa15 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
response.Success(c, dto.PublicSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
+ ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
@@ -51,11 +52,30 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
+ TableDefaultPageSize: settings.TableDefaultPageSize,
+ TablePageSizeOptions: settings.TablePageSizeOptions,
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
- SoraClientEnabled: settings.SoraClientEnabled,
+ WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
+ WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled,
+ WeChatOAuthMPEnabled: settings.WeChatOAuthMPEnabled,
+ WeChatOAuthMobileEnabled: settings.WeChatOAuthMobileEnabled,
+ OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
+ OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
BackendModeEnabled: settings.BackendModeEnabled,
+ PaymentEnabled: settings.PaymentEnabled,
Version: h.version,
+ BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
+ AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
+ BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
+ BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
+
+ ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
+ ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
+
+ AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
+
+ AffiliateEnabled: settings.AffiliateEnabled,
})
}
diff --git a/backend/internal/handler/setting_handler_public_test.go b/backend/internal/handler/setting_handler_public_test.go
new file mode 100644
index 00000000..45d66f8e
--- /dev/null
+++ b/backend/internal/handler/setting_handler_public_test.go
@@ -0,0 +1,122 @@
+//go:build unit
+
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type settingHandlerPublicRepoStub struct {
+ values map[string]string
+}
+
+func (s *settingHandlerPublicRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *settingHandlerPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *settingHandlerPublicRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *settingHandlerPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *settingHandlerPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *settingHandlerPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *settingHandlerPublicRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &settingHandlerPublicRepoStub{
+ values: map[string]string{
+ service.SettingKeyForceEmailOnThirdPartySignup: "true",
+ },
+ }
+ h := NewSettingHandler(service.NewSettingService(repo, &config.Config{}), "test-version")
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil)
+
+ h.GetPublicSettings(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.True(t, resp.Data.ForceEmailOnThirdPartySignup)
+}
+
+func TestSettingHandler_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ h := NewSettingHandler(service.NewSettingService(&settingHandlerPublicRepoStub{
+ values: map[string]string{
+ service.SettingKeyWeChatConnectEnabled: "true",
+ service.SettingKeyWeChatConnectAppID: "wx-mp-app",
+ service.SettingKeyWeChatConnectAppSecret: "wx-mp-secret",
+ service.SettingKeyWeChatConnectMode: "mp",
+ service.SettingKeyWeChatConnectScopes: "snsapi_base",
+ service.SettingKeyWeChatConnectOpenEnabled: "true",
+ service.SettingKeyWeChatConnectMPEnabled: "true",
+ service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback",
+ service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback",
+ },
+ }, &config.Config{}), "test-version")
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil)
+
+ h.GetPublicSettings(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
+ WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
+ WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.True(t, resp.Data.WeChatOAuthEnabled)
+ require.True(t, resp.Data.WeChatOAuthOpenEnabled)
+ require.True(t, resp.Data.WeChatOAuthMPEnabled)
+}
diff --git a/backend/internal/handler/sora_client_handler.go b/backend/internal/handler/sora_client_handler.go
deleted file mode 100644
index 80acc833..00000000
--- a/backend/internal/handler/sora_client_handler.go
+++ /dev/null
@@ -1,979 +0,0 @@
-package handler
-
-import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "net/http/httptest"
- "strconv"
- "strings"
- "sync"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
- "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
- "github.com/Wei-Shaw/sub2api/internal/pkg/response"
- middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
- "github.com/Wei-Shaw/sub2api/internal/service"
- "github.com/gin-gonic/gin"
-)
-
-const (
- // 上游模型缓存 TTL
- modelCacheTTL = 1 * time.Hour // 上游获取成功
- modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地)
-)
-
-// SoraClientHandler 处理 Sora 客户端 API 请求。
-type SoraClientHandler struct {
- genService *service.SoraGenerationService
- quotaService *service.SoraQuotaService
- s3Storage *service.SoraS3Storage
- soraGatewayService *service.SoraGatewayService
- gatewayService *service.GatewayService
- mediaStorage *service.SoraMediaStorage
- apiKeyService *service.APIKeyService
-
- // 上游模型缓存
- modelCacheMu sync.RWMutex
- cachedFamilies []service.SoraModelFamily
- modelCacheTime time.Time
- modelCacheUpstream bool // 是否来自上游(决定 TTL)
-}
-
-// NewSoraClientHandler 创建 Sora 客户端 Handler。
-func NewSoraClientHandler(
- genService *service.SoraGenerationService,
- quotaService *service.SoraQuotaService,
- s3Storage *service.SoraS3Storage,
- soraGatewayService *service.SoraGatewayService,
- gatewayService *service.GatewayService,
- mediaStorage *service.SoraMediaStorage,
- apiKeyService *service.APIKeyService,
-) *SoraClientHandler {
- return &SoraClientHandler{
- genService: genService,
- quotaService: quotaService,
- s3Storage: s3Storage,
- soraGatewayService: soraGatewayService,
- gatewayService: gatewayService,
- mediaStorage: mediaStorage,
- apiKeyService: apiKeyService,
- }
-}
-
-// GenerateRequest 生成请求。
-type GenerateRequest struct {
- Model string `json:"model" binding:"required"`
- Prompt string `json:"prompt" binding:"required"`
- MediaType string `json:"media_type"` // video / image,默认 video
- VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3)
- ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL)
- APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID
-}
-
-// Generate 异步生成 — 创建 pending 记录后立即返回。
-// POST /api/v1/sora/generate
-func (h *SoraClientHandler) Generate(c *gin.Context) {
- userID := getUserIDFromContext(c)
- if userID == 0 {
- response.Error(c, http.StatusUnauthorized, "未登录")
- return
- }
-
- var req GenerateRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error())
- return
- }
-
- if req.MediaType == "" {
- req.MediaType = "video"
- }
- req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount)
-
- // 并发数检查(最多 3 个)
- activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- if activeCount >= 3 {
- response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
- return
- }
-
- // 配额检查(粗略检查,实际文件大小在上传后才知道)
- if h.quotaService != nil {
- if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil {
- var quotaErr *service.QuotaExceededError
- if errors.As(err, "aErr) {
- response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
- return
- }
- response.Error(c, http.StatusForbidden, err.Error())
- return
- }
- }
-
- // 获取 API Key ID 和 Group ID
- var apiKeyID *int64
- var groupID *int64
-
- if req.APIKeyID != nil && h.apiKeyService != nil {
- // 前端传递了 api_key_id,需要校验
- apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID)
- if err != nil {
- response.Error(c, http.StatusBadRequest, "API Key 不存在")
- return
- }
- if apiKey.UserID != userID {
- response.Error(c, http.StatusForbidden, "API Key 不属于当前用户")
- return
- }
- if apiKey.Status != service.StatusAPIKeyActive {
- response.Error(c, http.StatusForbidden, "API Key 不可用")
- return
- }
- apiKeyID = &apiKey.ID
- groupID = apiKey.GroupID
- } else if id, ok := c.Get("api_key_id"); ok {
- // 兼容 API Key 认证路径(/sora/v1/ 网关路由)
- if v, ok := id.(int64); ok {
- apiKeyID = &v
- }
- }
-
- gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType)
- if err != nil {
- if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) {
- response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个")
- return
- }
- response.ErrorFrom(c, err)
- return
- }
-
- // 启动后台异步生成 goroutine
- go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount)
-
- response.Success(c, gin.H{
- "generation_id": gen.ID,
- "status": gen.Status,
- })
-}
-
-// processGeneration 后台异步执行 Sora 生成任务。
-// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。
-func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) {
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
- defer cancel()
-
- // 标记为生成中
- if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil {
- if errors.Is(err, service.ErrSoraGenerationStateConflict) {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID)
- return
- }
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err)
- return
- }
-
- logger.LegacyPrintf(
- "handler.sora_client",
- "[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d",
- genID,
- userID,
- groupIDForLog(groupID),
- model,
- mediaType,
- videoCount,
- strings.TrimSpace(imageInput) != "",
- len(strings.TrimSpace(prompt)),
- )
-
- // 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底
- if groupID == nil {
- ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
- }
-
- if h.gatewayService == nil {
- _ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化")
- return
- }
-
- // 选择 Sora 账号
- account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model)
- if err != nil {
- logger.LegacyPrintf(
- "handler.sora_client",
- "[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v",
- genID,
- userID,
- groupIDForLog(groupID),
- model,
- err,
- )
- _ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error())
- return
- }
- logger.LegacyPrintf(
- "handler.sora_client",
- "[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s",
- genID,
- userID,
- groupIDForLog(groupID),
- model,
- account.ID,
- account.Name,
- account.Platform,
- account.Type,
- )
-
- // 构建 chat completions 请求体(非流式)
- body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount))
-
- if h.soraGatewayService == nil {
- _ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化")
- return
- }
-
- // 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL)
- recorder := httptest.NewRecorder()
- mockGinCtx, _ := gin.CreateTestContext(recorder)
- mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil)
-
- // 调用 Forward(非流式)
- result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false)
- if err != nil {
- logger.LegacyPrintf(
- "handler.sora_client",
- "[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v",
- genID,
- account.ID,
- model,
- recorder.Code,
- trimForLog(recorder.Body.String(), 400),
- err,
- )
- // 检查是否已取消
- gen, _ := h.genService.GetByID(ctx, genID, userID)
- if gen != nil && gen.Status == service.SoraGenStatusCancelled {
- return
- }
- _ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error())
- return
- }
-
- // 提取媒体 URL(优先从 ForwardResult,其次从响应体解析)
- mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder)
- if mediaURL == "" {
- logger.LegacyPrintf(
- "handler.sora_client",
- "[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s",
- genID,
- account.ID,
- model,
- recorder.Code,
- trimForLog(recorder.Body.String(), 400),
- )
- _ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL")
- return
- }
-
- // 检查任务是否已被取消
- gen, _ := h.genService.GetByID(ctx, genID, userID)
- if gen != nil && gen.Status == service.SoraGenStatusCancelled {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID)
- return
- }
-
- // 三层降级存储:S3 → 本地 → 上游临时 URL
- storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs)
-
- usageAdded := false
- if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil {
- if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil {
- h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
- var quotaErr *service.QuotaExceededError
- if errors.As(err, "aErr) {
- _ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间")
- return
- }
- _ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error())
- return
- }
- usageAdded = true
- }
-
- // 存储完成后再做一次取消检查,防止取消被 completed 覆盖。
- gen, _ = h.genService.GetByID(ctx, genID, userID)
- if gen != nil && gen.Status == service.SoraGenStatusCancelled {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID)
- h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
- if usageAdded && h.quotaService != nil {
- _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
- }
- return
- }
-
- // 标记完成
- if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil {
- if errors.Is(err, service.ErrSoraGenerationStateConflict) {
- h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs)
- if usageAdded && h.quotaService != nil {
- _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize)
- }
- return
- }
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err)
- return
- }
-
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize)
-}
-
-// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。
-func (h *SoraClientHandler) storeMediaWithDegradation(
- ctx context.Context, userID int64, mediaType string,
- mediaURL string, mediaURLs []string,
-) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) {
- urls := mediaURLs
- if len(urls) == 0 {
- urls = []string{mediaURL}
- }
-
- // 第一层:尝试 S3
- if h.s3Storage != nil && h.s3Storage.Enabled(ctx) {
- keys := make([]string, 0, len(urls))
- var totalSize int64
- allOK := true
- for _, u := range urls {
- key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u)
- if err != nil {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err)
- allOK = false
- // 清理已上传的文件
- if len(keys) > 0 {
- _ = h.s3Storage.DeleteObjects(ctx, keys)
- }
- break
- }
- keys = append(keys, key)
- totalSize += size
- }
- if allOK && len(keys) > 0 {
- accessURLs := make([]string, 0, len(keys))
- for _, key := range keys {
- accessURL, err := h.s3Storage.GetAccessURL(ctx, key)
- if err != nil {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err)
- _ = h.s3Storage.DeleteObjects(ctx, keys)
- allOK = false
- break
- }
- accessURLs = append(accessURLs, accessURL)
- }
- if allOK && len(accessURLs) > 0 {
- return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize
- }
- }
- }
-
- // 第二层:尝试本地存储
- if h.mediaStorage != nil && h.mediaStorage.Enabled() {
- storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls)
- if err == nil && len(storedPaths) > 0 {
- firstPath := storedPaths[0]
- totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths)
- if sizeErr != nil {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr)
- }
- return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize
- }
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err)
- }
-
- // 第三层:保留上游临时 URL
- return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0
-}
-
-// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。
-func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte {
- body := map[string]any{
- "model": model,
- "messages": []map[string]string{
- {"role": "user", "content": prompt},
- },
- "stream": false,
- }
- if imageInput != "" {
- body["image_input"] = imageInput
- }
- if videoCount > 1 {
- body["video_count"] = videoCount
- }
- b, _ := json.Marshal(body)
- return b
-}
-
-func normalizeVideoCount(mediaType string, videoCount int) int {
- if mediaType != "video" {
- return 1
- }
- if videoCount <= 0 {
- return 1
- }
- if videoCount > 3 {
- return 3
- }
- return videoCount
-}
-
-// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。
-// OAuth 路径:ForwardResult.MediaURL 已填充。
-// APIKey 路径:需从响应体解析 media_url / media_urls 字段。
-func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) {
- // 优先从 ForwardResult 获取(OAuth 路径)
- if result != nil && result.MediaURL != "" {
- // 尝试从响应体获取完整 URL 列表
- if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
- return urls[0], urls
- }
- return result.MediaURL, []string{result.MediaURL}
- }
-
- // 从响应体解析(APIKey 路径)
- if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 {
- return urls[0], urls
- }
-
- return "", nil
-}
-
-// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。
-func parseMediaURLsFromBody(body []byte) []string {
- if len(body) == 0 {
- return nil
- }
- var resp map[string]any
- if err := json.Unmarshal(body, &resp); err != nil {
- return nil
- }
-
- // 优先 media_urls(多图数组)
- if rawURLs, ok := resp["media_urls"]; ok {
- if arr, ok := rawURLs.([]any); ok && len(arr) > 0 {
- urls := make([]string, 0, len(arr))
- for _, item := range arr {
- if s, ok := item.(string); ok && s != "" {
- urls = append(urls, s)
- }
- }
- if len(urls) > 0 {
- return urls
- }
- }
- }
-
- // 回退到 media_url(单个 URL)
- if url, ok := resp["media_url"].(string); ok && url != "" {
- return []string{url}
- }
-
- return nil
-}
-
-// ListGenerations 查询生成记录列表。
-// GET /api/v1/sora/generations
-func (h *SoraClientHandler) ListGenerations(c *gin.Context) {
- userID := getUserIDFromContext(c)
- if userID == 0 {
- response.Error(c, http.StatusUnauthorized, "未登录")
- return
- }
-
- page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
- pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
-
- params := service.SoraGenerationListParams{
- UserID: userID,
- Status: c.Query("status"),
- StorageType: c.Query("storage_type"),
- MediaType: c.Query("media_type"),
- Page: page,
- PageSize: pageSize,
- }
-
- gens, total, err := h.genService.List(c.Request.Context(), params)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- // 为 S3 记录动态生成预签名 URL
- for _, gen := range gens {
- _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
- }
-
- response.Success(c, gin.H{
- "data": gens,
- "total": total,
- "page": page,
- })
-}
-
-// GetGeneration 查询生成记录详情。
-// GET /api/v1/sora/generations/:id
-func (h *SoraClientHandler) GetGeneration(c *gin.Context) {
- userID := getUserIDFromContext(c)
- if userID == 0 {
- response.Error(c, http.StatusUnauthorized, "未登录")
- return
- }
-
- id, err := strconv.ParseInt(c.Param("id"), 10, 64)
- if err != nil {
- response.Error(c, http.StatusBadRequest, "无效的 ID")
- return
- }
-
- gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
- if err != nil {
- response.Error(c, http.StatusNotFound, err.Error())
- return
- }
-
- _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen)
- response.Success(c, gen)
-}
-
-// DeleteGeneration 删除生成记录。
-// DELETE /api/v1/sora/generations/:id
-func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) {
- userID := getUserIDFromContext(c)
- if userID == 0 {
- response.Error(c, http.StatusUnauthorized, "未登录")
- return
- }
-
- id, err := strconv.ParseInt(c.Param("id"), 10, 64)
- if err != nil {
- response.Error(c, http.StatusBadRequest, "无效的 ID")
- return
- }
-
- gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
- if err != nil {
- response.Error(c, http.StatusNotFound, err.Error())
- return
- }
-
- // 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。
- if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil {
- paths := gen.MediaURLs
- if len(paths) == 0 && gen.MediaURL != "" {
- paths = []string{gen.MediaURL}
- }
- if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err)
- }
- }
-
- if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil {
- response.Error(c, http.StatusNotFound, err.Error())
- return
- }
-
- response.Success(c, gin.H{"message": "已删除"})
-}
-
-// GetQuota 查询用户存储配额。
-// GET /api/v1/sora/quota
-func (h *SoraClientHandler) GetQuota(c *gin.Context) {
- userID := getUserIDFromContext(c)
- if userID == 0 {
- response.Error(c, http.StatusUnauthorized, "未登录")
- return
- }
-
- if h.quotaService == nil {
- response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"})
- return
- }
-
- quota, err := h.quotaService.GetQuota(c.Request.Context(), userID)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, quota)
-}
-
-// CancelGeneration 取消生成任务。
-// POST /api/v1/sora/generations/:id/cancel
-func (h *SoraClientHandler) CancelGeneration(c *gin.Context) {
- userID := getUserIDFromContext(c)
- if userID == 0 {
- response.Error(c, http.StatusUnauthorized, "未登录")
- return
- }
-
- id, err := strconv.ParseInt(c.Param("id"), 10, 64)
- if err != nil {
- response.Error(c, http.StatusBadRequest, "无效的 ID")
- return
- }
-
- // 权限校验
- gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
- if err != nil {
- response.Error(c, http.StatusNotFound, err.Error())
- return
- }
- _ = gen
-
- if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil {
- if errors.Is(err, service.ErrSoraGenerationNotActive) {
- response.Error(c, http.StatusConflict, "任务已结束,无法取消")
- return
- }
- response.Error(c, http.StatusBadRequest, err.Error())
- return
- }
-
- response.Success(c, gin.H{"message": "已取消"})
-}
-
-// SaveToStorage 手动保存 upstream 记录到 S3。
-// POST /api/v1/sora/generations/:id/save
-func (h *SoraClientHandler) SaveToStorage(c *gin.Context) {
- userID := getUserIDFromContext(c)
- if userID == 0 {
- response.Error(c, http.StatusUnauthorized, "未登录")
- return
- }
-
- id, err := strconv.ParseInt(c.Param("id"), 10, 64)
- if err != nil {
- response.Error(c, http.StatusBadRequest, "无效的 ID")
- return
- }
-
- gen, err := h.genService.GetByID(c.Request.Context(), id, userID)
- if err != nil {
- response.Error(c, http.StatusNotFound, err.Error())
- return
- }
-
- if gen.StorageType != service.SoraStorageTypeUpstream {
- response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存")
- return
- }
- if gen.MediaURL == "" {
- response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
- return
- }
-
- if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) {
- response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员")
- return
- }
-
- sourceURLs := gen.MediaURLs
- if len(sourceURLs) == 0 && gen.MediaURL != "" {
- sourceURLs = []string{gen.MediaURL}
- }
- if len(sourceURLs) == 0 {
- response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期")
- return
- }
-
- uploadedKeys := make([]string, 0, len(sourceURLs))
- accessURLs := make([]string, 0, len(sourceURLs))
- var totalSize int64
-
- for _, sourceURL := range sourceURLs {
- objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL)
- if uploadErr != nil {
- if len(uploadedKeys) > 0 {
- _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
- }
- var upstreamErr *service.UpstreamDownloadError
- if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) {
- response.Error(c, http.StatusGone, "媒体链接已过期,无法保存")
- return
- }
- response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error())
- return
- }
- accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey)
- if err != nil {
- uploadedKeys = append(uploadedKeys, objectKey)
- _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
- response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error())
- return
- }
- uploadedKeys = append(uploadedKeys, objectKey)
- accessURLs = append(accessURLs, accessURL)
- totalSize += fileSize
- }
-
- usageAdded := false
- if totalSize > 0 && h.quotaService != nil {
- if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil {
- _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
- var quotaErr *service.QuotaExceededError
- if errors.As(err, "aErr) {
- response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间")
- return
- }
- response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error())
- return
- }
- usageAdded = true
- }
-
- if err := h.genService.UpdateStorageForCompleted(
- c.Request.Context(),
- id,
- accessURLs[0],
- accessURLs,
- service.SoraStorageTypeS3,
- uploadedKeys,
- totalSize,
- ); err != nil {
- _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys)
- if usageAdded && h.quotaService != nil {
- _ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize)
- }
- response.ErrorFrom(c, err)
- return
- }
-
- response.Success(c, gin.H{
- "message": "已保存到 S3",
- "object_key": uploadedKeys[0],
- "object_keys": uploadedKeys,
- })
-}
-
-// GetStorageStatus 返回存储状态。
-// GET /api/v1/sora/storage-status
-func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) {
- s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context())
- s3Healthy := false
- if s3Enabled {
- s3Healthy = h.s3Storage.IsHealthy(c.Request.Context())
- }
- localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled()
- response.Success(c, gin.H{
- "s3_enabled": s3Enabled,
- "s3_healthy": s3Healthy,
- "local_enabled": localEnabled,
- })
-}
-
-func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) {
- switch storageType {
- case service.SoraStorageTypeS3:
- if h.s3Storage != nil && len(s3Keys) > 0 {
- if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err)
- }
- }
- case service.SoraStorageTypeLocal:
- if h.mediaStorage != nil && len(localPaths) > 0 {
- if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err)
- }
- }
- }
-}
-
-// getUserIDFromContext 从 gin 上下文中提取用户 ID。
-func getUserIDFromContext(c *gin.Context) int64 {
- if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
- return subject.UserID
- }
-
- if id, ok := c.Get("user_id"); ok {
- switch v := id.(type) {
- case int64:
- return v
- case float64:
- return int64(v)
- case string:
- n, _ := strconv.ParseInt(v, 10, 64)
- return n
- }
- }
- // 尝试从 JWT claims 获取
- if id, ok := c.Get("userID"); ok {
- if v, ok := id.(int64); ok {
- return v
- }
- }
- return 0
-}
-
-func groupIDForLog(groupID *int64) int64 {
- if groupID == nil {
- return 0
- }
- return *groupID
-}
-
-func trimForLog(raw string, maxLen int) string {
- trimmed := strings.TrimSpace(raw)
- if maxLen <= 0 || len(trimmed) <= maxLen {
- return trimmed
- }
- return trimmed[:maxLen] + "...(truncated)"
-}
-
-// GetModels 获取可用 Sora 模型家族列表。
-// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。
-// GET /api/v1/sora/models
-func (h *SoraClientHandler) GetModels(c *gin.Context) {
- families := h.getModelFamilies(c.Request.Context())
- response.Success(c, families)
-}
-
-// getModelFamilies 获取模型家族列表(带缓存)。
-func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily {
- // 读锁检查缓存
- h.modelCacheMu.RLock()
- ttl := modelCacheTTL
- if !h.modelCacheUpstream {
- ttl = modelCacheFailedTTL
- }
- if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
- families := h.cachedFamilies
- h.modelCacheMu.RUnlock()
- return families
- }
- h.modelCacheMu.RUnlock()
-
- // 写锁更新缓存
- h.modelCacheMu.Lock()
- defer h.modelCacheMu.Unlock()
-
- // double-check
- ttl = modelCacheTTL
- if !h.modelCacheUpstream {
- ttl = modelCacheFailedTTL
- }
- if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl {
- return h.cachedFamilies
- }
-
- // 尝试从上游获取
- families, err := h.fetchUpstreamModels(ctx)
- if err != nil {
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err)
- families = service.BuildSoraModelFamilies()
- h.cachedFamilies = families
- h.modelCacheTime = time.Now()
- h.modelCacheUpstream = false
- return families
- }
-
- logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families))
- h.cachedFamilies = families
- h.modelCacheTime = time.Now()
- h.modelCacheUpstream = true
- return families
-}
-
-// fetchUpstreamModels 从上游 Sora API 获取模型列表。
-func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) {
- if h.gatewayService == nil {
- return nil, fmt.Errorf("gatewayService 未初始化")
- }
-
- // 设置 ForcePlatform 用于 Sora 账号选择
- ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora)
-
- // 选择一个 Sora 账号
- account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s")
- if err != nil {
- return nil, fmt.Errorf("选择 Sora 账号失败: %w", err)
- }
-
- // 仅支持 API Key 类型账号
- if account.Type != service.AccountTypeAPIKey {
- return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type)
- }
-
- apiKey := account.GetCredential("api_key")
- if apiKey == "" {
- return nil, fmt.Errorf("账号缺少 api_key")
- }
-
- baseURL := account.GetBaseURL()
- if baseURL == "" {
- return nil, fmt.Errorf("账号缺少 base_url")
- }
-
- // 构建上游模型列表请求
- modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models"
-
- reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
- defer cancel()
-
- req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil)
- if err != nil {
- return nil, fmt.Errorf("创建请求失败: %w", err)
- }
- req.Header.Set("Authorization", "Bearer "+apiKey)
-
- client := &http.Client{Timeout: 10 * time.Second}
- resp, err := client.Do(req)
- if err != nil {
- return nil, fmt.Errorf("请求上游失败: %w", err)
- }
- defer func() {
- _ = resp.Body.Close()
- }()
-
- if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode)
- }
-
- body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024))
- if err != nil {
- return nil, fmt.Errorf("读取响应失败: %w", err)
- }
-
- // 解析 OpenAI 格式的模型列表
- var modelsResp struct {
- Data []struct {
- ID string `json:"id"`
- } `json:"data"`
- }
- if err := json.Unmarshal(body, &modelsResp); err != nil {
- return nil, fmt.Errorf("解析响应失败: %w", err)
- }
-
- if len(modelsResp.Data) == 0 {
- return nil, fmt.Errorf("上游返回空模型列表")
- }
-
- // 提取模型 ID
- modelIDs := make([]string, 0, len(modelsResp.Data))
- for _, m := range modelsResp.Data {
- modelIDs = append(modelIDs, m.ID)
- }
-
- // 转换为模型家族
- families := service.BuildSoraModelFamiliesFromIDs(modelIDs)
- if len(families) == 0 {
- return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族")
- }
-
- return families, nil
-}
diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go
deleted file mode 100644
index fe035b6f..00000000
--- a/backend/internal/handler/sora_client_handler_test.go
+++ /dev/null
@@ -1,3178 +0,0 @@
-//go:build unit
-
-package handler
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/http/httptest"
- "os"
- "strings"
- "sync/atomic"
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
- middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
- "github.com/Wei-Shaw/sub2api/internal/service"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/require"
-)
-
-func init() {
- gin.SetMode(gin.TestMode)
-}
-
-// ==================== Stub: SoraGenerationRepository ====================
-
-var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil)
-
-type stubSoraGenRepo struct {
- gens map[int64]*service.SoraGeneration
- nextID int64
- createErr error
- getErr error
- updateErr error
- deleteErr error
- listErr error
- countErr error
- countValue int64
-
- // 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败
- updateCallCount *int32
- updateFailAfterN int32
-
- // 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus
- getByIDCallCount int32
- getByIDOverrideAfterN int32 // 0 = 不覆盖
- getByIDOverrideStatus string
-}
-
-func newStubSoraGenRepo() *stubSoraGenRepo {
- return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1}
-}
-
-func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error {
- if r.createErr != nil {
- return r.createErr
- }
- gen.ID = r.nextID
- r.nextID++
- r.gens[gen.ID] = gen
- return nil
-}
-func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) {
- if r.getErr != nil {
- return nil, r.getErr
- }
- gen, ok := r.gens[id]
- if !ok {
- return nil, fmt.Errorf("not found")
- }
- // 条件性状态覆盖:模拟外部取消等场景
- if r.getByIDOverrideAfterN > 0 {
- n := atomic.AddInt32(&r.getByIDCallCount, 1)
- if n > r.getByIDOverrideAfterN {
- cp := *gen
- cp.Status = r.getByIDOverrideStatus
- return &cp, nil
- }
- }
- return gen, nil
-}
-func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error {
- // 条件性失败:前 N 次成功,之后失败
- if r.updateCallCount != nil {
- n := atomic.AddInt32(r.updateCallCount, 1)
- if n > r.updateFailAfterN {
- return fmt.Errorf("conditional update error (call #%d)", n)
- }
- }
- if r.updateErr != nil {
- return r.updateErr
- }
- r.gens[gen.ID] = gen
- return nil
-}
-func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error {
- if r.deleteErr != nil {
- return r.deleteErr
- }
- delete(r.gens, id)
- return nil
-}
-func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
- if r.listErr != nil {
- return nil, 0, r.listErr
- }
- var result []*service.SoraGeneration
- for _, gen := range r.gens {
- if gen.UserID != params.UserID {
- continue
- }
- result = append(result, gen)
- }
- return result, int64(len(result)), nil
-}
-func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) {
- if r.countErr != nil {
- return 0, r.countErr
- }
- return r.countValue, nil
-}
-
-// ==================== 辅助函数 ====================
-
-func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler {
- genService := service.NewSoraGenerationService(repo, nil, nil)
- return &SoraClientHandler{genService: genService}
-}
-
-func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) {
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- if body != "" {
- c.Request = httptest.NewRequest(method, path, strings.NewReader(body))
- c.Request.Header.Set("Content-Type", "application/json")
- } else {
- c.Request = httptest.NewRequest(method, path, nil)
- }
- if userID > 0 {
- c.Set("user_id", userID)
- }
- return c, rec
-}
-
-func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any {
- t.Helper()
- var resp map[string]any
- require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
- return resp
-}
-
-// ==================== 纯函数测试: buildAsyncRequestBody ====================
-
-func TestBuildAsyncRequestBody(t *testing.T) {
- body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1)
- var parsed map[string]any
- require.NoError(t, json.Unmarshal(body, &parsed))
- require.Equal(t, "sora2-landscape-10s", parsed["model"])
- require.Equal(t, false, parsed["stream"])
-
- msgs := parsed["messages"].([]any)
- require.Len(t, msgs, 1)
- msg := msgs[0].(map[string]any)
- require.Equal(t, "user", msg["role"])
- require.Equal(t, "一只猫在跳舞", msg["content"])
-}
-
-func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) {
- body := buildAsyncRequestBody("gpt-image", "", "", 1)
- var parsed map[string]any
- require.NoError(t, json.Unmarshal(body, &parsed))
- require.Equal(t, "gpt-image", parsed["model"])
- msgs := parsed["messages"].([]any)
- msg := msgs[0].(map[string]any)
- require.Equal(t, "", msg["content"])
-}
-
-func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) {
- body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1)
- var parsed map[string]any
- require.NoError(t, json.Unmarshal(body, &parsed))
- require.Equal(t, "https://example.com/ref.png", parsed["image_input"])
-}
-
-func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) {
- body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3)
- var parsed map[string]any
- require.NoError(t, json.Unmarshal(body, &parsed))
- require.Equal(t, float64(3), parsed["video_count"])
-}
-
-func TestNormalizeVideoCount(t *testing.T) {
- require.Equal(t, 1, normalizeVideoCount("video", 0))
- require.Equal(t, 2, normalizeVideoCount("video", 2))
- require.Equal(t, 3, normalizeVideoCount("video", 5))
- require.Equal(t, 1, normalizeVideoCount("image", 3))
-}
-
-// ==================== 纯函数测试: parseMediaURLsFromBody ====================
-
-func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) {
- urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`))
- require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
-}
-
-func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) {
- urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`))
- require.Equal(t, []string{"https://a.com/video.mp4"}, urls)
-}
-
-func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) {
- require.Nil(t, parseMediaURLsFromBody(nil))
- require.Nil(t, parseMediaURLsFromBody([]byte{}))
-}
-
-func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) {
- require.Nil(t, parseMediaURLsFromBody([]byte("not json")))
-}
-
-func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) {
- require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`)))
-}
-
-func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) {
- require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`)))
-}
-
-func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) {
- require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`)))
-}
-
-func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) {
- body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}`
- urls := parseMediaURLsFromBody([]byte(body))
- require.Len(t, urls, 2)
- require.Equal(t, "https://multi.com/a.mp4", urls[0])
-}
-
-func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) {
- urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`))
- require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
-}
-
-func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) {
- require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`)))
-}
-
-func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) {
- // media_urls 不是 string 数组
- require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`)))
-}
-
-func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) {
- require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`)))
-}
-
-// ==================== 纯函数测试: extractMediaURLsFromResult ====================
-
-func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) {
- result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
- recorder := httptest.NewRecorder()
- url, urls := extractMediaURLsFromResult(result, recorder)
- require.Equal(t, "https://oauth.com/video.mp4", url)
- require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls)
-}
-
-func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) {
- result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"}
- recorder := httptest.NewRecorder()
- _, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`))
- url, urls := extractMediaURLsFromResult(result, recorder)
- require.Equal(t, "https://body.com/1.mp4", url)
- require.Len(t, urls, 2)
-}
-
-func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) {
- recorder := httptest.NewRecorder()
- _, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`))
- url, urls := extractMediaURLsFromResult(nil, recorder)
- require.Equal(t, "https://upstream.com/video.mp4", url)
- require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls)
-}
-
-func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) {
- recorder := httptest.NewRecorder()
- url, urls := extractMediaURLsFromResult(nil, recorder)
- require.Empty(t, url)
- require.Nil(t, urls)
-}
-
-func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) {
- result := &service.ForwardResult{MediaURL: ""}
- recorder := httptest.NewRecorder()
- url, urls := extractMediaURLsFromResult(result, recorder)
- require.Empty(t, url)
- require.Nil(t, urls)
-}
-
-// ==================== getUserIDFromContext ====================
-
-func TestGetUserIDFromContext_Int64(t *testing.T) {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Set("user_id", int64(42))
- require.Equal(t, int64(42), getUserIDFromContext(c))
-}
-
-func TestGetUserIDFromContext_AuthSubject(t *testing.T) {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777})
- require.Equal(t, int64(777), getUserIDFromContext(c))
-}
-
-func TestGetUserIDFromContext_Float64(t *testing.T) {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Set("user_id", float64(99))
- require.Equal(t, int64(99), getUserIDFromContext(c))
-}
-
-func TestGetUserIDFromContext_String(t *testing.T) {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Set("user_id", "123")
- require.Equal(t, int64(123), getUserIDFromContext(c))
-}
-
-func TestGetUserIDFromContext_UserIDFallback(t *testing.T) {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Set("userID", int64(55))
- require.Equal(t, int64(55), getUserIDFromContext(c))
-}
-
-func TestGetUserIDFromContext_NoID(t *testing.T) {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- require.Equal(t, int64(0), getUserIDFromContext(c))
-}
-
-func TestGetUserIDFromContext_InvalidString(t *testing.T) {
- c, _ := gin.CreateTestContext(httptest.NewRecorder())
- c.Request = httptest.NewRequest("GET", "/", nil)
- c.Set("user_id", "not-a-number")
- require.Equal(t, int64(0), getUserIDFromContext(c))
-}
-
-// ==================== Handler: Generate ====================
-
-func TestGenerate_Unauthorized(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0)
- h.Generate(c)
- require.Equal(t, http.StatusUnauthorized, rec.Code)
-}
-
-func TestGenerate_BadRequest_MissingModel(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestGenerate_BadRequest_MissingPrompt(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestGenerate_BadRequest_InvalidJSON(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestGenerate_TooManyRequests(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.countValue = 3
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusTooManyRequests, rec.Code)
-}
-
-func TestGenerate_CountError(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.countErr = fmt.Errorf("db error")
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-func TestGenerate_Success(t *testing.T) {
- repo := newStubSoraGenRepo()
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.NotZero(t, data["generation_id"])
- require.Equal(t, "pending", data["status"])
-}
-
-func TestGenerate_DefaultMediaType(t *testing.T) {
- repo := newStubSoraGenRepo()
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.Equal(t, "video", repo.gens[1].MediaType)
-}
-
-func TestGenerate_ImageMediaType(t *testing.T) {
- repo := newStubSoraGenRepo()
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.Equal(t, "image", repo.gens[1].MediaType)
-}
-
-func TestGenerate_CreatePendingError(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.createErr = fmt.Errorf("create failed")
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) {
- repo := newStubSoraGenRepo()
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
-}
-
-func TestGenerate_APIKeyInContext(t *testing.T) {
- repo := newStubSoraGenRepo()
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- c.Set("api_key_id", int64(42))
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.NotNil(t, repo.gens[1].APIKeyID)
- require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
-}
-
-func TestGenerate_NoAPIKeyInContext(t *testing.T) {
- repo := newStubSoraGenRepo()
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.Nil(t, repo.gens[1].APIKeyID)
-}
-
-func TestGenerate_ConcurrencyBoundary(t *testing.T) {
- // activeCount == 2 应该允许
- repo := newStubSoraGenRepo()
- repo.countValue = 2
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
-}
-
-// ==================== Handler: ListGenerations ====================
-
-func TestListGenerations_Unauthorized(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0)
- h.ListGenerations(c)
- require.Equal(t, http.StatusUnauthorized, rec.Code)
-}
-
-func TestListGenerations_Success(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"}
- repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"}
- repo.nextID = 3
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1)
- h.ListGenerations(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- items := data["data"].([]any)
- require.Len(t, items, 2)
- require.Equal(t, float64(2), data["total"])
-}
-
-func TestListGenerations_ListError(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.listErr = fmt.Errorf("db error")
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
- h.ListGenerations(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-func TestListGenerations_DefaultPagination(t *testing.T) {
- repo := newStubSoraGenRepo()
- h := newTestSoraClientHandler(repo)
- // 不传分页参数,应默认 page=1 page_size=20
- c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1)
- h.ListGenerations(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, float64(1), data["page"])
-}
-
-// ==================== Handler: GetGeneration ====================
-
-func TestGetGeneration_Unauthorized(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.GetGeneration(c)
- require.Equal(t, http.StatusUnauthorized, rec.Code)
-}
-
-func TestGetGeneration_InvalidID(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "abc"}}
- h.GetGeneration(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestGetGeneration_NotFound(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "999"}}
- h.GetGeneration(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-func TestGetGeneration_WrongUser(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.GetGeneration(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-func TestGetGeneration_Success(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.GetGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, float64(1), data["id"])
-}
-
-// ==================== Handler: DeleteGeneration ====================
-
-func TestDeleteGeneration_Unauthorized(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusUnauthorized, rec.Code)
-}
-
-func TestDeleteGeneration_InvalidID(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "abc"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestDeleteGeneration_NotFound(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "999"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-func TestDeleteGeneration_WrongUser(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-func TestDeleteGeneration_Success(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
- _, exists := repo.gens[1]
- require.False(t, exists)
-}
-
-// ==================== Handler: CancelGeneration ====================
-
-func TestCancelGeneration_Unauthorized(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusUnauthorized, rec.Code)
-}
-
-func TestCancelGeneration_InvalidID(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "abc"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestCancelGeneration_NotFound(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "999"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-func TestCancelGeneration_WrongUser(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-func TestCancelGeneration_Pending(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.Equal(t, "cancelled", repo.gens[1].Status)
-}
-
-func TestCancelGeneration_Generating(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.Equal(t, "cancelled", repo.gens[1].Status)
-}
-
-func TestCancelGeneration_Completed(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusConflict, rec.Code)
-}
-
-func TestCancelGeneration_Failed(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusConflict, rec.Code)
-}
-
-func TestCancelGeneration_Cancelled(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusConflict, rec.Code)
-}
-
-// ==================== Handler: GetQuota ====================
-
-func TestGetQuota_Unauthorized(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0)
- h.GetQuota(c)
- require.Equal(t, http.StatusUnauthorized, rec.Code)
-}
-
-func TestGetQuota_NilQuotaService(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
- h.GetQuota(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, "unlimited", data["source"])
-}
-
-// ==================== Handler: GetModels ====================
-
-func TestGetModels(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0)
- h.GetModels(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].([]any)
- require.Len(t, data, 4)
- // 验证类型分布
- videoCount, imageCount := 0, 0
- for _, item := range data {
- m := item.(map[string]any)
- if m["type"] == "video" {
- videoCount++
- } else if m["type"] == "image" {
- imageCount++
- }
- }
- require.Equal(t, 3, videoCount)
- require.Equal(t, 1, imageCount)
-}
-
-// ==================== Handler: GetStorageStatus ====================
-
-func TestGetStorageStatus_NilS3(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
- h.GetStorageStatus(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, false, data["s3_enabled"])
- require.Equal(t, false, data["s3_healthy"])
- require.Equal(t, false, data["local_enabled"])
-}
-
-func TestGetStorageStatus_LocalEnabled(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "sora-storage-status-*")
- require.NoError(t, err)
- defer os.RemoveAll(tmpDir)
-
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: tmpDir,
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
- h := &SoraClientHandler{mediaStorage: mediaStorage}
-
- c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
- h.GetStorageStatus(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, false, data["s3_enabled"])
- require.Equal(t, false, data["s3_healthy"])
- require.Equal(t, true, data["local_enabled"])
-}
-
-// ==================== Handler: SaveToStorage ====================
-
-func TestSaveToStorage_Unauthorized(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusUnauthorized, rec.Code)
-}
-
-func TestSaveToStorage_InvalidID(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "abc"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestSaveToStorage_NotFound(t *testing.T) {
- h := newTestSoraClientHandler(newStubSoraGenRepo())
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "999"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-func TestSaveToStorage_NotUpstream(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestSaveToStorage_EmptyMediaURL(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-func TestSaveToStorage_S3Nil(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusServiceUnavailable, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, fmt.Sprint(resp["message"]), "云存储")
-}
-
-func TestSaveToStorage_WrongUser(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"}
- h := newTestSoraClientHandler(repo)
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-// ==================== storeMediaWithDegradation — nil guard 路径 ====================
-
-func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) {
- h := &SoraClientHandler{}
- url, urls, storageType, keys, size := h.storeMediaWithDegradation(
- context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
- )
- require.Equal(t, service.SoraStorageTypeUpstream, storageType)
- require.Equal(t, "https://upstream.com/v.mp4", url)
- require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls)
- require.Nil(t, keys)
- require.Equal(t, int64(0), size)
-}
-
-func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) {
- h := &SoraClientHandler{}
- url, urls, storageType, keys, size := h.storeMediaWithDegradation(
- context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"},
- )
- require.Equal(t, service.SoraStorageTypeUpstream, storageType)
- require.Equal(t, "https://a.com/1.mp4", url)
- require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls)
- require.Nil(t, keys)
- require.Equal(t, int64(0), size)
-}
-
-func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) {
- h := &SoraClientHandler{}
- url, _, storageType, _, _ := h.storeMediaWithDegradation(
- context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{},
- )
- require.Equal(t, service.SoraStorageTypeUpstream, storageType)
- require.Equal(t, "https://upstream.com/v.mp4", url)
-}
-
-// ==================== Stub: UserRepository (用于 SoraQuotaService) ====================
-
-var _ service.UserRepository = (*stubUserRepoForHandler)(nil)
-
-type stubUserRepoForHandler struct {
- users map[int64]*service.User
- updateErr error
-}
-
-func newStubUserRepoForHandler() *stubUserRepoForHandler {
- return &stubUserRepoForHandler{users: make(map[int64]*service.User)}
-}
-
-func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) {
- if u, ok := r.users[id]; ok {
- return u, nil
- }
- return nil, fmt.Errorf("user not found")
-}
-func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error {
- if r.updateErr != nil {
- return r.updateErr
- }
- r.users[user.ID] = user
- return nil
-}
-func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil }
-func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) {
- return nil, nil
-}
-func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) {
- return nil, nil
-}
-func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil }
-func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
- return nil, nil, nil
-}
-func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
- return nil, nil, nil
-}
-func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil }
-func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil }
-func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil }
-func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) {
- return false, nil
-}
-func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
- return 0, nil
-}
-func (r *stubUserRepoForHandler) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
- return nil
-}
-func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
-func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
-func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
-func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error {
- return nil
-}
-
-// ==================== NewSoraClientHandler ====================
-
-func TestNewSoraClientHandler(t *testing.T) {
- h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
- require.NotNil(t, h)
-}
-
-func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) {
- h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil)
- require.NotNil(t, h)
- require.Nil(t, h.apiKeyService)
-}
-
-// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ====================
-
-var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil)
-
-type stubAPIKeyRepoForHandler struct {
- keys map[int64]*service.APIKey
- getErr error
-}
-
-func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler {
- return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)}
-}
-
-func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) {
- if r.getErr != nil {
- return nil, r.getErr
- }
- if k, ok := r.keys[id]; ok {
- return k, nil
- }
- return nil, fmt.Errorf("api key not found: %d", id)
-}
-func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil }
-func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) {
- return "", 0, nil
-}
-func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) {
- return nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) {
- return nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil }
-func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil }
-func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
- return nil, nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
- return nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) {
- return 0, nil
-}
-func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) {
- return false, nil
-}
-func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
- return nil, nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) {
- return nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
- return 0, nil
-}
-func (r *stubAPIKeyRepoForHandler) UpdateGroupIDByUserAndGroup(_ context.Context, userID, oldGroupID, newGroupID int64) (int64, error) {
- var updated int64
- for id, key := range r.keys {
- if key.UserID != userID || key.GroupID == nil || *key.GroupID != oldGroupID {
- continue
- }
- clone := *key
- gid := newGroupID
- clone.GroupID = &gid
- r.keys[id] = &clone
- updated++
- }
- return updated, nil
-}
-func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) {
- return 0, nil
-}
-func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) {
- return nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) {
- return nil, nil
-}
-func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) {
- return 0, nil
-}
-func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error {
- return nil
-}
-func (r *stubAPIKeyRepoForHandler) IncrementRateLimitUsage(context.Context, int64, float64) error {
- return nil
-}
-func (r *stubAPIKeyRepoForHandler) ResetRateLimitWindows(context.Context, int64) error {
- return nil
-}
-func (r *stubAPIKeyRepoForHandler) GetRateLimitData(context.Context, int64) (*service.APIKeyRateLimitData, error) {
- return nil, nil
-}
-
-// newTestAPIKeyService 创建测试用的 APIKeyService
-func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService {
- return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{})
-}
-
-// ==================== Generate: API Key 校验(前端传递 api_key_id)====================
-
-func TestGenerate_WithAPIKeyID_Success(t *testing.T) {
- // 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- groupID := int64(5)
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyRepo.keys[42] = &service.APIKey{
- ID: 42,
- UserID: 1,
- Status: service.StatusAPIKeyActive,
- GroupID: &groupID,
- }
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.NotZero(t, data["generation_id"])
-
- // 验证 api_key_id 已关联到生成记录
- gen := repo.gens[1]
- require.NotNil(t, gen.APIKeyID)
- require.Equal(t, int64(42), *gen.APIKeyID)
-}
-
-func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) {
- // 前端传递不存在的 api_key_id → 400
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, fmt.Sprint(resp["message"]), "不存在")
-}
-
-func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) {
- // 前端传递别人的 api_key_id → 403
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyRepo.keys[42] = &service.APIKey{
- ID: 42,
- UserID: 999, // 属于 user 999
- Status: service.StatusAPIKeyActive,
- }
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusForbidden, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, fmt.Sprint(resp["message"]), "不属于")
-}
-
-func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) {
- // 前端传递已禁用的 api_key_id → 403
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyRepo.keys[42] = &service.APIKey{
- ID: 42,
- UserID: 1,
- Status: service.StatusAPIKeyDisabled,
- }
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusForbidden, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, fmt.Sprint(resp["message"]), "不可用")
-}
-
-func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) {
- // 前端传递配额耗尽的 api_key_id → 403
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyRepo.keys[42] = &service.APIKey{
- ID: 42,
- UserID: 1,
- Status: service.StatusAPIKeyQuotaExhausted,
- }
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusForbidden, rec.Code)
-}
-
-func TestGenerate_WithAPIKeyID_Expired(t *testing.T) {
- // 前端传递已过期的 api_key_id → 403
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyRepo.keys[42] = &service.APIKey{
- ID: 42,
- UserID: 1,
- Status: service.StatusAPIKeyExpired,
- }
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusForbidden, rec.Code)
-}
-
-func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) {
- // apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- h := &SoraClientHandler{genService: genService} // apiKeyService = nil
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- // apiKeyService 为 nil → 跳过校验 → api_key_id 不记录
- require.Nil(t, repo.gens[1].APIKeyID)
-}
-
-func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) {
- // api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyRepo.keys[42] = &service.APIKey{
- ID: 42,
- UserID: 1,
- Status: service.StatusAPIKeyActive,
- GroupID: nil, // 无分组
- }
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.NotNil(t, repo.gens[1].APIKeyID)
- require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
-}
-
-func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) {
- // 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- require.Nil(t, repo.gens[1].APIKeyID)
-}
-
-func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) {
- // 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- groupID := int64(10)
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyRepo.keys[42] = &service.APIKey{
- ID: 42,
- UserID: 1,
- Status: service.StatusAPIKeyActive,
- GroupID: &groupID,
- }
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1)
- c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- // 应使用 body 中的 api_key_id=42,而不是 context 中的 99
- require.NotNil(t, repo.gens[1].APIKeyID)
- require.Equal(t, int64(42), *repo.gens[1].APIKeyID)
-}
-
-func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) {
- // 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由)
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- c.Set("api_key_id", int64(99))
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
- // 应使用 context 中的 api_key_id=99
- require.NotNil(t, repo.gens[1].APIKeyID)
- require.Equal(t, int64(99), *repo.gens[1].APIKeyID)
-}
-
-func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) {
- // JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
- apiKeyRepo := newStubAPIKeyRepoForHandler()
- apiKeyService := newTestAPIKeyService(apiKeyRepo)
-
- h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService}
- // JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验
- // api_key_id=0 不存在 → 400
- c, rec := makeGinContext("POST", "/api/v1/sora/generate",
- `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
-}
-
-// ==================== processGeneration: groupID 传递与 ForcePlatform ====================
-
-func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) {
- // groupID 不为 nil → 不设置 ForcePlatform
- // gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
-
- gid := int64(5)
- h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
-}
-
-func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) {
- // groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
-}
-
-func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) {
- // 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- // 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled
- require.Equal(t, "cancelled", repo.gens[1].Status)
-}
-
-// ==================== GenerateRequest JSON 解析 ====================
-
-func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) {
- // 验证 api_key_id 在 JSON 中正确解析为 *int64
- var req GenerateRequest
- err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req)
- require.NoError(t, err)
- require.NotNil(t, req.APIKeyID)
- require.Equal(t, int64(42), *req.APIKeyID)
-}
-
-func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) {
- // 不传 api_key_id → 解析后为 nil
- var req GenerateRequest
- err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req)
- require.NoError(t, err)
- require.Nil(t, req.APIKeyID)
-}
-
-func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) {
- // api_key_id: null → 解析后为 nil
- var req GenerateRequest
- err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req)
- require.NoError(t, err)
- require.Nil(t, req.APIKeyID)
-}
-
-func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) {
- // 全字段解析
- var req GenerateRequest
- err := json.Unmarshal([]byte(`{
- "model":"sora2-landscape-10s",
- "prompt":"test prompt",
- "media_type":"video",
- "video_count":2,
- "image_input":"data:image/png;base64,abc",
- "api_key_id":100
- }`), &req)
- require.NoError(t, err)
- require.Equal(t, "sora2-landscape-10s", req.Model)
- require.Equal(t, "test prompt", req.Prompt)
- require.Equal(t, "video", req.MediaType)
- require.Equal(t, 2, req.VideoCount)
- require.Equal(t, "data:image/png;base64,abc", req.ImageInput)
- require.NotNil(t, req.APIKeyID)
- require.Equal(t, int64(100), *req.APIKeyID)
-}
-
-func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) {
- // api_key_id 为 nil 时 JSON 序列化应省略
- req := GenerateRequest{Model: "sora2", Prompt: "test"}
- b, err := json.Marshal(req)
- require.NoError(t, err)
- var parsed map[string]any
- require.NoError(t, json.Unmarshal(b, &parsed))
- _, hasAPIKeyID := parsed["api_key_id"]
- require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略")
-}
-
-func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) {
- // api_key_id 不为 nil 时 JSON 序列化应包含
- id := int64(42)
- req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id}
- b, err := json.Marshal(req)
- require.NoError(t, err)
- var parsed map[string]any
- require.NoError(t, json.Unmarshal(b, &parsed))
- require.Equal(t, float64(42), parsed["api_key_id"])
-}
-
-// ==================== GetQuota: 有配额服务 ====================
-
-func TestGetQuota_WithQuotaService_Success(t *testing.T) {
- userRepo := newStubUserRepoForHandler()
- userRepo.users[1] = &service.User{
- ID: 1,
- SoraStorageQuotaBytes: 10 * 1024 * 1024,
- SoraStorageUsedBytes: 3 * 1024 * 1024,
- }
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
-
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{
- genService: genService,
- quotaService: quotaService,
- }
-
- c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1)
- h.GetQuota(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, "user", data["source"])
- require.Equal(t, float64(10*1024*1024), data["quota_bytes"])
- require.Equal(t, float64(3*1024*1024), data["used_bytes"])
-}
-
-func TestGetQuota_WithQuotaService_Error(t *testing.T) {
- // 用户不存在时 GetQuota 返回错误
- userRepo := newStubUserRepoForHandler()
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
-
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{
- genService: genService,
- quotaService: quotaService,
- }
-
- c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999)
- h.GetQuota(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-// ==================== Generate: 配额检查 ====================
-
-func TestGenerate_QuotaCheckFailed(t *testing.T) {
- // 配额超限时返回 429
- userRepo := newStubUserRepoForHandler()
- userRepo.users[1] = &service.User{
- ID: 1,
- SoraStorageQuotaBytes: 1024,
- SoraStorageUsedBytes: 1025, // 已超限
- }
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
-
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{
- genService: genService,
- quotaService: quotaService,
- }
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusTooManyRequests, rec.Code)
-}
-
-func TestGenerate_QuotaCheckPassed(t *testing.T) {
- // 配额充足时允许生成
- userRepo := newStubUserRepoForHandler()
- userRepo.users[1] = &service.User{
- ID: 1,
- SoraStorageQuotaBytes: 10 * 1024 * 1024,
- SoraStorageUsedBytes: 0,
- }
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
-
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{
- genService: genService,
- quotaService: quotaService,
- }
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1)
- h.Generate(c)
- require.Equal(t, http.StatusOK, rec.Code)
-}
-
-// ==================== Stub: SettingRepository (用于 S3 存储测试) ====================
-
-var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil)
-
-type stubSettingRepoForHandler struct {
- values map[string]string
-}
-
-func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler {
- if values == nil {
- values = make(map[string]string)
- }
- return &stubSettingRepoForHandler{values: values}
-}
-
-func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) {
- if v, ok := r.values[key]; ok {
- return &service.Setting{Key: key, Value: v}, nil
- }
- return nil, service.ErrSettingNotFound
-}
-func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) {
- if v, ok := r.values[key]; ok {
- return v, nil
- }
- return "", service.ErrSettingNotFound
-}
-func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error {
- r.values[key] = value
- return nil
-}
-func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
- result := make(map[string]string)
- for _, k := range keys {
- if v, ok := r.values[k]; ok {
- result[k] = v
- }
- }
- return result, nil
-}
-func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error {
- for k, v := range settings {
- r.values[k] = v
- }
- return nil
-}
-func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) {
- return r.values, nil
-}
-func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error {
- delete(r.values, key)
- return nil
-}
-
-// ==================== S3 / MediaStorage 辅助函数 ====================
-
-// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。
-func newS3StorageForHandler(endpoint string) *service.SoraS3Storage {
- settingRepo := newStubSettingRepoForHandler(map[string]string{
- "sora_s3_enabled": "true",
- "sora_s3_endpoint": endpoint,
- "sora_s3_region": "us-east-1",
- "sora_s3_bucket": "test-bucket",
- "sora_s3_access_key_id": "AKIATEST",
- "sora_s3_secret_access_key": "test-secret",
- "sora_s3_prefix": "sora",
- "sora_s3_force_path_style": "true",
- })
- settingService := service.NewSettingService(settingRepo, &config.Config{})
- return service.NewSoraS3Storage(settingService)
-}
-
-// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。
-func newFakeSourceServer() *httptest.Server {
- return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "video/mp4")
- w.WriteHeader(http.StatusOK)
- _, _ = w.Write([]byte("fake video data for test"))
- }))
-}
-
-// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。
-// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。
-func newFakeS3Server(mode string) *httptest.Server {
- var counter atomic.Int32
- return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _, _ = io.Copy(io.Discard, r.Body)
- _ = r.Body.Close()
-
- switch mode {
- case "ok":
- w.Header().Set("ETag", `"test-etag"`)
- w.WriteHeader(http.StatusOK)
- case "fail":
- w.WriteHeader(http.StatusForbidden)
- _, _ = w.Write([]byte(`AccessDenied `))
- case "fail-second":
- n := counter.Add(1)
- if n <= 1 {
- w.Header().Set("ETag", `"test-etag"`)
- w.WriteHeader(http.StatusOK)
- } else {
- w.WriteHeader(http.StatusForbidden)
- _, _ = w.Write([]byte(`AccessDenied `))
- }
- }
- }))
-}
-
-// ==================== processGeneration 直接调用测试 ====================
-
-func TestProcessGeneration_MarkGeneratingFails(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- repo.updateErr = fmt.Errorf("db error")
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
-
- // 直接调用(非 goroutine),MarkGenerating 失败 → 早退
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- // MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating"
- // repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed
- // 因此 ErrorMessage 为空(证明未调用 MarkFailed)
- require.Equal(t, "generating", repo.gens[1].Status)
- require.Empty(t, repo.gens[1].ErrorMessage)
-}
-
-func TestProcessGeneration_GatewayServiceNil(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
- // gatewayService 未设置 → MarkFailed
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService")
-}
-
-// ==================== storeMediaWithDegradation: S3 路径 ====================
-
-func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
- context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
- )
- require.Equal(t, service.SoraStorageTypeS3, storageType)
- require.Len(t, s3Keys, 1)
- require.NotEmpty(t, s3Keys[0])
- require.Len(t, storedURLs, 1)
- require.Equal(t, storedURL, storedURLs[0])
- require.Contains(t, storedURL, fakeS3.URL)
- require.Contains(t, storedURL, "/test-bucket/")
- require.Greater(t, fileSize, int64(0))
-}
-
-func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
- storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(
- context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
- )
- require.Equal(t, service.SoraStorageTypeS3, storageType)
- require.Len(t, s3Keys, 2)
- require.Len(t, storedURLs, 2)
- require.Equal(t, storedURL, storedURLs[0])
- require.Contains(t, storedURLs[0], fakeS3.URL)
- require.Contains(t, storedURLs[1], fakeS3.URL)
- require.Greater(t, fileSize, int64(0))
-}
-
-func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) {
- // 上游返回 404 → 下载失败 → S3 上传不会开始
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
- badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusNotFound)
- }))
- defer badSource.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- _, _, storageType, _, _ := h.storeMediaWithDegradation(
- context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil,
- )
- require.Equal(t, service.SoraStorageTypeUpstream, storageType)
-}
-
-func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("fail")
- defer fakeS3.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
- context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
- )
- // S3 失败,降级到 upstream
- require.Equal(t, service.SoraStorageTypeUpstream, storageType)
- require.Nil(t, s3Keys)
-}
-
-func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("fail-second")
- defer fakeS3.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"}
- _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
- context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls,
- )
- // 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream
- require.Equal(t, service.SoraStorageTypeUpstream, storageType)
- require.Nil(t, s3Keys)
-}
-
-// ==================== storeMediaWithDegradation: 本地存储路径 ====================
-
-func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) {
- // 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: "/dev/null/invalid_dir",
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
- h := &SoraClientHandler{mediaStorage: mediaStorage}
-
- _, _, storageType, _, _ := h.storeMediaWithDegradation(
- context.Background(), 1, "video", "https://upstream.com/v.mp4", nil,
- )
- // 本地存储失败,降级到 upstream
- require.Equal(t, service.SoraStorageTypeUpstream, storageType)
-}
-
-func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
- require.NoError(t, err)
- defer os.RemoveAll(tmpDir)
-
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
-
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: tmpDir,
- DownloadTimeoutSeconds: 5,
- MaxDownloadBytes: 10 * 1024 * 1024,
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
- h := &SoraClientHandler{mediaStorage: mediaStorage}
-
- _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation(
- context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
- )
- require.Equal(t, service.SoraStorageTypeLocal, storageType)
- require.Nil(t, s3Keys) // 本地存储不返回 S3 keys
-}
-
-func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "sora-handler-test-*")
- require.NoError(t, err)
- defer os.RemoveAll(tmpDir)
-
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("fail")
- defer fakeS3.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: tmpDir,
- DownloadTimeoutSeconds: 5,
- MaxDownloadBytes: 10 * 1024 * 1024,
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
- h := &SoraClientHandler{
- s3Storage: s3Storage,
- mediaStorage: mediaStorage,
- }
-
- _, _, storageType, _, _ := h.storeMediaWithDegradation(
- context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil,
- )
- // S3 失败 → 本地存储成功
- require.Equal(t, service.SoraStorageTypeLocal, storageType)
-}
-
-// ==================== SaveToStorage: S3 路径 ====================
-
-func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("fail")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v.mp4",
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, resp["message"], "S3")
-}
-
-func TestSaveToStorage_UpstreamURLExpired(t *testing.T) {
- expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
- w.WriteHeader(http.StatusForbidden)
- }))
- defer expiredServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: expiredServer.URL + "/v.mp4",
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusGone, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, fmt.Sprint(resp["message"]), "过期")
-}
-
-func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v.mp4",
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Contains(t, data["message"], "S3")
- require.NotEmpty(t, data["object_key"])
- // 验证记录已更新为 S3 存储
- require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
-}
-
-func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v1.mp4",
- MediaURLs: []string{
- sourceServer.URL + "/v1.mp4",
- sourceServer.URL + "/v2.mp4",
- },
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Len(t, data["object_keys"].([]any), 2)
- require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
- require.Len(t, repo.gens[1].S3ObjectKeys, 2)
- require.Len(t, repo.gens[1].MediaURLs, 2)
-}
-
-func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v.mp4",
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- userRepo := newStubUserRepoForHandler()
- userRepo.users[1] = &service.User{
- ID: 1,
- SoraStorageQuotaBytes: 100 * 1024 * 1024,
- SoraStorageUsedBytes: 0,
- }
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusOK, rec.Code)
- // 验证配额已累加
- require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
-}
-
-func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v.mp4",
- }
- // S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败
- repo.updateErr = fmt.Errorf("db error")
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-// ==================== GetStorageStatus: S3 路径 ====================
-
-func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) {
- // S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket)
- fakeS3 := newFakeS3Server("fail")
- defer fakeS3.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
- h.GetStorageStatus(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, true, data["s3_enabled"])
- require.Equal(t, false, data["s3_healthy"])
-}
-
-func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) {
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0)
- h.GetStorageStatus(c)
- require.Equal(t, http.StatusOK, rec.Code)
- resp := parseResponse(t, rec)
- data := resp["data"].(map[string]any)
- require.Equal(t, true, data["s3_enabled"])
- require.Equal(t, true, data["s3_healthy"])
-}
-
-// ==================== Stub: AccountRepository (用于 GatewayService) ====================
-
-var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil)
-
-type stubAccountRepoForHandler struct {
- accounts []service.Account
-}
-
-func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil }
-func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) {
- for i := range r.accounts {
- if r.accounts[i].ID == id {
- return &r.accounts[i], nil
- }
- }
- return nil, fmt.Errorf("account not found")
-}
-func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) {
- return nil, nil
-}
-func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) {
- return false, nil
-}
-func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) {
- return nil, nil
-}
-func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) {
- return nil, nil
-}
-func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) {
- return nil, nil
-}
-func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil }
-func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil }
-func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
- return nil, nil, nil
-}
-func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]service.Account, *pagination.PaginationResult, error) {
- return nil, nil, nil
-}
-func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) {
- return nil, nil
-}
-func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) {
- return nil, nil
-}
-func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) {
- return nil, nil
-}
-func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil }
-func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil }
-func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil }
-func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) {
- return 0, nil
-}
-func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil }
-func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) {
- return r.accounts, nil
-}
-func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil }
-func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil }
-func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil }
-func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error {
- return nil
-}
-func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) {
- return 0, nil
-}
-
-func (r *stubAccountRepoForHandler) IncrementQuotaUsed(context.Context, int64, float64) error {
- return nil
-}
-
-func (r *stubAccountRepoForHandler) ResetQuotaUsed(context.Context, int64) error {
- return nil
-}
-
-// ==================== Stub: SoraClient (用于 SoraGatewayService) ====================
-
-var _ service.SoraClient = (*stubSoraClientForHandler)(nil)
-
-type stubSoraClientForHandler struct {
- videoStatus *service.SoraVideoTaskStatus
-}
-
-func (s *stubSoraClientForHandler) Enabled() bool { return true }
-func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) {
- return "", nil
-}
-func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) {
- return "task-image", nil
-}
-func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) {
- return "task-video", nil
-}
-func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) {
- return "task-video", nil
-}
-func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) {
- return "", nil
-}
-func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) {
- return nil, nil
-}
-func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) {
- return nil, nil
-}
-func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) {
- return "", nil
-}
-func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) {
- return "", nil
-}
-func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error {
- return nil
-}
-func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error {
- return nil
-}
-func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) {
- return "", nil
-}
-func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error {
- return nil
-}
-func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) {
- return "", nil
-}
-func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) {
- return "", nil
-}
-func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) {
- return nil, nil
-}
-func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) {
- return s.videoStatus, nil
-}
-
-// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ====================
-
-// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。
-func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
- return service.NewGatewayService(
- accountRepo, nil, nil, nil, nil, nil, nil, nil, nil,
- nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
- )
-}
-
-// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。
-func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService {
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Client: config.SoraClientConfig{
- PollIntervalSeconds: 1,
- MaxPollAttempts: 1,
- },
- },
- }
- return service.NewSoraGatewayService(soraClient, nil, nil, cfg)
-}
-
-// ==================== processGeneration: 更多路径测试 ====================
-
-func TestProcessGeneration_SelectAccountError(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- // accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts"
- accountRepo := &stubAccountRepoForHandler{accounts: nil}
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
-}
-
-func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- // 提供可用账号使 SelectAccountForModel 成功
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- // soraGatewayService 为 nil
- h := &SoraClientHandler{genService: genService, gatewayService: gatewayService}
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService")
-}
-
-func TestProcessGeneration_ForwardError(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- // SoraClient 返回视频任务失败
- soraClient := &stubSoraClientForHandler{
- videoStatus: &service.SoraVideoTaskStatus{
- Status: "failed",
- ErrorMsg: "content policy violation",
- },
- }
- soraGatewayService := newMinimalSoraGatewayService(soraClient)
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "生成失败")
-}
-
-func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- // MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration
- // 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。
- repo.getByIDOverrideAfterN = 1
- repo.getByIDOverrideStatus = "cancelled"
- genService := service.NewSoraGenerationService(repo, nil, nil)
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- soraClient := &stubSoraClientForHandler{
- videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"},
- }
- soraGatewayService := newMinimalSoraGatewayService(soraClient)
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- // Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating)
- require.Equal(t, "generating", repo.gens[1].Status)
-}
-
-func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- // SoraClient 返回 completed 但无 URL
- soraClient := &stubSoraClientForHandler{
- videoStatus: &service.SoraVideoTaskStatus{
- Status: "completed",
- URLs: nil, // 无 URL
- },
- }
- soraGatewayService := newMinimalSoraGatewayService(soraClient)
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL")
-}
-
-func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- // MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次)
- // 第 2 次返回 "cancelled" 状态,模拟外部取消
- repo.getByIDOverrideAfterN = 1
- repo.getByIDOverrideStatus = "cancelled"
- genService := service.NewSoraGenerationService(repo, nil, nil)
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- soraClient := &stubSoraClientForHandler{
- videoStatus: &service.SoraVideoTaskStatus{
- Status: "completed",
- URLs: []string{"https://example.com/video.mp4"},
- },
- }
- soraGatewayService := newMinimalSoraGatewayService(soraClient)
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- // Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating)
- require.Equal(t, "generating", repo.gens[1].Status)
-}
-
-func TestProcessGeneration_FullSuccessUpstream(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- soraClient := &stubSoraClientForHandler{
- videoStatus: &service.SoraVideoTaskStatus{
- Status: "completed",
- URLs: []string{"https://example.com/video.mp4"},
- },
- }
- soraGatewayService := newMinimalSoraGatewayService(soraClient)
- // 无 S3 和本地存储,降级到 upstream
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
- require.Equal(t, "completed", repo.gens[1].Status)
- require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType)
- require.NotEmpty(t, repo.gens[1].MediaURL)
-}
-
-func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- soraClient := &stubSoraClientForHandler{
- videoStatus: &service.SoraVideoTaskStatus{
- Status: "completed",
- URLs: []string{sourceServer.URL + "/video.mp4"},
- },
- }
- soraGatewayService := newMinimalSoraGatewayService(soraClient)
- s3Storage := newS3StorageForHandler(fakeS3.URL)
-
- userRepo := newStubUserRepoForHandler()
- userRepo.users[1] = &service.User{
- ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024,
- }
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
-
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- s3Storage: s3Storage,
- quotaService: quotaService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
- require.Equal(t, "completed", repo.gens[1].Status)
- require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType)
- require.NotEmpty(t, repo.gens[1].S3ObjectKeys)
- require.Greater(t, repo.gens[1].FileSizeBytes, int64(0))
- // 验证配额已累加
- require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0))
-}
-
-func TestProcessGeneration_MarkCompletedFails(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- // 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败
- repo.updateCallCount = new(int32)
- repo.updateFailAfterN = 1
- genService := service.NewSoraGenerationService(repo, nil, nil)
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- soraClient := &stubSoraClientForHandler{
- videoStatus: &service.SoraVideoTaskStatus{
- Status: "completed",
- URLs: []string{"https://example.com/video.mp4"},
- },
- }
- soraGatewayService := newMinimalSoraGatewayService(soraClient)
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1)
- // MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。
- // 由于 stub 存储的是指针,内存中的状态已被修改为 completed。
- // 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。
- require.Equal(t, "completed", repo.gens[1].Status)
-}
-
-// ==================== cleanupStoredMedia 直接测试 ====================
-
-func TestCleanupStoredMedia_S3Path(t *testing.T) {
- // S3 清理路径:s3Storage 为 nil 时不 panic
- h := &SoraClientHandler{}
- // 不应 panic
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
-}
-
-func TestCleanupStoredMedia_LocalPath(t *testing.T) {
- // 本地清理路径:mediaStorage 为 nil 时不 panic
- h := &SoraClientHandler{}
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"})
-}
-
-func TestCleanupStoredMedia_UpstreamPath(t *testing.T) {
- // upstream 类型不清理
- h := &SoraClientHandler{}
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil)
-}
-
-func TestCleanupStoredMedia_EmptyKeys(t *testing.T) {
- // 空 keys 不触发清理
- h := &SoraClientHandler{}
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil)
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil)
-}
-
-// ==================== DeleteGeneration: 本地存储清理路径 ====================
-
-func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "sora-delete-test-*")
- require.NoError(t, err)
- defer os.RemoveAll(tmpDir)
-
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: tmpDir,
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1,
- UserID: 1,
- Status: "completed",
- StorageType: service.SoraStorageTypeLocal,
- MediaURL: "video/test.mp4",
- MediaURLs: []string{"video/test.mp4"},
- }
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
-
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
- _, exists := repo.gens[1]
- require.False(t, exists)
-}
-
-func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) {
- // MediaURLs 为空,使用 MediaURL 作为清理路径
- tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*")
- require.NoError(t, err)
- defer os.RemoveAll(tmpDir)
-
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: tmpDir,
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1,
- UserID: 1,
- Status: "completed",
- StorageType: service.SoraStorageTypeLocal,
- MediaURL: "video/test.mp4",
- MediaURLs: nil, // 空
- }
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
-
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
-}
-
-func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) {
- // 非本地存储类型 → 跳过清理
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1,
- UserID: 1,
- Status: "completed",
- StorageType: service.SoraStorageTypeUpstream,
- MediaURL: "https://upstream.com/v.mp4",
- }
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
-
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
-}
-
-func TestDeleteGeneration_DeleteError(t *testing.T) {
- // repo.Delete 出错
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"}
- repo.deleteErr = fmt.Errorf("delete failed")
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
-
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusNotFound, rec.Code)
-}
-
-// ==================== fetchUpstreamModels 测试 ====================
-
-func TestFetchUpstreamModels_NilGateway(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- h := &SoraClientHandler{}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "gatewayService 未初始化")
-}
-
-func TestFetchUpstreamModels_NoAccounts(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- accountRepo := &stubAccountRepoForHandler{accounts: nil}
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "选择 Sora 账号失败")
-}
-
-func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "不支持模型同步")
-}
-
-func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"base_url": "https://sora.test"}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "api_key")
-}
-
-func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- // GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com"
- // 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"api_key": "sk-test"}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
-}
-
-func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusInternalServerError)
- }))
- defer ts.Close()
-
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "状态码 500")
-}
-
-func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- _, _ = w.Write([]byte("not json"))
- }))
- defer ts.Close()
-
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "解析响应失败")
-}
-
-func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- _, _ = w.Write([]byte(`{"data":[]}`))
- }))
- defer ts.Close()
-
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "空模型列表")
-}
-
-func TestFetchUpstreamModels_Success(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // 验证请求头
- require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization"))
- require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models"))
- w.WriteHeader(http.StatusOK)
- _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`))
- }))
- defer ts.Close()
-
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- families, err := h.fetchUpstreamModels(context.Background())
- require.NoError(t, err)
- require.NotEmpty(t, families)
-}
-
-func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- _, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`))
- }))
- defer ts.Close()
-
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
- _, err := h.fetchUpstreamModels(context.Background())
- require.Error(t, err)
- require.Contains(t, err.Error(), "未能从上游模型列表中识别")
-}
-
-// ==================== getModelFamilies 缓存测试 ====================
-
-func TestGetModelFamilies_CachesLocalConfig(t *testing.T) {
- // gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置
- h := &SoraClientHandler{}
- families := h.getModelFamilies(context.Background())
- require.NotEmpty(t, families)
-
- // 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL)
- families2 := h.getModelFamilies(context.Background())
- require.Equal(t, families, families2)
- require.False(t, h.modelCacheUpstream)
-}
-
-func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) {
- t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复")
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`))
- }))
- defer ts.Close()
-
- accountRepo := &stubAccountRepoForHandler{
- accounts: []service.Account{
- {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
- Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}},
- },
- }
- gatewayService := newMinimalGatewayService(accountRepo)
- h := &SoraClientHandler{gatewayService: gatewayService}
-
- families := h.getModelFamilies(context.Background())
- require.NotEmpty(t, families)
- require.True(t, h.modelCacheUpstream)
-
- // 第二次调用命中缓存
- families2 := h.getModelFamilies(context.Background())
- require.Equal(t, families, families2)
-}
-
-func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) {
- // 预设过期的缓存(modelCacheUpstream=false → 短 TTL)
- h := &SoraClientHandler{
- cachedFamilies: []service.SoraModelFamily{{ID: "old"}},
- modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期
- modelCacheUpstream: false,
- }
- // gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存
- families := h.getModelFamilies(context.Background())
- require.NotEmpty(t, families)
- // 缓存已刷新,不再是 "old"
- found := false
- for _, f := range families {
- if f.ID == "old" {
- found = true
- }
- }
- require.False(t, found, "过期缓存应被刷新")
-}
-
-// ==================== processGeneration: groupID 与 ForcePlatform ====================
-
-func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) {
- // groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- // 空账号列表 → SelectAccountForModel 失败
- accountRepo := &stubAccountRepoForHandler{accounts: nil}
- gatewayService := newMinimalGatewayService(accountRepo)
-
- h := &SoraClientHandler{
- genService: genService,
- gatewayService: gatewayService,
- }
-
- h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1)
- require.Equal(t, "failed", repo.gens[1].Status)
- require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败")
-}
-
-// ==================== Generate: 配额检查非 QuotaExceeded 错误 ====================
-
-func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) {
- // quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403
- repo := newStubSoraGenRepo()
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- // 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error
- userRepo := newStubUserRepoForHandler()
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
-
- h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil)
-
- body := `{"model":"sora2-landscape-10s","prompt":"test"}`
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
- h.Generate(c)
- require.Equal(t, http.StatusForbidden, rec.Code)
-}
-
-// ==================== Generate: CreatePending 并发限制错误 ====================
-
-// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口
-type stubSoraGenRepoWithAtomicCreate struct {
- stubSoraGenRepo
- limitErr error
-}
-
-func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error {
- if r.limitErr != nil {
- return r.limitErr
- }
- return r.stubSoraGenRepo.Create(context.Background(), gen)
-}
-
-func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) {
- repo := &stubSoraGenRepoWithAtomicCreate{
- stubSoraGenRepo: *newStubSoraGenRepo(),
- limitErr: service.ErrSoraGenerationConcurrencyLimit,
- }
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil)
-
- body := `{"model":"sora2-landscape-10s","prompt":"test"}`
- c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1)
- h.Generate(c)
- require.Equal(t, http.StatusTooManyRequests, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, resp["message"], "3")
-}
-
-// ==================== SaveToStorage: 配额超限 ====================
-
-func TestSaveToStorage_QuotaExceeded(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v.mp4",
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- // 用户配额已满
- userRepo := newStubUserRepoForHandler()
- userRepo.users[1] = &service.User{
- ID: 1,
- SoraStorageQuotaBytes: 10,
- SoraStorageUsedBytes: 10,
- }
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusTooManyRequests, rec.Code)
-}
-
-// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ====================
-
-func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v.mp4",
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- // 用户不存在 → GetByID 失败 → AddUsage 返回普通 error
- userRepo := newStubUserRepoForHandler()
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-// ==================== SaveToStorage: MediaURLs 全为空 ====================
-
-func TestSaveToStorage_EmptyMediaURLs(t *testing.T) {
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: "",
- MediaURLs: []string{},
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusBadRequest, rec.Code)
- resp := parseResponse(t, rec)
- require.Contains(t, resp["message"], "已过期")
-}
-
-// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ====================
-
-func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("fail-second")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v1.mp4",
- MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"},
- }
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ====================
-
-func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) {
- sourceServer := newFakeSourceServer()
- defer sourceServer.Close()
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: "upstream",
- MediaURL: sourceServer.URL + "/v.mp4",
- }
- repo.updateErr = fmt.Errorf("db error")
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- genService := service.NewSoraGenerationService(repo, nil, nil)
-
- userRepo := newStubUserRepoForHandler()
- userRepo.users[1] = &service.User{
- ID: 1,
- SoraStorageQuotaBytes: 100 * 1024 * 1024,
- SoraStorageUsedBytes: 0,
- }
- quotaService := service.NewSoraQuotaService(userRepo, nil, nil)
- h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.SaveToStorage(c)
- require.Equal(t, http.StatusInternalServerError, rec.Code)
-}
-
-// ==================== cleanupStoredMedia: 实际 S3 删除路径 ====================
-
-func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) {
- fakeS3 := newFakeS3Server("ok")
- defer fakeS3.Close()
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil)
-}
-
-func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) {
- fakeS3 := newFakeS3Server("fail")
- defer fakeS3.Close()
- s3Storage := newS3StorageForHandler(fakeS3.URL)
- h := &SoraClientHandler{s3Storage: s3Storage}
-
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil)
-}
-
-func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*")
- require.NoError(t, err)
- defer os.RemoveAll(tmpDir)
-
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: tmpDir,
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
- h := &SoraClientHandler{mediaStorage: mediaStorage}
-
- h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"})
-}
-
-// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ====================
-
-func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) {
- tmpDir, err := os.MkdirTemp("", "sora-del-test-*")
- require.NoError(t, err)
- defer os.RemoveAll(tmpDir)
-
- cfg := &config.Config{
- Sora: config.SoraConfig{
- Storage: config.SoraStorageConfig{
- Type: "local",
- LocalPath: tmpDir,
- },
- },
- }
- mediaStorage := service.NewSoraMediaStorage(cfg)
-
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{
- ID: 1, UserID: 1, Status: "completed",
- StorageType: service.SoraStorageTypeLocal,
- MediaURL: "nonexistent/video.mp4",
- MediaURLs: []string{"nonexistent/video.mp4"},
- }
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage}
-
- c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.DeleteGeneration(c)
- require.Equal(t, http.StatusOK, rec.Code)
-}
-
-// ==================== CancelGeneration: 任务已结束冲突 ====================
-
-func TestCancelGeneration_AlreadyCompleted(t *testing.T) {
- repo := newStubSoraGenRepo()
- repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"}
- genService := service.NewSoraGenerationService(repo, nil, nil)
- h := &SoraClientHandler{genService: genService}
-
- c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1)
- c.Params = gin.Params{{Key: "id", Value: "1"}}
- h.CancelGeneration(c)
- require.Equal(t, http.StatusConflict, rec.Code)
-}
diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go
deleted file mode 100644
index 5e505409..00000000
--- a/backend/internal/handler/sora_gateway_handler.go
+++ /dev/null
@@ -1,695 +0,0 @@
-package handler
-
-import (
- "context"
- "crypto/sha256"
- "encoding/hex"
- "encoding/json"
- "errors"
- "fmt"
- "net/http"
- "os"
- "path"
- "path/filepath"
- "strconv"
- "strings"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
- "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
- "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
- middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
- "github.com/Wei-Shaw/sub2api/internal/service"
- "github.com/Wei-Shaw/sub2api/internal/util/soraerror"
-
- "github.com/gin-gonic/gin"
- "github.com/tidwall/gjson"
- "github.com/tidwall/sjson"
- "go.uber.org/zap"
-)
-
-// SoraGatewayHandler handles Sora chat completions requests
-type SoraGatewayHandler struct {
- gatewayService *service.GatewayService
- soraGatewayService *service.SoraGatewayService
- billingCacheService *service.BillingCacheService
- usageRecordWorkerPool *service.UsageRecordWorkerPool
- concurrencyHelper *ConcurrencyHelper
- maxAccountSwitches int
- streamMode string
- soraTLSEnabled bool
- soraMediaSigningKey string
- soraMediaRoot string
-}
-
-// NewSoraGatewayHandler creates a new SoraGatewayHandler
-func NewSoraGatewayHandler(
- gatewayService *service.GatewayService,
- soraGatewayService *service.SoraGatewayService,
- concurrencyService *service.ConcurrencyService,
- billingCacheService *service.BillingCacheService,
- usageRecordWorkerPool *service.UsageRecordWorkerPool,
- cfg *config.Config,
-) *SoraGatewayHandler {
- pingInterval := time.Duration(0)
- maxAccountSwitches := 3
- streamMode := "force"
- soraTLSEnabled := true
- signKey := ""
- mediaRoot := "/app/data/sora"
- if cfg != nil {
- pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
- if cfg.Gateway.MaxAccountSwitches > 0 {
- maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
- }
- if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
- streamMode = mode
- }
- soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
- signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
- if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
- mediaRoot = root
- }
- }
- return &SoraGatewayHandler{
- gatewayService: gatewayService,
- soraGatewayService: soraGatewayService,
- billingCacheService: billingCacheService,
- usageRecordWorkerPool: usageRecordWorkerPool,
- concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
- maxAccountSwitches: maxAccountSwitches,
- streamMode: strings.ToLower(streamMode),
- soraTLSEnabled: soraTLSEnabled,
- soraMediaSigningKey: signKey,
- soraMediaRoot: mediaRoot,
- }
-}
-
-// ChatCompletions handles Sora /v1/chat/completions endpoint
-func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
- apiKey, ok := middleware2.GetAPIKeyFromContext(c)
- if !ok {
- h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
- return
- }
-
- subject, ok := middleware2.GetAuthSubjectFromContext(c)
- if !ok {
- h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
- return
- }
- reqLog := requestLogger(
- c,
- "handler.sora_gateway.chat_completions",
- zap.Int64("user_id", subject.UserID),
- zap.Int64("api_key_id", apiKey.ID),
- zap.Any("group_id", apiKey.GroupID),
- )
-
- body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
- if err != nil {
- if maxErr, ok := extractMaxBytesError(err); ok {
- h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
- return
- }
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
- return
- }
- if len(body) == 0 {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
- return
- }
-
- setOpsRequestContext(c, "", false, body)
-
- // 校验请求体 JSON 合法性
- if !gjson.ValidBytes(body) {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
- return
- }
-
- // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
- modelResult := gjson.GetBytes(body, "model")
- if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
- return
- }
- reqModel := modelResult.String()
-
- msgsResult := gjson.GetBytes(body, "messages")
- if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
- return
- }
-
- clientStream := gjson.GetBytes(body, "stream").Bool()
- reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream))
- if !clientStream {
- if h.streamMode == "error" {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
- return
- }
- var err error
- body, err = sjson.SetBytes(body, "stream", true)
- if err != nil {
- h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
- return
- }
- }
-
- setOpsRequestContext(c, reqModel, clientStream, body)
- setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(clientStream, false)))
-
- platform := ""
- if forced, ok := middleware2.GetForcePlatformFromContext(c); ok {
- platform = forced
- } else if apiKey.Group != nil {
- platform = apiKey.Group.Platform
- }
- if platform != service.PlatformSora {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform")
- return
- }
-
- streamStarted := false
- subscription, _ := middleware2.GetSubscriptionFromContext(c)
-
- maxWait := service.CalculateMaxWait(subject.Concurrency)
- canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
- waitCounted := false
- if err != nil {
- reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err))
- } else if !canWait {
- reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait))
- h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
- return
- }
- if err == nil && canWait {
- waitCounted = true
- }
- defer func() {
- if waitCounted {
- h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
- }
- }()
-
- userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted)
- if err != nil {
- reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err))
- h.handleConcurrencyError(c, err, "user", streamStarted)
- return
- }
- if waitCounted {
- h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
- waitCounted = false
- }
- userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
- if userReleaseFunc != nil {
- defer userReleaseFunc()
- }
-
- if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
- reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err))
- status, code, message := billingErrorDetails(err)
- h.handleStreamingAwareError(c, status, code, message, streamStarted)
- return
- }
-
- sessionHash := generateOpenAISessionHash(c, body)
-
- maxAccountSwitches := h.maxAccountSwitches
- switchCount := 0
- failedAccountIDs := make(map[int64]struct{})
- lastFailoverStatus := 0
- var lastFailoverBody []byte
- var lastFailoverHeaders http.Header
-
- for {
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
- if err != nil {
- reqLog.Warn("sora.account_select_failed",
- zap.Error(err),
- zap.Int("excluded_account_count", len(failedAccountIDs)),
- )
- if len(failedAccountIDs) == 0 {
- h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
- return
- }
- rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
- fields := []zap.Field{
- zap.Int("last_upstream_status", lastFailoverStatus),
- }
- if rayID != "" {
- fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
- }
- if mitigated != "" {
- fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
- }
- if contentType != "" {
- fields = append(fields, zap.String("last_upstream_content_type", contentType))
- }
- reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
- h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
- return
- }
- account := selection.Account
- setOpsSelectedAccount(c, account.ID, account.Platform)
- proxyBound := account.ProxyID != nil
- proxyID := int64(0)
- if account.ProxyID != nil {
- proxyID = *account.ProxyID
- }
- tlsFingerprintEnabled := h.soraTLSEnabled
-
- accountReleaseFunc := selection.ReleaseFunc
- if !selection.Acquired {
- if selection.WaitPlan == nil {
- h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
- return
- }
- accountWaitCounted := false
- canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
- if err != nil {
- reqLog.Warn("sora.account_wait_counter_increment_failed",
- zap.Int64("account_id", account.ID),
- zap.Int64("proxy_id", proxyID),
- zap.Bool("proxy_bound", proxyBound),
- zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
- zap.Error(err),
- )
- } else if !canWait {
- reqLog.Info("sora.account_wait_queue_full",
- zap.Int64("account_id", account.ID),
- zap.Int64("proxy_id", proxyID),
- zap.Bool("proxy_bound", proxyBound),
- zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
- zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
- )
- h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
- return
- }
- if err == nil && canWait {
- accountWaitCounted = true
- }
- defer func() {
- if accountWaitCounted {
- h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
- }
- }()
-
- accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
- c,
- account.ID,
- selection.WaitPlan.MaxConcurrency,
- selection.WaitPlan.Timeout,
- clientStream,
- &streamStarted,
- )
- if err != nil {
- reqLog.Warn("sora.account_slot_acquire_failed",
- zap.Int64("account_id", account.ID),
- zap.Int64("proxy_id", proxyID),
- zap.Bool("proxy_bound", proxyBound),
- zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
- zap.Error(err),
- )
- h.handleConcurrencyError(c, err, "account", streamStarted)
- return
- }
- if accountWaitCounted {
- h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
- accountWaitCounted = false
- }
- }
- accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
-
- result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream)
- if accountReleaseFunc != nil {
- accountReleaseFunc()
- }
- if err != nil {
- var failoverErr *service.UpstreamFailoverError
- if errors.As(err, &failoverErr) {
- failedAccountIDs[account.ID] = struct{}{}
- if switchCount >= maxAccountSwitches {
- lastFailoverStatus = failoverErr.StatusCode
- lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
- lastFailoverBody = failoverErr.ResponseBody
- rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
- fields := []zap.Field{
- zap.Int64("account_id", account.ID),
- zap.Int64("proxy_id", proxyID),
- zap.Bool("proxy_bound", proxyBound),
- zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
- zap.Int("upstream_status", failoverErr.StatusCode),
- zap.Int("switch_count", switchCount),
- zap.Int("max_switches", maxAccountSwitches),
- }
- if rayID != "" {
- fields = append(fields, zap.String("upstream_cf_ray", rayID))
- }
- if mitigated != "" {
- fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
- }
- if contentType != "" {
- fields = append(fields, zap.String("upstream_content_type", contentType))
- }
- reqLog.Warn("sora.upstream_failover_exhausted", fields...)
- h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
- return
- }
- lastFailoverStatus = failoverErr.StatusCode
- lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
- lastFailoverBody = failoverErr.ResponseBody
- switchCount++
- upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
- rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
- fields := []zap.Field{
- zap.Int64("account_id", account.ID),
- zap.Int64("proxy_id", proxyID),
- zap.Bool("proxy_bound", proxyBound),
- zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
- zap.Int("upstream_status", failoverErr.StatusCode),
- zap.String("upstream_error_code", upstreamErrCode),
- zap.String("upstream_error_message", upstreamErrMsg),
- zap.Int("switch_count", switchCount),
- zap.Int("max_switches", maxAccountSwitches),
- }
- if rayID != "" {
- fields = append(fields, zap.String("upstream_cf_ray", rayID))
- }
- if mitigated != "" {
- fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
- }
- if contentType != "" {
- fields = append(fields, zap.String("upstream_content_type", contentType))
- }
- reqLog.Warn("sora.upstream_failover_switching", fields...)
- continue
- }
- reqLog.Error("sora.forward_failed",
- zap.Int64("account_id", account.ID),
- zap.Int64("proxy_id", proxyID),
- zap.Bool("proxy_bound", proxyBound),
- zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
- zap.Error(err),
- )
- return
- }
-
- userAgent := c.GetHeader("User-Agent")
- clientIP := ip.GetClientIP(c)
- requestPayloadHash := service.HashUsageRequestPayload(body)
- inboundEndpoint := GetInboundEndpoint(c)
- upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform)
-
- // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
- h.submitUsageRecordTask(func(ctx context.Context) {
- if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
- Result: result,
- APIKey: apiKey,
- User: apiKey.User,
- Account: account,
- Subscription: subscription,
- InboundEndpoint: inboundEndpoint,
- UpstreamEndpoint: upstreamEndpoint,
- UserAgent: userAgent,
- IPAddress: clientIP,
- RequestPayloadHash: requestPayloadHash,
- }); err != nil {
- logger.L().With(
- zap.String("component", "handler.sora_gateway.chat_completions"),
- zap.Int64("user_id", subject.UserID),
- zap.Int64("api_key_id", apiKey.ID),
- zap.Any("group_id", apiKey.GroupID),
- zap.String("model", reqModel),
- zap.Int64("account_id", account.ID),
- ).Error("sora.record_usage_failed", zap.Error(err))
- }
- })
- reqLog.Debug("sora.request_completed",
- zap.Int64("account_id", account.ID),
- zap.Int64("proxy_id", proxyID),
- zap.Bool("proxy_bound", proxyBound),
- zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
- zap.Int("switch_count", switchCount),
- )
- return
- }
-}
-
-func generateOpenAISessionHash(c *gin.Context, body []byte) string {
- if c == nil {
- return ""
- }
- sessionID := strings.TrimSpace(c.GetHeader("session_id"))
- if sessionID == "" {
- sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
- }
- if sessionID == "" && len(body) > 0 {
- sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
- }
- if sessionID == "" {
- return ""
- }
- hash := sha256.Sum256([]byte(sessionID))
- return hex.EncodeToString(hash[:])
-}
-
-func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
- if task == nil {
- return
- }
- if h.usageRecordWorkerPool != nil {
- h.usageRecordWorkerPool.Submit(task)
- return
- }
- // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- defer func() {
- if recovered := recover(); recovered != nil {
- logger.L().With(
- zap.String("component", "handler.sora_gateway.chat_completions"),
- zap.Any("panic", recovered),
- ).Error("sora.usage_record_task_panic_recovered")
- }
- }()
- task(ctx)
-}
-
-func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
- h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
- fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
-}
-
-func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
- upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody)
- service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "")
-
- status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
- h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
-}
-
-func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
- if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
- baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
- return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
- }
-
- upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
- if strings.EqualFold(upstreamCode, "cf_shield_429") {
- baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
- return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
- }
- if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
- switch statusCode {
- case 401, 403, 404, 500, 502, 503, 504:
- return http.StatusBadGateway, "upstream_error", upstreamMessage
- case 429:
- return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
- }
- }
-
- switch statusCode {
- case 401:
- return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
- case 403:
- return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
- case 404:
- if strings.EqualFold(upstreamCode, "unsupported_country_code") {
- return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
- }
- return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
- case 429:
- return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
- case 529:
- return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
- case 500, 502, 503, 504:
- return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
- default:
- return http.StatusBadGateway, "upstream_error", "Upstream request failed"
- }
-}
-
-func cloneHTTPHeaders(headers http.Header) http.Header {
- if headers == nil {
- return nil
- }
- return headers.Clone()
-}
-
-func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
- if headers != nil {
- mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
- contentType = strings.TrimSpace(headers.Get("content-type"))
- if contentType == "" {
- contentType = strings.TrimSpace(headers.Get("Content-Type"))
- }
- }
- rayID = soraerror.ExtractCloudflareRayID(headers, body)
- return rayID, mitigated, contentType
-}
-
-func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
- return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
-}
-
-func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
- message = strings.TrimSpace(message)
- if message == "" {
- return false
- }
- if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
- lower := strings.ToLower(message)
- if strings.Contains(lower, "Just a moment... `)
-
- h := &SoraGatewayHandler{}
- h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
-
- lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
- require.Len(t, lines, 2)
- jsonStr := strings.TrimPrefix(lines[1], "data: ")
-
- var parsed map[string]any
- require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
-
- errorObj, ok := parsed["error"].(map[string]any)
- require.True(t, ok)
- require.Equal(t, "upstream_error", errorObj["type"])
- msg, _ := errorObj["message"].(string)
- require.Contains(t, msg, "Cloudflare challenge")
- require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
-}
-
-func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
- gin.SetMode(gin.TestMode)
- w := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(w)
- c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
-
- headers := http.Header{}
- headers.Set("cf-ray", "9d03b68c086027a1-SEA")
- body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
-
- h := &SoraGatewayHandler{}
- h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
-
- lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
- require.Len(t, lines, 2)
- jsonStr := strings.TrimPrefix(lines[1], "data: ")
-
- var parsed map[string]any
- require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
-
- errorObj, ok := parsed["error"].(map[string]any)
- require.True(t, ok)
- require.Equal(t, "rate_limit_error", errorObj["type"])
- msg, _ := errorObj["message"].(string)
- require.Contains(t, msg, "Cloudflare shield")
- require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
-}
-
-func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
- headers := http.Header{}
- headers.Set("cf-mitigated", "challenge")
- headers.Set("content-type", "text/html")
- body := []byte(``)
-
- rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
- require.Equal(t, "9cff2d62d83bb98d", rayID)
- require.Equal(t, "challenge", mitigated)
- require.Equal(t, "text/html", contentType)
-}
diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go
index 483f5105..b8506154 100644
--- a/backend/internal/handler/usage_handler.go
+++ b/backend/internal/handler/usage_handler.go
@@ -119,7 +119,12 @@ func (h *UsageHandler) List(c *gin.Context) {
endTime = &t
}
- params := pagination.PaginationParams{Page: page, PageSize: pageSize}
+ params := pagination.PaginationParams{
+ Page: page,
+ PageSize: pageSize,
+ SortBy: c.DefaultQuery("sort_by", "created_at"),
+ SortOrder: c.DefaultQuery("sort_order", "desc"),
+ }
filters := usagestats.UsageLogFilters{
UserID: subject.UserID, // Always filter by current user for security
APIKeyID: apiKeyID,
diff --git a/backend/internal/handler/usage_handler_request_type_test.go b/backend/internal/handler/usage_handler_request_type_test.go
index 7c4c7913..b49ed59b 100644
--- a/backend/internal/handler/usage_handler_request_type_test.go
+++ b/backend/internal/handler/usage_handler_request_type_test.go
@@ -16,10 +16,12 @@ import (
type userUsageRepoCapture struct {
service.UsageLogRepository
+ listParams pagination.PaginationParams
listFilters usagestats.UsageLogFilters
}
func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
+ s.listParams = params
s.listFilters = filters
return []service.UsageLog{}, &pagination.PaginationResult{
Total: 0,
diff --git a/backend/internal/handler/usage_handler_sort_test.go b/backend/internal/handler/usage_handler_sort_test.go
new file mode 100644
index 00000000..1af313b0
--- /dev/null
+++ b/backend/internal/handler/usage_handler_sort_test.go
@@ -0,0 +1,35 @@
+package handler
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserUsageListSortParams(t *testing.T) {
+ repo := &userUsageRepoCapture{}
+ router := newUserUsageRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/usage?sort_by=model&sort_order=ASC", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "model", repo.listParams.SortBy)
+ require.Equal(t, "ASC", repo.listParams.SortOrder)
+}
+
+func TestUserUsageListSortDefaults(t *testing.T) {
+ repo := &userUsageRepoCapture{}
+ router := newUserUsageRequestTypeTestRouter(repo)
+
+ req := httptest.NewRequest(http.MethodGet, "/usage", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "created_at", repo.listParams.SortBy)
+ require.Equal(t, "desc", repo.listParams.SortOrder)
+}
diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go
index c7c48e14..5c945815 100644
--- a/backend/internal/handler/usage_record_submit_task_test.go
+++ b/backend/internal/handler/usage_record_submit_task_test.go
@@ -129,56 +129,3 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovere
})
require.True(t, called.Load(), "panic 后后续任务应仍可执行")
}
-
-func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
- pool := newUsageRecordTestPool(t)
- h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
-
- done := make(chan struct{})
- h.submitUsageRecordTask(func(ctx context.Context) {
- close(done)
- })
-
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Fatal("task not executed")
- }
-}
-
-func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
- h := &SoraGatewayHandler{}
- var called atomic.Bool
-
- h.submitUsageRecordTask(func(ctx context.Context) {
- if _, ok := ctx.Deadline(); !ok {
- t.Fatal("expected deadline in fallback context")
- }
- called.Store(true)
- })
-
- require.True(t, called.Load())
-}
-
-func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
- h := &SoraGatewayHandler{}
- require.NotPanics(t, func() {
- h.submitUsageRecordTask(nil)
- })
-}
-
-func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) {
- h := &SoraGatewayHandler{}
- var called atomic.Bool
-
- require.NotPanics(t, func() {
- h.submitUsageRecordTask(func(ctx context.Context) {
- panic("usage task panic")
- })
- })
-
- h.submitUsageRecordTask(func(ctx context.Context) {
- called.Store(true)
- })
- require.True(t, called.Load(), "panic 后后续任务应仍可执行")
-}
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index 35862f1c..3f6ed8c2 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -1,6 +1,9 @@
package handler
import (
+ "context"
+ "strings"
+
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -11,13 +14,27 @@ import (
// UserHandler handles user-related requests
type UserHandler struct {
- userService *service.UserService
+ userService *service.UserService
+ authService *service.AuthService
+ emailService *service.EmailService
+ emailCache service.EmailCache
+ affiliateService *service.AffiliateService
}
// NewUserHandler creates a new UserHandler
-func NewUserHandler(userService *service.UserService) *UserHandler {
+func NewUserHandler(
+ userService *service.UserService,
+ authService *service.AuthService,
+ emailService *service.EmailService,
+ emailCache service.EmailCache,
+ affiliateService *service.AffiliateService,
+) *UserHandler {
return &UserHandler{
- userService: userService,
+ userService: userService,
+ authService: authService,
+ emailService: emailService,
+ emailCache: emailCache,
+ affiliateService: affiliateService,
}
}
@@ -29,7 +46,32 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct {
- Username *string `json:"username"`
+ Username *string `json:"username"`
+ AvatarURL *string `json:"avatar_url"`
+ BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
+}
+
+type userProfileResponse struct {
+ dto.User
+ AvatarURL string `json:"avatar_url,omitempty"`
+ AvatarSource *userProfileSourceContext `json:"avatar_source,omitempty"`
+ UsernameSource *userProfileSourceContext `json:"username_source,omitempty"`
+ DisplayNameSource *userProfileSourceContext `json:"display_name_source,omitempty"`
+ NicknameSource *userProfileSourceContext `json:"nickname_source,omitempty"`
+ ProfileSources map[string]*userProfileSourceContext `json:"profile_sources,omitempty"`
+ Identities service.UserIdentitySummarySet `json:"identities"`
+ AuthBindings map[string]service.UserIdentitySummary `json:"auth_bindings"`
+ IdentityBindings map[string]service.UserIdentitySummary `json:"identity_bindings"`
+ EmailBound bool `json:"email_bound"`
+ LinuxDoBound bool `json:"linuxdo_bound"`
+ OIDCBound bool `json:"oidc_bound"`
+ WeChatBound bool `json:"wechat_bound"`
+}
+
+type userProfileSourceContext struct {
+ Provider string `json:"provider,omitempty"`
+ Source string `json:"source,omitempty"`
}
// GetProfile handles getting user profile
@@ -41,13 +83,19 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
return
}
- userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
+ userData, err := h.userService.GetProfile(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
- response.Success(c, dto.UserFromService(userData))
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, userData)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
}
// ChangePassword handles changing user password
@@ -94,7 +142,10 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
}
svcReq := service.UpdateProfileRequest{
- Username: req.Username,
+ Username: req.Username,
+ AvatarURL: req.AvatarURL,
+ BalanceNotifyEnabled: req.BalanceNotifyEnabled,
+ BalanceNotifyThreshold: req.BalanceNotifyThreshold,
}
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
@@ -102,5 +153,453 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
return
}
- response.Success(c, dto.UserFromService(updatedUser))
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
+}
+
+// GetAffiliate returns the current user's affiliate details.
+// GET /api/v1/user/aff
+func (h *UserHandler) GetAffiliate(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ detail, err := h.affiliateService.GetAffiliateDetail(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, detail)
+}
+
+// TransferAffiliateQuota transfers all available affiliate quota into current balance.
+// POST /api/v1/user/aff/transfer
+func (h *UserHandler) TransferAffiliateQuota(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ transferred, balance, err := h.affiliateService.TransferAffiliateQuota(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{
+ "transferred_quota": transferred,
+ "balance": balance,
+ })
+}
+
+type StartIdentityBindingRequest struct {
+ Provider string `json:"provider" binding:"required"`
+ RedirectTo string `json:"redirect_to"`
+}
+
+type BindEmailIdentityRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ VerifyCode string `json:"verify_code" binding:"required"`
+ Password string `json:"password" binding:"required"`
+}
+
+type SendEmailBindingCodeRequest struct {
+ Email string `json:"email" binding:"required,email"`
+}
+
+// StartIdentityBinding returns the backend authorize URL for starting a third-party identity bind flow.
+// POST /api/v1/user/auth-identities/bind/start
+func (h *UserHandler) StartIdentityBinding(c *gin.Context) {
+ if _, ok := middleware2.GetAuthSubjectFromContext(c); !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req StartIdentityBindingRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ result, err := h.userService.PrepareIdentityBindingStart(c.Request.Context(), service.StartUserIdentityBindingRequest{
+ Provider: req.Provider,
+ RedirectTo: req.RedirectTo,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, result)
+}
+
+// BindEmailIdentity verifies and binds a local email identity for the current user.
+// POST /api/v1/user/account-bindings/email
+func (h *UserHandler) BindEmailIdentity(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+ if h.authService == nil {
+ response.InternalError(c, "Auth service not configured")
+ return
+ }
+
+ var req BindEmailIdentityRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ updatedUser, err := h.authService.BindEmailIdentity(
+ c.Request.Context(),
+ subject.UserID,
+ req.Email,
+ req.VerifyCode,
+ req.Password,
+ )
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
+}
+
+// UnbindIdentity removes a third-party sign-in provider from the current user.
+// DELETE /api/v1/user/account-bindings/:provider
+func (h *UserHandler) UnbindIdentity(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ updatedUser, unbound, err := h.userService.UnbindUserAuthProviderWithResult(
+ c.Request.Context(),
+ subject.UserID,
+ c.Param("provider"),
+ )
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if unbound && h.authService != nil {
+ if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ }
+
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
+}
+
+// SendEmailBindingCode sends a verification code for the current user's email binding flow.
+// POST /api/v1/user/account-bindings/email/send-code
+func (h *UserHandler) SendEmailBindingCode(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+ if h.authService == nil {
+ response.InternalError(c, "Auth service not configured")
+ return
+ }
+
+ var req SendEmailBindingCodeRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.authService.SendEmailIdentityBindCode(c.Request.Context(), subject.UserID, req.Email); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"message": "Verification code sent successfully"})
+}
+
+// SendNotifyEmailCodeRequest represents the request to send notify email verification code
+type SendNotifyEmailCodeRequest struct {
+ Email string `json:"email" binding:"required,email"`
+}
+
+// SendNotifyEmailCode sends verification code to extra notification email
+// POST /api/v1/user/notify-email/send-code
+func (h *UserHandler) SendNotifyEmailCode(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req SendNotifyEmailCodeRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"message": "Verification code sent successfully"})
+}
+
+// VerifyNotifyEmailRequest represents the request to verify and add notify email
+type VerifyNotifyEmailRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ Code string `json:"code" binding:"required,len=6"`
+}
+
+// VerifyNotifyEmail verifies code and adds email to notification list
+// POST /api/v1/user/notify-email/verify
+func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req VerifyNotifyEmailRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ err := h.userService.VerifyAndAddNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Code, h.emailCache)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Return updated user
+ updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
+}
+
+// RemoveNotifyEmailRequest represents the request to remove a notify email
+type RemoveNotifyEmailRequest struct {
+ Email string `json:"email" binding:"required,email"`
+}
+
+// RemoveNotifyEmail removes email from notification list
+// DELETE /api/v1/user/notify-email
+func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req RemoveNotifyEmailRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ err := h.userService.RemoveNotifyEmail(c.Request.Context(), subject.UserID, req.Email)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Return updated user
+ updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
+}
+
+// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
+type ToggleNotifyEmailRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ Disabled bool `json:"disabled"`
+}
+
+// ToggleNotifyEmail toggles the disabled state of a notification email
+// PUT /api/v1/user/notify-email/toggle
+func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req ToggleNotifyEmailRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ err := h.userService.ToggleNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Disabled)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, profileResp)
+}
+
+func (h *UserHandler) buildUserProfileResponse(ctx context.Context, userID int64, user *service.User) (userProfileResponse, error) {
+ identities, err := h.userService.GetProfileIdentitySummaries(ctx, userID, user)
+ if err != nil {
+ return userProfileResponse{}, err
+ }
+ return userProfileResponseFromService(user, identities), nil
+}
+
+func userProfileResponseFromService(user *service.User, identities service.UserIdentitySummarySet) userProfileResponse {
+ base := dto.UserFromService(user)
+ if base == nil {
+ return userProfileResponse{}
+ }
+ bindings := userProfileBindingMap(identities)
+ profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities)
+ return userProfileResponse{
+ User: *base,
+ AvatarURL: user.AvatarURL,
+ AvatarSource: avatarSource,
+ UsernameSource: usernameSource,
+ DisplayNameSource: usernameSource,
+ NicknameSource: usernameSource,
+ ProfileSources: profileSources,
+ Identities: identities,
+ AuthBindings: bindings,
+ IdentityBindings: bindings,
+ EmailBound: identities.Email.Bound,
+ LinuxDoBound: identities.LinuxDo.Bound,
+ OIDCBound: identities.OIDC.Bound,
+ WeChatBound: identities.WeChat.Bound,
+ }
+}
+
+func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary {
+ return map[string]service.UserIdentitySummary{
+ "email": identities.Email,
+ "linuxdo": identities.LinuxDo,
+ "oidc": identities.OIDC,
+ "wechat": identities.WeChat,
+ }
+}
+
+func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) (
+ map[string]*userProfileSourceContext,
+ *userProfileSourceContext,
+ *userProfileSourceContext,
+) {
+ if user == nil {
+ return nil, nil, nil
+ }
+
+ thirdParty := thirdPartyIdentityProviders(identities)
+ var avatarSource *userProfileSourceContext
+ avatarValue := strings.TrimSpace(user.AvatarURL)
+ for _, summary := range thirdParty {
+ if avatarValue != "" && avatarValue == strings.TrimSpace(summary.AvatarURL) {
+ avatarSource = buildUserProfileSourceContext(summary.Provider)
+ break
+ }
+ }
+
+ usernameValue := strings.TrimSpace(user.Username)
+ var usernameSource *userProfileSourceContext
+ for _, summary := range thirdParty {
+ if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) {
+ usernameSource = buildUserProfileSourceContext(summary.Provider)
+ break
+ }
+ }
+
+ profileSources := map[string]*userProfileSourceContext{}
+ if avatarSource != nil {
+ profileSources["avatar"] = avatarSource
+ }
+ if usernameSource != nil {
+ profileSources["username"] = usernameSource
+ profileSources["display_name"] = usernameSource
+ profileSources["nickname"] = usernameSource
+ }
+ if len(profileSources) == 0 {
+ return nil, avatarSource, usernameSource
+ }
+ return profileSources, avatarSource, usernameSource
+}
+
+func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
+ out := make([]service.UserIdentitySummary, 0, 3)
+ for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
+ if summary.Bound {
+ out = append(out, summary)
+ }
+ }
+ return out
+}
+
+func buildUserProfileSourceContext(provider string) *userProfileSourceContext {
+ provider = strings.TrimSpace(provider)
+ if provider == "" {
+ return nil
+ }
+ return &userProfileSourceContext{
+ Provider: provider,
+ Source: provider,
+ }
}
diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go
new file mode 100644
index 00000000..8a864b51
--- /dev/null
+++ b/backend/internal/handler/user_handler_test.go
@@ -0,0 +1,783 @@
+//go:build unit
+
+package handler
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type userHandlerRepoStub struct {
+ user *service.User
+ identities []service.UserAuthIdentityRecord
+ unbound []string
+}
+
+func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil }
+func (s *userHandlerRepoStub) GetByID(context.Context, int64) (*service.User, error) {
+ cloned := *s.user
+ return &cloned, nil
+}
+func (s *userHandlerRepoStub) GetByEmail(context.Context, string) (*service.User, error) {
+ cloned := *s.user
+ return &cloned, nil
+}
+func (s *userHandlerRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
+ cloned := *s.user
+ return &cloned, nil
+}
+func (s *userHandlerRepoStub) Update(_ context.Context, user *service.User) error {
+ cloned := *user
+ s.user = &cloned
+ return nil
+}
+func (s *userHandlerRepoStub) Delete(context.Context, int64) error { return nil }
+func (s *userHandlerRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
+ if s.user == nil || s.user.AvatarURL == "" {
+ return nil, nil
+ }
+ return &service.UserAvatar{
+ StorageProvider: s.user.AvatarSource,
+ URL: s.user.AvatarURL,
+ ContentType: s.user.AvatarMIME,
+ ByteSize: s.user.AvatarByteSize,
+ SHA256: s.user.AvatarSHA256,
+ }, nil
+}
+func (s *userHandlerRepoStub) UpsertUserAvatar(_ context.Context, _ int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ s.user.AvatarURL = input.URL
+ s.user.AvatarSource = input.StorageProvider
+ s.user.AvatarMIME = input.ContentType
+ s.user.AvatarByteSize = input.ByteSize
+ s.user.AvatarSHA256 = input.SHA256
+ return &service.UserAvatar{
+ StorageProvider: input.StorageProvider,
+ URL: input.URL,
+ ContentType: input.ContentType,
+ ByteSize: input.ByteSize,
+ SHA256: input.SHA256,
+ }, nil
+}
+func (s *userHandlerRepoStub) DeleteUserAvatar(context.Context, int64) error {
+ s.user.AvatarURL = ""
+ s.user.AvatarSource = ""
+ s.user.AvatarMIME = ""
+ s.user.AvatarByteSize = 0
+ s.user.AvatarSHA256 = ""
+ return nil
+}
+func (s *userHandlerRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
+func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
+func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
+func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
+func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+func (s *userHandlerRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+func (s *userHandlerRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+func (s *userHandlerRepoStub) UpdateUserLastActiveAt(_ context.Context, _ int64, activeAt time.Time) error {
+ if s.user != nil {
+ s.user.LastActiveAt = &activeAt
+ }
+ return nil
+}
+func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+func (s *userHandlerRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (s *userHandlerRepoStub) EnableTotp(context.Context, int64) error { return nil }
+func (s *userHandlerRepoStub) DisableTotp(context.Context, int64) error { return nil }
+func (s *userHandlerRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) {
+ out := make([]service.UserAuthIdentityRecord, len(s.identities))
+ copy(out, s.identities)
+ return out, nil
+}
+func (s *userHandlerRepoStub) UnbindUserAuthProvider(_ context.Context, _ int64, provider string) error {
+ s.unbound = append(s.unbound, provider)
+ filtered := s.identities[:0]
+ for _, identity := range s.identities {
+ if identity.ProviderType == provider {
+ continue
+ }
+ filtered = append(filtered, identity)
+ }
+ s.identities = append([]service.UserAuthIdentityRecord(nil), filtered...)
+ return nil
+}
+
+func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 11,
+ Email: "handler-avatar@example.com",
+ Username: "handler-avatar",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
+
+ body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/user", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.UpdateProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ AvatarURL string `json:"avatar_url"`
+ Username string `json:"username"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, "https://cdn.example.com/avatar.png", resp.Data.AvatarURL)
+ require.Equal(t, "handler-avatar", resp.Data.Username)
+}
+
+func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 11,
+ Email: "identity@example.com",
+ Username: "identity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-123456",
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ },
+ },
+ {
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example.com",
+ ProviderSubject: "oidc-user-abc",
+ Metadata: map[string]any{
+ "suggested_display_name": "OIDC Display",
+ },
+ },
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.GetProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Identities struct {
+ Email struct {
+ Bound bool `json:"bound"`
+ BoundCount int `json:"bound_count"`
+ DisplayName string `json:"display_name"`
+ } `json:"email"`
+ LinuxDo struct {
+ Bound bool `json:"bound"`
+ BoundCount int `json:"bound_count"`
+ DisplayName string `json:"display_name"`
+ ProviderKey string `json:"provider_key"`
+ } `json:"linuxdo"`
+ OIDC struct {
+ Bound bool `json:"bound"`
+ DisplayName string `json:"display_name"`
+ ProviderKey string `json:"provider_key"`
+ } `json:"oidc"`
+ WeChat struct {
+ Bound bool `json:"bound"`
+ CanBind bool `json:"can_bind"`
+ BindStartPath string `json:"bind_start_path"`
+ } `json:"wechat"`
+ } `json:"identities"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.True(t, resp.Data.Identities.Email.Bound)
+ require.Equal(t, 1, resp.Data.Identities.Email.BoundCount)
+ require.Equal(t, "identity@example.com", resp.Data.Identities.Email.DisplayName)
+ require.True(t, resp.Data.Identities.LinuxDo.Bound)
+ require.Equal(t, 1, resp.Data.Identities.LinuxDo.BoundCount)
+ require.Equal(t, "linuxdo-handle", resp.Data.Identities.LinuxDo.DisplayName)
+ require.Equal(t, "linuxdo", resp.Data.Identities.LinuxDo.ProviderKey)
+ require.True(t, resp.Data.Identities.OIDC.Bound)
+ require.Equal(t, "OIDC Display", resp.Data.Identities.OIDC.DisplayName)
+ require.Equal(t, "https://issuer.example.com", resp.Data.Identities.OIDC.ProviderKey)
+ require.False(t, resp.Data.Identities.WeChat.Bound)
+ require.True(t, resp.Data.Identities.WeChat.CanBind)
+ require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/bind/start")
+}
+
+func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 21,
+ Email: "legacy-profile@example.com",
+ Username: "linuxdo-handle",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ AvatarURL: "https://cdn.example.com/linuxdo.png",
+ AvatarSource: "remote_url",
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-21",
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ "avatar_url": "https://cdn.example.com/linuxdo.png",
+ },
+ },
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21})
+
+ handler.GetProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, true, resp.Data["email_bound"])
+ require.Equal(t, true, resp.Data["linuxdo_bound"])
+ require.Equal(t, false, resp.Data["oidc_bound"])
+ require.Equal(t, false, resp.Data["wechat_bound"])
+ require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"])
+
+ avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", avatarSource["provider"])
+ require.Equal(t, "linuxdo", avatarSource["source"])
+
+ authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
+ require.True(t, ok)
+ linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, linuxdoBinding["bound"])
+ require.Equal(t, "linuxdo", linuxdoBinding["provider"])
+
+ identityBindings, ok := resp.Data["identity_bindings"].(map[string]any)
+ require.True(t, ok)
+ emailBinding, ok := identityBindings["email"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, emailBinding["bound"])
+ require.Equal(t, "profile.authBindings.notes.emailManagedFromProfile", emailBinding["note_key"])
+
+ linuxdoCompatBinding, ok := identityBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "profile.authBindings.notes.canUnbind", linuxdoCompatBinding["note_key"])
+
+ profileSources, ok := resp.Data["profile_sources"].(map[string]any)
+ require.True(t, ok)
+ usernameSource, ok := profileSources["username"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "linuxdo", usernameSource["provider"])
+ require.Equal(t, "linuxdo", usernameSource["source"])
+}
+
+func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIdentityMetadata(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 22,
+ Email: "edited-profile@example.com",
+ Username: "custom-name",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ AvatarURL: "https://cdn.example.com/custom.png",
+ AvatarSource: "remote_url",
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-22",
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ "avatar_url": "https://cdn.example.com/linuxdo.png",
+ },
+ },
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 22})
+
+ handler.GetProfile(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.NotContains(t, resp.Data, "avatar_source")
+ require.NotContains(t, resp.Data, "username_source")
+ require.NotContains(t, resp.Data, "profile_sources")
+}
+
+type userHandlerEmailCacheStub struct {
+ data *service.VerificationCodeData
+}
+
+type userHandlerRefreshTokenCacheStub struct {
+ revokedUserIDs []int64
+}
+
+func (s *userHandlerRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
+ return nil, service.ErrRefreshTokenNotFound
+}
+
+func (s *userHandlerRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error {
+ s.revokedUserIDs = append(s.revokedUserIDs, userID)
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *userHandlerRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
+
+func (s *userHandlerEmailCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return s.data, nil
+}
+
+func (s *userHandlerEmailCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) DeleteVerificationCode(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *userHandlerEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
+ return nil, nil
+}
+
+func (s *userHandlerEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) DeletePasswordResetToken(context.Context, string) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
+ return false
+}
+
+func (s *userHandlerEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
+ return nil
+}
+
+func (s *userHandlerEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *userHandlerEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
+ return 0, nil
+}
+
+func TestUserHandlerBindEmailIdentityReturnsProfileResponse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 11,
+ Email: "legacy-user" + service.LinuxDoConnectSyntheticEmailDomain,
+ Username: "legacy-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ }
+ emailCache := &userHandlerEmailCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ emailService := service.NewEmailService(nil, emailCache)
+ authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
+
+ body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"new-password"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Params = gin.Params{{Key: "provider", Value: "email"}}
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.BindEmailIdentity(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Email string `json:"email"`
+ EmailBound bool `json:"email_bound"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, "new@example.com", resp.Data.Email)
+ require.True(t, resp.Data.EmailBound)
+}
+
+func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 21,
+ Email: "identity@example.com",
+ Username: "identity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "identity@example.com",
+ },
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-21",
+ Metadata: map[string]any{
+ "username": "linuxdo-handle",
+ },
+ },
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21})
+ c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
+
+ handler.UnbindIdentity(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, []string{"linuxdo"}, repo.unbound)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data map[string]any `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+
+ authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
+ require.True(t, ok)
+ linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, false, linuxdoBinding["bound"])
+}
+
+func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigured(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 23,
+ Email: "identity@example.com",
+ Username: "identity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ TokenVersion: 4,
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "identity@example.com",
+ },
+ {
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "linuxdo-subject-23",
+ },
+ },
+ }
+ refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 23})
+ c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
+
+ handler.UnbindIdentity(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Equal(t, []int64{23}, refreshTokenCache.revokedUserIDs)
+ require.Equal(t, int64(5), repo.user.TokenVersion)
+}
+
+func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 24,
+ Email: "identity@example.com",
+ Username: "identity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ TokenVersion: 4,
+ },
+ identities: []service.UserAuthIdentityRecord{
+ {
+ ProviderType: "email",
+ ProviderKey: "email",
+ ProviderSubject: "identity@example.com",
+ },
+ },
+ }
+ refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
+
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 24})
+ c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
+
+ handler.UnbindIdentity(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+ require.Empty(t, repo.unbound)
+ require.Empty(t, refreshTokenCache.revokedUserIDs)
+ require.Equal(t, int64(4), repo.user.TokenVersion)
+}
+
+func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ user := &service.User{
+ ID: 11,
+ Email: "current@example.com",
+ Username: "bound-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, user.SetPassword("current-password"))
+
+ repo := &userHandlerRepoStub{user: user}
+ emailCache := &userHandlerEmailCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ },
+ }
+ emailService := service.NewEmailService(nil, emailCache)
+ authService := service.NewAuthService(nil, repo, nil, nil, cfg, nil, emailService, nil, nil, nil, nil, nil)
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil, nil)
+
+ body := []byte(`{"email":"new@example.com","verify_code":"123456","password":"wrong-password"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/account-bindings/email", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.BindEmailIdentity(c)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Reason string `json:"reason"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, http.StatusBadRequest, resp.Code)
+ require.Equal(t, "PASSWORD_INCORRECT", resp.Reason)
+ require.Equal(t, "current password is incorrect", resp.Message)
+ require.Equal(t, "current@example.com", repo.user.Email)
+}
+
+func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ repo := &userHandlerRepoStub{
+ user: &service.User{
+ ID: 11,
+ Email: "identity@example.com",
+ Username: "identity-user",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ },
+ }
+ handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil, nil)
+
+ body := []byte(`{"provider":"wechat","redirect_to":"/settings/profile"}`)
+ recorder := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(recorder)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/user/auth-identities/bind/start", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
+
+ handler.StartIdentityBinding(c)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Provider string `json:"provider"`
+ AuthorizeURL string `json:"authorize_url"`
+ Method string `json:"method"`
+ UseBrowserRedirect bool `json:"use_browser_redirect"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, "wechat", resp.Data.Provider)
+ require.Equal(t, "GET", resp.Data.Method)
+ require.True(t, resp.Data.UseBrowserRedirect)
+ require.Contains(t, resp.Data.AuthorizeURL, "/api/v1/auth/oauth/wechat/bind/start")
+ require.Contains(t, resp.Data.AuthorizeURL, "intent=bind_current_user")
+ require.Contains(t, resp.Data.AuthorizeURL, "redirect=%2Fsettings%2Fprofile")
+}
diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go
index 02ddd030..a8725875 100644
--- a/backend/internal/handler/wire.go
+++ b/backend/internal/handler/wire.go
@@ -33,32 +33,42 @@ func ProvideAdminHandlers(
tlsFingerprintProfileHandler *admin.TLSFingerprintProfileHandler,
apiKeyHandler *admin.AdminAPIKeyHandler,
scheduledTestHandler *admin.ScheduledTestHandler,
+ channelHandler *admin.ChannelHandler,
+ channelMonitorHandler *admin.ChannelMonitorHandler,
+ channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler,
+ paymentHandler *admin.PaymentHandler,
+ affiliateHandler *admin.AffiliateHandler,
) *AdminHandlers {
return &AdminHandlers{
- Dashboard: dashboardHandler,
- User: userHandler,
- Group: groupHandler,
- Account: accountHandler,
- Announcement: announcementHandler,
- DataManagement: dataManagementHandler,
- Backup: backupHandler,
- OAuth: oauthHandler,
- OpenAIOAuth: openaiOAuthHandler,
- GeminiOAuth: geminiOAuthHandler,
- AntigravityOAuth: antigravityOAuthHandler,
- Proxy: proxyHandler,
- Redeem: redeemHandler,
- Promo: promoHandler,
- Setting: settingHandler,
- Ops: opsHandler,
- System: systemHandler,
- Subscription: subscriptionHandler,
- Usage: usageHandler,
- UserAttribute: userAttributeHandler,
- ErrorPassthrough: errorPassthroughHandler,
- TLSFingerprintProfile: tlsFingerprintProfileHandler,
- APIKey: apiKeyHandler,
- ScheduledTest: scheduledTestHandler,
+ Dashboard: dashboardHandler,
+ User: userHandler,
+ Group: groupHandler,
+ Account: accountHandler,
+ Announcement: announcementHandler,
+ DataManagement: dataManagementHandler,
+ Backup: backupHandler,
+ OAuth: oauthHandler,
+ OpenAIOAuth: openaiOAuthHandler,
+ GeminiOAuth: geminiOAuthHandler,
+ AntigravityOAuth: antigravityOAuthHandler,
+ Proxy: proxyHandler,
+ Redeem: redeemHandler,
+ Promo: promoHandler,
+ Setting: settingHandler,
+ Ops: opsHandler,
+ System: systemHandler,
+ Subscription: subscriptionHandler,
+ Usage: usageHandler,
+ UserAttribute: userAttributeHandler,
+ ErrorPassthrough: errorPassthroughHandler,
+ TLSFingerprintProfile: tlsFingerprintProfileHandler,
+ APIKey: apiKeyHandler,
+ ScheduledTest: scheduledTestHandler,
+ Channel: channelHandler,
+ ChannelMonitor: channelMonitorHandler,
+ ChannelMonitorTemplate: channelMonitorTemplateHandler,
+ Payment: paymentHandler,
+ Affiliate: affiliateHandler,
}
}
@@ -81,31 +91,35 @@ func ProvideHandlers(
redeemHandler *RedeemHandler,
subscriptionHandler *SubscriptionHandler,
announcementHandler *AnnouncementHandler,
+ channelMonitorUserHandler *ChannelMonitorUserHandler,
adminHandlers *AdminHandlers,
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
- soraGatewayHandler *SoraGatewayHandler,
- soraClientHandler *SoraClientHandler,
settingHandler *SettingHandler,
totpHandler *TotpHandler,
+ paymentHandler *PaymentHandler,
+ paymentWebhookHandler *PaymentWebhookHandler,
+ availableChannelHandler *AvailableChannelHandler,
_ *service.IdempotencyCoordinator,
_ *service.IdempotencyCleanupService,
) *Handlers {
return &Handlers{
- Auth: authHandler,
- User: userHandler,
- APIKey: apiKeyHandler,
- Usage: usageHandler,
- Redeem: redeemHandler,
- Subscription: subscriptionHandler,
- Announcement: announcementHandler,
- Admin: adminHandlers,
- Gateway: gatewayHandler,
- OpenAIGateway: openaiGatewayHandler,
- SoraGateway: soraGatewayHandler,
- SoraClient: soraClientHandler,
- Setting: settingHandler,
- Totp: totpHandler,
+ Auth: authHandler,
+ User: userHandler,
+ APIKey: apiKeyHandler,
+ Usage: usageHandler,
+ Redeem: redeemHandler,
+ Subscription: subscriptionHandler,
+ Announcement: announcementHandler,
+ ChannelMonitor: channelMonitorUserHandler,
+ Admin: adminHandlers,
+ Gateway: gatewayHandler,
+ OpenAIGateway: openaiGatewayHandler,
+ Setting: settingHandler,
+ Totp: totpHandler,
+ Payment: paymentHandler,
+ PaymentWebhook: paymentWebhookHandler,
+ AvailableChannel: availableChannelHandler,
}
}
@@ -119,11 +133,14 @@ var ProviderSet = wire.NewSet(
NewRedeemHandler,
NewSubscriptionHandler,
NewAnnouncementHandler,
+ NewChannelMonitorUserHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
- NewSoraGatewayHandler,
NewTotpHandler,
ProvideSettingHandler,
+ NewPaymentHandler,
+ NewPaymentWebhookHandler,
+ NewAvailableChannelHandler,
// Admin handlers
admin.NewDashboardHandler,
@@ -150,6 +167,11 @@ var ProviderSet = wire.NewSet(
admin.NewTLSFingerprintProfileHandler,
admin.NewAdminAPIKeyHandler,
admin.NewScheduledTestHandler,
+ admin.NewChannelHandler,
+ admin.NewChannelMonitorHandler,
+ admin.NewChannelMonitorRequestTemplateHandler,
+ admin.NewPaymentHandler,
+ admin.NewAffiliateHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,
diff --git a/backend/internal/payment/amount.go b/backend/internal/payment/amount.go
new file mode 100644
index 00000000..8489ffa3
--- /dev/null
+++ b/backend/internal/payment/amount.go
@@ -0,0 +1,24 @@
+package payment
+
+import (
+ "fmt"
+
+ "github.com/shopspring/decimal"
+)
+
+const centsPerYuan = 100
+
+// YuanToFen converts a CNY yuan string (e.g. "10.50") to fen (int64).
+// Uses shopspring/decimal for precision.
+func YuanToFen(yuanStr string) (int64, error) {
+ d, err := decimal.NewFromString(yuanStr)
+ if err != nil {
+ return 0, fmt.Errorf("invalid amount: %s", yuanStr)
+ }
+ return d.Mul(decimal.NewFromInt(centsPerYuan)).IntPart(), nil
+}
+
+// FenToYuan converts fen (int64) to yuan as a float64 for interface compatibility.
+func FenToYuan(fen int64) float64 {
+ return decimal.NewFromInt(fen).Div(decimal.NewFromInt(centsPerYuan)).InexactFloat64()
+}
diff --git a/backend/internal/payment/amount_test.go b/backend/internal/payment/amount_test.go
new file mode 100644
index 00000000..6120b189
--- /dev/null
+++ b/backend/internal/payment/amount_test.go
@@ -0,0 +1,128 @@
+//go:build unit
+
+package payment
+
+import (
+ "math"
+ "testing"
+)
+
+func TestYuanToFen(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ want int64
+ wantErr bool
+ }{
+ // Normal values
+ {name: "one yuan", input: "1.00", want: 100},
+ {name: "ten yuan fifty fen", input: "10.50", want: 1050},
+ {name: "one fen", input: "0.01", want: 1},
+ {name: "large amount", input: "99999.99", want: 9999999},
+
+ // Edge: zero
+ {name: "zero no decimal", input: "0", want: 0},
+ {name: "zero with decimal", input: "0.00", want: 0},
+
+ // IEEE 754 precision edge case: 1.15 * 100 = 114.99999... in float64
+ {name: "ieee754 precision 1.15", input: "1.15", want: 115},
+
+ // More precision edge cases
+ {name: "ieee754 precision 0.1", input: "0.1", want: 10},
+ {name: "ieee754 precision 0.2", input: "0.2", want: 20},
+ {name: "ieee754 precision 33.33", input: "33.33", want: 3333},
+
+ // Large value
+ {name: "hundred thousand", input: "100000.00", want: 10000000},
+
+ // Integer without decimal
+ {name: "integer 5", input: "5", want: 500},
+ {name: "integer 100", input: "100", want: 10000},
+
+ // Single decimal place
+ {name: "single decimal 1.5", input: "1.5", want: 150},
+
+ // Negative values
+ {name: "negative one yuan", input: "-1.00", want: -100},
+ {name: "negative with fen", input: "-10.50", want: -1050},
+
+ // Invalid inputs
+ {name: "empty string", input: "", wantErr: true},
+ {name: "alphabetic", input: "abc", wantErr: true},
+ {name: "double dot", input: "1.2.3", wantErr: true},
+ {name: "spaces", input: " ", wantErr: true},
+ {name: "special chars", input: "$10.00", wantErr: true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := YuanToFen(tt.input)
+ if tt.wantErr {
+ if err == nil {
+ t.Errorf("YuanToFen(%q) expected error, got %d", tt.input, got)
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("YuanToFen(%q) unexpected error: %v", tt.input, err)
+ }
+ if got != tt.want {
+ t.Errorf("YuanToFen(%q) = %d, want %d", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestFenToYuan(t *testing.T) {
+ tests := []struct {
+ name string
+ fen int64
+ want float64
+ }{
+ {name: "one yuan", fen: 100, want: 1.0},
+ {name: "ten yuan fifty fen", fen: 1050, want: 10.5},
+ {name: "one fen", fen: 1, want: 0.01},
+ {name: "zero", fen: 0, want: 0.0},
+ {name: "large amount", fen: 9999999, want: 99999.99},
+ {name: "negative", fen: -100, want: -1.0},
+ {name: "negative with fen", fen: -1050, want: -10.5},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := FenToYuan(tt.fen)
+ if math.Abs(got-tt.want) > 1e-9 {
+ t.Errorf("FenToYuan(%d) = %f, want %f", tt.fen, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestYuanToFenRoundTrip(t *testing.T) {
+ // Verify that converting yuan->fen->yuan preserves the value.
+ cases := []struct {
+ yuan string
+ fen int64
+ }{
+ {"0.01", 1},
+ {"1.00", 100},
+ {"10.50", 1050},
+ {"99999.99", 9999999},
+ }
+
+ for _, tc := range cases {
+ fen, err := YuanToFen(tc.yuan)
+ if err != nil {
+ t.Fatalf("YuanToFen(%q) unexpected error: %v", tc.yuan, err)
+ }
+ if fen != tc.fen {
+ t.Errorf("YuanToFen(%q) = %d, want %d", tc.yuan, fen, tc.fen)
+ }
+ yuan := FenToYuan(fen)
+ // Parse expected yuan back for comparison
+ expectedYuan := FenToYuan(tc.fen)
+ if math.Abs(yuan-expectedYuan) > 1e-9 {
+ t.Errorf("round-trip: FenToYuan(%d) = %f, want %f", fen, yuan, expectedYuan)
+ }
+ }
+}
diff --git a/backend/internal/payment/crypto.go b/backend/internal/payment/crypto.go
new file mode 100644
index 00000000..0581469d
--- /dev/null
+++ b/backend/internal/payment/crypto.go
@@ -0,0 +1,111 @@
+package payment
+
+import (
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "encoding/base64"
+ "fmt"
+ "io"
+ "strings"
+)
+
+// AES256KeySize is the required key length (in bytes) for AES-256-GCM.
+const AES256KeySize = 32
+
+// Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key.
+// The output format is "iv:authTag:ciphertext" where each component is base64-encoded,
+// matching the Node.js crypto.ts format for cross-compatibility.
+//
+// Deprecated: payment provider configs are now stored as plaintext JSON.
+// This function is kept only for seeding legacy ciphertext in tests and for
+// the transitional Decrypt fallback. Scheduled for removal after all live
+// deployments complete migration by re-saving their configs.
+func Encrypt(plaintext string, key []byte) (string, error) {
+ if len(key) != AES256KeySize {
+ return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
+ }
+
+ block, err := aes.NewCipher(key)
+ if err != nil {
+ return "", fmt.Errorf("create AES cipher: %w", err)
+ }
+
+ gcm, err := cipher.NewGCM(block)
+ if err != nil {
+ return "", fmt.Errorf("create GCM: %w", err)
+ }
+
+ nonce := make([]byte, gcm.NonceSize()) // 12 bytes for GCM
+ if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
+ return "", fmt.Errorf("generate nonce: %w", err)
+ }
+
+ // Seal appends the ciphertext + auth tag
+ sealed := gcm.Seal(nil, nonce, []byte(plaintext), nil)
+
+ // Split sealed into ciphertext and auth tag (last 16 bytes)
+ tagSize := gcm.Overhead()
+ ciphertext := sealed[:len(sealed)-tagSize]
+ authTag := sealed[len(sealed)-tagSize:]
+
+ // Format: iv:authTag:ciphertext (all base64)
+ return fmt.Sprintf("%s:%s:%s",
+ base64.StdEncoding.EncodeToString(nonce),
+ base64.StdEncoding.EncodeToString(authTag),
+ base64.StdEncoding.EncodeToString(ciphertext),
+ ), nil
+}
+
+// Decrypt decrypts a ciphertext string produced by Encrypt.
+// The input format is "iv:authTag:ciphertext" where each component is base64-encoded.
+//
+// Deprecated: payment provider configs are now stored as plaintext JSON.
+// This function remains only as a read-path fallback for pre-migration
+// ciphertext records. Scheduled for removal once all deployments re-save
+// their provider configs through the admin UI.
+func Decrypt(ciphertext string, key []byte) (string, error) {
+ if len(key) != AES256KeySize {
+ return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key))
+ }
+
+ parts := strings.SplitN(ciphertext, ":", 3)
+ if len(parts) != 3 {
+ return "", fmt.Errorf("invalid ciphertext format: expected iv:authTag:ciphertext")
+ }
+
+ nonce, err := base64.StdEncoding.DecodeString(parts[0])
+ if err != nil {
+ return "", fmt.Errorf("decode IV: %w", err)
+ }
+
+ authTag, err := base64.StdEncoding.DecodeString(parts[1])
+ if err != nil {
+ return "", fmt.Errorf("decode auth tag: %w", err)
+ }
+
+ encrypted, err := base64.StdEncoding.DecodeString(parts[2])
+ if err != nil {
+ return "", fmt.Errorf("decode ciphertext: %w", err)
+ }
+
+ block, err := aes.NewCipher(key)
+ if err != nil {
+ return "", fmt.Errorf("create AES cipher: %w", err)
+ }
+
+ gcm, err := cipher.NewGCM(block)
+ if err != nil {
+ return "", fmt.Errorf("create GCM: %w", err)
+ }
+
+ // Reconstruct the sealed data: ciphertext + authTag
+ sealed := append(encrypted, authTag...)
+
+ plaintext, err := gcm.Open(nil, nonce, sealed, nil)
+ if err != nil {
+ return "", fmt.Errorf("decrypt: %w", err)
+ }
+
+ return string(plaintext), nil
+}
diff --git a/backend/internal/payment/crypto_test.go b/backend/internal/payment/crypto_test.go
new file mode 100644
index 00000000..da8b6006
--- /dev/null
+++ b/backend/internal/payment/crypto_test.go
@@ -0,0 +1,183 @@
+package payment
+
+import (
+ "crypto/rand"
+ "strings"
+ "testing"
+)
+
+func makeKey(t *testing.T) []byte {
+ t.Helper()
+ key := make([]byte, 32)
+ if _, err := rand.Read(key); err != nil {
+ t.Fatalf("generate random key: %v", err)
+ }
+ return key
+}
+
+func TestEncryptDecryptRoundTrip(t *testing.T) {
+ t.Parallel()
+ key := makeKey(t)
+
+ plaintexts := []string{
+ "hello world",
+ "short",
+ "a longer string with special chars: !@#$%^&*()",
+ `{"key":"value","num":42}`,
+ "你好世界 unicode test 🎉",
+ strings.Repeat("x", 10000),
+ }
+
+ for _, pt := range plaintexts {
+ encrypted, err := Encrypt(pt, key)
+ if err != nil {
+ t.Fatalf("Encrypt(%q) error: %v", pt[:min(len(pt), 30)], err)
+ }
+ decrypted, err := Decrypt(encrypted, key)
+ if err != nil {
+ t.Fatalf("Decrypt error for plaintext %q: %v", pt[:min(len(pt), 30)], err)
+ }
+ if decrypted != pt {
+ t.Fatalf("round-trip failed: got %q, want %q", decrypted[:min(len(decrypted), 30)], pt[:min(len(pt), 30)])
+ }
+ }
+}
+
+func TestEncryptProducesDifferentCiphertexts(t *testing.T) {
+ t.Parallel()
+ key := makeKey(t)
+
+ ct1, err := Encrypt("same plaintext", key)
+ if err != nil {
+ t.Fatalf("first Encrypt error: %v", err)
+ }
+ ct2, err := Encrypt("same plaintext", key)
+ if err != nil {
+ t.Fatalf("second Encrypt error: %v", err)
+ }
+ if ct1 == ct2 {
+ t.Fatal("two encryptions of the same plaintext should produce different ciphertexts (random nonce)")
+ }
+}
+
+func TestDecryptWithWrongKeyFails(t *testing.T) {
+ t.Parallel()
+ key1 := makeKey(t)
+ key2 := makeKey(t)
+
+ encrypted, err := Encrypt("secret data", key1)
+ if err != nil {
+ t.Fatalf("Encrypt error: %v", err)
+ }
+
+ _, err = Decrypt(encrypted, key2)
+ if err == nil {
+ t.Fatal("Decrypt with wrong key should fail, but got nil error")
+ }
+}
+
+func TestEncryptRejectsInvalidKeyLength(t *testing.T) {
+ t.Parallel()
+ badKeys := [][]byte{
+ nil,
+ make([]byte, 0),
+ make([]byte, 16),
+ make([]byte, 31),
+ make([]byte, 33),
+ make([]byte, 64),
+ }
+ for _, key := range badKeys {
+ _, err := Encrypt("test", key)
+ if err == nil {
+ t.Fatalf("Encrypt should reject key of length %d", len(key))
+ }
+ }
+}
+
+func TestDecryptRejectsInvalidKeyLength(t *testing.T) {
+ t.Parallel()
+ badKeys := [][]byte{
+ nil,
+ make([]byte, 16),
+ make([]byte, 33),
+ }
+ for _, key := range badKeys {
+ _, err := Decrypt("dummydata:dummydata:dummydata", key)
+ if err == nil {
+ t.Fatalf("Decrypt should reject key of length %d", len(key))
+ }
+ }
+}
+
+func TestEncryptEmptyPlaintext(t *testing.T) {
+ t.Parallel()
+ key := makeKey(t)
+
+ encrypted, err := Encrypt("", key)
+ if err != nil {
+ t.Fatalf("Encrypt empty plaintext error: %v", err)
+ }
+ decrypted, err := Decrypt(encrypted, key)
+ if err != nil {
+ t.Fatalf("Decrypt empty plaintext error: %v", err)
+ }
+ if decrypted != "" {
+ t.Fatalf("expected empty string, got %q", decrypted)
+ }
+}
+
+func TestEncryptDecryptUnicodeJSON(t *testing.T) {
+ t.Parallel()
+ key := makeKey(t)
+
+ jsonContent := `{"name":"测试用户","email":"test@example.com","balance":100.50}`
+ encrypted, err := Encrypt(jsonContent, key)
+ if err != nil {
+ t.Fatalf("Encrypt JSON error: %v", err)
+ }
+ decrypted, err := Decrypt(encrypted, key)
+ if err != nil {
+ t.Fatalf("Decrypt JSON error: %v", err)
+ }
+ if decrypted != jsonContent {
+ t.Fatalf("JSON round-trip failed: got %q, want %q", decrypted, jsonContent)
+ }
+}
+
+func TestDecryptInvalidFormat(t *testing.T) {
+ t.Parallel()
+ key := makeKey(t)
+
+ invalidInputs := []string{
+ "",
+ "nodelimiter",
+ "only:two",
+ "invalid:base64:!!!",
+ }
+ for _, input := range invalidInputs {
+ _, err := Decrypt(input, key)
+ if err == nil {
+ t.Fatalf("Decrypt(%q) should fail but got nil error", input)
+ }
+ }
+}
+
+func TestCiphertextFormat(t *testing.T) {
+ t.Parallel()
+ key := makeKey(t)
+
+ encrypted, err := Encrypt("test", key)
+ if err != nil {
+ t.Fatalf("Encrypt error: %v", err)
+ }
+
+ parts := strings.SplitN(encrypted, ":", 3)
+ if len(parts) != 3 {
+ t.Fatalf("ciphertext should have format iv:authTag:ciphertext, got %d parts", len(parts))
+ }
+ for i, part := range parts {
+ if part == "" {
+ t.Fatalf("ciphertext part %d is empty", i)
+ }
+ }
+}
diff --git a/backend/internal/payment/fee.go b/backend/internal/payment/fee.go
new file mode 100644
index 00000000..e2128e5e
--- /dev/null
+++ b/backend/internal/payment/fee.go
@@ -0,0 +1,19 @@
+package payment
+
+import (
+ "github.com/shopspring/decimal"
+)
+
+// CalculatePayAmount computes the total pay amount given a recharge amount and
+// fee rate (percentage). Fee = amount * feeRate / 100, rounded UP (away from zero)
+// to 2 decimal places. The returned string is formatted to exactly 2 decimal places.
+// If feeRate <= 0, the amount is returned as-is (formatted to 2 decimal places).
+func CalculatePayAmount(rechargeAmount float64, feeRate float64) string {
+ amount := decimal.NewFromFloat(rechargeAmount)
+ if feeRate <= 0 {
+ return amount.StringFixed(2)
+ }
+ rate := decimal.NewFromFloat(feeRate)
+ fee := amount.Mul(rate).Div(decimal.NewFromInt(100)).RoundUp(2)
+ return amount.Add(fee).StringFixed(2)
+}
diff --git a/backend/internal/payment/fee_test.go b/backend/internal/payment/fee_test.go
new file mode 100644
index 00000000..c58d1082
--- /dev/null
+++ b/backend/internal/payment/fee_test.go
@@ -0,0 +1,111 @@
+package payment
+
+import (
+ "testing"
+)
+
+func TestCalculatePayAmount(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ amount float64
+ feeRate float64
+ expected string
+ }{
+ {
+ name: "zero fee rate returns same amount",
+ amount: 100.00,
+ feeRate: 0,
+ expected: "100.00",
+ },
+ {
+ name: "negative fee rate returns same amount",
+ amount: 50.00,
+ feeRate: -5,
+ expected: "50.00",
+ },
+ {
+ name: "1 percent fee rate",
+ amount: 100.00,
+ feeRate: 1,
+ expected: "101.00",
+ },
+ {
+ name: "5 percent fee on 200",
+ amount: 200.00,
+ feeRate: 5,
+ expected: "210.00",
+ },
+ {
+ name: "fee rounds UP to 2 decimal places",
+ amount: 100.00,
+ feeRate: 3,
+ expected: "103.00",
+ },
+ {
+ name: "fee rounds UP small remainder",
+ amount: 10.00,
+ feeRate: 3.33,
+ expected: "10.34", // 10 * 3.33 / 100 = 0.333 -> round up -> 0.34
+ },
+ {
+ name: "very small amount",
+ amount: 0.01,
+ feeRate: 1,
+ expected: "0.02", // 0.01 * 1/100 = 0.0001 -> round up -> 0.01 -> total 0.02
+ },
+ {
+ name: "large amount",
+ amount: 99999.99,
+ feeRate: 10,
+ expected: "109999.99", // 99999.99 * 10/100 = 9999.999 -> round up -> 10000.00 -> total 109999.99
+ },
+ {
+ name: "100 percent fee rate doubles amount",
+ amount: 50.00,
+ feeRate: 100,
+ expected: "100.00",
+ },
+ {
+ name: "precision 0.01 fee difference",
+ amount: 100.00,
+ feeRate: 1.01,
+ expected: "101.01", // 100 * 1.01/100 = 1.01
+ },
+ {
+ name: "precision 0.02 fee",
+ amount: 100.00,
+ feeRate: 1.02,
+ expected: "101.02",
+ },
+ {
+ name: "zero amount with positive fee",
+ amount: 0,
+ feeRate: 5,
+ expected: "0.00",
+ },
+ {
+ name: "fractional amount no fee",
+ amount: 19.99,
+ feeRate: 0,
+ expected: "19.99",
+ },
+ {
+ name: "fractional fee that causes rounding up",
+ amount: 33.33,
+ feeRate: 7.77,
+ expected: "35.92", // 33.33 * 7.77 / 100 = 2.589741 -> round up -> 2.59 -> total 35.92
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := CalculatePayAmount(tt.amount, tt.feeRate)
+ if got != tt.expected {
+ t.Fatalf("CalculatePayAmount(%v, %v) = %q, want %q", tt.amount, tt.feeRate, got, tt.expected)
+ }
+ })
+ }
+}
diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go
new file mode 100644
index 00000000..41fd2c50
--- /dev/null
+++ b/backend/internal/payment/load_balancer.go
@@ -0,0 +1,429 @@
+package payment
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/paymentorder"
+ "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance"
+)
+
+// Strategy represents a load balancing strategy for provider instance selection.
+type Strategy string
+
+const (
+ StrategyRoundRobin Strategy = "round-robin"
+ StrategyLeastAmount Strategy = "least-amount"
+)
+
+// ChannelLimits holds limits for a single payment channel within a provider instance.
+type ChannelLimits struct {
+ DailyLimit float64 `json:"dailyLimit,omitempty"`
+ SingleMin float64 `json:"singleMin,omitempty"`
+ SingleMax float64 `json:"singleMax,omitempty"`
+}
+
+// InstanceLimits holds per-channel limits for a provider instance (JSON).
+type InstanceLimits map[string]ChannelLimits
+
+// LoadBalancer selects a provider instance for a given payment type.
+type LoadBalancer interface {
+ GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error)
+ SelectInstance(ctx context.Context, providerKey string, paymentType PaymentType, strategy Strategy, orderAmount float64) (*InstanceSelection, error)
+}
+
+// DefaultLoadBalancer implements LoadBalancer using database queries.
+type DefaultLoadBalancer struct {
+ db *dbent.Client
+ encryptionKey []byte
+ counter atomic.Uint64
+}
+
+type contextKey string
+
+const wxpayJSAPIAppIDContextKey contextKey = "payment.wxpay.jsapi_app_id"
+
+// NewDefaultLoadBalancer creates a new load balancer.
+func NewDefaultLoadBalancer(db *dbent.Client, encryptionKey []byte) *DefaultLoadBalancer {
+ return &DefaultLoadBalancer{db: db, encryptionKey: encryptionKey}
+}
+
+func WithWxpayJSAPIAppID(ctx context.Context, appID string) context.Context {
+ appID = strings.TrimSpace(appID)
+ if appID == "" {
+ return ctx
+ }
+ return context.WithValue(ctx, wxpayJSAPIAppIDContextKey, appID)
+}
+
+func wxpayJSAPIAppIDFromContext(ctx context.Context) string {
+ if ctx == nil {
+ return ""
+ }
+ appID, _ := ctx.Value(wxpayJSAPIAppIDContextKey).(string)
+ return strings.TrimSpace(appID)
+}
+
+// instanceCandidate pairs an instance with its pre-fetched daily usage.
+type instanceCandidate struct {
+ inst *dbent.PaymentProviderInstance
+ dailyUsed float64 // includes PENDING orders
+}
+
+// SelectInstance picks an enabled instance for the given provider key and payment type.
+//
+// Flow:
+// 1. Query all enabled instances for providerKey, filter by supported paymentType
+// 2. Batch-query daily usage (PENDING + PAID + COMPLETED + RECHARGING) for all candidates
+// 3. Filter out instances where: single-min/max violated OR daily remaining < orderAmount
+// 4. Pick from survivors using the configured strategy (round-robin / least-amount)
+// 5. If all filtered out, fall back to full list (let the provider itself reject)
+func (lb *DefaultLoadBalancer) SelectInstance(
+ ctx context.Context,
+ providerKey string,
+ paymentType PaymentType,
+ strategy Strategy,
+ orderAmount float64,
+) (*InstanceSelection, error) {
+ // Step 1: query enabled instances matching payment type.
+ instances, err := lb.queryEnabledInstances(ctx, providerKey, paymentType)
+ if err != nil {
+ return nil, err
+ }
+
+ // Step 2: batch-fetch daily usage for all candidates.
+ candidates := lb.attachDailyUsage(ctx, instances)
+
+ // Step 3: filter by limits.
+ available := filterByLimits(candidates, paymentType, orderAmount)
+ if len(available) == 0 {
+ slog.Warn("all instances exceeded limits, using full candidate list",
+ "provider", providerKey, "payment_type", paymentType,
+ "order_amount", orderAmount, "count", len(candidates))
+ available = candidates
+ }
+
+ // Step 4: pick by strategy.
+ selected := lb.pickByStrategy(available, strategy)
+ return lb.buildSelection(selected.inst)
+}
+
+// queryEnabledInstances returns enabled instances that support paymentType.
+// When providerKey is non-empty, only instances with that provider key are considered.
+// When providerKey is empty, instances across all providers are considered,
+// enabling cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay").
+func (lb *DefaultLoadBalancer) queryEnabledInstances(
+ ctx context.Context,
+ providerKey string,
+ paymentType PaymentType,
+) ([]*dbent.PaymentProviderInstance, error) {
+ query := lb.db.PaymentProviderInstance.Query().
+ Where(paymentproviderinstance.Enabled(true))
+ if providerKey != "" {
+ query = query.Where(paymentproviderinstance.ProviderKey(providerKey))
+ }
+ instances, err := query.
+ Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("query provider instances: %w", err)
+ }
+
+ var matched []*dbent.PaymentProviderInstance
+ expectedWxpayJSAPIAppID := wxpayJSAPIAppIDFromContext(ctx)
+ for _, inst := range instances {
+ // Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay),
+ // not "stripe" itself. The checkout page aggregates all sub-types under "stripe".
+ if paymentType == TypeStripe {
+ if inst.ProviderKey == TypeStripe {
+ matched = append(matched, inst)
+ }
+ } else if InstanceSupportsType(inst.SupportedTypes, paymentType) {
+ if expectedWxpayJSAPIAppID != "" && normalizeVisibleMethodSupportType(paymentType) == TypeWxpay && inst.ProviderKey == TypeWxpay {
+ config, cfgErr := lb.decryptConfig(inst.Config)
+ if cfgErr != nil {
+ slog.Warn("skip wxpay instance with unreadable config during jsapi filtering", "instance_id", inst.ID, "error", cfgErr)
+ continue
+ }
+ if resolveWxpayJSAPIAppID(config) != expectedWxpayJSAPIAppID {
+ continue
+ }
+ }
+ matched = append(matched, inst)
+ }
+ }
+ if len(matched) == 0 {
+ return nil, fmt.Errorf("no enabled instance for payment type %s", paymentType)
+ }
+ return matched, nil
+}
+
+// attachDailyUsage queries daily usage for each instance in a single pass.
+// Usage includes PENDING orders to avoid over-committing capacity.
+func (lb *DefaultLoadBalancer) attachDailyUsage(
+ ctx context.Context,
+ instances []*dbent.PaymentProviderInstance,
+) []instanceCandidate {
+ todayStart := startOfDay(time.Now())
+
+ // Collect instance IDs.
+ ids := make([]string, len(instances))
+ for i, inst := range instances {
+ ids[i] = fmt.Sprintf("%d", inst.ID)
+ }
+
+ // Batch query: sum pay_amount grouped by provider_instance_id.
+ type row struct {
+ InstanceID string `json:"provider_instance_id"`
+ Sum float64 `json:"sum"`
+ }
+ var rows []row
+ err := lb.db.PaymentOrder.Query().
+ Where(
+ paymentorder.ProviderInstanceIDIn(ids...),
+ paymentorder.StatusIn(
+ OrderStatusPending, OrderStatusPaid,
+ OrderStatusCompleted, OrderStatusRecharging,
+ ),
+ paymentorder.CreatedAtGTE(todayStart),
+ ).
+ GroupBy(paymentorder.FieldProviderInstanceID).
+ Aggregate(dbent.Sum(paymentorder.FieldPayAmount)).
+ Scan(ctx, &rows)
+ if err != nil {
+ slog.Warn("batch daily usage query failed, treating all as zero", "error", err)
+ }
+
+ usageMap := make(map[string]float64, len(rows))
+ for _, r := range rows {
+ usageMap[r.InstanceID] = r.Sum
+ }
+
+ candidates := make([]instanceCandidate, len(instances))
+ for i, inst := range instances {
+ candidates[i] = instanceCandidate{
+ inst: inst,
+ dailyUsed: usageMap[fmt.Sprintf("%d", inst.ID)],
+ }
+ }
+ return candidates
+}
+
+// filterByLimits removes instances that cannot accommodate the order:
+// - orderAmount outside single-transaction [min, max]
+// - daily remaining capacity (limit - used) < orderAmount
+func filterByLimits(candidates []instanceCandidate, paymentType PaymentType, orderAmount float64) []instanceCandidate {
+ var result []instanceCandidate
+ for _, c := range candidates {
+ cl := getInstanceChannelLimits(c.inst, paymentType)
+
+ if cl.SingleMin > 0 && orderAmount < cl.SingleMin {
+ slog.Info("order below instance single min, skipping",
+ "instance_id", c.inst.ID, "order", orderAmount, "min", cl.SingleMin)
+ continue
+ }
+ if cl.SingleMax > 0 && orderAmount > cl.SingleMax {
+ slog.Info("order above instance single max, skipping",
+ "instance_id", c.inst.ID, "order", orderAmount, "max", cl.SingleMax)
+ continue
+ }
+ if cl.DailyLimit > 0 && c.dailyUsed+orderAmount > cl.DailyLimit {
+ slog.Info("instance daily remaining insufficient, skipping",
+ "instance_id", c.inst.ID, "used", c.dailyUsed,
+ "order", orderAmount, "limit", cl.DailyLimit)
+ continue
+ }
+
+ result = append(result, c)
+ }
+ return result
+}
+
+// getInstanceChannelLimits returns the channel limits for a specific payment type.
+func getInstanceChannelLimits(inst *dbent.PaymentProviderInstance, paymentType PaymentType) ChannelLimits {
+ if inst.Limits == "" {
+ return ChannelLimits{}
+ }
+ var limits InstanceLimits
+ if err := json.Unmarshal([]byte(inst.Limits), &limits); err != nil {
+ return ChannelLimits{}
+ }
+ // For Stripe, limits are stored under the provider key "stripe".
+ lookupKey := paymentType
+ if inst.ProviderKey == "stripe" {
+ lookupKey = "stripe"
+ }
+ if cl, ok := limits[lookupKey]; ok {
+ return cl
+ }
+ if aliasKey := legacyVisibleMethodAlias(lookupKey); aliasKey != "" {
+ if cl, ok := limits[aliasKey]; ok {
+ return cl
+ }
+ }
+ return ChannelLimits{}
+}
+
+// pickByStrategy selects one instance from the available candidates.
+func (lb *DefaultLoadBalancer) pickByStrategy(candidates []instanceCandidate, strategy Strategy) instanceCandidate {
+ if strategy == StrategyLeastAmount && len(candidates) > 1 {
+ return pickLeastAmount(candidates)
+ }
+ // Default: round-robin.
+ idx := lb.counter.Add(1) % uint64(len(candidates))
+ return candidates[idx]
+}
+
+// pickLeastAmount selects the instance with the lowest daily usage.
+// No extra DB queries — usage was pre-fetched in attachDailyUsage.
+func pickLeastAmount(candidates []instanceCandidate) instanceCandidate {
+ best := candidates[0]
+ for _, c := range candidates[1:] {
+ if c.dailyUsed < best.dailyUsed {
+ best = c
+ }
+ }
+ return best
+}
+
+func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderInstance) (*InstanceSelection, error) {
+ config, err := lb.decryptConfig(selected.Config)
+ if err != nil {
+ return nil, fmt.Errorf("decrypt instance %d config: %w", selected.ID, err)
+ }
+ if config == nil {
+ config = map[string]string{}
+ }
+
+ if selected.PaymentMode != "" {
+ config["paymentMode"] = selected.PaymentMode
+ }
+
+ return &InstanceSelection{
+ InstanceID: fmt.Sprintf("%d", selected.ID),
+ ProviderKey: selected.ProviderKey,
+ Config: config,
+ SupportedTypes: selected.SupportedTypes,
+ PaymentMode: selected.PaymentMode,
+ }, nil
+}
+
+// decryptConfig parses a stored provider config.
+// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext.
+// Unreadable values (legacy ciphertext without a valid key, or malformed data)
+// are treated as empty so the service keeps running while the admin re-enters
+// the config via the UI.
+//
+// TODO(deprecated-legacy-ciphertext): The AES fallback branch below is a
+// transitional compatibility shim for pre-plaintext records. Remove it (and
+// the encryptionKey field + the Decrypt import) after a few releases once all
+// live deployments have re-saved their provider configs through the UI.
+func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, error) {
+ if stored == "" {
+ return nil, nil
+ }
+ var config map[string]string
+ if err := json.Unmarshal([]byte(stored), &config); err == nil {
+ return config, nil
+ }
+ // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal.
+ if len(lb.encryptionKey) == AES256KeySize {
+ //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal
+ if plaintext, err := Decrypt(stored, lb.encryptionKey); err == nil {
+ if err := json.Unmarshal([]byte(plaintext), &config); err == nil {
+ return config, nil
+ }
+ }
+ }
+ slog.Warn("payment provider config unreadable, treating as empty for re-entry",
+ "stored_len", len(stored))
+ return nil, nil
+}
+
+// GetInstanceDailyAmount returns the total completed order amount for an instance today.
+func (lb *DefaultLoadBalancer) GetInstanceDailyAmount(ctx context.Context, instanceID string) (float64, error) {
+ todayStart := startOfDay(time.Now())
+
+ var result []struct {
+ Sum float64 `json:"sum"`
+ }
+ err := lb.db.PaymentOrder.Query().
+ Where(
+ paymentorder.ProviderInstanceID(instanceID),
+ paymentorder.StatusIn(OrderStatusCompleted, OrderStatusPaid, OrderStatusRecharging),
+ paymentorder.PaidAtGTE(todayStart),
+ ).
+ Aggregate(dbent.Sum(paymentorder.FieldPayAmount)).
+ Scan(ctx, &result)
+ if err != nil {
+ return 0, fmt.Errorf("query daily amount: %w", err)
+ }
+ if len(result) > 0 {
+ return result[0].Sum, nil
+ }
+ return 0, nil
+}
+
+func startOfDay(t time.Time) time.Time {
+ return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
+}
+
+// InstanceSupportsType checks if the given supported types string includes the target type.
+// An empty supportedTypes string means all types are supported.
+func InstanceSupportsType(supportedTypes string, target PaymentType) bool {
+ if supportedTypes == "" {
+ return true
+ }
+ normalizedTarget := normalizeVisibleMethodSupportType(target)
+ for _, t := range strings.Split(supportedTypes, ",") {
+ supported := strings.TrimSpace(t)
+ if supported == target || normalizeVisibleMethodSupportType(supported) == normalizedTarget {
+ return true
+ }
+ }
+ return false
+}
+
+func normalizeVisibleMethodSupportType(paymentType PaymentType) PaymentType {
+ switch strings.TrimSpace(paymentType) {
+ case TypeAlipay, TypeAlipayDirect:
+ return TypeAlipay
+ case TypeWxpay, TypeWxpayDirect:
+ return TypeWxpay
+ default:
+ return strings.TrimSpace(paymentType)
+ }
+}
+
+func legacyVisibleMethodAlias(paymentType PaymentType) PaymentType {
+ switch normalizeVisibleMethodSupportType(paymentType) {
+ case TypeAlipay:
+ return TypeAlipayDirect
+ case TypeWxpay:
+ return TypeWxpayDirect
+ default:
+ return ""
+ }
+}
+
+func resolveWxpayJSAPIAppID(config map[string]string) string {
+ if appID := strings.TrimSpace(config["mpAppId"]); appID != "" {
+ return appID
+ }
+ return strings.TrimSpace(config["appId"])
+}
+
+// GetInstanceConfig decrypts and returns the configuration for a provider instance by ID.
+func (lb *DefaultLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
+ inst, err := lb.db.PaymentProviderInstance.Get(ctx, instanceID)
+ if err != nil {
+ return nil, fmt.Errorf("get instance %d: %w", instanceID, err)
+ }
+ return lb.decryptConfig(inst.Config)
+}
diff --git a/backend/internal/payment/load_balancer_test.go b/backend/internal/payment/load_balancer_test.go
new file mode 100644
index 00000000..ed08a7dd
--- /dev/null
+++ b/backend/internal/payment/load_balancer_test.go
@@ -0,0 +1,593 @@
+//go:build unit
+
+package payment
+
+import (
+ "encoding/json"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+)
+
+func TestInstanceSupportsType(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ supportedTypes string
+ target PaymentType
+ expected bool
+ }{
+ {
+ name: "exact match single type",
+ supportedTypes: "alipay",
+ target: "alipay",
+ expected: true,
+ },
+ {
+ name: "no match single type",
+ supportedTypes: "wxpay",
+ target: "alipay",
+ expected: false,
+ },
+ {
+ name: "match in comma-separated list",
+ supportedTypes: "alipay,wxpay,stripe",
+ target: "wxpay",
+ expected: true,
+ },
+ {
+ name: "first in comma-separated list",
+ supportedTypes: "alipay,wxpay",
+ target: "alipay",
+ expected: true,
+ },
+ {
+ name: "last in comma-separated list",
+ supportedTypes: "alipay,wxpay,stripe",
+ target: "stripe",
+ expected: true,
+ },
+ {
+ name: "no match in comma-separated list",
+ supportedTypes: "alipay,wxpay",
+ target: "stripe",
+ expected: false,
+ },
+ {
+ name: "empty target",
+ supportedTypes: "alipay,wxpay",
+ target: "",
+ expected: false,
+ },
+ {
+ name: "types with spaces are trimmed",
+ supportedTypes: " alipay , wxpay ",
+ target: "alipay",
+ expected: true,
+ },
+ {
+ name: "legacy alipay direct supports canonical visible method",
+ supportedTypes: "alipay_direct",
+ target: "alipay",
+ expected: true,
+ },
+ {
+ name: "legacy wxpay direct supports canonical visible method",
+ supportedTypes: "wxpay_direct",
+ target: "wxpay",
+ expected: true,
+ },
+ {
+ name: "empty supported types means all supported",
+ supportedTypes: "",
+ target: "alipay",
+ expected: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := InstanceSupportsType(tt.supportedTypes, tt.target)
+ if got != tt.expected {
+ t.Fatalf("InstanceSupportsType(%q, %q) = %v, want %v", tt.supportedTypes, tt.target, got, tt.expected)
+ }
+ })
+ }
+}
+
+func TestGetInstanceChannelLimitsFallsBackToLegacyDirectAliases(t *testing.T) {
+ t.Parallel()
+
+ inst := testInstance(1, TypeAlipay, makeLimitsJSON(TypeAlipayDirect, ChannelLimits{SingleMax: 66}))
+ got := getInstanceChannelLimits(inst, TypeAlipay)
+ if got.SingleMax != 66 {
+ t.Fatalf("getInstanceChannelLimits() = %+v, want SingleMax=66", got)
+ }
+
+ wxInst := testInstance(2, TypeWxpay, makeLimitsJSON(TypeWxpayDirect, ChannelLimits{SingleMin: 8}))
+ wxGot := getInstanceChannelLimits(wxInst, TypeWxpay)
+ if wxGot.SingleMin != 8 {
+ t.Fatalf("getInstanceChannelLimits() = %+v, want SingleMin=8", wxGot)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Helper to build test PaymentProviderInstance values
+// ---------------------------------------------------------------------------
+
+func testInstance(id int64, providerKey, limits string) *dbent.PaymentProviderInstance {
+ return &dbent.PaymentProviderInstance{
+ ID: id,
+ ProviderKey: providerKey,
+ Limits: limits,
+ Enabled: true,
+ }
+}
+
+// makeLimitsJSON builds a limits JSON string for a single payment type.
+func makeLimitsJSON(paymentType string, cl ChannelLimits) string {
+ m := map[string]ChannelLimits{paymentType: cl}
+ b, _ := json.Marshal(m)
+ return string(b)
+}
+
+// ---------------------------------------------------------------------------
+// filterByLimits
+// ---------------------------------------------------------------------------
+
+func TestFilterByLimits(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ candidates []instanceCandidate
+ paymentType PaymentType
+ orderAmount float64
+ wantIDs []int64 // expected surviving instance IDs
+ }{
+ {
+ name: "order below SingleMin is filtered out",
+ candidates: []instanceCandidate{
+ {inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 10})), dailyUsed: 0},
+ },
+ paymentType: "alipay",
+ orderAmount: 5,
+ wantIDs: nil,
+ },
+ {
+ name: "order at exact SingleMin boundary passes",
+ candidates: []instanceCandidate{
+ {inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 10})), dailyUsed: 0},
+ },
+ paymentType: "alipay",
+ orderAmount: 10,
+ wantIDs: []int64{1},
+ },
+ {
+ name: "order above SingleMax is filtered out",
+ candidates: []instanceCandidate{
+ {inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMax: 100})), dailyUsed: 0},
+ },
+ paymentType: "alipay",
+ orderAmount: 150,
+ wantIDs: nil,
+ },
+ {
+ name: "order at exact SingleMax boundary passes",
+ candidates: []instanceCandidate{
+ {inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMax: 100})), dailyUsed: 0},
+ },
+ paymentType: "alipay",
+ orderAmount: 100,
+ wantIDs: []int64{1},
+ },
+ {
+ name: "daily used + orderAmount exceeding dailyLimit is filtered out",
+ candidates: []instanceCandidate{
+ {inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{DailyLimit: 500})), dailyUsed: 480},
+ },
+ paymentType: "alipay",
+ orderAmount: 30,
+ wantIDs: nil, // 480+30=510 > 500
+ },
+ {
+ name: "daily used + orderAmount equal to dailyLimit passes (strict greater-than)",
+ candidates: []instanceCandidate{
+ {inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{DailyLimit: 500})), dailyUsed: 480},
+ },
+ paymentType: "alipay",
+ orderAmount: 20,
+ wantIDs: []int64{1}, // 480+20=500, 500 > 500 is false → passes
+ },
+ {
+ name: "daily used + orderAmount below dailyLimit passes",
+ candidates: []instanceCandidate{
+ {inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{DailyLimit: 500})), dailyUsed: 400},
+ },
+ paymentType: "alipay",
+ orderAmount: 50,
+ wantIDs: []int64{1},
+ },
+ {
+ name: "no limits configured passes through",
+ candidates: []instanceCandidate{
+ {inst: testInstance(1, "easypay", ""), dailyUsed: 99999},
+ },
+ paymentType: "alipay",
+ orderAmount: 100,
+ wantIDs: []int64{1},
+ },
+ {
+ name: "multiple candidates with partial filtering",
+ candidates: []instanceCandidate{
+ // singleMax=50, order=80 → filtered out
+ {inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMax: 50})), dailyUsed: 0},
+ // no limits → passes
+ {inst: testInstance(2, "easypay", ""), dailyUsed: 0},
+ // singleMin=100, order=80 → filtered out
+ {inst: testInstance(3, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 100})), dailyUsed: 0},
+ // daily limit ok → passes (500+80=580 < 1000)
+ {inst: testInstance(4, "easypay", makeLimitsJSON("alipay", ChannelLimits{DailyLimit: 1000})), dailyUsed: 500},
+ },
+ paymentType: "alipay",
+ orderAmount: 80,
+ wantIDs: []int64{2, 4},
+ },
+ {
+ name: "zero SingleMin and SingleMax means no single-transaction limit",
+ candidates: []instanceCandidate{
+ {inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 0, SingleMax: 0, DailyLimit: 0})), dailyUsed: 0},
+ },
+ paymentType: "alipay",
+ orderAmount: 99999,
+ wantIDs: []int64{1},
+ },
+ {
+ name: "all limits combined - order passes all checks",
+ candidates: []instanceCandidate{
+ {inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 10, SingleMax: 200, DailyLimit: 1000})), dailyUsed: 500},
+ },
+ paymentType: "alipay",
+ orderAmount: 50,
+ wantIDs: []int64{1},
+ },
+ {
+ name: "all limits combined - order fails SingleMin",
+ candidates: []instanceCandidate{
+ {inst: testInstance(1, "easypay", makeLimitsJSON("alipay", ChannelLimits{SingleMin: 10, SingleMax: 200, DailyLimit: 1000})), dailyUsed: 500},
+ },
+ paymentType: "alipay",
+ orderAmount: 5,
+ wantIDs: nil,
+ },
+ {
+ name: "empty candidates returns empty",
+ candidates: nil,
+ paymentType: "alipay",
+ orderAmount: 10,
+ wantIDs: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := filterByLimits(tt.candidates, tt.paymentType, tt.orderAmount)
+ gotIDs := make([]int64, len(got))
+ for i, c := range got {
+ gotIDs[i] = c.inst.ID
+ }
+ if !int64SliceEqual(gotIDs, tt.wantIDs) {
+ t.Fatalf("filterByLimits() returned IDs %v, want %v", gotIDs, tt.wantIDs)
+ }
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// pickLeastAmount
+// ---------------------------------------------------------------------------
+
+func TestPickLeastAmount(t *testing.T) {
+ t.Parallel()
+
+ t.Run("picks candidate with lowest dailyUsed", func(t *testing.T) {
+ t.Parallel()
+ candidates := []instanceCandidate{
+ {inst: testInstance(1, "easypay", ""), dailyUsed: 300},
+ {inst: testInstance(2, "easypay", ""), dailyUsed: 100},
+ {inst: testInstance(3, "easypay", ""), dailyUsed: 200},
+ }
+ got := pickLeastAmount(candidates)
+ if got.inst.ID != 2 {
+ t.Fatalf("pickLeastAmount() picked instance %d, want 2", got.inst.ID)
+ }
+ })
+
+ t.Run("with equal dailyUsed picks the first one", func(t *testing.T) {
+ t.Parallel()
+ candidates := []instanceCandidate{
+ {inst: testInstance(1, "easypay", ""), dailyUsed: 100},
+ {inst: testInstance(2, "easypay", ""), dailyUsed: 100},
+ {inst: testInstance(3, "easypay", ""), dailyUsed: 200},
+ }
+ got := pickLeastAmount(candidates)
+ if got.inst.ID != 1 {
+ t.Fatalf("pickLeastAmount() picked instance %d, want 1 (first with lowest)", got.inst.ID)
+ }
+ })
+
+ t.Run("single candidate returns that candidate", func(t *testing.T) {
+ t.Parallel()
+ candidates := []instanceCandidate{
+ {inst: testInstance(42, "easypay", ""), dailyUsed: 999},
+ }
+ got := pickLeastAmount(candidates)
+ if got.inst.ID != 42 {
+ t.Fatalf("pickLeastAmount() picked instance %d, want 42", got.inst.ID)
+ }
+ })
+
+ t.Run("zero usage among non-zero picks zero", func(t *testing.T) {
+ t.Parallel()
+ candidates := []instanceCandidate{
+ {inst: testInstance(1, "easypay", ""), dailyUsed: 500},
+ {inst: testInstance(2, "easypay", ""), dailyUsed: 0},
+ {inst: testInstance(3, "easypay", ""), dailyUsed: 300},
+ }
+ got := pickLeastAmount(candidates)
+ if got.inst.ID != 2 {
+ t.Fatalf("pickLeastAmount() picked instance %d, want 2", got.inst.ID)
+ }
+ })
+}
+
+// ---------------------------------------------------------------------------
+// getInstanceChannelLimits
+// ---------------------------------------------------------------------------
+
+func TestGetInstanceChannelLimits(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ inst *dbent.PaymentProviderInstance
+ paymentType PaymentType
+ want ChannelLimits
+ }{
+ {
+ name: "empty limits string returns zero ChannelLimits",
+ inst: testInstance(1, "easypay", ""),
+ paymentType: "alipay",
+ want: ChannelLimits{},
+ },
+ {
+ name: "invalid JSON returns zero ChannelLimits",
+ inst: testInstance(1, "easypay", "not-json{"),
+ paymentType: "alipay",
+ want: ChannelLimits{},
+ },
+ {
+ name: "valid JSON with matching payment type",
+ inst: testInstance(1, "easypay",
+ `{"alipay":{"singleMin":5,"singleMax":200,"dailyLimit":1000}}`),
+ paymentType: "alipay",
+ want: ChannelLimits{SingleMin: 5, SingleMax: 200, DailyLimit: 1000},
+ },
+ {
+ name: "payment type not in limits returns zero ChannelLimits",
+ inst: testInstance(1, "easypay",
+ `{"alipay":{"singleMin":5,"singleMax":200}}`),
+ paymentType: "wxpay",
+ want: ChannelLimits{},
+ },
+ {
+ name: "stripe provider uses stripe lookup key regardless of payment type",
+ inst: testInstance(1, "stripe",
+ `{"stripe":{"singleMin":10,"singleMax":500,"dailyLimit":5000}}`),
+ paymentType: "alipay",
+ want: ChannelLimits{SingleMin: 10, SingleMax: 500, DailyLimit: 5000},
+ },
+ {
+ name: "stripe provider ignores payment type key even if present",
+ inst: testInstance(1, "stripe",
+ `{"stripe":{"singleMin":10,"singleMax":500},"alipay":{"singleMin":1,"singleMax":100}}`),
+ paymentType: "alipay",
+ want: ChannelLimits{SingleMin: 10, SingleMax: 500},
+ },
+ {
+ name: "non-stripe provider uses payment type as lookup key",
+ inst: testInstance(1, "easypay",
+ `{"alipay":{"singleMin":5},"wxpay":{"singleMin":10}}`),
+ paymentType: "wxpay",
+ want: ChannelLimits{SingleMin: 10},
+ },
+ {
+ name: "valid JSON with partial limits (only dailyLimit)",
+ inst: testInstance(1, "easypay",
+ `{"alipay":{"dailyLimit":800}}`),
+ paymentType: "alipay",
+ want: ChannelLimits{DailyLimit: 800},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := getInstanceChannelLimits(tt.inst, tt.paymentType)
+ if got != tt.want {
+ t.Fatalf("getInstanceChannelLimits() = %+v, want %+v", got, tt.want)
+ }
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// startOfDay
+// ---------------------------------------------------------------------------
+
+func TestStartOfDay(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ in time.Time
+ want time.Time
+ }{
+ {
+ name: "midday returns midnight of same day",
+ in: time.Date(2025, 6, 15, 14, 30, 45, 123456789, time.UTC),
+ want: time.Date(2025, 6, 15, 0, 0, 0, 0, time.UTC),
+ },
+ {
+ name: "midnight returns same time",
+ in: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
+ want: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
+ },
+ {
+ name: "last second of day returns midnight of same day",
+ in: time.Date(2025, 12, 31, 23, 59, 59, 999999999, time.UTC),
+ want: time.Date(2025, 12, 31, 0, 0, 0, 0, time.UTC),
+ },
+ {
+ name: "preserves timezone location",
+ in: time.Date(2025, 3, 10, 15, 0, 0, 0, time.FixedZone("CST", 8*3600)),
+ want: time.Date(2025, 3, 10, 0, 0, 0, 0, time.FixedZone("CST", 8*3600)),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := startOfDay(tt.in)
+ if !got.Equal(tt.want) {
+ t.Fatalf("startOfDay(%v) = %v, want %v", tt.in, got, tt.want)
+ }
+ // Also verify location is preserved.
+ if got.Location().String() != tt.want.Location().String() {
+ t.Fatalf("startOfDay() location = %v, want %v", got.Location(), tt.want.Location())
+ }
+ })
+ }
+}
+
+func TestDecryptConfig_PlaintextAndLegacyCompat(t *testing.T) {
+ t.Parallel()
+
+ key := make([]byte, AES256KeySize)
+ for i := range key {
+ key[i] = byte(i + 1)
+ }
+ wrongKey := make([]byte, AES256KeySize)
+ for i := range wrongKey {
+ wrongKey[i] = byte(0xFF - i)
+ }
+
+ plaintextJSON := `{"appId":"app-123","secret":"sec-xyz"}`
+
+ legacyEncrypted, err := Encrypt(plaintextJSON, key)
+ if err != nil {
+ t.Fatalf("seed Encrypt: %v", err)
+ }
+
+ tests := []struct {
+ name string
+ stored string
+ key []byte
+ want map[string]string
+ }{
+ {
+ name: "empty stored returns nil map",
+ stored: "",
+ key: key,
+ want: nil,
+ },
+ {
+ name: "plaintext JSON parses directly",
+ stored: plaintextJSON,
+ key: nil,
+ want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
+ },
+ {
+ name: "plaintext JSON works even with key present",
+ stored: plaintextJSON,
+ key: key,
+ want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
+ },
+ {
+ name: "legacy ciphertext with correct key decrypts",
+ stored: legacyEncrypted,
+ key: key,
+ want: map[string]string{"appId": "app-123", "secret": "sec-xyz"},
+ },
+ {
+ name: "legacy ciphertext with no key treated as empty",
+ stored: legacyEncrypted,
+ key: nil,
+ want: nil,
+ },
+ {
+ name: "legacy ciphertext with wrong key treated as empty",
+ stored: legacyEncrypted,
+ key: wrongKey,
+ want: nil,
+ },
+ {
+ name: "garbage data treated as empty",
+ stored: "not-json-and-not-ciphertext",
+ key: key,
+ want: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ lb := NewDefaultLoadBalancer(nil, tt.key)
+ got, err := lb.decryptConfig(tt.stored)
+ if err != nil {
+ t.Fatalf("decryptConfig unexpected error: %v", err)
+ }
+ if !stringMapEqual(got, tt.want) {
+ t.Fatalf("decryptConfig = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+// stringMapEqual compares two map[string]string values; nil and empty are equal.
+func stringMapEqual(a, b map[string]string) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for k, v := range a {
+ if bv, ok := b[k]; !ok || bv != v {
+ return false
+ }
+ }
+ return true
+}
+
+// ---------------------------------------------------------------------------
+// Helpers
+// ---------------------------------------------------------------------------
+
+// int64SliceEqual compares two int64 slices for equality.
+// Both nil and empty slices are treated as equal.
+func int64SliceEqual(a, b []int64) bool {
+ if len(a) == 0 && len(b) == 0 {
+ return true
+ }
+ if len(a) != len(b) {
+ return false
+ }
+ for i := range a {
+ if a[i] != b[i] {
+ return false
+ }
+ }
+ return true
+}
diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go
new file mode 100644
index 00000000..1234b568
--- /dev/null
+++ b/backend/internal/payment/provider/alipay.go
@@ -0,0 +1,390 @@
+package provider
+
+import (
+ "context"
+ "fmt"
+ "net/url"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/smartwalle/alipay/v3"
+)
+
+// Alipay product codes.
+const (
+ alipayProductCodePreCreate = "FACE_TO_FACE_PAYMENT"
+ alipayProductCodeWapPay = "QUICK_WAP_WAY"
+ alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY"
+)
+
+// Alipay response constants.
+const (
+ alipayFundChangeYes = "Y"
+ alipayErrTradeNotExist = "ACQ.TRADE_NOT_EXIST"
+ alipayRefundSuffix = "-refund"
+)
+
+var (
+ alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
+ return client.TradeWapPay(param)
+ }
+ alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
+ return client.TradePreCreate(ctx, param)
+ }
+ alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
+ return client.TradePagePay(param)
+ }
+)
+
+// Alipay implements payment.Provider and payment.CancelableProvider using the smartwalle/alipay SDK.
+type Alipay struct {
+ instanceID string
+ config map[string]string // appId, privateKey, publicKey (or alipayPublicKey), notifyUrl, returnUrl
+
+ mu sync.Mutex
+ client *alipay.Client
+}
+
+// NewAlipay creates a new Alipay provider instance.
+func NewAlipay(instanceID string, config map[string]string) (*Alipay, error) {
+ required := []string{"appId", "privateKey"}
+ for _, k := range required {
+ if config[k] == "" {
+ return nil, fmt.Errorf("alipay config missing required key: %s", k)
+ }
+ }
+ return &Alipay{
+ instanceID: instanceID,
+ config: config,
+ }, nil
+}
+
+func (a *Alipay) getClient() (*alipay.Client, error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ if a.client != nil {
+ return a.client, nil
+ }
+ client, err := alipay.New(a.config["appId"], a.config["privateKey"], true)
+ if err != nil {
+ return nil, fmt.Errorf("alipay init client: %w", err)
+ }
+ pubKey := a.config["publicKey"]
+ if pubKey == "" {
+ pubKey = a.config["alipayPublicKey"]
+ }
+ if pubKey == "" {
+ return nil, fmt.Errorf("alipay config missing required key: publicKey (or alipayPublicKey)")
+ }
+ if err := client.LoadAliPayPublicKey(pubKey); err != nil {
+ return nil, fmt.Errorf("alipay load public key: %w", err)
+ }
+ a.client = client
+ return a.client, nil
+}
+
+func (a *Alipay) Name() string { return "Alipay" }
+func (a *Alipay) ProviderKey() string { return payment.TypeAlipay }
+func (a *Alipay) SupportedTypes() []payment.PaymentType {
+ return []payment.PaymentType{payment.TypeAlipay}
+}
+
+func (a *Alipay) MerchantIdentityMetadata() map[string]string {
+ if a == nil {
+ return nil
+ }
+ appID := strings.TrimSpace(a.config["appId"])
+ if appID == "" {
+ return nil
+ }
+ return map[string]string{"app_id": appID}
+}
+
+// CreatePayment creates an Alipay payment using the following routing:
+// - Mobile (H5): alipay.trade.wap.pay — browser redirect into Alipay.
+// - Desktop: prefer alipay.trade.precreate to get a scan payload directly.
+// - Desktop fallback: if precreate is unavailable for the merchant, fall back
+// to alipay.trade.page.pay and expose both pay_url and qr_code so the
+// frontend can render a QR while still allowing direct page open.
+func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ client, err := a.getClient()
+ if err != nil {
+ return nil, err
+ }
+
+ notifyURL := a.config["notifyUrl"]
+ if req.NotifyURL != "" {
+ notifyURL = req.NotifyURL
+ }
+ returnURL := a.config["returnUrl"]
+ if req.ReturnURL != "" {
+ returnURL = req.ReturnURL
+ }
+
+ if req.IsMobile {
+ return a.createWapTrade(client, req, notifyURL, returnURL)
+ }
+ return a.createDesktopTrade(ctx, client, req, notifyURL, returnURL)
+}
+
+func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
+ param := alipay.TradeWapPay{}
+ param.OutTradeNo = req.OrderID
+ param.TotalAmount = req.Amount
+ param.Subject = req.Subject
+ param.ProductCode = alipayProductCodeWapPay
+ param.NotifyURL = notifyURL
+ param.ReturnURL = returnURL
+
+ payURL, err := alipayTradeWapPay(client, param)
+ if err != nil {
+ return nil, fmt.Errorf("alipay TradeWapPay: %w", err)
+ }
+ return &payment.CreatePaymentResponse{
+ TradeNo: req.OrderID,
+ PayURL: payURL.String(),
+ }, nil
+}
+
+func (a *Alipay) createDesktopTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
+ resp, precreateErr := a.createPrecreateTrade(ctx, client, req, notifyURL)
+ if precreateErr == nil {
+ return resp, nil
+ }
+
+ resp, pagePayErr := a.createPagePayTrade(client, req, notifyURL, returnURL)
+ if pagePayErr == nil {
+ return resp, nil
+ }
+
+ return nil, fmt.Errorf("alipay desktop payment failed: precreate=%v; pagepay=%w", precreateErr, pagePayErr)
+}
+
+func (a *Alipay) createPrecreateTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL string) (*payment.CreatePaymentResponse, error) {
+ param := alipay.TradePreCreate{}
+ param.OutTradeNo = req.OrderID
+ param.TotalAmount = req.Amount
+ param.Subject = req.Subject
+ param.ProductCode = alipayProductCodePreCreate
+ param.NotifyURL = notifyURL
+
+ rsp, err := alipayTradePreCreate(ctx, client, param)
+ if err != nil {
+ return nil, fmt.Errorf("alipay TradePreCreate: %w", err)
+ }
+ if rsp == nil {
+ return nil, fmt.Errorf("alipay TradePreCreate: empty response")
+ }
+ if rsp.IsFailure() {
+ return nil, fmt.Errorf("alipay TradePreCreate failed: %s", rsp.Error.Error())
+ }
+ if strings.TrimSpace(rsp.QRCode) == "" {
+ return nil, fmt.Errorf("alipay TradePreCreate: empty qr_code")
+ }
+
+ return &payment.CreatePaymentResponse{
+ TradeNo: req.OrderID,
+ QRCode: rsp.QRCode,
+ }, nil
+}
+
+func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
+ param := alipay.TradePagePay{}
+ param.OutTradeNo = req.OrderID
+ param.TotalAmount = req.Amount
+ param.Subject = req.Subject
+ param.ProductCode = alipayProductCodePagePay
+ param.NotifyURL = notifyURL
+ param.ReturnURL = returnURL
+
+ payURL, err := alipayTradePagePay(client, param)
+ if err != nil {
+ return nil, fmt.Errorf("alipay TradePagePay: %w", err)
+ }
+ return &payment.CreatePaymentResponse{
+ TradeNo: req.OrderID,
+ PayURL: payURL.String(),
+ QRCode: payURL.String(),
+ }, nil
+}
+
+// QueryOrder queries the trade status via Alipay.
+func (a *Alipay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
+ client, err := a.getClient()
+ if err != nil {
+ return nil, err
+ }
+
+ result, err := client.TradeQuery(ctx, alipay.TradeQuery{OutTradeNo: tradeNo})
+ if err != nil {
+ if isTradeNotExist(err) {
+ return &payment.QueryOrderResponse{
+ TradeNo: tradeNo,
+ Status: payment.ProviderStatusPending,
+ }, nil
+ }
+ return nil, fmt.Errorf("alipay TradeQuery: %w", err)
+ }
+
+ status := payment.ProviderStatusPending
+ switch result.TradeStatus {
+ case alipay.TradeStatusSuccess, alipay.TradeStatusFinished:
+ status = payment.ProviderStatusPaid
+ case alipay.TradeStatusClosed:
+ status = payment.ProviderStatusFailed
+ }
+
+ amount, err := strconv.ParseFloat(result.TotalAmount, 64)
+ if err != nil {
+ amount, err = parseAlipayAmount(
+ result.TotalAmount,
+ result.ReceiptAmount,
+ result.BuyerPayAmount,
+ result.InvoiceAmount,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("alipay parse amount: %w", err)
+ }
+ }
+
+ return &payment.QueryOrderResponse{
+ TradeNo: result.TradeNo,
+ Status: status,
+ Amount: amount,
+ PaidAt: result.SendPayDate,
+ Metadata: a.MerchantIdentityMetadata(),
+ }, nil
+}
+
+// VerifyNotification decodes and verifies an Alipay async notification.
+func (a *Alipay) VerifyNotification(ctx context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) {
+ client, err := a.getClient()
+ if err != nil {
+ return nil, err
+ }
+
+ values, err := url.ParseQuery(rawBody)
+ if err != nil {
+ return nil, fmt.Errorf("alipay parse notification: %w", err)
+ }
+
+ notification, err := client.DecodeNotification(ctx, values)
+ if err != nil {
+ return nil, fmt.Errorf("alipay verify notification: %w", err)
+ }
+
+ status := payment.ProviderStatusFailed
+ if notification.TradeStatus == alipay.TradeStatusSuccess || notification.TradeStatus == alipay.TradeStatusFinished {
+ status = payment.ProviderStatusSuccess
+ }
+
+ amount, err := strconv.ParseFloat(notification.TotalAmount, 64)
+ if err != nil {
+ amount, err = parseAlipayAmount(
+ notification.TotalAmount,
+ notification.ReceiptAmount,
+ notification.BuyerPayAmount,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("alipay parse notification amount: %w", err)
+ }
+ }
+
+ metadata := a.MerchantIdentityMetadata()
+ if appID := strings.TrimSpace(notification.AppId); appID != "" {
+ if metadata == nil {
+ metadata = map[string]string{}
+ }
+ metadata["app_id"] = appID
+ }
+
+ return &payment.PaymentNotification{
+ TradeNo: notification.TradeNo,
+ OrderID: notification.OutTradeNo,
+ Amount: amount,
+ Status: status,
+ RawData: rawBody,
+ Metadata: metadata,
+ }, nil
+}
+
+// Refund requests a refund through Alipay.
+func (a *Alipay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
+ client, err := a.getClient()
+ if err != nil {
+ return nil, err
+ }
+
+ result, err := client.TradeRefund(ctx, alipay.TradeRefund{
+ OutTradeNo: req.OrderID,
+ RefundAmount: req.Amount,
+ RefundReason: req.Reason,
+ OutRequestNo: fmt.Sprintf("%s-refund-%d", req.OrderID, time.Now().UnixNano()),
+ })
+ if err != nil {
+ return nil, fmt.Errorf("alipay TradeRefund: %w", err)
+ }
+
+ refundStatus := payment.ProviderStatusPending
+ if result.FundChange == alipayFundChangeYes {
+ refundStatus = payment.ProviderStatusSuccess
+ }
+
+ refundID := result.TradeNo
+ if refundID == "" {
+ refundID = req.OrderID + alipayRefundSuffix
+ }
+
+ return &payment.RefundResponse{
+ RefundID: refundID,
+ Status: refundStatus,
+ }, nil
+}
+
+// CancelPayment closes a pending trade on Alipay.
+func (a *Alipay) CancelPayment(ctx context.Context, tradeNo string) error {
+ client, err := a.getClient()
+ if err != nil {
+ return err
+ }
+
+ _, err = client.TradeClose(ctx, alipay.TradeClose{OutTradeNo: tradeNo})
+ if err != nil {
+ if isTradeNotExist(err) {
+ return nil
+ }
+ return fmt.Errorf("alipay TradeClose: %w", err)
+ }
+ return nil
+}
+
+func isTradeNotExist(err error) bool {
+ if err == nil {
+ return false
+ }
+ return strings.Contains(err.Error(), alipayErrTradeNotExist)
+}
+
+func parseAlipayAmount(values ...string) (float64, error) {
+ for _, raw := range values {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ continue
+ }
+ amount, err := strconv.ParseFloat(raw, 64)
+ if err == nil {
+ return amount, nil
+ }
+ }
+ return 0, fmt.Errorf("no valid amount field")
+}
+
+// Ensure interface compliance.
+var (
+ _ payment.Provider = (*Alipay)(nil)
+ _ payment.CancelableProvider = (*Alipay)(nil)
+ _ payment.MerchantIdentityProvider = (*Alipay)(nil)
+)
diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go
new file mode 100644
index 00000000..fdc8eec1
--- /dev/null
+++ b/backend/internal/payment/provider/alipay_test.go
@@ -0,0 +1,307 @@
+//go:build unit
+
+package provider
+
+import (
+ "context"
+ "errors"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/smartwalle/alipay/v3"
+)
+
+func TestIsTradeNotExist(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ err error
+ want bool
+ }{
+ {
+ name: "nil error returns false",
+ err: nil,
+ want: false,
+ },
+ {
+ name: "error containing ACQ.TRADE_NOT_EXIST returns true",
+ err: errors.New("alipay: sub_code=ACQ.TRADE_NOT_EXIST, sub_msg=交易不存在"),
+ want: true,
+ },
+ {
+ name: "error not containing the code returns false",
+ err: errors.New("alipay: sub_code=ACQ.SYSTEM_ERROR, sub_msg=系统错误"),
+ want: false,
+ },
+ {
+ name: "error with only partial match returns false",
+ err: errors.New("ACQ.TRADE_NOT"),
+ want: false,
+ },
+ {
+ name: "error with exact constant value returns true",
+ err: errors.New(alipayErrTradeNotExist),
+ want: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := isTradeNotExist(tt.err)
+ if got != tt.want {
+ t.Errorf("isTradeNotExist(%v) = %v, want %v", tt.err, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestNewAlipay(t *testing.T) {
+ t.Parallel()
+
+ validConfig := map[string]string{
+ "appId": "2021001234567890",
+ "privateKey": "MIIEvQIBADANBgkqhkiG9w0BAQEFAASC...",
+ }
+
+ // helper to clone and override config fields
+ withOverride := func(overrides map[string]string) map[string]string {
+ cfg := make(map[string]string, len(validConfig))
+ for k, v := range validConfig {
+ cfg[k] = v
+ }
+ for k, v := range overrides {
+ cfg[k] = v
+ }
+ return cfg
+ }
+
+ tests := []struct {
+ name string
+ config map[string]string
+ wantErr bool
+ errSubstr string
+ }{
+ {
+ name: "valid config succeeds",
+ config: validConfig,
+ wantErr: false,
+ },
+ {
+ name: "missing appId",
+ config: withOverride(map[string]string{"appId": ""}),
+ wantErr: true,
+ errSubstr: "appId",
+ },
+ {
+ name: "missing privateKey",
+ config: withOverride(map[string]string{"privateKey": ""}),
+ wantErr: true,
+ errSubstr: "privateKey",
+ },
+ {
+ name: "nil config map returns error for appId",
+ config: map[string]string{},
+ wantErr: true,
+ errSubstr: "appId",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got, err := NewAlipay("test-instance", tt.config)
+ if tt.wantErr {
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+ if tt.errSubstr != "" && !strings.Contains(err.Error(), tt.errSubstr) {
+ t.Errorf("error %q should contain %q", err.Error(), tt.errSubstr)
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got == nil {
+ t.Fatal("expected non-nil Alipay instance")
+ }
+ if got.instanceID != "test-instance" {
+ t.Errorf("instanceID = %q, want %q", got.instanceID, "test-instance")
+ }
+ })
+ }
+}
+
+func TestCreateTradeUsesPagePayForDesktop(t *testing.T) {
+ origPreCreate := alipayTradePreCreate
+ origPagePay := alipayTradePagePay
+ origWapPay := alipayTradeWapPay
+ t.Cleanup(func() {
+ alipayTradePreCreate = origPreCreate
+ alipayTradePagePay = origPagePay
+ alipayTradeWapPay = origWapPay
+ })
+
+ preCreateCalls := 0
+ pagePayCalls := 0
+ wapPayCalls := 0
+ alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
+ preCreateCalls++
+ return nil, errors.New("merchant does not have FACE_TO_FACE_PAYMENT")
+ }
+ alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
+ pagePayCalls++
+ if param.OutTradeNo != "sub2_100" {
+ t.Fatalf("out_trade_no = %q, want %q", param.OutTradeNo, "sub2_100")
+ }
+ if param.NotifyURL != "https://merchant.example.com/api/v1/payment/webhook/alipay" {
+ t.Fatalf("notify_url = %q", param.NotifyURL)
+ }
+ return url.Parse("https://openapi.alipay.com/gateway.do?page-pay")
+ }
+ alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
+ wapPayCalls++
+ return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay")
+ }
+
+ provider := &Alipay{}
+ resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{
+ OrderID: "sub2_100",
+ Amount: "88.00",
+ Subject: "Balance recharge",
+ }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if preCreateCalls != 1 {
+ t.Fatalf("precreate calls = %d, want 1", preCreateCalls)
+ }
+ if pagePayCalls != 1 {
+ t.Fatalf("page pay calls = %d, want 1", pagePayCalls)
+ }
+ if wapPayCalls != 0 {
+ t.Fatalf("wap pay calls = %d, want 0", wapPayCalls)
+ }
+ if resp.PayURL == "" {
+ t.Fatal("expected pay_url for desktop page pay")
+ }
+ if resp.QRCode != resp.PayURL {
+ t.Fatalf("qr_code = %q, want same as pay_url %q", resp.QRCode, resp.PayURL)
+ }
+}
+
+func TestCreateTradeUsesWapPayForMobile(t *testing.T) {
+ origWapPay := alipayTradeWapPay
+ t.Cleanup(func() {
+ alipayTradeWapPay = origWapPay
+ })
+
+ wapPayCalls := 0
+ alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) {
+ wapPayCalls++
+ if param.ReturnURL != "https://merchant.example.com/payment/result" {
+ t.Fatalf("return_url = %q", param.ReturnURL)
+ }
+ return url.Parse("https://openapi.alipay.com/gateway.do?wap-pay")
+ }
+
+ provider := &Alipay{}
+ resp, err := provider.createWapTrade(&alipay.Client{}, payment.CreatePaymentRequest{
+ OrderID: "sub2_101",
+ Amount: "18.00",
+ Subject: "Balance recharge",
+ IsMobile: true,
+ }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if wapPayCalls != 1 {
+ t.Fatalf("wap pay calls = %d, want 1", wapPayCalls)
+ }
+ if resp.PayURL == "" {
+ t.Fatal("expected pay_url for mobile wap pay")
+ }
+}
+
+func TestCreateTradeUsesPrecreateForDesktopWhenAvailable(t *testing.T) {
+ origPreCreate := alipayTradePreCreate
+ origPagePay := alipayTradePagePay
+ t.Cleanup(func() {
+ alipayTradePreCreate = origPreCreate
+ alipayTradePagePay = origPagePay
+ })
+
+ preCreateCalls := 0
+ pagePayCalls := 0
+ alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
+ preCreateCalls++
+ if param.ProductCode != alipayProductCodePreCreate {
+ t.Fatalf("product_code = %q, want %q", param.ProductCode, alipayProductCodePreCreate)
+ }
+ return &alipay.TradePreCreateRsp{
+ Error: alipay.Error{Code: alipay.CodeSuccess},
+ QRCode: "https://qr.alipay.example.com/precreate-token",
+ }, nil
+ }
+ alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
+ pagePayCalls++
+ return url.Parse("https://openapi.alipay.com/gateway.do?page-pay")
+ }
+
+ provider := &Alipay{}
+ resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{
+ OrderID: "sub2_102",
+ Amount: "66.00",
+ Subject: "Balance recharge",
+ }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if preCreateCalls != 1 {
+ t.Fatalf("precreate calls = %d, want 1", preCreateCalls)
+ }
+ if pagePayCalls != 0 {
+ t.Fatalf("page pay calls = %d, want 0", pagePayCalls)
+ }
+ if resp.QRCode != "https://qr.alipay.example.com/precreate-token" {
+ t.Fatalf("qr_code = %q", resp.QRCode)
+ }
+ if resp.PayURL != "" {
+ t.Fatalf("pay_url = %q, want empty for precreate", resp.PayURL)
+ }
+}
+
+func TestAlipayMerchantIdentityMetadata(t *testing.T) {
+ t.Parallel()
+
+ provider := &Alipay{
+ config: map[string]string{
+ "appId": "2021001234567890",
+ },
+ }
+
+ metadata := provider.MerchantIdentityMetadata()
+ if metadata["app_id"] != "2021001234567890" {
+ t.Fatalf("app_id = %q, want %q", metadata["app_id"], "2021001234567890")
+ }
+}
+
+func TestParseAlipayAmount(t *testing.T) {
+ t.Parallel()
+
+ amount, err := parseAlipayAmount("", "88.00", "77.00")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if amount != 88 {
+ t.Fatalf("amount = %v, want 88", amount)
+ }
+
+ if _, err := parseAlipayAmount("", "not-a-number"); err == nil {
+ t.Fatal("expected error when no valid amount field exists")
+ }
+}
diff --git a/backend/internal/payment/provider/easypay.go b/backend/internal/payment/provider/easypay.go
new file mode 100644
index 00000000..e7d8aab9
--- /dev/null
+++ b/backend/internal/payment/provider/easypay.go
@@ -0,0 +1,466 @@
+// Package provider contains concrete payment provider implementations.
+package provider
+
+import (
+ "context"
+ "crypto/hmac"
+ "crypto/md5"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+// EasyPay constants.
+const (
+ easypayCodeSuccess = 1
+ easypayStatusPaid = 1
+ easypayHTTPTimeout = 10 * time.Second
+ maxEasypayResponseSize = 1 << 20 // 1MB
+ maxEasypayErrorSummary = 512
+ tradeStatusSuccess = "TRADE_SUCCESS"
+ signTypeMD5 = "MD5"
+ paymentModePopup = "popup"
+ deviceMobile = "mobile"
+)
+
+// EasyPay implements payment.Provider for the EasyPay aggregation platform.
+type EasyPay struct {
+ instanceID string
+ config map[string]string
+ httpClient *http.Client
+}
+
+// NewEasyPay creates a new EasyPay provider.
+// config keys: pid, pkey, apiBase, notifyUrl, returnUrl, cid, cidAlipay, cidWxpay
+func NewEasyPay(instanceID string, config map[string]string) (*EasyPay, error) {
+ for _, k := range []string{"pid", "pkey", "apiBase", "notifyUrl", "returnUrl"} {
+ if strings.TrimSpace(config[k]) == "" {
+ return nil, fmt.Errorf("easypay config missing required key: %s", k)
+ }
+ }
+ cfg := make(map[string]string, len(config))
+ for k, v := range config {
+ cfg[k] = v
+ }
+ cfg["apiBase"] = normalizeEasyPayAPIBase(cfg["apiBase"])
+ return &EasyPay{
+ instanceID: instanceID,
+ config: cfg,
+ httpClient: &http.Client{Timeout: easypayHTTPTimeout},
+ }, nil
+}
+
+func normalizeEasyPayAPIBase(apiBase string) string {
+ base := strings.TrimSpace(apiBase)
+ if base == "" {
+ return ""
+ }
+ if parsed, err := url.Parse(base); err == nil && parsed.Scheme != "" && parsed.Host != "" {
+ parsed.RawQuery = ""
+ parsed.Fragment = ""
+ parsed.RawPath = ""
+ parsed.Path = trimEasyPayEndpointPath(parsed.Path)
+ return strings.TrimRight(parsed.String(), "/")
+ }
+ return strings.TrimRight(trimEasyPayEndpointPath(base), "/")
+}
+
+func trimEasyPayEndpointPath(path string) string {
+ path = strings.TrimRight(strings.TrimSpace(path), "/")
+ lower := strings.ToLower(path)
+ for _, endpoint := range []string{"/submit.php", "/mapi.php", "/api.php"} {
+ if strings.HasSuffix(lower, endpoint) {
+ return strings.TrimRight(path[:len(path)-len(endpoint)], "/")
+ }
+ }
+ return path
+}
+
+func (e *EasyPay) apiBase() string {
+ if e == nil {
+ return ""
+ }
+ return normalizeEasyPayAPIBase(e.config["apiBase"])
+}
+
+func (e *EasyPay) Name() string { return "EasyPay" }
+func (e *EasyPay) ProviderKey() string { return payment.TypeEasyPay }
+func (e *EasyPay) SupportedTypes() []payment.PaymentType {
+ return []payment.PaymentType{payment.TypeAlipay, payment.TypeWxpay}
+}
+
+func (e *EasyPay) MerchantIdentityMetadata() map[string]string {
+ if e == nil {
+ return nil
+ }
+ pid := strings.TrimSpace(e.config["pid"])
+ if pid == "" {
+ return nil
+ }
+ return map[string]string{"pid": pid}
+}
+
+func (e *EasyPay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ // Payment mode determined by instance config, not payment type.
+ // "popup" → hosted page (submit.php); "qrcode"/default → API call (mapi.php).
+ mode := e.config["paymentMode"]
+ if mode == paymentModePopup {
+ return e.createRedirectPayment(req)
+ }
+ return e.createAPIPayment(ctx, req)
+}
+
+// createRedirectPayment builds a submit.php URL for browser redirect.
+// No server-side API call — the user is redirected to EasyPay's hosted page.
+// TradeNo is empty; it arrives via the notify callback after payment.
+func (e *EasyPay) createRedirectPayment(req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ notifyURL, returnURL := e.resolveURLs(req)
+ params := map[string]string{
+ "pid": e.config["pid"], "type": req.PaymentType,
+ "out_trade_no": req.OrderID, "notify_url": notifyURL,
+ "return_url": returnURL, "name": req.Subject,
+ "money": req.Amount,
+ }
+ if cid := e.resolveCID(req.PaymentType); cid != "" {
+ params["cid"] = cid
+ }
+ if req.IsMobile {
+ params["device"] = deviceMobile
+ }
+ params["sign"] = easyPaySign(params, e.config["pkey"])
+ params["sign_type"] = signTypeMD5
+
+ q := url.Values{}
+ for k, v := range params {
+ q.Set(k, v)
+ }
+ payURL := e.apiBase() + "/submit.php?" + q.Encode()
+ return &payment.CreatePaymentResponse{PayURL: payURL}, nil
+}
+
+// createAPIPayment calls mapi.php to get payurl/qrcode (existing behavior).
+func (e *EasyPay) createAPIPayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ notifyURL, returnURL := e.resolveURLs(req)
+ params := map[string]string{
+ "pid": e.config["pid"], "type": req.PaymentType,
+ "out_trade_no": req.OrderID, "notify_url": notifyURL,
+ "return_url": returnURL, "name": req.Subject,
+ "money": req.Amount, "clientip": req.ClientIP,
+ }
+ if cid := e.resolveCID(req.PaymentType); cid != "" {
+ params["cid"] = cid
+ }
+ if req.IsMobile {
+ params["device"] = deviceMobile
+ }
+ params["sign"] = easyPaySign(params, e.config["pkey"])
+ params["sign_type"] = signTypeMD5
+
+ body, err := e.post(ctx, e.apiBase()+"/mapi.php", params)
+ if err != nil {
+ return nil, fmt.Errorf("easypay create: %w", err)
+ }
+ var resp struct {
+ Code int `json:"code"`
+ Msg string `json:"msg"`
+ TradeNo string `json:"trade_no"`
+ PayURL string `json:"payurl"`
+ PayURL2 string `json:"payurl2"` // H5 mobile payment URL
+ QRCode string `json:"qrcode"`
+ }
+ if err := json.Unmarshal(body, &resp); err != nil {
+ return nil, fmt.Errorf("easypay parse: %w", err)
+ }
+ if resp.Code != easypayCodeSuccess {
+ return nil, fmt.Errorf("easypay error: %s", resp.Msg)
+ }
+ payURL := resp.PayURL
+ if req.IsMobile && resp.PayURL2 != "" {
+ payURL = resp.PayURL2
+ }
+ return &payment.CreatePaymentResponse{TradeNo: resp.TradeNo, PayURL: payURL, QRCode: resp.QRCode}, nil
+}
+
+// resolveURLs returns (notifyURL, returnURL) preferring request values,
+// falling back to instance config.
+func (e *EasyPay) resolveURLs(req payment.CreatePaymentRequest) (string, string) {
+ notifyURL := req.NotifyURL
+ if notifyURL == "" {
+ notifyURL = e.config["notifyUrl"]
+ }
+ returnURL := req.ReturnURL
+ if returnURL == "" {
+ returnURL = e.config["returnUrl"]
+ }
+ return notifyURL, returnURL
+}
+
+func (e *EasyPay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
+ params := map[string]string{
+ "act": "order", "pid": e.config["pid"],
+ "key": e.config["pkey"], "out_trade_no": tradeNo,
+ }
+ body, err := e.post(ctx, e.apiBase()+"/api.php", params)
+ if err != nil {
+ return nil, fmt.Errorf("easypay query: %w", err)
+ }
+ var resp struct {
+ Code int `json:"code"`
+ Msg string `json:"msg"`
+ Status int `json:"status"`
+ Money string `json:"money"`
+ }
+ if err := json.Unmarshal(body, &resp); err != nil {
+ return nil, fmt.Errorf("easypay parse query: %w", err)
+ }
+ status := payment.ProviderStatusPending
+ if resp.Status == easypayStatusPaid {
+ status = payment.ProviderStatusPaid
+ }
+ amount, _ := strconv.ParseFloat(resp.Money, 64)
+ return &payment.QueryOrderResponse{
+ TradeNo: tradeNo,
+ Status: status,
+ Amount: amount,
+ Metadata: e.MerchantIdentityMetadata(),
+ }, nil
+}
+
+func (e *EasyPay) VerifyNotification(_ context.Context, rawBody string, _ map[string]string) (*payment.PaymentNotification, error) {
+ values, err := url.ParseQuery(rawBody)
+ if err != nil {
+ return nil, fmt.Errorf("parse notify: %w", err)
+ }
+ // url.ParseQuery already decodes values — no additional decode needed.
+ params := make(map[string]string)
+ for k := range values {
+ params[k] = values.Get(k)
+ }
+ sign := params["sign"]
+ if sign == "" {
+ return nil, fmt.Errorf("missing sign")
+ }
+ if !easyPayVerifySign(params, e.config["pkey"], sign) {
+ return nil, fmt.Errorf("invalid signature")
+ }
+ status := payment.ProviderStatusFailed
+ if params["trade_status"] == tradeStatusSuccess {
+ status = payment.ProviderStatusSuccess
+ }
+ amount, _ := strconv.ParseFloat(params["money"], 64)
+
+ metadata := e.MerchantIdentityMetadata()
+ if pid := strings.TrimSpace(params["pid"]); pid != "" {
+ if metadata == nil {
+ metadata = map[string]string{}
+ }
+ metadata["pid"] = pid
+ }
+ return &payment.PaymentNotification{
+ TradeNo: params["trade_no"], OrderID: params["out_trade_no"],
+ Amount: amount, Status: status, RawData: rawBody, Metadata: metadata,
+ }, nil
+}
+
+func (e *EasyPay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
+ attempts := e.refundAttempts(req)
+ if len(attempts) == 0 {
+ return nil, fmt.Errorf("easypay refund missing order identifier")
+ }
+ var firstErr error
+ for i, attempt := range attempts {
+ body, status, err := e.postRaw(ctx, e.apiBase()+"/api.php?act=refund", attempt.params)
+ if err != nil {
+ return nil, fmt.Errorf("easypay refund request: %w", err)
+ }
+ if err := parseEasyPayRefundResponse(status, body); err != nil {
+ if firstErr == nil {
+ firstErr = err
+ }
+ if i+1 < len(attempts) && isEasyPayRefundOrderNotFound(err) {
+ continue
+ }
+ return nil, err
+ }
+ return &payment.RefundResponse{RefundID: attempt.refundID, Status: payment.ProviderStatusSuccess}, nil
+ }
+ return nil, firstErr
+}
+
+type easyPayRefundAttempt struct {
+ params map[string]string
+ refundID string
+}
+
+func (e *EasyPay) refundAttempts(req payment.RefundRequest) []easyPayRefundAttempt {
+ base := map[string]string{
+ "pid": e.config["pid"], "key": e.config["pkey"], "money": req.Amount,
+ }
+ var attempts []easyPayRefundAttempt
+ if orderID := strings.TrimSpace(req.OrderID); orderID != "" {
+ params := cloneStringMap(base)
+ params["out_trade_no"] = orderID
+ attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: orderID})
+ }
+ if tradeNo := strings.TrimSpace(req.TradeNo); tradeNo != "" {
+ params := cloneStringMap(base)
+ params["trade_no"] = tradeNo
+ attempts = append(attempts, easyPayRefundAttempt{params: params, refundID: tradeNo})
+ }
+ return attempts
+}
+
+func cloneStringMap(in map[string]string) map[string]string {
+ out := make(map[string]string, len(in))
+ for k, v := range in {
+ out[k] = v
+ }
+ return out
+}
+
+func isEasyPayRefundOrderNotFound(err error) bool {
+ if err == nil {
+ return false
+ }
+ msg := err.Error()
+ lower := strings.ToLower(msg)
+ return strings.Contains(msg, "订单编号不存在") ||
+ strings.Contains(msg, "订单不存在") ||
+ strings.Contains(lower, "order not found") ||
+ strings.Contains(lower, "not exist")
+}
+
+func parseEasyPayRefundResponse(status int, body []byte) error {
+ summary := summarizeEasyPayResponse(body)
+ if status < http.StatusOK || status >= http.StatusMultipleChoices {
+ return fmt.Errorf("easypay refund HTTP %d: %s", status, summary)
+ }
+
+ trimmed := strings.TrimSpace(string(body))
+ if trimmed == "" {
+ return fmt.Errorf("easypay refund empty response (HTTP %d): %s", status, summary)
+ }
+
+ lower := strings.ToLower(trimmed)
+ if strings.HasPrefix(lower, ""
+ }
+ if len(summary) > maxEasypayErrorSummary {
+ return summary[:maxEasypayErrorSummary] + "..."
+ }
+ return summary
+}
+
+func (e *EasyPay) resolveCID(paymentType string) string {
+ if strings.HasPrefix(paymentType, "alipay") {
+ if v := e.config["cidAlipay"]; v != "" {
+ return v
+ }
+ return e.config["cid"]
+ }
+ if v := e.config["cidWxpay"]; v != "" {
+ return v
+ }
+ return e.config["cid"]
+}
+
+func (e *EasyPay) post(ctx context.Context, endpoint string, params map[string]string) ([]byte, error) {
+ body, _, err := e.postRaw(ctx, endpoint, params)
+ return body, err
+}
+
+func (e *EasyPay) postRaw(ctx context.Context, endpoint string, params map[string]string) ([]byte, int, error) {
+ form := url.Values{}
+ for k, v := range params {
+ form.Set(k, v)
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
+ if err != nil {
+ return nil, 0, err
+ }
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ client := e.httpClient
+ if client == nil {
+ client = &http.Client{Timeout: easypayHTTPTimeout}
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ return nil, 0, err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ body, err := io.ReadAll(io.LimitReader(resp.Body, maxEasypayResponseSize))
+ if err != nil {
+ return nil, resp.StatusCode, err
+ }
+ return body, resp.StatusCode, nil
+}
+
+func easyPaySign(params map[string]string, pkey string) string {
+ keys := make([]string, 0, len(params))
+ for k, v := range params {
+ if k == "sign" || k == "sign_type" || v == "" {
+ continue
+ }
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+ var buf strings.Builder
+ for i, k := range keys {
+ if i > 0 {
+ _ = buf.WriteByte('&')
+ }
+ _, _ = buf.WriteString(k + "=" + params[k])
+ }
+ _, _ = buf.WriteString(pkey)
+ hash := md5.Sum([]byte(buf.String()))
+ return hex.EncodeToString(hash[:])
+}
+
+func easyPayVerifySign(params map[string]string, pkey string, sign string) bool {
+ return hmac.Equal([]byte(easyPaySign(params, pkey)), []byte(sign))
+}
diff --git a/backend/internal/payment/provider/easypay_refund_test.go b/backend/internal/payment/provider/easypay_refund_test.go
new file mode 100644
index 00000000..9e0e4942
--- /dev/null
+++ b/backend/internal/payment/provider/easypay_refund_test.go
@@ -0,0 +1,196 @@
+package provider
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+func TestNormalizeEasyPayAPIBase(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ input string
+ want string
+ }{
+ {input: "https://zpayz.cn", want: "https://zpayz.cn"},
+ {input: "https://zpayz.cn/", want: "https://zpayz.cn"},
+ {input: "https://zpayz.cn/mapi.php", want: "https://zpayz.cn"},
+ {input: "https://zpayz.cn/submit.php", want: "https://zpayz.cn"},
+ {input: "https://zpayz.cn/api.php", want: "https://zpayz.cn"},
+ {input: "https://zpayz.cn/api.php?act=refund", want: "https://zpayz.cn"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ t.Parallel()
+ if got := normalizeEasyPayAPIBase(tt.input); got != tt.want {
+ t.Fatalf("normalizeEasyPayAPIBase(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestEasyPayRefundNormalizesAPIBaseAndSendsOutTradeNoOnly(t *testing.T) {
+ t.Parallel()
+
+ var gotPath string
+ var gotQuery url.Values
+ var gotForm url.Values
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotPath = r.URL.Path
+ gotQuery = r.URL.Query()
+ if err := r.ParseForm(); err != nil {
+ t.Errorf("ParseForm: %v", err)
+ }
+ gotForm = r.PostForm
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
+ }))
+ defer server.Close()
+
+ provider := newTestEasyPay(t, server.URL+"/mapi.php")
+ resp, err := provider.Refund(context.Background(), payment.RefundRequest{
+ TradeNo: "trade-123",
+ OrderID: "out-456",
+ Amount: "1.50",
+ })
+ if err != nil {
+ t.Fatalf("Refund returned error: %v", err)
+ }
+ if resp == nil || resp.Status != payment.ProviderStatusSuccess {
+ t.Fatalf("Refund response = %+v, want success", resp)
+ }
+ if gotPath != "/api.php" {
+ t.Fatalf("refund path = %q, want /api.php", gotPath)
+ }
+ if gotQuery.Get("act") != "refund" {
+ t.Fatalf("refund act query = %q, want refund", gotQuery.Get("act"))
+ }
+ for key, want := range map[string]string{
+ "pid": "pid-1",
+ "key": "pkey-1",
+ "out_trade_no": "out-456",
+ "money": "1.50",
+ } {
+ if got := gotForm.Get(key); got != want {
+ t.Fatalf("form[%s] = %q, want %q (form=%v)", key, got, want, gotForm)
+ }
+ }
+ if got := gotForm.Get("trade_no"); got != "" {
+ t.Fatalf("form[trade_no] = %q, want empty (form=%v)", got, gotForm)
+ }
+}
+
+func TestEasyPayRefundRetriesWithTradeNoWhenOutTradeNoNotFound(t *testing.T) {
+ t.Parallel()
+
+ var gotForms []url.Values
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/api.php" {
+ t.Errorf("refund path = %q, want /api.php", r.URL.Path)
+ }
+ if r.URL.Query().Get("act") != "refund" {
+ t.Errorf("refund act query = %q, want refund", r.URL.Query().Get("act"))
+ }
+ if err := r.ParseForm(); err != nil {
+ t.Errorf("ParseForm: %v", err)
+ }
+ gotForms = append(gotForms, r.PostForm)
+ w.Header().Set("Content-Type", "application/json")
+ if len(gotForms) == 1 {
+ _, _ = w.Write([]byte(`{"code":0,"msg":"订单编号不存在!"}`))
+ return
+ }
+ _, _ = w.Write([]byte(`{"code":1,"msg":"ok"}`))
+ }))
+ defer server.Close()
+
+ provider := newTestEasyPay(t, server.URL+"/mapi.php")
+ resp, err := provider.Refund(context.Background(), payment.RefundRequest{
+ TradeNo: "trade-123",
+ OrderID: "out-456",
+ Amount: "1.50",
+ })
+ if err != nil {
+ t.Fatalf("Refund returned error: %v", err)
+ }
+ if resp == nil || resp.Status != payment.ProviderStatusSuccess || resp.RefundID != "trade-123" {
+ t.Fatalf("Refund response = %+v, want success with trade refund id", resp)
+ }
+ if len(gotForms) != 2 {
+ t.Fatalf("refund attempts = %d, want 2", len(gotForms))
+ }
+ if got := gotForms[0].Get("out_trade_no"); got != "out-456" {
+ t.Fatalf("first form[out_trade_no] = %q, want out-456 (form=%v)", got, gotForms[0])
+ }
+ if got := gotForms[0].Get("trade_no"); got != "" {
+ t.Fatalf("first form[trade_no] = %q, want empty (form=%v)", got, gotForms[0])
+ }
+ if got := gotForms[1].Get("trade_no"); got != "trade-123" {
+ t.Fatalf("second form[trade_no] = %q, want trade-123 (form=%v)", got, gotForms[1])
+ }
+ if got := gotForms[1].Get("out_trade_no"); got != "" {
+ t.Fatalf("second form[out_trade_no] = %q, want empty (form=%v)", got, gotForms[1])
+ }
+}
+
+func TestEasyPayRefundResponseErrors(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ statusCode int
+ body string
+ want string
+ }{
+ {name: "html response", statusCode: http.StatusOK, body: "bad config", want: "non-JSON response (HTTP 200): bad config"},
+ {name: "non json response", statusCode: http.StatusOK, body: "not json", want: "non-JSON response (HTTP 200): not json"},
+ {name: "non 2xx response", statusCode: http.StatusBadGateway, body: "bad gateway", want: "HTTP 502: bad gateway"},
+ {name: "empty response", statusCode: http.StatusOK, body: "", want: "empty response (HTTP 200): "},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(tt.statusCode)
+ _, _ = w.Write([]byte(tt.body))
+ }))
+ defer server.Close()
+
+ provider := newTestEasyPay(t, server.URL)
+ _, err := provider.Refund(context.Background(), payment.RefundRequest{
+ OrderID: "out-456",
+ Amount: "1.50",
+ })
+ if err == nil {
+ t.Fatal("Refund returned nil error")
+ }
+ if !strings.Contains(err.Error(), tt.want) {
+ t.Fatalf("Refund error = %q, want substring %q", err.Error(), tt.want)
+ }
+ })
+ }
+}
+
+func newTestEasyPay(t *testing.T, apiBase string) *EasyPay {
+ t.Helper()
+
+ provider, err := NewEasyPay("test-instance", map[string]string{
+ "pid": "pid-1",
+ "pkey": "pkey-1",
+ "apiBase": apiBase,
+ "notifyUrl": "https://example.com/notify",
+ "returnUrl": "https://example.com/return",
+ })
+ if err != nil {
+ t.Fatalf("NewEasyPay: %v", err)
+ }
+ return provider
+}
diff --git a/backend/internal/payment/provider/easypay_sign_test.go b/backend/internal/payment/provider/easypay_sign_test.go
new file mode 100644
index 00000000..8328d294
--- /dev/null
+++ b/backend/internal/payment/provider/easypay_sign_test.go
@@ -0,0 +1,195 @@
+package provider
+
+import (
+ "testing"
+)
+
+func TestEasyPaySignConsistentOutput(t *testing.T) {
+ t.Parallel()
+
+ params := map[string]string{
+ "pid": "1001",
+ "type": "alipay",
+ "out_trade_no": "ORDER123",
+ "name": "Test Product",
+ "money": "10.00",
+ }
+ pkey := "test_secret_key"
+
+ sign1 := easyPaySign(params, pkey)
+ sign2 := easyPaySign(params, pkey)
+ if sign1 != sign2 {
+ t.Fatalf("easyPaySign should be deterministic: %q != %q", sign1, sign2)
+ }
+ if len(sign1) != 32 {
+ t.Fatalf("MD5 hex should be 32 chars, got %d", len(sign1))
+ }
+}
+
+func TestEasyPaySignExcludesSignAndSignType(t *testing.T) {
+ t.Parallel()
+
+ pkey := "my_key"
+ base := map[string]string{
+ "pid": "1001",
+ "type": "alipay",
+ }
+ withSign := map[string]string{
+ "pid": "1001",
+ "type": "alipay",
+ "sign": "should_be_ignored",
+ "sign_type": "MD5",
+ }
+
+ signBase := easyPaySign(base, pkey)
+ signWithExtra := easyPaySign(withSign, pkey)
+
+ if signBase != signWithExtra {
+ t.Fatalf("sign and sign_type should be excluded: base=%q, withExtra=%q", signBase, signWithExtra)
+ }
+}
+
+func TestEasyPaySignExcludesEmptyValues(t *testing.T) {
+ t.Parallel()
+
+ pkey := "key123"
+ base := map[string]string{
+ "pid": "1001",
+ "type": "alipay",
+ }
+ withEmpty := map[string]string{
+ "pid": "1001",
+ "type": "alipay",
+ "device": "",
+ "clientip": "",
+ }
+
+ signBase := easyPaySign(base, pkey)
+ signWithEmpty := easyPaySign(withEmpty, pkey)
+
+ if signBase != signWithEmpty {
+ t.Fatalf("empty values should be excluded: base=%q, withEmpty=%q", signBase, signWithEmpty)
+ }
+}
+
+func TestEasyPayVerifySignValid(t *testing.T) {
+ t.Parallel()
+
+ params := map[string]string{
+ "pid": "1001",
+ "type": "alipay",
+ "out_trade_no": "ORDER456",
+ "money": "25.00",
+ }
+ pkey := "secret"
+
+ sign := easyPaySign(params, pkey)
+
+ // Add sign to params (as would come in a real callback)
+ params["sign"] = sign
+ params["sign_type"] = "MD5"
+
+ if !easyPayVerifySign(params, pkey, sign) {
+ t.Fatal("easyPayVerifySign should return true for a valid signature")
+ }
+}
+
+func TestEasyPayVerifySignTampered(t *testing.T) {
+ t.Parallel()
+
+ params := map[string]string{
+ "pid": "1001",
+ "type": "alipay",
+ "out_trade_no": "ORDER789",
+ "money": "50.00",
+ }
+ pkey := "secret"
+
+ sign := easyPaySign(params, pkey)
+
+ // Tamper with the amount
+ params["money"] = "99.99"
+
+ if easyPayVerifySign(params, pkey, sign) {
+ t.Fatal("easyPayVerifySign should return false for tampered params")
+ }
+}
+
+func TestEasyPayVerifySignWrongKey(t *testing.T) {
+ t.Parallel()
+
+ params := map[string]string{
+ "pid": "1001",
+ "type": "wxpay",
+ }
+
+ sign := easyPaySign(params, "correct_key")
+
+ if easyPayVerifySign(params, "wrong_key", sign) {
+ t.Fatal("easyPayVerifySign should return false with wrong key")
+ }
+}
+
+func TestEasyPaySignEmptyParams(t *testing.T) {
+ t.Parallel()
+
+ sign := easyPaySign(map[string]string{}, "key123")
+ if sign == "" {
+ t.Fatal("easyPaySign with empty params should still produce a hash")
+ }
+ if len(sign) != 32 {
+ t.Fatalf("MD5 hex should be 32 chars, got %d", len(sign))
+ }
+}
+
+func TestEasyPaySignSortOrder(t *testing.T) {
+ t.Parallel()
+
+ pkey := "test_key"
+ params1 := map[string]string{
+ "a": "1",
+ "b": "2",
+ "c": "3",
+ }
+ params2 := map[string]string{
+ "c": "3",
+ "a": "1",
+ "b": "2",
+ }
+
+ sign1 := easyPaySign(params1, pkey)
+ sign2 := easyPaySign(params2, pkey)
+
+ if sign1 != sign2 {
+ t.Fatalf("easyPaySign should be order-independent: %q != %q", sign1, sign2)
+ }
+}
+
+func TestEasyPayVerifySignWrongSignValue(t *testing.T) {
+ t.Parallel()
+
+ params := map[string]string{
+ "pid": "1001",
+ "type": "alipay",
+ }
+ pkey := "key"
+
+ if easyPayVerifySign(params, pkey, "00000000000000000000000000000000") {
+ t.Fatal("easyPayVerifySign should return false for an incorrect sign value")
+ }
+}
+
+func TestEasyPayMerchantIdentityMetadata(t *testing.T) {
+ t.Parallel()
+
+ provider := &EasyPay{
+ config: map[string]string{
+ "pid": "1001",
+ },
+ }
+
+ metadata := provider.MerchantIdentityMetadata()
+ if metadata["pid"] != "1001" {
+ t.Fatalf("pid = %q, want %q", metadata["pid"], "1001")
+ }
+}
diff --git a/backend/internal/payment/provider/factory.go b/backend/internal/payment/provider/factory.go
new file mode 100644
index 00000000..0adbd267
--- /dev/null
+++ b/backend/internal/payment/provider/factory.go
@@ -0,0 +1,23 @@
+package provider
+
+import (
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+)
+
+// CreateProvider creates a Provider from a provider key, instance ID and decrypted config.
+func CreateProvider(providerKey string, instanceID string, config map[string]string) (payment.Provider, error) {
+ switch providerKey {
+ case payment.TypeEasyPay:
+ return NewEasyPay(instanceID, config)
+ case payment.TypeAlipay:
+ return NewAlipay(instanceID, config)
+ case payment.TypeWxpay:
+ return NewWxpay(instanceID, config)
+ case payment.TypeStripe:
+ return NewStripe(instanceID, config)
+ default:
+ return nil, fmt.Errorf("unknown provider key: %s", providerKey)
+ }
+}
diff --git a/backend/internal/payment/provider/stripe.go b/backend/internal/payment/provider/stripe.go
new file mode 100644
index 00000000..15359d45
--- /dev/null
+++ b/backend/internal/payment/provider/stripe.go
@@ -0,0 +1,262 @@
+package provider
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "sync"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ stripe "github.com/stripe/stripe-go/v85"
+ "github.com/stripe/stripe-go/v85/webhook"
+)
+
+// Stripe constants.
+const (
+ stripeCurrency = "cny"
+ stripeEventPaymentSuccess = "payment_intent.succeeded"
+ stripeEventPaymentFailed = "payment_intent.payment_failed"
+)
+
+// Stripe implements the payment.CancelableProvider interface for Stripe payments.
+type Stripe struct {
+ instanceID string
+ config map[string]string
+
+ mu sync.Mutex
+ initialized bool
+ sc *stripe.Client
+}
+
+// NewStripe creates a new Stripe provider instance.
+func NewStripe(instanceID string, config map[string]string) (*Stripe, error) {
+ if config["secretKey"] == "" {
+ return nil, fmt.Errorf("stripe config missing required key: secretKey")
+ }
+ return &Stripe{
+ instanceID: instanceID,
+ config: config,
+ }, nil
+}
+
+func (s *Stripe) ensureInit() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if !s.initialized {
+ s.sc = stripe.NewClient(s.config["secretKey"])
+ s.initialized = true
+ }
+}
+
+// GetPublishableKey returns the publishable key for frontend use.
+func (s *Stripe) GetPublishableKey() string {
+ return s.config["publishableKey"]
+}
+
+func (s *Stripe) Name() string { return "Stripe" }
+func (s *Stripe) ProviderKey() string { return payment.TypeStripe }
+func (s *Stripe) SupportedTypes() []payment.PaymentType {
+ return []payment.PaymentType{payment.TypeStripe}
+}
+
+// stripePaymentMethodTypes maps our PaymentType to Stripe payment_method_types.
+var stripePaymentMethodTypes = map[string][]string{
+ payment.TypeCard: {"card"},
+ payment.TypeAlipay: {"alipay"},
+ payment.TypeWxpay: {"wechat_pay"},
+ payment.TypeLink: {"link"},
+}
+
+// CreatePayment creates a Stripe PaymentIntent.
+func (s *Stripe) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ s.ensureInit()
+
+ amountInCents, err := payment.YuanToFen(req.Amount)
+ if err != nil {
+ return nil, fmt.Errorf("stripe create payment: %w", err)
+ }
+
+ // Collect all Stripe payment_method_types from the instance's configured sub-methods
+ methods := resolveStripeMethodTypes(req.InstanceSubMethods)
+
+ pmTypes := make([]*string, len(methods))
+ for i, m := range methods {
+ pmTypes[i] = stripe.String(m)
+ }
+
+ params := &stripe.PaymentIntentCreateParams{
+ Amount: stripe.Int64(amountInCents),
+ Currency: stripe.String(stripeCurrency),
+ PaymentMethodTypes: pmTypes,
+ Description: stripe.String(req.Subject),
+ Metadata: map[string]string{"orderId": req.OrderID},
+ }
+
+ // WeChat Pay requires payment_method_options with client type
+ if hasStripeMethod(methods, "wechat_pay") {
+ params.PaymentMethodOptions = &stripe.PaymentIntentCreatePaymentMethodOptionsParams{
+ WeChatPay: &stripe.PaymentIntentCreatePaymentMethodOptionsWeChatPayParams{
+ Client: stripe.String("web"),
+ },
+ }
+ }
+
+ params.SetIdempotencyKey(fmt.Sprintf("pi-%s", req.OrderID))
+ params.Context = ctx
+
+ pi, err := s.sc.V1PaymentIntents.Create(ctx, params)
+ if err != nil {
+ return nil, fmt.Errorf("stripe create payment: %w", err)
+ }
+
+ return &payment.CreatePaymentResponse{
+ TradeNo: pi.ID,
+ ClientSecret: pi.ClientSecret,
+ }, nil
+}
+
+// QueryOrder retrieves a PaymentIntent by ID.
+func (s *Stripe) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
+ s.ensureInit()
+
+ pi, err := s.sc.V1PaymentIntents.Retrieve(ctx, tradeNo, nil)
+ if err != nil {
+ return nil, fmt.Errorf("stripe query order: %w", err)
+ }
+
+ status := payment.ProviderStatusPending
+ switch pi.Status {
+ case stripe.PaymentIntentStatusSucceeded:
+ status = payment.ProviderStatusPaid
+ case stripe.PaymentIntentStatusCanceled:
+ status = payment.ProviderStatusFailed
+ }
+
+ return &payment.QueryOrderResponse{
+ TradeNo: pi.ID,
+ Status: status,
+ Amount: payment.FenToYuan(pi.Amount),
+ }, nil
+}
+
+// VerifyNotification verifies a Stripe webhook event.
+func (s *Stripe) VerifyNotification(_ context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
+ s.ensureInit()
+
+ webhookSecret := s.config["webhookSecret"]
+ if webhookSecret == "" {
+ return nil, fmt.Errorf("stripe webhookSecret not configured")
+ }
+
+ sig := headers["stripe-signature"]
+ if sig == "" {
+ return nil, fmt.Errorf("stripe notification missing stripe-signature header")
+ }
+
+ event, err := webhook.ConstructEvent([]byte(rawBody), sig, webhookSecret)
+ if err != nil {
+ return nil, fmt.Errorf("stripe verify notification: %w", err)
+ }
+
+ switch event.Type {
+ case stripeEventPaymentSuccess:
+ return parseStripePaymentIntent(&event, payment.ProviderStatusSuccess, rawBody)
+ case stripeEventPaymentFailed:
+ return parseStripePaymentIntent(&event, payment.ProviderStatusFailed, rawBody)
+ }
+
+ return nil, nil
+}
+
+func parseStripePaymentIntent(event *stripe.Event, status string, rawBody string) (*payment.PaymentNotification, error) {
+ var pi stripe.PaymentIntent
+ if err := json.Unmarshal(event.Data.Raw, &pi); err != nil {
+ return nil, fmt.Errorf("stripe parse payment_intent: %w", err)
+ }
+ return &payment.PaymentNotification{
+ TradeNo: pi.ID,
+ OrderID: pi.Metadata["orderId"],
+ Amount: payment.FenToYuan(pi.Amount),
+ Status: status,
+ RawData: rawBody,
+ }, nil
+}
+
+// Refund creates a Stripe refund.
+func (s *Stripe) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
+ s.ensureInit()
+
+ amountInCents, err := payment.YuanToFen(req.Amount)
+ if err != nil {
+ return nil, fmt.Errorf("stripe refund: %w", err)
+ }
+
+ params := &stripe.RefundCreateParams{
+ PaymentIntent: stripe.String(req.TradeNo),
+ Amount: stripe.Int64(amountInCents),
+ Reason: stripe.String(string(stripe.RefundReasonRequestedByCustomer)),
+ }
+ params.Context = ctx
+
+ r, err := s.sc.V1Refunds.Create(ctx, params)
+ if err != nil {
+ return nil, fmt.Errorf("stripe refund: %w", err)
+ }
+
+ refundStatus := payment.ProviderStatusPending
+ if r.Status == stripe.RefundStatusSucceeded {
+ refundStatus = payment.ProviderStatusSuccess
+ }
+
+ return &payment.RefundResponse{
+ RefundID: r.ID,
+ Status: refundStatus,
+ }, nil
+}
+
+// resolveStripeMethodTypes converts instance supported_types (comma-separated)
+// into Stripe API payment_method_types. Falls back to ["card"] if empty.
+func resolveStripeMethodTypes(instanceSubMethods string) []string {
+ if instanceSubMethods == "" {
+ return []string{"card"}
+ }
+ var methods []string
+ for _, t := range strings.Split(instanceSubMethods, ",") {
+ t = strings.TrimSpace(t)
+ if mapped, ok := stripePaymentMethodTypes[t]; ok {
+ methods = append(methods, mapped...)
+ }
+ }
+ if len(methods) == 0 {
+ return []string{"card"}
+ }
+ return methods
+}
+
+// hasStripeMethod checks if the given Stripe method list contains the target method.
+func hasStripeMethod(methods []string, target string) bool {
+ for _, m := range methods {
+ if m == target {
+ return true
+ }
+ }
+ return false
+}
+
+// CancelPayment cancels a pending PaymentIntent.
+func (s *Stripe) CancelPayment(ctx context.Context, tradeNo string) error {
+ s.ensureInit()
+
+ _, err := s.sc.V1PaymentIntents.Cancel(ctx, tradeNo, nil)
+ if err != nil {
+ return fmt.Errorf("stripe cancel payment: %w", err)
+ }
+ return nil
+}
+
+// Ensure interface compliance.
+var (
+ _ payment.Provider = (*Stripe)(nil)
+ _ payment.CancelableProvider = (*Stripe)(nil)
+)
diff --git a/backend/internal/payment/provider/wxpay.go b/backend/internal/payment/provider/wxpay.go
new file mode 100644
index 00000000..e6291dd3
--- /dev/null
+++ b/backend/internal/payment/provider/wxpay.go
@@ -0,0 +1,527 @@
+package provider
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/wechatpay-apiv3/wechatpay-go/core"
+ "github.com/wechatpay-apiv3/wechatpay-go/core/auth/verifiers"
+ "github.com/wechatpay-apiv3/wechatpay-go/core/notify"
+ "github.com/wechatpay-apiv3/wechatpay-go/core/option"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/refunddomestic"
+ "github.com/wechatpay-apiv3/wechatpay-go/utils"
+)
+
+// WeChat Pay constants.
+const (
+ wxpayCurrency = "CNY"
+ wxpayH5Type = "Wap"
+ wxpayResultPath = "/payment/result"
+)
+
+const (
+ wxpayMetadataAppID = "appid"
+ wxpayMetadataMerchantID = "mchid"
+ wxpayMetadataCurrency = "currency"
+ wxpayMetadataTradeState = "trade_state"
+)
+
+// WeChat Pay create-payment modes.
+const (
+ wxpayModeNative = "native"
+ wxpayModeH5 = "h5"
+ wxpayModeJSAPI = "jsapi"
+)
+
+// WeChat Pay trade states.
+const (
+ wxpayTradeStateSuccess = "SUCCESS"
+ wxpayTradeStateRefund = "REFUND"
+ wxpayTradeStateClosed = "CLOSED"
+ wxpayTradeStatePayError = "PAYERROR"
+)
+
+// WeChat Pay notification event types.
+const (
+ wxpayEventTransactionSuccess = "TRANSACTION.SUCCESS"
+)
+
+var (
+ wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
+ return svc.Prepay(ctx, req)
+ }
+ wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
+ return svc.Prepay(ctx, req)
+ }
+ wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
+ return svc.PrepayWithRequestPayment(ctx, req)
+ }
+)
+
+type Wxpay struct {
+ instanceID string
+ config map[string]string
+ mu sync.Mutex
+ coreClient *core.Client
+ notifyHandler *notify.Handler
+}
+
+const wxpayAPIv3KeyLength = 32
+
+func NewWxpay(instanceID string, config map[string]string) (*Wxpay, error) {
+ // All fields are required. Platform-certificate mode is intentionally unsupported —
+ // WeChat has been migrating all merchants to the pubkey verifier since 2024-10,
+ // and newly-provisioned merchants cannot download platform certificates at all.
+ required := []string{"appId", "mchId", "privateKey", "apiV3Key", "certSerial", "publicKey", "publicKeyId"}
+ for _, k := range required {
+ if config[k] == "" {
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_MISSING_KEY", "missing_required_key").
+ WithMetadata(map[string]string{"key": k})
+ }
+ }
+ if len(config["apiV3Key"]) != wxpayAPIv3KeyLength {
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY_LENGTH", "invalid_key_length").
+ WithMetadata(map[string]string{
+ "key": "apiV3Key",
+ "expected": strconv.Itoa(wxpayAPIv3KeyLength),
+ "actual": strconv.Itoa(len(config["apiV3Key"])),
+ })
+ }
+ // Parse PEMs eagerly so malformed keys surface at save time, not at order creation.
+ if _, err := utils.LoadPrivateKey(formatPEM(config["privateKey"], "PRIVATE KEY")); err != nil {
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
+ WithMetadata(map[string]string{"key": "privateKey"})
+ }
+ if _, err := utils.LoadPublicKey(formatPEM(config["publicKey"], "PUBLIC KEY")); err != nil {
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
+ WithMetadata(map[string]string{"key": "publicKey"})
+ }
+ return &Wxpay{instanceID: instanceID, config: config}, nil
+}
+
+func (w *Wxpay) Name() string { return "Wxpay" }
+func (w *Wxpay) ProviderKey() string { return payment.TypeWxpay }
+func (w *Wxpay) SupportedTypes() []payment.PaymentType {
+ return []payment.PaymentType{payment.TypeWxpay}
+}
+
+// ResolveWxpayJSAPIAppID returns the AppID that JSAPI prepay will use for a
+// given provider config. A dedicated MP AppID takes precedence over the base
+// merchant AppID.
+func ResolveWxpayJSAPIAppID(config map[string]string) string {
+ if appID := strings.TrimSpace(config["mpAppId"]); appID != "" {
+ return appID
+ }
+ return strings.TrimSpace(config["appId"])
+}
+
+func formatPEM(key, keyType string) string {
+ key = strings.TrimSpace(key)
+ if strings.HasPrefix(key, "-----BEGIN") {
+ return key
+ }
+ return fmt.Sprintf("-----BEGIN %s-----\n%s\n-----END %s-----", keyType, key, keyType)
+}
+
+func (w *Wxpay) ensureClient() (*core.Client, error) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if w.coreClient != nil {
+ return w.coreClient, nil
+ }
+ privateKey, err := utils.LoadPrivateKey(formatPEM(w.config["privateKey"], "PRIVATE KEY"))
+ if err != nil {
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
+ WithMetadata(map[string]string{"key": "privateKey"})
+ }
+ publicKey, err := utils.LoadPublicKey(formatPEM(w.config["publicKey"], "PUBLIC KEY"))
+ if err != nil {
+ return nil, infraerrors.BadRequest("WXPAY_CONFIG_INVALID_KEY", "invalid_key").
+ WithMetadata(map[string]string{"key": "publicKey"})
+ }
+ verifier := verifiers.NewSHA256WithRSAPubkeyVerifier(w.config["publicKeyId"], *publicKey)
+ client, err := core.NewClient(context.Background(),
+ option.WithMerchantCredential(w.config["mchId"], w.config["certSerial"], privateKey),
+ option.WithVerifier(verifier))
+ if err != nil {
+ return nil, fmt.Errorf("wxpay init client: %w", err)
+ }
+ handler, err := notify.NewRSANotifyHandler(w.config["apiV3Key"], verifier)
+ if err != nil {
+ return nil, fmt.Errorf("wxpay init notify handler: %w", err)
+ }
+ w.notifyHandler = handler
+ w.coreClient = client
+ return w.coreClient, nil
+}
+
+func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
+ client, err := w.ensureClient()
+ if err != nil {
+ return nil, err
+ }
+ // Request-first, config-fallback (consistent with EasyPay/Alipay)
+ notifyURL := req.NotifyURL
+ if notifyURL == "" {
+ notifyURL = w.config["notifyUrl"]
+ }
+ if notifyURL == "" {
+ return nil, fmt.Errorf("wxpay notifyUrl is required")
+ }
+ totalFen, err := payment.YuanToFen(req.Amount)
+ if err != nil {
+ return nil, fmt.Errorf("wxpay create payment: %w", err)
+ }
+
+ mode, err := resolveWxpayCreateMode(req)
+ if err != nil {
+ return nil, err
+ }
+ switch mode {
+ case wxpayModeJSAPI:
+ return w.prepayJSAPI(ctx, client, req, notifyURL, totalFen)
+ case wxpayModeH5:
+ return w.prepayH5(ctx, client, req, notifyURL, totalFen)
+ case wxpayModeNative:
+ return w.prepayNative(ctx, client, req, notifyURL, totalFen)
+ default:
+ return nil, fmt.Errorf("wxpay create payment: unsupported mode %q", mode)
+ }
+}
+
+func (w *Wxpay) prepayJSAPI(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
+ svc := jsapi.JsapiApiService{Client: c}
+ cur := wxpayCurrency
+ appID := ResolveWxpayJSAPIAppID(w.config)
+ prepayReq := jsapi.PrepayRequest{
+ Appid: core.String(appID),
+ Mchid: core.String(w.config["mchId"]),
+ Description: core.String(req.Subject),
+ OutTradeNo: core.String(req.OrderID),
+ NotifyUrl: core.String(notifyURL),
+ Amount: &jsapi.Amount{Total: core.Int64(totalFen), Currency: &cur},
+ Payer: &jsapi.Payer{Openid: core.String(strings.TrimSpace(req.OpenID))},
+ }
+ if clientIP := strings.TrimSpace(req.ClientIP); clientIP != "" {
+ prepayReq.SceneInfo = &jsapi.SceneInfo{PayerClientIp: core.String(clientIP)}
+ }
+ resp, _, err := wxpayJSAPIPrepayWithRequestPayment(ctx, svc, prepayReq)
+ if err != nil {
+ return nil, fmt.Errorf("wxpay jsapi prepay: %w", err)
+ }
+ return &payment.CreatePaymentResponse{
+ TradeNo: req.OrderID,
+ ResultType: payment.CreatePaymentResultJSAPIReady,
+ JSAPI: &payment.WechatJSAPIPayload{
+ AppID: wxSV(resp.Appid),
+ TimeStamp: wxSV(resp.TimeStamp),
+ NonceStr: wxSV(resp.NonceStr),
+ Package: wxSV(resp.Package),
+ SignType: wxSV(resp.SignType),
+ PaySign: wxSV(resp.PaySign),
+ },
+ }, nil
+}
+
+func (w *Wxpay) prepayNative(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
+ svc := native.NativeApiService{Client: c}
+ cur := wxpayCurrency
+ resp, _, err := wxpayNativePrepay(ctx, svc, native.PrepayRequest{
+ Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]),
+ Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID),
+ NotifyUrl: core.String(notifyURL),
+ Amount: &native.Amount{Total: core.Int64(totalFen), Currency: &cur},
+ })
+ if err != nil {
+ return nil, fmt.Errorf("wxpay native prepay: %w", err)
+ }
+ codeURL := ""
+ if resp.CodeUrl != nil {
+ codeURL = *resp.CodeUrl
+ }
+ return &payment.CreatePaymentResponse{TradeNo: req.OrderID, QRCode: codeURL}, nil
+}
+
+func (w *Wxpay) prepayH5(ctx context.Context, c *core.Client, req payment.CreatePaymentRequest, notifyURL string, totalFen int64) (*payment.CreatePaymentResponse, error) {
+ svc := h5.H5ApiService{Client: c}
+ cur := wxpayCurrency
+ resp, _, err := wxpayH5Prepay(ctx, svc, h5.PrepayRequest{
+ Appid: core.String(w.config["appId"]), Mchid: core.String(w.config["mchId"]),
+ Description: core.String(req.Subject), OutTradeNo: core.String(req.OrderID),
+ NotifyUrl: core.String(notifyURL),
+ Amount: &h5.Amount{Total: core.Int64(totalFen), Currency: &cur},
+ SceneInfo: &h5.SceneInfo{PayerClientIp: core.String(req.ClientIP), H5Info: buildWxpayH5Info(w.config)},
+ })
+ if err != nil {
+ return nil, fmt.Errorf("wxpay h5 prepay: %w", err)
+ }
+ h5URL := ""
+ if resp.H5Url != nil {
+ h5URL = *resp.H5Url
+ }
+ h5URL, err = appendWxpayRedirectURL(h5URL, req)
+ if err != nil {
+ return nil, err
+ }
+ return &payment.CreatePaymentResponse{TradeNo: req.OrderID, PayURL: h5URL}, nil
+}
+
+func buildWxpayH5Info(config map[string]string) *h5.H5Info {
+ tp := wxpayH5Type
+ info := &h5.H5Info{Type: &tp}
+ if appName := strings.TrimSpace(config["h5AppName"]); appName != "" {
+ info.AppName = core.String(appName)
+ }
+ if appURL := strings.TrimSpace(config["h5AppUrl"]); appURL != "" {
+ info.AppUrl = core.String(appURL)
+ }
+ return info
+}
+
+func resolveWxpayCreateMode(req payment.CreatePaymentRequest) (string, error) {
+ if strings.TrimSpace(req.OpenID) != "" {
+ return wxpayModeJSAPI, nil
+ }
+ if req.IsMobile {
+ if strings.TrimSpace(req.ClientIP) == "" {
+ return "", fmt.Errorf("wxpay H5 payment requires client IP")
+ }
+ return wxpayModeH5, nil
+ }
+ return wxpayModeNative, nil
+}
+
+func appendWxpayRedirectURL(h5URL string, req payment.CreatePaymentRequest) (string, error) {
+ h5URL = strings.TrimSpace(h5URL)
+ returnURL := strings.TrimSpace(req.ReturnURL)
+ if h5URL == "" || returnURL == "" {
+ return h5URL, nil
+ }
+
+ redirectURL, err := buildWxpayResultURL(returnURL, req)
+ if err != nil {
+ return "", err
+ }
+
+ sep := "&"
+ if !strings.Contains(h5URL, "?") {
+ sep = "?"
+ }
+ return h5URL + sep + "redirect_url=" + url.QueryEscape(redirectURL), nil
+}
+
+func buildWxpayResultURL(returnURL string, req payment.CreatePaymentRequest) (string, error) {
+ u, err := url.Parse(returnURL)
+ if err != nil || !u.IsAbs() || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") {
+ return "", fmt.Errorf("return URL must be an absolute http(s) URL")
+ }
+
+ values := u.Query()
+ values.Set("out_trade_no", strings.TrimSpace(req.OrderID))
+ if paymentType := strings.TrimSpace(req.PaymentType); paymentType != "" {
+ values.Set("payment_type", paymentType)
+ }
+ if strings.TrimSpace(u.Path) == "" {
+ u.Path = wxpayResultPath
+ }
+ u.RawPath = ""
+ u.RawQuery = values.Encode()
+ u.Fragment = ""
+ return u.String(), nil
+}
+
+func wxSV(s *string) string {
+ if s == nil {
+ return ""
+ }
+ return *s
+}
+
+func mapWxState(s string) string {
+ switch s {
+ case wxpayTradeStateSuccess:
+ return payment.ProviderStatusPaid
+ case wxpayTradeStateRefund:
+ return payment.ProviderStatusRefunded
+ case wxpayTradeStateClosed, wxpayTradeStatePayError:
+ return payment.ProviderStatusFailed
+ default:
+ return payment.ProviderStatusPending
+ }
+}
+
+func buildWxpayTransactionMetadata(tx *payments.Transaction) map[string]string {
+ if tx == nil {
+ return nil
+ }
+
+ metadata := map[string]string{}
+ if appID := wxSV(tx.Appid); appID != "" {
+ metadata[wxpayMetadataAppID] = appID
+ }
+ if merchantID := wxSV(tx.Mchid); merchantID != "" {
+ metadata[wxpayMetadataMerchantID] = merchantID
+ }
+ if tradeState := wxSV(tx.TradeState); tradeState != "" {
+ metadata[wxpayMetadataTradeState] = tradeState
+ }
+ if tx.Amount != nil {
+ if currency := wxSV(tx.Amount.Currency); currency != "" {
+ metadata[wxpayMetadataCurrency] = currency
+ }
+ }
+ if len(metadata) == 0 {
+ return nil
+ }
+ return metadata
+}
+
+func (w *Wxpay) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) {
+ c, err := w.ensureClient()
+ if err != nil {
+ return nil, err
+ }
+ svc := native.NativeApiService{Client: c}
+ tx, _, err := svc.QueryOrderByOutTradeNo(ctx, native.QueryOrderByOutTradeNoRequest{
+ OutTradeNo: core.String(tradeNo), Mchid: core.String(w.config["mchId"]),
+ })
+ if err != nil {
+ return nil, fmt.Errorf("wxpay query order: %w", err)
+ }
+ var amt float64
+ if tx.Amount != nil && tx.Amount.Total != nil {
+ amt = payment.FenToYuan(*tx.Amount.Total)
+ }
+ id := tradeNo
+ if tx.TransactionId != nil {
+ id = *tx.TransactionId
+ }
+ pa := ""
+ if tx.SuccessTime != nil {
+ pa = *tx.SuccessTime
+ }
+ return &payment.QueryOrderResponse{
+ TradeNo: id,
+ Status: mapWxState(wxSV(tx.TradeState)),
+ Amount: amt,
+ PaidAt: pa,
+ Metadata: buildWxpayTransactionMetadata(tx),
+ }, nil
+}
+
+func (w *Wxpay) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) {
+ if _, err := w.ensureClient(); err != nil {
+ return nil, err
+ }
+ r, err := http.NewRequestWithContext(ctx, http.MethodPost, "/", io.NopCloser(bytes.NewBufferString(rawBody)))
+ if err != nil {
+ return nil, fmt.Errorf("wxpay construct request: %w", err)
+ }
+ for k, v := range headers {
+ r.Header.Set(k, v)
+ }
+ var tx payments.Transaction
+ nr, err := w.notifyHandler.ParseNotifyRequest(ctx, r, &tx)
+ if err != nil {
+ return nil, fmt.Errorf("wxpay verify notification: %w", err)
+ }
+ if nr.EventType != wxpayEventTransactionSuccess {
+ return nil, nil
+ }
+ var amt float64
+ if tx.Amount != nil && tx.Amount.Total != nil {
+ amt = payment.FenToYuan(*tx.Amount.Total)
+ }
+ st := payment.ProviderStatusFailed
+ if wxSV(tx.TradeState) == wxpayTradeStateSuccess {
+ st = payment.ProviderStatusSuccess
+ }
+ return &payment.PaymentNotification{
+ TradeNo: wxSV(tx.TransactionId), OrderID: wxSV(tx.OutTradeNo),
+ Amount: amt, Status: st, RawData: rawBody, Metadata: buildWxpayTransactionMetadata(&tx),
+ }, nil
+}
+
+func (w *Wxpay) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) {
+ c, err := w.ensureClient()
+ if err != nil {
+ return nil, err
+ }
+ rf, err := payment.YuanToFen(req.Amount)
+ if err != nil {
+ return nil, fmt.Errorf("wxpay refund amount: %w", err)
+ }
+ tf, err := w.queryOrderTotalFen(ctx, c, req.OrderID)
+ if err != nil {
+ return nil, err
+ }
+ rs := refunddomestic.RefundsApiService{Client: c}
+ cur := wxpayCurrency
+ res, _, err := rs.Create(ctx, refunddomestic.CreateRequest{
+ OutTradeNo: core.String(req.OrderID),
+ OutRefundNo: core.String(fmt.Sprintf("%s-refund-%d", req.OrderID, time.Now().UnixNano())),
+ Reason: core.String(req.Reason),
+ Amount: &refunddomestic.AmountReq{Refund: core.Int64(rf), Total: core.Int64(tf), Currency: &cur},
+ })
+ if err != nil {
+ return nil, fmt.Errorf("wxpay refund: %w", err)
+ }
+ rid := wxSV(res.RefundId)
+ if rid == "" {
+ rid = fmt.Sprintf("%s-refund", req.OrderID)
+ }
+ st := payment.ProviderStatusPending
+ if res.Status != nil && *res.Status == refunddomestic.STATUS_SUCCESS {
+ st = payment.ProviderStatusSuccess
+ }
+ return &payment.RefundResponse{RefundID: rid, Status: st}, nil
+}
+
+func (w *Wxpay) queryOrderTotalFen(ctx context.Context, c *core.Client, orderID string) (int64, error) {
+ svc := native.NativeApiService{Client: c}
+ tx, _, err := svc.QueryOrderByOutTradeNo(ctx, native.QueryOrderByOutTradeNoRequest{
+ OutTradeNo: core.String(orderID), Mchid: core.String(w.config["mchId"]),
+ })
+ if err != nil {
+ return 0, fmt.Errorf("wxpay refund query order: %w", err)
+ }
+ var tf int64
+ if tx.Amount != nil && tx.Amount.Total != nil {
+ tf = *tx.Amount.Total
+ }
+ return tf, nil
+}
+
+func (w *Wxpay) CancelPayment(ctx context.Context, tradeNo string) error {
+ c, err := w.ensureClient()
+ if err != nil {
+ return err
+ }
+ svc := native.NativeApiService{Client: c}
+ _, err = svc.CloseOrder(ctx, native.CloseOrderRequest{
+ OutTradeNo: core.String(tradeNo), Mchid: core.String(w.config["mchId"]),
+ })
+ if err != nil {
+ return fmt.Errorf("wxpay cancel payment: %w", err)
+ }
+ return nil
+}
+
+var (
+ _ payment.Provider = (*Wxpay)(nil)
+ _ payment.CancelableProvider = (*Wxpay)(nil)
+)
diff --git a/backend/internal/payment/provider/wxpay_test.go b/backend/internal/payment/provider/wxpay_test.go
new file mode 100644
index 00000000..e8ac5e54
--- /dev/null
+++ b/backend/internal/payment/provider/wxpay_test.go
@@ -0,0 +1,709 @@
+//go:build unit
+
+package provider
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/pem"
+ "errors"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/payment"
+ "github.com/wechatpay-apiv3/wechatpay-go/core"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi"
+ "github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
+)
+
+// generateTestKeyPair returns a fresh RSA 2048 key pair as PEM strings.
+// The wechatpay-go SDK expects PKCS8 private keys and PKIX public keys.
+func generateTestKeyPair(t *testing.T) (privPEM, pubPEM string) {
+ t.Helper()
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ t.Fatalf("generate rsa key: %v", err)
+ }
+ privDER, err := x509.MarshalPKCS8PrivateKey(key)
+ if err != nil {
+ t.Fatalf("marshal pkcs8: %v", err)
+ }
+ pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey)
+ if err != nil {
+ t.Fatalf("marshal pkix: %v", err)
+ }
+ return string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privDER})),
+ string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER}))
+}
+
+func TestMapWxState(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {
+ name: "SUCCESS maps to paid",
+ input: wxpayTradeStateSuccess,
+ want: payment.ProviderStatusPaid,
+ },
+ {
+ name: "REFUND maps to refunded",
+ input: wxpayTradeStateRefund,
+ want: payment.ProviderStatusRefunded,
+ },
+ {
+ name: "CLOSED maps to failed",
+ input: wxpayTradeStateClosed,
+ want: payment.ProviderStatusFailed,
+ },
+ {
+ name: "PAYERROR maps to failed",
+ input: wxpayTradeStatePayError,
+ want: payment.ProviderStatusFailed,
+ },
+ {
+ name: "unknown state maps to pending",
+ input: "NOTPAY",
+ want: payment.ProviderStatusPending,
+ },
+ {
+ name: "empty string maps to pending",
+ input: "",
+ want: payment.ProviderStatusPending,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := mapWxState(tt.input)
+ if got != tt.want {
+ t.Errorf("mapWxState(%q) = %q, want %q", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestWxSV(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input *string
+ want string
+ }{
+ {
+ name: "nil pointer returns empty string",
+ input: nil,
+ want: "",
+ },
+ {
+ name: "non-nil pointer returns value",
+ input: strPtr("hello"),
+ want: "hello",
+ },
+ {
+ name: "pointer to empty string returns empty string",
+ input: strPtr(""),
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := wxSV(tt.input)
+ if got != tt.want {
+ t.Errorf("wxSV() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestBuildWxpayTransactionMetadata(t *testing.T) {
+ t.Parallel()
+
+ tx := &payments.Transaction{
+ Appid: strPtr("wx-app-id"),
+ Mchid: strPtr("mch-id"),
+ TradeState: strPtr(wxpayTradeStateSuccess),
+ Amount: &payments.TransactionAmount{
+ Currency: strPtr(wxpayCurrency),
+ },
+ }
+
+ metadata := buildWxpayTransactionMetadata(tx)
+ if metadata[wxpayMetadataAppID] != "wx-app-id" {
+ t.Fatalf("appid = %q", metadata[wxpayMetadataAppID])
+ }
+ if metadata[wxpayMetadataMerchantID] != "mch-id" {
+ t.Fatalf("mchid = %q", metadata[wxpayMetadataMerchantID])
+ }
+ if metadata[wxpayMetadataCurrency] != wxpayCurrency {
+ t.Fatalf("currency = %q", metadata[wxpayMetadataCurrency])
+ }
+ if metadata[wxpayMetadataTradeState] != wxpayTradeStateSuccess {
+ t.Fatalf("trade_state = %q", metadata[wxpayMetadataTradeState])
+ }
+}
+
+func strPtr(s string) *string {
+ return &s
+}
+
+func TestFormatPEM(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ key string
+ keyType string
+ want string
+ }{
+ {
+ name: "raw key gets wrapped with headers",
+ key: "MIIBIjANBgkqhki...",
+ keyType: "PUBLIC KEY",
+ want: "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhki...\n-----END PUBLIC KEY-----",
+ },
+ {
+ name: "already formatted key is returned as-is",
+ key: "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBg...\n-----END PRIVATE KEY-----",
+ keyType: "PRIVATE KEY",
+ want: "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBg...\n-----END PRIVATE KEY-----",
+ },
+ {
+ name: "key with leading/trailing whitespace is trimmed before check",
+ key: " \n MIIBIjANBgkqhki... \n ",
+ keyType: "PUBLIC KEY",
+ want: "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhki...\n-----END PUBLIC KEY-----",
+ },
+ {
+ name: "already formatted key with whitespace is trimmed and returned",
+ key: " -----BEGIN RSA PRIVATE KEY-----\ndata\n-----END RSA PRIVATE KEY----- ",
+ keyType: "RSA PRIVATE KEY",
+ want: "-----BEGIN RSA PRIVATE KEY-----\ndata\n-----END RSA PRIVATE KEY-----",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := formatPEM(tt.key, tt.keyType)
+ if got != tt.want {
+ t.Errorf("formatPEM(%q, %q) =\n%s\nwant:\n%s", tt.key, tt.keyType, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestNewWxpay(t *testing.T) {
+ t.Parallel()
+
+ privPEM, pubPEM := generateTestKeyPair(t)
+ validConfig := map[string]string{
+ "appId": "wx1234567890",
+ "mchId": "1234567890",
+ "privateKey": privPEM,
+ "apiV3Key": "12345678901234567890123456789012", // exactly 32 bytes
+ "publicKey": pubPEM,
+ "publicKeyId": "PUB_KEY_ID_TEST",
+ "certSerial": "SERIAL001",
+ }
+
+ // helper to clone and override config fields
+ withOverride := func(overrides map[string]string) map[string]string {
+ cfg := make(map[string]string, len(validConfig))
+ for k, v := range validConfig {
+ cfg[k] = v
+ }
+ for k, v := range overrides {
+ cfg[k] = v
+ }
+ return cfg
+ }
+
+ tests := []struct {
+ name string
+ config map[string]string
+ wantErr bool
+ errSubstr string
+ }{
+ {
+ name: "valid config succeeds",
+ config: validConfig,
+ wantErr: false,
+ },
+ {
+ name: "missing appId",
+ config: withOverride(map[string]string{"appId": ""}),
+ wantErr: true,
+ errSubstr: "appId",
+ },
+ {
+ name: "missing mchId",
+ config: withOverride(map[string]string{"mchId": ""}),
+ wantErr: true,
+ errSubstr: "mchId",
+ },
+ {
+ name: "missing privateKey",
+ config: withOverride(map[string]string{"privateKey": ""}),
+ wantErr: true,
+ errSubstr: "privateKey",
+ },
+ {
+ name: "missing apiV3Key",
+ config: withOverride(map[string]string{"apiV3Key": ""}),
+ wantErr: true,
+ errSubstr: "apiV3Key",
+ },
+ {
+ name: "missing certSerial",
+ config: withOverride(map[string]string{"certSerial": ""}),
+ wantErr: true,
+ errSubstr: "certSerial",
+ },
+ {
+ name: "missing publicKey",
+ config: withOverride(map[string]string{"publicKey": ""}),
+ wantErr: true,
+ errSubstr: "publicKey",
+ },
+ {
+ name: "missing publicKeyId",
+ config: withOverride(map[string]string{"publicKeyId": ""}),
+ wantErr: true,
+ errSubstr: "publicKeyId",
+ },
+ {
+ name: "malformed privateKey PEM",
+ config: withOverride(map[string]string{"privateKey": "not-a-valid-pem"}),
+ wantErr: true,
+ errSubstr: "WXPAY_CONFIG_INVALID_KEY",
+ },
+ {
+ name: "malformed publicKey PEM",
+ config: withOverride(map[string]string{"publicKey": "not-a-valid-pem"}),
+ wantErr: true,
+ errSubstr: "WXPAY_CONFIG_INVALID_KEY",
+ },
+ {
+ name: "apiV3Key too short",
+ config: withOverride(map[string]string{"apiV3Key": "short"}),
+ wantErr: true,
+ errSubstr: "WXPAY_CONFIG_INVALID_KEY_LENGTH",
+ },
+ {
+ name: "apiV3Key too long",
+ config: withOverride(map[string]string{"apiV3Key": "123456789012345678901234567890123"}), // 33 bytes
+ wantErr: true,
+ errSubstr: "WXPAY_CONFIG_INVALID_KEY_LENGTH",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got, err := NewWxpay("test-instance", tt.config)
+ if tt.wantErr {
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+ if tt.errSubstr != "" && !strings.Contains(err.Error(), tt.errSubstr) {
+ t.Errorf("error %q should contain %q", err.Error(), tt.errSubstr)
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got == nil {
+ t.Fatal("expected non-nil Wxpay instance")
+ }
+ if got.instanceID != "test-instance" {
+ t.Errorf("instanceID = %q, want %q", got.instanceID, "test-instance")
+ }
+ })
+ }
+}
+
+func TestBuildWxpayResultURLPreservesResumeToken(t *testing.T) {
+ t.Parallel()
+
+ resultURL, err := buildWxpayResultURL("https://app.example.com/payment/result?order_id=42&resume_token=resume-42&status=success", payment.CreatePaymentRequest{
+ OrderID: "sub2_42",
+ PaymentType: payment.TypeWxpay,
+ })
+ if err != nil {
+ t.Fatalf("buildWxpayResultURL returned error: %v", err)
+ }
+
+ parsed, err := url.Parse(resultURL)
+ if err != nil {
+ t.Fatalf("url.Parse returned error: %v", err)
+ }
+ query := parsed.Query()
+ if parsed.Path != wxpayResultPath {
+ t.Fatalf("path = %q, want %q", parsed.Path, wxpayResultPath)
+ }
+ if query.Get("resume_token") != "resume-42" {
+ t.Fatalf("resume_token = %q, want %q", query.Get("resume_token"), "resume-42")
+ }
+ if query.Get("order_id") != "42" {
+ t.Fatalf("order_id = %q, want %q", query.Get("order_id"), "42")
+ }
+ if query.Get("out_trade_no") != "sub2_42" {
+ t.Fatalf("out_trade_no = %q, want %q", query.Get("out_trade_no"), "sub2_42")
+ }
+}
+
+func TestResolveWxpayJSAPIAppID(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ config map[string]string
+ want string
+ }{
+ {
+ name: "prefers dedicated mp app id",
+ config: map[string]string{
+ "mpAppId": "wx-mp-app",
+ "appId": "wx-merchant-app",
+ },
+ want: "wx-mp-app",
+ },
+ {
+ name: "falls back to merchant app id",
+ config: map[string]string{
+ "appId": "wx-merchant-app",
+ },
+ want: "wx-merchant-app",
+ },
+ {
+ name: "missing app ids returns empty",
+ config: map[string]string{},
+ want: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := ResolveWxpayJSAPIAppID(tt.config); got != tt.want {
+ t.Fatalf("ResolveWxpayJSAPIAppID() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestResolveWxpayCreateMode(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ req payment.CreatePaymentRequest
+ wantMode string
+ wantErr string
+ }{
+ {
+ name: "desktop uses native",
+ req: payment.CreatePaymentRequest{},
+ wantMode: wxpayModeNative,
+ },
+ {
+ name: "mobile uses h5 when client ip is present",
+ req: payment.CreatePaymentRequest{
+ IsMobile: true,
+ ClientIP: "203.0.113.10",
+ },
+ wantMode: wxpayModeH5,
+ },
+ {
+ name: "mobile without client ip returns clear error",
+ req: payment.CreatePaymentRequest{
+ IsMobile: true,
+ },
+ wantErr: "requires client IP",
+ },
+ {
+ name: "openid uses jsapi mode",
+ req: payment.CreatePaymentRequest{
+ OpenID: "openid-123",
+ },
+ wantMode: wxpayModeJSAPI,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got, err := resolveWxpayCreateMode(tt.req)
+ if tt.wantErr != "" {
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+ if !strings.Contains(err.Error(), tt.wantErr) {
+ t.Fatalf("error %q should contain %q", err.Error(), tt.wantErr)
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got != tt.wantMode {
+ t.Fatalf("resolveWxpayCreateMode() = %q, want %q", got, tt.wantMode)
+ }
+ })
+ }
+}
+
+func TestCreatePaymentWithOpenIDReturnsJSAPIResult(t *testing.T) {
+ origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment
+ origNativePrepay := wxpayNativePrepay
+ origH5Prepay := wxpayH5Prepay
+ t.Cleanup(func() {
+ wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay
+ wxpayNativePrepay = origNativePrepay
+ wxpayH5Prepay = origH5Prepay
+ })
+
+ jsapiCalls := 0
+ nativeCalls := 0
+ h5Calls := 0
+ wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
+ jsapiCalls++
+ if got := wxSV(req.Payer.Openid); got != "openid-123" {
+ t.Fatalf("openid = %q, want %q", got, "openid-123")
+ }
+ if req.SceneInfo == nil || wxSV(req.SceneInfo.PayerClientIp) != "203.0.113.10" {
+ t.Fatalf("scene_info payer_client_ip = %q, want %q", wxSV(req.SceneInfo.PayerClientIp), "203.0.113.10")
+ }
+ return &jsapi.PrepayWithRequestPaymentResponse{
+ Appid: core.String("wx123"),
+ TimeStamp: core.String("1712345678"),
+ NonceStr: core.String("nonce-123"),
+ Package: core.String("prepay_id=wx_prepay_123"),
+ SignType: core.String("RSA"),
+ PaySign: core.String("signed-payload"),
+ }, nil, nil
+ }
+ wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
+ nativeCalls++
+ return &native.PrepayResponse{}, nil, nil
+ }
+ wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
+ h5Calls++
+ return &h5.PrepayResponse{}, nil, nil
+ }
+
+ provider := &Wxpay{
+ config: map[string]string{
+ "appId": "wx123",
+ "mchId": "mch123",
+ },
+ coreClient: &core.Client{},
+ }
+
+ resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{
+ OrderID: "sub2_88",
+ Amount: "66.88",
+ PaymentType: payment.TypeWxpay,
+ NotifyURL: "https://merchant.example/payment/notify",
+ OpenID: "openid-123",
+ ClientIP: "203.0.113.10",
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if jsapiCalls != 1 {
+ t.Fatalf("jsapi prepay calls = %d, want 1", jsapiCalls)
+ }
+ if nativeCalls != 0 {
+ t.Fatalf("native prepay calls = %d, want 0", nativeCalls)
+ }
+ if h5Calls != 0 {
+ t.Fatalf("h5 prepay calls = %d, want 0", h5Calls)
+ }
+ if resp.ResultType != payment.CreatePaymentResultJSAPIReady {
+ t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultJSAPIReady)
+ }
+ if resp.JSAPI == nil {
+ t.Fatal("expected jsapi payload, got nil")
+ }
+ if resp.JSAPI.AppID != "wx123" {
+ t.Fatalf("jsapi appId = %q, want %q", resp.JSAPI.AppID, "wx123")
+ }
+ if resp.JSAPI.TimeStamp != "1712345678" {
+ t.Fatalf("jsapi timeStamp = %q, want %q", resp.JSAPI.TimeStamp, "1712345678")
+ }
+ if resp.JSAPI.NonceStr != "nonce-123" {
+ t.Fatalf("jsapi nonceStr = %q, want %q", resp.JSAPI.NonceStr, "nonce-123")
+ }
+ if resp.JSAPI.Package != "prepay_id=wx_prepay_123" {
+ t.Fatalf("jsapi package = %q, want %q", resp.JSAPI.Package, "prepay_id=wx_prepay_123")
+ }
+ if resp.JSAPI.SignType != "RSA" {
+ t.Fatalf("jsapi signType = %q, want %q", resp.JSAPI.SignType, "RSA")
+ }
+ if resp.JSAPI.PaySign != "signed-payload" {
+ t.Fatalf("jsapi paySign = %q, want %q", resp.JSAPI.PaySign, "signed-payload")
+ }
+}
+
+func TestCreatePaymentMobileH5IncludesConfiguredSceneInfo(t *testing.T) {
+ origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment
+ origNativePrepay := wxpayNativePrepay
+ origH5Prepay := wxpayH5Prepay
+ t.Cleanup(func() {
+ wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay
+ wxpayNativePrepay = origNativePrepay
+ wxpayH5Prepay = origH5Prepay
+ })
+
+ jsapiCalls := 0
+ nativeCalls := 0
+ h5Calls := 0
+ wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
+ jsapiCalls++
+ return &jsapi.PrepayWithRequestPaymentResponse{}, nil, nil
+ }
+ wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
+ nativeCalls++
+ return &native.PrepayResponse{}, nil, nil
+ }
+ wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
+ h5Calls++
+ if req.SceneInfo == nil {
+ t.Fatal("expected scene_info, got nil")
+ }
+ if got := wxSV(req.SceneInfo.PayerClientIp); got != "203.0.113.10" {
+ t.Fatalf("scene_info payer_client_ip = %q, want %q", got, "203.0.113.10")
+ }
+ if req.SceneInfo.H5Info == nil {
+ t.Fatal("expected scene_info.h5_info, got nil")
+ }
+ if got := wxSV(req.SceneInfo.H5Info.Type); got != wxpayH5Type {
+ t.Fatalf("scene_info.h5_info.type = %q, want %q", got, wxpayH5Type)
+ }
+ if got := wxSV(req.SceneInfo.H5Info.AppName); got != "Sub2API" {
+ t.Fatalf("scene_info.h5_info.app_name = %q, want %q", got, "Sub2API")
+ }
+ if got := wxSV(req.SceneInfo.H5Info.AppUrl); got != "https://app.example.com" {
+ t.Fatalf("scene_info.h5_info.app_url = %q, want %q", got, "https://app.example.com")
+ }
+ return &h5.PrepayResponse{
+ H5Url: core.String("https://wx.tenpay.example/h5pay?prepay_id=1"),
+ }, nil, nil
+ }
+
+ provider := &Wxpay{
+ config: map[string]string{
+ "appId": "wx123",
+ "mchId": "mch123",
+ "h5AppName": "Sub2API",
+ "h5AppUrl": "https://app.example.com",
+ },
+ coreClient: &core.Client{},
+ }
+
+ resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{
+ OrderID: "sub2_99",
+ Amount: "66.88",
+ PaymentType: payment.TypeWxpay,
+ Subject: "Balance Recharge",
+ NotifyURL: "https://merchant.example/payment/notify",
+ ReturnURL: "https://merchant.example/payment/result?resume_token=resume-99",
+ ClientIP: "203.0.113.10",
+ IsMobile: true,
+ })
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if jsapiCalls != 0 {
+ t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls)
+ }
+ if nativeCalls != 0 {
+ t.Fatalf("native prepay calls = %d, want 0", nativeCalls)
+ }
+ if h5Calls != 1 {
+ t.Fatalf("h5 prepay calls = %d, want 1", h5Calls)
+ }
+ if !strings.Contains(resp.PayURL, "redirect_url=") {
+ t.Fatalf("pay_url = %q, want redirect_url query appended", resp.PayURL)
+ }
+}
+
+func TestCreatePaymentMobileH5ReturnsNoAuthErrorWithoutNativeFallback(t *testing.T) {
+ origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment
+ origNativePrepay := wxpayNativePrepay
+ origH5Prepay := wxpayH5Prepay
+ t.Cleanup(func() {
+ wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay
+ wxpayNativePrepay = origNativePrepay
+ wxpayH5Prepay = origH5Prepay
+ })
+
+ jsapiCalls := 0
+ nativeCalls := 0
+ h5Calls := 0
+ wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
+ jsapiCalls++
+ return &jsapi.PrepayWithRequestPaymentResponse{}, nil, nil
+ }
+ wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
+ h5Calls++
+ return nil, nil, errors.New("NO_AUTH")
+ }
+ wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
+ nativeCalls++
+ return &native.PrepayResponse{
+ CodeUrl: core.String("weixin://wxpay/bizpayurl?pr=fallback-native"),
+ }, nil, nil
+ }
+
+ provider := &Wxpay{
+ config: map[string]string{
+ "appId": "wx123",
+ "mchId": "mch123",
+ },
+ coreClient: &core.Client{},
+ }
+
+ resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{
+ OrderID: "sub2_100",
+ Amount: "66.88",
+ PaymentType: payment.TypeWxpay,
+ Subject: "Balance Recharge",
+ NotifyURL: "https://merchant.example/payment/notify",
+ ClientIP: "203.0.113.10",
+ IsMobile: true,
+ })
+ if err == nil {
+ t.Fatal("expected no-auth error, got nil")
+ }
+ if jsapiCalls != 0 {
+ t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls)
+ }
+ if h5Calls != 1 {
+ t.Fatalf("h5 prepay calls = %d, want 1", h5Calls)
+ }
+ if nativeCalls != 0 {
+ t.Fatalf("native prepay calls = %d, want 0", nativeCalls)
+ }
+ if resp != nil {
+ t.Fatalf("expected nil response, got %+v", resp)
+ }
+ if !strings.Contains(err.Error(), "NO_AUTH") {
+ t.Fatalf("error = %v, want NO_AUTH", err)
+ }
+}
diff --git a/backend/internal/payment/registry.go b/backend/internal/payment/registry.go
new file mode 100644
index 00000000..259eb4bb
--- /dev/null
+++ b/backend/internal/payment/registry.go
@@ -0,0 +1,85 @@
+package payment
+
+import (
+ "sync"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+// Registry is a thread-safe registry mapping PaymentType to Provider.
+type Registry struct {
+ mu sync.RWMutex
+ providers map[PaymentType]Provider
+}
+
+// ErrProviderNotFound is returned when a requested payment provider is not registered.
+var ErrProviderNotFound = infraerrors.NotFound("PROVIDER_NOT_FOUND", "payment provider not registered")
+
+// NewRegistry creates a new empty provider registry.
+func NewRegistry() *Registry {
+ return &Registry{
+ providers: make(map[PaymentType]Provider),
+ }
+}
+
+// Register adds a provider for each of its supported payment types.
+// If a type was previously registered, it is overwritten.
+func (r *Registry) Register(p Provider) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ for _, t := range p.SupportedTypes() {
+ r.providers[t] = p
+ }
+}
+
+// GetProvider returns the provider registered for the given payment type.
+func (r *Registry) GetProvider(t PaymentType) (Provider, error) {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ p, ok := r.providers[t]
+ if !ok {
+ return nil, ErrProviderNotFound
+ }
+ return p, nil
+}
+
+// GetProviderByKey returns the first provider whose ProviderKey matches the given key.
+func (r *Registry) GetProviderByKey(key string) (Provider, error) {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ for _, p := range r.providers {
+ if p.ProviderKey() == key {
+ return p, nil
+ }
+ }
+ return nil, ErrProviderNotFound
+}
+
+// GetProviderKey returns the provider key for the given payment type, or empty string if not found.
+func (r *Registry) GetProviderKey(t PaymentType) string {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ p, ok := r.providers[t]
+ if !ok {
+ return ""
+ }
+ return p.ProviderKey()
+}
+
+// SupportedTypes returns all currently registered payment types.
+func (r *Registry) SupportedTypes() []PaymentType {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ types := make([]PaymentType, 0, len(r.providers))
+ for t := range r.providers {
+ types = append(types, t)
+ }
+ return types
+}
+
+// Clear removes all registered providers.
+func (r *Registry) Clear() {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.providers = make(map[PaymentType]Provider)
+}
diff --git a/backend/internal/payment/registry_test.go b/backend/internal/payment/registry_test.go
new file mode 100644
index 00000000..9684945c
--- /dev/null
+++ b/backend/internal/payment/registry_test.go
@@ -0,0 +1,234 @@
+package payment
+
+import (
+ "context"
+ "fmt"
+ "sync"
+ "testing"
+)
+
+// mockProvider implements the Provider interface for testing.
+type mockProvider struct {
+ name string
+ key string
+ supportedTypes []PaymentType
+}
+
+func (m *mockProvider) Name() string { return m.name }
+func (m *mockProvider) ProviderKey() string { return m.key }
+func (m *mockProvider) SupportedTypes() []PaymentType { return m.supportedTypes }
+func (m *mockProvider) CreatePayment(_ context.Context, _ CreatePaymentRequest) (*CreatePaymentResponse, error) {
+ return nil, nil
+}
+func (m *mockProvider) QueryOrder(_ context.Context, _ string) (*QueryOrderResponse, error) {
+ return nil, nil
+}
+func (m *mockProvider) VerifyNotification(_ context.Context, _ string, _ map[string]string) (*PaymentNotification, error) {
+ return nil, nil
+}
+func (m *mockProvider) Refund(_ context.Context, _ RefundRequest) (*RefundResponse, error) {
+ return nil, nil
+}
+
+func TestRegistryRegisterAndGetProvider(t *testing.T) {
+ t.Parallel()
+ r := NewRegistry()
+
+ p := &mockProvider{
+ name: "TestPay",
+ key: "testpay",
+ supportedTypes: []PaymentType{TypeAlipay, TypeWxpay},
+ }
+ r.Register(p)
+
+ got, err := r.GetProvider(TypeAlipay)
+ if err != nil {
+ t.Fatalf("GetProvider(alipay) error: %v", err)
+ }
+ if got.ProviderKey() != "testpay" {
+ t.Fatalf("GetProvider(alipay) key = %q, want %q", got.ProviderKey(), "testpay")
+ }
+
+ got2, err := r.GetProvider(TypeWxpay)
+ if err != nil {
+ t.Fatalf("GetProvider(wxpay) error: %v", err)
+ }
+ if got2.ProviderKey() != "testpay" {
+ t.Fatalf("GetProvider(wxpay) key = %q, want %q", got2.ProviderKey(), "testpay")
+ }
+}
+
+func TestRegistryGetProviderNotFound(t *testing.T) {
+ t.Parallel()
+ r := NewRegistry()
+
+ _, err := r.GetProvider("nonexistent")
+ if err == nil {
+ t.Fatal("GetProvider for unregistered type should return error")
+ }
+}
+
+func TestRegistryGetProviderByKey(t *testing.T) {
+ t.Parallel()
+ r := NewRegistry()
+
+ p := &mockProvider{
+ name: "EasyPay",
+ key: "easypay",
+ supportedTypes: []PaymentType{TypeAlipay},
+ }
+ r.Register(p)
+
+ got, err := r.GetProviderByKey("easypay")
+ if err != nil {
+ t.Fatalf("GetProviderByKey error: %v", err)
+ }
+ if got.Name() != "EasyPay" {
+ t.Fatalf("GetProviderByKey name = %q, want %q", got.Name(), "EasyPay")
+ }
+}
+
+func TestRegistryGetProviderByKeyNotFound(t *testing.T) {
+ t.Parallel()
+ r := NewRegistry()
+
+ _, err := r.GetProviderByKey("nonexistent")
+ if err == nil {
+ t.Fatal("GetProviderByKey for unknown key should return error")
+ }
+}
+
+func TestRegistryGetProviderKeyUnknownType(t *testing.T) {
+ t.Parallel()
+ r := NewRegistry()
+
+ key := r.GetProviderKey("unknown_type")
+ if key != "" {
+ t.Fatalf("GetProviderKey for unknown type should return empty, got %q", key)
+ }
+}
+
+func TestRegistryGetProviderKeyKnownType(t *testing.T) {
+ t.Parallel()
+ r := NewRegistry()
+
+ p := &mockProvider{
+ name: "Stripe",
+ key: "stripe",
+ supportedTypes: []PaymentType{TypeStripe},
+ }
+ r.Register(p)
+
+ key := r.GetProviderKey(TypeStripe)
+ if key != "stripe" {
+ t.Fatalf("GetProviderKey(stripe) = %q, want %q", key, "stripe")
+ }
+}
+
+func TestRegistrySupportedTypes(t *testing.T) {
+ t.Parallel()
+ r := NewRegistry()
+
+ p1 := &mockProvider{
+ name: "EasyPay",
+ key: "easypay",
+ supportedTypes: []PaymentType{TypeAlipay, TypeWxpay},
+ }
+ p2 := &mockProvider{
+ name: "Stripe",
+ key: "stripe",
+ supportedTypes: []PaymentType{TypeStripe},
+ }
+ r.Register(p1)
+ r.Register(p2)
+
+ types := r.SupportedTypes()
+ if len(types) != 3 {
+ t.Fatalf("SupportedTypes() len = %d, want 3", len(types))
+ }
+
+ typeSet := make(map[PaymentType]bool)
+ for _, tp := range types {
+ typeSet[tp] = true
+ }
+ for _, expected := range []PaymentType{TypeAlipay, TypeWxpay, TypeStripe} {
+ if !typeSet[expected] {
+ t.Fatalf("SupportedTypes() missing %q", expected)
+ }
+ }
+}
+
+func TestRegistrySupportedTypesEmpty(t *testing.T) {
+ t.Parallel()
+ r := NewRegistry()
+
+ types := r.SupportedTypes()
+ if len(types) != 0 {
+ t.Fatalf("SupportedTypes() on empty registry should be empty, got %d", len(types))
+ }
+}
+
+func TestRegistryOverwriteExisting(t *testing.T) {
+ t.Parallel()
+ r := NewRegistry()
+
+ p1 := &mockProvider{
+ name: "OldPay",
+ key: "old",
+ supportedTypes: []PaymentType{TypeAlipay},
+ }
+ p2 := &mockProvider{
+ name: "NewPay",
+ key: "new",
+ supportedTypes: []PaymentType{TypeAlipay},
+ }
+ r.Register(p1)
+ r.Register(p2)
+
+ got, err := r.GetProvider(TypeAlipay)
+ if err != nil {
+ t.Fatalf("GetProvider error: %v", err)
+ }
+ if got.Name() != "NewPay" {
+ t.Fatalf("expected overwritten provider, got %q", got.Name())
+ }
+}
+
+func TestRegistryConcurrentAccess(t *testing.T) {
+ t.Parallel()
+ r := NewRegistry()
+
+ const goroutines = 50
+ var wg sync.WaitGroup
+ wg.Add(goroutines * 2)
+
+ // Concurrent writers
+ for i := 0; i < goroutines; i++ {
+ go func(idx int) {
+ defer wg.Done()
+ p := &mockProvider{
+ name: fmt.Sprintf("Provider-%d", idx),
+ key: fmt.Sprintf("key-%d", idx),
+ supportedTypes: []PaymentType{PaymentType(fmt.Sprintf("type-%d", idx))},
+ }
+ r.Register(p)
+ }(i)
+ }
+
+ // Concurrent readers
+ for i := 0; i < goroutines; i++ {
+ go func() {
+ defer wg.Done()
+ _ = r.SupportedTypes()
+ _, _ = r.GetProvider("some-type")
+ _ = r.GetProviderKey("some-type")
+ }()
+ }
+
+ wg.Wait()
+
+ types := r.SupportedTypes()
+ if len(types) != goroutines {
+ t.Fatalf("after concurrent registration, expected %d types, got %d", goroutines, len(types))
+ }
+}
diff --git a/backend/internal/payment/types.go b/backend/internal/payment/types.go
new file mode 100644
index 00000000..e7ac6727
--- /dev/null
+++ b/backend/internal/payment/types.go
@@ -0,0 +1,222 @@
+// Package payment provides the core payment provider abstraction,
+// registry, load balancing, and shared utilities for the payment subsystem.
+package payment
+
+import "context"
+
+// PaymentType represents a supported payment method.
+type PaymentType = string
+
+// Supported payment type constants.
+const (
+ TypeAlipay PaymentType = "alipay"
+ TypeWxpay PaymentType = "wxpay"
+ TypeAlipayDirect PaymentType = "alipay_direct"
+ TypeWxpayDirect PaymentType = "wxpay_direct"
+ TypeStripe PaymentType = "stripe"
+ TypeCard PaymentType = "card"
+ TypeLink PaymentType = "link"
+ TypeEasyPay PaymentType = "easypay"
+)
+
+// Order status constants shared across payment and service layers.
+const (
+ OrderStatusPending = "PENDING"
+ OrderStatusPaid = "PAID"
+ OrderStatusRecharging = "RECHARGING"
+ OrderStatusCompleted = "COMPLETED"
+ OrderStatusExpired = "EXPIRED"
+ OrderStatusCancelled = "CANCELLED"
+ OrderStatusFailed = "FAILED"
+ OrderStatusRefundRequested = "REFUND_REQUESTED"
+ OrderStatusRefunding = "REFUNDING"
+ OrderStatusPartiallyRefunded = "PARTIALLY_REFUNDED"
+ OrderStatusRefunded = "REFUNDED"
+ OrderStatusRefundFailed = "REFUND_FAILED"
+)
+
+// Order types distinguish balance recharges from subscription purchases.
+const (
+ OrderTypeBalance = "balance"
+ OrderTypeSubscription = "subscription"
+)
+
+// Entity statuses shared across users, groups, etc.
+const (
+ EntityStatusActive = "active"
+)
+
+// Deduction types for refund flow.
+const (
+ DeductionTypeBalance = "balance"
+ DeductionTypeSubscription = "subscription"
+ DeductionTypeNone = "none"
+)
+
+// Payment notification status values.
+const (
+ NotificationStatusSuccess = "success"
+ NotificationStatusPaid = "paid"
+)
+
+// Provider-level status constants returned by provider implementations
+// to the service layer (lowercase, distinct from OrderStatus uppercase constants).
+const (
+ ProviderStatusPending = "pending"
+ ProviderStatusPaid = "paid"
+ ProviderStatusSuccess = "success"
+ ProviderStatusFailed = "failed"
+ ProviderStatusRefunded = "refunded"
+)
+
+// DefaultLoadBalanceStrategy is the default load-balancing strategy
+// used when no strategy is configured.
+const DefaultLoadBalanceStrategy = "round-robin"
+
+// ConfigKeyPublishableKey is the config map key for Stripe's publishable key.
+const ConfigKeyPublishableKey = "publishableKey"
+
+// GetBasePaymentType extracts the base payment method from a composite key.
+// For example, "alipay_direct" -> "alipay".
+func GetBasePaymentType(t string) string {
+ switch {
+ case t == TypeEasyPay:
+ return TypeEasyPay
+ case t == TypeStripe || t == TypeCard || t == TypeLink:
+ return TypeStripe
+ case len(t) >= len(TypeAlipay) && t[:len(TypeAlipay)] == TypeAlipay:
+ return TypeAlipay
+ case len(t) >= len(TypeWxpay) && t[:len(TypeWxpay)] == TypeWxpay:
+ return TypeWxpay
+ default:
+ return t
+ }
+}
+
+// CreatePaymentRequest holds the parameters for creating a new payment.
+type CreatePaymentRequest struct {
+ OrderID string // Internal order ID
+ Amount string // Pay amount in CNY (formatted to 2 decimal places)
+ PaymentType string // e.g. "alipay", "wxpay", "stripe"
+ Subject string // Product description
+ NotifyURL string // Webhook callback URL
+ ReturnURL string // Browser redirect URL after payment
+ OpenID string // WeChat JSAPI payer OpenID when available
+ ClientIP string // Payer's IP address
+ IsMobile bool // Whether the request comes from a mobile device
+ InstanceSubMethods string // Comma-separated sub-methods from instance supported_types (for Stripe)
+}
+
+// CreatePaymentResultType describes the shape of the create-payment result.
+type CreatePaymentResultType = string
+
+const (
+ CreatePaymentResultOrderCreated CreatePaymentResultType = "order_created"
+ CreatePaymentResultOAuthRequired CreatePaymentResultType = "oauth_required"
+ CreatePaymentResultJSAPIReady CreatePaymentResultType = "jsapi_ready"
+)
+
+// WechatOAuthInfo describes the next step when WeChat OAuth is required before payment.
+type WechatOAuthInfo struct {
+ AuthorizeURL string `json:"authorize_url,omitempty"`
+ AppID string `json:"appid,omitempty"`
+ OpenID string `json:"openid,omitempty"`
+ Scope string `json:"scope,omitempty"`
+ State string `json:"state,omitempty"`
+ RedirectURL string `json:"redirect_url,omitempty"`
+}
+
+// WechatJSAPIPayload contains the fields the frontend needs to invoke WeChat JSAPI payment.
+type WechatJSAPIPayload struct {
+ AppID string `json:"appId,omitempty"`
+ TimeStamp string `json:"timeStamp,omitempty"`
+ NonceStr string `json:"nonceStr,omitempty"`
+ Package string `json:"package,omitempty"`
+ SignType string `json:"signType,omitempty"`
+ PaySign string `json:"paySign,omitempty"`
+}
+
+// CreatePaymentResponse is returned after successfully initiating a payment.
+type CreatePaymentResponse struct {
+ TradeNo string // Third-party transaction ID
+ PayURL string // H5 payment URL (alipay/wxpay)
+ QRCode string // QR code content for scanning
+ ClientSecret string // Stripe PaymentIntent client secret
+ ResultType CreatePaymentResultType // Typed result contract for frontend flows
+ OAuth *WechatOAuthInfo // WeChat OAuth bootstrap payload when required
+ JSAPI *WechatJSAPIPayload // WeChat JSAPI invocation payload when ready
+}
+
+// QueryOrderResponse describes the payment status from the upstream provider.
+type QueryOrderResponse struct {
+ TradeNo string
+ Status string // "pending", "paid", "failed", "refunded"
+ Amount float64 // Amount in CNY
+ PaidAt string // RFC3339 timestamp or empty
+ Metadata map[string]string
+}
+
+// PaymentNotification is the parsed result of a webhook/notify callback.
+type PaymentNotification struct {
+ TradeNo string
+ OrderID string
+ Amount float64
+ Status string // "success" or "failed"
+ RawData string // Raw notification body for audit
+ Metadata map[string]string
+}
+
+// RefundRequest contains the parameters for requesting a refund.
+type RefundRequest struct {
+ TradeNo string
+ OrderID string
+ Amount string // Refund amount formatted to 2 decimal places
+ Reason string
+}
+
+// RefundResponse is returned after a refund request.
+type RefundResponse struct {
+ RefundID string
+ Status string // "success", "pending", "failed"
+}
+
+// InstanceSelection holds the selected provider instance and its decrypted config.
+type InstanceSelection struct {
+ InstanceID string
+ ProviderKey string // Provider key of the selected instance (e.g. "alipay", "easypay")
+ Config map[string]string
+ SupportedTypes string // Comma-separated list of supported payment types from the instance
+ PaymentMode string // Payment display mode: "qrcode", "redirect", "popup"
+}
+
+// Provider defines the interface that all payment providers must implement.
+type Provider interface {
+ // Name returns a human-readable name for this provider.
+ Name() string
+ // ProviderKey returns the unique key identifying this provider type (e.g. "easypay").
+ ProviderKey() string
+ // SupportedTypes returns the list of payment types this provider handles.
+ SupportedTypes() []PaymentType
+ // CreatePayment initiates a payment and returns the upstream response.
+ CreatePayment(ctx context.Context, req CreatePaymentRequest) (*CreatePaymentResponse, error)
+ // QueryOrder queries the payment status of the given trade number.
+ QueryOrder(ctx context.Context, tradeNo string) (*QueryOrderResponse, error)
+ // VerifyNotification parses and verifies a webhook callback.
+ // Returns nil for unrecognized or irrelevant events (caller should return 200).
+ VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*PaymentNotification, error)
+ // Refund requests a refund from the upstream provider.
+ Refund(ctx context.Context, req RefundRequest) (*RefundResponse, error)
+}
+
+// CancelableProvider extends Provider with the ability to cancel pending payments.
+type CancelableProvider interface {
+ Provider
+ // CancelPayment cancels/expires a pending payment on the upstream platform.
+ CancelPayment(ctx context.Context, tradeNo string) error
+}
+
+// MerchantIdentityProvider exposes the current non-sensitive merchant identity
+// derived from provider configuration for snapshot consistency checks.
+type MerchantIdentityProvider interface {
+ MerchantIdentityMetadata() map[string]string
+}
diff --git a/backend/internal/payment/wire.go b/backend/internal/payment/wire.go
new file mode 100644
index 00000000..4b7f422d
--- /dev/null
+++ b/backend/internal/payment/wire.go
@@ -0,0 +1,65 @@
+package payment
+
+import (
+ "encoding/hex"
+ "fmt"
+ "log/slog"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/google/wire"
+)
+
+// EncryptionKey is a named type for the payment encryption key (AES-256, 32 bytes).
+// Using a named type avoids Wire ambiguity with other []byte parameters.
+type EncryptionKey []byte
+
+// ProvideEncryptionKey derives the payment encryption key from the TOTP encryption key in config.
+// When the key is empty, nil is returned (payment features that need encryption will be disabled).
+// When the key is non-empty but invalid (bad hex or wrong length), an error is returned
+// to prevent startup with a misconfigured encryption key.
+func ProvideEncryptionKey(cfg *config.Config) (EncryptionKey, error) {
+ if cfg == nil {
+ slog.Warn("payment encryption key not configured — encrypted payment config and resume signing will be unavailable")
+ return nil, nil
+ }
+ keyHex := strings.TrimSpace(cfg.Totp.EncryptionKey)
+ if keyHex == "" {
+ slog.Warn("payment encryption key not configured — encrypted payment config will be unavailable")
+ return nil, nil
+ }
+ // Reject auto-generated TOTP keys for payment signing.
+ // They change across restarts/instances and can silently break resume-token flows.
+ if !cfg.Totp.EncryptionKeyConfigured {
+ slog.Warn("payment encryption/signing key is not explicitly configured; set TOTP_ENCRYPTION_KEY to enable payment resume tokens")
+ return nil, nil
+ }
+ key, err := hex.DecodeString(keyHex)
+ if err != nil {
+ return nil, fmt.Errorf("invalid payment encryption key (hex decode): %w", err)
+ }
+ if len(key) != 32 {
+ return nil, fmt.Errorf("payment encryption key must be 32 bytes, got %d", len(key))
+ }
+ return EncryptionKey(key), nil
+}
+
+// ProvideRegistry creates an empty payment provider registry.
+// Providers are registered at runtime after application startup.
+func ProvideRegistry() *Registry {
+ return NewRegistry()
+}
+
+// ProvideDefaultLoadBalancer creates a DefaultLoadBalancer backed by the ent client.
+func ProvideDefaultLoadBalancer(client *dbent.Client, key EncryptionKey) *DefaultLoadBalancer {
+ return NewDefaultLoadBalancer(client, []byte(key))
+}
+
+// ProviderSet is the Wire provider set for the payment package.
+var ProviderSet = wire.NewSet(
+ ProvideEncryptionKey,
+ ProvideRegistry,
+ ProvideDefaultLoadBalancer,
+ wire.Bind(new(LoadBalancer), new(*DefaultLoadBalancer)),
+)
diff --git a/backend/internal/payment/wire_test.go b/backend/internal/payment/wire_test.go
new file mode 100644
index 00000000..1b360f89
--- /dev/null
+++ b/backend/internal/payment/wire_test.go
@@ -0,0 +1,62 @@
+package payment
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+)
+
+func TestProvideEncryptionKeySkipsAutoGeneratedTotpKey(t *testing.T) {
+ t.Parallel()
+
+ cfg := &config.Config{
+ Totp: config.TotpConfig{
+ EncryptionKey: strings.Repeat("a", 64),
+ EncryptionKeyConfigured: false,
+ },
+ }
+
+ key, err := ProvideEncryptionKey(cfg)
+ if err != nil {
+ t.Fatalf("ProvideEncryptionKey returned error: %v", err)
+ }
+ if len(key) != 0 {
+ t.Fatalf("encryption key len = %d, want 0", len(key))
+ }
+}
+
+func TestProvideEncryptionKeyUsesConfiguredTotpKey(t *testing.T) {
+ t.Parallel()
+
+ cfg := &config.Config{
+ Totp: config.TotpConfig{
+ EncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
+ EncryptionKeyConfigured: true,
+ },
+ }
+
+ key, err := ProvideEncryptionKey(cfg)
+ if err != nil {
+ t.Fatalf("ProvideEncryptionKey returned error: %v", err)
+ }
+ if len(key) != 32 {
+ t.Fatalf("encryption key len = %d, want 32", len(key))
+ }
+}
+
+func TestProvideEncryptionKeyRejectsConfiguredInvalidLength(t *testing.T) {
+ t.Parallel()
+
+ cfg := &config.Config{
+ Totp: config.TotpConfig{
+ EncryptionKey: "abcd",
+ EncryptionKeyConfigured: true,
+ },
+ }
+
+ _, err := ProvideEncryptionKey(cfg)
+ if err == nil {
+ t.Fatal("expected error for invalid key length")
+ }
+}
diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go
index 8ea87f18..0b8ae5f2 100644
--- a/backend/internal/pkg/antigravity/claude_types.go
+++ b/backend/internal/pkg/antigravity/claude_types.go
@@ -125,6 +125,7 @@ type ClaudeUsage struct {
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
+ ImageOutputTokens int `json:"image_output_tokens,omitempty"`
}
// ClaudeError Claude 错误响应
@@ -153,6 +154,7 @@ var claudeModels = []modelDef{
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
{ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
{ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"},
+ {ID: "claude-opus-4-7", DisplayName: "Claude Opus 4.7", CreatedAt: "2026-04-17T00:00:00Z"},
{ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
}
diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go
index 1a0ca5bb..033dccbd 100644
--- a/backend/internal/pkg/antigravity/gemini_types.go
+++ b/backend/internal/pkg/antigravity/gemini_types.go
@@ -149,13 +149,31 @@ type GeminiCandidate struct {
GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"`
}
+// GeminiTokenDetail Gemini token 详情(按模态分类)
+type GeminiTokenDetail struct {
+ Modality string `json:"modality"`
+ TokenCount int `json:"tokenCount"`
+}
+
// GeminiUsageMetadata Gemini 用量元数据
type GeminiUsageMetadata struct {
- PromptTokenCount int `json:"promptTokenCount,omitempty"`
- CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
- CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
- TotalTokenCount int `json:"totalTokenCount,omitempty"`
- ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费)
+ PromptTokenCount int `json:"promptTokenCount,omitempty"`
+ CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
+ CachedContentTokenCount int `json:"cachedContentTokenCount,omitempty"`
+ TotalTokenCount int `json:"totalTokenCount,omitempty"`
+ ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"` // thinking tokens(按输出价格计费)
+ CandidatesTokensDetails []GeminiTokenDetail `json:"candidatesTokensDetails,omitempty"`
+ PromptTokensDetails []GeminiTokenDetail `json:"promptTokensDetails,omitempty"`
+}
+
+// ImageOutputTokens 从 CandidatesTokensDetails 中提取 IMAGE 模态的 token 数
+func (m *GeminiUsageMetadata) ImageOutputTokens() int {
+ for _, d := range m.CandidatesTokensDetails {
+ if d.Modality == "IMAGE" {
+ return d.TokenCount
+ }
+ }
+ return 0
}
// GeminiGroundingMetadata Gemini grounding 元数据(Web Search)
diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go
index 8a8bed92..7c963d9e 100644
--- a/backend/internal/pkg/antigravity/oauth.go
+++ b/backend/internal/pkg/antigravity/oauth.go
@@ -50,7 +50,7 @@ const (
)
// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.20.5
-var defaultUserAgentVersion = "1.20.5"
+var defaultUserAgentVersion = "1.21.9"
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
var defaultClientSecret = "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"
diff --git a/backend/internal/pkg/antigravity/oauth_test.go b/backend/internal/pkg/antigravity/oauth_test.go
index 3a093fe6..9850af17 100644
--- a/backend/internal/pkg/antigravity/oauth_test.go
+++ b/backend/internal/pkg/antigravity/oauth_test.go
@@ -690,7 +690,7 @@ func TestConstants_值正确(t *testing.T) {
if RedirectURI != "http://localhost:8085/callback" {
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
}
- if GetUserAgent() != "antigravity/1.20.5 windows/amd64" {
+ if GetUserAgent() != "antigravity/1.21.9 windows/amd64" {
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
}
if SessionTTL != 30*time.Minute {
diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go
index 1b45e507..b5de8166 100644
--- a/backend/internal/pkg/antigravity/request_transformer.go
+++ b/backend/internal/pkg/antigravity/request_transformer.go
@@ -582,8 +582,12 @@ func maxOutputTokensLimit(model string) int {
return maxOutputTokensUpperBound
}
-func isAntigravityOpus46Model(model string) bool {
- return strings.HasPrefix(strings.ToLower(model), "claude-opus-4-6")
+// isAntigravityOpusHighTierModel 判断是否为高阶 Opus 模型(4.6+),
+// 用于 adaptive thinking 时覆写为高预算。
+func isAntigravityOpusHighTierModel(model string) bool {
+ lower := strings.ToLower(model)
+ return strings.HasPrefix(lower, "claude-opus-4-6") ||
+ strings.HasPrefix(lower, "claude-opus-4-7")
}
func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
@@ -605,12 +609,12 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig {
}
// - thinking.type=enabled:budget_tokens>0 用显式预算
- // - thinking.type=adaptive:仅在 Antigravity 的 Opus 4.6 上覆写为 (24576)
+ // - thinking.type=adaptive:在 Antigravity 的高阶 Opus(4.6+)上覆写为 (24576)
budget := -1
if req.Thinking.BudgetTokens > 0 {
budget = req.Thinking.BudgetTokens
}
- if req.Thinking.Type == "adaptive" && isAntigravityOpus46Model(req.Model) {
+ if req.Thinking.Type == "adaptive" && isAntigravityOpusHighTierModel(req.Model) {
budget = ClaudeAdaptiveHighThinkingBudgetTokens
}
@@ -730,13 +734,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
})
}
- if len(funcDecls) == 0 {
- if !hasWebSearch {
- return nil
- }
-
- // Web Search 工具映射
- return []GeminiToolDeclaration{{
+ var declarations []GeminiToolDeclaration
+ if len(funcDecls) > 0 {
+ declarations = append(declarations, GeminiToolDeclaration{
+ FunctionDeclarations: funcDecls,
+ })
+ }
+ if hasWebSearch {
+ declarations = append(declarations, GeminiToolDeclaration{
GoogleSearch: &GeminiGoogleSearch{
EnhancedContent: &GeminiEnhancedContent{
ImageSearch: &GeminiImageSearch{
@@ -744,10 +749,11 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
},
},
},
- }}
+ })
+ }
+ if len(declarations) == 0 {
+ return nil
}
- return []GeminiToolDeclaration{{
- FunctionDeclarations: funcDecls,
- }}
+ return declarations
}
diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go
index 9e46295a..6fae5b7c 100644
--- a/backend/internal/pkg/antigravity/request_transformer_test.go
+++ b/backend/internal/pkg/antigravity/request_transformer_test.go
@@ -263,6 +263,29 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
}
}
+func TestBuildTools_PreservesWebSearchAlongsideFunctions(t *testing.T) {
+ tools := []ClaudeTool{
+ {
+ Name: "get_weather",
+ Description: "Get weather information",
+ InputSchema: map[string]any{"type": "object"},
+ },
+ {
+ Type: "web_search_20250305",
+ Name: "web_search",
+ },
+ }
+
+ result := buildTools(tools)
+ require.Len(t, result, 2)
+ require.Len(t, result[0].FunctionDeclarations, 1)
+ require.Equal(t, "get_weather", result[0].FunctionDeclarations[0].Name)
+ require.NotNil(t, result[1].GoogleSearch)
+ require.NotNil(t, result[1].GoogleSearch.EnhancedContent)
+ require.NotNil(t, result[1].GoogleSearch.EnhancedContent.ImageSearch)
+ require.Equal(t, 5, result[1].GoogleSearch.EnhancedContent.ImageSearch.MaxResultCount)
+}
+
func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
tests := []struct {
name string
@@ -400,3 +423,36 @@ func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t
})
}
}
+
+func TestTransformClaudeToGeminiWithOptions_PreservesWebSearchAlongsideFunctions(t *testing.T) {
+ claudeReq := &ClaudeRequest{
+ Model: "claude-3-5-sonnet-latest",
+ Messages: []ClaudeMessage{
+ {
+ Role: "user",
+ Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
+ },
+ },
+ Tools: []ClaudeTool{
+ {
+ Name: "get_weather",
+ Description: "Get weather information",
+ InputSchema: map[string]any{"type": "object"},
+ },
+ {
+ Type: "web_search_20250305",
+ Name: "web_search",
+ },
+ },
+ }
+
+ body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "gemini-2.5-flash", DefaultTransformOptions())
+ require.NoError(t, err)
+
+ var req V1InternalRequest
+ require.NoError(t, json.Unmarshal(body, &req))
+ require.Len(t, req.Request.Tools, 2)
+ require.Len(t, req.Request.Tools[0].FunctionDeclarations, 1)
+ require.Equal(t, "get_weather", req.Request.Tools[0].FunctionDeclarations[0].Name)
+ require.NotNil(t, req.Request.Tools[1].GoogleSearch)
+}
diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go
index f12effb6..bc1fd32e 100644
--- a/backend/internal/pkg/antigravity/response_transformer.go
+++ b/backend/internal/pkg/antigravity/response_transformer.go
@@ -284,6 +284,7 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
usage.InputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached
+ usage.ImageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
}
// 生成响应 ID
diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go
index deed5f92..4a68f3a9 100644
--- a/backend/internal/pkg/antigravity/stream_transformer.go
+++ b/backend/internal/pkg/antigravity/stream_transformer.go
@@ -18,6 +18,9 @@ const (
BlockTypeFunction
)
+// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
+type UsageMapHook func(usageMap map[string]any)
+
// StreamingProcessor 流式响应处理器
type StreamingProcessor struct {
blockType BlockType
@@ -30,11 +33,13 @@ type StreamingProcessor struct {
originalModel string
webSearchQueries []string
groundingChunks []GeminiGroundingChunk
+ usageMapHook UsageMapHook
// 累计 usage
- inputTokens int
- outputTokens int
- cacheReadTokens int
+ inputTokens int
+ outputTokens int
+ cacheReadTokens int
+ imageOutputTokens int
}
// NewStreamingProcessor 创建流式响应处理器
@@ -45,6 +50,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
}
}
+// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
+func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
+ p.usageMapHook = fn
+}
+
+func usageToMap(u ClaudeUsage) map[string]any {
+ m := map[string]any{
+ "input_tokens": u.InputTokens,
+ "output_tokens": u.OutputTokens,
+ }
+ if u.CacheCreationInputTokens > 0 {
+ m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
+ }
+ if u.CacheReadInputTokens > 0 {
+ m["cache_read_input_tokens"] = u.CacheReadInputTokens
+ }
+ if u.ImageOutputTokens > 0 {
+ m["image_output_tokens"] = u.ImageOutputTokens
+ }
+ return m
+}
+
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
func (p *StreamingProcessor) ProcessLine(line string) []byte {
line = strings.TrimSpace(line)
@@ -87,6 +114,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
p.inputTokens = geminiResp.UsageMetadata.PromptTokenCount - cached
p.outputTokens = geminiResp.UsageMetadata.CandidatesTokenCount + geminiResp.UsageMetadata.ThoughtsTokenCount
p.cacheReadTokens = cached
+ p.imageOutputTokens = geminiResp.UsageMetadata.ImageOutputTokens()
}
// 处理 parts
@@ -127,6 +155,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens,
+ ImageOutputTokens: p.imageOutputTokens,
}
if !p.messageStartSent {
@@ -158,6 +187,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
usage.InputTokens = v1Resp.Response.UsageMetadata.PromptTokenCount - cached
usage.OutputTokens = v1Resp.Response.UsageMetadata.CandidatesTokenCount + v1Resp.Response.UsageMetadata.ThoughtsTokenCount
usage.CacheReadInputTokens = cached
+ usage.ImageOutputTokens = v1Resp.Response.UsageMetadata.ImageOutputTokens()
}
responseID := v1Resp.ResponseID
@@ -168,6 +198,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
responseID = "msg_" + generateRandomID()
}
+ var usageValue any = usage
+ if p.usageMapHook != nil {
+ usageMap := usageToMap(usage)
+ p.usageMapHook(usageMap)
+ usageValue = usageMap
+ }
+
message := map[string]any{
"id": responseID,
"type": "message",
@@ -176,7 +213,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
"model": p.originalModel,
"stop_reason": nil,
"stop_sequence": nil,
- "usage": usage,
+ "usage": usageValue,
}
event := map[string]any{
@@ -485,6 +522,14 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
InputTokens: p.inputTokens,
OutputTokens: p.outputTokens,
CacheReadInputTokens: p.cacheReadTokens,
+ ImageOutputTokens: p.imageOutputTokens,
+ }
+
+ var usageValue any = usage
+ if p.usageMapHook != nil {
+ usageMap := usageToMap(usage)
+ p.usageMapHook(usageMap)
+ usageValue = usageMap
}
deltaEvent := map[string]any{
@@ -493,7 +538,7 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
"stop_reason": stopReason,
"stop_sequence": nil,
},
- "usage": usage,
+ "usage": usageValue,
}
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go
index 095305c2..facfe572 100644
--- a/backend/internal/pkg/apicompat/anthropic_responses_test.go
+++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go
@@ -181,6 +181,55 @@ func TestResponsesToAnthropic_TextOnly(t *testing.T) {
assert.Equal(t, 5, anth.Usage.OutputTokens)
}
+func TestResponsesToAnthropic_CachedTokensUseAnthropicInputSemantics(t *testing.T) {
+ resp := &ResponsesResponse{
+ ID: "resp_cached",
+ Model: "gpt-5.2",
+ Status: "completed",
+ Output: []ResponsesOutput{
+ {
+ Type: "message",
+ Content: []ResponsesContentPart{
+ {Type: "output_text", Text: "Cached response"},
+ },
+ },
+ },
+ Usage: &ResponsesUsage{
+ InputTokens: 54006,
+ OutputTokens: 123,
+ TotalTokens: 54129,
+ InputTokensDetails: &ResponsesInputTokensDetails{
+ CachedTokens: 50688,
+ },
+ },
+ }
+
+ anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
+ assert.Equal(t, 3318, anth.Usage.InputTokens)
+ assert.Equal(t, 50688, anth.Usage.CacheReadInputTokens)
+ assert.Equal(t, 123, anth.Usage.OutputTokens)
+}
+
+func TestResponsesToAnthropic_CachedTokensClampInputTokens(t *testing.T) {
+ resp := &ResponsesResponse{
+ ID: "resp_cached_clamp",
+ Model: "gpt-5.2",
+ Status: "completed",
+ Usage: &ResponsesUsage{
+ InputTokens: 100,
+ OutputTokens: 5,
+ InputTokensDetails: &ResponsesInputTokensDetails{
+ CachedTokens: 150,
+ },
+ },
+ }
+
+ anth := ResponsesToAnthropic(resp, "claude-sonnet-4-5-20250929")
+ assert.Equal(t, 0, anth.Usage.InputTokens)
+ assert.Equal(t, 150, anth.Usage.CacheReadInputTokens)
+ assert.Equal(t, 5, anth.Usage.OutputTokens)
+}
+
func TestResponsesToAnthropic_ToolUse(t *testing.T) {
resp := &ResponsesResponse{
ID: "resp_456",
@@ -209,6 +258,48 @@ func TestResponsesToAnthropic_ToolUse(t *testing.T) {
assert.Equal(t, "tool_use", anth.Content[1].Type)
assert.Equal(t, "call_1", anth.Content[1].ID)
assert.Equal(t, "get_weather", anth.Content[1].Name)
+ assert.JSONEq(t, `{"city":"NYC"}`, string(anth.Content[1].Input))
+}
+
+func TestResponsesToAnthropic_ReadToolDropsEmptyPages(t *testing.T) {
+ resp := &ResponsesResponse{
+ ID: "resp_read",
+ Model: "gpt-5.5",
+ Status: "completed",
+ Output: []ResponsesOutput{
+ {
+ Type: "function_call",
+ CallID: "call_read",
+ Name: "Read",
+ Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
+ },
+ },
+ }
+
+ anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
+ require.Len(t, anth.Content, 1)
+ assert.Equal(t, "tool_use", anth.Content[0].Type)
+ assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, string(anth.Content[0].Input))
+}
+
+func TestResponsesToAnthropic_PreservesEmptyStringsForOtherTools(t *testing.T) {
+ resp := &ResponsesResponse{
+ ID: "resp_other",
+ Model: "gpt-5.5",
+ Status: "completed",
+ Output: []ResponsesOutput{
+ {
+ Type: "function_call",
+ CallID: "call_other",
+ Name: "Search",
+ Arguments: `{"query":""}`,
+ },
+ },
+ }
+
+ anth := ResponsesToAnthropic(resp, "claude-opus-4-6")
+ require.Len(t, anth.Content, 1)
+ assert.JSONEq(t, `{"query":""}`, string(anth.Content[0].Input))
}
func TestResponsesToAnthropic_Reasoning(t *testing.T) {
@@ -343,6 +434,36 @@ func TestStreamingTextOnly(t *testing.T) {
assert.Equal(t, "message_stop", events[1].Type)
}
+func TestStreamingCachedTokensUseAnthropicInputSemantics(t *testing.T) {
+ state := NewResponsesEventToAnthropicState()
+ ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.created",
+ Response: &ResponsesResponse{ID: "resp_cached_stream", Model: "gpt-5.2"},
+ }, state)
+
+ events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.completed",
+ Response: &ResponsesResponse{
+ Status: "completed",
+ Usage: &ResponsesUsage{
+ InputTokens: 54006,
+ OutputTokens: 123,
+ TotalTokens: 54129,
+ InputTokensDetails: &ResponsesInputTokensDetails{
+ CachedTokens: 50688,
+ },
+ },
+ },
+ }, state)
+
+ require.Len(t, events, 2)
+ assert.Equal(t, "message_delta", events[0].Type)
+ assert.Equal(t, 3318, events[0].Usage.InputTokens)
+ assert.Equal(t, 50688, events[0].Usage.CacheReadInputTokens)
+ assert.Equal(t, 123, events[0].Usage.OutputTokens)
+ assert.Equal(t, "message_stop", events[1].Type)
+}
+
func TestStreamingToolCall(t *testing.T) {
state := NewResponsesEventToAnthropicState()
@@ -393,6 +514,41 @@ func TestStreamingToolCall(t *testing.T) {
assert.Equal(t, "tool_use", events[0].Delta.StopReason)
}
+func TestStreamingReadToolDropsEmptyPages(t *testing.T) {
+ state := NewResponsesEventToAnthropicState()
+
+ ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.created",
+ Response: &ResponsesResponse{ID: "resp_read_stream", Model: "gpt-5.5"},
+ }, state)
+
+ events := ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.output_item.added",
+ OutputIndex: 0,
+ Item: &ResponsesOutput{Type: "function_call", CallID: "call_read", Name: "Read"},
+ }, state)
+ require.Len(t, events, 1)
+ assert.Equal(t, "content_block_start", events[0].Type)
+
+ events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.function_call_arguments.delta",
+ OutputIndex: 0,
+ Delta: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
+ }, state)
+ assert.Len(t, events, 0)
+
+ events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
+ Type: "response.function_call_arguments.done",
+ OutputIndex: 0,
+ Arguments: `{"file_path":"/tmp/demo.py","limit":2000,"offset":0,"pages":""}`,
+ }, state)
+ require.Len(t, events, 2)
+ assert.Equal(t, "content_block_delta", events[0].Type)
+ assert.Equal(t, "input_json_delta", events[0].Delta.Type)
+ assert.JSONEq(t, `{"file_path":"/tmp/demo.py","limit":2000,"offset":0}`, events[0].Delta.PartialJSON)
+ assert.Equal(t, "content_block_stop", events[1].Type)
+}
+
func TestStreamingReasoning(t *testing.T) {
state := NewResponsesEventToAnthropicState()
diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go
index f54a4a02..c140449a 100644
--- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go
+++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go
@@ -181,6 +181,50 @@ func TestChatCompletionsToResponses_ImageURL(t *testing.T) {
assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL)
}
+func TestChatCompletionsToResponses_EmptyBase64ImageURLSkipped(t *testing.T) {
+ content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,"}}]`
+ req := &ChatCompletionsRequest{
+ Model: "gpt-4o",
+ Messages: []ChatMessage{
+ {Role: "user", Content: json.RawMessage(content)},
+ },
+ }
+ resp, err := ChatCompletionsToResponses(req)
+ require.NoError(t, err)
+
+ var items []ResponsesInputItem
+ require.NoError(t, json.Unmarshal(resp.Input, &items))
+ require.Len(t, items, 1)
+
+ var parts []ResponsesContentPart
+ require.NoError(t, json.Unmarshal(items[0].Content, &parts))
+ require.Len(t, parts, 1)
+ assert.Equal(t, "input_text", parts[0].Type)
+ assert.Equal(t, "Describe this", parts[0].Text)
+}
+
+func TestChatCompletionsToResponses_WhitespaceOnlyBase64ImageURLSkipped(t *testing.T) {
+ content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64, "}}]`
+ req := &ChatCompletionsRequest{
+ Model: "gpt-4o",
+ Messages: []ChatMessage{
+ {Role: "user", Content: json.RawMessage(content)},
+ },
+ }
+ resp, err := ChatCompletionsToResponses(req)
+ require.NoError(t, err)
+
+ var items []ResponsesInputItem
+ require.NoError(t, json.Unmarshal(resp.Input, &items))
+ require.Len(t, items, 1)
+
+ var parts []ResponsesContentPart
+ require.NoError(t, json.Unmarshal(items[0].Content, &parts))
+ require.Len(t, parts, 1)
+ assert.Equal(t, "input_text", parts[0].Type)
+ assert.Equal(t, "Describe this", parts[0].Text)
+}
+
func TestChatCompletionsToResponses_SystemArrayContent(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
@@ -876,3 +920,182 @@ func TestChatCompletionsStreamRoundTrip(t *testing.T) {
assert.Equal(t, "resp_rt", c.ID)
}
}
+
+// ---------------------------------------------------------------------------
+// BufferedResponseAccumulator tests
+// ---------------------------------------------------------------------------
+
+func TestBufferedResponseAccumulator_TextOnly(t *testing.T) {
+ acc := NewBufferedResponseAccumulator()
+
+ acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.output_text.delta", Delta: "Hello"})
+ acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.output_text.delta", Delta: ", world!"})
+
+ assert.True(t, acc.HasContent())
+
+ output := acc.BuildOutput()
+ require.Len(t, output, 1)
+ assert.Equal(t, "message", output[0].Type)
+ assert.Equal(t, "assistant", output[0].Role)
+ require.Len(t, output[0].Content, 1)
+ assert.Equal(t, "output_text", output[0].Content[0].Type)
+ assert.Equal(t, "Hello, world!", output[0].Content[0].Text)
+}
+
+func TestBufferedResponseAccumulator_ToolCalls(t *testing.T) {
+ acc := NewBufferedResponseAccumulator()
+
+ // Add function call at output_index=1
+ acc.ProcessEvent(&ResponsesStreamEvent{
+ Type: "response.output_item.added",
+ OutputIndex: 1,
+ Item: &ResponsesOutput{
+ Type: "function_call",
+ CallID: "call_abc",
+ Name: "get_weather",
+ },
+ })
+ acc.ProcessEvent(&ResponsesStreamEvent{
+ Type: "response.function_call_arguments.delta",
+ OutputIndex: 1,
+ Delta: `{"city":`,
+ })
+ acc.ProcessEvent(&ResponsesStreamEvent{
+ Type: "response.function_call_arguments.delta",
+ OutputIndex: 1,
+ Delta: `"NYC"}`,
+ })
+
+ assert.True(t, acc.HasContent())
+
+ output := acc.BuildOutput()
+ require.Len(t, output, 1)
+ assert.Equal(t, "function_call", output[0].Type)
+ assert.Equal(t, "call_abc", output[0].CallID)
+ assert.Equal(t, "get_weather", output[0].Name)
+ assert.Equal(t, `{"city":"NYC"}`, output[0].Arguments)
+}
+
+func TestBufferedResponseAccumulator_Reasoning(t *testing.T) {
+ acc := NewBufferedResponseAccumulator()
+
+ acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.reasoning_summary_text.delta", Delta: "Step 1: "})
+ acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.reasoning_summary_text.delta", Delta: "think about it"})
+
+ assert.True(t, acc.HasContent())
+
+ output := acc.BuildOutput()
+ require.Len(t, output, 1)
+ assert.Equal(t, "reasoning", output[0].Type)
+ require.Len(t, output[0].Summary, 1)
+ assert.Equal(t, "summary_text", output[0].Summary[0].Type)
+ assert.Equal(t, "Step 1: think about it", output[0].Summary[0].Text)
+}
+
+func TestBufferedResponseAccumulator_Mixed(t *testing.T) {
+ acc := NewBufferedResponseAccumulator()
+
+ // Reasoning first
+ acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.reasoning_summary_text.delta", Delta: "I thought about it."})
+
+ // Then text
+ acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.output_text.delta", Delta: "The answer is 42."})
+
+ // Then a tool call
+ acc.ProcessEvent(&ResponsesStreamEvent{
+ Type: "response.output_item.added",
+ OutputIndex: 2,
+ Item: &ResponsesOutput{
+ Type: "function_call",
+ CallID: "call_1",
+ Name: "verify",
+ },
+ })
+ acc.ProcessEvent(&ResponsesStreamEvent{
+ Type: "response.function_call_arguments.delta",
+ OutputIndex: 2,
+ Delta: `{}`,
+ })
+
+ assert.True(t, acc.HasContent())
+
+ output := acc.BuildOutput()
+ // Order: reasoning → message → function_calls
+ require.Len(t, output, 3)
+ assert.Equal(t, "reasoning", output[0].Type)
+ assert.Equal(t, "message", output[1].Type)
+ assert.Equal(t, "function_call", output[2].Type)
+ assert.Equal(t, "The answer is 42.", output[1].Content[0].Text)
+ assert.Equal(t, "verify", output[2].Name)
+}
+
+func TestBufferedResponseAccumulator_SupplementEmptyOutput(t *testing.T) {
+ acc := NewBufferedResponseAccumulator()
+ acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.output_text.delta", Delta: "Hello"})
+
+ resp := &ResponsesResponse{
+ ID: "resp_1",
+ Status: "completed",
+ Output: nil, // empty output
+ Usage: &ResponsesUsage{InputTokens: 10, OutputTokens: 5},
+ }
+
+ acc.SupplementResponseOutput(resp)
+
+ require.Len(t, resp.Output, 1)
+ assert.Equal(t, "message", resp.Output[0].Type)
+ assert.Equal(t, "Hello", resp.Output[0].Content[0].Text)
+ // Usage should be untouched
+ assert.Equal(t, 10, resp.Usage.InputTokens)
+}
+
+func TestBufferedResponseAccumulator_NoSupplementWhenOutputExists(t *testing.T) {
+ acc := NewBufferedResponseAccumulator()
+ acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.output_text.delta", Delta: "from deltas"})
+
+ resp := &ResponsesResponse{
+ ID: "resp_2",
+ Status: "completed",
+ Output: []ResponsesOutput{
+ {
+ Type: "message",
+ Content: []ResponsesContentPart{
+ {Type: "output_text", Text: "from terminal event"},
+ },
+ },
+ },
+ }
+
+ acc.SupplementResponseOutput(resp)
+
+ // Output should NOT be overwritten
+ require.Len(t, resp.Output, 1)
+ assert.Equal(t, "from terminal event", resp.Output[0].Content[0].Text)
+}
+
+func TestBufferedResponseAccumulator_EmptyDeltas(t *testing.T) {
+ acc := NewBufferedResponseAccumulator()
+
+ // Process events with empty delta — should not accumulate
+ acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.output_text.delta", Delta: ""})
+ acc.ProcessEvent(&ResponsesStreamEvent{Type: "response.created"})
+
+ assert.False(t, acc.HasContent())
+
+ resp := &ResponsesResponse{ID: "resp_3", Status: "completed"}
+ acc.SupplementResponseOutput(resp)
+ assert.Nil(t, resp.Output)
+}
+
+func TestBufferedResponseAccumulator_IgnoresNonFunctionCallItems(t *testing.T) {
+ acc := NewBufferedResponseAccumulator()
+
+ // output_item.added with type "message" should be ignored
+ acc.ProcessEvent(&ResponsesStreamEvent{
+ Type: "response.output_item.added",
+ OutputIndex: 0,
+ Item: &ResponsesOutput{Type: "message"},
+ })
+
+ assert.False(t, acc.HasContent())
+}
diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go
index 6cdd012a..c2725406 100644
--- a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go
+++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go
@@ -27,13 +27,14 @@ func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest,
}
out := &ResponsesRequest{
- Model: req.Model,
- Input: inputJSON,
- Temperature: req.Temperature,
- TopP: req.TopP,
- Stream: true, // upstream always streams
- Include: []string{"reasoning.encrypted_content"},
- ServiceTier: req.ServiceTier,
+ Model: req.Model,
+ Instructions: req.Instructions,
+ Input: inputJSON,
+ Temperature: req.Temperature,
+ TopP: req.TopP,
+ Stream: true, // upstream always streams
+ Include: []string{"reasoning.encrypted_content"},
+ ServiceTier: req.ServiceTier,
}
storeFalse := false
@@ -339,7 +340,7 @@ func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesCont
})
}
case "image_url":
- if p.ImageURL != nil && p.ImageURL.URL != "" {
+ if p.ImageURL != nil && p.ImageURL.URL != "" && !isEmptyBase64DataURI(p.ImageURL.URL) {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_image",
ImageURL: p.ImageURL.URL,
@@ -350,6 +351,22 @@ func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesCont
return responseParts
}
+func isEmptyBase64DataURI(raw string) bool {
+ if !strings.HasPrefix(raw, "data:") {
+ return false
+ }
+ rest := strings.TrimPrefix(raw, "data:")
+ semicolonIdx := strings.Index(rest, ";")
+ if semicolonIdx < 0 {
+ return false
+ }
+ rest = rest[semicolonIdx+1:]
+ if !strings.HasPrefix(rest, "base64,") {
+ return false
+ }
+ return strings.TrimSpace(strings.TrimPrefix(rest, "base64,")) == ""
+}
+
func flattenChatContentParts(parts []ChatContentPart) string {
var textParts []string
for _, p := range parts {
diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic.go b/backend/internal/pkg/apicompat/responses_to_anthropic.go
index 5409a0f4..489ed238 100644
--- a/backend/internal/pkg/apicompat/responses_to_anthropic.go
+++ b/backend/internal/pkg/apicompat/responses_to_anthropic.go
@@ -52,7 +52,7 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
Type: "tool_use",
ID: fromResponsesCallID(item.CallID),
Name: item.Name,
- Input: json.RawMessage(item.Arguments),
+ Input: sanitizeAnthropicToolUseInput(item.Name, item.Arguments),
})
case "web_search_call":
toolUseID := "srvtoolu_" + item.ID
@@ -84,18 +84,34 @@ func ResponsesToAnthropic(resp *ResponsesResponse, model string) *AnthropicRespo
out.StopReason = responsesStatusToAnthropicStopReason(resp.Status, resp.IncompleteDetails, blocks)
if resp.Usage != nil {
- out.Usage = AnthropicUsage{
- InputTokens: resp.Usage.InputTokens,
- OutputTokens: resp.Usage.OutputTokens,
- }
- if resp.Usage.InputTokensDetails != nil {
- out.Usage.CacheReadInputTokens = resp.Usage.InputTokensDetails.CachedTokens
- }
+ out.Usage = anthropicUsageFromResponsesUsage(resp.Usage)
}
return out
}
+func anthropicUsageFromResponsesUsage(usage *ResponsesUsage) AnthropicUsage {
+ if usage == nil {
+ return AnthropicUsage{}
+ }
+
+ cachedTokens := 0
+ if usage.InputTokensDetails != nil {
+ cachedTokens = usage.InputTokensDetails.CachedTokens
+ }
+
+ inputTokens := usage.InputTokens - cachedTokens
+ if inputTokens < 0 {
+ inputTokens = 0
+ }
+
+ return AnthropicUsage{
+ InputTokens: inputTokens,
+ OutputTokens: usage.OutputTokens,
+ CacheReadInputTokens: cachedTokens,
+ }
+}
+
func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncompleteDetails, blocks []AnthropicContentBlock) string {
switch status {
case "incomplete":
@@ -113,6 +129,28 @@ func responsesStatusToAnthropicStopReason(status string, details *ResponsesIncom
}
}
+func sanitizeAnthropicToolUseInput(name string, raw string) json.RawMessage {
+ if name != "Read" || raw == "" {
+ return json.RawMessage(raw)
+ }
+
+ var input map[string]json.RawMessage
+ if err := json.Unmarshal([]byte(raw), &input); err != nil {
+ return json.RawMessage(raw)
+ }
+
+ if pages, ok := input["pages"]; !ok || string(pages) != `""` {
+ return json.RawMessage(raw)
+ }
+
+ delete(input, "pages")
+ sanitized, err := json.Marshal(input)
+ if err != nil {
+ return json.RawMessage(raw)
+ }
+ return sanitized
+}
+
// ---------------------------------------------------------------------------
// Streaming: ResponsesStreamEvent → []AnthropicStreamEvent (stateful converter)
// ---------------------------------------------------------------------------
@@ -126,6 +164,8 @@ type ResponsesEventToAnthropicState struct {
ContentBlockIndex int
ContentBlockOpen bool
CurrentBlockType string // "text" | "thinking" | "tool_use"
+ CurrentToolName string
+ CurrentToolArgs string
// OutputIndexToBlockIdx maps Responses output_index → Anthropic content block index.
OutputIndexToBlockIdx map[int]int
@@ -165,7 +205,7 @@ func ResponsesEventToAnthropicEvents(
case "response.function_call_arguments.delta":
return resToAnthHandleFuncArgsDelta(evt, state)
case "response.function_call_arguments.done":
- return resToAnthHandleBlockDone(state)
+ return resToAnthHandleFuncArgsDone(evt, state)
case "response.output_item.done":
return resToAnthHandleOutputItemDone(evt, state)
case "response.reasoning_summary_text.delta":
@@ -262,6 +302,8 @@ func resToAnthHandleOutputItemAdded(evt *ResponsesStreamEvent, state *ResponsesE
state.OutputIndexToBlockIdx[evt.OutputIndex] = idx
state.ContentBlockOpen = true
state.CurrentBlockType = "tool_use"
+ state.CurrentToolName = evt.Item.Name
+ state.CurrentToolArgs = ""
events = append(events, AnthropicStreamEvent{
Type: "content_block_start",
@@ -342,6 +384,11 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
return nil
}
+ if state.CurrentBlockType == "tool_use" && state.CurrentToolName == "Read" {
+ state.CurrentToolArgs += evt.Delta
+ return nil
+ }
+
blockIdx, ok := state.OutputIndexToBlockIdx[evt.OutputIndex]
if !ok {
return nil
@@ -357,6 +404,33 @@ func resToAnthHandleFuncArgsDelta(evt *ResponsesStreamEvent, state *ResponsesEve
}}
}
+func resToAnthHandleFuncArgsDone(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
+ if state.CurrentBlockType != "tool_use" || state.CurrentToolName != "Read" {
+ return resToAnthHandleBlockDone(state)
+ }
+
+ raw := evt.Arguments
+ if raw == "" {
+ raw = state.CurrentToolArgs
+ }
+ sanitized := sanitizeAnthropicToolUseInput(state.CurrentToolName, raw)
+ if len(sanitized) == 0 {
+ return closeCurrentBlock(state)
+ }
+
+ idx := state.ContentBlockIndex
+ events := []AnthropicStreamEvent{{
+ Type: "content_block_delta",
+ Index: &idx,
+ Delta: &AnthropicDelta{
+ Type: "input_json_delta",
+ PartialJSON: string(sanitized),
+ },
+ }}
+ events = append(events, closeCurrentBlock(state)...)
+ return events
+}
+
func resToAnthHandleReasoningDelta(evt *ResponsesStreamEvent, state *ResponsesEventToAnthropicState) []AnthropicStreamEvent {
if evt.Delta == "" {
return nil
@@ -466,11 +540,10 @@ func resToAnthHandleCompleted(evt *ResponsesStreamEvent, state *ResponsesEventTo
stopReason := "end_turn"
if evt.Response != nil {
if evt.Response.Usage != nil {
- state.InputTokens = evt.Response.Usage.InputTokens
- state.OutputTokens = evt.Response.Usage.OutputTokens
- if evt.Response.Usage.InputTokensDetails != nil {
- state.CacheReadInputTokens = evt.Response.Usage.InputTokensDetails.CachedTokens
- }
+ usage := anthropicUsageFromResponsesUsage(evt.Response.Usage)
+ state.InputTokens = usage.InputTokens
+ state.OutputTokens = usage.OutputTokens
+ state.CacheReadInputTokens = usage.CacheReadInputTokens
}
switch evt.Response.Status {
case "incomplete":
@@ -509,6 +582,8 @@ func closeCurrentBlock(state *ResponsesEventToAnthropicState) []AnthropicStreamE
idx := state.ContentBlockIndex
state.ContentBlockOpen = false
state.ContentBlockIndex++
+ state.CurrentToolName = ""
+ state.CurrentToolArgs = ""
return []AnthropicStreamEvent{{
Type: "content_block_stop",
Index: &idx,
diff --git a/backend/internal/pkg/apicompat/responses_to_anthropic_request.go b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go
index f0a5b07e..49426b88 100644
--- a/backend/internal/pkg/apicompat/responses_to_anthropic_request.go
+++ b/backend/internal/pkg/apicompat/responses_to_anthropic_request.go
@@ -390,7 +390,7 @@ func convertResponsesToAnthropicTools(tools []ResponsesTool) []AnthropicTool {
var out []AnthropicTool
for _, t := range tools {
switch t.Type {
- case "web_search":
+ case "web_search", "google_search", "web_search_20250305":
out = append(out, AnthropicTool{
Type: "web_search_20250305",
Name: "web_search",
diff --git a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go
index 688a68eb..61b3bf9c 100644
--- a/backend/internal/pkg/apicompat/responses_to_chatcompletions.go
+++ b/backend/internal/pkg/apicompat/responses_to_chatcompletions.go
@@ -5,6 +5,7 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
+ "strings"
"time"
)
@@ -372,3 +373,119 @@ func generateChatCmplID() string {
_, _ = rand.Read(b)
return "chatcmpl-" + hex.EncodeToString(b)
}
+
+// ---------------------------------------------------------------------------
+// BufferedResponseAccumulator: accumulates SSE delta events for non-streaming
+// paths where the terminal event may have empty output.
+// ---------------------------------------------------------------------------
+
+type bufferedFuncCall struct {
+ CallID string
+ Name string
+ Args strings.Builder
+}
+
+// BufferedResponseAccumulator collects content from Responses SSE delta events
+// so that non-streaming handlers can reconstruct output when the terminal event
+// (response.completed / response.done) carries an empty output array.
+type BufferedResponseAccumulator struct {
+ text strings.Builder
+ reasoning strings.Builder
+ funcCalls []bufferedFuncCall
+ outputIndexToFuncIdx map[int]int
+}
+
+// NewBufferedResponseAccumulator returns an initialised accumulator.
+func NewBufferedResponseAccumulator() *BufferedResponseAccumulator {
+ return &BufferedResponseAccumulator{
+ outputIndexToFuncIdx: make(map[int]int),
+ }
+}
+
+// ProcessEvent inspects a single Responses SSE event and accumulates any
+// content it carries. Only delta events that contribute to the final output
+// are handled; all other event types are silently ignored.
+func (a *BufferedResponseAccumulator) ProcessEvent(event *ResponsesStreamEvent) {
+ switch event.Type {
+ case "response.output_text.delta":
+ if event.Delta != "" {
+ _, _ = a.text.WriteString(event.Delta)
+ }
+ case "response.output_item.added":
+ if event.Item != nil && event.Item.Type == "function_call" {
+ idx := len(a.funcCalls)
+ a.outputIndexToFuncIdx[event.OutputIndex] = idx
+ a.funcCalls = append(a.funcCalls, bufferedFuncCall{
+ CallID: event.Item.CallID,
+ Name: event.Item.Name,
+ })
+ }
+ case "response.function_call_arguments.delta":
+ if event.Delta != "" {
+ if idx, ok := a.outputIndexToFuncIdx[event.OutputIndex]; ok {
+ _, _ = a.funcCalls[idx].Args.WriteString(event.Delta)
+ }
+ }
+ case "response.reasoning_summary_text.delta":
+ if event.Delta != "" {
+ _, _ = a.reasoning.WriteString(event.Delta)
+ }
+ }
+}
+
+// HasContent reports whether any content has been accumulated.
+func (a *BufferedResponseAccumulator) HasContent() bool {
+ return a.text.Len() > 0 || len(a.funcCalls) > 0 || a.reasoning.Len() > 0
+}
+
+// BuildOutput constructs a []ResponsesOutput from the accumulated delta
+// content. The order matches what ResponsesToChatCompletions expects:
+// reasoning → message → function_calls.
+func (a *BufferedResponseAccumulator) BuildOutput() []ResponsesOutput {
+ var out []ResponsesOutput
+
+ if a.reasoning.Len() > 0 {
+ out = append(out, ResponsesOutput{
+ Type: "reasoning",
+ Summary: []ResponsesSummary{{
+ Type: "summary_text",
+ Text: a.reasoning.String(),
+ }},
+ })
+ }
+
+ if a.text.Len() > 0 {
+ out = append(out, ResponsesOutput{
+ Type: "message",
+ Role: "assistant",
+ Content: []ResponsesContentPart{{
+ Type: "output_text",
+ Text: a.text.String(),
+ }},
+ })
+ }
+
+ for i := range a.funcCalls {
+ out = append(out, ResponsesOutput{
+ Type: "function_call",
+ CallID: a.funcCalls[i].CallID,
+ Name: a.funcCalls[i].Name,
+ Arguments: a.funcCalls[i].Args.String(),
+ })
+ }
+
+ return out
+}
+
+// SupplementResponseOutput fills resp.Output from accumulated delta content
+// when the terminal event delivered an empty output array. If resp.Output is
+// already populated, this is a no-op (preserves backward compatibility).
+func (a *BufferedResponseAccumulator) SupplementResponseOutput(resp *ResponsesResponse) {
+ if resp == nil || len(resp.Output) > 0 {
+ return
+ }
+ if !a.HasContent() {
+ return
+ }
+ resp.Output = a.BuildOutput()
+}
diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go
index b724a5ed..f8c6b75f 100644
--- a/backend/internal/pkg/apicompat/types.go
+++ b/backend/internal/pkg/apicompat/types.go
@@ -12,23 +12,29 @@ import "encoding/json"
// AnthropicRequest is the request body for POST /v1/messages.
type AnthropicRequest struct {
- Model string `json:"model"`
- MaxTokens int `json:"max_tokens"`
- System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
- Messages []AnthropicMessage `json:"messages"`
- Tools []AnthropicTool `json:"tools,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- TopP *float64 `json:"top_p,omitempty"`
- StopSeqs []string `json:"stop_sequences,omitempty"`
- Thinking *AnthropicThinking `json:"thinking,omitempty"`
- ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
+ Model string `json:"model"`
+ MaxTokens int `json:"max_tokens"`
+ System json.RawMessage `json:"system,omitempty"` // string or []AnthropicContentBlock
+ Messages []AnthropicMessage `json:"messages"`
+ Tools []AnthropicTool `json:"tools,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP *float64 `json:"top_p,omitempty"`
+ StopSeqs []string `json:"stop_sequences,omitempty"`
+ Thinking *AnthropicThinking `json:"thinking,omitempty"`
+ ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
+ // Metadata 会被原样透传给上游。OAuth/Claude-Code 路径依赖 metadata.user_id
+ // 参与上游的"是否为官方 Claude Code 请求"判定;如果经由本结构体重新序列化
+ // 时丢弃该字段,网关侧后续的 metadata 重写(ensureClaudeOAuthMetadataUserID/
+ // RewriteUserIDWithMasking) 在 body 里拿不到起点,就无法重建一个合法的
+ // user_id,进而导致请求被归类为第三方 app。
+ Metadata json.RawMessage `json:"metadata,omitempty"`
OutputConfig *AnthropicOutputConfig `json:"output_config,omitempty"`
}
// AnthropicOutputConfig controls output generation parameters.
type AnthropicOutputConfig struct {
- Effort string `json:"effort,omitempty"` // "low" | "medium" | "high"
+ Effort string `json:"effort,omitempty"` // "low" | "medium" | "high" | "max"
}
// AnthropicThinking configures extended thinking in the Anthropic API.
@@ -76,10 +82,18 @@ type AnthropicImageSource struct {
// AnthropicTool describes a tool available to the model.
type AnthropicTool struct {
- Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
- Name string `json:"name"`
- Description string `json:"description,omitempty"`
- InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
+ Type string `json:"type,omitempty"` // e.g. "web_search_20250305" for server tools
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
+ InputSchema json.RawMessage `json:"input_schema"` // JSON Schema object
+ CacheControl *AnthropicCacheControl `json:"cache_control,omitempty"`
+}
+
+// AnthropicCacheControl 对应 Anthropic API 的 cache_control 字段。
+// ttl 默认由调用方决定;本项目策略见 claude.DefaultCacheControlTTL。
+type AnthropicCacheControl struct {
+ Type string `json:"type"` // "ephemeral"
+ TTL string `json:"ttl,omitempty"` // "5m" / "1h" / 省略=默认 5m(由 Anthropic 判定)
}
// AnthropicResponse is the non-streaming response from POST /v1/messages.
@@ -152,6 +166,7 @@ type AnthropicDelta struct {
// ResponsesRequest is the request body for POST /v1/responses.
type ResponsesRequest struct {
Model string `json:"model"`
+ Instructions string `json:"instructions,omitempty"`
Input json.RawMessage `json:"input"` // string or []ResponsesInputItem
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
@@ -167,7 +182,7 @@ type ResponsesRequest struct {
// ResponsesReasoning configures reasoning effort in the Responses API.
type ResponsesReasoning struct {
- Effort string `json:"effort"` // "low" | "medium" | "high"
+ Effort string `json:"effort"` // "low" | "medium" | "high" | "xhigh"
Summary string `json:"summary,omitempty"` // "auto" | "concise" | "detailed"
}
@@ -337,6 +352,7 @@ type ResponsesStreamEvent struct {
type ChatCompletionsRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
+ Instructions string `json:"instructions,omitempty"` // OpenAI Responses API compat
MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
@@ -345,7 +361,7 @@ type ChatCompletionsRequest struct {
StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"`
Tools []ChatTool `json:"tools,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
- ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high"
+ ReasoningEffort string `json:"reasoning_effort,omitempty"` // "low" | "medium" | "high" | "xhigh"
ServiceTier string `json:"service_tier,omitempty"`
Stop json.RawMessage `json:"stop,omitempty"` // string or []string
diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go
index dfca252f..aa59ba64 100644
--- a/backend/internal/pkg/claude/constants.go
+++ b/backend/internal/pkg/claude/constants.go
@@ -4,6 +4,12 @@ package claude
// Claude Code 客户端相关常量
// Beta header 常量
+//
+// 这里的常量对齐真实 Claude Code CLI 的最新流量(截至 2026-04)。
+// 选型参考:与 Parrot (src/transform/cc_mimicry.py) 的 BETAS 保持一致,
+// 原因:Anthropic 上游会基于 anthropic-beta 的完整集合判定请求来源;
+// 缺少任何"官方 Claude Code 请求才会带"的 beta,都会被降级到第三方额度,
+// 对应报错:`Third-party apps now draw from your extra usage, not your plan limits.`
const (
BetaOAuth = "oauth-2025-04-20"
BetaClaudeCode = "claude-code-20250219"
@@ -12,6 +18,13 @@ const (
BetaTokenCounting = "token-counting-2024-11-01"
BetaContext1M = "context-1m-2025-08-07"
BetaFastMode = "fast-mode-2026-02-01"
+
+ // 新增(对齐官方 CLI 2.1.9x 以来的流量)
+ BetaPromptCachingScope = "prompt-caching-scope-2026-01-05"
+ BetaEffort = "effort-2025-11-24"
+ BetaRedactThinking = "redact-thinking-2026-02-12"
+ BetaContextManagement = "context-management-2025-06-27"
+ BetaExtendedCacheTTL = "extended-cache-ttl-2025-04-11"
)
// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。
@@ -44,11 +57,43 @@ const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," +
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
const APIKeyHaikuBetaHeader = BetaInterleavedThinking
+// DefaultCacheControlTTL 是网关代理为自己生成的 cache_control 块默认使用的 ttl。
+// 真实 Claude Code CLI 当前使用 "1h",但本仓策略是"客户端透传 ttl 优先;
+// 客户端缺省时统一使用 5m",这样既不浪费 1h 缓存额度,也保留客户端自定义能力。
+const DefaultCacheControlTTL = "5m"
+
+// CLICurrentVersion 是 sub2api 当前对外伪装的 Claude Code CLI 版本号(三段 semver)。
+// 用于 billing attribution block 中的 cc_version=X.Y.Z.{fp} 前缀以及 fingerprint 计算。
+// 必须与 DefaultHeaders["User-Agent"] 中的版本号严格一致;不一致会被 Anthropic 判第三方。
+const CLICurrentVersion = "2.1.92"
+
+// FullClaudeCodeMimicryBetas 返回最"像"真实 Claude Code CLI 的完整 beta 列表,
+// 用于 OAuth 账号伪装成 Claude Code 时使用。
+// 顺序与真实 CLI 抓包一致。
+//
+// 使用建议:
+// - OAuth 账号 + 非 haiku:追加这整份列表,再按需保留 client 带来的 beta。
+// - OAuth 账号 + haiku:Anthropic 对 haiku 不做 third-party 判定,使用 HaikuBetaHeader 即可。
+// - API-key 账号:不要使用本函数,参见 APIKeyBetaHeader。
+func FullClaudeCodeMimicryBetas() []string {
+ return []string{
+ BetaClaudeCode,
+ BetaOAuth,
+ BetaInterleavedThinking,
+ BetaPromptCachingScope,
+ BetaEffort,
+ BetaRedactThinking,
+ BetaContextManagement,
+ BetaExtendedCacheTTL,
+ }
+}
+
// DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{
// Keep these in sync with recent Claude CLI traffic to reduce the chance
// that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage.
- "User-Agent": "claude-cli/2.1.22 (external, cli)",
+ // 版本参考:对齐 Parrot (src/transform/cc_mimicry.py:49) 的 CLI_USER_AGENT。
+ "User-Agent": "claude-cli/2.1.92 (external, cli)",
"X-Stainless-Lang": "js",
"X-Stainless-Package-Version": "0.70.0",
"X-Stainless-OS": "Linux",
@@ -83,6 +128,12 @@ var DefaultModels = []Model{
DisplayName: "Claude Opus 4.6",
CreatedAt: "2026-02-06T00:00:00Z",
},
+ {
+ ID: "claude-opus-4-7",
+ Type: "model",
+ DisplayName: "Claude Opus 4.7",
+ CreatedAt: "2026-04-17T00:00:00Z",
+ },
{
ID: "claude-sonnet-4-6",
Type: "model",
diff --git a/backend/internal/pkg/logger/logger_test.go b/backend/internal/pkg/logger/logger_test.go
index 74aae061..06a277a4 100644
--- a/backend/internal/pkg/logger/logger_test.go
+++ b/backend/internal/pkg/logger/logger_test.go
@@ -10,7 +10,13 @@ import (
)
func TestInit_DualOutput(t *testing.T) {
- tmpDir := t.TempDir()
+ // Use os.MkdirTemp instead of t.TempDir to avoid cleanup failures
+ // when lumberjack holds file handles on Windows.
+ tmpDir, err := os.MkdirTemp("", "logger-test-*")
+ if err != nil {
+ t.Fatalf("create temp dir: %v", err)
+ }
+ t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
logPath := filepath.Join(tmpDir, "logs", "sub2api.log")
origStdout := os.Stdout
@@ -57,7 +63,9 @@ func TestInit_DualOutput(t *testing.T) {
L().Info("dual-output-info")
L().Warn("dual-output-warn")
- Sync()
+
+ // Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
+ // The log data is already in the pipe buffer; closing writers is sufficient.
_ = stdoutW.Close()
_ = stderrW.Close()
@@ -166,7 +174,9 @@ func TestInit_CallerShouldPointToCallsite(t *testing.T) {
}
L().Info("caller-check")
- Sync()
+ // Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
+ os.Stdout = origStdout
+ os.Stderr = origStderr
_ = stdoutW.Close()
logBytes, _ := io.ReadAll(stdoutR)
diff --git a/backend/internal/pkg/logger/stdlog_bridge_test.go b/backend/internal/pkg/logger/stdlog_bridge_test.go
index 4482a2ec..30d25b33 100644
--- a/backend/internal/pkg/logger/stdlog_bridge_test.go
+++ b/backend/internal/pkg/logger/stdlog_bridge_test.go
@@ -77,7 +77,7 @@ func TestStdLogBridgeRoutesLevels(t *testing.T) {
log.Printf("service started")
log.Printf("Warning: queue full")
log.Printf("Forward request failed: timeout")
- Sync()
+ // Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
_ = stdoutW.Close()
_ = stderrW.Close()
@@ -139,7 +139,7 @@ func TestLegacyPrintfRoutesLevels(t *testing.T) {
LegacyPrintf("service.test", "request started")
LegacyPrintf("service.test", "Warning: queue full")
LegacyPrintf("service.test", "forward failed: timeout")
- Sync()
+ // Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
_ = stdoutW.Close()
_ = stderrW.Close()
diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go
index 49e38bf8..be9f3aae 100644
--- a/backend/internal/pkg/openai/constants.go
+++ b/backend/internal/pkg/openai/constants.go
@@ -15,18 +15,15 @@ type Model struct {
// DefaultModels OpenAI models list
var DefaultModels = []Model{
+ {ID: "gpt-5.5", Object: "model", Created: 1776873600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.5"},
{ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"},
{ID: "gpt-5.4-mini", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Mini"},
- {ID: "gpt-5.4-nano", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Nano"},
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
- {ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
- {ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
- {ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"},
- {ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"},
- {ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"},
- {ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"},
+ {ID: "gpt-image-1", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 1"},
+ {ID: "gpt-image-1.5", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 1.5"},
+ {ID: "gpt-image-2", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT Image 2"},
}
// DefaultModelIDs returns the default model ID list
@@ -39,7 +36,7 @@ func DefaultModelIDs() []string {
}
// DefaultTestModel default model for testing OpenAI accounts
-const DefaultTestModel = "gpt-5.1-codex"
+const DefaultTestModel = "gpt-5.4"
// DefaultInstructions default instructions for non-Codex CLI requests
// Content loaded from instructions.txt at compile time
diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go
index 6b8521bd..618b6adb 100644
--- a/backend/internal/pkg/openai/oauth.go
+++ b/backend/internal/pkg/openai/oauth.go
@@ -17,8 +17,6 @@ import (
const (
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
- // OAuth Client ID for Sora mobile flow (aligned with sora2api)
- SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
// OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
@@ -39,8 +37,6 @@ const (
const (
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
OAuthPlatformOpenAI = "openai"
- // OAuthPlatformSora uses Sora OAuth client.
- OAuthPlatformSora = "sora"
)
// OAuthSession stores OAuth flow state for OpenAI
@@ -211,15 +207,8 @@ func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platfor
}
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
-// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri),
-// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) {
- switch strings.ToLower(strings.TrimSpace(platform)) {
- case OAuthPlatformSora:
- return ClientID, false
- default:
- return ClientID, true
- }
+ return ClientID, true
}
// TokenRequest represents the token exchange request body
diff --git a/backend/internal/pkg/openai/oauth_test.go b/backend/internal/pkg/openai/oauth_test.go
index 2970addf..56b42fc9 100644
--- a/backend/internal/pkg/openai/oauth_test.go
+++ b/backend/internal/pkg/openai/oauth_test.go
@@ -60,23 +60,3 @@ func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
}
}
-
-// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
-// 但不启用 codex_cli_simplified_flow。
-func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) {
- authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora)
- parsed, err := url.Parse(authURL)
- if err != nil {
- t.Fatalf("Parse URL failed: %v", err)
- }
- q := parsed.Query()
- if got := q.Get("client_id"); got != ClientID {
- t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID)
- }
- if got := q.Get("codex_cli_simplified_flow"); got != "" {
- t.Fatalf("codex flow should be empty for sora, got=%q", got)
- }
- if got := q.Get("id_token_add_organizations"); got != "true" {
- t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
- }
-}
diff --git a/backend/internal/pkg/pagination/pagination.go b/backend/internal/pkg/pagination/pagination.go
index c162588a..ce8e74b8 100644
--- a/backend/internal/pkg/pagination/pagination.go
+++ b/backend/internal/pkg/pagination/pagination.go
@@ -1,10 +1,19 @@
// Package pagination provides types and helpers for paginated responses.
package pagination
+import "strings"
+
+const (
+ SortOrderAsc = "asc"
+ SortOrderDesc = "desc"
+)
+
// PaginationParams 分页参数
type PaginationParams struct {
- Page int
- PageSize int
+ Page int
+ PageSize int
+ SortBy string
+ SortOrder string
}
// PaginationResult 分页结果
@@ -18,8 +27,9 @@ type PaginationResult struct {
// DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams {
return PaginationParams{
- Page: 1,
- PageSize: 20,
+ Page: 1,
+ PageSize: 20,
+ SortOrder: SortOrderDesc,
}
}
@@ -36,8 +46,32 @@ func (p PaginationParams) Limit() int {
if p.PageSize < 1 {
return 20
}
- if p.PageSize > 100 {
- return 100
+ if p.PageSize > 1000 {
+ return 1000
}
return p.PageSize
}
+
+// NormalizeSortOrder normalizes sort order to asc/desc and falls back to defaultOrder.
+func NormalizeSortOrder(order string, defaultOrder string) string {
+ switch strings.ToLower(strings.TrimSpace(defaultOrder)) {
+ case SortOrderAsc:
+ defaultOrder = SortOrderAsc
+ default:
+ defaultOrder = SortOrderDesc
+ }
+
+ switch strings.ToLower(strings.TrimSpace(order)) {
+ case SortOrderAsc:
+ return SortOrderAsc
+ case SortOrderDesc:
+ return SortOrderDesc
+ default:
+ return defaultOrder
+ }
+}
+
+// NormalizedSortOrder returns the normalized sort order using defaultOrder as fallback.
+func (p PaginationParams) NormalizedSortOrder(defaultOrder string) string {
+ return NormalizeSortOrder(p.SortOrder, defaultOrder)
+}
diff --git a/backend/internal/pkg/pagination/pagination_test.go b/backend/internal/pkg/pagination/pagination_test.go
new file mode 100644
index 00000000..9a3b069d
--- /dev/null
+++ b/backend/internal/pkg/pagination/pagination_test.go
@@ -0,0 +1,71 @@
+package pagination
+
+import "testing"
+
+func TestNormalizeSortOrder(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ defaultOrder string
+ want string
+ }{
+ {name: "asc", input: "asc", defaultOrder: "desc", want: "asc"},
+ {name: "uppercase asc", input: "ASC", defaultOrder: "desc", want: "asc"},
+ {name: "desc", input: "desc", defaultOrder: "asc", want: "desc"},
+ {name: "trim spaces", input: " desc ", defaultOrder: "asc", want: "desc"},
+ {name: "invalid falls back", input: "sideways", defaultOrder: "asc", want: "asc"},
+ {name: "empty falls back", input: "", defaultOrder: "desc", want: "desc"},
+ {name: "invalid default falls back to desc", input: "", defaultOrder: "wat", want: "desc"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ if got := NormalizeSortOrder(tt.input, tt.defaultOrder); got != tt.want {
+ t.Fatalf("NormalizeSortOrder(%q, %q) = %q, want %q", tt.input, tt.defaultOrder, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestPaginationParamsNormalizedSortOrder(t *testing.T) {
+ t.Parallel()
+
+ params := PaginationParams{SortOrder: "ASC"}
+ if got := params.NormalizedSortOrder("desc"); got != "asc" {
+ t.Fatalf("NormalizedSortOrder = %q, want asc", got)
+ }
+
+ params = PaginationParams{SortOrder: "bad"}
+ if got := params.NormalizedSortOrder("asc"); got != "asc" {
+ t.Fatalf("NormalizedSortOrder invalid fallback = %q, want asc", got)
+ }
+}
+
+func TestPaginationParamsLimit(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ pageSize int
+ want int
+ }{
+ {name: "non-positive falls back to default", pageSize: 0, want: 20},
+ {name: "negative falls back to default", pageSize: -1, want: 20},
+ {name: "normal value keeps", pageSize: 50, want: 50},
+ {name: "max value keeps", pageSize: 1000, want: 1000},
+ {name: "beyond max clamps to 1000", pageSize: 1500, want: 1000},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ p := PaginationParams{PageSize: tt.pageSize}
+ if got := p.Limit(); got != tt.want {
+ t.Fatalf("Limit() for PageSize=%d = %d, want %d", tt.pageSize, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go
index 44cddb6a..fe5f98d6 100644
--- a/backend/internal/pkg/usagestats/usage_log_types.go
+++ b/backend/internal/pkg/usagestats/usage_log_types.go
@@ -56,8 +56,9 @@ type DashboardStats struct {
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
- TotalCost float64 `json:"total_cost"` // 累计标准计费
- TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
+ TotalCost float64 `json:"total_cost"` // 累计标准计费
+ TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
+ TotalAccountCost float64 `json:"total_account_cost"` // 累计账号成本
// 今日 Token 使用统计
TodayRequests int64 `json:"today_requests"`
@@ -66,8 +67,9 @@ type DashboardStats struct {
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
TodayTokens int64 `json:"today_tokens"`
- TodayCost float64 `json:"today_cost"` // 今日标准计费
- TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
+ TodayCost float64 `json:"today_cost"` // 今日标准计费
+ TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
+ TodayAccountCost float64 `json:"today_account_cost"` // 今日账号成本
// 系统运行统计
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
@@ -99,8 +101,9 @@ type ModelStat struct {
CacheCreationTokens int64 `json:"cache_creation_tokens"`
CacheReadTokens int64 `json:"cache_read_tokens"`
TotalTokens int64 `json:"total_tokens"`
- Cost float64 `json:"cost"` // 标准计费
- ActualCost float64 `json:"actual_cost"` // 实际扣除
+ Cost float64 `json:"cost"` // 标准计费
+ ActualCost float64 `json:"actual_cost"` // 实际扣除
+ AccountCost float64 `json:"account_cost"` // 账号成本
}
// EndpointStat represents usage statistics for a single request endpoint.
@@ -125,8 +128,9 @@ type GroupStat struct {
GroupName string `json:"group_name"`
Requests int64 `json:"requests"`
TotalTokens int64 `json:"total_tokens"`
- Cost float64 `json:"cost"` // 标准计费
- ActualCost float64 `json:"actual_cost"` // 实际扣除
+ Cost float64 `json:"cost"` // 标准计费
+ ActualCost float64 `json:"actual_cost"` // 实际扣除
+ AccountCost float64 `json:"account_cost"` // 账号成本
}
// UserUsageTrendPoint represents user usage trend data point
@@ -164,8 +168,9 @@ type UserBreakdownItem struct {
Email string `json:"email"`
Requests int64 `json:"requests"`
TotalTokens int64 `json:"total_tokens"`
- Cost float64 `json:"cost"` // 标准计费
- ActualCost float64 `json:"actual_cost"` // 实际扣除
+ Cost float64 `json:"cost"` // 标准计费
+ ActualCost float64 `json:"actual_cost"` // 实际扣除
+ AccountCost float64 `json:"account_cost"` // 账号成本
}
// UserBreakdownDimension specifies the dimension to filter for user breakdown.
@@ -175,6 +180,13 @@ type UserBreakdownDimension struct {
ModelType string // "requested", "upstream", or "mapping"
Endpoint string // filter by endpoint value (non-empty to enable)
EndpointType string // "inbound", "upstream", or "path"
+ // Additional filter conditions
+ UserID int64 // filter by user_id (>0 to enable)
+ APIKeyID int64 // filter by api_key_id (>0 to enable)
+ AccountID int64 // filter by account_id (>0 to enable)
+ RequestType *int16 // filter by request_type (non-nil to enable)
+ Stream *bool // filter by stream flag (non-nil to enable)
+ BillingType *int8 // filter by billing_type (non-nil to enable)
}
// APIKeyUsageTrendPoint represents API key usage trend data point
@@ -230,6 +242,7 @@ type UsageLogFilters struct {
RequestType *int16
Stream *bool
BillingType *int8
+ BillingMode string
StartTime *time.Time
EndTime *time.Time
// ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging.
diff --git a/backend/internal/pkg/websearch/brave.go b/backend/internal/pkg/websearch/brave.go
new file mode 100644
index 00000000..707e7029
--- /dev/null
+++ b/backend/internal/pkg/websearch/brave.go
@@ -0,0 +1,106 @@
+package websearch
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strconv"
+)
+
+const (
+ braveSearchEndpoint = "https://api.search.brave.com/res/v1/web/search"
+ braveMaxCount = 20
+ braveProviderName = "brave"
+)
+
+// braveSearchURL is pre-parsed at init time; url.Parse cannot fail on a constant literal.
+var braveSearchURL, _ = url.Parse(braveSearchEndpoint) //nolint:errcheck
+
+// BraveProvider implements web search via the Brave Search API.
+type BraveProvider struct {
+ apiKey string
+ httpClient *http.Client
+}
+
+// NewBraveProvider creates a Brave Search provider.
+// The caller is responsible for configuring the http.Client with proxy/timeouts.
+func NewBraveProvider(apiKey string, httpClient *http.Client) *BraveProvider {
+ if httpClient == nil {
+ httpClient = http.DefaultClient
+ }
+ return &BraveProvider{apiKey: apiKey, httpClient: httpClient}
+}
+
+func (b *BraveProvider) Name() string { return braveProviderName }
+
+func (b *BraveProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
+ count := req.MaxResults
+ if count <= 0 {
+ count = defaultMaxResults
+ }
+ if count > braveMaxCount {
+ count = braveMaxCount
+ }
+
+ u := *braveSearchURL // copy the pre-parsed URL
+ q := u.Query()
+ q.Set("q", req.Query)
+ q.Set("count", strconv.Itoa(count))
+ u.RawQuery = q.Encode()
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("brave: build request: %w", err)
+ }
+ httpReq.Header.Set("X-Subscription-Token", b.apiKey)
+ httpReq.Header.Set("Accept", "application/json")
+
+ resp, err := b.httpClient.Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("brave: request failed: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
+ if err != nil {
+ return nil, fmt.Errorf("brave: read body: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("brave: status %d: %s", resp.StatusCode, truncateBody(body))
+ }
+
+ var raw braveResponse
+ if err := json.Unmarshal(body, &raw); err != nil {
+ return nil, fmt.Errorf("brave: decode response: %w", err)
+ }
+
+ results := make([]SearchResult, 0, len(raw.Web.Results))
+ for _, r := range raw.Web.Results {
+ results = append(results, SearchResult{
+ URL: r.URL,
+ Title: r.Title,
+ Snippet: r.Description,
+ PageAge: r.Age,
+ })
+ }
+
+ return &SearchResponse{Results: results, Query: req.Query}, nil
+}
+
+// braveResponse is the minimal structure of the Brave Search API response.
+type braveResponse struct {
+ Web struct {
+ Results []braveResult `json:"results"`
+ } `json:"web"`
+}
+
+type braveResult struct {
+ URL string `json:"url"`
+ Title string `json:"title"`
+ Description string `json:"description"`
+ Age string `json:"age"`
+}
diff --git a/backend/internal/pkg/websearch/brave_test.go b/backend/internal/pkg/websearch/brave_test.go
new file mode 100644
index 00000000..4dc5b219
--- /dev/null
+++ b/backend/internal/pkg/websearch/brave_test.go
@@ -0,0 +1,119 @@
+package websearch
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestBraveProvider_Name(t *testing.T) {
+ p := NewBraveProvider("key", nil)
+ require.Equal(t, "brave", p.Name())
+}
+
+func TestBraveProvider_Search_Success(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "test-key", r.Header.Get("X-Subscription-Token"))
+ require.Equal(t, "application/json", r.Header.Get("Accept"))
+ require.Equal(t, "golang", r.URL.Query().Get("q"))
+ require.Equal(t, "3", r.URL.Query().Get("count"))
+
+ resp := braveResponse{}
+ resp.Web.Results = []braveResult{
+ {URL: "https://go.dev", Title: "Go", Description: "Go lang", Age: "1 day"},
+ {URL: "https://pkg.go.dev", Title: "Pkg", Description: "Packages"},
+ {URL: "https://tour.go.dev", Title: "Tour", Description: "A Tour of Go", Age: "3 days"},
+ }
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("test-key", srv.Client())
+ // Override the endpoint for testing
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ resp, err := p.Search(context.Background(), SearchRequest{Query: "golang", MaxResults: 3})
+ require.NoError(t, err)
+ require.Len(t, resp.Results, 3)
+ require.Equal(t, "https://go.dev", resp.Results[0].URL)
+ require.Equal(t, "Go lang", resp.Results[0].Snippet)
+ require.Equal(t, "1 day", resp.Results[0].PageAge)
+}
+
+func TestBraveProvider_Search_DefaultMaxResults(t *testing.T) {
+ var receivedCount string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ receivedCount = r.URL.Query().Get("count")
+ resp := braveResponse{}
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ _, _ = p.Search(context.Background(), SearchRequest{Query: "test", MaxResults: 0})
+ require.Equal(t, "5", receivedCount)
+}
+
+func TestBraveProvider_Search_HTTPError(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(429)
+ _, _ = w.Write([]byte("rate limited"))
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ _, err := p.Search(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "brave: status 429")
+}
+
+func TestBraveProvider_Search_InvalidJSON(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ _, _ = w.Write([]byte("not json"))
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ _, err := p.Search(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "brave: decode response")
+}
+
+func TestBraveProvider_Search_EmptyResults(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ resp := braveResponse{}
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ p := NewBraveProvider("key", srv.Client())
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ resp, err := p.Search(context.Background(), SearchRequest{Query: "test"})
+ require.NoError(t, err)
+ require.Empty(t, resp.Results)
+}
diff --git a/backend/internal/pkg/websearch/helpers.go b/backend/internal/pkg/websearch/helpers.go
new file mode 100644
index 00000000..0d08b749
--- /dev/null
+++ b/backend/internal/pkg/websearch/helpers.go
@@ -0,0 +1,14 @@
+package websearch
+
+const (
+ maxResponseSize = 1 << 20 // 1 MB
+ errorBodyTruncLen = 200
+)
+
+// truncateBody returns a truncated string of body for error messages.
+func truncateBody(body []byte) string {
+ if len(body) <= errorBodyTruncLen {
+ return string(body)
+ }
+ return string(body[:errorBodyTruncLen]) + "...(truncated)"
+}
diff --git a/backend/internal/pkg/websearch/helpers_test.go b/backend/internal/pkg/websearch/helpers_test.go
new file mode 100644
index 00000000..e3164329
--- /dev/null
+++ b/backend/internal/pkg/websearch/helpers_test.go
@@ -0,0 +1,25 @@
+package websearch
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestTruncateBody_Short(t *testing.T) {
+ body := []byte("short body")
+ require.Equal(t, "short body", truncateBody(body))
+}
+
+func TestTruncateBody_Long(t *testing.T) {
+ body := []byte(strings.Repeat("x", 500))
+ result := truncateBody(body)
+ require.Len(t, result, errorBodyTruncLen+len("...(truncated)"))
+ require.True(t, strings.HasSuffix(result, "...(truncated)"))
+}
+
+func TestTruncateBody_ExactBoundary(t *testing.T) {
+ body := []byte(strings.Repeat("x", errorBodyTruncLen))
+ require.Equal(t, string(body), truncateBody(body))
+}
diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go
new file mode 100644
index 00000000..307aa1e9
--- /dev/null
+++ b/backend/internal/pkg/websearch/manager.go
@@ -0,0 +1,528 @@
+package websearch
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "log/slog"
+ "math/rand"
+ "net"
+ "net/http"
+ "net/url"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
+ "github.com/redis/go-redis/v9"
+)
+
+// ProviderConfig holds the configuration for a single search provider.
+type ProviderConfig struct {
+ Type string `json:"type"` // ProviderTypeBrave | ProviderTypeTavily
+ APIKey string `json:"api_key"` // secret
+ QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
+ SubscribedAt *int64 `json:"subscribed_at,omitempty"` // subscription start (unix seconds); quota resets monthly from this date
+ ProxyURL string `json:"-"` // resolved proxy URL (not persisted)
+ ProxyID int64 `json:"-"` // resolved proxy ID for unavailability tracking
+ ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds)
+}
+
+// Manager selects providers by quota-weighted load balancing and tracks quota via Redis.
+type Manager struct {
+ configs []ProviderConfig
+ redis *redis.Client
+
+ clientMu sync.Mutex
+ clientCache map[string]*http.Client
+}
+
+// Timeout constants for proxy and search operations.
+const (
+ proxyDialTimeout = 3 * time.Second // proxy TCP connection timeout
+ proxyTLSTimeout = 3 * time.Second // TLS handshake timeout
+ searchDataTimeout = 60 * time.Second // response data transfer timeout
+ searchRequestTimeout = searchDataTimeout + proxyDialTimeout
+
+ quotaKeyPrefix = "websearch:quota:"
+ proxyUnavailableKey = "websearch:proxy_unavailable:%d"
+ proxyUnavailableTTL = 5 * time.Minute
+ quotaTTLBuffer = 24 * time.Hour
+ defaultQuotaTTL = 31*24*time.Hour + quotaTTLBuffer // fallback when no subscription date
+ maxCachedClients = 100
+)
+
+// ErrProxyUnavailable indicates the search failed due to a proxy connectivity issue.
+// Callers may use this to trigger account switching instead of direct fallback.
+var ErrProxyUnavailable = errors.New("websearch: proxy unavailable")
+
+// quotaIncrScript atomically increments the counter and sets TTL on first creation.
+var quotaIncrScript = redis.NewScript(`
+local val = redis.call('INCR', KEYS[1])
+if val == 1 then
+ redis.call('EXPIRE', KEYS[1], ARGV[1])
+else
+ local ttl = redis.call('TTL', KEYS[1])
+ if ttl == -1 then
+ redis.call('EXPIRE', KEYS[1], ARGV[1])
+ end
+end
+return val
+`)
+
+// NewManager creates a Manager with the given provider configs and Redis client.
+// Provider order is preserved as-is; selectByQuotaWeight handles load balancing.
+func NewManager(configs []ProviderConfig, redisClient *redis.Client) *Manager {
+ copied := make([]ProviderConfig, len(configs))
+ copy(copied, configs)
+ return &Manager{
+ configs: copied,
+ redis: redisClient,
+ clientCache: make(map[string]*http.Client),
+ }
+}
+
+// SearchWithBestProvider selects a provider using quota-weighted load balancing,
+// reserves quota, executes the search, and rolls back quota on failure.
+// If the search fails due to a proxy error, the proxy is marked unavailable for 5 minutes.
+func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
+ if strings.TrimSpace(req.Query) == "" {
+ return nil, "", fmt.Errorf("websearch: empty search query")
+ }
+
+ candidates := m.filterAvailableProviders(ctx, req.ProxyURL)
+ if len(candidates) == 0 {
+ return nil, "", fmt.Errorf("websearch: no available provider (all exhausted, expired, or proxy unavailable)")
+ }
+
+ selected := m.selectByQuotaWeight(ctx, candidates)
+
+ for _, cfg := range selected {
+ allowed, incremented := m.tryReserveQuota(ctx, cfg)
+ if !allowed {
+ continue
+ }
+ resp, err := m.executeSearch(ctx, cfg, req)
+ if err != nil {
+ if incremented {
+ m.rollbackQuota(ctx, cfg)
+ }
+ if isProxyError(err) {
+ m.markProxyUnavailable(ctx, cfg, req.ProxyURL)
+ if req.ProxyURL != "" {
+ // Account-level proxy is shared by all providers — no point
+ // trying others with the same broken proxy; signal account switch.
+ slog.Warn("websearch: account proxy error, aborting failover",
+ "provider", cfg.Type, "error", err)
+ return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error())
+ }
+ // Provider-specific proxy failed — try the next provider which
+ // may use a different (or no) proxy.
+ slog.Warn("websearch: provider proxy error, trying next provider",
+ "provider", cfg.Type, "error", err)
+ continue
+ }
+ slog.Warn("websearch: provider search failed",
+ "provider", cfg.Type, "error", err)
+ continue
+ }
+ return resp, cfg.Type, nil
+ }
+ return nil, "", fmt.Errorf("websearch: no available provider (all exhausted or failed)")
+}
+
+// filterAvailableProviders returns providers that have API keys, are not expired,
+// and whose proxies are not marked unavailable.
+func (m *Manager) filterAvailableProviders(ctx context.Context, accountProxyURL string) []ProviderConfig {
+ var out []ProviderConfig
+ for _, cfg := range m.configs {
+ if !m.isProviderAvailable(cfg) {
+ continue
+ }
+ proxyID := resolveProxyID(cfg, accountProxyURL)
+ if proxyID > 0 && !m.isProxyAvailable(ctx, proxyID) {
+ slog.Debug("websearch: proxy marked unavailable, skipping",
+ "provider", cfg.Type, "proxy_id", proxyID)
+ continue
+ }
+ out = append(out, cfg)
+ }
+ return out
+}
+
+// weighted is a provider candidate with computed quota weight.
+type weighted struct {
+ cfg ProviderConfig
+ weight int64
+}
+
+// selectByQuotaWeight orders candidates by remaining quota weight.
+// Providers with quota_limit=0 (no limit set) get weight 0 and are placed last.
+// Among providers with quota, higher remaining quota = higher priority.
+func (m *Manager) selectByQuotaWeight(ctx context.Context, candidates []ProviderConfig) []ProviderConfig {
+ items := m.computeWeights(ctx, candidates)
+ withQuota, withoutQuota := partitionByQuota(items)
+ sortByStableRandomWeight(withQuota)
+ return mergeWeightedResults(withQuota, withoutQuota, len(candidates))
+}
+
+func (m *Manager) computeWeights(ctx context.Context, candidates []ProviderConfig) []weighted {
+ items := make([]weighted, 0, len(candidates))
+ for _, cfg := range candidates {
+ w := int64(0)
+ if cfg.QuotaLimit > 0 {
+ used, _ := m.GetUsage(ctx, cfg.Type)
+ if remaining := cfg.QuotaLimit - used; remaining > 0 {
+ w = remaining
+ }
+ }
+ items = append(items, weighted{cfg: cfg, weight: w})
+ }
+ return items
+}
+
+func partitionByQuota(items []weighted) (withQuota, withoutQuota []weighted) {
+ for _, item := range items {
+ if item.weight > 0 {
+ withQuota = append(withQuota, item)
+ } else {
+ withoutQuota = append(withoutQuota, item)
+ }
+ }
+ return
+}
+
+// sortByStableRandomWeight assigns a fixed random factor to each item before sorting,
+// ensuring deterministic sort behavior (transitivity) within a single call.
+func sortByStableRandomWeight(items []weighted) {
+ if len(items) <= 1 {
+ return
+ }
+ type entry struct {
+ item weighted
+ factor float64
+ }
+ entries := make([]entry, len(items))
+ for i, item := range items {
+ entries[i] = entry{item: item, factor: float64(item.weight) * (0.5 + rand.Float64())}
+ }
+ sort.Slice(entries, func(i, j int) bool {
+ return entries[i].factor > entries[j].factor
+ })
+ for i, e := range entries {
+ items[i] = e.item
+ }
+}
+
+func mergeWeightedResults(withQuota, withoutQuota []weighted, capacity int) []ProviderConfig {
+ result := make([]ProviderConfig, 0, capacity)
+ for _, item := range withQuota {
+ result = append(result, item.cfg)
+ }
+ for _, item := range withoutQuota {
+ result = append(result, item.cfg)
+ }
+ return result
+}
+
+func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool {
+ if cfg.APIKey == "" {
+ return false
+ }
+ if cfg.ExpiresAt != nil && time.Now().Unix() > *cfg.ExpiresAt {
+ slog.Info("websearch: provider expired, skipping",
+ "provider", cfg.Type, "expires_at", *cfg.ExpiresAt)
+ return false
+ }
+ return true
+}
+
+// --- Proxy availability tracking ---
+
+// markProxyUnavailable marks the effective proxy as unavailable for proxyUnavailableTTL.
+func (m *Manager) markProxyUnavailable(ctx context.Context, cfg ProviderConfig, accountProxyURL string) {
+ proxyID := resolveProxyID(cfg, accountProxyURL)
+ if proxyID <= 0 || m.redis == nil {
+ return
+ }
+ key := fmt.Sprintf(proxyUnavailableKey, proxyID)
+ if err := m.redis.Set(ctx, key, "1", proxyUnavailableTTL).Err(); err != nil {
+ slog.Warn("websearch: failed to mark proxy unavailable",
+ "proxy_id", proxyID, "error", err)
+ }
+}
+
+// isProxyAvailable checks whether a proxy is currently marked as unavailable.
+func (m *Manager) isProxyAvailable(ctx context.Context, proxyID int64) bool {
+ if m.redis == nil || proxyID <= 0 {
+ return true
+ }
+ key := fmt.Sprintf(proxyUnavailableKey, proxyID)
+ val, err := m.redis.Get(ctx, key).Result()
+ if err != nil {
+ return true // Redis error → assume available
+ }
+ return val == ""
+}
+
+// resolveProxyID determines the effective proxy ID for a provider+account combination.
+func resolveProxyID(cfg ProviderConfig, accountProxyURL string) int64 {
+ if accountProxyURL != "" {
+ return 0 // account proxy has no ID in provider config
+ }
+ return cfg.ProxyID
+}
+
+// isProxyError checks whether the error is likely caused by proxy or network connectivity
+// (as opposed to an API-level error from the search provider).
+func isProxyError(err error) bool {
+ if err == nil {
+ return false
+ }
+ // Network-level errors (timeout, connection refused, DNS failure)
+ var netErr net.Error
+ if errors.As(err, &netErr) {
+ return true
+ }
+ var opErr *net.OpError
+ if errors.As(err, &opErr) {
+ return true
+ }
+ // TLS handshake failures (often caused by proxy intercepting/blocking)
+ var tlsErr *tls.RecordHeaderError
+ if errors.As(err, &tlsErr) {
+ return true
+ }
+ // String-based detection for wrapped errors
+ msg := strings.ToLower(err.Error())
+ return strings.Contains(msg, "proxy") ||
+ strings.Contains(msg, "socks") ||
+ strings.Contains(msg, "connection refused") ||
+ strings.Contains(msg, "no such host") ||
+ strings.Contains(msg, "i/o timeout") ||
+ strings.Contains(msg, "tls handshake") ||
+ strings.Contains(msg, "certificate")
+}
+
+// --- Quota management ---
+
+func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool, bool) {
+ if cfg.QuotaLimit <= 0 {
+ return true, false
+ }
+ if m.redis == nil {
+ slog.Warn("websearch: Redis unavailable, quota check skipped", "provider", cfg.Type)
+ return true, false
+ }
+ key := quotaRedisKey(cfg.Type)
+ ttlSec := int(quotaTTLFromSubscription(cfg.SubscribedAt).Seconds())
+ newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64()
+ if err != nil {
+ slog.Warn("websearch: quota Lua INCR failed, allowing request",
+ "provider", cfg.Type, "error", err)
+ return true, false
+ }
+ if newVal > cfg.QuotaLimit {
+ if decrErr := m.redis.Decr(ctx, key).Err(); decrErr != nil {
+ slog.Warn("websearch: quota over-limit DECR failed",
+ "provider", cfg.Type, "error", decrErr)
+ }
+ slog.Info("websearch: provider quota exhausted",
+ "provider", cfg.Type, "used", newVal, "limit", cfg.QuotaLimit)
+ return false, false
+ }
+ return true, true
+}
+
+func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
+ if cfg.QuotaLimit <= 0 || m.redis == nil {
+ return
+ }
+ key := quotaRedisKey(cfg.Type)
+ if err := m.redis.Decr(ctx, key).Err(); err != nil {
+ slog.Warn("websearch: quota rollback DECR failed",
+ "provider", cfg.Type, "error", err)
+ }
+}
+
+// --- Search execution ---
+
+// TestSearch executes a search using the first available provider without reserving quota.
+// Intended for admin test functionality only.
+func (m *Manager) TestSearch(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
+ if strings.TrimSpace(req.Query) == "" {
+ return nil, "", fmt.Errorf("websearch: empty search query")
+ }
+ for _, cfg := range m.configs {
+ if !m.isProviderAvailable(cfg) {
+ continue
+ }
+ resp, err := m.executeSearch(ctx, cfg, req)
+ if err != nil {
+ continue
+ }
+ return resp, cfg.Type, nil
+ }
+ return nil, "", fmt.Errorf("websearch: no available provider")
+}
+
+func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) {
+ proxyURL := cfg.ProxyURL
+ if req.ProxyURL != "" {
+ proxyURL = req.ProxyURL
+ }
+ client, err := m.getOrCreateHTTPClient(proxyURL)
+ if err != nil {
+ return nil, fmt.Errorf("websearch: %w", err)
+ }
+ provider := m.buildProvider(cfg, client)
+ return provider.Search(ctx, req)
+}
+
+// --- HTTP client cache ---
+
+func (m *Manager) getOrCreateHTTPClient(proxyURL string) (*http.Client, error) {
+ m.clientMu.Lock()
+ defer m.clientMu.Unlock()
+
+ if c, ok := m.clientCache[proxyURL]; ok {
+ return c, nil
+ }
+ if len(m.clientCache) >= maxCachedClients {
+ m.clientCache = make(map[string]*http.Client)
+ }
+ c, err := newHTTPClient(proxyURL)
+ if err != nil {
+ return nil, err
+ }
+ m.clientCache[proxyURL] = c
+ return c, nil
+}
+
+// newHTTPClient creates an HTTP client with proper timeout settings.
+// Uses proxyutil.ConfigureTransportProxy for unified proxy protocol support
+// (HTTP/HTTPS/SOCKS5/SOCKS5H).
+// Returns error if proxyURL is invalid — never falls back to direct connection.
+func newHTTPClient(proxyURL string) (*http.Client, error) {
+ transport := &http.Transport{
+ TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
+ DialContext: (&net.Dialer{Timeout: proxyDialTimeout}).DialContext,
+ TLSHandshakeTimeout: proxyTLSTimeout,
+ ResponseHeaderTimeout: searchDataTimeout,
+ }
+ if proxyURL != "" {
+ parsed, err := url.Parse(proxyURL)
+ if err != nil {
+ return nil, fmt.Errorf("invalid proxy URL %q: %w", proxyURL, err)
+ }
+ if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
+ return nil, fmt.Errorf("configure proxy: %w", err)
+ }
+ }
+ return &http.Client{Transport: transport, Timeout: searchRequestTimeout}, nil
+}
+
+// GetUsage returns the current usage count for the given provider.
+func (m *Manager) GetUsage(ctx context.Context, providerType string) (int64, error) {
+ if m.redis == nil {
+ return 0, nil
+ }
+ key := quotaRedisKey(providerType)
+ val, err := m.redis.Get(ctx, key).Int64()
+ if err == redis.Nil {
+ return 0, nil
+ }
+ return val, err
+}
+
+// GetAllUsage returns usage for every configured provider.
+func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 {
+ result := make(map[string]int64, len(m.configs))
+ for _, cfg := range m.configs {
+ used, _ := m.GetUsage(ctx, cfg.Type)
+ result[cfg.Type] = used
+ }
+ return result
+}
+
+// ResetUsage deletes the Redis quota key for the given provider, resetting usage to 0.
+func (m *Manager) ResetUsage(ctx context.Context, providerType string) error {
+ if m.redis == nil {
+ return nil
+ }
+ key := quotaRedisKey(providerType)
+ return m.redis.Del(ctx, key).Err()
+}
+
+// --- Provider factory ---
+
+func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider {
+ switch cfg.Type {
+ case braveProviderName:
+ return NewBraveProvider(cfg.APIKey, client)
+ case tavilyProviderName:
+ return NewTavilyProvider(cfg.APIKey, client)
+ default:
+ slog.Warn("websearch: unknown provider type, falling back to brave",
+ "type", cfg.Type)
+ return NewBraveProvider(cfg.APIKey, client)
+ }
+}
+
+// --- Redis key helpers ---
+
+func quotaRedisKey(providerType string) string {
+ return quotaKeyPrefix + providerType
+}
+
+// quotaTTLFromSubscription calculates the TTL for the quota counter based on
+// the provider's subscription start date. Quota resets monthly from that date.
+// When the Redis key expires naturally, the next INCR creates a fresh counter (lazy refresh).
+func quotaTTLFromSubscription(subscribedAt *int64) time.Duration {
+ if subscribedAt == nil || *subscribedAt == 0 {
+ return defaultQuotaTTL
+ }
+ next := nextMonthlyReset(time.Unix(*subscribedAt, 0).UTC())
+ ttl := time.Until(next) + quotaTTLBuffer
+ if ttl <= quotaTTLBuffer {
+ // Already past the reset — next cycle
+ ttl = defaultQuotaTTL
+ }
+ return ttl
+}
+
+// nextMonthlyReset returns the next monthly reset time based on the subscription start date.
+// E.g., subscribed on Jan 15 → resets on Feb 15, Mar 15, etc.
+// Handles day-of-month overflow: Jan 31 → Feb 28 (not Mar 3).
+func nextMonthlyReset(subscribedAt time.Time) time.Time {
+ now := time.Now().UTC()
+ if subscribedAt.IsZero() {
+ return now.AddDate(0, 1, 0)
+ }
+ months := (now.Year()-subscribedAt.Year())*12 + int(now.Month()-subscribedAt.Month())
+ if months < 0 {
+ months = 0
+ }
+ candidate := addMonthsClamped(subscribedAt, months)
+ if candidate.After(now) {
+ return candidate
+ }
+ return addMonthsClamped(subscribedAt, months+1)
+}
+
+// addMonthsClamped adds N months to a date, clamping the day to the last day of the target month.
+// E.g., Jan 31 + 1 month = Feb 28 (not Mar 3).
+func addMonthsClamped(t time.Time, months int) time.Time {
+ y, m, d := t.Date()
+ targetMonth := time.Month(int(m) + months)
+ targetYear := y + int(targetMonth-1)/12
+ targetMonth = (targetMonth-1)%12 + 1
+ // Last day of the target month
+ lastDay := time.Date(targetYear, targetMonth+1, 0, 0, 0, 0, 0, time.UTC).Day()
+ if d > lastDay {
+ d = lastDay
+ }
+ return time.Date(targetYear, targetMonth, d, 0, 0, 0, 0, time.UTC)
+}
diff --git a/backend/internal/pkg/websearch/manager_test.go b/backend/internal/pkg/websearch/manager_test.go
new file mode 100644
index 00000000..a4413417
--- /dev/null
+++ b/backend/internal/pkg/websearch/manager_test.go
@@ -0,0 +1,323 @@
+package websearch
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewManager_PreservesOrder(t *testing.T) {
+ configs := []ProviderConfig{
+ {Type: "brave", APIKey: "k3"},
+ {Type: "tavily", APIKey: "k1"},
+ }
+ m := NewManager(configs, nil)
+ require.Equal(t, "brave", m.configs[0].Type)
+ require.Equal(t, "tavily", m.configs[1].Type)
+}
+
+func TestManager_SearchWithBestProvider_EmptyQuery(t *testing.T) {
+ m := NewManager([]ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
+ _, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: ""})
+ require.ErrorContains(t, err, "empty search query")
+
+ _, _, err = m.SearchWithBestProvider(context.Background(), SearchRequest{Query: " "})
+ require.ErrorContains(t, err, "empty search query")
+}
+
+func TestManager_SearchWithBestProvider_SkipEmptyAPIKey(t *testing.T) {
+ m := NewManager([]ProviderConfig{{Type: "brave", APIKey: ""}}, nil)
+ _, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "no available provider")
+}
+
+func TestManager_SearchWithBestProvider_SkipExpired(t *testing.T) {
+ past := time.Now().Add(-1 * time.Hour).Unix()
+ m := NewManager([]ProviderConfig{
+ {Type: "brave", APIKey: "k", ExpiresAt: &past},
+ }, nil)
+ _, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.ErrorContains(t, err, "no available provider")
+}
+
+func TestManager_SearchWithBestProvider_UsesFirstAvailable(t *testing.T) {
+ srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ resp := braveResponse{}
+ resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}}
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer srvBrave.Close()
+
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srvBrave.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ m := NewManager([]ProviderConfig{
+ {Type: "brave", APIKey: "k1"},
+ {Type: "tavily", APIKey: "k2"},
+ }, nil)
+ m.clientCache[srvBrave.URL] = srvBrave.Client()
+ m.clientCache[""] = srvBrave.Client()
+
+ resp, providerName, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.NoError(t, err)
+ require.Equal(t, "brave", providerName)
+ require.Len(t, resp.Results, 1)
+ require.Equal(t, "from brave", resp.Results[0].Snippet)
+}
+
+func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ resp := braveResponse{}
+ resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}}
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer srv.Close()
+
+ origURL := *braveSearchURL
+ u, _ := http.NewRequest("GET", srv.URL, nil)
+ *braveSearchURL = *u.URL
+ defer func() { *braveSearchURL = origURL }()
+
+ m := NewManager([]ProviderConfig{
+ {Type: "brave", APIKey: "k", QuotaLimit: 100},
+ }, nil)
+ m.clientCache[""] = srv.Client()
+
+ resp, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
+ require.NoError(t, err)
+ require.Len(t, resp.Results, 1)
+}
+
+func TestManager_GetUsage_NilRedis(t *testing.T) {
+ m := NewManager(nil, nil)
+ used, err := m.GetUsage(context.Background(), "brave")
+ require.NoError(t, err)
+ require.Equal(t, int64(0), used)
+}
+
+func TestManager_GetAllUsage_NilRedis(t *testing.T) {
+ m := NewManager([]ProviderConfig{
+ {Type: "brave"},
+ }, nil)
+ usage := m.GetAllUsage(context.Background())
+ require.Equal(t, int64(0), usage["brave"])
+}
+
+// --- Quota TTL from subscription ---
+
+func TestQuotaTTLFromSubscription_NilSubscription(t *testing.T) {
+ ttl := quotaTTLFromSubscription(nil)
+ require.Equal(t, defaultQuotaTTL, ttl)
+}
+
+func TestQuotaTTLFromSubscription_ZeroSubscription(t *testing.T) {
+ zero := int64(0)
+ ttl := quotaTTLFromSubscription(&zero)
+ require.Equal(t, defaultQuotaTTL, ttl)
+}
+
+func TestQuotaTTLFromSubscription_ValidSubscription(t *testing.T) {
+ // Subscribed 10 days ago — next reset in ~20 days
+ sub := time.Now().Add(-10 * 24 * time.Hour).Unix()
+ ttl := quotaTTLFromSubscription(&sub)
+ require.Greater(t, ttl, 15*24*time.Hour) // at least 15 days
+ require.Less(t, ttl, 25*24*time.Hour+quotaTTLBuffer)
+}
+
+func TestNextMonthlyReset_SubscribedRecentPast(t *testing.T) {
+ // Subscribed on the 10th of this month (always valid day)
+ now := time.Now().UTC()
+ sub := time.Date(now.Year(), now.Month(), 10, 0, 0, 0, 0, time.UTC)
+ next := nextMonthlyReset(sub)
+ require.True(t, next.After(now) || next.Equal(now), "next reset should be in the future or now")
+ require.True(t, next.Before(now.AddDate(0, 1, 1)))
+}
+
+func TestNextMonthlyReset_SubscribedLongAgo(t *testing.T) {
+ // Subscribed 6 months ago on the 1st
+ sub := time.Now().UTC().AddDate(0, -6, 0)
+ sub = time.Date(sub.Year(), sub.Month(), 1, 0, 0, 0, 0, time.UTC)
+ next := nextMonthlyReset(sub)
+ require.True(t, next.After(time.Now().UTC()))
+ // Should be within the next 31 days
+ require.True(t, next.Before(time.Now().UTC().AddDate(0, 1, 1)))
+}
+
+func TestNextMonthlyReset_FutureSubscription(t *testing.T) {
+ sub := time.Now().UTC().AddDate(0, 0, 5)
+ next := nextMonthlyReset(sub)
+ require.True(t, next.After(time.Now().UTC()))
+}
+
+func TestAddMonthsClamped_Jan31ToFeb(t *testing.T) {
+ sub := time.Date(2026, 1, 31, 0, 0, 0, 0, time.UTC)
+ next := addMonthsClamped(sub, 1)
+ require.Equal(t, time.Month(2), next.Month())
+ require.Equal(t, 28, next.Day()) // Feb 28 (2026 is not a leap year)
+}
+
+func TestAddMonthsClamped_Jan31ToFebLeapYear(t *testing.T) {
+ sub := time.Date(2028, 1, 31, 0, 0, 0, 0, time.UTC)
+ next := addMonthsClamped(sub, 1)
+ require.Equal(t, time.Month(2), next.Month())
+ require.Equal(t, 29, next.Day()) // Feb 29 (2028 is a leap year)
+}
+
+func TestAddMonthsClamped_Mar31ToApr(t *testing.T) {
+ sub := time.Date(2026, 3, 31, 0, 0, 0, 0, time.UTC)
+ next := addMonthsClamped(sub, 1)
+ require.Equal(t, time.Month(4), next.Month())
+ require.Equal(t, 30, next.Day()) // Apr has 30 days
+}
+
+func TestAddMonthsClamped_NormalDay(t *testing.T) {
+ sub := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC)
+ next := addMonthsClamped(sub, 1)
+ require.Equal(t, time.Month(2), next.Month())
+ require.Equal(t, 15, next.Day()) // no clamping needed
+}
+
+// --- Redis key ---
+
+func TestQuotaRedisKey_Format(t *testing.T) {
+ key := quotaRedisKey("brave")
+ require.Equal(t, "websearch:quota:brave", key)
+}
+
+// --- isProviderAvailable ---
+
+func TestIsProviderAvailable_EmptyAPIKey(t *testing.T) {
+ m := NewManager(nil, nil)
+ require.False(t, m.isProviderAvailable(ProviderConfig{APIKey: ""}))
+}
+
+func TestIsProviderAvailable_Expired(t *testing.T) {
+ m := NewManager(nil, nil)
+ past := time.Now().Add(-1 * time.Hour).Unix()
+ require.False(t, m.isProviderAvailable(ProviderConfig{APIKey: "k", ExpiresAt: &past}))
+}
+
+func TestIsProviderAvailable_Valid(t *testing.T) {
+ m := NewManager(nil, nil)
+ future := time.Now().Add(1 * time.Hour).Unix()
+ require.True(t, m.isProviderAvailable(ProviderConfig{APIKey: "k", ExpiresAt: &future}))
+ require.True(t, m.isProviderAvailable(ProviderConfig{APIKey: "k"})) // no expiry
+}
+
+// --- resolveProxyID ---
+
+func TestResolveProxyID_AccountProxyOverrides(t *testing.T) {
+ cfg := ProviderConfig{ProxyID: 42}
+ require.Equal(t, int64(0), resolveProxyID(cfg, "http://account-proxy:8080"))
+ require.Equal(t, int64(42), resolveProxyID(cfg, ""))
+}
+
+// --- isProxyError ---
+
+func TestIsProxyError_Nil(t *testing.T) {
+ require.False(t, isProxyError(nil))
+}
+
+func TestIsProxyError_ConnectionRefused(t *testing.T) {
+ require.True(t, isProxyError(fmt.Errorf("dial tcp: connection refused")))
+}
+
+func TestIsProxyError_Timeout(t *testing.T) {
+ require.True(t, isProxyError(fmt.Errorf("i/o timeout while connecting to proxy")))
+}
+
+func TestIsProxyError_SOCKS(t *testing.T) {
+ require.True(t, isProxyError(fmt.Errorf("socks connect failed")))
+}
+
+func TestIsProxyError_TLSHandshake(t *testing.T) {
+ require.True(t, isProxyError(fmt.Errorf("tls handshake timeout")))
+}
+
+func TestIsProxyError_APIError_NotProxy(t *testing.T) {
+ require.False(t, isProxyError(fmt.Errorf("API rate limit exceeded")))
+}
+
+// --- isProxyAvailable (nil Redis) ---
+
+func TestIsProxyAvailable_NilRedis(t *testing.T) {
+ m := NewManager(nil, nil)
+ require.True(t, m.isProxyAvailable(context.Background(), 42))
+}
+
+func TestIsProxyAvailable_ZeroID(t *testing.T) {
+ m := NewManager(nil, nil)
+ require.True(t, m.isProxyAvailable(context.Background(), 0))
+}
+
+// --- selectByQuotaWeight ---
+
+func TestSelectByQuotaWeight_NoQuotaLast(t *testing.T) {
+ m := NewManager(nil, nil)
+ candidates := []ProviderConfig{
+ {Type: "brave", APIKey: "k1", QuotaLimit: 0},
+ {Type: "tavily", APIKey: "k2", QuotaLimit: 100},
+ }
+ result := m.selectByQuotaWeight(context.Background(), candidates)
+ require.Len(t, result, 2)
+ require.Equal(t, "tavily", result[0].Type)
+ require.Equal(t, "brave", result[1].Type)
+}
+
+func TestSelectByQuotaWeight_AllNoQuota(t *testing.T) {
+ m := NewManager(nil, nil)
+ candidates := []ProviderConfig{
+ {Type: "brave", APIKey: "k1", QuotaLimit: 0},
+ {Type: "tavily", APIKey: "k2", QuotaLimit: 0},
+ }
+ result := m.selectByQuotaWeight(context.Background(), candidates)
+ require.Len(t, result, 2)
+}
+
+func TestSelectByQuotaWeight_Empty(t *testing.T) {
+ m := NewManager(nil, nil)
+ result := m.selectByQuotaWeight(context.Background(), nil)
+ require.Empty(t, result)
+}
+
+// --- newHTTPClient ---
+
+func TestNewHTTPClient_NoProxy(t *testing.T) {
+ c, err := newHTTPClient("")
+ require.NoError(t, err)
+ require.NotNil(t, c)
+}
+
+func TestNewHTTPClient_InvalidProxy(t *testing.T) {
+ _, err := newHTTPClient("://bad-url")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "invalid proxy URL")
+}
+
+func TestNewHTTPClient_ValidHTTPProxy(t *testing.T) {
+ c, err := newHTTPClient("http://proxy.example.com:8080")
+ require.NoError(t, err)
+ require.NotNil(t, c)
+}
+
+func TestNewHTTPClient_ValidSOCKS5Proxy(t *testing.T) {
+ c, err := newHTTPClient("socks5://proxy.example.com:1080")
+ require.NoError(t, err)
+ require.NotNil(t, c)
+}
+
+// --- ResetUsage ---
+
+func TestManager_ResetUsage_NilRedis(t *testing.T) {
+ m := NewManager(nil, nil)
+ err := m.ResetUsage(context.Background(), "brave")
+ require.NoError(t, err)
+}
diff --git a/backend/internal/pkg/websearch/provider.go b/backend/internal/pkg/websearch/provider.go
new file mode 100644
index 00000000..3424c056
--- /dev/null
+++ b/backend/internal/pkg/websearch/provider.go
@@ -0,0 +1,11 @@
+package websearch
+
+import "context"
+
+// Provider is the interface every search backend must implement.
+type Provider interface {
+ // Name returns the provider identifier ("brave" or "tavily").
+ Name() string
+ // Search executes a web search and returns results.
+ Search(ctx context.Context, req SearchRequest) (*SearchResponse, error)
+}
diff --git a/backend/internal/pkg/websearch/tavily.go b/backend/internal/pkg/websearch/tavily.go
new file mode 100644
index 00000000..ac4928a6
--- /dev/null
+++ b/backend/internal/pkg/websearch/tavily.go
@@ -0,0 +1,107 @@
+package websearch
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+)
+
+const (
+ tavilySearchEndpoint = "https://api.tavily.com/search"
+ tavilyProviderName = "tavily"
+ tavilySearchDepthBasic = "basic"
+)
+
+// TavilyProvider implements web search via the Tavily Search API.
+type TavilyProvider struct {
+ apiKey string
+ httpClient *http.Client
+}
+
+// NewTavilyProvider creates a Tavily Search provider.
+// The caller is responsible for configuring the http.Client with proxy/timeouts.
+func NewTavilyProvider(apiKey string, httpClient *http.Client) *TavilyProvider {
+ if httpClient == nil {
+ httpClient = http.DefaultClient
+ }
+ return &TavilyProvider{apiKey: apiKey, httpClient: httpClient}
+}
+
+func (t *TavilyProvider) Name() string { return tavilyProviderName }
+
+func (t *TavilyProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
+ maxResults := req.MaxResults
+ if maxResults <= 0 {
+ maxResults = defaultMaxResults
+ }
+
+ payload := tavilyRequest{
+ APIKey: t.apiKey,
+ Query: req.Query,
+ MaxResults: maxResults,
+ SearchDepth: tavilySearchDepthBasic,
+ }
+
+ bodyBytes, err := json.Marshal(payload)
+ if err != nil {
+ return nil, fmt.Errorf("tavily: encode request: %w", err)
+ }
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tavilySearchEndpoint, bytes.NewReader(bodyBytes))
+ if err != nil {
+ return nil, fmt.Errorf("tavily: build request: %w", err)
+ }
+ httpReq.Header.Set("Content-Type", "application/json")
+
+ resp, err := t.httpClient.Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("tavily: request failed: %w", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
+ if err != nil {
+ return nil, fmt.Errorf("tavily: read body: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("tavily: status %d: %s", resp.StatusCode, truncateBody(body))
+ }
+
+ var raw tavilyResponse
+ if err := json.Unmarshal(body, &raw); err != nil {
+ return nil, fmt.Errorf("tavily: decode response: %w", err)
+ }
+
+ results := make([]SearchResult, 0, len(raw.Results))
+ for _, r := range raw.Results {
+ results = append(results, SearchResult{
+ URL: r.URL,
+ Title: r.Title,
+ Snippet: r.Content,
+ })
+ }
+
+ return &SearchResponse{Results: results, Query: req.Query}, nil
+}
+
+type tavilyRequest struct {
+ APIKey string `json:"api_key"`
+ Query string `json:"query"`
+ MaxResults int `json:"max_results"`
+ SearchDepth string `json:"search_depth"`
+}
+
+type tavilyResponse struct {
+ Results []tavilyResult `json:"results"`
+}
+
+type tavilyResult struct {
+ URL string `json:"url"`
+ Title string `json:"title"`
+ Content string `json:"content"`
+ Score float64 `json:"score"`
+}
diff --git a/backend/internal/pkg/websearch/tavily_test.go b/backend/internal/pkg/websearch/tavily_test.go
new file mode 100644
index 00000000..e1b6819a
--- /dev/null
+++ b/backend/internal/pkg/websearch/tavily_test.go
@@ -0,0 +1,63 @@
+package websearch
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestTavilyProvider_Name(t *testing.T) {
+ p := NewTavilyProvider("key", nil)
+ require.Equal(t, "tavily", p.Name())
+}
+
+func TestTavilyProvider_Search_RequestConstruction(t *testing.T) {
+ // Verify tavilyRequest struct fields map correctly
+ req := tavilyRequest{
+ APIKey: "test-key",
+ Query: "golang",
+ MaxResults: 3,
+ SearchDepth: tavilySearchDepthBasic,
+ }
+ data, err := json.Marshal(req)
+ require.NoError(t, err)
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(data, &parsed))
+ require.Equal(t, "test-key", parsed["api_key"])
+ require.Equal(t, "golang", parsed["query"])
+ require.Equal(t, float64(3), parsed["max_results"])
+ require.Equal(t, "basic", parsed["search_depth"])
+}
+
+func TestTavilyProvider_Search_ResponseParsing(t *testing.T) {
+ rawResp := `{"results":[{"url":"https://go.dev","title":"Go","content":"Go programming language","score":0.95}]}`
+ var resp tavilyResponse
+ require.NoError(t, json.Unmarshal([]byte(rawResp), &resp))
+ require.Len(t, resp.Results, 1)
+ require.Equal(t, "https://go.dev", resp.Results[0].URL)
+ require.Equal(t, "Go programming language", resp.Results[0].Content)
+ require.InDelta(t, 0.95, resp.Results[0].Score, 0.001)
+
+ // Verify mapping to SearchResult
+ results := make([]SearchResult, 0, len(resp.Results))
+ for _, r := range resp.Results {
+ results = append(results, SearchResult{
+ URL: r.URL, Title: r.Title, Snippet: r.Content,
+ })
+ }
+ require.Equal(t, "Go programming language", results[0].Snippet)
+ require.Equal(t, "", results[0].PageAge)
+}
+
+func TestTavilyProvider_Search_EmptyResults(t *testing.T) {
+ var resp tavilyResponse
+ require.NoError(t, json.Unmarshal([]byte(`{"results":[]}`), &resp))
+ require.Empty(t, resp.Results)
+}
+
+func TestTavilyProvider_Search_InvalidJSON(t *testing.T) {
+ var resp tavilyResponse
+ require.Error(t, json.Unmarshal([]byte("not json"), &resp))
+}
diff --git a/backend/internal/pkg/websearch/types.go b/backend/internal/pkg/websearch/types.go
new file mode 100644
index 00000000..bb489690
--- /dev/null
+++ b/backend/internal/pkg/websearch/types.go
@@ -0,0 +1,30 @@
+package websearch
+
+// SearchResult represents a single web search result.
+type SearchResult struct {
+ URL string `json:"url"`
+ Title string `json:"title"`
+ Snippet string `json:"snippet"`
+ PageAge string `json:"page_age,omitempty"`
+}
+
+// SearchRequest describes a web search to perform.
+type SearchRequest struct {
+ Query string
+ MaxResults int // defaults to defaultMaxResults if <= 0
+ ProxyURL string // optional HTTP proxy URL
+}
+
+// SearchResponse holds the results of a web search.
+type SearchResponse struct {
+ Results []SearchResult
+ Query string // the query that was actually executed
+}
+
+const defaultMaxResults = 5
+
+// Provider type identifiers.
+const (
+ ProviderTypeBrave = "brave"
+ ProviderTypeTavily = "tavily"
+)
diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go
index d45e8a12..78f739ac 100644
--- a/backend/internal/repository/account_repo.go
+++ b/backend/internal/repository/account_repo.go
@@ -438,6 +438,9 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
if _, err := txClient.AccountGroup.Delete().Where(dbaccountgroup.AccountIDEQ(id)).Exec(ctx); err != nil {
return err
}
+ if _, err := txClient.ExecContext(ctx, "DELETE FROM scheduled_test_plans WHERE account_id = $1", id); err != nil {
+ return err
+ }
if _, err := txClient.Account.Delete().Where(dbaccount.IDEQ(id)).Exec(ctx); err != nil {
return err
}
@@ -468,16 +471,61 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
}
if status != "" {
switch status {
+ case service.StatusActive:
+ q = q.Where(
+ dbaccount.StatusEQ(status),
+ dbaccount.SchedulableEQ(true),
+ dbaccount.Or(
+ dbaccount.RateLimitResetAtIsNil(),
+ dbaccount.RateLimitResetAtLTE(time.Now()),
+ ),
+ dbpredicate.Account(func(s *entsql.Selector) {
+ col := s.C("temp_unschedulable_until")
+ s.Where(entsql.Or(
+ entsql.IsNull(col),
+ entsql.LTE(col, entsql.Expr("NOW()")),
+ ))
+ }),
+ )
case "rate_limited":
- q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
+ q = q.Where(
+ dbaccount.StatusEQ(service.StatusActive),
+ dbaccount.RateLimitResetAtGT(time.Now()),
+ dbpredicate.Account(func(s *entsql.Selector) {
+ col := s.C("temp_unschedulable_until")
+ s.Where(entsql.Or(
+ entsql.IsNull(col),
+ entsql.LTE(col, entsql.Expr("NOW()")),
+ ))
+ }),
+ )
case "temp_unschedulable":
- q = q.Where(dbpredicate.Account(func(s *entsql.Selector) {
- col := s.C("temp_unschedulable_until")
- s.Where(entsql.And(
- entsql.Not(entsql.IsNull(col)),
- entsql.GT(col, entsql.Expr("NOW()")),
- ))
- }))
+ q = q.Where(
+ dbaccount.StatusEQ(service.StatusActive),
+ dbpredicate.Account(func(s *entsql.Selector) {
+ col := s.C("temp_unschedulable_until")
+ s.Where(entsql.And(
+ entsql.Not(entsql.IsNull(col)),
+ entsql.GT(col, entsql.Expr("NOW()")),
+ ))
+ }),
+ )
+ case "unschedulable":
+ q = q.Where(
+ dbaccount.StatusEQ(service.StatusActive),
+ dbaccount.SchedulableEQ(false),
+ dbaccount.Or(
+ dbaccount.RateLimitResetAtIsNil(),
+ dbaccount.RateLimitResetAtLTE(time.Now()),
+ ),
+ dbpredicate.Account(func(s *entsql.Selector) {
+ col := s.C("temp_unschedulable_until")
+ s.Where(entsql.Or(
+ entsql.IsNull(col),
+ entsql.LTE(col, entsql.Expr("NOW()")),
+ ))
+ }),
+ )
default:
q = q.Where(dbaccount.StatusEQ(status))
}
@@ -510,11 +558,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
return nil, nil, err
}
- accounts, err := q.
+ accountsQuery := q.
Offset(params.Offset()).
- Limit(params.Limit()).
- Order(dbent.Desc(dbaccount.FieldID)).
- All(ctx)
+ Limit(params.Limit())
+ for _, order := range accountListOrder(params) {
+ accountsQuery = accountsQuery.Order(order)
+ }
+
+ accounts, err := accountsQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -526,6 +577,50 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
return outAccounts, paginationResultFromTotal(int64(total), params), nil
}
+func accountListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
+ sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
+ sortOrder := params.NormalizedSortOrder(pagination.SortOrderAsc)
+
+ field := dbaccount.FieldName
+ defaultOrder := true
+ switch sortBy {
+ case "", "name":
+ field = dbaccount.FieldName
+ case "id":
+ field = dbaccount.FieldID
+ defaultOrder = false
+ case "status":
+ field = dbaccount.FieldStatus
+ defaultOrder = false
+ case "schedulable":
+ field = dbaccount.FieldSchedulable
+ defaultOrder = false
+ case "priority":
+ field = dbaccount.FieldPriority
+ defaultOrder = false
+ case "rate_multiplier":
+ field = dbaccount.FieldRateMultiplier
+ defaultOrder = false
+ case "last_used_at":
+ field = dbaccount.FieldLastUsedAt
+ defaultOrder = false
+ case "expires_at":
+ field = dbaccount.FieldExpiresAt
+ defaultOrder = false
+ case "created_at":
+ field = dbaccount.FieldCreatedAt
+ defaultOrder = false
+ }
+
+ if sortOrder == pagination.SortOrderDesc {
+ return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbaccount.FieldID)}
+ }
+ if defaultOrder {
+ return []func(*entsql.Selector){dbent.Asc(dbaccount.FieldName), dbent.Asc(dbaccount.FieldID)}
+ }
+ return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbaccount.FieldID)}
+}
+
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
accounts, err := r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
status: service.StatusActive,
@@ -1692,20 +1787,13 @@ func itoa(v int) string {
}
// FindByExtraField 根据 extra 字段中的键值对查找账号。
-// 该方法限定 platform='sora',避免误查询其他平台的账号。
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
//
-// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
-//
// FindByExtraField finds accounts by key-value pairs in the extra field.
-// Limited to platform='sora' to avoid querying accounts from other platforms.
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
-//
-// Use case: Finding Sora accounts linked via linked_openai_account_id.
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
accounts, err := r.client.Account.Query().
Where(
- dbaccount.PlatformEQ("sora"), // 限定平台为 sora
dbaccount.DeletedAtIsNil(),
func(s *entsql.Selector) {
path := sqljson.Path(key)
diff --git a/backend/internal/repository/account_repo_compact_extra_test.go b/backend/internal/repository/account_repo_compact_extra_test.go
new file mode 100644
index 00000000..604f392e
--- /dev/null
+++ b/backend/internal/repository/account_repo_compact_extra_test.go
@@ -0,0 +1,14 @@
+package repository
+
+import "testing"
+
+func TestShouldEnqueueSchedulerOutboxForExtraUpdates_CompactCapabilityKeysAreRelevant(t *testing.T) {
+ updates := map[string]any{
+ "openai_compact_supported": true,
+ "openai_compact_checked_at": "2026-04-10T10:00:00Z",
+ }
+
+ if !shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
+ t.Fatalf("expected compact capability updates to enqueue scheduler outbox")
+ }
+}
diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go
index 8da30c92..b249bb61 100644
--- a/backend/internal/repository/account_repo_integration_test.go
+++ b/backend/internal/repository/account_repo_integration_test.go
@@ -255,6 +255,101 @@ func (s *AccountRepoSuite) TestListWithFilters() {
s.Require().Equal(service.StatusDisabled, accounts[0].Status)
},
},
+ {
+ name: "filter_by_status_active_excludes_runtime_blocked_accounts",
+ setup: func(client *dbent.Client) {
+ mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive})
+ rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive})
+ err := client.Account.UpdateOneID(rateLimited.ID).
+ SetRateLimitResetAt(time.Now().Add(10 * time.Minute)).
+ Exec(context.Background())
+ s.Require().NoError(err)
+ tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive})
+ err = client.Account.UpdateOneID(tempUnsched.ID).
+ SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
+ Exec(context.Background())
+ s.Require().NoError(err)
+ unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive})
+ err = client.Account.UpdateOneID(unsched.ID).
+ SetSchedulable(false).
+ Exec(context.Background())
+ s.Require().NoError(err)
+ },
+ status: service.StatusActive,
+ wantCount: 1,
+ validate: func(accounts []service.Account) {
+ s.Require().Equal("active-normal", accounts[0].Name)
+ },
+ },
+ {
+ name: "filter_by_status_unschedulable_excludes_rate_limited_and_temp_unschedulable",
+ setup: func(client *dbent.Client) {
+ mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive, Schedulable: true})
+ unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive})
+ err := client.Account.UpdateOneID(unsched.ID).
+ SetSchedulable(false).
+ Exec(context.Background())
+ s.Require().NoError(err)
+ rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive})
+ err = client.Account.UpdateOneID(rateLimited.ID).
+ SetSchedulable(false).
+ SetRateLimitResetAt(time.Now().Add(10 * time.Minute)).
+ Exec(context.Background())
+ s.Require().NoError(err)
+ tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive})
+ err = client.Account.UpdateOneID(tempUnsched.ID).
+ SetSchedulable(false).
+ SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
+ Exec(context.Background())
+ s.Require().NoError(err)
+ },
+ status: "unschedulable",
+ wantCount: 1,
+ validate: func(accounts []service.Account) {
+ s.Require().Equal("active-unsched", accounts[0].Name)
+ },
+ },
+ {
+ name: "filter_by_status_rate_limited_excludes_temp_unschedulable",
+ setup: func(client *dbent.Client) {
+ rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive})
+ err := client.Account.UpdateOneID(rateLimited.ID).
+ SetRateLimitResetAt(time.Now().Add(10 * time.Minute)).
+ Exec(context.Background())
+ s.Require().NoError(err)
+ tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive})
+ err = client.Account.UpdateOneID(tempUnsched.ID).
+ SetRateLimitResetAt(time.Now().Add(20 * time.Minute)).
+ SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
+ Exec(context.Background())
+ s.Require().NoError(err)
+ },
+ status: "rate_limited",
+ wantCount: 1,
+ validate: func(accounts []service.Account) {
+ s.Require().Equal("active-rate-limited", accounts[0].Name)
+ },
+ },
+ {
+ name: "filter_by_status_temp_unschedulable_excludes_manually_unschedulable",
+ setup: func(client *dbent.Client) {
+ tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive, Schedulable: true})
+ err := client.Account.UpdateOneID(tempUnsched.ID).
+ SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
+ Exec(context.Background())
+ s.Require().NoError(err)
+ unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive})
+ err = client.Account.UpdateOneID(unsched.ID).
+ SetSchedulable(false).
+ Exec(context.Background())
+ s.Require().NoError(err)
+ },
+ status: "temp_unschedulable",
+ wantCount: 1,
+ validate: func(accounts []service.Account) {
+ s.Require().Equal("active-temp-unsched", accounts[0].Name)
+ },
+ },
{
name: "filter_by_search",
setup: func(client *dbent.Client) {
diff --git a/backend/internal/repository/account_repo_sort_integration_test.go b/backend/internal/repository/account_repo_sort_integration_test.go
new file mode 100644
index 00000000..098dde7b
--- /dev/null
+++ b/backend/internal/repository/account_repo_sort_integration_test.go
@@ -0,0 +1,35 @@
+//go:build integration
+
+package repository
+
+import (
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func (s *AccountRepoSuite) TestList_DefaultSortByNameAsc() {
+ mustCreateAccount(s.T(), s.client, &service.Account{Name: "z-account"})
+ mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-account"})
+
+ accounts, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
+ s.Require().NoError(err)
+ s.Require().Len(accounts, 2)
+ s.Require().Equal("a-account", accounts[0].Name)
+ s.Require().Equal("z-account", accounts[1].Name)
+}
+
+func (s *AccountRepoSuite) TestListWithFilters_SortByPriorityDesc() {
+ mustCreateAccount(s.T(), s.client, &service.Account{Name: "low-priority", Priority: 10})
+ mustCreateAccount(s.T(), s.client, &service.Account{Name: "high-priority", Priority: 90})
+
+ accounts, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "priority",
+ SortOrder: "desc",
+ }, "", "", "", "", 0, "")
+ s.Require().NoError(err)
+ s.Require().Len(accounts, 2)
+ s.Require().Equal("high-priority", accounts[0].Name)
+ s.Require().Equal("low-priority", accounts[1].Name)
+}
diff --git a/backend/internal/repository/affiliate_repo.go b/backend/internal/repository/affiliate_repo.go
new file mode 100644
index 00000000..ef89e5b6
--- /dev/null
+++ b/backend/internal/repository/affiliate_repo.go
@@ -0,0 +1,762 @@
+package repository
+
+import (
+ "context"
+ "crypto/rand"
+ "database/sql"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/user"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
+)
+
+const (
+ affiliateCodeLength = 12
+ affiliateCodeMaxAttempts = 12
+)
+
+var affiliateCodeCharset = []byte("ABCDEFGHJKLMNPQRSTUVWXYZ23456789")
+
+type affiliateQueryExecer interface {
+ QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
+ ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
+}
+
+type affiliateRepository struct {
+ client *dbent.Client
+}
+
+func NewAffiliateRepository(client *dbent.Client, _ *sql.DB) service.AffiliateRepository {
+ return &affiliateRepository{client: client}
+}
+
+func (r *affiliateRepository) EnsureUserAffiliate(ctx context.Context, userID int64) (*service.AffiliateSummary, error) {
+ if userID <= 0 {
+ return nil, service.ErrUserNotFound
+ }
+ client := clientFromContext(ctx, r.client)
+ return ensureUserAffiliateWithClient(ctx, client, userID)
+}
+
+func (r *affiliateRepository) GetAffiliateByCode(ctx context.Context, code string) (*service.AffiliateSummary, error) {
+ client := clientFromContext(ctx, r.client)
+ return queryAffiliateByCode(ctx, client, code)
+}
+
+func (r *affiliateRepository) BindInviter(ctx context.Context, userID, inviterID int64) (bool, error) {
+ var bound bool
+ err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
+ return err
+ }
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, inviterID); err != nil {
+ return err
+ }
+
+ res, err := txClient.ExecContext(txCtx,
+ "UPDATE user_affiliates SET inviter_id = $1, updated_at = NOW() WHERE user_id = $2 AND inviter_id IS NULL",
+ inviterID, userID,
+ )
+ if err != nil {
+ return fmt.Errorf("bind inviter: %w", err)
+ }
+ affected, _ := res.RowsAffected()
+ if affected == 0 {
+ bound = false
+ return nil
+ }
+
+ if _, err = txClient.ExecContext(txCtx,
+ "UPDATE user_affiliates SET aff_count = aff_count + 1, updated_at = NOW() WHERE user_id = $1",
+ inviterID,
+ ); err != nil {
+ return fmt.Errorf("increment inviter aff_count: %w", err)
+ }
+ bound = true
+ return nil
+ })
+ if err != nil {
+ return false, err
+ }
+ return bound, nil
+}
+
+func (r *affiliateRepository) AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error) {
+ if amount <= 0 {
+ return false, nil
+ }
+
+ var applied bool
+ err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ // freezeHours > 0: add to frozen quota; == 0: add to available quota directly
+ var updateSQL string
+ if freezeHours > 0 {
+ updateSQL = "UPDATE user_affiliates SET aff_frozen_quota = aff_frozen_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
+ } else {
+ updateSQL = "UPDATE user_affiliates SET aff_quota = aff_quota + $1, aff_history_quota = aff_history_quota + $1, updated_at = NOW() WHERE user_id = $2"
+ }
+ res, err := txClient.ExecContext(txCtx, updateSQL, amount, inviterID)
+ if err != nil {
+ return err
+ }
+ affected, _ := res.RowsAffected()
+ if affected == 0 {
+ applied = false
+ return nil
+ }
+
+ if freezeHours > 0 {
+ if _, err = txClient.ExecContext(txCtx, `
+INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, frozen_until, created_at, updated_at)
+VALUES ($1, 'accrue', $2, $3, NOW() + make_interval(hours => $4), NOW(), NOW())`,
+ inviterID, amount, inviteeUserID, freezeHours); err != nil {
+ return fmt.Errorf("insert affiliate accrue ledger: %w", err)
+ }
+ } else {
+ if _, err = txClient.ExecContext(txCtx, `
+INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
+VALUES ($1, 'accrue', $2, $3, NOW(), NOW())`, inviterID, amount, inviteeUserID); err != nil {
+ return fmt.Errorf("insert affiliate accrue ledger: %w", err)
+ }
+ }
+
+ applied = true
+ return nil
+ })
+ if err != nil {
+ return false, err
+ }
+ return applied, nil
+}
+
+func (r *affiliateRepository) GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error) {
+ client := clientFromContext(ctx, r.client)
+ rows, err := client.QueryContext(ctx,
+ `SELECT COALESCE(SUM(amount), 0)::double precision FROM user_affiliate_ledger WHERE user_id = $1 AND source_user_id = $2 AND action = 'accrue'`,
+ inviterID, inviteeUserID)
+ if err != nil {
+ return 0, fmt.Errorf("query accrued rebate from invitee: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+ var total float64
+ if rows.Next() {
+ if err := rows.Scan(&total); err != nil {
+ return 0, err
+ }
+ }
+ return total, rows.Close()
+}
+
+func (r *affiliateRepository) ThawFrozenQuota(ctx context.Context, userID int64) (float64, error) {
+ var thawed float64
+ err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ var err error
+ thawed, err = thawFrozenQuotaTx(txCtx, txClient, userID)
+ return err
+ })
+ return thawed, err
+}
+
+// thawFrozenQuotaTx moves matured frozen quota to available quota within an existing tx.
+func thawFrozenQuotaTx(txCtx context.Context, txClient *dbent.Client, userID int64) (float64, error) {
+ rows, err := txClient.QueryContext(txCtx, `
+WITH matured AS (
+ UPDATE user_affiliate_ledger
+ SET frozen_until = NULL, updated_at = NOW()
+ WHERE user_id = $1
+ AND frozen_until IS NOT NULL
+ AND frozen_until <= NOW()
+ RETURNING amount
+)
+SELECT COALESCE(SUM(amount), 0) FROM matured`, userID)
+ if err != nil {
+ return 0, fmt.Errorf("thaw frozen quota: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ var thawed float64
+ if rows.Next() {
+ if err := rows.Scan(&thawed); err != nil {
+ return 0, err
+ }
+ }
+ if err := rows.Close(); err != nil {
+ return 0, err
+ }
+ if thawed <= 0 {
+ return 0, nil
+ }
+
+ _, err = txClient.ExecContext(txCtx, `
+UPDATE user_affiliates
+SET aff_quota = aff_quota + $1,
+ aff_frozen_quota = GREATEST(aff_frozen_quota - $1, 0),
+ updated_at = NOW()
+WHERE user_id = $2`, thawed, userID)
+ if err != nil {
+ return 0, fmt.Errorf("move thawed quota: %w", err)
+ }
+ return thawed, nil
+}
+
+func (r *affiliateRepository) TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error) {
+ var transferred float64
+ var newBalance float64
+
+ err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
+ return err
+ }
+
+ // Thaw any matured frozen quota before transfer.
+ if _, err := thawFrozenQuotaTx(txCtx, txClient, userID); err != nil {
+ return fmt.Errorf("thaw before transfer: %w", err)
+ }
+
+ rows, err := txClient.QueryContext(txCtx, `
+WITH claimed AS (
+ SELECT aff_quota::double precision AS amount
+ FROM user_affiliates
+ WHERE user_id = $1
+ AND aff_quota > 0
+ FOR UPDATE
+),
+cleared AS (
+ UPDATE user_affiliates ua
+ SET aff_quota = 0,
+ updated_at = NOW()
+ FROM claimed c
+ WHERE ua.user_id = $1
+ RETURNING c.amount
+)
+SELECT amount
+FROM cleared`, userID)
+ if err != nil {
+ return fmt.Errorf("claim affiliate quota: %w", err)
+ }
+
+ if !rows.Next() {
+ _ = rows.Close()
+ if err := rows.Err(); err != nil {
+ return err
+ }
+ return service.ErrAffiliateQuotaEmpty
+ }
+ if err := rows.Scan(&transferred); err != nil {
+ _ = rows.Close()
+ return err
+ }
+ if err := rows.Close(); err != nil {
+ return err
+ }
+ if transferred <= 0 {
+ return service.ErrAffiliateQuotaEmpty
+ }
+
+ affected, err := txClient.User.Update().
+ Where(user.IDEQ(userID)).
+ AddBalance(transferred).
+ AddTotalRecharged(transferred).
+ Save(txCtx)
+ if err != nil {
+ return fmt.Errorf("credit user balance by affiliate quota: %w", err)
+ }
+ if affected == 0 {
+ return service.ErrUserNotFound
+ }
+
+ newBalance, err = queryUserBalance(txCtx, txClient, userID)
+ if err != nil {
+ return err
+ }
+
+ if _, err = txClient.ExecContext(txCtx, `
+INSERT INTO user_affiliate_ledger (user_id, action, amount, source_user_id, created_at, updated_at)
+VALUES ($1, 'transfer', $2, NULL, NOW(), NOW())`, userID, transferred); err != nil {
+ return fmt.Errorf("insert affiliate transfer ledger: %w", err)
+ }
+
+ return nil
+ })
+ if err != nil {
+ return 0, 0, err
+ }
+
+ return transferred, newBalance, nil
+}
+
+func (r *affiliateRepository) ListInvitees(ctx context.Context, inviterID int64, limit int) ([]service.AffiliateInvitee, error) {
+ if limit <= 0 {
+ limit = 100
+ }
+ client := clientFromContext(ctx, r.client)
+ rows, err := client.QueryContext(ctx, `
+SELECT ua.user_id,
+ COALESCE(u.email, ''),
+ COALESCE(u.username, ''),
+ ua.created_at,
+ COALESCE(SUM(ual.amount), 0)::double precision AS total_rebate
+FROM user_affiliates ua
+LEFT JOIN users u ON u.id = ua.user_id
+LEFT JOIN user_affiliate_ledger ual
+ ON ual.user_id = $1
+ AND ual.source_user_id = ua.user_id
+ AND ual.action = 'accrue'
+WHERE ua.inviter_id = $1
+GROUP BY ua.user_id, u.email, u.username, ua.created_at
+ORDER BY ua.created_at DESC
+LIMIT $2`, inviterID, limit)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ invitees := make([]service.AffiliateInvitee, 0)
+ for rows.Next() {
+ var item service.AffiliateInvitee
+ var createdAt time.Time
+ if err := rows.Scan(&item.UserID, &item.Email, &item.Username, &createdAt, &item.TotalRebate); err != nil {
+ return nil, err
+ }
+ item.CreatedAt = &createdAt
+ invitees = append(invitees, item)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return invitees, nil
+}
+
+func (r *affiliateRepository) withTx(ctx context.Context, fn func(txCtx context.Context, txClient *dbent.Client) error) error {
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return fn(ctx, tx.Client())
+ }
+
+ tx, err := r.client.Tx(ctx)
+ if err != nil {
+ return fmt.Errorf("begin affiliate transaction: %w", err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := fn(txCtx, tx.Client()); err != nil {
+ return err
+ }
+
+ if err := tx.Commit(); err != nil {
+ return fmt.Errorf("commit affiliate transaction: %w", err)
+ }
+ return nil
+}
+
+func ensureUserAffiliateWithClient(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
+ summary, err := queryAffiliateByUserID(ctx, client, userID)
+ if err == nil {
+ return summary, nil
+ }
+ if !errors.Is(err, service.ErrAffiliateProfileNotFound) {
+ return nil, err
+ }
+
+ for i := 0; i < affiliateCodeMaxAttempts; i++ {
+ code, codeErr := generateAffiliateCode()
+ if codeErr != nil {
+ return nil, codeErr
+ }
+ _, insertErr := client.ExecContext(ctx, `
+INSERT INTO user_affiliates (user_id, aff_code, created_at, updated_at)
+VALUES ($1, $2, NOW(), NOW())
+ON CONFLICT (user_id) DO NOTHING`, userID, code)
+ if insertErr == nil {
+ break
+ }
+ if isAffiliateUniqueViolation(insertErr) {
+ continue
+ }
+ return nil, insertErr
+ }
+
+ return queryAffiliateByUserID(ctx, client, userID)
+}
+
+func queryAffiliateByUserID(ctx context.Context, client affiliateQueryExecer, userID int64) (*service.AffiliateSummary, error) {
+ rows, err := client.QueryContext(ctx, `
+SELECT user_id,
+ aff_code,
+ aff_code_custom,
+ aff_rebate_rate_percent,
+ inviter_id,
+ aff_count,
+ aff_quota::double precision,
+ aff_frozen_quota::double precision,
+ aff_history_quota::double precision,
+ created_at,
+ updated_at
+FROM user_affiliates
+WHERE user_id = $1`, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+ if !rows.Next() {
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return nil, service.ErrAffiliateProfileNotFound
+ }
+
+ var out service.AffiliateSummary
+ var inviterID sql.NullInt64
+ var rebateRate sql.NullFloat64
+ if err := rows.Scan(
+ &out.UserID,
+ &out.AffCode,
+ &out.AffCodeCustom,
+ &rebateRate,
+ &inviterID,
+ &out.AffCount,
+ &out.AffQuota,
+ &out.AffFrozenQuota,
+ &out.AffHistoryQuota,
+ &out.CreatedAt,
+ &out.UpdatedAt,
+ ); err != nil {
+ return nil, err
+ }
+ if inviterID.Valid {
+ out.InviterID = &inviterID.Int64
+ }
+ if rebateRate.Valid {
+ v := rebateRate.Float64
+ out.AffRebateRatePercent = &v
+ }
+ return &out, nil
+}
+
+func queryAffiliateByCode(ctx context.Context, client affiliateQueryExecer, code string) (*service.AffiliateSummary, error) {
+ rows, err := client.QueryContext(ctx, `
+SELECT user_id,
+ aff_code,
+ aff_code_custom,
+ aff_rebate_rate_percent,
+ inviter_id,
+ aff_count,
+ aff_quota::double precision,
+ aff_frozen_quota::double precision,
+ aff_history_quota::double precision,
+ created_at,
+ updated_at
+FROM user_affiliates
+WHERE aff_code = $1
+LIMIT 1`, strings.ToUpper(strings.TrimSpace(code)))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return nil, service.ErrAffiliateProfileNotFound
+ }
+
+ var out service.AffiliateSummary
+ var inviterID sql.NullInt64
+ var rebateRate sql.NullFloat64
+ if err := rows.Scan(
+ &out.UserID,
+ &out.AffCode,
+ &out.AffCodeCustom,
+ &rebateRate,
+ &inviterID,
+ &out.AffCount,
+ &out.AffQuota,
+ &out.AffFrozenQuota,
+ &out.AffHistoryQuota,
+ &out.CreatedAt,
+ &out.UpdatedAt,
+ ); err != nil {
+ return nil, err
+ }
+ if inviterID.Valid {
+ out.InviterID = &inviterID.Int64
+ }
+ if rebateRate.Valid {
+ v := rebateRate.Float64
+ out.AffRebateRatePercent = &v
+ }
+ return &out, nil
+}
+
+func queryUserBalance(ctx context.Context, client affiliateQueryExecer, userID int64) (float64, error) {
+ rows, err := client.QueryContext(ctx,
+ "SELECT balance::double precision FROM users WHERE id = $1 LIMIT 1",
+ userID,
+ )
+ if err != nil {
+ return 0, err
+ }
+ defer func() { _ = rows.Close() }()
+ if !rows.Next() {
+ if err := rows.Err(); err != nil {
+ return 0, err
+ }
+ return 0, service.ErrUserNotFound
+ }
+ var balance float64
+ if err := rows.Scan(&balance); err != nil {
+ return 0, err
+ }
+ return balance, nil
+}
+
+func generateAffiliateCode() (string, error) {
+ buf := make([]byte, affiliateCodeLength)
+ if _, err := rand.Read(buf); err != nil {
+ return "", fmt.Errorf("generate affiliate code: %w", err)
+ }
+ for i := range buf {
+ buf[i] = affiliateCodeCharset[int(buf[i])%len(affiliateCodeCharset)]
+ }
+ return string(buf), nil
+}
+
+func isAffiliateUniqueViolation(err error) bool {
+ var pqErr *pq.Error
+ if errors.As(err, &pqErr) {
+ return string(pqErr.Code) == "23505"
+ }
+ return false
+}
+
+// UpdateUserAffCode 改写用户的邀请码(自定义专属邀请码)。
+// 唯一性冲突返回 ErrAffiliateCodeTaken。
+func (r *affiliateRepository) UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error {
+ if userID <= 0 {
+ return service.ErrUserNotFound
+ }
+ code := strings.ToUpper(strings.TrimSpace(newCode))
+ if code == "" {
+ return service.ErrAffiliateCodeInvalid
+ }
+
+ return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
+ return err
+ }
+ res, err := txClient.ExecContext(txCtx, `
+UPDATE user_affiliates
+SET aff_code = $1,
+ aff_code_custom = true,
+ updated_at = NOW()
+WHERE user_id = $2`, code, userID)
+ if err != nil {
+ if isAffiliateUniqueViolation(err) {
+ return service.ErrAffiliateCodeTaken
+ }
+ return fmt.Errorf("update aff_code: %w", err)
+ }
+ affected, _ := res.RowsAffected()
+ if affected == 0 {
+ return service.ErrUserNotFound
+ }
+ return nil
+ })
+}
+
+// ResetUserAffCode 把 aff_code 还原为系统随机码,并清除 aff_code_custom 标记。
+func (r *affiliateRepository) ResetUserAffCode(ctx context.Context, userID int64) (string, error) {
+ if userID <= 0 {
+ return "", service.ErrUserNotFound
+ }
+ var newCode string
+ err := r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
+ return err
+ }
+ for i := 0; i < affiliateCodeMaxAttempts; i++ {
+ candidate, codeErr := generateAffiliateCode()
+ if codeErr != nil {
+ return codeErr
+ }
+ res, err := txClient.ExecContext(txCtx, `
+UPDATE user_affiliates
+SET aff_code = $1,
+ aff_code_custom = false,
+ updated_at = NOW()
+WHERE user_id = $2`, candidate, userID)
+ if err != nil {
+ if isAffiliateUniqueViolation(err) {
+ continue
+ }
+ return fmt.Errorf("reset aff_code: %w", err)
+ }
+ affected, _ := res.RowsAffected()
+ if affected == 0 {
+ return service.ErrUserNotFound
+ }
+ newCode = candidate
+ return nil
+ }
+ return fmt.Errorf("reset aff_code: exhausted attempts")
+ })
+ if err != nil {
+ return "", err
+ }
+ return newCode, nil
+}
+
+// SetUserRebateRate 设置或清除用户专属返利比例。ratePercent==nil 表示清除(沿用全局)。
+func (r *affiliateRepository) SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error {
+ if userID <= 0 {
+ return service.ErrUserNotFound
+ }
+ return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, userID); err != nil {
+ return err
+ }
+ // nullableArg lets us use a single UPDATE for both "set value" and
+ // "clear" cases — database/sql converts nil interface{} to SQL NULL.
+ res, err := txClient.ExecContext(txCtx, `
+UPDATE user_affiliates
+SET aff_rebate_rate_percent = $1,
+ updated_at = NOW()
+WHERE user_id = $2`, nullableArg(ratePercent), userID)
+ if err != nil {
+ return fmt.Errorf("set aff_rebate_rate_percent: %w", err)
+ }
+ affected, _ := res.RowsAffected()
+ if affected == 0 {
+ return service.ErrUserNotFound
+ }
+ return nil
+ })
+}
+
+// BatchSetUserRebateRate 批量为多个用户设置专属比例(nil 清除)。
+func (r *affiliateRepository) BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error {
+ if len(userIDs) == 0 {
+ return nil
+ }
+ return r.withTx(ctx, func(txCtx context.Context, txClient *dbent.Client) error {
+ for _, uid := range userIDs {
+ if uid <= 0 {
+ continue
+ }
+ if _, err := ensureUserAffiliateWithClient(txCtx, txClient, uid); err != nil {
+ return err
+ }
+ }
+ _, err := txClient.ExecContext(txCtx, `
+UPDATE user_affiliates
+SET aff_rebate_rate_percent = $1,
+ updated_at = NOW()
+WHERE user_id = ANY($2)`, nullableArg(ratePercent), pq.Array(userIDs))
+ if err != nil {
+ return fmt.Errorf("batch set aff_rebate_rate_percent: %w", err)
+ }
+ return nil
+ })
+}
+
+// nullableArg unwraps a *float64 into an interface{} suitable for SQL parameter
+// binding: nil pointer → SQL NULL, non-nil → the float value.
+func nullableArg(v *float64) any {
+ if v == nil {
+ return nil
+ }
+ return *v
+}
+
+// ListUsersWithCustomSettings 列出有专属配置(自定义码或专属比例)的用户。
+//
+// 单一查询同时处理"无搜索"与"按邮箱/用户名模糊搜索":
+// 空 search 时拼接出的 LIKE 模式为 "%%",匹配所有行;非空时按 ILIKE 子串匹配。
+// 这避免了为两种情况维护两份 SQL 模板。
+func (r *affiliateRepository) ListUsersWithCustomSettings(ctx context.Context, filter service.AffiliateAdminFilter) ([]service.AffiliateAdminEntry, int64, error) {
+ page := filter.Page
+ if page < 1 {
+ page = 1
+ }
+ pageSize := filter.PageSize
+ if pageSize <= 0 || pageSize > 200 {
+ pageSize = 20
+ }
+ offset := (page - 1) * pageSize
+ likePattern := "%" + strings.TrimSpace(filter.Search) + "%"
+
+ const baseFrom = `
+FROM user_affiliates ua
+JOIN users u ON u.id = ua.user_id
+WHERE (ua.aff_code_custom = true OR ua.aff_rebate_rate_percent IS NOT NULL)
+ AND (u.email ILIKE $1 OR u.username ILIKE $1)`
+
+ client := clientFromContext(ctx, r.client)
+
+ total, err := scanInt64(ctx, client, "SELECT COUNT(*)"+baseFrom, likePattern)
+ if err != nil {
+ return nil, 0, fmt.Errorf("count affiliate admin entries: %w", err)
+ }
+
+ listQuery := `
+SELECT ua.user_id,
+ COALESCE(u.email, ''),
+ COALESCE(u.username, ''),
+ ua.aff_code,
+ ua.aff_code_custom,
+ ua.aff_rebate_rate_percent,
+ ua.aff_count` + baseFrom + `
+ORDER BY ua.updated_at DESC
+LIMIT $2 OFFSET $3`
+
+ rows, err := client.QueryContext(ctx, listQuery, likePattern, pageSize, offset)
+ if err != nil {
+ return nil, 0, fmt.Errorf("list affiliate admin entries: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ entries := make([]service.AffiliateAdminEntry, 0)
+ for rows.Next() {
+ var e service.AffiliateAdminEntry
+ var rebate sql.NullFloat64
+ if err := rows.Scan(&e.UserID, &e.Email, &e.Username, &e.AffCode,
+ &e.AffCodeCustom, &rebate, &e.AffCount); err != nil {
+ return nil, 0, err
+ }
+ if rebate.Valid {
+ v := rebate.Float64
+ e.AffRebateRatePercent = &v
+ }
+ entries = append(entries, e)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, 0, err
+ }
+ return entries, total, nil
+}
+
+// scanInt64 runs a query expected to return a single int64 column (e.g. COUNT).
+func scanInt64(ctx context.Context, client affiliateQueryExecer, query string, args ...any) (int64, error) {
+ rows, err := client.QueryContext(ctx, query, args...)
+ if err != nil {
+ return 0, err
+ }
+ defer func() { _ = rows.Close() }()
+ if !rows.Next() {
+ if err := rows.Err(); err != nil {
+ return 0, err
+ }
+ return 0, nil
+ }
+ var v int64
+ if err := rows.Scan(&v); err != nil {
+ return 0, err
+ }
+ return v, nil
+}
diff --git a/backend/internal/repository/affiliate_repo_integration_test.go b/backend/internal/repository/affiliate_repo_integration_test.go
new file mode 100644
index 00000000..697a193b
--- /dev/null
+++ b/backend/internal/repository/affiliate_repo_integration_test.go
@@ -0,0 +1,399 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func querySingleFloat(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) float64 {
+ t.Helper()
+ rows, err := client.QueryContext(ctx, query, args...)
+ require.NoError(t, err)
+ defer func() { _ = rows.Close() }()
+
+ require.True(t, rows.Next(), "expected one row")
+ var value float64
+ require.NoError(t, rows.Scan(&value))
+ require.NoError(t, rows.Err())
+ return value
+}
+
+func querySingleInt(t *testing.T, ctx context.Context, client *dbent.Client, query string, args ...any) int {
+ t.Helper()
+ rows, err := client.QueryContext(ctx, query, args...)
+ require.NoError(t, err)
+ defer func() { _ = rows.Close() }()
+
+ require.True(t, rows.Next(), "expected one row")
+ var value int
+ require.NoError(t, rows.Scan(&value))
+ require.NoError(t, rows.Err())
+ return value
+}
+
+func TestAffiliateRepository_TransferQuotaToBalance_UsesClaimedQuotaBeforeClear(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ u := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-transfer-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Balance: 5.5,
+ Concurrency: 5,
+ })
+
+ affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
+ _, err := client.ExecContext(txCtx, `
+INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
+VALUES ($1, $2, $3, $3, NOW(), NOW())`, u.ID, affCode, 12.34)
+ require.NoError(t, err)
+
+ transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
+ require.NoError(t, err)
+ require.InDelta(t, 12.34, transferred, 1e-9)
+ require.InDelta(t, 17.84, balance, 1e-9)
+
+ affQuota := querySingleFloat(t, txCtx, client,
+ "SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", u.ID)
+ require.InDelta(t, 0.0, affQuota, 1e-9)
+
+ persistedBalance := querySingleFloat(t, txCtx, client,
+ "SELECT balance::double precision FROM users WHERE id = $1", u.ID)
+ require.InDelta(t, 17.84, persistedBalance, 1e-9)
+
+ ledgerCount := querySingleInt(t, txCtx, client,
+ "SELECT COUNT(*) FROM user_affiliate_ledger WHERE user_id = $1 AND action = 'transfer'", u.ID)
+ require.Equal(t, 1, ledgerCount)
+}
+
+// TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction guards the
+// cross-layer tx propagation invariant: when AccrueQuota is called with a ctx
+// that already carries a transaction (via dbent.NewTxContext), repo.withTx
+// must reuse that tx rather than opening a nested one. If this invariant
+// breaks, AccrueQuota would commit independently and survive a rollback of
+// the outer tx, which would violate payment_fulfillment's all-or-nothing
+// semantics.
+func TestAffiliateRepository_AccrueQuota_ReusesOuterTransaction(t *testing.T) {
+ ctx := context.Background()
+
+ outerTx, err := integrationEntClient.Tx(ctx)
+ require.NoError(t, err, "begin outer tx")
+ // Defensive cleanup: if any require.* below fires before the explicit
+ // Rollback, this prevents the tx from leaking until container teardown.
+ // Rollback is idempotent at the driver level (extra rollback returns an
+ // error we ignore).
+ t.Cleanup(func() { _ = outerTx.Rollback() })
+ client := outerTx.Client()
+ txCtx := dbent.NewTxContext(ctx, outerTx)
+
+ inviter := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-inviter-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Concurrency: 5,
+ })
+ invitee := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-invitee-%d@example.com", time.Now().UnixNano()+1),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Concurrency: 5,
+ })
+
+ repo := NewAffiliateRepository(client, integrationDB)
+ _, err = repo.EnsureUserAffiliate(txCtx, inviter.ID)
+ require.NoError(t, err)
+ _, err = repo.EnsureUserAffiliate(txCtx, invitee.ID)
+ require.NoError(t, err)
+
+ bound, err := repo.BindInviter(txCtx, invitee.ID, inviter.ID)
+ require.NoError(t, err)
+ require.True(t, bound, "invitee must bind to inviter")
+
+ applied, err := repo.AccrueQuota(txCtx, inviter.ID, invitee.ID, 3.5, 0)
+ require.NoError(t, err)
+ require.True(t, applied, "AccrueQuota must report applied=true")
+
+ // Visible inside the outer tx.
+ innerQuota := querySingleFloat(t, txCtx, client,
+ "SELECT aff_quota::double precision FROM user_affiliates WHERE user_id = $1", inviter.ID)
+ require.InDelta(t, 3.5, innerQuota, 1e-9)
+
+ // Roll back the outer tx; if AccrueQuota had opened its own inner tx and
+ // committed it, the rows would still be visible to the global client.
+ require.NoError(t, outerTx.Rollback())
+
+ rows, err := integrationEntClient.QueryContext(ctx,
+ "SELECT COUNT(*) FROM user_affiliates WHERE user_id IN ($1, $2)",
+ inviter.ID, invitee.ID)
+ require.NoError(t, err)
+ defer func() { _ = rows.Close() }()
+ require.True(t, rows.Next())
+ var postRollbackCount int
+ require.NoError(t, rows.Scan(&postRollbackCount))
+ require.Equal(t, 0, postRollbackCount,
+ "AccrueQuota must propagate the outer tx — found persisted rows after rollback")
+}
+
+func TestAffiliateRepository_TransferQuotaToBalance_EmptyQuota(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ u := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-empty-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Balance: 3.21,
+ Concurrency: 5,
+ })
+
+ affCode := fmt.Sprintf("AFF%09d", time.Now().UnixNano()%1_000_000_000)
+ _, err := client.ExecContext(txCtx, `
+INSERT INTO user_affiliates (user_id, aff_code, aff_quota, aff_history_quota, created_at, updated_at)
+VALUES ($1, $2, 0, 0, NOW(), NOW())`, u.ID, affCode)
+ require.NoError(t, err)
+
+ transferred, balance, err := repo.TransferQuotaToBalance(txCtx, u.ID)
+ require.ErrorIs(t, err, service.ErrAffiliateQuotaEmpty)
+ require.InDelta(t, 0.0, transferred, 1e-9)
+ require.InDelta(t, 0.0, balance, 1e-9)
+
+ persistedBalance := querySingleFloat(t, txCtx, client,
+ "SELECT balance::double precision FROM users WHERE id = $1", u.ID)
+ require.InDelta(t, 3.21, persistedBalance, 1e-9)
+}
+
+// TestAffiliateRepository_AdminCustomCode covers the success path of admin
+// invite-code rewrite + reset within a shared test transaction:
+// - UpdateUserAffCode replaces aff_code, sets aff_code_custom=true, lookup works
+// - the old code can no longer be found
+// - ResetUserAffCode reverts aff_code_custom and assigns a new system-format code
+//
+// The conflict path (duplicate code → ErrAffiliateCodeTaken) lives in its own
+// test because a unique-violation aborts the surrounding Postgres tx, which
+// would poison subsequent assertions in the same transaction.
+func TestAffiliateRepository_AdminCustomCode(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ u := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-custom-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+
+ original, err := repo.EnsureUserAffiliate(txCtx, u.ID)
+ require.NoError(t, err)
+ require.False(t, original.AffCodeCustom, "system-generated codes start as non-custom")
+ originalCode := original.AffCode
+
+ // Rewrite to a custom code
+ customCode := fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000)
+ require.NoError(t, repo.UpdateUserAffCode(txCtx, u.ID, customCode))
+
+ updated, err := repo.EnsureUserAffiliate(txCtx, u.ID)
+ require.NoError(t, err)
+ require.Equal(t, customCode, updated.AffCode)
+ require.True(t, updated.AffCodeCustom)
+
+ // Lookup by new custom code finds the user
+ byCode, err := repo.GetAffiliateByCode(txCtx, customCode)
+ require.NoError(t, err)
+ require.Equal(t, u.ID, byCode.UserID)
+
+ // Old system code should no longer match
+ _, err = repo.GetAffiliateByCode(txCtx, originalCode)
+ require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound)
+
+ // Reset back to a fresh system code, clears custom flag
+ newSysCode, err := repo.ResetUserAffCode(txCtx, u.ID)
+ require.NoError(t, err)
+ require.NotEqual(t, customCode, newSysCode)
+
+ reset, err := repo.EnsureUserAffiliate(txCtx, u.ID)
+ require.NoError(t, err)
+ require.Equal(t, newSysCode, reset.AffCode)
+ require.False(t, reset.AffCodeCustom)
+
+ // The old custom code is now free again
+ _, err = repo.GetAffiliateByCode(txCtx, customCode)
+ require.ErrorIs(t, err, service.ErrAffiliateProfileNotFound)
+}
+
+// TestAffiliateRepository_AdminCustomCode_Conflict isolates the unique-violation
+// path. PostgreSQL aborts the enclosing tx when a unique constraint fires, so
+// this test must be the only assertion and run in its own tx — production
+// callers each have their own outer tx, so this matches real behavior.
+func TestAffiliateRepository_AdminCustomCode_Conflict(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ taker := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-conflict-taker-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser, Status: service.StatusActive,
+ })
+ requester := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-conflict-req-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser, Status: service.StatusActive,
+ })
+
+ takenCode := fmt.Sprintf("HOT%09d", time.Now().UnixNano()%1_000_000_000)
+ require.NoError(t, repo.UpdateUserAffCode(txCtx, taker.ID, takenCode))
+
+ // Now requester tries to grab the same code → conflict.
+ err := repo.UpdateUserAffCode(txCtx, requester.ID, takenCode)
+ require.ErrorIs(t, err, service.ErrAffiliateCodeTaken)
+}
+
+// TestAffiliateRepository_AdminRebateRate covers per-user exclusive rate
+// set/clear and the Batch variant including NULL semantics.
+func TestAffiliateRepository_AdminRebateRate(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ u1 := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-rate-%d-a@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ u2 := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-rate-%d-b@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+
+ // Set exclusive rate for u1
+ rate := 42.5
+ require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, &rate))
+
+ got, err := repo.EnsureUserAffiliate(txCtx, u1.ID)
+ require.NoError(t, err)
+ require.NotNil(t, got.AffRebateRatePercent)
+ require.InDelta(t, 42.5, *got.AffRebateRatePercent, 1e-9)
+
+ // Clear exclusive rate
+ require.NoError(t, repo.SetUserRebateRate(txCtx, u1.ID, nil))
+ cleared, err := repo.EnsureUserAffiliate(txCtx, u1.ID)
+ require.NoError(t, err)
+ require.Nil(t, cleared.AffRebateRatePercent)
+
+ // Batch set both users
+ batchRate := 15.0
+ require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, &batchRate))
+
+ for _, uid := range []int64{u1.ID, u2.ID} {
+ v, err := repo.EnsureUserAffiliate(txCtx, uid)
+ require.NoError(t, err)
+ require.NotNil(t, v.AffRebateRatePercent)
+ require.InDelta(t, 15.0, *v.AffRebateRatePercent, 1e-9)
+ }
+
+ // Batch clear
+ require.NoError(t, repo.BatchSetUserRebateRate(txCtx, []int64{u1.ID, u2.ID}, nil))
+ for _, uid := range []int64{u1.ID, u2.ID} {
+ v, err := repo.EnsureUserAffiliate(txCtx, uid)
+ require.NoError(t, err)
+ require.Nil(t, v.AffRebateRatePercent)
+ }
+}
+
+// TestAffiliateRepository_ListUsersWithCustomSettings verifies the admin list
+// only includes users with at least one override applied.
+func TestAffiliateRepository_ListUsersWithCustomSettings(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ txCtx := dbent.NewTxContext(ctx, tx)
+ client := tx.Client()
+
+ repo := NewAffiliateRepository(client, integrationDB)
+
+ // User without any custom config — should NOT appear in the list.
+ plainEmail := fmt.Sprintf("affiliate-plain-%d@example.com", time.Now().UnixNano())
+ uPlain := mustCreateUser(t, client, &service.User{
+ Email: plainEmail, PasswordHash: "hash",
+ Role: service.RoleUser, Status: service.StatusActive,
+ })
+ _, err := repo.EnsureUserAffiliate(txCtx, uPlain.ID)
+ require.NoError(t, err)
+
+ // User with a custom code — should appear.
+ uCode := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-codeonly-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser, Status: service.StatusActive,
+ })
+ require.NoError(t, repo.UpdateUserAffCode(txCtx, uCode.ID, fmt.Sprintf("VIP%09d", time.Now().UnixNano()%1_000_000_000)))
+
+ // User with only an exclusive rate — should appear.
+ uRate := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("affiliate-rateonly-%d@example.com", time.Now().UnixNano()),
+ PasswordHash: "hash",
+ Role: service.RoleUser, Status: service.StatusActive,
+ })
+ r := 33.3
+ require.NoError(t, repo.SetUserRebateRate(txCtx, uRate.ID, &r))
+
+ entries, total, err := repo.ListUsersWithCustomSettings(txCtx, service.AffiliateAdminFilter{
+ Page: 1, PageSize: 100,
+ })
+ require.NoError(t, err)
+
+ // Build a quick lookup to assert per-user attributes (other tests may have
+ // inserted custom rows in the same DB; we only care about our 3).
+ byUserID := make(map[int64]service.AffiliateAdminEntry, len(entries))
+ for _, e := range entries {
+ byUserID[e.UserID] = e
+ }
+
+ require.NotContains(t, byUserID, uPlain.ID, "users without overrides must not appear")
+
+ codeEntry, ok := byUserID[uCode.ID]
+ require.True(t, ok, "custom-code user missing from list")
+ require.True(t, codeEntry.AffCodeCustom)
+ require.Nil(t, codeEntry.AffRebateRatePercent)
+
+ rateEntry, ok := byUserID[uRate.ID]
+ require.True(t, ok, "custom-rate user missing from list")
+ require.False(t, rateEntry.AffCodeCustom)
+ require.NotNil(t, rateEntry.AffRebateRatePercent)
+ require.InDelta(t, 33.3, *rateEntry.AffRebateRatePercent, 1e-9)
+
+ require.GreaterOrEqual(t, total, int64(2), "total must include at least our 2 custom rows")
+}
diff --git a/backend/internal/repository/announcement_read_repo.go b/backend/internal/repository/announcement_read_repo.go
index 2dc346b1..5268ec45 100644
--- a/backend/internal/repository/announcement_read_repo.go
+++ b/backend/internal/repository/announcement_read_repo.go
@@ -19,13 +19,17 @@ func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementRea
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
client := clientFromContext(ctx, r.client)
- return client.AnnouncementRead.Create().
+ err := client.AnnouncementRead.Create().
SetAnnouncementID(announcementID).
SetUserID(userID).
SetReadAt(readAt).
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
DoNothing().
Exec(ctx)
+ if isSQLNoRowsError(err) {
+ return nil
+ }
+ return err
}
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
diff --git a/backend/internal/repository/announcement_repo.go b/backend/internal/repository/announcement_repo.go
index 53dc335f..afe1fb25 100644
--- a/backend/internal/repository/announcement_repo.go
+++ b/backend/internal/repository/announcement_repo.go
@@ -2,12 +2,15 @@ package repository
import (
"context"
+ "strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/announcement"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
+
+ entsql "entgo.io/ent/dialect/sql"
)
type announcementRepository struct {
@@ -128,11 +131,14 @@ func (r *announcementRepository) List(
return nil, nil, err
}
- items, err := q.
+ itemsQuery := q.
Offset(params.Offset()).
- Limit(params.Limit()).
- Order(dbent.Desc(announcement.FieldID)).
- All(ctx)
+ Limit(params.Limit())
+ for _, order := range announcementListOrders(params) {
+ itemsQuery = itemsQuery.Order(order)
+ }
+
+ items, err := itemsQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -141,6 +147,56 @@ func (r *announcementRepository) List(
return out, paginationResultFromTotal(int64(total), params), nil
}
+func announcementListOrder(params pagination.PaginationParams) (string, string) {
+ sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
+ sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
+
+ switch sortBy {
+ case "title":
+ return announcement.FieldTitle, sortOrder
+ case "status":
+ return announcement.FieldStatus, sortOrder
+ case "notify_mode":
+ return announcement.FieldNotifyMode, sortOrder
+ case "starts_at":
+ return announcement.FieldStartsAt, sortOrder
+ case "ends_at":
+ return announcement.FieldEndsAt, sortOrder
+ case "id":
+ return announcement.FieldID, sortOrder
+ case "", "created_at":
+ return announcement.FieldCreatedAt, sortOrder
+ default:
+ return announcement.FieldCreatedAt, pagination.SortOrderDesc
+ }
+}
+
+func announcementListOrders(params pagination.PaginationParams) []func(*entsql.Selector) {
+ field, sortOrder := announcementListOrder(params)
+
+ if sortOrder == pagination.SortOrderAsc {
+ if field == announcement.FieldID {
+ return []func(*entsql.Selector){
+ dbent.Asc(field),
+ }
+ }
+ return []func(*entsql.Selector){
+ dbent.Asc(field),
+ dbent.Asc(announcement.FieldID),
+ }
+ }
+
+ if field == announcement.FieldID {
+ return []func(*entsql.Selector){
+ dbent.Desc(field),
+ }
+ }
+ return []func(*entsql.Selector){
+ dbent.Desc(field),
+ dbent.Desc(announcement.FieldID),
+ }
+}
+
func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) ([]service.Announcement, error) {
q := r.client.Announcement.Query().
Where(
diff --git a/backend/internal/repository/announcement_repo_sort_test.go b/backend/internal/repository/announcement_repo_sort_test.go
new file mode 100644
index 00000000..e47f98dc
--- /dev/null
+++ b/backend/internal/repository/announcement_repo_sort_test.go
@@ -0,0 +1,63 @@
+package repository
+
+import (
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+)
+
+func TestAnnouncementListOrder(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ params pagination.PaginationParams
+ wantBy string
+ want string
+ }{
+ {
+ name: "default created_at desc",
+ params: pagination.PaginationParams{},
+ wantBy: "created_at",
+ want: "desc",
+ },
+ {
+ name: "title asc",
+ params: pagination.PaginationParams{
+ SortBy: "title",
+ SortOrder: "ASC",
+ },
+ wantBy: "title",
+ want: "asc",
+ },
+ {
+ name: "status desc",
+ params: pagination.PaginationParams{
+ SortBy: "status",
+ SortOrder: "desc",
+ },
+ wantBy: "status",
+ want: "desc",
+ },
+ {
+ name: "invalid falls back",
+ params: pagination.PaginationParams{
+ SortBy: "sideways",
+ SortOrder: "wat",
+ },
+ wantBy: "created_at",
+ want: "desc",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ gotBy, gotOrder := announcementListOrder(tt.params)
+ if gotBy != tt.wantBy || gotOrder != tt.want {
+ t.Fatalf("announcementListOrder(%+v) = (%q, %q), want (%q, %q)", tt.params, gotBy, gotOrder, tt.wantBy, tt.want)
+ }
+ })
+ }
+}
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index ade0d464..3a527405 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
+ "strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -14,6 +15,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+
+ entsql "entgo.io/ent/dialect/sql"
)
type apiKeyRepository struct {
@@ -135,10 +138,21 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
WithUser(func(q *dbent.UserQuery) {
q.Select(
user.FieldID,
+ user.FieldEmail,
+ user.FieldUsername,
user.FieldStatus,
user.FieldRole,
user.FieldBalance,
user.FieldConcurrency,
+ user.FieldBalanceNotifyEnabled,
+ user.FieldBalanceNotifyThresholdType,
+ user.FieldBalanceNotifyThreshold,
+ user.FieldBalanceNotifyExtraEmails,
+ user.FieldTotalRecharged,
+ user.FieldSignupSource,
+ user.FieldLastLoginAt,
+ user.FieldLastActiveAt,
+ user.FieldRpmLimit,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
@@ -155,10 +169,6 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldImagePrice1k,
group.FieldImagePrice2k,
group.FieldImagePrice4k,
- group.FieldSoraImagePrice360,
- group.FieldSoraImagePrice540,
- group.FieldSoraVideoPricePerRequest,
- group.FieldSoraVideoPricePerRequestHd,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
group.FieldFallbackGroupIDOnInvalidRequest,
@@ -168,6 +178,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldSupportedModelScopes,
group.FieldAllowMessagesDispatch,
group.FieldDefaultMappedModel,
+ group.FieldMessagesDispatchModelConfig,
+ group.FieldRpmLimit,
)
}).
Only(ctx)
@@ -313,12 +325,15 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err
}
- keys, err := q.
+ keysQuery := q.
WithGroup().
Offset(params.Offset()).
- Limit(params.Limit()).
- Order(dbent.Desc(apikey.FieldID)).
- All(ctx)
+ Limit(params.Limit())
+ for _, order := range apiKeyListOrder(params) {
+ keysQuery = keysQuery.Order(order)
+ }
+
+ keys, err := keysQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -363,12 +378,15 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err
}
- keys, err := q.
+ keysQuery := q.
WithUser().
Offset(params.Offset()).
- Limit(params.Limit()).
- Order(dbent.Desc(apikey.FieldID)).
- All(ctx)
+ Limit(params.Limit())
+ for _, order := range apiKeyListOrder(params) {
+ keysQuery = keysQuery.Order(order)
+ }
+
+ keys, err := keysQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -381,6 +399,32 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return outKeys, paginationResultFromTotal(int64(total), params), nil
}
+func apiKeyListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
+ sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
+ sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
+
+ var field string
+ switch sortBy {
+ case "name":
+ field = apikey.FieldName
+ case "status":
+ field = apikey.FieldStatus
+ case "expires_at":
+ field = apikey.FieldExpiresAt
+ case "last_used_at":
+ field = apikey.FieldLastUsedAt
+ case "created_at":
+ field = apikey.FieldCreatedAt
+ default:
+ field = apikey.FieldID
+ }
+
+ if sortOrder == pagination.SortOrderAsc {
+ return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(apikey.FieldID)}
+ }
+ return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(apikey.FieldID)}
+}
+
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
q := r.activeQuery()
@@ -607,24 +651,35 @@ func userEntityToService(u *dbent.User) *service.User {
if u == nil {
return nil
}
- return &service.User{
- ID: u.ID,
- Email: u.Email,
- Username: u.Username,
- Notes: u.Notes,
- PasswordHash: u.PasswordHash,
- Role: u.Role,
- Balance: u.Balance,
- Concurrency: u.Concurrency,
- Status: u.Status,
- SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
- SoraStorageUsedBytes: u.SoraStorageUsedBytes,
- TotpSecretEncrypted: u.TotpSecretEncrypted,
- TotpEnabled: u.TotpEnabled,
- TotpEnabledAt: u.TotpEnabledAt,
- CreatedAt: u.CreatedAt,
- UpdatedAt: u.UpdatedAt,
+ out := &service.User{
+ ID: u.ID,
+ Email: u.Email,
+ Username: u.Username,
+ Notes: u.Notes,
+ PasswordHash: u.PasswordHash,
+ Role: u.Role,
+ Balance: u.Balance,
+ Concurrency: u.Concurrency,
+ Status: u.Status,
+ SignupSource: u.SignupSource,
+ LastLoginAt: u.LastLoginAt,
+ LastActiveAt: u.LastActiveAt,
+ TotpSecretEncrypted: u.TotpSecretEncrypted,
+ TotpEnabled: u.TotpEnabled,
+ TotpEnabledAt: u.TotpEnabledAt,
+ BalanceNotifyEnabled: u.BalanceNotifyEnabled,
+ BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
+ BalanceNotifyThreshold: u.BalanceNotifyThreshold,
+ TotalRecharged: u.TotalRecharged,
+ RPMLimit: u.RpmLimit,
+ CreatedAt: u.CreatedAt,
+ UpdatedAt: u.UpdatedAt,
}
+ // Parse extra emails JSON (supports both old []string and new []NotifyEmailEntry format)
+ if u.BalanceNotifyExtraEmails != "" && u.BalanceNotifyExtraEmails != "[]" {
+ out.BalanceNotifyExtraEmails = service.ParseNotifyEmails(u.BalanceNotifyExtraEmails)
+ }
+ return out
}
func groupEntityToService(g *dbent.Group) *service.Group {
@@ -647,11 +702,6 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
- SoraImagePrice360: g.SoraImagePrice360,
- SoraImagePrice540: g.SoraImagePrice540,
- SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
- SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
- SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
@@ -665,6 +715,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RequireOAuthOnly: g.RequireOauthOnly,
RequirePrivacySet: g.RequirePrivacySet,
DefaultMappedModel: g.DefaultMappedModel,
+ MessagesDispatchModelConfig: g.MessagesDispatchModelConfig,
+ RPMLimit: g.RpmLimit,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go
index 7d5c1826..e926ed86 100644
--- a/backend/internal/repository/api_key_repo_integration_test.go
+++ b/backend/internal/repository/api_key_repo_integration_test.go
@@ -86,6 +86,45 @@ func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
s.Require().Error(err, "expected error for non-existent key")
}
+func (s *APIKeyRepoSuite) TestGetByKeyForAuth_PreservesMessagesDispatchModelConfig() {
+ user := s.mustCreateUser("getbykey-auth-dispatch@test.com")
+ group, err := s.client.Group.Create().
+ SetName("g-auth-dispatch").
+ SetPlatform(service.PlatformOpenAI).
+ SetStatus(service.StatusActive).
+ SetSubscriptionType(service.SubscriptionTypeStandard).
+ SetRateMultiplier(1).
+ SetAllowMessagesDispatch(true).
+ SetDefaultMappedModel("gpt-5.4").
+ SetMessagesDispatchModelConfig(service.OpenAIMessagesDispatchModelConfig{
+ OpusMappedModel: "gpt-5.4-nano",
+ SonnetMappedModel: "gpt-5.3-codex",
+ HaikuMappedModel: "gpt-5.4-mini",
+ ExactModelMappings: map[string]string{
+ "claude-sonnet-4.5": "gpt-5.4-nano",
+ },
+ }).
+ Save(s.ctx)
+ s.Require().NoError(err)
+
+ key := &service.APIKey{
+ UserID: user.ID,
+ Key: "sk-getbykey-auth-dispatch",
+ Name: "Dispatch Key",
+ GroupID: &group.ID,
+ Status: service.StatusActive,
+ }
+ s.Require().NoError(s.repo.Create(s.ctx, key))
+
+ got, err := s.repo.GetByKeyForAuth(s.ctx, key.Key)
+ s.Require().NoError(err)
+ s.Require().NotNil(got.Group)
+ s.Require().True(got.Group.AllowMessagesDispatch)
+ s.Require().Equal("gpt-5.4", got.Group.DefaultMappedModel)
+ s.Require().Equal("gpt-5.4-nano", got.Group.MessagesDispatchModelConfig.OpusMappedModel)
+ s.Require().Equal("gpt-5.4-nano", got.Group.MessagesDispatchModelConfig.ExactModelMappings["claude-sonnet-4.5"])
+}
+
// --- Update ---
func (s *APIKeyRepoSuite) TestUpdate() {
diff --git a/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go b/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go
new file mode 100644
index 00000000..aba62ead
--- /dev/null
+++ b/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go
@@ -0,0 +1,74 @@
+package repository
+
+import (
+ "context"
+ "testing"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestGroupEntityToService_PreservesMessagesDispatchModelConfig(t *testing.T) {
+ group := &dbent.Group{
+ ID: 1,
+ Name: "openai-dispatch",
+ Platform: service.PlatformOpenAI,
+ Status: service.StatusActive,
+ SubscriptionType: service.SubscriptionTypeStandard,
+ RateMultiplier: 1,
+ AllowMessagesDispatch: true,
+ DefaultMappedModel: "gpt-5.4",
+ MessagesDispatchModelConfig: service.OpenAIMessagesDispatchModelConfig{
+ OpusMappedModel: "gpt-5.4-nano",
+ SonnetMappedModel: "gpt-5.3-codex",
+ HaikuMappedModel: "gpt-5.4-mini",
+ ExactModelMappings: map[string]string{
+ "claude-sonnet-4.5": "gpt-5.4-nano",
+ },
+ },
+ }
+
+ got := groupEntityToService(group)
+ require.NotNil(t, got)
+ require.Equal(t, group.MessagesDispatchModelConfig, got.MessagesDispatchModelConfig)
+}
+
+func TestAPIKeyRepository_GetByKeyForAuth_PreservesMessagesDispatchModelConfig_SQLite(t *testing.T) {
+ repo, client := newAPIKeyRepoSQLite(t)
+ ctx := context.Background()
+ user := mustCreateAPIKeyRepoUser(t, ctx, client, "getbykey-auth-dispatch-unit@test.com")
+
+ group, err := client.Group.Create().
+ SetName("g-auth-dispatch-unit").
+ SetPlatform(service.PlatformOpenAI).
+ SetStatus(service.StatusActive).
+ SetSubscriptionType(service.SubscriptionTypeStandard).
+ SetRateMultiplier(1).
+ SetAllowMessagesDispatch(true).
+ SetDefaultMappedModel("gpt-5.4").
+ SetMessagesDispatchModelConfig(service.OpenAIMessagesDispatchModelConfig{
+ OpusMappedModel: "gpt-5.4-nano",
+ SonnetMappedModel: "gpt-5.3-codex",
+ HaikuMappedModel: "gpt-5.4-mini",
+ ExactModelMappings: map[string]string{
+ "claude-sonnet-4.5": "gpt-5.4-nano",
+ },
+ }).
+ Save(ctx)
+ require.NoError(t, err)
+
+ key := &service.APIKey{
+ UserID: user.ID,
+ Key: "sk-getbykey-auth-dispatch-unit",
+ Name: "Dispatch Key Unit",
+ GroupID: &group.ID,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, repo.Create(ctx, key))
+
+ got, err := repo.GetByKeyForAuth(ctx, key.Key)
+ require.NoError(t, err)
+ require.NotNil(t, got.Group)
+ require.Equal(t, group.MessagesDispatchModelConfig, got.Group.MessagesDispatchModelConfig)
+}
diff --git a/backend/internal/repository/api_key_repo_sort_integration_test.go b/backend/internal/repository/api_key_repo_sort_integration_test.go
new file mode 100644
index 00000000..69812882
--- /dev/null
+++ b/backend/internal/repository/api_key_repo_sort_integration_test.go
@@ -0,0 +1,25 @@
+//go:build integration
+
+package repository
+
+import (
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func (s *APIKeyRepoSuite) TestListByUserID_SortByNameAsc() {
+ user := s.mustCreateUser("sort-name@example.com")
+ s.mustCreateApiKey(user.ID, "sk-z", "z-key", nil)
+ s.mustCreateApiKey(user.ID, "sk-a", "a-key", nil)
+
+ keys, _, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "name",
+ SortOrder: "asc",
+ }, service.APIKeyListFilters{})
+ s.Require().NoError(err)
+ s.Require().Len(keys, 2)
+ s.Require().Equal("a-key", keys[0].Name)
+ s.Require().Equal("z-key", keys[1].Name)
+}
diff --git a/backend/internal/repository/auth_identity_compat_backfill_integration_test.go b/backend/internal/repository/auth_identity_compat_backfill_integration_test.go
new file mode 100644
index 00000000..7e34777a
--- /dev/null
+++ b/backend/internal/repository/auth_identity_compat_backfill_integration_test.go
@@ -0,0 +1,80 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityCompatBackfillMigration_AllowsLongReportTypes(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migration108Path := filepath.Join("..", "..", "migrations", "108_auth_identity_foundation_core.sql")
+ migration108SQL, err := os.ReadFile(migration108Path)
+ require.NoError(t, err)
+
+ migration108aPath := filepath.Join("..", "..", "migrations", "108a_widen_auth_identity_migration_report_type.sql")
+ migration108aSQL, err := os.ReadFile(migration108aPath)
+ require.NoError(t, err)
+
+ migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql")
+ migration109SQL, err := os.ReadFile(migration109Path)
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, `
+DROP TABLE IF EXISTS auth_identity_migration_reports CASCADE;
+DROP TABLE IF EXISTS auth_identity_channels CASCADE;
+DROP TABLE IF EXISTS identity_adoption_decisions CASCADE;
+DROP TABLE IF EXISTS pending_auth_sessions CASCADE;
+DROP TABLE IF EXISTS auth_identities CASCADE;
+
+ALTER TABLE users
+ DROP COLUMN IF EXISTS signup_source,
+ DROP COLUMN IF EXISTS last_login_at,
+ DROP COLUMN IF EXISTS last_active_at;
+`)
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration108SQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration108aSQL))
+ require.NoError(t, err)
+
+ var userID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('oidc-demo-subject@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&userID))
+
+ _, err = tx.ExecContext(ctx, string(migration109SQL))
+ require.NoError(t, err)
+
+ var reportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery'
+ AND report_key = $1
+`, strconv.FormatInt(userID, 10)).Scan(&reportCount))
+ require.Equal(t, 1, reportCount)
+
+ var reportTypeLimit int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT character_maximum_length
+FROM information_schema.columns
+WHERE table_schema = 'public'
+ AND table_name = 'auth_identity_migration_reports'
+ AND column_name = 'report_type'
+`).Scan(&reportTypeLimit))
+ require.GreaterOrEqual(t, reportTypeLimit, 45)
+
+ require.NotZero(t, userID)
+}
diff --git a/backend/internal/repository/auth_identity_legacy_migration_integration_test.go b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
new file mode 100644
index 00000000..e64934c5
--- /dev/null
+++ b/backend/internal/repository/auth_identity_legacy_migration_integration_test.go
@@ -0,0 +1,959 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "os"
+ "path/filepath"
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ var linuxDoUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoUserID))
+
+ var wechatUnionUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-union@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatUnionUserID))
+
+ var wechatOpenIDOnlyUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-openid@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatOpenIDOnlyUserID))
+
+ var syntheticAuthIdentityID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'wechat', 'wechat-main', 'openid-synthetic', '{"backfill_source":"synthetic_email"}'::jsonb)
+RETURNING id`, wechatOpenIDOnlyUserID).Scan(&syntheticAuthIdentityID))
+
+ var linuxDoLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-user-1', NULL, 'linux-user', 'Linux User', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoUserID).Scan(&linuxDoLegacyID))
+
+ var wechatUnionLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-union-1', 'union-1', 'wechat-union-user', 'WeChat Union User', '{"channel":"oa","appid":"wx-app-1"}')
+RETURNING id
+`, wechatUnionUserID).Scan(&wechatUnionLegacyID))
+
+ var wechatOpenIDLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-only-1', NULL, 'wechat-openid-user', 'WeChat OpenID User', '{"channel":"oa","appid":"wx-app-2"}')
+RETURNING id
+`, wechatOpenIDOnlyUserID).Scan(&wechatOpenIDLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var linuxDoCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-user-1'
+`, linuxDoUserID).Scan(&linuxDoCount))
+ require.Equal(t, 1, linuxDoCount)
+
+ var wechatSubject string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT provider_subject
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'wechat'
+ AND provider_key = 'wechat-main'
+ AND provider_subject = 'union-1'
+`, wechatUnionUserID).Scan(&wechatSubject))
+ require.Equal(t, "union-1", wechatSubject)
+
+ var wechatChannelCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_channels channel
+JOIN auth_identities ai ON ai.id = channel.identity_id
+WHERE ai.user_id = $1
+ AND channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat-main'
+ AND channel.channel = 'oa'
+ AND channel.channel_app_id = 'wx-app-1'
+ AND channel.channel_subject = 'openid-union-1'
+`, wechatUnionUserID).Scan(&wechatChannelCount))
+ require.Equal(t, 1, wechatChannelCount)
+
+ var legacyOpenIDOnlyReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDLegacyID, 10)).Scan(&legacyOpenIDOnlyReportCount))
+ require.Equal(t, 1, legacyOpenIDOnlyReportCount)
+
+ var syntheticReviewCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "synthetic_auth_identity:"+strconv.FormatInt(syntheticAuthIdentityID, 10)).Scan(&syntheticReviewCount))
+ require.Equal(t, 1, syntheticReviewCount)
+
+ var unionLegacyReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatUnionLegacyID, 10)).Scan(&unionLegacyReportCount))
+ require.Zero(t, unionLegacyReportCount)
+ require.NotZero(t, linuxDoLegacyID)
+}
+
+func TestAuthIdentityLegacyExternalBackfillMigration_IsSafeWhenLegacyTableMissing(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ var beforeCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+`).Scan(&beforeCount))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var afterCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+ `).Scan(&afterCount))
+ require.Equal(t, beforeCount, afterCount)
+}
+
+func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectMetadata(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migration115SQL, err := os.ReadFile(migration115Path)
+ require.NoError(t, err)
+
+ migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migration116SQL, err := os.ReadFile(migration116Path)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ var linuxDoMalformedUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-malformed@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoMalformedUserID))
+
+ var linuxDoArrayUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-array@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoArrayUserID))
+
+ var wechatUnionArrayUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-array@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatUnionArrayUserID))
+
+ var wechatOpenIDArrayUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-openid-array@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatOpenIDArrayUserID))
+
+ var linuxDoMalformedLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-malformed', NULL, 'legacy-linuxdo-malformed', 'Legacy LinuxDo Malformed', '{invalid')
+RETURNING id
+`, linuxDoMalformedUserID).Scan(&linuxDoMalformedLegacyID))
+
+ var linuxDoArrayLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-array', NULL, 'legacy-linuxdo-array', 'Legacy LinuxDo Array', '["legacy-linuxdo-array"]')
+RETURNING id
+`, linuxDoArrayUserID).Scan(&linuxDoArrayLegacyID))
+
+ var wechatUnionArrayLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-array', 'union-array', 'legacy-wechat-array', 'Legacy WeChat Array', '["legacy-wechat-array"]')
+RETURNING id
+`, wechatUnionArrayUserID).Scan(&wechatUnionArrayLegacyID))
+
+ var wechatOpenIDArrayLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-array-only', NULL, 'legacy-wechat-array-only', 'Legacy WeChat Array Only', '["legacy-wechat-openid-array"]')
+RETURNING id
+`, wechatOpenIDArrayUserID).Scan(&wechatOpenIDArrayLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migration115SQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration116SQL))
+ require.NoError(t, err)
+
+ var linuxDoMalformedMetadataType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(metadata)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-malformed'
+`, linuxDoMalformedUserID).Scan(&linuxDoMalformedMetadataType))
+ require.Equal(t, "object", linuxDoMalformedMetadataType)
+
+ var linuxDoArrayMetadataType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(metadata)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-array'
+`, linuxDoArrayUserID).Scan(&linuxDoArrayMetadataType))
+ require.Equal(t, "object", linuxDoArrayMetadataType)
+
+ var wechatUnionArrayMetadataType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(metadata)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'wechat'
+ AND provider_key = 'wechat-main'
+ AND provider_subject = 'union-array'
+`, wechatUnionArrayUserID).Scan(&wechatUnionArrayMetadataType))
+ require.Equal(t, "object", wechatUnionArrayMetadataType)
+
+ var invalidJSONReportDetailsType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(details)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(linuxDoMalformedLegacyID, 10)).Scan(&invalidJSONReportDetailsType))
+ require.Equal(t, "object", invalidJSONReportDetailsType)
+
+ var openIDOnlyReportDetailsType string
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT jsonb_typeof(details)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDArrayLegacyID, 10)).Scan(&openIDOnlyReportDetailsType))
+ require.Equal(t, "object", openIDOnlyReportDetailsType)
+
+ var preservedArrayMetadataCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE id IN (
+ SELECT id
+ FROM auth_identities
+ WHERE (user_id = $1 AND provider_subject = 'linuxdo-array')
+ OR (user_id = $2 AND provider_subject = 'union-array')
+)
+ AND metadata ? '_legacy_metadata_raw_json'
+`, linuxDoArrayUserID, wechatUnionArrayUserID).Scan(&preservedArrayMetadataCount))
+ require.Equal(t, 2, preservedArrayMetadataCount)
+
+ require.NotZero(t, linuxDoArrayLegacyID)
+ require.NotZero(t, wechatUnionArrayLegacyID)
+}
+
+func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngradesInvalidJSON(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ userIDs := make([]int64, 0, 8)
+ for _, email := range []string{
+ "linuxdo-conflict-legacy@example.com",
+ "linuxdo-conflict-owner@example.com",
+ "wechat-conflict-legacy@example.com",
+ "wechat-conflict-owner@example.com",
+ "wechat-channel-legacy@example.com",
+ "wechat-channel-owner@example.com",
+ "linuxdo-invalid-json@example.com",
+ "wechat-openid-invalid-json@example.com",
+ } {
+ var userID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ($1, 'hash', 'user', 'active', 0, 1)
+RETURNING id`, email).Scan(&userID))
+ userIDs = append(userIDs, userID)
+ }
+
+ linuxdoConflictLegacyUserID := userIDs[0]
+ linuxdoConflictOwnerUserID := userIDs[1]
+ wechatConflictLegacyUserID := userIDs[2]
+ wechatConflictOwnerUserID := userIDs[3]
+ wechatChannelLegacyUserID := userIDs[4]
+ wechatChannelOwnerUserID := userIDs[5]
+ linuxdoInvalidJSONUserID := userIDs[6]
+ wechatInvalidOpenIDUserID := userIDs[7]
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'linuxdo', 'linuxdo', 'linuxdo-conflict', '{}'::jsonb)
+RETURNING id`, linuxdoConflictOwnerUserID).Scan(new(int64)))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'wechat', 'wechat-main', 'union-conflict', '{}'::jsonb)
+RETURNING id`, wechatConflictOwnerUserID).Scan(new(int64)))
+
+ var wechatChannelOwnerIdentityID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
+VALUES ($1, 'wechat', 'wechat-main', 'union-channel-owner', '{}'::jsonb)
+RETURNING id`, wechatChannelOwnerUserID).Scan(&wechatChannelOwnerIdentityID))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO auth_identity_channels (
+ identity_id,
+ provider_type,
+ provider_key,
+ channel,
+ channel_app_id,
+ channel_subject,
+ metadata
+)
+VALUES ($1, 'wechat', 'wechat-main', 'oa', 'wx-app-conflict', 'openid-channel-conflict', '{}'::jsonb)
+RETURNING id`, wechatChannelOwnerIdentityID).Scan(new(int64)))
+
+ var linuxdoConflictLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-conflict', NULL, 'legacy-linuxdo', 'Legacy LinuxDo Conflict', '{"source":"legacy"}')
+RETURNING id
+`, linuxdoConflictLegacyUserID).Scan(&linuxdoConflictLegacyID))
+
+ var wechatConflictLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-union-conflict', 'union-conflict', 'legacy-wechat', 'Legacy WeChat Conflict', '{"channel":"oa","appid":"wx-app-conflict-canon"}')
+RETURNING id
+`, wechatConflictLegacyUserID).Scan(&wechatConflictLegacyID))
+
+ var wechatChannelConflictLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-channel-conflict', 'union-channel-legacy', 'legacy-wechat-channel', 'Legacy WeChat Channel Conflict', '{"channel":"oa","appid":"wx-app-conflict"}')
+RETURNING id
+`, wechatChannelLegacyUserID).Scan(&wechatChannelConflictLegacyID))
+
+ var linuxdoInvalidJSONLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-invalid-json', NULL, 'legacy-linuxdo-invalid', 'Legacy LinuxDo Invalid JSON', '{invalid')
+RETURNING id
+`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidJSONLegacyID))
+
+ var wechatInvalidOpenIDLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-invalid-json-only', NULL, 'legacy-wechat-invalid', 'Legacy WeChat Invalid JSON', '{still-invalid')
+RETURNING id
+`, wechatInvalidOpenIDUserID).Scan(&wechatInvalidOpenIDLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var linuxdoConflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(linuxdoConflictLegacyID, 10)).Scan(&linuxdoConflictReportCount))
+ require.Equal(t, 1, linuxdoConflictReportCount)
+
+ var wechatConflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatConflictLegacyID, 10)).Scan(&wechatConflictReportCount))
+ require.Equal(t, 1, wechatConflictReportCount)
+
+ var channelConflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_channel_conflict'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatChannelConflictLegacyID, 10)).Scan(&channelConflictReportCount))
+ require.Equal(t, 1, channelConflictReportCount)
+
+ var invalidJSONReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
+ AND report_key IN ($1, $2)
+`, "legacy_external_identity:"+strconv.FormatInt(linuxdoInvalidJSONLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&invalidJSONReportCount))
+ require.Equal(t, 2, invalidJSONReportCount)
+
+ var linuxdoInvalidIdentityCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE user_id = $1
+ AND provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-invalid-json'
+`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidIdentityCount))
+ require.Equal(t, 1, linuxdoInvalidIdentityCount)
+
+ var wechatOpenIDOnlyReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'wechat_openid_only_requires_remediation'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&wechatOpenIDOnlyReportCount))
+ require.Equal(t, 1, wechatOpenIDOnlyReportCount)
+}
+
+func TestAuthIdentityLegacyExternalSafetyMigration_IsSafeWhenLegacyTableMissing(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ var beforeCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+`).Scan(&beforeCount))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var afterCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+ `).Scan(&afterCount))
+ require.Equal(t, beforeCount, afterCount)
+}
+
+func TestAuthIdentityLegacyExternalBackfillMigration_SkipsAmbiguousCanonicalSubjects(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migrationSQL, err := os.ReadFile(migrationPath)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ var linuxDoFirstUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoFirstUserID))
+
+ var linuxDoSecondUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoSecondUserID))
+
+ var wechatFirstUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatFirstUserID))
+
+ var wechatSecondUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatSecondUserID))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-a', 'Legacy LinuxDo Ambiguous A', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoFirstUserID).Scan(new(int64)))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-b', 'Legacy LinuxDo Ambiguous B', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoSecondUserID).Scan(new(int64)))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-ambiguous-a', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-a', 'Legacy WeChat Ambiguous A', '{"channel":"oa","appid":"wx-ambiguous-a"}')
+RETURNING id
+`, wechatFirstUserID).Scan(new(int64)))
+
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-ambiguous-b', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-b', 'Legacy WeChat Ambiguous B', '{"channel":"oa","appid":"wx-ambiguous-b"}')
+RETURNING id
+`, wechatSecondUserID).Scan(new(int64)))
+
+ _, err = tx.ExecContext(ctx, string(migrationSQL))
+ require.NoError(t, err)
+
+ var linuxDoIdentityCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE provider_type = 'linuxdo'
+ AND provider_key = 'linuxdo'
+ AND provider_subject = 'linuxdo-ambiguous-subject'
+`).Scan(&linuxDoIdentityCount))
+ require.Zero(t, linuxDoIdentityCount)
+
+ var wechatIdentityCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE provider_type = 'wechat'
+ AND provider_key = 'wechat-main'
+ AND provider_subject = 'union-ambiguous-subject'
+`).Scan(&wechatIdentityCount))
+ require.Zero(t, wechatIdentityCount)
+
+ var wechatChannelCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_channels
+WHERE provider_type = 'wechat'
+ AND provider_key = 'wechat-main'
+ AND channel = 'oa'
+ AND channel_app_id IN ('wx-ambiguous-a', 'wx-ambiguous-b')
+`).Scan(&wechatChannelCount))
+ require.Zero(t, wechatChannelCount)
+}
+
+func TestAuthIdentityLegacyExternalMigrations_ReportAmbiguousCanonicalSubjectsWithoutWinnerAttribution(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
+ migration115SQL, err := os.ReadFile(migration115Path)
+ require.NoError(t, err)
+
+ migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migration116SQL, err := os.ReadFile(migration116Path)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ var linuxDoFirstUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoFirstUserID))
+
+ var linuxDoSecondUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxDoSecondUserID))
+
+ var wechatFirstUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatFirstUserID))
+
+ var wechatSecondUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-wechat-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&wechatSecondUserID))
+
+ var linuxDoFirstLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-a', 'Legacy LinuxDo Conflict A', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoFirstUserID).Scan(&linuxDoFirstLegacyID))
+
+ var linuxDoSecondLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-b', 'Legacy LinuxDo Conflict B', '{"source":"legacy"}')
+RETURNING id
+`, linuxDoSecondUserID).Scan(&linuxDoSecondLegacyID))
+
+ var wechatFirstLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-conflict-a', 'union-conflict-subject', 'legacy-wechat-conflict-a', 'Legacy WeChat Conflict A', '{"channel":"oa","appid":"wx-conflict-a"}')
+RETURNING id
+`, wechatFirstUserID).Scan(&wechatFirstLegacyID))
+
+ var wechatSecondLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'wechat', 'openid-conflict-b', 'union-conflict-subject', 'legacy-wechat-conflict-b', 'Legacy WeChat Conflict B', '{"channel":"oa","appid":"wx-conflict-b"}')
+RETURNING id
+`, wechatSecondUserID).Scan(&wechatSecondLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migration115SQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration116SQL))
+ require.NoError(t, err)
+
+ var identityCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identities
+WHERE (provider_type = 'linuxdo' AND provider_key = 'linuxdo' AND provider_subject = 'linuxdo-conflict-subject')
+ OR (provider_type = 'wechat' AND provider_key = 'wechat-main' AND provider_subject = 'union-conflict-subject')
+`).Scan(&identityCount))
+ require.Zero(t, identityCount)
+
+ var conflictReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key IN ($1, $2, $3, $4)
+`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&conflictReportCount))
+ require.Equal(t, 4, conflictReportCount)
+
+ var winnerAttributedReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_conflict'
+ AND report_key IN ($1, $2, $3, $4)
+ AND details ->> 'existing_identity_id' IS NOT NULL
+`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&winnerAttributedReportCount))
+ require.Zero(t, winnerAttributedReportCount)
+}
+
+func TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121(t *testing.T) {
+ tx := testTx(t)
+ ctx := context.Background()
+
+ migration108aPath := filepath.Join("..", "..", "migrations", "108a_widen_auth_identity_migration_report_type.sql")
+ migration108aSQL, err := os.ReadFile(migration108aPath)
+ require.NoError(t, err)
+
+ migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql")
+ migration109SQL, err := os.ReadFile(migration109Path)
+ require.NoError(t, err)
+
+ migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
+ migration116SQL, err := os.ReadFile(migration116Path)
+ require.NoError(t, err)
+
+ prepareLegacyExternalIdentitiesTable(t, tx, ctx)
+ truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
+
+ _, err = tx.ExecContext(ctx, `
+ALTER TABLE auth_identity_migration_reports
+ALTER COLUMN report_type TYPE VARCHAR(40);
+`)
+ require.NoError(t, err)
+
+ var oidcSyntheticUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('oidc-before-121@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&oidcSyntheticUserID))
+
+ var linuxdoLegacyUserID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO users (email, password_hash, role, status, balance, concurrency)
+VALUES ('legacy-linuxdo-before-121@example.com', 'hash', 'user', 'active', 0, 1)
+RETURNING id`).Scan(&linuxdoLegacyUserID))
+
+ var invalidMetadataLegacyID int64
+ require.NoError(t, tx.QueryRowContext(ctx, `
+INSERT INTO user_external_identities (
+ user_id,
+ provider,
+ provider_user_id,
+ provider_union_id,
+ provider_username,
+ display_name,
+ metadata
+) VALUES ($1, 'linuxdo', 'linuxdo-before-121', NULL, 'legacy-linuxdo-before-121', 'Legacy LinuxDo Before 121', '{invalid')
+RETURNING id
+`, linuxdoLegacyUserID).Scan(&invalidMetadataLegacyID))
+
+ _, err = tx.ExecContext(ctx, string(migration108aSQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration109SQL))
+ require.NoError(t, err)
+
+ _, err = tx.ExecContext(ctx, string(migration116SQL))
+ require.NoError(t, err)
+
+ var reportTypeWidth int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT character_maximum_length
+FROM information_schema.columns
+WHERE table_schema = 'public'
+ AND table_name = 'auth_identity_migration_reports'
+ AND column_name = 'report_type'
+`).Scan(&reportTypeWidth))
+ require.Equal(t, 80, reportTypeWidth)
+
+ var oidcSyntheticRecoveryReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery'
+ AND report_key = $1
+`, strconv.FormatInt(oidcSyntheticUserID, 10)).Scan(&oidcSyntheticRecoveryReportCount))
+ require.Equal(t, 1, oidcSyntheticRecoveryReportCount)
+
+ var invalidMetadataReportCount int
+ require.NoError(t, tx.QueryRowContext(ctx, `
+SELECT COUNT(*)
+FROM auth_identity_migration_reports
+WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
+ AND report_key = $1
+`, "legacy_external_identity:"+strconv.FormatInt(invalidMetadataLegacyID, 10)).Scan(&invalidMetadataReportCount))
+ require.Equal(t, 1, invalidMetadataReportCount)
+}
+
+func prepareLegacyExternalIdentitiesTable(t *testing.T, tx *sql.Tx, ctx context.Context) {
+ t.Helper()
+
+ _, err := tx.ExecContext(ctx, `
+CREATE TABLE IF NOT EXISTS user_external_identities (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL,
+ provider TEXT NOT NULL,
+ provider_user_id TEXT NOT NULL,
+ provider_union_id TEXT NULL,
+ provider_username TEXT NOT NULL DEFAULT '',
+ display_name TEXT NOT NULL DEFAULT '',
+ profile_url TEXT NOT NULL DEFAULT '',
+ avatar_url TEXT NOT NULL DEFAULT '',
+ metadata TEXT NOT NULL DEFAULT '{}',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+`)
+ require.NoError(t, err)
+}
+
+func truncateAuthIdentityLegacyFixtureTables(t *testing.T, tx *sql.Tx, ctx context.Context) {
+ t.Helper()
+
+ _, err := tx.ExecContext(ctx, `
+TRUNCATE TABLE
+ auth_identity_channels,
+ identity_adoption_decisions,
+ pending_auth_sessions,
+ auth_identities,
+ auth_identity_migration_reports,
+ user_provider_default_grants,
+ user_avatars,
+ user_external_identities,
+ users
+RESTART IDENTITY CASCADE;
+`)
+ require.NoError(t, err)
+}
diff --git a/backend/internal/repository/channel_monitor_repo.go b/backend/internal/repository/channel_monitor_repo.go
new file mode 100644
index 00000000..800ee43b
--- /dev/null
+++ b/backend/internal/repository/channel_monitor_repo.go
@@ -0,0 +1,755 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorhistory"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
+)
+
+// channelMonitorRepository 实现 service.ChannelMonitorRepository。
+//
+// 选型说明:
+// - CRUD 走 ent,复用项目的事务上下文支持
+// - 聚合查询(latest per model / availability)走原生 SQL,避免 ent 在 GROUP BY 上
+// 的样板代码,并保证索引能被命中
+type channelMonitorRepository struct {
+ client *dbent.Client
+ db *sql.DB
+}
+
+// NewChannelMonitorRepository 创建仓储实例。
+func NewChannelMonitorRepository(client *dbent.Client, db *sql.DB) service.ChannelMonitorRepository {
+ return &channelMonitorRepository{client: client, db: db}
+}
+
+// ---------- CRUD ----------
+
+func (r *channelMonitorRepository) Create(ctx context.Context, m *service.ChannelMonitor) error {
+ client := clientFromContext(ctx, r.client)
+ builder := client.ChannelMonitor.Create().
+ SetName(m.Name).
+ SetProvider(channelmonitor.Provider(m.Provider)).
+ SetEndpoint(m.Endpoint).
+ SetAPIKeyEncrypted(m.APIKey). // 调用方传入的已是密文
+ SetPrimaryModel(m.PrimaryModel).
+ SetExtraModels(emptySliceIfNil(m.ExtraModels)).
+ SetGroupName(m.GroupName).
+ SetEnabled(m.Enabled).
+ SetIntervalSeconds(m.IntervalSeconds).
+ SetCreatedBy(m.CreatedBy).
+ SetExtraHeaders(emptyHeadersIfNilRepo(m.ExtraHeaders)).
+ SetBodyOverrideMode(defaultBodyModeRepo(m.BodyOverrideMode))
+ if m.TemplateID != nil {
+ builder = builder.SetTemplateID(*m.TemplateID)
+ }
+ if m.BodyOverride != nil {
+ builder = builder.SetBodyOverride(m.BodyOverride)
+ }
+
+ created, err := builder.Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil)
+ }
+ m.ID = created.ID
+ m.CreatedAt = created.CreatedAt
+ m.UpdatedAt = created.UpdatedAt
+ return nil
+}
+
+func (r *channelMonitorRepository) GetByID(ctx context.Context, id int64) (*service.ChannelMonitor, error) {
+ row, err := r.client.ChannelMonitor.Query().
+ Where(channelmonitor.IDEQ(id)).
+ Only(ctx)
+ if err != nil {
+ return nil, translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil)
+ }
+ return entToServiceMonitor(row), nil
+}
+
+func (r *channelMonitorRepository) Update(ctx context.Context, m *service.ChannelMonitor) error {
+ client := clientFromContext(ctx, r.client)
+ updater := client.ChannelMonitor.UpdateOneID(m.ID).
+ SetName(m.Name).
+ SetProvider(channelmonitor.Provider(m.Provider)).
+ SetEndpoint(m.Endpoint).
+ SetAPIKeyEncrypted(m.APIKey).
+ SetPrimaryModel(m.PrimaryModel).
+ SetExtraModels(emptySliceIfNil(m.ExtraModels)).
+ SetGroupName(m.GroupName).
+ SetEnabled(m.Enabled).
+ SetIntervalSeconds(m.IntervalSeconds).
+ SetExtraHeaders(emptyHeadersIfNilRepo(m.ExtraHeaders)).
+ SetBodyOverrideMode(defaultBodyModeRepo(m.BodyOverrideMode))
+ if m.TemplateID != nil {
+ updater = updater.SetTemplateID(*m.TemplateID)
+ } else {
+ updater = updater.ClearTemplateID()
+ }
+ if m.BodyOverride != nil {
+ updater = updater.SetBodyOverride(m.BodyOverride)
+ } else {
+ updater = updater.ClearBodyOverride()
+ }
+
+ updated, err := updater.Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil)
+ }
+ m.UpdatedAt = updated.UpdatedAt
+ return nil
+}
+
+func (r *channelMonitorRepository) Delete(ctx context.Context, id int64) error {
+ client := clientFromContext(ctx, r.client)
+ if err := client.ChannelMonitor.DeleteOneID(id).Exec(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil)
+ }
+ return nil
+}
+
+func (r *channelMonitorRepository) List(ctx context.Context, params service.ChannelMonitorListParams) ([]*service.ChannelMonitor, int64, error) {
+ q := r.client.ChannelMonitor.Query()
+ if params.Provider != "" {
+ q = q.Where(channelmonitor.ProviderEQ(channelmonitor.Provider(params.Provider)))
+ }
+ if params.Enabled != nil {
+ q = q.Where(channelmonitor.EnabledEQ(*params.Enabled))
+ }
+ if s := strings.TrimSpace(params.Search); s != "" {
+ q = q.Where(channelmonitor.Or(
+ channelmonitor.NameContainsFold(s),
+ channelmonitor.GroupNameContainsFold(s),
+ channelmonitor.PrimaryModelContainsFold(s),
+ ))
+ }
+
+ total, err := q.Count(ctx)
+ if err != nil {
+ return nil, 0, fmt.Errorf("count monitors: %w", err)
+ }
+
+ pageSize := params.PageSize
+ if pageSize <= 0 {
+ pageSize = 20
+ }
+ page := params.Page
+ if page <= 0 {
+ page = 1
+ }
+
+ rows, err := q.
+ Order(dbent.Desc(channelmonitor.FieldID)).
+ Offset((page - 1) * pageSize).
+ Limit(pageSize).
+ All(ctx)
+ if err != nil {
+ return nil, 0, fmt.Errorf("list monitors: %w", err)
+ }
+
+ out := make([]*service.ChannelMonitor, 0, len(rows))
+ for _, row := range rows {
+ out = append(out, entToServiceMonitor(row))
+ }
+ return out, int64(total), nil
+}
+
+// ---------- 调度器辅助 ----------
+
+func (r *channelMonitorRepository) ListEnabled(ctx context.Context) ([]*service.ChannelMonitor, error) {
+ rows, err := r.client.ChannelMonitor.Query().
+ Where(channelmonitor.EnabledEQ(true)).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list enabled monitors: %w", err)
+ }
+ out := make([]*service.ChannelMonitor, 0, len(rows))
+ for _, row := range rows {
+ out = append(out, entToServiceMonitor(row))
+ }
+ return out, nil
+}
+
+func (r *channelMonitorRepository) MarkChecked(ctx context.Context, id int64, checkedAt time.Time) error {
+ client := clientFromContext(ctx, r.client)
+ if err := client.ChannelMonitor.UpdateOneID(id).
+ SetLastCheckedAt(checkedAt).
+ Exec(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorNotFound, nil)
+ }
+ return nil
+}
+
+func (r *channelMonitorRepository) InsertHistoryBatch(ctx context.Context, rows []*service.ChannelMonitorHistoryRow) error {
+ if len(rows) == 0 {
+ return nil
+ }
+ client := clientFromContext(ctx, r.client)
+ bulk := make([]*dbent.ChannelMonitorHistoryCreate, 0, len(rows))
+ for _, row := range rows {
+ c := client.ChannelMonitorHistory.Create().
+ SetMonitorID(row.MonitorID).
+ SetModel(row.Model).
+ SetStatus(channelmonitorhistory.Status(row.Status)).
+ SetMessage(row.Message).
+ SetCheckedAt(row.CheckedAt)
+ if row.LatencyMs != nil {
+ c = c.SetLatencyMs(*row.LatencyMs)
+ }
+ if row.PingLatencyMs != nil {
+ c = c.SetPingLatencyMs(*row.PingLatencyMs)
+ }
+ bulk = append(bulk, c)
+ }
+ if _, err := client.ChannelMonitorHistory.CreateBulk(bulk...).Save(ctx); err != nil {
+ return fmt.Errorf("insert history bulk: %w", err)
+ }
+ return nil
+}
+
+// DeleteHistoryBefore 物理删 checked_at < before 的明细,分批 channelMonitorPruneBatchSize 行一批,
+// 避免单事务删除过多引起锁/WAL 压力。借助 (checked_at) 索引定位小批 id,再按 id 删。
+func (r *channelMonitorRepository) DeleteHistoryBefore(ctx context.Context, before time.Time) (int64, error) {
+ return deleteChannelMonitorBatched(ctx, r.db, channelMonitorPruneHistorySQL, before)
+}
+
+// ListHistory 按 checked_at 倒序返回某个监控的最近 N 条历史记录。
+// model 为空时不过滤;非空时只返回该模型的记录。
+func (r *channelMonitorRepository) ListHistory(ctx context.Context, monitorID int64, model string, limit int) ([]*service.ChannelMonitorHistoryEntry, error) {
+ q := r.client.ChannelMonitorHistory.Query().
+ Where(channelmonitorhistory.MonitorIDEQ(monitorID))
+ if strings.TrimSpace(model) != "" {
+ q = q.Where(channelmonitorhistory.ModelEQ(model))
+ }
+ rows, err := q.
+ Order(dbent.Desc(channelmonitorhistory.FieldCheckedAt)).
+ Limit(limit).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list history: %w", err)
+ }
+ out := make([]*service.ChannelMonitorHistoryEntry, 0, len(rows))
+ for _, row := range rows {
+ entry := &service.ChannelMonitorHistoryEntry{
+ ID: row.ID,
+ Model: row.Model,
+ Status: string(row.Status),
+ LatencyMs: row.LatencyMs,
+ PingLatencyMs: row.PingLatencyMs,
+ Message: row.Message,
+ CheckedAt: row.CheckedAt,
+ }
+ out = append(out, entry)
+ }
+ return out, nil
+}
+
+// ---------- 用户视图聚合(原生 SQL) ----------
+
+// ListLatestPerModel 用 DISTINCT ON 取每个 (monitor_id, model) 的最近一条记录。
+// 借助 (monitor_id, model, checked_at DESC) 索引可走 Index Scan。
+func (r *channelMonitorRepository) ListLatestPerModel(ctx context.Context, monitorID int64) ([]*service.ChannelMonitorLatest, error) {
+ const q = `
+ SELECT DISTINCT ON (model)
+ model, status, latency_ms, ping_latency_ms, checked_at
+ FROM channel_monitor_histories
+ WHERE monitor_id = $1
+ ORDER BY model, checked_at DESC
+ `
+ rows, err := r.db.QueryContext(ctx, q, monitorID)
+ if err != nil {
+ return nil, fmt.Errorf("query latest per model: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ out := make([]*service.ChannelMonitorLatest, 0)
+ for rows.Next() {
+ l := &service.ChannelMonitorLatest{}
+ var latency, ping sql.NullInt64
+ if err := rows.Scan(&l.Model, &l.Status, &latency, &ping, &l.CheckedAt); err != nil {
+ return nil, fmt.Errorf("scan latest row: %w", err)
+ }
+ assignNullInt(&l.LatencyMs, latency)
+ assignNullInt(&l.PingLatencyMs, ping)
+ out = append(out, l)
+ }
+ return out, rows.Err()
+}
+
+// assignNullInt 把 sql.NullInt64 解包到 *int 指针目标(valid 才分配新 int)。
+// 集中实现避免 latency / ping 两处重复 if latency.Valid { v := int(...) ... } 模板。
+func assignNullInt(dst **int, n sql.NullInt64) {
+ if !n.Valid {
+ return
+ }
+ v := int(n.Int64)
+ *dst = &v
+}
+
+// ComputeAvailability 计算指定窗口内每个模型的可用率与平均延迟。
+// "可用" = status IN (operational, degraded)。
+//
+// 数据来源:明细表只保留 1 天;窗口前其余天数走聚合表。
+// 明细保留 30 天(monitorHistoryRetentionDays),窗口 <= 30 天时直接扫 histories,
+// 精度到秒,避免与聚合表 UNION 带来的 UTC 日切精度损失。
+func (r *channelMonitorRepository) ComputeAvailability(ctx context.Context, monitorID int64, windowDays int) ([]*service.ChannelMonitorAvailability, error) {
+ if windowDays <= 0 {
+ windowDays = 7
+ }
+ const q = `
+ SELECT model,
+ COUNT(*) AS total,
+ COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok,
+ CASE WHEN COUNT(latency_ms) > 0
+ THEN SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL)::float8 / COUNT(latency_ms)
+ ELSE NULL END AS avg_latency_ms
+ FROM channel_monitor_histories
+ WHERE monitor_id = $1
+ AND checked_at >= NOW() - ($2::int || ' days')::interval
+ GROUP BY model
+ `
+ rows, err := r.db.QueryContext(ctx, q, monitorID, windowDays)
+ if err != nil {
+ return nil, fmt.Errorf("query availability: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ out := make([]*service.ChannelMonitorAvailability, 0)
+ for rows.Next() {
+ row, err := scanAvailabilityRow(rows, windowDays)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, row)
+ }
+ return out, rows.Err()
+}
+
+// scanAvailabilityRow 把单行 (model, total, ok, avg_latency) 扫描为 ChannelMonitorAvailability。
+// 仅服务于 ComputeAvailability(4 列);批量版本因为多一列 monitor_id 直接 inline 调 finalizeAvailabilityRow。
+func scanAvailabilityRow(rows interface{ Scan(...any) error }, windowDays int) (*service.ChannelMonitorAvailability, error) {
+ row := &service.ChannelMonitorAvailability{WindowDays: windowDays}
+ var avgLatency sql.NullFloat64
+ if err := rows.Scan(&row.Model, &row.TotalChecks, &row.OperationalChecks, &avgLatency); err != nil {
+ return nil, fmt.Errorf("scan availability row: %w", err)
+ }
+ finalizeAvailabilityRow(row, avgLatency)
+ return row, nil
+}
+
+// finalizeAvailabilityRow 根据 OperationalChecks/TotalChecks 算出可用率,
+// 并把 sql.NullFloat64 的平均延迟解包为 *int。两处复用避免维护漂移。
+func finalizeAvailabilityRow(row *service.ChannelMonitorAvailability, avgLatency sql.NullFloat64) {
+ if row.TotalChecks > 0 {
+ row.AvailabilityPct = float64(row.OperationalChecks) * 100.0 / float64(row.TotalChecks)
+ }
+ if avgLatency.Valid {
+ v := int(avgLatency.Float64)
+ row.AvgLatencyMs = &v
+ }
+}
+
+// ListLatestForMonitorIDs 一次性查询多个监控的"每个 (monitor_id, model) 最近一条"记录。
+// 利用 PG 的 DISTINCT ON 特性,借助 (monitor_id, model, checked_at DESC) 索引可走 Index Scan。
+func (r *channelMonitorRepository) ListLatestForMonitorIDs(ctx context.Context, ids []int64) (map[int64][]*service.ChannelMonitorLatest, error) {
+ out := make(map[int64][]*service.ChannelMonitorLatest, len(ids))
+ if len(ids) == 0 {
+ return out, nil
+ }
+ const q = `
+ SELECT DISTINCT ON (monitor_id, model)
+ monitor_id, model, status, latency_ms, ping_latency_ms, checked_at
+ FROM channel_monitor_histories
+ WHERE monitor_id = ANY($1)
+ ORDER BY monitor_id, model, checked_at DESC
+ `
+ rows, err := r.db.QueryContext(ctx, q, pq.Array(ids))
+ if err != nil {
+ return nil, fmt.Errorf("query latest batch: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var monitorID int64
+ l := &service.ChannelMonitorLatest{}
+ var latency, ping sql.NullInt64
+ if err := rows.Scan(&monitorID, &l.Model, &l.Status, &latency, &ping, &l.CheckedAt); err != nil {
+ return nil, fmt.Errorf("scan latest batch row: %w", err)
+ }
+ assignNullInt(&l.LatencyMs, latency)
+ assignNullInt(&l.PingLatencyMs, ping)
+ out[monitorID] = append(out[monitorID], l)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// ListRecentHistoryForMonitors 为多个 monitor 批量取各自"指定模型"最近 N 条历史(按 checked_at DESC,最新在前)。
+// primaryModels[monitorID] 指定该监控要过滤的模型名;monitor 不在 primaryModels 中的记录不返回。
+// 通过 CTE + unnest(两个 int8/text 数组) 构造 (monitor_id, model) 白名单,
+// 再用 ROW_NUMBER() OVER (PARTITION BY monitor_id) 取各自前 N 条。
+//
+// 返回值:map[monitorID] -> []*ChannelMonitorHistoryEntry(不含 message,减少网络开销)。
+// 空 ids / 空 primaryModels 返回空 map,不报错。
+func (r *channelMonitorRepository) ListRecentHistoryForMonitors(
+ ctx context.Context,
+ ids []int64,
+ primaryModels map[int64]string,
+ perMonitorLimit int,
+) (map[int64][]*service.ChannelMonitorHistoryEntry, error) {
+ out := make(map[int64][]*service.ChannelMonitorHistoryEntry, len(ids))
+ pairIDs, pairModels := buildMonitorModelPairs(ids, primaryModels)
+ if len(pairIDs) == 0 {
+ return out, nil
+ }
+ perMonitorLimit = clampTimelineLimit(perMonitorLimit)
+
+ const q = `
+ WITH targets AS (
+ SELECT unnest($1::bigint[]) AS monitor_id,
+ unnest($2::text[]) AS model
+ ),
+ ranked AS (
+ SELECT h.monitor_id,
+ h.status,
+ h.latency_ms,
+ h.ping_latency_ms,
+ h.checked_at,
+ ROW_NUMBER() OVER (PARTITION BY h.monitor_id ORDER BY h.checked_at DESC) AS rn
+ FROM channel_monitor_histories h
+ JOIN targets t
+ ON t.monitor_id = h.monitor_id AND t.model = h.model
+ )
+ SELECT monitor_id, status, latency_ms, ping_latency_ms, checked_at
+ FROM ranked
+ WHERE rn <= $3
+ ORDER BY monitor_id, checked_at DESC
+ `
+ rows, err := r.db.QueryContext(ctx, q, pq.Array(pairIDs), pq.Array(pairModels), perMonitorLimit)
+ if err != nil {
+ return nil, fmt.Errorf("query recent history batch: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var monitorID int64
+ entry := &service.ChannelMonitorHistoryEntry{}
+ var latency, ping sql.NullInt64
+ if err := rows.Scan(&monitorID, &entry.Status, &latency, &ping, &entry.CheckedAt); err != nil {
+ return nil, fmt.Errorf("scan recent history row: %w", err)
+ }
+ assignNullInt(&entry.LatencyMs, latency)
+ assignNullInt(&entry.PingLatencyMs, ping)
+ out[monitorID] = append(out[monitorID], entry)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// buildMonitorModelPairs 基于 ids 过滤出有效的 (monitor_id, model) 对,model 为空时跳过。
+// 保证两个数组长度一致且一一对应,供 unnest 展开。
+func buildMonitorModelPairs(ids []int64, primaryModels map[int64]string) ([]int64, []string) {
+ if len(ids) == 0 || len(primaryModels) == 0 {
+ return nil, nil
+ }
+ pairIDs := make([]int64, 0, len(ids))
+ pairModels := make([]string, 0, len(ids))
+ for _, id := range ids {
+ model, ok := primaryModels[id]
+ if !ok || strings.TrimSpace(model) == "" {
+ continue
+ }
+ pairIDs = append(pairIDs, id)
+ pairModels = append(pairModels, model)
+ }
+ return pairIDs, pairModels
+}
+
+// timelineLimit* 批量 timeline 查询的 perMonitorLimit 夹紧范围。
+// 下限 1 表示至少返回最近一条;上限 200 控制单次响应体与 SQL 内存占用(ROW_NUMBER 窗口上限)。
+const (
+ timelineLimitMin = 1
+ timelineLimitMax = 200
+)
+
+// clampTimelineLimit 把 perMonitorLimit 夹紧到 [timelineLimitMin, timelineLimitMax],避免非法值或超大查询。
+func clampTimelineLimit(n int) int {
+ if n < timelineLimitMin {
+ return timelineLimitMin
+ }
+ if n > timelineLimitMax {
+ return timelineLimitMax
+ }
+ return n
+}
+
+// ComputeAvailabilityForMonitors 一次性计算多个监控在某个窗口内的每模型可用率与平均延迟。
+// 明细保留 30 天,直接扫 histories(窗口 <= 30 天时无需聚合)。
+func (r *channelMonitorRepository) ComputeAvailabilityForMonitors(ctx context.Context, ids []int64, windowDays int) (map[int64][]*service.ChannelMonitorAvailability, error) {
+ out := make(map[int64][]*service.ChannelMonitorAvailability, len(ids))
+ if len(ids) == 0 {
+ return out, nil
+ }
+ if windowDays <= 0 {
+ windowDays = 7
+ }
+ const q = `
+ SELECT monitor_id,
+ model,
+ COUNT(*) AS total,
+ COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok,
+ CASE WHEN COUNT(latency_ms) > 0
+ THEN SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL)::float8 / COUNT(latency_ms)
+ ELSE NULL END AS avg_latency_ms
+ FROM channel_monitor_histories
+ WHERE monitor_id = ANY($1)
+ AND checked_at >= NOW() - ($2::int || ' days')::interval
+ GROUP BY monitor_id, model
+ `
+ rows, err := r.db.QueryContext(ctx, q, pq.Array(ids), windowDays)
+ if err != nil {
+ return nil, fmt.Errorf("query availability batch: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var monitorID int64
+ row := &service.ChannelMonitorAvailability{WindowDays: windowDays}
+ var avgLatency sql.NullFloat64
+ if err := rows.Scan(&monitorID, &row.Model, &row.TotalChecks, &row.OperationalChecks, &avgLatency); err != nil {
+ return nil, fmt.Errorf("scan availability batch row: %w", err)
+ }
+ // 批量查询多了首列 monitor_id;其余字段的可用率/平均延迟换算与单 monitor 版本一致,
+ // 抽出 finalizeAvailabilityRow 复用,避免两处分别维护除法与 NullFloat 解包。
+ finalizeAvailabilityRow(row, avgLatency)
+ out[monitorID] = append(out[monitorID], row)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// ---------- 聚合维护 ----------
+
+// UpsertDailyRollupsFor 把 targetDate 当天([targetDate, targetDate+1d))的明细
+// 按 (monitor_id, model, bucket_date) 聚合写入 channel_monitor_daily_rollups。
+// - 用 ON CONFLICT (monitor_id, model, bucket_date) DO UPDATE 实现幂等回填,
+// 重复执行只会用最新统计覆盖;
+// - $1::date 让 PG 自动把入参 truncate 到 UTC 日期,调用方不需要预处理 targetDate。
+func (r *channelMonitorRepository) UpsertDailyRollupsFor(ctx context.Context, targetDate time.Time) (int64, error) {
+ const q = `
+ INSERT INTO channel_monitor_daily_rollups (
+ monitor_id, model, bucket_date,
+ total_checks, ok_count,
+ operational_count, degraded_count, failed_count, error_count,
+ sum_latency_ms, count_latency,
+ sum_ping_latency_ms, count_ping_latency,
+ computed_at
+ )
+ SELECT
+ monitor_id,
+ model,
+ $1::date AS bucket_date,
+ COUNT(*) AS total_checks,
+ COUNT(*) FILTER (WHERE status IN ('operational','degraded')) AS ok_count,
+ COUNT(*) FILTER (WHERE status = 'operational') AS operational_count,
+ COUNT(*) FILTER (WHERE status = 'degraded') AS degraded_count,
+ COUNT(*) FILTER (WHERE status = 'failed') AS failed_count,
+ COUNT(*) FILTER (WHERE status = 'error') AS error_count,
+ COALESCE(SUM(latency_ms) FILTER (WHERE latency_ms IS NOT NULL), 0) AS sum_latency_ms,
+ COUNT(latency_ms) AS count_latency,
+ COALESCE(SUM(ping_latency_ms) FILTER (WHERE ping_latency_ms IS NOT NULL), 0) AS sum_ping_latency_ms,
+ COUNT(ping_latency_ms) AS count_ping_latency,
+ NOW()
+ FROM channel_monitor_histories
+ WHERE checked_at >= $1::date
+ AND checked_at < ($1::date + INTERVAL '1 day')
+ GROUP BY monitor_id, model
+ ON CONFLICT (monitor_id, model, bucket_date) DO UPDATE SET
+ total_checks = EXCLUDED.total_checks,
+ ok_count = EXCLUDED.ok_count,
+ operational_count = EXCLUDED.operational_count,
+ degraded_count = EXCLUDED.degraded_count,
+ failed_count = EXCLUDED.failed_count,
+ error_count = EXCLUDED.error_count,
+ sum_latency_ms = EXCLUDED.sum_latency_ms,
+ count_latency = EXCLUDED.count_latency,
+ sum_ping_latency_ms = EXCLUDED.sum_ping_latency_ms,
+ count_ping_latency = EXCLUDED.count_ping_latency,
+ computed_at = NOW()
+ `
+ res, err := r.db.ExecContext(ctx, q, targetDate)
+ if err != nil {
+ return 0, fmt.Errorf("upsert daily rollups for %s: %w", targetDate.Format("2006-01-02"), err)
+ }
+ n, err := res.RowsAffected()
+ if err != nil {
+ return 0, fmt.Errorf("rows affected (upsert rollups): %w", err)
+ }
+ return n, nil
+}
+
+// DeleteRollupsBefore 物理删 bucket_date < beforeDate 的聚合行,同样分批。
+func (r *channelMonitorRepository) DeleteRollupsBefore(ctx context.Context, beforeDate time.Time) (int64, error) {
+ return deleteChannelMonitorBatched(ctx, r.db, channelMonitorPruneRollupSQL, beforeDate)
+}
+
+// channelMonitorPruneBatchSize 单批删除上限。与 ops_cleanup_service 保持一致的 5000,
+// 在大表上按 id 小批删可以避免长事务和 WAL 堆积。
+const channelMonitorPruneBatchSize = 5000
+
+// channelMonitorPruneHistorySQL 分批物理删明细表过期行。
+const channelMonitorPruneHistorySQL = `
+WITH batch AS (
+ SELECT id FROM channel_monitor_histories
+ WHERE checked_at < $1
+ ORDER BY id
+ LIMIT $2
+)
+DELETE FROM channel_monitor_histories
+WHERE id IN (SELECT id FROM batch)
+`
+
+// channelMonitorPruneRollupSQL 分批物理删 rollup 表过期行。bucket_date 需要 ::date 转型
+// 保证与 DATE 列一致比较。
+const channelMonitorPruneRollupSQL = `
+WITH batch AS (
+ SELECT id FROM channel_monitor_daily_rollups
+ WHERE bucket_date < $1::date
+ ORDER BY id
+ LIMIT $2
+)
+DELETE FROM channel_monitor_daily_rollups
+WHERE id IN (SELECT id FROM batch)
+`
+
+// deleteChannelMonitorBatched 循环执行分批 DELETE,直到影响行为 0。返回累计删除行数。
+// cutoff 由调用方按列类型传入(明细用 time.Time 对 TIMESTAMPTZ,rollup 用 time.Time SQL 侧 ::date 转型)。
+func deleteChannelMonitorBatched(ctx context.Context, db *sql.DB, query string, cutoff time.Time) (int64, error) {
+ var total int64
+ for {
+ res, err := db.ExecContext(ctx, query, cutoff, channelMonitorPruneBatchSize)
+ if err != nil {
+ return total, fmt.Errorf("channel_monitor prune batch: %w", err)
+ }
+ affected, err := res.RowsAffected()
+ if err != nil {
+ return total, fmt.Errorf("channel_monitor prune rows affected: %w", err)
+ }
+ total += affected
+ if affected == 0 {
+ break
+ }
+ }
+ return total, nil
+}
+
+// LoadAggregationWatermark 读 watermark 表(id=1)。
+// watermark 表不是 ent schema(只有一行),直接走原生 SQL。
+// - 行不存在或 last_aggregated_date IS NULL:返回 (nil, nil),由调用方决定首次回填策略
+func (r *channelMonitorRepository) LoadAggregationWatermark(ctx context.Context) (*time.Time, error) {
+ const q = `SELECT last_aggregated_date FROM channel_monitor_aggregation_watermark WHERE id = 1`
+ var t sql.NullTime
+ if err := r.db.QueryRowContext(ctx, q).Scan(&t); err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("load aggregation watermark: %w", err)
+ }
+ if !t.Valid {
+ return nil, nil
+ }
+ return &t.Time, nil
+}
+
+// UpdateAggregationWatermark 更新 watermark(UPSERT 到 id=1)。
+// $1::date 让 PG 把入参 truncate 到 UTC 日期,与 last_aggregated_date 列的 DATE 类型一致。
+func (r *channelMonitorRepository) UpdateAggregationWatermark(ctx context.Context, date time.Time) error {
+ const q = `
+ INSERT INTO channel_monitor_aggregation_watermark (id, last_aggregated_date, updated_at)
+ VALUES (1, $1::date, NOW())
+ ON CONFLICT (id) DO UPDATE SET
+ last_aggregated_date = EXCLUDED.last_aggregated_date,
+ updated_at = NOW()
+ `
+ if _, err := r.db.ExecContext(ctx, q, date); err != nil {
+ return fmt.Errorf("update aggregation watermark: %w", err)
+ }
+ return nil
+}
+
+// ---------- helpers ----------
+
+func entToServiceMonitor(row *dbent.ChannelMonitor) *service.ChannelMonitor {
+ if row == nil {
+ return nil
+ }
+ extras := row.ExtraModels
+ if extras == nil {
+ extras = []string{}
+ }
+ headers := row.ExtraHeaders
+ if headers == nil {
+ headers = map[string]string{}
+ }
+ out := &service.ChannelMonitor{
+ ID: row.ID,
+ Name: row.Name,
+ Provider: string(row.Provider),
+ Endpoint: row.Endpoint,
+ APIKey: row.APIKeyEncrypted, // 仍为密文,service 层负责解密
+ PrimaryModel: row.PrimaryModel,
+ ExtraModels: extras,
+ GroupName: row.GroupName,
+ Enabled: row.Enabled,
+ IntervalSeconds: row.IntervalSeconds,
+ LastCheckedAt: row.LastCheckedAt,
+ CreatedBy: row.CreatedBy,
+ CreatedAt: row.CreatedAt,
+ UpdatedAt: row.UpdatedAt,
+ ExtraHeaders: headers,
+ BodyOverrideMode: row.BodyOverrideMode,
+ BodyOverride: row.BodyOverride,
+ }
+ if row.TemplateID != nil {
+ id := *row.TemplateID
+ out.TemplateID = &id
+ }
+ return out
+}
+
+// emptyHeadersIfNilRepo 与 service.emptyHeadersIfNil 功能一致,
+// repo 独立一份避免 import 循环。
+func emptyHeadersIfNilRepo(h map[string]string) map[string]string {
+ if h == nil {
+ return map[string]string{}
+ }
+ return h
+}
+
+// defaultBodyModeRepo 空串归一为 off(同上不循环)。
+func defaultBodyModeRepo(mode string) string {
+ if mode == "" {
+ return "off"
+ }
+ return mode
+}
+
+func emptySliceIfNil(in []string) []string {
+ if in == nil {
+ return []string{}
+ }
+ return in
+}
diff --git a/backend/internal/repository/channel_monitor_template_repo.go b/backend/internal/repository/channel_monitor_template_repo.go
new file mode 100644
index 00000000..845d186b
--- /dev/null
+++ b/backend/internal/repository/channel_monitor_template_repo.go
@@ -0,0 +1,195 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitor"
+ "github.com/Wei-Shaw/sub2api/ent/channelmonitorrequesttemplate"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+// channelMonitorRequestTemplateRepository 实现 service.ChannelMonitorRequestTemplateRepository。
+// 与 channelMonitorRepository 分开一个文件,职责清晰。
+type channelMonitorRequestTemplateRepository struct {
+ client *dbent.Client
+ db *sql.DB
+}
+
+// NewChannelMonitorRequestTemplateRepository 创建模板仓储实例。
+func NewChannelMonitorRequestTemplateRepository(client *dbent.Client, db *sql.DB) service.ChannelMonitorRequestTemplateRepository {
+ return &channelMonitorRequestTemplateRepository{client: client, db: db}
+}
+
+// ---------- CRUD ----------
+
+func (r *channelMonitorRequestTemplateRepository) Create(ctx context.Context, t *service.ChannelMonitorRequestTemplate) error {
+ client := clientFromContext(ctx, r.client)
+ builder := client.ChannelMonitorRequestTemplate.Create().
+ SetName(t.Name).
+ SetProvider(channelmonitorrequesttemplate.Provider(t.Provider)).
+ SetDescription(t.Description).
+ SetExtraHeaders(emptyHeadersIfNilRepo(t.ExtraHeaders)).
+ SetBodyOverrideMode(defaultBodyModeRepo(t.BodyOverrideMode))
+ if t.BodyOverride != nil {
+ builder = builder.SetBodyOverride(t.BodyOverride)
+ }
+
+ created, err := builder.Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
+ }
+ t.ID = created.ID
+ t.CreatedAt = created.CreatedAt
+ t.UpdatedAt = created.UpdatedAt
+ return nil
+}
+
+func (r *channelMonitorRequestTemplateRepository) GetByID(ctx context.Context, id int64) (*service.ChannelMonitorRequestTemplate, error) {
+ row, err := r.client.ChannelMonitorRequestTemplate.Query().
+ Where(channelmonitorrequesttemplate.IDEQ(id)).
+ Only(ctx)
+ if err != nil {
+ return nil, translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
+ }
+ return entToServiceTemplate(row), nil
+}
+
+func (r *channelMonitorRequestTemplateRepository) Update(ctx context.Context, t *service.ChannelMonitorRequestTemplate) error {
+ client := clientFromContext(ctx, r.client)
+ updater := client.ChannelMonitorRequestTemplate.UpdateOneID(t.ID).
+ SetName(t.Name).
+ SetDescription(t.Description).
+ SetExtraHeaders(emptyHeadersIfNilRepo(t.ExtraHeaders)).
+ SetBodyOverrideMode(defaultBodyModeRepo(t.BodyOverrideMode))
+ if t.BodyOverride != nil {
+ updater = updater.SetBodyOverride(t.BodyOverride)
+ } else {
+ updater = updater.ClearBodyOverride()
+ }
+ updated, err := updater.Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
+ }
+ t.UpdatedAt = updated.UpdatedAt
+ return nil
+}
+
+func (r *channelMonitorRequestTemplateRepository) Delete(ctx context.Context, id int64) error {
+ client := clientFromContext(ctx, r.client)
+ if err := client.ChannelMonitorRequestTemplate.DeleteOneID(id).Exec(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
+ }
+ return nil
+}
+
+func (r *channelMonitorRequestTemplateRepository) List(ctx context.Context, params service.ChannelMonitorRequestTemplateListParams) ([]*service.ChannelMonitorRequestTemplate, error) {
+ q := r.client.ChannelMonitorRequestTemplate.Query()
+ if params.Provider != "" {
+ q = q.Where(channelmonitorrequesttemplate.ProviderEQ(channelmonitorrequesttemplate.Provider(params.Provider)))
+ }
+ rows, err := q.
+ Order(dbent.Asc(channelmonitorrequesttemplate.FieldProvider), dbent.Asc(channelmonitorrequesttemplate.FieldName)).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list monitor templates: %w", err)
+ }
+ out := make([]*service.ChannelMonitorRequestTemplate, 0, len(rows))
+ for _, row := range rows {
+ out = append(out, entToServiceTemplate(row))
+ }
+ return out, nil
+}
+
+// ApplyToMonitors 把模板当前配置覆盖到 monitorIDs 列表里的关联监控。
+// WHERE 双重过滤:template_id = id AND id IN (monitorIDs),防止用户传了未关联本模板的 id
+// 就被覆盖。走 ent UpdateMany 保留 hooks。
+func (r *channelMonitorRequestTemplateRepository) ApplyToMonitors(ctx context.Context, id int64, monitorIDs []int64) (int64, error) {
+ if len(monitorIDs) == 0 {
+ return 0, nil
+ }
+ client := clientFromContext(ctx, r.client)
+ tpl, err := client.ChannelMonitorRequestTemplate.Query().
+ Where(channelmonitorrequesttemplate.IDEQ(id)).
+ Only(ctx)
+ if err != nil {
+ return 0, translatePersistenceError(err, service.ErrChannelMonitorTemplateNotFound, nil)
+ }
+
+ updater := client.ChannelMonitor.Update().
+ Where(
+ channelmonitor.TemplateIDEQ(id),
+ channelmonitor.IDIn(monitorIDs...),
+ ).
+ SetExtraHeaders(emptyHeadersIfNilRepo(tpl.ExtraHeaders)).
+ SetBodyOverrideMode(defaultBodyModeRepo(tpl.BodyOverrideMode))
+ if tpl.BodyOverride != nil {
+ updater = updater.SetBodyOverride(tpl.BodyOverride)
+ } else {
+ updater = updater.ClearBodyOverride()
+ }
+
+ affected, err := updater.Save(ctx)
+ if err != nil {
+ return 0, fmt.Errorf("apply template to monitors: %w", err)
+ }
+ return int64(affected), nil
+}
+
+// CountAssociatedMonitors 统计关联监控数(UI 展示「N 个配置」用)。
+func (r *channelMonitorRequestTemplateRepository) CountAssociatedMonitors(ctx context.Context, id int64) (int64, error) {
+ count, err := r.client.ChannelMonitor.Query().
+ Where(channelmonitor.TemplateIDEQ(id)).
+ Count(ctx)
+ if err != nil {
+ return 0, fmt.Errorf("count monitors for template %d: %w", id, err)
+ }
+ return int64(count), nil
+}
+
+// ListAssociatedMonitors 列出模板关联的所有监控简略字段。
+// ORDER BY name 稳定输出方便前端展示。
+func (r *channelMonitorRequestTemplateRepository) ListAssociatedMonitors(ctx context.Context, id int64) ([]*service.AssociatedMonitorBrief, error) {
+ rows, err := r.client.ChannelMonitor.Query().
+ Where(channelmonitor.TemplateIDEQ(id)).
+ Order(dbent.Asc(channelmonitor.FieldName)).
+ All(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("list associated monitors for template %d: %w", id, err)
+ }
+ out := make([]*service.AssociatedMonitorBrief, 0, len(rows))
+ for _, row := range rows {
+ out = append(out, &service.AssociatedMonitorBrief{
+ ID: row.ID,
+ Name: row.Name,
+ Provider: string(row.Provider),
+ Enabled: row.Enabled,
+ })
+ }
+ return out, nil
+}
+
+// ---------- helpers ----------
+
+func entToServiceTemplate(row *dbent.ChannelMonitorRequestTemplate) *service.ChannelMonitorRequestTemplate {
+ if row == nil {
+ return nil
+ }
+ headers := row.ExtraHeaders
+ if headers == nil {
+ headers = map[string]string{}
+ }
+ return &service.ChannelMonitorRequestTemplate{
+ ID: row.ID,
+ Name: row.Name,
+ Provider: string(row.Provider),
+ Description: row.Description,
+ ExtraHeaders: headers,
+ BodyOverrideMode: row.BodyOverrideMode,
+ BodyOverride: row.BodyOverride,
+ CreatedAt: row.CreatedAt,
+ UpdatedAt: row.UpdatedAt,
+ }
+}
diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go
new file mode 100644
index 00000000..2cb90aab
--- /dev/null
+++ b/backend/internal/repository/channel_repo.go
@@ -0,0 +1,551 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
+)
+
+type channelRepository struct {
+ db *sql.DB
+}
+
+// NewChannelRepository 创建渠道数据访问实例
+func NewChannelRepository(db *sql.DB) service.ChannelRepository {
+ return &channelRepository{db: db}
+}
+
+// runInTx 在事务中执行 fn,成功 commit,失败 rollback。
+func (r *channelRepository) runInTx(ctx context.Context, fn func(tx *sql.Tx) error) error {
+ tx, err := r.db.BeginTx(ctx, nil)
+ if err != nil {
+ return fmt.Errorf("begin tx: %w", err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ if err := fn(tx); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func (r *channelRepository) Create(ctx context.Context, channel *service.Channel) error {
+ return r.runInTx(ctx, func(tx *sql.Tx) error {
+ modelMappingJSON, err := marshalModelMapping(channel.ModelMapping)
+ if err != nil {
+ return err
+ }
+ featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
+ if err != nil {
+ return err
+ }
+ err = tx.QueryRowContext(ctx,
+ `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
+ RETURNING id, created_at, updated_at`,
+ channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats,
+ ).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
+ if err != nil {
+ if isUniqueViolation(err) {
+ return service.ErrChannelExists
+ }
+ return fmt.Errorf("insert channel: %w", err)
+ }
+
+ // 设置分组关联
+ if len(channel.GroupIDs) > 0 {
+ if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
+ return err
+ }
+ }
+
+ // 设置模型定价
+ if len(channel.ModelPricing) > 0 {
+ if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
+ return err
+ }
+ }
+
+ // 设置账号统计定价规则
+ if len(channel.AccountStatsPricingRules) > 0 {
+ if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+}
+
+func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
+ ch := &service.Channel{}
+ var modelMappingJSON, featuresConfigJSON []byte
+ err := r.db.QueryRowContext(ctx,
+ `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at
+ FROM channels WHERE id = $1`, id,
+ ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt)
+ if err == sql.ErrNoRows {
+ return nil, service.ErrChannelNotFound
+ }
+ if err != nil {
+ return nil, fmt.Errorf("get channel: %w", err)
+ }
+ ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
+ ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
+
+ groupIDs, err := r.GetGroupIDs(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ ch.GroupIDs = groupIDs
+
+ pricing, err := r.ListModelPricing(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ ch.ModelPricing = pricing
+
+ statsPricingRules, err := r.loadAccountStatsPricingRules(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ ch.AccountStatsPricingRules = statsPricingRules
+
+ return ch, nil
+}
+
+func (r *channelRepository) Update(ctx context.Context, channel *service.Channel) error {
+ return r.runInTx(ctx, func(tx *sql.Tx) error {
+ modelMappingJSON, err := marshalModelMapping(channel.ModelMapping)
+ if err != nil {
+ return err
+ }
+ featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
+ if err != nil {
+ return err
+ }
+ result, err := tx.ExecContext(ctx,
+ `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, apply_pricing_to_account_stats = $9, updated_at = NOW()
+ WHERE id = $10`,
+ channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats, channel.ID,
+ )
+ if err != nil {
+ if isUniqueViolation(err) {
+ return service.ErrChannelExists
+ }
+ return fmt.Errorf("update channel: %w", err)
+ }
+ rows, _ := result.RowsAffected()
+ if rows == 0 {
+ return service.ErrChannelNotFound
+ }
+
+ // 更新分组关联
+ if channel.GroupIDs != nil {
+ if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
+ return err
+ }
+ }
+
+ // 更新模型定价
+ if channel.ModelPricing != nil {
+ if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
+ return err
+ }
+ }
+
+ // 更新账号统计定价规则
+ if channel.AccountStatsPricingRules != nil {
+ if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+}
+
+func (r *channelRepository) Delete(ctx context.Context, id int64) error {
+ result, err := r.db.ExecContext(ctx, `DELETE FROM channels WHERE id = $1`, id)
+ if err != nil {
+ return fmt.Errorf("delete channel: %w", err)
+ }
+ rows, _ := result.RowsAffected()
+ if rows == 0 {
+ return service.ErrChannelNotFound
+ }
+ return nil
+}
+
+func (r *channelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.Channel, *pagination.PaginationResult, error) {
+ where := []string{"1=1"}
+ args := []any{}
+ argIdx := 1
+
+ if status != "" {
+ where = append(where, fmt.Sprintf("c.status = $%d", argIdx))
+ args = append(args, status)
+ argIdx++
+ }
+ if search != "" {
+ where = append(where, fmt.Sprintf("(c.name ILIKE $%d OR c.description ILIKE $%d)", argIdx, argIdx))
+ args = append(args, "%"+escapeLike(search)+"%")
+ argIdx++
+ }
+
+ whereClause := strings.Join(where, " AND ")
+
+ // 计数
+ var total int64
+ countQuery := fmt.Sprintf("SELECT COUNT(*) FROM channels c WHERE %s", whereClause)
+ if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
+ return nil, nil, fmt.Errorf("count channels: %w", err)
+ }
+
+ pageSize := params.Limit() // 约束在 [1, 100]
+ page := params.Page
+ if page < 1 {
+ page = 1
+ }
+ offset := (page - 1) * pageSize
+
+ // 查询 channel 列表
+ dataQuery := fmt.Sprintf(
+ `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.apply_pricing_to_account_stats, c.created_at, c.updated_at
+ FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
+ whereClause, channelListOrderBy(params), argIdx, argIdx+1,
+ )
+ args = append(args, pageSize, offset)
+
+ rows, err := r.db.QueryContext(ctx, dataQuery, args...)
+ if err != nil {
+ return nil, nil, fmt.Errorf("query channels: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ var channels []service.Channel
+ var channelIDs []int64
+ for rows.Next() {
+ var ch service.Channel
+ var modelMappingJSON, featuresConfigJSON []byte
+ if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
+ return nil, nil, fmt.Errorf("scan channel: %w", err)
+ }
+ ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
+ ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
+ channels = append(channels, ch)
+ channelIDs = append(channelIDs, ch.ID)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, nil, fmt.Errorf("iterate channels: %w", err)
+ }
+
+ // 批量加载分组 ID 和模型定价(避免 N+1)
+ if len(channelIDs) > 0 {
+ groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
+ if err != nil {
+ return nil, nil, err
+ }
+ pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
+ if err != nil {
+ return nil, nil, err
+ }
+ statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
+ if err != nil {
+ return nil, nil, err
+ }
+ for i := range channels {
+ channels[i].GroupIDs = groupMap[channels[i].ID]
+ channels[i].ModelPricing = pricingMap[channels[i].ID]
+ channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
+ }
+ }
+
+ pages := 0
+ if total > 0 {
+ pages = int((total + int64(pageSize) - 1) / int64(pageSize))
+ }
+
+ paginationResult := &pagination.PaginationResult{
+ Total: total,
+ Page: page,
+ PageSize: pageSize,
+ Pages: pages,
+ }
+
+ return channels, paginationResult, nil
+}
+
+func channelListOrderBy(params pagination.PaginationParams) string {
+ sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
+ sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderAsc))
+
+ var column string
+ switch sortBy {
+ case "":
+ column = "c.id"
+ sortOrder = "ASC"
+ case "id":
+ column = "c.id"
+ case "name":
+ column = "c.name"
+ case "status":
+ column = "c.status"
+ case "created_at":
+ column = "c.created_at"
+ default:
+ column = "c.id"
+ sortOrder = "ASC"
+ }
+
+ return fmt.Sprintf("%s %s, c.id %s", column, sortOrder, sortOrder)
+}
+
+func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("query all channels: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ var channels []service.Channel
+ var channelIDs []int64
+ for rows.Next() {
+ var ch service.Channel
+ var modelMappingJSON, featuresConfigJSON []byte
+ if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
+ return nil, fmt.Errorf("scan channel: %w", err)
+ }
+ ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
+ ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
+ channels = append(channels, ch)
+ channelIDs = append(channelIDs, ch.ID)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterate channels: %w", err)
+ }
+
+ if len(channelIDs) == 0 {
+ return channels, nil
+ }
+
+ // 批量加载分组 ID
+ groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ // 批量加载模型定价
+ pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ // 批量加载账号统计定价规则
+ statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ for i := range channels {
+ channels[i].GroupIDs = groupMap[channels[i].ID]
+ channels[i].ModelPricing = pricingMap[channels[i].ID]
+ channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
+ }
+
+ return channels, nil
+}
+
+// --- 批量加载辅助方法 ---
+
+// batchLoadGroupIDs 批量加载多个渠道的分组 ID
+func (r *channelRepository) batchLoadGroupIDs(ctx context.Context, channelIDs []int64) (map[int64][]int64, error) {
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT channel_id, group_id FROM channel_groups
+ WHERE channel_id = ANY($1) ORDER BY channel_id, group_id`,
+ pq.Array(channelIDs),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("batch load group ids: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ groupMap := make(map[int64][]int64, len(channelIDs))
+ for rows.Next() {
+ var channelID, groupID int64
+ if err := rows.Scan(&channelID, &groupID); err != nil {
+ return nil, fmt.Errorf("scan group id: %w", err)
+ }
+ groupMap[channelID] = append(groupMap[channelID], groupID)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterate group ids: %w", err)
+ }
+ return groupMap, nil
+}
+
+func (r *channelRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
+ var exists bool
+ err := r.db.QueryRowContext(ctx,
+ `SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1)`, name,
+ ).Scan(&exists)
+ return exists, err
+}
+
+func (r *channelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) {
+ var exists bool
+ err := r.db.QueryRowContext(ctx,
+ `SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1 AND id != $2)`, name, excludeID,
+ ).Scan(&exists)
+ return exists, err
+}
+
+// --- 分组关联 ---
+
+func (r *channelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) {
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT group_id FROM channel_groups WHERE channel_id = $1 ORDER BY group_id`, channelID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("get group ids: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ var ids []int64
+ for rows.Next() {
+ var id int64
+ if err := rows.Scan(&id); err != nil {
+ return nil, fmt.Errorf("scan group id: %w", err)
+ }
+ ids = append(ids, id)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterate group ids: %w", err)
+ }
+ return ids, nil
+}
+
+func (r *channelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error {
+ return setGroupIDsTx(ctx, r.db, channelID, groupIDs)
+}
+
+func (r *channelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
+ var channelID int64
+ err := r.db.QueryRowContext(ctx,
+ `SELECT channel_id FROM channel_groups WHERE group_id = $1`, groupID,
+ ).Scan(&channelID)
+ if err == sql.ErrNoRows {
+ return 0, nil
+ }
+ return channelID, err
+}
+
+func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) {
+ if len(groupIDs) == 0 {
+ return nil, nil
+ }
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT group_id FROM channel_groups WHERE group_id = ANY($1) AND channel_id != $2`,
+ pq.Array(groupIDs), channelID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("get groups in other channels: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ var conflicting []int64
+ for rows.Next() {
+ var id int64
+ if err := rows.Scan(&id); err != nil {
+ return nil, fmt.Errorf("scan conflicting group id: %w", err)
+ }
+ conflicting = append(conflicting, id)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterate conflicting group ids: %w", err)
+ }
+ return conflicting, nil
+}
+
+// marshalModelMapping 将 model mapping 序列化为嵌套 JSON 字节
+// 格式:{"platform": {"src": "dst"}, ...}
+func marshalModelMapping(m map[string]map[string]string) ([]byte, error) {
+ if len(m) == 0 {
+ return []byte("{}"), nil
+ }
+ data, err := json.Marshal(m)
+ if err != nil {
+ return nil, fmt.Errorf("marshal model_mapping: %w", err)
+ }
+ return data, nil
+}
+
+// unmarshalModelMapping 将 JSON 字节反序列化为嵌套 model mapping
+func unmarshalModelMapping(data []byte) map[string]map[string]string {
+ if len(data) == 0 {
+ return nil
+ }
+ var m map[string]map[string]string
+ if err := json.Unmarshal(data, &m); err != nil {
+ return nil
+ }
+ return m
+}
+
+func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
+ if len(m) == 0 {
+ return []byte("{}"), nil
+ }
+ data, err := json.Marshal(m)
+ if err != nil {
+ return nil, fmt.Errorf("marshal features_config: %w", err)
+ }
+ return data, nil
+}
+
+func unmarshalFeaturesConfig(data []byte) map[string]any {
+ if len(data) == 0 {
+ return nil
+ }
+ var m map[string]any
+ if err := json.Unmarshal(data, &m); err != nil {
+ return nil
+ }
+ return m
+}
+
+// GetGroupPlatforms 批量查询分组 ID 对应的平台
+func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
+ if len(groupIDs) == 0 {
+ return make(map[int64]string), nil
+ }
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT id, platform FROM groups WHERE id = ANY($1)`,
+ pq.Array(groupIDs),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("get group platforms: %w", err)
+ }
+ defer rows.Close() //nolint:errcheck
+
+ result := make(map[int64]string, len(groupIDs))
+ for rows.Next() {
+ var id int64
+ var platform string
+ if err := rows.Scan(&id, &platform); err != nil {
+ return nil, fmt.Errorf("scan group platform: %w", err)
+ }
+ result[id] = platform
+ }
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterate group platforms: %w", err)
+ }
+ return result, nil
+}
diff --git a/backend/internal/repository/channel_repo_account_stats_pricing.go b/backend/internal/repository/channel_repo_account_stats_pricing.go
new file mode 100644
index 00000000..9e00fed8
--- /dev/null
+++ b/backend/internal/repository/channel_repo_account_stats_pricing.go
@@ -0,0 +1,244 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
+)
+
+// --- 账号统计定价规则 ---
+
+// batchLoadAccountStatsPricingRules 批量加载多个渠道的账号统计定价规则(含模型定价)
+func (r *channelRepository) batchLoadAccountStatsPricingRules(ctx context.Context, channelIDs []int64) (map[int64][]service.AccountStatsPricingRule, error) {
+ // 1. 查询规则
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT id, channel_id, name, group_ids, account_ids, sort_order, created_at, updated_at
+ FROM channel_account_stats_pricing_rules WHERE channel_id = ANY($1) ORDER BY channel_id, sort_order, id`,
+ pq.Array(channelIDs),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("batch load account stats pricing rules: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ var allRules []service.AccountStatsPricingRule
+ var ruleIDs []int64
+ for rows.Next() {
+ var rule service.AccountStatsPricingRule
+ if err := rows.Scan(
+ &rule.ID, &rule.ChannelID, &rule.Name,
+ pq.Array(&rule.GroupIDs), pq.Array(&rule.AccountIDs),
+ &rule.SortOrder, &rule.CreatedAt, &rule.UpdatedAt,
+ ); err != nil {
+ return nil, fmt.Errorf("scan account stats pricing rule: %w", err)
+ }
+ ruleIDs = append(ruleIDs, rule.ID)
+ allRules = append(allRules, rule)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterate account stats pricing rules: %w", err)
+ }
+
+ // 2. 批量加载规则的模型定价
+ pricingMap, err := r.batchLoadAccountStatsModelPricing(ctx, ruleIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ // 3. 按 channelID 分组并关联定价
+ result := make(map[int64][]service.AccountStatsPricingRule, len(channelIDs))
+ for i := range allRules {
+ allRules[i].Pricing = pricingMap[allRules[i].ID]
+ result[allRules[i].ChannelID] = append(result[allRules[i].ChannelID], allRules[i])
+ }
+
+ return result, nil
+}
+
+// batchLoadAccountStatsModelPricing 批量加载规则的模型定价
+func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Context, ruleIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
+ if len(ruleIDs) == 0 {
+ return make(map[int64][]service.ChannelModelPricing), nil
+ }
+
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT id, rule_id, platform, models, billing_mode, input_price, output_price,
+ cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
+ FROM channel_account_stats_model_pricing WHERE rule_id = ANY($1) ORDER BY rule_id, id`,
+ pq.Array(ruleIDs),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("batch load account stats model pricing: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ pricingMap := make(map[int64][]service.ChannelModelPricing, len(ruleIDs))
+ for rows.Next() {
+ var p service.ChannelModelPricing
+ var ruleID int64
+ var modelsJSON []byte
+ if err := rows.Scan(
+ &p.ID, &ruleID, &p.Platform, &modelsJSON, &p.BillingMode,
+ &p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
+ &p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
+ ); err != nil {
+ return nil, fmt.Errorf("scan account stats model pricing: %w", err)
+ }
+ if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
+ p.Models = []string{}
+ }
+ pricingMap[ruleID] = append(pricingMap[ruleID], p)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterate account stats model pricing: %w", err)
+ }
+
+ // Load intervals for all pricing entries.
+ var allPricingIDs []int64
+ for _, pricings := range pricingMap {
+ for _, p := range pricings {
+ allPricingIDs = append(allPricingIDs, p.ID)
+ }
+ }
+ if len(allPricingIDs) > 0 {
+ intervalsMap, err := r.batchLoadAccountStatsIntervals(ctx, allPricingIDs)
+ if err != nil {
+ return nil, err
+ }
+ for ruleID, pricings := range pricingMap {
+ for i := range pricings {
+ pricings[i].Intervals = intervalsMap[pricings[i].ID]
+ }
+ pricingMap[ruleID] = pricings
+ }
+ }
+
+ return pricingMap, nil
+}
+
+// loadAccountStatsPricingRules 加载单个渠道的账号统计定价规则(供 GetByID 使用)
+func (r *channelRepository) loadAccountStatsPricingRules(ctx context.Context, channelID int64) ([]service.AccountStatsPricingRule, error) {
+ result, err := r.batchLoadAccountStatsPricingRules(ctx, []int64{channelID})
+ if err != nil {
+ return nil, err
+ }
+ return result[channelID], nil
+}
+
+// replaceAccountStatsPricingRulesTx 在事务中替换渠道的账号统计定价规则(删除旧的 + 插入新的)
+func replaceAccountStatsPricingRulesTx(ctx context.Context, tx *sql.Tx, channelID int64, rules []service.AccountStatsPricingRule) error {
+ // CASCADE 会自动删除关联的 model_pricing
+ if _, err := tx.ExecContext(ctx,
+ `DELETE FROM channel_account_stats_pricing_rules WHERE channel_id = $1`, channelID,
+ ); err != nil {
+ return fmt.Errorf("delete old account stats pricing rules: %w", err)
+ }
+
+ for i := range rules {
+ rules[i].ChannelID = channelID
+ if err := createAccountStatsPricingRuleTx(ctx, tx, &rules[i]); err != nil {
+ return fmt.Errorf("insert account stats pricing rule: %w", err)
+ }
+ }
+ return nil
+}
+
+// createAccountStatsPricingRuleTx 在事务中创建单条账号统计定价规则及其模型定价
+func createAccountStatsPricingRuleTx(ctx context.Context, tx *sql.Tx, rule *service.AccountStatsPricingRule) error {
+ err := tx.QueryRowContext(ctx,
+ `INSERT INTO channel_account_stats_pricing_rules (channel_id, name, group_ids, account_ids, sort_order)
+ VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`,
+ rule.ChannelID, rule.Name, pq.Array(rule.GroupIDs), pq.Array(rule.AccountIDs), rule.SortOrder,
+ ).Scan(&rule.ID, &rule.CreatedAt, &rule.UpdatedAt)
+ if err != nil {
+ return fmt.Errorf("insert account stats pricing rule: %w", err)
+ }
+
+ for j := range rule.Pricing {
+ if err := createAccountStatsModelPricingTx(ctx, tx, rule.ID, &rule.Pricing[j]); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// createAccountStatsModelPricingTx 在事务中创建单条账号统计模型定价
+func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID int64, pricing *service.ChannelModelPricing) error {
+ modelsJSON, err := json.Marshal(pricing.Models)
+ if err != nil {
+ return fmt.Errorf("marshal models: %w", err)
+ }
+ billingMode := pricing.BillingMode
+ if billingMode == "" {
+ billingMode = service.BillingModeToken
+ }
+ platform := pricing.Platform
+ err = tx.QueryRowContext(ctx,
+ `INSERT INTO channel_account_stats_model_pricing (rule_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
+ ruleID, platform, modelsJSON, billingMode,
+ pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
+ pricing.ImageOutputPrice, pricing.PerRequestPrice,
+ ).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
+ if err != nil {
+ return fmt.Errorf("insert account stats model pricing: %w", err)
+ }
+ // Persist intervals (mirrors channel_pricing_intervals logic).
+ for i := range pricing.Intervals {
+ iv := &pricing.Intervals[i]
+ iv.PricingID = pricing.ID
+ if err := createAccountStatsIntervalTx(ctx, tx, iv); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// createAccountStatsIntervalTx inserts a single interval for an account stats pricing entry.
+func createAccountStatsIntervalTx(ctx context.Context, tx *sql.Tx, iv *service.PricingInterval) error {
+ return tx.QueryRowContext(ctx,
+ `INSERT INTO channel_account_stats_pricing_intervals
+ (pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
+ iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
+ iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
+ iv.PerRequestPrice, iv.SortOrder,
+ ).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
+}
+
+// batchLoadAccountStatsIntervals loads intervals for account stats pricing entries.
+func (r *channelRepository) batchLoadAccountStatsIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
+ if len(pricingIDs) == 0 {
+ return nil, nil
+ }
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
+ input_price, output_price, cache_write_price, cache_read_price,
+ per_request_price, sort_order, created_at, updated_at
+ FROM channel_account_stats_pricing_intervals
+ WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
+ pq.Array(pricingIDs),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("batch load account stats pricing intervals: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ result := make(map[int64][]service.PricingInterval)
+ for rows.Next() {
+ var iv service.PricingInterval
+ if err := rows.Scan(
+ &iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
+ &iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
+ &iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
+ ); err != nil {
+ return nil, fmt.Errorf("scan account stats pricing interval: %w", err)
+ }
+ result[iv.PricingID] = append(result[iv.PricingID], iv)
+ }
+ return result, rows.Err()
+}
diff --git a/backend/internal/repository/channel_repo_pricing.go b/backend/internal/repository/channel_repo_pricing.go
new file mode 100644
index 00000000..6dcf3c91
--- /dev/null
+++ b/backend/internal/repository/channel_repo_pricing.go
@@ -0,0 +1,291 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "strings"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
+)
+
+// --- 模型定价 ---
+
+func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) {
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
+ FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("list model pricing: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ result, pricingIDs, err := scanModelPricingRows(rows)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(pricingIDs) > 0 {
+ intervalMap, err := r.batchLoadIntervals(ctx, pricingIDs)
+ if err != nil {
+ return nil, err
+ }
+ for i := range result {
+ result[i].Intervals = intervalMap[result[i].ID]
+ }
+ }
+
+ return result, nil
+}
+
+func (r *channelRepository) CreateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
+ return createModelPricingExec(ctx, r.db, pricing)
+}
+
+func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
+ modelsJSON, err := json.Marshal(pricing.Models)
+ if err != nil {
+ return fmt.Errorf("marshal models: %w", err)
+ }
+ billingMode := pricing.BillingMode
+ if billingMode == "" {
+ billingMode = service.BillingModeToken
+ }
+ result, err := r.db.ExecContext(ctx,
+ `UPDATE channel_model_pricing
+ SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, per_request_price = $8, platform = $9, updated_at = NOW()
+ WHERE id = $10`,
+ modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
+ pricing.ImageOutputPrice, pricing.PerRequestPrice, pricing.Platform, pricing.ID,
+ )
+ if err != nil {
+ return fmt.Errorf("update model pricing: %w", err)
+ }
+ rows, _ := result.RowsAffected()
+ if rows == 0 {
+ return fmt.Errorf("pricing entry not found: %d", pricing.ID)
+ }
+ return nil
+}
+
+func (r *channelRepository) DeleteModelPricing(ctx context.Context, id int64) error {
+ _, err := r.db.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE id = $1`, id)
+ if err != nil {
+ return fmt.Errorf("delete model pricing: %w", err)
+ }
+ return nil
+}
+
+func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []service.ChannelModelPricing) error {
+ return r.runInTx(ctx, func(tx *sql.Tx) error {
+ return replaceModelPricingTx(ctx, tx, channelID, pricingList)
+ })
+}
+
+// --- 批量加载辅助方法 ---
+
+// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间)
+func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT id, channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
+ FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`,
+ pq.Array(channelIDs),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("batch load model pricing: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ allPricing, allPricingIDs, err := scanModelPricingRows(rows)
+ if err != nil {
+ return nil, err
+ }
+
+ // 按 channelID 分组
+ pricingMap := make(map[int64][]service.ChannelModelPricing, len(channelIDs))
+ for _, p := range allPricing {
+ pricingMap[p.ChannelID] = append(pricingMap[p.ChannelID], p)
+ }
+
+ // 批量加载所有区间
+ if len(allPricingIDs) > 0 {
+ intervalMap, err := r.batchLoadIntervals(ctx, allPricingIDs)
+ if err != nil {
+ return nil, err
+ }
+ for chID := range pricingMap {
+ for i := range pricingMap[chID] {
+ pricingMap[chID][i].Intervals = intervalMap[pricingMap[chID][i].ID]
+ }
+ }
+ }
+
+ return pricingMap, nil
+}
+
+// batchLoadIntervals 批量加载多个定价条目的区间
+func (r *channelRepository) batchLoadIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
+ rows, err := r.db.QueryContext(ctx,
+ `SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
+ input_price, output_price, cache_write_price, cache_read_price,
+ per_request_price, sort_order, created_at, updated_at
+ FROM channel_pricing_intervals
+ WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
+ pq.Array(pricingIDs),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("batch load intervals: %w", err)
+ }
+ defer func() { _ = rows.Close() }()
+
+ intervalMap := make(map[int64][]service.PricingInterval, len(pricingIDs))
+ for rows.Next() {
+ var iv service.PricingInterval
+ if err := rows.Scan(
+ &iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
+ &iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
+ &iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
+ ); err != nil {
+ return nil, fmt.Errorf("scan interval: %w", err)
+ }
+ intervalMap[iv.PricingID] = append(intervalMap[iv.PricingID], iv)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterate intervals: %w", err)
+ }
+ return intervalMap, nil
+}
+
+// --- 共享 scan 辅助 ---
+
+// scanModelPricingRows 扫描 model pricing 行,返回结果列表和 ID 列表
+func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int64, error) {
+ var result []service.ChannelModelPricing
+ var pricingIDs []int64
+ for rows.Next() {
+ var p service.ChannelModelPricing
+ var modelsJSON []byte
+ if err := rows.Scan(
+ &p.ID, &p.ChannelID, &p.Platform, &modelsJSON, &p.BillingMode,
+ &p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
+ &p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
+ ); err != nil {
+ return nil, nil, fmt.Errorf("scan model pricing: %w", err)
+ }
+ if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
+ p.Models = []string{}
+ }
+ pricingIDs = append(pricingIDs, p.ID)
+ result = append(result, p)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, nil, fmt.Errorf("iterate model pricing: %w", err)
+ }
+ return result, pricingIDs, nil
+}
+
+// --- 事务内辅助方法 ---
+
+// dbExec 是 *sql.DB 和 *sql.Tx 共享的最小 SQL 执行接口
+type dbExec interface {
+ ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
+ QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
+ QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
+}
+
+func setGroupIDsTx(ctx context.Context, exec dbExec, channelID int64, groupIDs []int64) error {
+ if _, err := exec.ExecContext(ctx, `DELETE FROM channel_groups WHERE channel_id = $1`, channelID); err != nil {
+ return fmt.Errorf("delete old group associations: %w", err)
+ }
+ if len(groupIDs) == 0 {
+ return nil
+ }
+ _, err := exec.ExecContext(ctx,
+ `INSERT INTO channel_groups (channel_id, group_id)
+ SELECT $1, unnest($2::bigint[])`,
+ channelID, pq.Array(groupIDs),
+ )
+ if err != nil {
+ return fmt.Errorf("insert group associations: %w", err)
+ }
+ return nil
+}
+
+func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.ChannelModelPricing) error {
+ modelsJSON, err := json.Marshal(pricing.Models)
+ if err != nil {
+ return fmt.Errorf("marshal models: %w", err)
+ }
+ billingMode := pricing.BillingMode
+ if billingMode == "" {
+ billingMode = service.BillingModeToken
+ }
+ platform := pricing.Platform
+ if platform == "" {
+ platform = "anthropic"
+ }
+ err = exec.QueryRowContext(ctx,
+ `INSERT INTO channel_model_pricing (channel_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
+ pricing.ChannelID, platform, modelsJSON, billingMode,
+ pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
+ pricing.ImageOutputPrice, pricing.PerRequestPrice,
+ ).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
+ if err != nil {
+ return fmt.Errorf("insert model pricing: %w", err)
+ }
+
+ for i := range pricing.Intervals {
+ pricing.Intervals[i].PricingID = pricing.ID
+ if err := createIntervalExec(ctx, exec, &pricing.Intervals[i]); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func createIntervalExec(ctx context.Context, exec dbExec, iv *service.PricingInterval) error {
+ return exec.QueryRowContext(ctx,
+ `INSERT INTO channel_pricing_intervals
+ (pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
+ iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
+ iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
+ iv.PerRequestPrice, iv.SortOrder,
+ ).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
+}
+
+func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pricingList []service.ChannelModelPricing) error {
+ if _, err := exec.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE channel_id = $1`, channelID); err != nil {
+ return fmt.Errorf("delete old model pricing: %w", err)
+ }
+ for i := range pricingList {
+ pricingList[i].ChannelID = channelID
+ if err := createModelPricingExec(ctx, exec, &pricingList[i]); err != nil {
+ return fmt.Errorf("insert model pricing: %w", err)
+ }
+ }
+ return nil
+}
+
+// isUniqueViolation 检查 pq 唯一约束违反错误
+func isUniqueViolation(err error) bool {
+ var pqErr *pq.Error
+ if errors.As(err, &pqErr) && pqErr != nil {
+ return pqErr.Code == "23505"
+ }
+ return false
+}
+
+// escapeLike 转义 LIKE/ILIKE 模式中的特殊字符
+func escapeLike(s string) string {
+ s = strings.ReplaceAll(s, `\`, `\\`)
+ s = strings.ReplaceAll(s, `%`, `\%`)
+ s = strings.ReplaceAll(s, `_`, `\_`)
+ return s
+}
diff --git a/backend/internal/repository/channel_repo_test.go b/backend/internal/repository/channel_repo_test.go
new file mode 100644
index 00000000..e761866d
--- /dev/null
+++ b/backend/internal/repository/channel_repo_test.go
@@ -0,0 +1,237 @@
+//go:build unit
+
+package repository
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/lib/pq"
+ "github.com/stretchr/testify/require"
+)
+
+// --- marshalModelMapping ---
+
+func TestMarshalModelMapping(t *testing.T) {
+ tests := []struct {
+ name string
+ input map[string]map[string]string
+ wantJSON string // expected JSON output (exact match)
+ }{
+ {
+ name: "empty map",
+ input: map[string]map[string]string{},
+ wantJSON: "{}",
+ },
+ {
+ name: "nil map",
+ input: nil,
+ wantJSON: "{}",
+ },
+ {
+ name: "populated map",
+ input: map[string]map[string]string{
+ "openai": {"gpt-4": "gpt-4-turbo"},
+ },
+ },
+ {
+ name: "nested values",
+ input: map[string]map[string]string{
+ "openai": {"*": "gpt-5.4"},
+ "anthropic": {"claude-old": "claude-new"},
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := marshalModelMapping(tt.input)
+ require.NoError(t, err)
+
+ if tt.wantJSON != "" {
+ require.Equal(t, []byte(tt.wantJSON), result)
+ } else {
+ // round-trip: unmarshal and compare with input
+ var parsed map[string]map[string]string
+ require.NoError(t, json.Unmarshal(result, &parsed))
+ require.Equal(t, tt.input, parsed)
+ }
+ })
+ }
+}
+
+// --- unmarshalModelMapping ---
+
+func TestUnmarshalModelMapping(t *testing.T) {
+ tests := []struct {
+ name string
+ input []byte
+ wantNil bool
+ want map[string]map[string]string
+ }{
+ {
+ name: "nil data",
+ input: nil,
+ wantNil: true,
+ },
+ {
+ name: "empty data",
+ input: []byte{},
+ wantNil: true,
+ },
+ {
+ name: "invalid JSON",
+ input: []byte("not-json"),
+ wantNil: true,
+ },
+ {
+ name: "type error - number",
+ input: []byte("42"),
+ wantNil: true,
+ },
+ {
+ name: "type error - array",
+ input: []byte("[1,2,3]"),
+ wantNil: true,
+ },
+ {
+ name: "valid JSON",
+ input: []byte(`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`),
+ want: map[string]map[string]string{
+ "openai": {"gpt-4": "gpt-4-turbo"},
+ "anthropic": {"old": "new"},
+ },
+ },
+ {
+ name: "empty object",
+ input: []byte("{}"),
+ want: map[string]map[string]string{},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := unmarshalModelMapping(tt.input)
+ if tt.wantNil {
+ require.Nil(t, result)
+ } else {
+ require.NotNil(t, result)
+ require.Equal(t, tt.want, result)
+ }
+ })
+ }
+}
+
+// --- escapeLike ---
+
+func TestEscapeLike(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {
+ name: "no special chars",
+ input: "hello",
+ want: "hello",
+ },
+ {
+ name: "backslash",
+ input: `a\b`,
+ want: `a\\b`,
+ },
+ {
+ name: "percent",
+ input: "50%",
+ want: `50\%`,
+ },
+ {
+ name: "underscore",
+ input: "a_b",
+ want: `a\_b`,
+ },
+ {
+ name: "all special chars",
+ input: `a\b%c_d`,
+ want: `a\\b\%c\_d`,
+ },
+ {
+ name: "empty string",
+ input: "",
+ want: "",
+ },
+ {
+ name: "consecutive special chars",
+ input: "%_%",
+ want: `\%\_\%`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ require.Equal(t, tt.want, escapeLike(tt.input))
+ })
+ }
+}
+
+// --- isUniqueViolation ---
+
+func TestIsUniqueViolation(t *testing.T) {
+ tests := []struct {
+ name string
+ err error
+ want bool
+ }{
+ {
+ name: "unique violation code 23505",
+ err: &pq.Error{Code: "23505"},
+ want: true,
+ },
+ {
+ name: "different pq error code",
+ err: &pq.Error{Code: "23503"},
+ want: false,
+ },
+ {
+ name: "non-pq error",
+ err: errors.New("some generic error"),
+ want: false,
+ },
+ {
+ name: "typed nil pq.Error",
+ err: func() error {
+ var pqErr *pq.Error
+ return pqErr
+ }(),
+ want: false,
+ },
+ {
+ name: "bare nil",
+ err: nil,
+ want: false,
+ },
+ {
+ name: "wrapped pq error with 23505",
+ err: fmt.Errorf("wrapped: %w", &pq.Error{Code: "23505"}),
+ want: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ require.Equal(t, tt.want, isUniqueViolation(tt.err))
+ })
+ }
+}
+
+func TestChannelListOrderBy_AllowsDescendingIDSort(t *testing.T) {
+ params := pagination.PaginationParams{
+ SortBy: "id",
+ SortOrder: "desc",
+ }
+
+ require.Equal(t, "c.id DESC, c.id DESC", channelListOrderBy(params))
+}
diff --git a/backend/internal/repository/dashboard_aggregation_repo.go b/backend/internal/repository/dashboard_aggregation_repo.go
index e82a73a3..5e09e75d 100644
--- a/backend/internal/repository/dashboard_aggregation_repo.go
+++ b/backend/internal/repository/dashboard_aggregation_repo.go
@@ -331,6 +331,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
COALESCE(SUM(total_cost), 0) AS total_cost,
COALESCE(SUM(actual_cost), 0) AS actual_cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) AS account_cost,
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
@@ -351,6 +352,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
cache_read_tokens,
total_cost,
actual_cost,
+ account_cost,
total_duration_ms,
active_users,
computed_at
@@ -364,6 +366,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
hourly.cache_read_tokens,
hourly.total_cost,
hourly.actual_cost,
+ hourly.account_cost,
hourly.total_duration_ms,
COALESCE(user_counts.active_users, 0) AS active_users,
NOW()
@@ -378,6 +381,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
cache_read_tokens = EXCLUDED.cache_read_tokens,
total_cost = EXCLUDED.total_cost,
actual_cost = EXCLUDED.actual_cost,
+ account_cost = EXCLUDED.account_cost,
total_duration_ms = EXCLUDED.total_duration_ms,
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
@@ -399,6 +403,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
COALESCE(SUM(total_cost), 0) AS total_cost,
COALESCE(SUM(actual_cost), 0) AS actual_cost,
+ COALESCE(SUM(account_cost), 0) AS account_cost,
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
FROM usage_dashboard_hourly
WHERE bucket_start >= $1 AND bucket_start < $2
@@ -419,6 +424,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
cache_read_tokens,
total_cost,
actual_cost,
+ account_cost,
total_duration_ms,
active_users,
computed_at
@@ -432,6 +438,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
daily.cache_read_tokens,
daily.total_cost,
daily.actual_cost,
+ daily.account_cost,
daily.total_duration_ms,
COALESCE(user_counts.active_users, 0) AS active_users,
NOW()
@@ -446,6 +453,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
cache_read_tokens = EXCLUDED.cache_read_tokens,
total_cost = EXCLUDED.total_cost,
actual_cost = EXCLUDED.actual_cost,
+ account_cost = EXCLUDED.account_cost,
total_duration_ms = EXCLUDED.total_duration_ms,
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go
index 8f2b8eca..96a23a8e 100644
--- a/backend/internal/repository/email_cache.go
+++ b/backend/internal/repository/email_cache.go
@@ -3,6 +3,8 @@ package repository
import (
"context"
"encoding/json"
+ "fmt"
+ "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -11,23 +13,33 @@ import (
const (
verifyCodeKeyPrefix = "verify_code:"
+ notifyVerifyKeyPrefix = "notify_verify:"
passwordResetKeyPrefix = "password_reset:"
passwordResetSentAtKeyPrefix = "password_reset_sent:"
+ notifyCodeUserRateKeyPrefix = "notify_code_user_rate:"
)
// verifyCodeKey generates the Redis key for email verification code.
+// Email is lowercased for case-insensitive consistency.
func verifyCodeKey(email string) string {
- return verifyCodeKeyPrefix + email
+ return verifyCodeKeyPrefix + strings.ToLower(email)
+}
+
+// notifyVerifyKey generates the Redis key for notify email verification code.
+// Email is lowercased to prevent case-sensitive key mismatch (the business layer
+// uses strings.EqualFold for comparison).
+func notifyVerifyKey(email string) string {
+ return notifyVerifyKeyPrefix + strings.ToLower(email)
}
// passwordResetKey generates the Redis key for password reset token.
func passwordResetKey(email string) string {
- return passwordResetKeyPrefix + email
+ return passwordResetKeyPrefix + strings.ToLower(email)
}
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
func passwordResetSentAtKey(email string) string {
- return passwordResetSentAtKeyPrefix + email
+ return passwordResetSentAtKeyPrefix + strings.ToLower(email)
}
type emailCache struct {
@@ -106,3 +118,60 @@ func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email st
key := passwordResetSentAtKey(email)
return c.rdb.Set(ctx, key, "1", ttl).Err()
}
+
+// Notify email verification code methods
+
+func (c *emailCache) GetNotifyVerifyCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
+ key := notifyVerifyKey(email)
+ val, err := c.rdb.Get(ctx, key).Result()
+ if err != nil {
+ return nil, err
+ }
+ var data service.VerificationCodeData
+ if err := json.Unmarshal([]byte(val), &data); err != nil {
+ return nil, err
+ }
+ return &data, nil
+}
+
+func (c *emailCache) SetNotifyVerifyCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
+ key := notifyVerifyKey(email)
+ val, err := json.Marshal(data)
+ if err != nil {
+ return err
+ }
+ return c.rdb.Set(ctx, key, val, ttl).Err()
+}
+
+func (c *emailCache) DeleteNotifyVerifyCode(ctx context.Context, email string) error {
+ key := notifyVerifyKey(email)
+ return c.rdb.Del(ctx, key).Err()
+}
+
+// User-level rate limiting for notify email verification codes
+
+func notifyCodeUserRateKey(userID int64) string {
+ return notifyCodeUserRateKeyPrefix + fmt.Sprintf("%d", userID)
+}
+
+func (c *emailCache) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
+ key := notifyCodeUserRateKey(userID)
+ count, err := c.rdb.Incr(ctx, key).Result()
+ if err != nil {
+ return 0, err
+ }
+ // Always set TTL (idempotent) to avoid orphan keys if process crashes between INCR and EXPIRE.
+ if err := c.rdb.Expire(ctx, key, window).Err(); err != nil {
+ return count, fmt.Errorf("expire notify code rate key: %w", err)
+ }
+ return count, nil
+}
+
+func (c *emailCache) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
+ key := notifyCodeUserRateKey(userID)
+ count, err := c.rdb.Get(ctx, key).Int64()
+ if err != nil {
+ return 0, err
+ }
+ return count, nil
+}
diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go
index 3cfd649b..5e16475a 100644
--- a/backend/internal/repository/group_repo.go
+++ b/backend/internal/repository/group_repo.go
@@ -5,6 +5,7 @@ import (
"database/sql"
"errors"
"fmt"
+ "sort"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -14,6 +15,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
+
+ entsql "entgo.io/ent/dialect/sql"
)
type sqlExecutor interface {
@@ -40,6 +43,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetDescription(groupIn.Description).
SetPlatform(groupIn.Platform).
SetRateMultiplier(groupIn.RateMultiplier).
+ SetSortOrder(groupIn.SortOrder).
SetIsExclusive(groupIn.IsExclusive).
SetStatus(groupIn.Status).
SetSubscriptionType(groupIn.SubscriptionType).
@@ -49,21 +53,18 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
- SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
- SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
- SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
- SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject).
- SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
- SetDefaultMappedModel(groupIn.DefaultMappedModel)
+ SetDefaultMappedModel(groupIn.DefaultMappedModel).
+ SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
+ SetRpmLimit(groupIn.RPMLimit)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
@@ -122,19 +123,16 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
- SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
- SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
- SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
- SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject).
- SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
SetRequirePrivacySet(groupIn.RequirePrivacySet).
- SetDefaultMappedModel(groupIn.DefaultMappedModel)
+ SetDefaultMappedModel(groupIn.DefaultMappedModel).
+ SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig).
+ SetRpmLimit(groupIn.RPMLimit)
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
if groupIn.DailyLimitUSD != nil {
@@ -241,11 +239,18 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err
}
- groups, err := q.
+ if strings.EqualFold(strings.TrimSpace(params.SortBy), "account_count") {
+ return r.listWithAccountCountSort(ctx, q, params, total)
+ }
+
+ groupsQuery := q.
Offset(params.Offset()).
- Limit(params.Limit()).
- Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
- All(ctx)
+ Limit(params.Limit())
+ for _, order := range groupListOrder(params) {
+ groupsQuery = groupsQuery.Order(order)
+ }
+
+ groups, err := groupsQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -271,6 +276,104 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
return outGroups, paginationResultFromTotal(int64(total), params), nil
}
+func (r *groupRepository) listWithAccountCountSort(ctx context.Context, q *dbent.GroupQuery, params pagination.PaginationParams, total int) ([]service.Group, *pagination.PaginationResult, error) {
+ groups, err := q.
+ Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
+ All(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ groupIDs := make([]int64, 0, len(groups))
+ outGroups := make([]service.Group, 0, len(groups))
+ for i := range groups {
+ g := groupEntityToService(groups[i])
+ outGroups = append(outGroups, *g)
+ groupIDs = append(groupIDs, g.ID)
+ }
+
+ counts, err := r.loadAccountCounts(ctx, groupIDs)
+ if err != nil {
+ return nil, nil, err
+ }
+ for i := range outGroups {
+ c := counts[outGroups[i].ID]
+ outGroups[i].AccountCount = c.Total
+ outGroups[i].ActiveAccountCount = c.Active
+ outGroups[i].RateLimitedAccountCount = c.RateLimited
+ }
+
+ sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
+ sort.SliceStable(outGroups, func(i, j int) bool {
+ if outGroups[i].AccountCount == outGroups[j].AccountCount {
+ if outGroups[i].SortOrder == outGroups[j].SortOrder {
+ return outGroups[i].ID < outGroups[j].ID
+ }
+ return outGroups[i].SortOrder < outGroups[j].SortOrder
+ }
+ if sortOrder == pagination.SortOrderAsc {
+ return outGroups[i].AccountCount < outGroups[j].AccountCount
+ }
+ return outGroups[i].AccountCount > outGroups[j].AccountCount
+ })
+
+ return paginateSlice(outGroups, params), paginationResultFromTotal(int64(total), params), nil
+}
+
+func groupListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
+ sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
+ sortOrder := params.NormalizedSortOrder(pagination.SortOrderAsc)
+
+ var field string
+ tieField := group.FieldID
+ defaultOrder := true
+ switch sortBy {
+ case "", "sort_order":
+ field = group.FieldSortOrder
+ case "name":
+ field = group.FieldName
+ defaultOrder = false
+ case "platform":
+ field = group.FieldPlatform
+ defaultOrder = false
+ case "billing_type", "subscription_type":
+ field = group.FieldSubscriptionType
+ defaultOrder = false
+ case "rate_multiplier":
+ field = group.FieldRateMultiplier
+ defaultOrder = false
+ case "is_exclusive":
+ field = group.FieldIsExclusive
+ defaultOrder = false
+ case "status":
+ field = group.FieldStatus
+ defaultOrder = false
+ case "created_at":
+ field = group.FieldCreatedAt
+ defaultOrder = false
+ case "id":
+ field = group.FieldID
+ defaultOrder = false
+ tieField = ""
+ default:
+ field = group.FieldSortOrder
+ }
+
+ if sortOrder == pagination.SortOrderDesc && sortBy != "" {
+ if tieField == "" {
+ return []func(*entsql.Selector){dbent.Desc(field)}
+ }
+ return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(tieField)}
+ }
+ if defaultOrder {
+ return []func(*entsql.Selector){dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)}
+ }
+ if tieField == "" {
+ return []func(*entsql.Selector){dbent.Asc(field)}
+ }
+ return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(tieField)}
+}
+
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
groups, err := r.client.Group.Query().
Where(group.StatusEQ(service.StatusActive)).
diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go
index eccf5cea..f91dae43 100644
--- a/backend/internal/repository/group_repo_integration_test.go
+++ b/backend/internal/repository/group_repo_integration_test.go
@@ -113,6 +113,33 @@ func (s *GroupRepoSuite) TestUpdate() {
s.Require().Equal("updated", got.Name)
}
+func (s *GroupRepoSuite) TestGetByID_PreservesMessagesDispatchModelConfig() {
+ group := &service.Group{
+ Name: "openai-dispatch",
+ Platform: service.PlatformOpenAI,
+ RateMultiplier: 1.0,
+ IsExclusive: false,
+ Status: service.StatusActive,
+ SubscriptionType: service.SubscriptionTypeStandard,
+ AllowMessagesDispatch: true,
+ DefaultMappedModel: "gpt-5.4",
+ MessagesDispatchModelConfig: service.OpenAIMessagesDispatchModelConfig{
+ OpusMappedModel: "gpt-5.4",
+ SonnetMappedModel: "gpt-5.3-codex",
+ HaikuMappedModel: "gpt-5.4-mini",
+ ExactModelMappings: map[string]string{
+ "claude-sonnet-4.5": "gpt-5.4-nano",
+ },
+ },
+ }
+
+ s.Require().NoError(s.repo.Create(s.ctx, group))
+
+ got, err := s.repo.GetByID(s.ctx, group.ID)
+ s.Require().NoError(err)
+ s.Require().Equal(group.MessagesDispatchModelConfig, got.MessagesDispatchModelConfig)
+}
+
func (s *GroupRepoSuite) TestDelete() {
group := &service.Group{
Name: "to-delete",
diff --git a/backend/internal/repository/group_repo_sort_integration_test.go b/backend/internal/repository/group_repo_sort_integration_test.go
new file mode 100644
index 00000000..85b2efcc
--- /dev/null
+++ b/backend/internal/repository/group_repo_sort_integration_test.go
@@ -0,0 +1,50 @@
+//go:build integration
+
+package repository
+
+import (
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func (s *GroupRepoSuite) TestList_DefaultSortBySortOrderAsc() {
+ g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 20}
+ g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 10}
+ s.Require().NoError(s.repo.Create(s.ctx, g1))
+ s.Require().NoError(s.repo.Create(s.ctx, g2))
+
+ groups, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
+ s.Require().NoError(err)
+ s.Require().GreaterOrEqual(len(groups), 2)
+ indexByID := make(map[int64]int, len(groups))
+ for i, g := range groups {
+ indexByID[g.ID] = i
+ }
+ s.Require().Contains(indexByID, g1.ID)
+ s.Require().Contains(indexByID, g2.ID)
+ // g2 has SortOrder=10, g1 has SortOrder=20; ascending means g2 comes first
+ s.Require().Less(indexByID[g2.ID], indexByID[g1.ID])
+}
+
+func (s *GroupRepoSuite) TestList_SortBySortOrderDesc() {
+ g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 40}
+ g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 50}
+ s.Require().NoError(s.repo.Create(s.ctx, g1))
+ s.Require().NoError(s.repo.Create(s.ctx, g2))
+
+ groups, _, err := s.repo.List(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "sort_order",
+ SortOrder: "desc",
+ })
+ s.Require().NoError(err)
+ s.Require().GreaterOrEqual(len(groups), 2)
+ indexByID := make(map[int64]int, len(groups))
+ for i, group := range groups {
+ indexByID[group.ID] = i
+ }
+ s.Require().Contains(indexByID, g1.ID)
+ s.Require().Contains(indexByID, g2.ID)
+ s.Require().Less(indexByID[g2.ID], indexByID[g1.ID])
+}
diff --git a/backend/internal/repository/integration_harness_test.go b/backend/internal/repository/integration_harness_test.go
index fb9c26c4..5857fbcb 100644
--- a/backend/internal/repository/integration_harness_test.go
+++ b/backend/internal/repository/integration_harness_test.go
@@ -332,6 +332,10 @@ func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists",
"zadd", "zcard", "zrange", "zrangebyscore", "zrem", "zremrangebyscore", "zrevrange", "zrevrangebyscore", "zscore":
prefixOne(1)
+ case "mget":
+ for i := 1; i < len(args); i++ {
+ prefixOne(i)
+ }
case "del", "unlink":
for i := 1; i < len(args); i++ {
prefixOne(i)
diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go
index 9cf3b392..6dbb9fbd 100644
--- a/backend/internal/repository/migrations_runner.go
+++ b/backend/internal/repository/migrations_runner.go
@@ -51,28 +51,30 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
const migrationsAdvisoryLockID int64 = 694208311321144027
const migrationsLockRetryInterval = 500 * time.Millisecond
const nonTransactionalMigrationSuffix = "_notx.sql"
+const paymentOrdersOutTradeNoUniqueMigration = "120_enforce_payment_orders_out_trade_no_unique_notx.sql"
+const paymentOrdersOutTradeNoUniqueIndex = "paymentorder_out_trade_no_unique"
type migrationChecksumCompatibilityRule struct {
fileChecksum string
acceptedDBChecksum map[string]struct{}
+ acceptedChecksums map[string]struct{}
}
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
-// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。
+// 规则必须同时匹配「迁移名 + 数据库 checksum + 当前文件 checksum」且两者都落在该迁移的已知版本集合内才会放行,
+// 避免放宽全局校验,也允许将误改的历史 migration 回滚为已发布版本而不要求人工修 checksum。
var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{
- "054_drop_legacy_cache_columns.sql": {
- fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
- acceptedDBChecksum: map[string]struct{}{
- "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {},
- },
- },
- "061_add_usage_log_request_type.sql": {
- fileChecksum: "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
- acceptedDBChecksum: map[string]struct{}{
- "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0": {},
- "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {},
- },
- },
+ "054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"),
+ "061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"),
+ "109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"),
+ "110_pending_auth_and_provider_default_grants.sql": newMigrationChecksumCompatibilityRule("32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279", "e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925"),
+ "112_add_payment_order_provider_key_snapshot.sql": newMigrationChecksumCompatibilityRule("b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99", "ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e"),
+ "115_auth_identity_legacy_external_backfill.sql": newMigrationChecksumCompatibilityRule("022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f", "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"),
+ "116_auth_identity_legacy_external_safety_reports.sql": newMigrationChecksumCompatibilityRule("07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488", "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"),
+ "118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227", "a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb"),
+ "119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"),
+ "120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074", "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"),
+ "123_fix_legacy_auth_source_grant_on_signup_defaults.sql": newMigrationChecksumCompatibilityRule("2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57", "6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"),
}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
@@ -199,6 +201,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
}
if nonTx {
+ if err := prepareNonTransactionalMigration(ctx, db, name); err != nil {
+ return fmt.Errorf("prepare migration %s: %w", name, err)
+ }
+
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
statements := splitSQLStatements(content)
@@ -248,6 +254,90 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return nil
}
+func prepareNonTransactionalMigration(ctx context.Context, db *sql.DB, name string) error {
+ switch name {
+ case paymentOrdersOutTradeNoUniqueMigration:
+ return preparePaymentOrdersOutTradeNoUniqueMigration(ctx, db)
+ default:
+ return nil
+ }
+}
+
+func preparePaymentOrdersOutTradeNoUniqueMigration(ctx context.Context, db *sql.DB) error {
+ duplicates, err := findDuplicatePaymentOrderOutTradeNos(ctx, db)
+ if err != nil {
+ return fmt.Errorf("precheck duplicate out_trade_no: %w", err)
+ }
+ if len(duplicates) > 0 {
+ return fmt.Errorf(
+ "duplicate out_trade_no values block %s; remediate duplicates before retrying: %s",
+ paymentOrdersOutTradeNoUniqueMigration,
+ strings.Join(duplicates, ", "),
+ )
+ }
+
+ invalid, err := indexIsInvalid(ctx, db, paymentOrdersOutTradeNoUniqueIndex)
+ if err != nil {
+ return fmt.Errorf("check invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
+ }
+ if !invalid {
+ return nil
+ }
+
+ if _, err := db.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", paymentOrdersOutTradeNoUniqueIndex)); err != nil {
+ return fmt.Errorf("drop invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
+ }
+ return nil
+}
+
+func findDuplicatePaymentOrderOutTradeNos(ctx context.Context, db *sql.DB) ([]string, error) {
+ rows, err := db.QueryContext(ctx, `
+ SELECT out_trade_no, COUNT(*) AS duplicate_count
+ FROM payment_orders
+ WHERE out_trade_no <> ''
+ GROUP BY out_trade_no
+ HAVING COUNT(*) > 1
+ ORDER BY duplicate_count DESC, out_trade_no
+ LIMIT 5
+ `)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ _ = rows.Close()
+ }()
+
+ duplicates := make([]string, 0, 5)
+ for rows.Next() {
+ var outTradeNo string
+ var duplicateCount int
+ if err := rows.Scan(&outTradeNo, &duplicateCount); err != nil {
+ return nil, err
+ }
+ duplicates = append(duplicates, fmt.Sprintf("%s (count=%d)", outTradeNo, duplicateCount))
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return duplicates, nil
+}
+
+func indexIsInvalid(ctx context.Context, db *sql.DB, indexName string) (bool, error) {
+ var invalid bool
+ err := db.QueryRowContext(ctx, `
+ SELECT EXISTS (
+ SELECT 1
+ FROM pg_class idx
+ JOIN pg_namespace ns ON ns.oid = idx.relnamespace
+ JOIN pg_index i ON i.indexrelid = idx.oid
+ WHERE ns.nspname = 'public'
+ AND idx.relname = $1
+ AND NOT i.indisvalid
+ )
+ `, indexName).Scan(&invalid)
+ return invalid, err
+}
+
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
if err != nil {
@@ -322,16 +412,33 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
return version, version, hash, nil
}
+func checksumSet(values ...string) map[string]struct{} {
+ out := make(map[string]struct{}, len(values))
+ for _, value := range values {
+ out[value] = struct{}{}
+ }
+ return out
+}
+
+func newMigrationChecksumCompatibilityRule(fileChecksum string, acceptedDBChecksums ...string) migrationChecksumCompatibilityRule {
+ return migrationChecksumCompatibilityRule{
+ fileChecksum: fileChecksum,
+ acceptedDBChecksum: checksumSet(acceptedDBChecksums...),
+ acceptedChecksums: checksumSet(append([]string{fileChecksum}, acceptedDBChecksums...)...),
+ }
+}
+
func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool {
rule, ok := migrationChecksumCompatibilityRules[name]
if !ok {
return false
}
- if rule.fileChecksum != fileChecksum {
+ _, dbOK := rule.acceptedChecksums[dbChecksum]
+ if !dbOK {
return false
}
- _, ok = rule.acceptedDBChecksum[dbChecksum]
- return ok
+ _, fileOK := rule.acceptedChecksums[fileChecksum]
+ return fileOK
}
func validateMigrationExecutionMode(name, content string) (bool, error) {
diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go
index 6c3ad725..1fcb3be1 100644
--- a/backend/internal/repository/migrations_runner_checksum_test.go
+++ b/backend/internal/repository/migrations_runner_checksum_test.go
@@ -51,4 +51,114 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
)
require.False(t, ok)
})
+
+ t.Run("109历史checksum可兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "109_auth_identity_compat_backfill.sql",
+ "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
+ "0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("109当前checksum可兼容历史checksum", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "109_auth_identity_compat_backfill.sql",
+ "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
+ "0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("109回滚到历史文件后仍兼容已应用的新checksum", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "109_auth_identity_compat_backfill.sql",
+ "0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace",
+ "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("110历史checksum可兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "110_pending_auth_and_provider_default_grants.sql",
+ "e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925",
+ "32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("112历史checksum可兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "112_add_payment_order_provider_key_snapshot.sql",
+ "ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e",
+ "b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("115历史checksum可兼容修复后的legacy external backfill", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "115_auth_identity_legacy_external_backfill.sql",
+ "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f",
+ "022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("116历史checksum可兼容修复后的legacy external safety reports", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "116_auth_identity_legacy_external_safety_reports.sql",
+ "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877",
+ "07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("119历史checksum可兼容占位文件", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "119_enforce_payment_orders_out_trade_no_unique.sql",
+ "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34",
+ "0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e",
+ )
+ require.True(t, ok)
+ })
+
+ t.Run("118多个历史checksum都可兼容当前版本", func(t *testing.T) {
+ for _, dbChecksum := range []string{
+ "a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb",
+ "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227",
+ } {
+ ok := isMigrationChecksumCompatible(
+ "118_wechat_dual_mode_and_auth_source_defaults.sql",
+ dbChecksum,
+ "b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0",
+ )
+ require.True(t, ok)
+ }
+ })
+
+ t.Run("120多个历史checksum都可兼容新的notx修复版本", func(t *testing.T) {
+ for _, dbChecksum := range []string{
+ "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61",
+ "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22",
+ "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a",
+ } {
+ ok := isMigrationChecksumCompatible(
+ "120_enforce_payment_orders_out_trade_no_unique_notx.sql",
+ dbChecksum,
+ "34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074",
+ )
+ require.True(t, ok)
+ }
+ })
+
+ t.Run("119未知checksum不兼容", func(t *testing.T) {
+ ok := isMigrationChecksumCompatible(
+ "119_enforce_payment_orders_out_trade_no_unique.sql",
+ "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34",
+ "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
+ )
+ require.False(t, ok)
+ })
}
diff --git a/backend/internal/repository/migrations_runner_extra_test.go b/backend/internal/repository/migrations_runner_extra_test.go
index 9f8a94c6..5d67665e 100644
--- a/backend/internal/repository/migrations_runner_extra_test.go
+++ b/backend/internal/repository/migrations_runner_extra_test.go
@@ -94,6 +94,24 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum))
}
+func TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations(t *testing.T) {
+ for _, name := range []string{
+ "109_auth_identity_compat_backfill.sql",
+ "110_pending_auth_and_provider_default_grants.sql",
+ "112_add_payment_order_provider_key_snapshot.sql",
+ "115_auth_identity_legacy_external_backfill.sql",
+ "116_auth_identity_legacy_external_safety_reports.sql",
+ "118_wechat_dual_mode_and_auth_source_defaults.sql",
+ "120_enforce_payment_orders_out_trade_no_unique_notx.sql",
+ "123_fix_legacy_auth_source_grant_on_signup_defaults.sql",
+ } {
+ rule, ok := migrationChecksumCompatibilityRules[name]
+ require.Truef(t, ok, "missing compatibility rule for %s", name)
+ require.NotEmpty(t, rule.fileChecksum)
+ require.NotEmpty(t, rule.acceptedDBChecksum)
+ }
+}
+
func TestEnsureAtlasBaselineAligned(t *testing.T) {
t.Run("skip_when_no_legacy_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
diff --git a/backend/internal/repository/migrations_runner_notx_test.go b/backend/internal/repository/migrations_runner_notx_test.go
index db1183cd..b7cb396c 100644
--- a/backend/internal/repository/migrations_runner_notx_test.go
+++ b/backend/internal/repository/migrations_runner_notx_test.go
@@ -116,6 +116,84 @@ CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
require.NoError(t, mock.ExpectationsWereMet())
}
+func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_FailsFastOnDuplicatePrecheck(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ prepareMigrationsBootstrapExpectations(mock)
+ mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
+ WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
+ WillReturnError(sql.ErrNoRows)
+ mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
+ WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}).AddRow("dup-out-trade-no", 2))
+ mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ fsys := fstest.MapFS{
+ "120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
+ Data: []byte(`
+CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
+ ON payment_orders (out_trade_no)
+ WHERE out_trade_no <> '';
+
+DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
+`),
+ },
+ }
+
+ err = applyMigrationsFS(context.Background(), db, fsys)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "duplicate out_trade_no")
+ require.Contains(t, err.Error(), "dup-out-trade-no")
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_DropsInvalidIndexBeforeRetry(t *testing.T) {
+ db, mock, err := sqlmock.New()
+ require.NoError(t, err)
+ defer func() { _ = db.Close() }()
+
+ prepareMigrationsBootstrapExpectations(mock)
+ mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
+ WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
+ WillReturnError(sql.ErrNoRows)
+ mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
+ WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}))
+ mock.ExpectQuery("SELECT EXISTS \\(").
+ WithArgs("paymentorder_out_trade_no_unique").
+ WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
+ mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique").
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ mock.ExpectExec("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique").
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no").
+ WillReturnResult(sqlmock.NewResult(0, 0))
+ mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
+ WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql", sqlmock.AnyArg()).
+ WillReturnResult(sqlmock.NewResult(1, 1))
+ mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
+ WithArgs(migrationsAdvisoryLockID).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ fsys := fstest.MapFS{
+ "120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
+ Data: []byte(`
+CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
+ ON payment_orders (out_trade_no)
+ WHERE out_trade_no <> '';
+
+DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
+`),
+ },
+ }
+
+ err = applyMigrationsFS(context.Background(), db, fsys)
+ require.NoError(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go
index dd3019bb..eeee5c23 100644
--- a/backend/internal/repository/migrations_schema_integration_test.go
+++ b/backend/internal/repository/migrations_schema_integration_test.go
@@ -89,6 +89,35 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
+func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T) {
+ tx := testTx(t)
+
+ requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false)
+ requireColumn(t, tx, "users", "signup_source", "character varying", 20, false)
+ requireColumnDefaultContains(t, tx, "users", "signup_source", "email")
+ requireConstraintDefinitionContains(
+ t,
+ tx,
+ "users",
+ "users_signup_source_check",
+ "signup_source",
+ "'email'",
+ "'linuxdo'",
+ "'wechat'",
+ "'oidc'",
+ )
+
+ requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE")
+ requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE")
+ requireForeignKeyOnDelete(t, tx, "pending_auth_sessions", "target_user_id", "users", "SET NULL")
+ requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "pending_auth_session_id", "pending_auth_sessions", "CASCADE")
+ requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "identity_id", "auth_identities", "SET NULL")
+
+ requireIndex(t, tx, "payment_orders", "paymentorder_out_trade_no")
+ requirePartialUniqueIndexDefinition(t, tx, "payment_orders", "paymentorder_out_trade_no", "out_trade_no", "WHERE")
+ requireIndexAbsent(t, tx, "payment_orders", "paymentorder_out_trade_no_unique")
+}
+
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
@@ -106,6 +135,118 @@ SELECT EXISTS (
require.True(t, exists, "expected index %s on %s", index, table)
}
+func requireIndexAbsent(t *testing.T, tx *sql.Tx, table, index string) {
+ t.Helper()
+
+ var exists bool
+ err := tx.QueryRowContext(context.Background(), `
+SELECT EXISTS (
+ SELECT 1
+ FROM pg_indexes
+ WHERE schemaname = 'public'
+ AND tablename = $1
+ AND indexname = $2
+)
+`, table, index).Scan(&exists)
+ require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
+ require.False(t, exists, "expected index %s on %s to be absent", index, table)
+}
+
+func requirePartialUniqueIndexDefinition(t *testing.T, tx *sql.Tx, table, index string, fragments ...string) {
+ t.Helper()
+
+ var (
+ unique bool
+ def string
+ )
+
+ err := tx.QueryRowContext(context.Background(), `
+SELECT
+ i.indisunique,
+ pg_get_indexdef(i.indexrelid)
+FROM pg_class idx
+JOIN pg_index i ON i.indexrelid = idx.oid
+JOIN pg_class tbl ON tbl.oid = i.indrelid
+JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
+WHERE ns.nspname = 'public'
+ AND tbl.relname = $1
+ AND idx.relname = $2
+`, table, index).Scan(&unique, &def)
+ require.NoError(t, err, "query index definition for %s.%s", table, index)
+ require.True(t, unique, "expected index %s on %s to be unique", index, table)
+
+ for _, fragment := range fragments {
+ require.Contains(t, def, fragment, "expected index definition for %s.%s to contain %q", table, index, fragment)
+ }
+}
+
+func requireForeignKeyOnDelete(t *testing.T, tx *sql.Tx, table, column, refTable, expected string) {
+ t.Helper()
+
+ var actual string
+ err := tx.QueryRowContext(context.Background(), `
+SELECT CASE c.confdeltype
+ WHEN 'a' THEN 'NO ACTION'
+ WHEN 'r' THEN 'RESTRICT'
+ WHEN 'c' THEN 'CASCADE'
+ WHEN 'n' THEN 'SET NULL'
+ WHEN 'd' THEN 'SET DEFAULT'
+END
+FROM pg_constraint c
+JOIN pg_class tbl ON tbl.oid = c.conrelid
+JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
+JOIN pg_class ref_tbl ON ref_tbl.oid = c.confrelid
+JOIN pg_attribute attr ON attr.attrelid = tbl.oid AND attr.attnum = ANY(c.conkey)
+WHERE ns.nspname = 'public'
+ AND c.contype = 'f'
+ AND tbl.relname = $1
+ AND attr.attname = $2
+ AND ref_tbl.relname = $3
+LIMIT 1
+`, table, column, refTable).Scan(&actual)
+ require.NoError(t, err, "query foreign key action for %s.%s -> %s", table, column, refTable)
+ require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable)
+}
+
+func requireConstraintDefinitionContains(t *testing.T, tx *sql.Tx, table, constraint string, fragments ...string) {
+ t.Helper()
+
+ var def string
+ err := tx.QueryRowContext(context.Background(), `
+SELECT pg_get_constraintdef(c.oid)
+FROM pg_constraint c
+JOIN pg_class tbl ON tbl.oid = c.conrelid
+JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
+WHERE ns.nspname = 'public'
+ AND tbl.relname = $1
+ AND c.conname = $2
+`, table, constraint).Scan(&def)
+ require.NoError(t, err, "query constraint definition for %s.%s", table, constraint)
+
+ for _, fragment := range fragments {
+ require.Contains(t, def, fragment, "expected constraint definition for %s.%s to contain %q", table, constraint, fragment)
+ }
+}
+
+func requireColumnDefaultContains(t *testing.T, tx *sql.Tx, table, column string, fragments ...string) {
+ t.Helper()
+
+ var columnDefault sql.NullString
+ err := tx.QueryRowContext(context.Background(), `
+SELECT column_default
+FROM information_schema.columns
+WHERE table_schema = 'public'
+ AND table_name = $1
+ AND column_name = $2
+`, table, column).Scan(&columnDefault)
+ require.NoError(t, err, "query column_default for %s.%s", table, column)
+ require.True(t, columnDefault.Valid, "expected column_default for %s.%s", table, column)
+
+ for _, fragment := range fragments {
+ require.Contains(t, columnDefault.String, fragment, "expected default for %s.%s to contain %q", table, column, fragment)
+ }
+}
+
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper()
diff --git a/backend/internal/repository/openai_403_counter_cache.go b/backend/internal/repository/openai_403_counter_cache.go
new file mode 100644
index 00000000..a68d2518
--- /dev/null
+++ b/backend/internal/repository/openai_403_counter_cache.go
@@ -0,0 +1,51 @@
+package repository
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/redis/go-redis/v9"
+)
+
+const openAI403CounterPrefix = "openai_403_count:account:"
+
+var openAI403CounterIncrScript = redis.NewScript(`
+ local key = KEYS[1]
+ local ttl = tonumber(ARGV[1])
+
+ local count = redis.call('INCR', key)
+ if count == 1 then
+ redis.call('EXPIRE', key, ttl)
+ end
+
+ return count
+`)
+
+type openAI403CounterCache struct {
+ rdb *redis.Client
+}
+
+func NewOpenAI403CounterCache(rdb *redis.Client) service.OpenAI403CounterCache {
+ return &openAI403CounterCache{rdb: rdb}
+}
+
+func (c *openAI403CounterCache) IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error) {
+ key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)
+
+ ttlSeconds := windowMinutes * 60
+ if ttlSeconds < 60 {
+ ttlSeconds = 60
+ }
+
+ result, err := openAI403CounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64()
+ if err != nil {
+ return 0, fmt.Errorf("increment openai 403 count: %w", err)
+ }
+ return result, nil
+}
+
+func (c *openAI403CounterCache) ResetOpenAI403Count(ctx context.Context, accountID int64) error {
+ key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)
+ return c.rdb.Del(ctx, key).Err()
+}
diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go
index dca0b612..acb270a3 100644
--- a/backend/internal/repository/openai_oauth_service.go
+++ b/backend/internal/repository/openai_oauth_service.go
@@ -2,6 +2,7 @@ package repository
import (
"context"
+ "errors"
"net/http"
"net/url"
"strings"
@@ -53,6 +54,9 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
Post(s.tokenURL)
if err != nil {
+ if shouldReturnOpenAINoProxyHint(ctx, proxyURL, err) {
+ return nil, newOpenAINoProxyHintError(err)
+ }
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
@@ -98,6 +102,9 @@ func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refre
Post(s.tokenURL)
if err != nil {
+ if shouldReturnOpenAINoProxyHint(ctx, proxyURL, err) {
+ return nil, newOpenAINoProxyHintError(err)
+ }
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
}
@@ -114,3 +121,21 @@ func createOpenAIReqClient(proxyURL string) (*req.Client, error) {
Timeout: 120 * time.Second,
})
}
+
+func shouldReturnOpenAINoProxyHint(ctx context.Context, proxyURL string, err error) bool {
+ if strings.TrimSpace(proxyURL) != "" || err == nil {
+ return false
+ }
+ if ctx != nil && ctx.Err() != nil {
+ return false
+ }
+ return !errors.Is(err, context.Canceled)
+}
+
+func newOpenAINoProxyHintError(cause error) error {
+ return infraerrors.New(
+ http.StatusBadGateway,
+ "OPENAI_OAUTH_PROXY_REQUIRED",
+ "OpenAI OAuth request failed: no proxy is configured and this server could not reach OpenAI directly. Select a proxy that can access OpenAI, then retry; if the authorization code has expired, regenerate the authorization URL.",
+ ).WithCause(cause)
+}
diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go
index 44fa291b..b43e2b52 100644
--- a/backend/internal/repository/openai_oauth_service_test.go
+++ b/backend/internal/repository/openai_oauth_service_test.go
@@ -8,6 +8,7 @@ import (
"net/url"
"testing"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
@@ -158,30 +159,6 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
}
-// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
-func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
- var seenClientIDs []string
- s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if err := r.ParseForm(); err != nil {
- w.WriteHeader(http.StatusBadRequest)
- return
- }
- clientID := r.PostForm.Get("client_id")
- seenClientIDs = append(seenClientIDs, clientID)
- if clientID == openai.SoraClientID {
- w.Header().Set("Content-Type", "application/json")
- _, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
- return
- }
- w.WriteHeader(http.StatusBadRequest)
- }))
-
- resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
- require.NoError(s.T(), err, "RefreshTokenWithClientID")
- require.Equal(s.T(), "at-sora", resp.AccessToken)
- require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
-}
-
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
const customClientID = "custom-client-id"
var seenClientIDs []string
@@ -228,6 +205,17 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
require.ErrorContains(s.T(), err, "request failed")
}
+func (s *OpenAIOAuthServiceSuite) TestExchangeCode_RequestErrorWithoutProxyReturnsProxyHint() {
+ s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
+ s.srv.Close()
+
+ _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
+
+ require.Error(s.T(), err)
+ require.Equal(s.T(), "OPENAI_OAUTH_PROXY_REQUIRED", infraerrors.Reason(err))
+ require.Contains(s.T(), infraerrors.Message(err), "no proxy is configured")
+}
+
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
started := make(chan struct{})
block := make(chan struct{})
@@ -276,7 +264,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
- wantClientID := openai.SoraClientID
+ wantClientID := "custom-exchange-client-id"
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
diff --git a/backend/internal/repository/pagination.go b/backend/internal/repository/pagination.go
index ff08c34b..87c42a59 100644
--- a/backend/internal/repository/pagination.go
+++ b/backend/internal/repository/pagination.go
@@ -14,3 +14,22 @@ func paginationResultFromTotal(total int64, params pagination.PaginationParams)
Pages: pages,
}
}
+
+func paginateSlice[T any](items []T, params pagination.PaginationParams) []T {
+ if len(items) == 0 {
+ return []T{}
+ }
+
+ offset := params.Offset()
+ if offset >= len(items) {
+ return []T{}
+ }
+
+ limit := params.Limit()
+ end := offset + limit
+ if end > len(items) {
+ end = len(items)
+ }
+
+ return items[offset:end]
+}
diff --git a/backend/internal/repository/promo_code_repo.go b/backend/internal/repository/promo_code_repo.go
index 95ce687a..d9c76bb3 100644
--- a/backend/internal/repository/promo_code_repo.go
+++ b/backend/internal/repository/promo_code_repo.go
@@ -2,12 +2,15 @@ package repository
import (
"context"
+ "strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
+
+ entsql "entgo.io/ent/dialect/sql"
)
type promoCodeRepository struct {
@@ -137,11 +140,14 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
return nil, nil, err
}
- codes, err := q.
+ codesQuery := q.
Offset(params.Offset()).
- Limit(params.Limit()).
- Order(dbent.Desc(promocode.FieldID)).
- All(ctx)
+ Limit(params.Limit())
+ for _, order := range promoCodeListOrder(params) {
+ codesQuery = codesQuery.Order(order)
+ }
+
+ codes, err := codesQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -151,6 +157,32 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
return outCodes, paginationResultFromTotal(int64(total), params), nil
}
+func promoCodeListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
+ sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
+ sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
+
+ var field string
+ switch sortBy {
+ case "bonus_amount":
+ field = promocode.FieldBonusAmount
+ case "status":
+ field = promocode.FieldStatus
+ case "expires_at":
+ field = promocode.FieldExpiresAt
+ case "created_at":
+ field = promocode.FieldCreatedAt
+ case "code":
+ field = promocode.FieldCode
+ default:
+ field = promocode.FieldID
+ }
+
+ if sortOrder == pagination.SortOrderAsc {
+ return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(promocode.FieldID)}
+ }
+ return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(promocode.FieldID)}
+}
+
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
client := clientFromContext(ctx, r.client)
created, err := client.PromoCodeUsage.Create().
diff --git a/backend/internal/repository/proxy_repo.go b/backend/internal/repository/proxy_repo.go
index 07c2a204..60b2f069 100644
--- a/backend/internal/repository/proxy_repo.go
+++ b/backend/internal/repository/proxy_repo.go
@@ -3,12 +3,16 @@ package repository
import (
"context"
"database/sql"
+ "sort"
+ "strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+
+ entsql "entgo.io/ent/dialect/sql"
)
type sqlQuerier interface {
@@ -135,11 +139,14 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
return nil, nil, err
}
- proxies, err := q.
+ proxiesQuery := q.
Offset(params.Offset()).
- Limit(params.Limit()).
- Order(dbent.Desc(proxy.FieldID)).
- All(ctx)
+ Limit(params.Limit())
+ for _, order := range proxyListOrder(params) {
+ proxiesQuery = proxiesQuery.Order(order)
+ }
+
+ proxies, err := proxiesQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -170,22 +177,58 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa
return nil, nil, err
}
- proxies, err := q.
+ if strings.EqualFold(strings.TrimSpace(params.SortBy), "account_count") {
+ return r.listWithAccountCountSort(ctx, q, params, total)
+ }
+
+ proxiesQuery := q.
Offset(params.Offset()).
- Limit(params.Limit()).
+ Limit(params.Limit())
+ for _, order := range proxyListOrder(params) {
+ proxiesQuery = proxiesQuery.Order(order)
+ }
+
+ proxies, err := proxiesQuery.All(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return r.buildProxyWithAccountCountResult(ctx, proxies, params, int64(total))
+}
+
+func (r *proxyRepository) listWithAccountCountSort(ctx context.Context, q *dbent.ProxyQuery, params pagination.PaginationParams, total int) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
+ proxies, err := q.
Order(dbent.Desc(proxy.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
- // Get account counts
+ result, _, err := r.buildProxyWithAccountCountResult(ctx, proxies, params, int64(total))
+ if err != nil {
+ return nil, nil, err
+ }
+
+ sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
+ sort.SliceStable(result, func(i, j int) bool {
+ if result[i].AccountCount == result[j].AccountCount {
+ return result[i].ID > result[j].ID
+ }
+ if sortOrder == pagination.SortOrderAsc {
+ return result[i].AccountCount < result[j].AccountCount
+ }
+ return result[i].AccountCount > result[j].AccountCount
+ })
+
+ return paginateSlice(result, params), paginationResultFromTotal(int64(total), params), nil
+}
+
+func (r *proxyRepository) buildProxyWithAccountCountResult(ctx context.Context, proxies []*dbent.Proxy, params pagination.PaginationParams, total int64) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
counts, err := r.GetAccountCountsForProxies(ctx)
if err != nil {
return nil, nil, err
}
- // Build result with account counts
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
proxyOut := proxyEntityToService(proxies[i])
@@ -198,7 +241,31 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa
})
}
- return result, paginationResultFromTotal(int64(total), params), nil
+ return result, paginationResultFromTotal(total, params), nil
+}
+
+func proxyListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
+ sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
+ sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
+
+ var field string
+ switch sortBy {
+ case "name":
+ field = proxy.FieldName
+ case "protocol":
+ field = proxy.FieldProtocol
+ case "status":
+ field = proxy.FieldStatus
+ case "created_at":
+ field = proxy.FieldCreatedAt
+ default:
+ field = proxy.FieldID
+ }
+
+ if sortOrder == pagination.SortOrderAsc {
+ return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(proxy.FieldID)}
+ }
+ return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(proxy.FieldID)}
}
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
diff --git a/backend/internal/repository/proxy_repo_sort_integration_test.go b/backend/internal/repository/proxy_repo_sort_integration_test.go
new file mode 100644
index 00000000..fe1c2873
--- /dev/null
+++ b/backend/internal/repository/proxy_repo_sort_integration_test.go
@@ -0,0 +1,28 @@
+//go:build integration
+
+package repository
+
+import (
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func (s *ProxyRepoSuite) TestListWithFiltersAndAccountCount_SortByAccountCountDesc() {
+ p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
+ p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
+ s.mustInsertAccount("a1", &p1.ID)
+ s.mustInsertAccount("a2", &p1.ID)
+ s.mustInsertAccount("a3", &p2.ID)
+
+ proxies, _, err := s.repo.ListWithFiltersAndAccountCount(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "account_count",
+ SortOrder: "desc",
+ }, "", "", "")
+ s.Require().NoError(err)
+ s.Require().Len(proxies, 2)
+ s.Require().Equal(p1.ID, proxies[0].ID)
+ s.Require().Equal(int64(2), proxies[0].AccountCount)
+ s.Require().Equal(p2.ID, proxies[1].ID)
+}
diff --git a/backend/internal/repository/redeem_code_repo.go b/backend/internal/repository/redeem_code_repo.go
index 934a3095..07975970 100644
--- a/backend/internal/repository/redeem_code_repo.go
+++ b/backend/internal/repository/redeem_code_repo.go
@@ -2,6 +2,7 @@ package repository
import (
"context"
+ "strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -9,6 +10,8 @@ import (
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
+
+ entsql "entgo.io/ent/dialect/sql"
)
type redeemCodeRepository struct {
@@ -120,13 +123,16 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
return nil, nil, err
}
- codes, err := q.
+ codesQuery := q.
WithUser().
WithGroup().
Offset(params.Offset()).
- Limit(params.Limit()).
- Order(dbent.Desc(redeemcode.FieldID)).
- All(ctx)
+ Limit(params.Limit())
+ for _, order := range redeemCodeListOrder(params) {
+ codesQuery = codesQuery.Order(order)
+ }
+
+ codes, err := codesQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -136,6 +142,34 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
return outCodes, paginationResultFromTotal(int64(total), params), nil
}
+func redeemCodeListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
+ sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
+ sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
+
+ var field string
+ switch sortBy {
+ case "type":
+ field = redeemcode.FieldType
+ case "value":
+ field = redeemcode.FieldValue
+ case "status":
+ field = redeemcode.FieldStatus
+ case "used_at":
+ field = redeemcode.FieldUsedAt
+ case "created_at":
+ field = redeemcode.FieldCreatedAt
+ case "code":
+ field = redeemcode.FieldCode
+ default:
+ field = redeemcode.FieldID
+ }
+
+ if sortOrder == pagination.SortOrderAsc {
+ return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(redeemcode.FieldID)}
+ }
+ return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(redeemcode.FieldID)}
+}
+
func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
up := r.client.RedeemCode.UpdateOneID(code.ID).
SetCode(code.Code).
diff --git a/backend/internal/repository/redeem_code_repo_sort_integration_test.go b/backend/internal/repository/redeem_code_repo_sort_integration_test.go
new file mode 100644
index 00000000..30d32f4c
--- /dev/null
+++ b/backend/internal/repository/redeem_code_repo_sort_integration_test.go
@@ -0,0 +1,24 @@
+//go:build integration
+
+package repository
+
+import (
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func (s *RedeemCodeRepoSuite) TestListWithFilters_SortByValueAsc() {
+ s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "VALUE-20", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused}))
+ s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "VALUE-10", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused}))
+
+ codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "value",
+ SortOrder: "asc",
+ }, "", "", "")
+ s.Require().NoError(err)
+ s.Require().Len(codes, 2)
+ s.Require().Equal("VALUE-10", codes[0].Code)
+ s.Require().Equal("VALUE-20", codes[1].Code)
+}
diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go
index 4f447e4f..add0e501 100644
--- a/backend/internal/repository/scheduler_cache.go
+++ b/backend/internal/repository/scheduler_cache.go
@@ -15,19 +15,39 @@ const (
schedulerBucketSetKey = "sched:buckets"
schedulerOutboxWatermarkKey = "sched:outbox:watermark"
schedulerAccountPrefix = "sched:acc:"
+ schedulerAccountMetaPrefix = "sched:meta:"
schedulerActivePrefix = "sched:active:"
schedulerReadyPrefix = "sched:ready:"
schedulerVersionPrefix = "sched:ver:"
schedulerSnapshotPrefix = "sched:"
schedulerLockPrefix = "sched:lock:"
+
+ defaultSchedulerSnapshotMGetChunkSize = 128
+ defaultSchedulerSnapshotWriteChunkSize = 256
)
type schedulerCache struct {
- rdb *redis.Client
+ rdb *redis.Client
+ mgetChunkSize int
+ writeChunkSize int
}
func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache {
- return &schedulerCache{rdb: rdb}
+ return newSchedulerCacheWithChunkSizes(rdb, defaultSchedulerSnapshotMGetChunkSize, defaultSchedulerSnapshotWriteChunkSize)
+}
+
+func newSchedulerCacheWithChunkSizes(rdb *redis.Client, mgetChunkSize, writeChunkSize int) service.SchedulerCache {
+ if mgetChunkSize <= 0 {
+ mgetChunkSize = defaultSchedulerSnapshotMGetChunkSize
+ }
+ if writeChunkSize <= 0 {
+ writeChunkSize = defaultSchedulerSnapshotWriteChunkSize
+ }
+ return &schedulerCache{
+ rdb: rdb,
+ mgetChunkSize: mgetChunkSize,
+ writeChunkSize: writeChunkSize,
+ }
}
func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
@@ -65,9 +85,9 @@ func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.Schedul
keys := make([]string, 0, len(ids))
for _, id := range ids {
- keys = append(keys, schedulerAccountKey(id))
+ keys = append(keys, schedulerAccountMetaKey(id))
}
- values, err := c.rdb.MGet(ctx, keys...).Result()
+ values, err := c.mgetChunked(ctx, keys)
if err != nil {
return nil, false, err
}
@@ -100,14 +120,11 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
versionStr := strconv.FormatInt(version, 10)
snapshotKey := schedulerSnapshotKey(bucket, versionStr)
- pipe := c.rdb.Pipeline()
- for _, account := range accounts {
- payload, err := json.Marshal(account)
- if err != nil {
- return err
- }
- pipe.Set(ctx, schedulerAccountKey(strconv.FormatInt(account.ID, 10)), payload, 0)
+ if err := c.writeAccounts(ctx, accounts); err != nil {
+ return err
}
+
+ pipe := c.rdb.Pipeline()
if len(accounts) > 0 {
// 使用序号作为 score,保持数据库返回的排序语义。
members := make([]redis.Z, 0, len(accounts))
@@ -117,7 +134,13 @@ func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.Schedul
Member: strconv.FormatInt(account.ID, 10),
})
}
- pipe.ZAdd(ctx, snapshotKey, members...)
+ for start := 0; start < len(members); start += c.writeChunkSize {
+ end := start + c.writeChunkSize
+ if end > len(members) {
+ end = len(members)
+ }
+ pipe.ZAdd(ctx, snapshotKey, members[start:end]...)
+ }
} else {
pipe.Del(ctx, snapshotKey)
}
@@ -151,20 +174,15 @@ func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Accoun
if account == nil || account.ID <= 0 {
return nil
}
- payload, err := json.Marshal(account)
- if err != nil {
- return err
- }
- key := schedulerAccountKey(strconv.FormatInt(account.ID, 10))
- return c.rdb.Set(ctx, key, payload, 0).Err()
+ return c.writeAccounts(ctx, []service.Account{*account})
}
func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error {
if accountID <= 0 {
return nil
}
- key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
- return c.rdb.Del(ctx, key).Err()
+ id := strconv.FormatInt(accountID, 10)
+ return c.rdb.Del(ctx, schedulerAccountKey(id), schedulerAccountMetaKey(id)).Err()
}
func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
@@ -179,7 +197,7 @@ func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]t
ids = append(ids, id)
}
- values, err := c.rdb.MGet(ctx, keys...).Result()
+ values, err := c.mgetChunked(ctx, keys)
if err != nil {
return err
}
@@ -198,7 +216,12 @@ func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]t
if err != nil {
return err
}
+ metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(*account))
+ if err != nil {
+ return err
+ }
pipe.Set(ctx, keys[i], updated, 0)
+ pipe.Set(ctx, schedulerAccountMetaKey(strconv.FormatInt(ids[i], 10)), metaPayload, 0)
}
_, err = pipe.Exec(ctx)
return err
@@ -256,6 +279,10 @@ func schedulerAccountKey(id string) string {
return schedulerAccountPrefix + id
}
+func schedulerAccountMetaKey(id string) string {
+ return schedulerAccountMetaPrefix + id
+}
+
func ptrTime(t time.Time) *time.Time {
return &t
}
@@ -276,3 +303,145 @@ func decodeCachedAccount(val any) (*service.Account, error) {
}
return &account, nil
}
+
+func (c *schedulerCache) writeAccounts(ctx context.Context, accounts []service.Account) error {
+ if len(accounts) == 0 {
+ return nil
+ }
+
+ pipe := c.rdb.Pipeline()
+ pending := 0
+ flush := func() error {
+ if pending == 0 {
+ return nil
+ }
+ if _, err := pipe.Exec(ctx); err != nil {
+ return err
+ }
+ pipe = c.rdb.Pipeline()
+ pending = 0
+ return nil
+ }
+
+ for _, account := range accounts {
+ fullPayload, err := json.Marshal(account)
+ if err != nil {
+ return err
+ }
+ metaPayload, err := json.Marshal(buildSchedulerMetadataAccount(account))
+ if err != nil {
+ return err
+ }
+
+ id := strconv.FormatInt(account.ID, 10)
+ pipe.Set(ctx, schedulerAccountKey(id), fullPayload, 0)
+ pipe.Set(ctx, schedulerAccountMetaKey(id), metaPayload, 0)
+ pending++
+ if pending >= c.writeChunkSize {
+ if err := flush(); err != nil {
+ return err
+ }
+ }
+ }
+
+ return flush()
+}
+
+func (c *schedulerCache) mgetChunked(ctx context.Context, keys []string) ([]any, error) {
+ if len(keys) == 0 {
+ return []any{}, nil
+ }
+
+ out := make([]any, 0, len(keys))
+ chunkSize := c.mgetChunkSize
+ if chunkSize <= 0 {
+ chunkSize = defaultSchedulerSnapshotMGetChunkSize
+ }
+ for start := 0; start < len(keys); start += chunkSize {
+ end := start + chunkSize
+ if end > len(keys) {
+ end = len(keys)
+ }
+ part, err := c.rdb.MGet(ctx, keys[start:end]...).Result()
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, part...)
+ }
+ return out, nil
+}
+
+func buildSchedulerMetadataAccount(account service.Account) service.Account {
+ return service.Account{
+ ID: account.ID,
+ Name: account.Name,
+ Platform: account.Platform,
+ Type: account.Type,
+ Concurrency: account.Concurrency,
+ LoadFactor: account.LoadFactor,
+ Priority: account.Priority,
+ RateMultiplier: account.RateMultiplier,
+ Status: account.Status,
+ LastUsedAt: account.LastUsedAt,
+ ExpiresAt: account.ExpiresAt,
+ AutoPauseOnExpired: account.AutoPauseOnExpired,
+ Schedulable: account.Schedulable,
+ RateLimitedAt: account.RateLimitedAt,
+ RateLimitResetAt: account.RateLimitResetAt,
+ OverloadUntil: account.OverloadUntil,
+ TempUnschedulableUntil: account.TempUnschedulableUntil,
+ TempUnschedulableReason: account.TempUnschedulableReason,
+ SessionWindowStart: account.SessionWindowStart,
+ SessionWindowEnd: account.SessionWindowEnd,
+ SessionWindowStatus: account.SessionWindowStatus,
+ Credentials: filterSchedulerCredentials(account.Credentials),
+ Extra: filterSchedulerExtra(account.Extra),
+ }
+}
+
+func filterSchedulerCredentials(credentials map[string]any) map[string]any {
+ if len(credentials) == 0 {
+ return nil
+ }
+ keys := []string{"model_mapping", "api_key", "project_id", "oauth_type"}
+ filtered := make(map[string]any)
+ for _, key := range keys {
+ if value, ok := credentials[key]; ok && value != nil {
+ filtered[key] = value
+ }
+ }
+ if len(filtered) == 0 {
+ return nil
+ }
+ return filtered
+}
+
+func filterSchedulerExtra(extra map[string]any) map[string]any {
+ if len(extra) == 0 {
+ return nil
+ }
+ keys := []string{
+ "mixed_scheduling",
+ "window_cost_limit",
+ "window_cost_sticky_reserve",
+ "max_sessions",
+ "session_idle_timeout_minutes",
+ "openai_oauth_responses_websockets_v2_enabled",
+ "openai_oauth_responses_websockets_v2_mode",
+ "openai_apikey_responses_websockets_v2_enabled",
+ "openai_apikey_responses_websockets_v2_mode",
+ "responses_websockets_v2_enabled",
+ "openai_ws_enabled",
+ "openai_ws_force_http",
+ }
+ filtered := make(map[string]any)
+ for _, key := range keys {
+ if value, ok := extra[key]; ok && value != nil {
+ filtered[key] = value
+ }
+ }
+ if len(filtered) == 0 {
+ return nil
+ }
+ return filtered
+}
diff --git a/backend/internal/repository/scheduler_cache_integration_test.go b/backend/internal/repository/scheduler_cache_integration_test.go
new file mode 100644
index 00000000..134a6a07
--- /dev/null
+++ b/backend/internal/repository/scheduler_cache_integration_test.go
@@ -0,0 +1,88 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestSchedulerCacheSnapshotUsesSlimMetadataButKeepsFullAccount(t *testing.T) {
+ ctx := context.Background()
+ rdb := testRedis(t)
+ cache := NewSchedulerCache(rdb)
+
+ bucket := service.SchedulerBucket{GroupID: 2, Platform: service.PlatformGemini, Mode: service.SchedulerModeSingle}
+ now := time.Now().UTC().Truncate(time.Second)
+ limitReset := now.Add(10 * time.Minute)
+ overloadUntil := now.Add(2 * time.Minute)
+ tempUnschedUntil := now.Add(3 * time.Minute)
+ windowEnd := now.Add(5 * time.Hour)
+
+ account := service.Account{
+ ID: 101,
+ Name: "gemini-heavy",
+ Platform: service.PlatformGemini,
+ Type: service.AccountTypeOAuth,
+ Status: service.StatusActive,
+ Schedulable: true,
+ Concurrency: 3,
+ Priority: 7,
+ LastUsedAt: &now,
+ Credentials: map[string]any{
+ "api_key": "gemini-api-key",
+ "access_token": "secret-access-token",
+ "project_id": "proj-1",
+ "oauth_type": "ai_studio",
+ "model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"},
+ "huge_blob": strings.Repeat("x", 4096),
+ },
+ Extra: map[string]any{
+ "mixed_scheduling": true,
+ "window_cost_limit": 12.5,
+ "window_cost_sticky_reserve": 8.0,
+ "max_sessions": 4,
+ "session_idle_timeout_minutes": 11,
+ "unused_large_field": strings.Repeat("y", 4096),
+ },
+ RateLimitResetAt: &limitReset,
+ OverloadUntil: &overloadUntil,
+ TempUnschedulableUntil: &tempUnschedUntil,
+ SessionWindowStart: &now,
+ SessionWindowEnd: &windowEnd,
+ SessionWindowStatus: "active",
+ }
+
+ require.NoError(t, cache.SetSnapshot(ctx, bucket, []service.Account{account}))
+
+ snapshot, hit, err := cache.GetSnapshot(ctx, bucket)
+ require.NoError(t, err)
+ require.True(t, hit)
+ require.Len(t, snapshot, 1)
+
+ got := snapshot[0]
+ require.NotNil(t, got)
+ require.Equal(t, "gemini-api-key", got.GetCredential("api_key"))
+ require.Equal(t, "proj-1", got.GetCredential("project_id"))
+ require.Equal(t, "ai_studio", got.GetCredential("oauth_type"))
+ require.NotEmpty(t, got.GetModelMapping())
+ require.Empty(t, got.GetCredential("access_token"))
+ require.Empty(t, got.GetCredential("huge_blob"))
+ require.Equal(t, true, got.Extra["mixed_scheduling"])
+ require.Equal(t, 12.5, got.GetWindowCostLimit())
+ require.Equal(t, 8.0, got.GetWindowCostStickyReserve())
+ require.Equal(t, 4, got.GetMaxSessions())
+ require.Equal(t, 11, got.GetSessionIdleTimeoutMinutes())
+ require.Nil(t, got.Extra["unused_large_field"])
+
+ full, err := cache.GetAccount(ctx, account.ID)
+ require.NoError(t, err)
+ require.NotNil(t, full)
+ require.Equal(t, "secret-access-token", full.GetCredential("access_token"))
+ require.Equal(t, strings.Repeat("x", 4096), full.GetCredential("huge_blob"))
+}
diff --git a/backend/internal/repository/scheduler_cache_unit_test.go b/backend/internal/repository/scheduler_cache_unit_test.go
new file mode 100644
index 00000000..bcfd0e7a
--- /dev/null
+++ b/backend/internal/repository/scheduler_cache_unit_test.go
@@ -0,0 +1,33 @@
+//go:build unit
+
+package repository
+
+import (
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) {
+ account := service.Account{
+ ID: 42,
+ Platform: service.PlatformOpenAI,
+ Type: service.AccountTypeOAuth,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_enabled": true,
+ "openai_oauth_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
+ "openai_ws_force_http": true,
+ "mixed_scheduling": true,
+ "unused_large_field": "drop-me",
+ },
+ }
+
+ got := buildSchedulerMetadataAccount(account)
+
+ require.Equal(t, true, got.Extra["openai_oauth_responses_websockets_v2_enabled"])
+ require.Equal(t, service.OpenAIWSIngressModePassthrough, got.Extra["openai_oauth_responses_websockets_v2_mode"])
+ require.Equal(t, true, got.Extra["openai_ws_force_http"])
+ require.Equal(t, true, got.Extra["mixed_scheduling"])
+ require.Nil(t, got.Extra["unused_large_field"])
+}
diff --git a/backend/internal/repository/sora_account_repo.go b/backend/internal/repository/sora_account_repo.go
deleted file mode 100644
index ad2ae638..00000000
--- a/backend/internal/repository/sora_account_repo.go
+++ /dev/null
@@ -1,98 +0,0 @@
-package repository
-
-import (
- "context"
- "database/sql"
- "errors"
-
- "github.com/Wei-Shaw/sub2api/internal/service"
-)
-
-// soraAccountRepository 实现 service.SoraAccountRepository 接口。
-// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。
-//
-// 设计说明:
-// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理
-// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义
-// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除
-type soraAccountRepository struct {
- sql *sql.DB
-}
-
-// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例
-func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository {
- return &soraAccountRepository{sql: sqlDB}
-}
-
-// Upsert 创建或更新 Sora 账号扩展信息
-// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert
-func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
- accessToken, accessOK := updates["access_token"].(string)
- refreshToken, refreshOK := updates["refresh_token"].(string)
- sessionToken, sessionOK := updates["session_token"].(string)
-
- if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" {
- if !sessionOK {
- return errors.New("缺少 access_token/refresh_token,且未提供可更新字段")
- }
- result, err := r.sql.ExecContext(ctx, `
- UPDATE sora_accounts
- SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END,
- updated_at = NOW()
- WHERE account_id = $1
- `, accountID, sessionToken)
- if err != nil {
- return err
- }
- rows, err := result.RowsAffected()
- if err != nil {
- return err
- }
- if rows == 0 {
- return errors.New("sora_accounts 记录不存在,无法仅更新 session_token")
- }
- return nil
- }
-
- _, err := r.sql.ExecContext(ctx, `
- INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at)
- VALUES ($1, $2, $3, $4, NOW(), NOW())
- ON CONFLICT (account_id) DO UPDATE SET
- access_token = EXCLUDED.access_token,
- refresh_token = EXCLUDED.refresh_token,
- session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END,
- updated_at = NOW()
- `, accountID, accessToken, refreshToken, sessionToken)
- return err
-}
-
-// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
-func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
- rows, err := r.sql.QueryContext(ctx, `
- SELECT account_id, access_token, refresh_token, COALESCE(session_token, '')
- FROM sora_accounts
- WHERE account_id = $1
- `, accountID)
- if err != nil {
- return nil, err
- }
- defer func() { _ = rows.Close() }()
-
- if !rows.Next() {
- return nil, nil // 记录不存在
- }
-
- var sa service.SoraAccount
- if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil {
- return nil, err
- }
- return &sa, nil
-}
-
-// Delete 删除 Sora 账号扩展信息
-func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error {
- _, err := r.sql.ExecContext(ctx, `
- DELETE FROM sora_accounts WHERE account_id = $1
- `, accountID)
- return err
-}
diff --git a/backend/internal/repository/sora_generation_repo.go b/backend/internal/repository/sora_generation_repo.go
deleted file mode 100644
index aaf3cb2f..00000000
--- a/backend/internal/repository/sora_generation_repo.go
+++ /dev/null
@@ -1,419 +0,0 @@
-package repository
-
-import (
- "context"
- "database/sql"
- "encoding/json"
- "fmt"
- "strings"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/service"
-)
-
-// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
-// 使用原生 SQL 操作 sora_generations 表。
-type soraGenerationRepository struct {
- sql *sql.DB
-}
-
-// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
-func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository {
- return &soraGenerationRepository{sql: sqlDB}
-}
-
-func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error {
- mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
- s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
-
- err := r.sql.QueryRowContext(ctx, `
- INSERT INTO sora_generations (
- user_id, api_key_id, model, prompt, media_type,
- status, media_url, media_urls, file_size_bytes,
- storage_type, s3_object_keys, upstream_task_id, error_message
- ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
- RETURNING id, created_at
- `,
- gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
- gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
- gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
- ).Scan(&gen.ID, &gen.CreatedAt)
- return err
-}
-
-// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
-func (r *soraGenerationRepository) CreatePendingWithLimit(
- ctx context.Context,
- gen *service.SoraGeneration,
- activeStatuses []string,
- maxActive int64,
-) error {
- if gen == nil {
- return fmt.Errorf("generation is nil")
- }
- if maxActive <= 0 {
- return r.Create(ctx, gen)
- }
- if len(activeStatuses) == 0 {
- activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating}
- }
-
- tx, err := r.sql.BeginTx(ctx, nil)
- if err != nil {
- return err
- }
- defer func() { _ = tx.Rollback() }()
-
- // 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
- if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil {
- return err
- }
-
- placeholders := make([]string, len(activeStatuses))
- args := make([]any, 0, 1+len(activeStatuses))
- args = append(args, gen.UserID)
- for i, s := range activeStatuses {
- placeholders[i] = fmt.Sprintf("$%d", i+2)
- args = append(args, s)
- }
- countQuery := fmt.Sprintf(
- `SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`,
- strings.Join(placeholders, ","),
- )
- var activeCount int64
- if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil {
- return err
- }
- if activeCount >= maxActive {
- return service.ErrSoraGenerationConcurrencyLimit
- }
-
- mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
- s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
- if err := tx.QueryRowContext(ctx, `
- INSERT INTO sora_generations (
- user_id, api_key_id, model, prompt, media_type,
- status, media_url, media_urls, file_size_bytes,
- storage_type, s3_object_keys, upstream_task_id, error_message
- ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
- RETURNING id, created_at
- `,
- gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
- gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
- gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
- ).Scan(&gen.ID, &gen.CreatedAt); err != nil {
- return err
- }
-
- return tx.Commit()
-}
-
-func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) {
- gen := &service.SoraGeneration{}
- var mediaURLsJSON, s3KeysJSON []byte
- var completedAt sql.NullTime
- var apiKeyID sql.NullInt64
-
- err := r.sql.QueryRowContext(ctx, `
- SELECT id, user_id, api_key_id, model, prompt, media_type,
- status, media_url, media_urls, file_size_bytes,
- storage_type, s3_object_keys, upstream_task_id, error_message,
- created_at, completed_at
- FROM sora_generations WHERE id = $1
- `, id).Scan(
- &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
- &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
- &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
- &gen.CreatedAt, &completedAt,
- )
- if err != nil {
- if err == sql.ErrNoRows {
- return nil, fmt.Errorf("生成记录不存在")
- }
- return nil, err
- }
-
- if apiKeyID.Valid {
- gen.APIKeyID = &apiKeyID.Int64
- }
- if completedAt.Valid {
- gen.CompletedAt = &completedAt.Time
- }
- _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
- _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
- return gen, nil
-}
-
-func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error {
- mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
- s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
-
- var completedAt *time.Time
- if gen.CompletedAt != nil {
- completedAt = gen.CompletedAt
- }
-
- _, err := r.sql.ExecContext(ctx, `
- UPDATE sora_generations SET
- status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
- storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
- error_message = $9, completed_at = $10
- WHERE id = $1
- `,
- gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
- gen.StorageType, s3KeysJSON, gen.UpstreamTaskID,
- gen.ErrorMessage, completedAt,
- )
- return err
-}
-
-// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
-func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) {
- result, err := r.sql.ExecContext(ctx, `
- UPDATE sora_generations
- SET status = $2, upstream_task_id = $3
- WHERE id = $1 AND status = $4
- `,
- id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending,
- )
- if err != nil {
- return false, err
- }
- affected, err := result.RowsAffected()
- if err != nil {
- return false, err
- }
- return affected > 0, nil
-}
-
-// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
-func (r *soraGenerationRepository) UpdateCompletedIfActive(
- ctx context.Context,
- id int64,
- mediaURL string,
- mediaURLs []string,
- storageType string,
- s3Keys []string,
- fileSizeBytes int64,
- completedAt time.Time,
-) (bool, error) {
- mediaURLsJSON, _ := json.Marshal(mediaURLs)
- s3KeysJSON, _ := json.Marshal(s3Keys)
- result, err := r.sql.ExecContext(ctx, `
- UPDATE sora_generations
- SET status = $2,
- media_url = $3,
- media_urls = $4,
- file_size_bytes = $5,
- storage_type = $6,
- s3_object_keys = $7,
- error_message = '',
- completed_at = $8
- WHERE id = $1 AND status IN ($9, $10)
- `,
- id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes,
- storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
- )
- if err != nil {
- return false, err
- }
- affected, err := result.RowsAffected()
- if err != nil {
- return false, err
- }
- return affected > 0, nil
-}
-
-// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
-func (r *soraGenerationRepository) UpdateFailedIfActive(
- ctx context.Context,
- id int64,
- errMsg string,
- completedAt time.Time,
-) (bool, error) {
- result, err := r.sql.ExecContext(ctx, `
- UPDATE sora_generations
- SET status = $2,
- error_message = $3,
- completed_at = $4
- WHERE id = $1 AND status IN ($5, $6)
- `,
- id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
- )
- if err != nil {
- return false, err
- }
- affected, err := result.RowsAffected()
- if err != nil {
- return false, err
- }
- return affected > 0, nil
-}
-
-// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
-func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) {
- result, err := r.sql.ExecContext(ctx, `
- UPDATE sora_generations
- SET status = $2, completed_at = $3
- WHERE id = $1 AND status IN ($4, $5)
- `,
- id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
- )
- if err != nil {
- return false, err
- }
- affected, err := result.RowsAffected()
- if err != nil {
- return false, err
- }
- return affected > 0, nil
-}
-
-// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。
-func (r *soraGenerationRepository) UpdateStorageIfCompleted(
- ctx context.Context,
- id int64,
- mediaURL string,
- mediaURLs []string,
- storageType string,
- s3Keys []string,
- fileSizeBytes int64,
-) (bool, error) {
- mediaURLsJSON, _ := json.Marshal(mediaURLs)
- s3KeysJSON, _ := json.Marshal(s3Keys)
- result, err := r.sql.ExecContext(ctx, `
- UPDATE sora_generations
- SET media_url = $2,
- media_urls = $3,
- file_size_bytes = $4,
- storage_type = $5,
- s3_object_keys = $6
- WHERE id = $1 AND status = $7
- `,
- id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted,
- )
- if err != nil {
- return false, err
- }
- affected, err := result.RowsAffected()
- if err != nil {
- return false, err
- }
- return affected > 0, nil
-}
-
-func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error {
- _, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id)
- return err
-}
-
-func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
- // 构建 WHERE 条件
- conditions := []string{"user_id = $1"}
- args := []any{params.UserID}
- argIdx := 2
-
- if params.Status != "" {
- // 支持逗号分隔的多状态
- statuses := strings.Split(params.Status, ",")
- placeholders := make([]string, len(statuses))
- for i, s := range statuses {
- placeholders[i] = fmt.Sprintf("$%d", argIdx)
- args = append(args, strings.TrimSpace(s))
- argIdx++
- }
- conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
- }
- if params.StorageType != "" {
- storageTypes := strings.Split(params.StorageType, ",")
- placeholders := make([]string, len(storageTypes))
- for i, s := range storageTypes {
- placeholders[i] = fmt.Sprintf("$%d", argIdx)
- args = append(args, strings.TrimSpace(s))
- argIdx++
- }
- conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ",")))
- }
- if params.MediaType != "" {
- conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx))
- args = append(args, params.MediaType)
- argIdx++
- }
-
- whereClause := "WHERE " + strings.Join(conditions, " AND ")
-
- // 计数
- var total int64
- countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause)
- if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
- return nil, 0, err
- }
-
- // 分页查询
- offset := (params.Page - 1) * params.PageSize
- listQuery := fmt.Sprintf(`
- SELECT id, user_id, api_key_id, model, prompt, media_type,
- status, media_url, media_urls, file_size_bytes,
- storage_type, s3_object_keys, upstream_task_id, error_message,
- created_at, completed_at
- FROM sora_generations %s
- ORDER BY created_at DESC
- LIMIT $%d OFFSET $%d
- `, whereClause, argIdx, argIdx+1)
- args = append(args, params.PageSize, offset)
-
- rows, err := r.sql.QueryContext(ctx, listQuery, args...)
- if err != nil {
- return nil, 0, err
- }
- defer func() {
- _ = rows.Close()
- }()
-
- var results []*service.SoraGeneration
- for rows.Next() {
- gen := &service.SoraGeneration{}
- var mediaURLsJSON, s3KeysJSON []byte
- var completedAt sql.NullTime
- var apiKeyID sql.NullInt64
-
- if err := rows.Scan(
- &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
- &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
- &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
- &gen.CreatedAt, &completedAt,
- ); err != nil {
- return nil, 0, err
- }
-
- if apiKeyID.Valid {
- gen.APIKeyID = &apiKeyID.Int64
- }
- if completedAt.Valid {
- gen.CompletedAt = &completedAt.Time
- }
- _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
- _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
- results = append(results, gen)
- }
-
- return results, total, rows.Err()
-}
-
-func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) {
- if len(statuses) == 0 {
- return 0, nil
- }
-
- placeholders := make([]string, len(statuses))
- args := []any{userID}
- for i, s := range statuses {
- placeholders[i] = fmt.Sprintf("$%d", i+2)
- args = append(args, s)
- }
-
- var count int64
- query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ","))
- err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count)
- return count, err
-}
diff --git a/backend/internal/repository/usage_billing_repo.go b/backend/internal/repository/usage_billing_repo.go
index b4c76da5..62f48b58 100644
--- a/backend/internal/repository/usage_billing_repo.go
+++ b/backend/internal/repository/usage_billing_repo.go
@@ -113,9 +113,11 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
}
if cmd.BalanceCost > 0 {
- if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
+ newBalance, err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost)
+ if err != nil {
return err
}
+ result.NewBalance = &newBalance
}
if cmd.APIKeyQuotaCost > 0 {
@@ -133,9 +135,11 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
}
if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) {
- if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
+ quotaState, err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost)
+ if err != nil {
return err
}
+ result.QuotaState = quotaState
}
return nil
@@ -169,24 +173,22 @@ func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscrip
return service.ErrSubscriptionNotFound
}
-func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
- res, err := tx.ExecContext(ctx, `
+func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) (float64, error) {
+ var newBalance float64
+ err := tx.QueryRowContext(ctx, `
UPDATE users
SET balance = balance - $1,
updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
- `, amount, userID)
+ RETURNING balance
+ `, amount, userID).Scan(&newBalance)
+ if errors.Is(err, sql.ErrNoRows) {
+ return 0, service.ErrUserNotFound
+ }
if err != nil {
- return err
+ return 0, err
}
- affected, err := res.RowsAffected()
- if err != nil {
- return err
- }
- if affected > 0 {
- return nil
- }
- return service.ErrUserNotFound
+ return newBalance, nil
}
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
@@ -240,7 +242,7 @@ func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKe
return nil
}
-func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
+func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) (*service.AccountQuotaState, error) {
rows, err := tx.QueryContext(ctx,
`UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb)
@@ -248,61 +250,88 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_daily_used',
- CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
- + '24 hours'::interval <= NOW()
+ CASE WHEN `+dailyExpiredExpr+`
THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start',
- CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
- + '24 hours'::interval <= NOW()
+ CASE WHEN `+dailyExpiredExpr+`
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
)
+ || CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
+ THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
+ ELSE '{}'::jsonb END
ELSE '{}'::jsonb END
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object(
'quota_weekly_used',
- CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
- + '168 hours'::interval <= NOW()
+ CASE WHEN `+weeklyExpiredExpr+`
THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start',
- CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
- + '168 hours'::interval <= NOW()
+ CASE WHEN `+weeklyExpiredExpr+`
THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
)
+ || CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
+ THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
+ ELSE '{}'::jsonb END
ELSE '{}'::jsonb END
), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL
RETURNING
COALESCE((extra->>'quota_used')::numeric, 0),
- COALESCE((extra->>'quota_limit')::numeric, 0)`,
+ COALESCE((extra->>'quota_limit')::numeric, 0),
+ COALESCE((extra->>'quota_daily_used')::numeric, 0),
+ COALESCE((extra->>'quota_daily_limit')::numeric, 0),
+ COALESCE((extra->>'quota_weekly_used')::numeric, 0),
+ COALESCE((extra->>'quota_weekly_limit')::numeric, 0)`,
amount, accountID)
if err != nil {
- return err
+ return nil, err
}
- defer func() { _ = rows.Close() }()
- var newUsed, limit float64
+ var state service.AccountQuotaState
if rows.Next() {
- if err := rows.Scan(&newUsed, &limit); err != nil {
- return err
+ if err := rows.Scan(
+ &state.TotalUsed, &state.TotalLimit,
+ &state.DailyUsed, &state.DailyLimit,
+ &state.WeeklyUsed, &state.WeeklyLimit,
+ ); err != nil {
+ _ = rows.Close()
+ return nil, err
}
} else {
if err := rows.Err(); err != nil {
- return err
+ _ = rows.Close()
+ return nil, err
}
- return service.ErrAccountNotFound
+ _ = rows.Close()
+ return nil, service.ErrAccountNotFound
}
if err := rows.Err(); err != nil {
- return err
+ _ = rows.Close()
+ return nil, err
}
- if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
+ // 必须在执行下一条 SQL 前显式关闭 rows:pq 驱动在同一连接上
+ // 不允许前一条查询的结果集未耗尽时启动新查询,否则会返回
+ // "unexpected Parse response" 错误。
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ // 任意维度额度在本次递增中从"未超"跨越到"已超"时,必须刷新调度快照,
+ // 否则 Redis 中缓存的 Account 仍显示旧的 used 值,后续请求会继续选中本账号,
+ // 最终观察到 daily_used / weekly_used 大幅超过配置的 limit。
+ // 对于日/周额度,即使本次触发了周期重置(pre=0、post=amount),
+ // 判定式 (post-amount) < limit 同样成立,逻辑与总额度保持一致。
+ crossedTotal := state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit
+ crossedDaily := state.DailyLimit > 0 && state.DailyUsed >= state.DailyLimit && (state.DailyUsed-amount) < state.DailyLimit
+ crossedWeekly := state.WeeklyLimit > 0 && state.WeeklyUsed >= state.WeeklyLimit && (state.WeeklyUsed-amount) < state.WeeklyLimit
+ if crossedTotal || crossedDaily || crossedWeekly {
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
- return err
+ return nil, err
}
}
- return nil
+ return &state, nil
}
diff --git a/backend/internal/repository/usage_billing_repo_integration_test.go b/backend/internal/repository/usage_billing_repo_integration_test.go
index eda34cc9..e8d4d327 100644
--- a/backend/internal/repository/usage_billing_repo_integration_test.go
+++ b/backend/internal/repository/usage_billing_repo_integration_test.go
@@ -199,6 +199,94 @@ func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) {
require.InDelta(t, 3.5, quotaUsed, 0.000001)
}
+func TestUsageBillingRepositoryApply_EnqueuesSchedulerOutboxOnQuotaCrossing(t *testing.T) {
+ ctx := context.Background()
+ client := testEntClient(t)
+ repo := NewUsageBillingRepository(client, integrationDB)
+
+ newFixture := func(t *testing.T, extra map[string]any) (int64, int64) {
+ t.Helper()
+ user := mustCreateUser(t, client, &service.User{
+ Email: fmt.Sprintf("usage-billing-outbox-user-%d-%s@example.com", time.Now().UnixNano(), uuid.NewString()),
+ PasswordHash: "hash",
+ })
+ apiKey := mustCreateApiKey(t, client, &service.APIKey{
+ UserID: user.ID,
+ Key: "sk-usage-billing-outbox-" + uuid.NewString(),
+ Name: "billing-outbox",
+ })
+ account := mustCreateAccount(t, client, &service.Account{
+ Name: "usage-billing-outbox-" + uuid.NewString(),
+ Type: service.AccountTypeAPIKey,
+ Extra: extra,
+ })
+ return apiKey.ID, account.ID
+ }
+
+ outboxCountFor := func(t *testing.T, accountID int64) int {
+ t.Helper()
+ var count int
+ require.NoError(t, integrationDB.QueryRowContext(ctx,
+ "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1 AND account_id = $2",
+ service.SchedulerOutboxEventAccountChanged, accountID,
+ ).Scan(&count))
+ return count
+ }
+
+ t.Run("daily_first_crossing_enqueues", func(t *testing.T) {
+ apiKeyID, accountID := newFixture(t, map[string]any{
+ "quota_daily_limit": 10.0,
+ })
+ // 第一次低于日限额:不应入队 outbox
+ _, err := repo.Apply(ctx, &service.UsageBillingCommand{
+ RequestID: uuid.NewString(),
+ APIKeyID: apiKeyID,
+ AccountID: accountID,
+ AccountType: service.AccountTypeAPIKey,
+ AccountQuotaCost: 4,
+ })
+ require.NoError(t, err)
+ require.Equal(t, 0, outboxCountFor(t, accountID), "below limit should not enqueue")
+
+ // 第二次跨越日限额:应入队一次 outbox
+ _, err = repo.Apply(ctx, &service.UsageBillingCommand{
+ RequestID: uuid.NewString(),
+ APIKeyID: apiKeyID,
+ AccountID: accountID,
+ AccountType: service.AccountTypeAPIKey,
+ AccountQuotaCost: 8,
+ })
+ require.NoError(t, err)
+ require.Equal(t, 1, outboxCountFor(t, accountID), "crossing daily limit should enqueue once")
+
+ // 再次递增(已超):不应重复入队
+ _, err = repo.Apply(ctx, &service.UsageBillingCommand{
+ RequestID: uuid.NewString(),
+ APIKeyID: apiKeyID,
+ AccountID: accountID,
+ AccountType: service.AccountTypeAPIKey,
+ AccountQuotaCost: 2,
+ })
+ require.NoError(t, err)
+ require.Equal(t, 1, outboxCountFor(t, accountID), "subsequent increments beyond limit should not re-enqueue")
+ })
+
+ t.Run("weekly_first_crossing_enqueues", func(t *testing.T) {
+ apiKeyID, accountID := newFixture(t, map[string]any{
+ "quota_weekly_limit": 10.0,
+ })
+ _, err := repo.Apply(ctx, &service.UsageBillingCommand{
+ RequestID: uuid.NewString(),
+ APIKeyID: apiKeyID,
+ AccountID: accountID,
+ AccountType: service.AccountTypeAPIKey,
+ AccountQuotaCost: 15, // 单次即跨越
+ })
+ require.NoError(t, err)
+ require.Equal(t, 1, outboxCountFor(t, accountID), "single-shot crossing weekly limit should enqueue once")
+ })
+}
+
func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) {
ctx := context.Background()
repo := newDashboardAggregationRepositoryWithSQL(integrationDB)
diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go
index e4da825b..f2fb87da 100644
--- a/backend/internal/repository/usage_log_repo.go
+++ b/backend/internal/repository/usage_log_repo.go
@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
-const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at"
+const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
@@ -53,6 +53,8 @@ var usageLogInsertArgTypes = [...]string{
"integer", // cache_read_tokens
"integer", // cache_creation_5m_tokens
"integer", // cache_creation_1h_tokens
+ "integer", // image_output_tokens
+ "numeric", // image_output_cost
"numeric", // input_cost
"numeric", // output_cost
"numeric", // cache_creation_cost
@@ -71,12 +73,16 @@ var usageLogInsertArgTypes = [...]string{
"text", // ip_address
"integer", // image_count
"text", // image_size
- "text", // media_type
"text", // service_tier
"text", // reasoning_effort
"text", // inbound_endpoint
"text", // upstream_endpoint
"boolean", // cache_ttl_overridden
+ "bigint", // channel_id
+ "text", // model_mapping_chain
+ "text", // billing_tier
+ "text", // billing_mode
+ "numeric", // account_stats_cost
"timestamptz", // created_at
}
@@ -326,6 +332,8 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
+ image_output_tokens,
+ image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -344,20 +352,24 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
+ channel_id,
+ model_mapping_chain,
+ billing_tier,
+ billing_mode,
+ account_stats_cost,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
$8, $9,
$10, $11, $12, $13,
- $14, $15,
- $16, $17, $18, $19, $20, $21,
- $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
+ $14, $15, $16, $17,
+ $18, $19, $20, $21, $22, $23,
+ $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -758,6 +770,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
+ image_output_tokens,
+ image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -776,16 +790,20 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
+ channel_id,
+ model_mapping_chain,
+ billing_tier,
+ billing_mode,
+ account_stats_cost,
created_at
) AS (VALUES `)
- args := make([]any, 0, len(keys)*39)
+ args := make([]any, 0, len(keys)*46)
argPos := 1
for idx, key := range keys {
if idx > 0 {
@@ -829,6 +847,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
+ image_output_tokens,
+ image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -847,12 +867,16 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
+ channel_id,
+ model_mapping_chain,
+ billing_tier,
+ billing_mode,
+ account_stats_cost,
created_at
)
SELECT
@@ -871,6 +895,8 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
+ image_output_tokens,
+ image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -889,12 +915,16 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
+ channel_id,
+ model_mapping_chain,
+ billing_tier,
+ billing_mode,
+ account_stats_cost,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
@@ -953,6 +983,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
+ image_output_tokens,
+ image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -971,16 +1003,20 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
+ channel_id,
+ model_mapping_chain,
+ billing_tier,
+ billing_mode,
+ account_stats_cost,
created_at
) AS (VALUES `)
- args := make([]any, 0, len(preparedList)*40)
+ args := make([]any, 0, len(preparedList)*46)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
@@ -1021,6 +1057,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
+ image_output_tokens,
+ image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -1039,12 +1077,16 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
+ channel_id,
+ model_mapping_chain,
+ billing_tier,
+ billing_mode,
+ account_stats_cost,
created_at
)
SELECT
@@ -1063,6 +1105,8 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
+ image_output_tokens,
+ image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -1081,12 +1125,16 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
+ channel_id,
+ model_mapping_chain,
+ billing_tier,
+ billing_mode,
+ account_stats_cost,
created_at
FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING
@@ -1113,6 +1161,8 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
cache_read_tokens,
cache_creation_5m_tokens,
cache_creation_1h_tokens,
+ image_output_tokens,
+ image_output_cost,
input_cost,
output_cost,
cache_creation_cost,
@@ -1131,20 +1181,24 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
ip_address,
image_count,
image_size,
- media_type,
service_tier,
reasoning_effort,
inbound_endpoint,
upstream_endpoint,
cache_ttl_overridden,
+ channel_id,
+ model_mapping_chain,
+ billing_tier,
+ billing_mode,
+ account_stats_cost,
created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7,
$8, $9,
$10, $11, $12, $13,
- $14, $15,
- $16, $17, $18, $19, $20, $21,
- $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40
+ $14, $15, $16, $17,
+ $18, $19, $20, $21, $22, $23,
+ $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
@@ -1171,11 +1225,14 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize)
- mediaType := nullString(log.MediaType)
serviceTier := nullString(log.ServiceTier)
reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint)
upstreamEndpoint := nullString(log.UpstreamEndpoint)
+ channelID := nullInt64(log.ChannelID)
+ modelMappingChain := nullString(log.ModelMappingChain)
+ billingTier := nullString(log.BillingTier)
+ billingMode := nullString(log.BillingMode)
requestedModel := strings.TrimSpace(log.RequestedModel)
if requestedModel == "" {
requestedModel = strings.TrimSpace(log.Model)
@@ -1208,6 +1265,8 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
+ log.ImageOutputTokens,
+ log.ImageOutputCost,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
@@ -1226,12 +1285,16 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
ipAddress,
log.ImageCount,
imageSize,
- mediaType,
serviceTier,
reasoningEffort,
inboundEndpoint,
upstreamEndpoint,
log.CacheTTLOverridden,
+ channelID,
+ modelMappingChain,
+ billingTier,
+ billingMode,
+ log.AccountStatsCost, // account_stats_cost
createdAt,
},
}
@@ -1465,6 +1528,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
+ COALESCE(SUM(account_cost), 0) as total_account_cost,
COALESCE(SUM(total_duration_ms), 0) as total_duration_ms
FROM usage_dashboard_daily
`
@@ -1481,6 +1545,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
+ &stats.TotalAccountCost,
&totalDurationMs,
); err != nil {
return err
@@ -1499,6 +1564,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
cache_read_tokens as today_cache_read_tokens,
total_cost as today_cost,
actual_cost as today_actual_cost,
+ account_cost as today_account_cost,
active_users as active_users
FROM usage_dashboard_daily
WHERE bucket_date = $1::date
@@ -1515,6 +1581,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
+ &stats.TodayAccountCost,
&stats.ActiveUsers,
); err != nil {
if err != sql.ErrNoRows {
@@ -1550,6 +1617,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
cache_read_tokens,
total_cost,
actual_cost,
+ COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1) AS account_cost,
COALESCE(duration_ms, 0) AS duration_ms
FROM usage_logs
WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
@@ -1563,6 +1631,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens,
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost,
+ COALESCE(SUM(account_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_account_cost,
COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms,
COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests,
COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens,
@@ -1570,7 +1639,8 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens,
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost,
- COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost
+ COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost,
+ COALESCE(SUM(account_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_account_cost
FROM scoped
`
var totalDurationMs int64
@@ -1586,6 +1656,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
+ &stats.TotalAccountCost,
&totalDurationMs,
&stats.TodayRequests,
&stats.TodayInputTokens,
@@ -1594,6 +1665,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
+ &stats.TodayAccountCost,
); err != nil {
return err
}
@@ -1906,7 +1978,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
- COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
@@ -1936,7 +2008,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
- COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
@@ -1973,7 +2045,7 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
account_id,
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
- COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
@@ -2532,7 +2604,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
- COALESCE(SUM(actual_cost), 0) as actual_cost
+ COALESCE(SUM(actual_cost), 0) as actual_cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as account_cost
FROM usage_logs
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY model
@@ -2564,8 +2637,8 @@ type UsageLogFilters = usagestats.UsageLogFilters
// ListWithFilters lists usage logs with optional filters (for admin)
func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
- conditions := make([]string, 0, 8)
- args := make([]any, 0, 8)
+ conditions := make([]string, 0, 9)
+ args := make([]any, 0, 9)
if filters.UserID > 0 {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
@@ -2589,6 +2662,10 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
+ if filters.BillingMode != "" {
+ conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
+ args = append(args, filters.BillingMode)
+ }
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
@@ -2933,8 +3010,9 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if accountID > 0 && userID == 0 && apiKeyID == 0 {
- actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
+ actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
+ accountCostExpr := "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as account_cost"
modelExpr := resolveModelDimensionExpression(source)
query := fmt.Sprintf(`
@@ -2947,10 +3025,11 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
+ %s,
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
- `, modelExpr, actualCostExpr)
+ `, modelExpr, actualCostExpr, accountCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
@@ -3005,7 +3084,8 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
COUNT(*) as requests,
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(ul.total_cost), 0) as cost,
- COALESCE(SUM(ul.actual_cost), 0) as actual_cost
+ COALESCE(SUM(ul.actual_cost), 0) as actual_cost,
+ COALESCE(SUM(COALESCE(ul.account_stats_cost, ul.total_cost) * COALESCE(ul.account_rate_multiplier, 1)), 0) as account_cost
FROM usage_logs ul
LEFT JOIN groups g ON g.id = ul.group_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
@@ -3056,6 +3136,7 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
&row.TotalTokens,
&row.Cost,
&row.ActualCost,
+ &row.AccountCost,
); err != nil {
return nil, err
}
@@ -3076,7 +3157,8 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
COUNT(*) as requests,
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(ul.total_cost), 0) as cost,
- COALESCE(SUM(ul.actual_cost), 0) as actual_cost
+ COALESCE(SUM(ul.actual_cost), 0) as actual_cost,
+ COALESCE(SUM(COALESCE(ul.account_stats_cost, ul.total_cost) * COALESCE(ul.account_rate_multiplier, 1)), 0) as account_cost
FROM usage_logs ul
LEFT JOIN users u ON u.id = ul.user_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
@@ -3096,6 +3178,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1)
args = append(args, dim.Endpoint)
}
+ if dim.UserID > 0 {
+ query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1)
+ args = append(args, dim.UserID)
+ }
+ if dim.APIKeyID > 0 {
+ query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1)
+ args = append(args, dim.APIKeyID)
+ }
+ if dim.AccountID > 0 {
+ query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1)
+ args = append(args, dim.AccountID)
+ }
+ if dim.RequestType != nil {
+ query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1)
+ args = append(args, *dim.RequestType)
+ }
+ if dim.Stream != nil {
+ query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1)
+ args = append(args, *dim.Stream)
+ }
+ if dim.BillingType != nil {
+ query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1)
+ args = append(args, *dim.BillingType)
+ }
query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC"
if limit > 0 {
@@ -3123,6 +3229,7 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
&row.TotalTokens,
&row.Cost,
&row.ActualCost,
+ &row.AccountCost,
); err != nil {
return nil, err
}
@@ -3256,6 +3363,10 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
+ if filters.BillingMode != "" {
+ conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
+ args = append(args, filters.BillingMode)
+ }
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
@@ -3273,7 +3384,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
- COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
@@ -3297,9 +3408,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
); err != nil {
return nil, err
}
- if filters.AccountID > 0 {
- stats.TotalAccountCost = &totalAccountCost
- }
+ stats.TotalAccountCost = &totalAccountCost
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
start := time.Unix(0, 0).UTC()
@@ -3348,7 +3457,7 @@ type EndpointStat = usagestats.EndpointStat
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
- actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
+ actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
@@ -3415,7 +3524,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 {
- actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
+ actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
@@ -3506,7 +3615,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
- COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
+ COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
@@ -3686,7 +3795,7 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
limitPos := len(args) + 1
offsetPos := len(args) + 2
listArgs := append(append([]any{}, args...), params.Limit(), params.Offset())
- query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos)
+ query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos)
logs, err := r.queryUsageLogs(ctx, query, listArgs...)
if err != nil {
return nil, nil, err
@@ -3701,7 +3810,7 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context
limitPos := len(args) + 1
offsetPos := len(args) + 2
listArgs := append(append([]any{}, args...), limit+1, offset)
- query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos)
+ query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos)
logs, err := r.queryUsageLogs(ctx, query, listArgs...)
if err != nil {
@@ -3723,6 +3832,26 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context
return logs, paginationResultFromTotal(total, params), nil
}
+func usageLogOrderBy(params pagination.PaginationParams) string {
+ sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
+ sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderDesc))
+
+ var column string
+ switch sortBy {
+ case "model":
+ column = "COALESCE(NULLIF(TRIM(requested_model), ''), model)"
+ case "created_at":
+ column = "created_at"
+ default:
+ column = "id"
+ }
+
+ if column == "id" {
+ return fmt.Sprintf("id %s", sortOrder)
+ }
+ return fmt.Sprintf("%s %s, id %s", column, sortOrder, sortOrder)
+}
+
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
@@ -3935,6 +4064,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
cacheReadTokens int
cacheCreation5m int
cacheCreation1h int
+ imageOutputTokens int
+ imageOutputCost float64
inputCost float64
outputCost float64
cacheCreationCost float64
@@ -3953,12 +4084,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
ipAddress sql.NullString
imageCount int
imageSize sql.NullString
- mediaType sql.NullString
serviceTier sql.NullString
reasoningEffort sql.NullString
inboundEndpoint sql.NullString
upstreamEndpoint sql.NullString
cacheTTLOverridden bool
+ channelID sql.NullInt64
+ modelMappingChain sql.NullString
+ billingTier sql.NullString
+ billingMode sql.NullString
+ accountStatsCost sql.NullFloat64
createdAt time.Time
)
@@ -3979,6 +4114,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&cacheReadTokens,
&cacheCreation5m,
&cacheCreation1h,
+ &imageOutputTokens,
+ &imageOutputCost,
&inputCost,
&outputCost,
&cacheCreationCost,
@@ -3997,12 +4134,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&ipAddress,
&imageCount,
&imageSize,
- &mediaType,
&serviceTier,
&reasoningEffort,
&inboundEndpoint,
&upstreamEndpoint,
&cacheTTLOverridden,
+ &channelID,
+ &modelMappingChain,
+ &billingTier,
+ &billingMode,
+ &accountStatsCost,
&createdAt,
); err != nil {
return nil, err
@@ -4021,6 +4162,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
CacheReadTokens: cacheReadTokens,
CacheCreation5mTokens: cacheCreation5m,
CacheCreation1hTokens: cacheCreation1h,
+ ImageOutputTokens: imageOutputTokens,
+ ImageOutputCost: imageOutputCost,
InputCost: inputCost,
OutputCost: outputCost,
CacheCreationCost: cacheCreationCost,
@@ -4069,9 +4212,6 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if imageSize.Valid {
log.ImageSize = &imageSize.String
}
- if mediaType.Valid {
- log.MediaType = &mediaType.String
- }
if serviceTier.Valid {
log.ServiceTier = &serviceTier.String
}
@@ -4087,6 +4227,22 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if upstreamModel.Valid {
log.UpstreamModel = &upstreamModel.String
}
+ if channelID.Valid {
+ value := channelID.Int64
+ log.ChannelID = &value
+ }
+ if modelMappingChain.Valid {
+ log.ModelMappingChain = &modelMappingChain.String
+ }
+ if billingTier.Valid {
+ log.BillingTier = &billingTier.String
+ }
+ if billingMode.Valid {
+ log.BillingMode = &billingMode.String
+ }
+ if accountStatsCost.Valid {
+ log.AccountStatsCost = &accountStatsCost.Float64
+ }
return log, nil
}
@@ -4130,6 +4286,7 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
&row.TotalTokens,
&row.Cost,
&row.ActualCost,
+ &row.AccountCost,
); err != nil {
return nil, err
}
diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go
index 0383f3bc..ed3050d8 100644
--- a/backend/internal/repository/usage_log_repo_integration_test.go
+++ b/backend/internal/repository/usage_log_repo_integration_test.go
@@ -753,8 +753,11 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
s.Require().Equal(baseStats.TotalTokens+int64(51), stats.TotalTokens, "TotalTokens mismatch")
s.Require().Equal(baseStats.TotalCost+2.3, stats.TotalCost, "TotalCost mismatch")
s.Require().Equal(baseStats.TotalActualCost+2.0, stats.TotalActualCost, "TotalActualCost mismatch")
+ // account_cost falls back to total_cost when account_stats_cost is NULL
+ s.Require().Equal(baseStats.TotalAccountCost+2.3, stats.TotalAccountCost, "TotalAccountCost mismatch")
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")
+ s.Require().GreaterOrEqual(stats.TodayAccountCost, 0.0, "expected TodayAccountCost >= 0")
wantRpm, wantTpm, err := s.repo.getPerformanceStats(s.ctx, 0)
s.Require().NoError(err, "getPerformanceStats")
@@ -833,6 +836,8 @@ func (s *UsageLogRepoSuite) TestDashboardStatsWithRange_Fallback() {
s.Require().Equal(int64(45), stats.TotalTokens)
s.Require().Equal(1.5, stats.TotalCost)
s.Require().Equal(1.4, stats.TotalActualCost)
+ // account_cost = COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1) = total_cost
+ s.Require().Equal(1.5, stats.TotalAccountCost)
s.Require().InEpsilon(150.0, stats.AverageDurationMs, 0.0001)
}
diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go
index ebc8929a..a5ff4bc1 100644
--- a/backend/internal/repository/usage_log_repo_request_type_test.go
+++ b/backend/internal/repository/usage_log_repo_request_type_test.go
@@ -56,6 +56,8 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
+ log.ImageOutputTokens,
+ log.ImageOutputCost,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
@@ -74,12 +76,16 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // ip_address
log.ImageCount,
sqlmock.AnyArg(), // image_size
- sqlmock.AnyArg(), // media_type
sqlmock.AnyArg(), // service_tier
sqlmock.AnyArg(), // reasoning_effort
sqlmock.AnyArg(), // inbound_endpoint
sqlmock.AnyArg(), // upstream_endpoint
log.CacheTTLOverridden,
+ sqlmock.AnyArg(), // channel_id
+ sqlmock.AnyArg(), // model_mapping_chain
+ sqlmock.AnyArg(), // billing_tier
+ sqlmock.AnyArg(), // billing_mode
+ sqlmock.AnyArg(), // account_stats_cost
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
@@ -129,6 +135,8 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
+ log.ImageOutputTokens,
+ log.ImageOutputCost,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
@@ -147,12 +155,16 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(),
log.ImageCount,
sqlmock.AnyArg(),
- sqlmock.AnyArg(),
serviceTier,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
sqlmock.AnyArg(),
log.CacheTTLOverridden,
+ sqlmock.AnyArg(), // channel_id
+ sqlmock.AnyArg(), // model_mapping_chain
+ sqlmock.AnyArg(), // billing_tier
+ sqlmock.AnyArg(), // billing_mode
+ sqlmock.AnyArg(), // account_stats_cost
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
@@ -289,7 +301,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
WithArgs(start, end, requestType).
- WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
+ WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost", "account_cost"}))
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
require.NoError(t, err)
@@ -320,11 +332,107 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
"total_account_cost",
"avg_duration_ms",
}).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0))
+ mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(inbound_endpoint\\), ''\\), 'unknown'\\) AS endpoint").
+ WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType).
+ WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
+ mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(upstream_endpoint\\), ''\\), 'unknown'\\) AS endpoint").
+ WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType).
+ WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
+ mock.ExpectQuery("SELECT CONCAT\\(").
+ WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType).
+ WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
stats, err := repo.GetStatsWithFilters(context.Background(), filters)
require.NoError(t, err)
require.Equal(t, int64(1), stats.TotalRequests)
require.Equal(t, int64(9), stats.TotalTokens)
+ require.NotNil(t, stats.TotalAccountCost, "TotalAccountCost should always be returned")
+ require.Equal(t, 1.2, *stats.TotalAccountCost)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageLogRepositoryGetModelStatsAccountCostColumn(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageLogRepository{sql: db}
+
+ start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+
+ mock.ExpectQuery("FROM usage_logs").
+ WithArgs(start, end).
+ WillReturnRows(sqlmock.NewRows([]string{
+ "model", "requests", "input_tokens", "output_tokens",
+ "cache_creation_tokens", "cache_read_tokens", "total_tokens",
+ "cost", "actual_cost", "account_cost",
+ }).
+ AddRow("claude-opus-4-6", int64(10), int64(100), int64(200), int64(5), int64(3), int64(308), 2.5, 2.0, 1.8).
+ AddRow("claude-sonnet-4-6", int64(5), int64(50), int64(100), int64(0), int64(0), int64(150), 1.0, 0.8, 0.7))
+
+ results, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, nil, nil, nil)
+ require.NoError(t, err)
+ require.Len(t, results, 2)
+ require.Equal(t, "claude-opus-4-6", results[0].Model)
+ require.Equal(t, 2.5, results[0].Cost)
+ require.Equal(t, 2.0, results[0].ActualCost)
+ require.Equal(t, 1.8, results[0].AccountCost)
+ require.Equal(t, "claude-sonnet-4-6", results[1].Model)
+ require.Equal(t, 0.7, results[1].AccountCost)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageLogRepositoryGetGroupStatsAccountCostColumn(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageLogRepository{sql: db}
+
+ start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+
+ mock.ExpectQuery("FROM usage_logs").
+ WithArgs(start, end).
+ WillReturnRows(sqlmock.NewRows([]string{
+ "group_id", "group_name", "requests", "total_tokens",
+ "cost", "actual_cost", "account_cost",
+ }).
+ AddRow(int64(1), "azure-cc", int64(100), int64(5000), 10.0, 8.5, 7.2).
+ AddRow(int64(2), "max", int64(50), int64(2000), 5.0, 4.0, 3.5))
+
+ results, err := repo.GetGroupStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, nil, nil, nil)
+ require.NoError(t, err)
+ require.Len(t, results, 2)
+ require.Equal(t, int64(1), results[0].GroupID)
+ require.Equal(t, "azure-cc", results[0].GroupName)
+ require.Equal(t, 10.0, results[0].Cost)
+ require.Equal(t, 8.5, results[0].ActualCost)
+ require.Equal(t, 7.2, results[0].AccountCost)
+ require.Equal(t, int64(2), results[1].GroupID)
+ require.Equal(t, 3.5, results[1].AccountCost)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageLogRepositoryGetStatsWithFiltersAlwaysReturnsAccountCost(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageLogRepository{sql: db}
+
+ // No AccountID filter set - TotalAccountCost should still be returned
+ filters := usagestats.UsageLogFilters{}
+
+ mock.ExpectQuery("FROM usage_logs").
+ WillReturnRows(sqlmock.NewRows([]string{
+ "total_requests", "total_input_tokens", "total_output_tokens",
+ "total_cache_tokens", "total_cost", "total_actual_cost",
+ "total_account_cost", "avg_duration_ms",
+ }).AddRow(int64(50), int64(1000), int64(2000), int64(100), 15.0, 12.5, 11.0, 100.0))
+ mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(inbound_endpoint\\)").
+ WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
+ mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(upstream_endpoint\\)").
+ WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
+ mock.ExpectQuery("SELECT CONCAT\\(").
+ WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
+
+ stats, err := repo.GetStatsWithFilters(context.Background(), filters)
+ require.NoError(t, err)
+ require.NotNil(t, stats.TotalAccountCost, "TotalAccountCost must always be returned, even without AccountID filter")
+ require.Equal(t, 11.0, *stats.TotalAccountCost)
require.NoError(t, mock.ExpectationsWereMet())
}
@@ -439,6 +547,8 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
4, // cache_read_tokens
5, // cache_creation_5m_tokens
6, // cache_creation_1h_tokens
+ 0, // image_output_tokens
+ 0.0, // image_output_cost
0.1, // input_cost
0.2, // output_cost
0.3, // cache_creation_cost
@@ -457,12 +567,16 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
0,
sql.NullString{},
- sql.NullString{},
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
+ sql.NullInt64{}, // channel_id
+ sql.NullString{}, // model_mapping_chain
+ sql.NullString{}, // billing_tier
+ sql.NullString{}, // billing_mode
+ sql.NullFloat64{}, // account_stats_cost
now,
}})
require.NoError(t, err)
@@ -487,6 +601,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
+ 0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0,
sql.NullFloat64{},
@@ -500,12 +615,16 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
0,
sql.NullString{},
- sql.NullString{},
sql.NullString{Valid: true, String: "flex"},
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
+ sql.NullInt64{}, // channel_id
+ sql.NullString{}, // model_mapping_chain
+ sql.NullString{}, // billing_tier
+ sql.NullString{}, // billing_mode
+ sql.NullFloat64{}, // account_stats_cost
now,
}})
require.NoError(t, err)
@@ -530,6 +649,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
+ 0, 0.0, // image_output_tokens, image_output_cost
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0,
sql.NullFloat64{},
@@ -543,12 +663,16 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
0,
sql.NullString{},
- sql.NullString{},
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
+ sql.NullInt64{}, // channel_id
+ sql.NullString{}, // model_mapping_chain
+ sql.NullString{}, // billing_tier
+ sql.NullString{}, // billing_mode
+ sql.NullFloat64{}, // account_stats_cost
now,
}})
require.NoError(t, err)
diff --git a/backend/internal/repository/usage_log_repo_sort_integration_test.go b/backend/internal/repository/usage_log_repo_sort_integration_test.go
new file mode 100644
index 00000000..4c69f975
--- /dev/null
+++ b/backend/internal/repository/usage_log_repo_sort_integration_test.go
@@ -0,0 +1,61 @@
+//go:build integration
+
+package repository
+
+import (
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/google/uuid"
+)
+
+func (s *UsageLogRepoSuite) TestListWithFilters_SortByModelAsc() {
+ user := mustCreateUser(s.T(), s.client, &service.User{Email: "usage-sort@example.com"})
+ apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usage-sort", Name: "k"})
+ account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "usage-sort-account"})
+
+ first := &service.UsageLog{
+ UserID: user.ID,
+ APIKeyID: apiKey.ID,
+ AccountID: account.ID,
+ RequestID: uuid.New().String(),
+ Model: "z-model",
+ RequestedModel: "z-model",
+ InputTokens: 10,
+ OutputTokens: 20,
+ TotalCost: 0.5,
+ ActualCost: 0.5,
+ CreatedAt: time.Now(),
+ }
+ _, err := s.repo.Create(s.ctx, first)
+ s.Require().NoError(err)
+
+ second := &service.UsageLog{
+ UserID: user.ID,
+ APIKeyID: apiKey.ID,
+ AccountID: account.ID,
+ RequestID: uuid.New().String(),
+ Model: "a-model",
+ RequestedModel: "a-model",
+ InputTokens: 10,
+ OutputTokens: 20,
+ TotalCost: 0.5,
+ ActualCost: 0.5,
+ CreatedAt: time.Now().Add(time.Second),
+ }
+ _, err = s.repo.Create(s.ctx, second)
+ s.Require().NoError(err)
+
+ logs, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "model",
+ SortOrder: "asc",
+ }, usagestats.UsageLogFilters{UserID: user.ID})
+ s.Require().NoError(err)
+ s.Require().Len(logs, 2)
+ s.Require().Equal("a-model", logs[0].RequestedModel)
+ s.Require().Equal("z-model", logs[1].RequestedModel)
+}
diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go
index e2471ae5..74d25cb0 100644
--- a/backend/internal/repository/user_group_rate_repo.go
+++ b/backend/internal/repository/user_group_rate_repo.go
@@ -13,14 +13,14 @@ type userGroupRateRepository struct {
sql sqlExecutor
}
-// NewUserGroupRateRepository 创建用户专属分组倍率仓储
+// NewUserGroupRateRepository 创建用户专属分组倍率/RPM 仓储
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
return &userGroupRateRepository{sql: sqlDB}
}
-// GetByUserID 获取用户的所有专属分组倍率
+// GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目)
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
- query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
+ query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NOT NULL`
rows, err := r.sql.QueryContext(ctx, query, userID)
if err != nil {
return nil, err
@@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
return result, nil
}
-// GetByUserIDs 批量获取多个用户的专属分组倍率。
-// 返回结构:map[userID]map[groupID]rate
+// GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier(仅返回非 NULL 的条目)
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
result := make(map[int64]map[int64]float64, len(userIDs))
if len(userIDs) == 0 {
@@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
rows, err := r.sql.QueryContext(ctx, `
SELECT user_id, group_id, rate_multiplier
FROM user_group_rate_multipliers
- WHERE user_id = ANY($1)
+ WHERE user_id = ANY($1) AND rate_multiplier IS NOT NULL
`, pq.Array(uniqueIDs))
if err != nil {
return nil, err
@@ -95,12 +94,12 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in
return result, nil
}
-// GetByGroupID 获取指定分组下所有用户的专属倍率
+// GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回)
func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) {
query := `
- SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
+ SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier, ugr.rpm_override
FROM user_group_rate_multipliers ugr
- JOIN users u ON u.id = ugr.user_id
+ JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
WHERE ugr.group_id = $1
ORDER BY ugr.user_id
`
@@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
var result []service.UserGroupRateEntry
for rows.Next() {
var entry service.UserGroupRateEntry
- if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil {
+ var rate sql.NullFloat64
+ var rpm sql.NullInt32
+ if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &rate, &rpm); err != nil {
return nil, err
}
+ if rate.Valid {
+ v := rate.Float64
+ entry.RateMultiplier = &v
+ }
+ if rpm.Valid {
+ v := int(rpm.Int32)
+ entry.RPMOverride = &v
+ }
result = append(result, entry)
}
if err := rows.Err(); err != nil {
@@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
return result, nil
}
-// GetByUserAndGroup 获取用户在特定分组的专属倍率
+// GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil)
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
- var rate float64
+ var rate sql.NullFloat64
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
if err == sql.ErrNoRows {
return nil, nil
@@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID,
if err != nil {
return nil, err
}
- return &rate, nil
+ if !rate.Valid {
+ return nil, nil
+ }
+ v := rate.Float64
+ return &v, nil
}
-// SyncUserGroupRates 同步用户的分组专属倍率
+// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil)
+func (r *userGroupRateRepository) GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) {
+ query := `SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
+ var rpm sql.NullInt32
+ err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rpm)
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ if !rpm.Valid {
+ return nil, nil
+ }
+ v := int(rpm.Int32)
+ return &v, nil
+}
+
+// SyncUserGroupRates 同步用户的分组专属 rate_multiplier。
+// - 传入空 map:清空该用户所有行的 rate_multiplier;若 rpm_override 也为 NULL 则整行删除。
+// - 值为 nil:清空对应行的 rate_multiplier(保留 rpm_override)。
+// - 值非 nil:upsert rate_multiplier(保留已有 rpm_override)。
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
if len(rates) == 0 {
- // 如果传入空 map,删除该用户的所有专属倍率
- _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rate_multiplier = NULL, updated_at = NOW()
+ WHERE user_id = $1
+ `, userID); err != nil {
+ return err
+ }
+ _, err := r.sql.ExecContext(ctx,
+ `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`,
+ userID)
return err
}
- // 分离需要删除和需要 upsert 的记录
- var toDelete []int64
+ var clearGroupIDs []int64
upsertGroupIDs := make([]int64, 0, len(rates))
upsertRates := make([]float64, 0, len(rates))
for groupID, rate := range rates {
if rate == nil {
- toDelete = append(toDelete, groupID)
+ clearGroupIDs = append(clearGroupIDs, groupID)
} else {
upsertGroupIDs = append(upsertGroupIDs, groupID)
upsertRates = append(upsertRates, *rate)
}
}
- // 删除指定的记录
- if len(toDelete) > 0 {
+ if len(clearGroupIDs) > 0 {
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rate_multiplier = NULL, updated_at = NOW()
+ WHERE user_id = $1 AND group_id = ANY($2)
+ `, userID, pq.Array(clearGroupIDs)); err != nil {
+ return err
+ }
if _, err := r.sql.ExecContext(ctx,
- `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`,
- userID, pq.Array(toDelete)); err != nil {
+ `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2) AND rate_multiplier IS NULL AND rpm_override IS NULL`,
+ userID, pq.Array(clearGroupIDs)); err != nil {
return err
}
}
- // Upsert 记录
- now := time.Now()
if len(upsertGroupIDs) > 0 {
+ now := time.Now()
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
SELECT
@@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
return nil
}
-// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插)
+// SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override)。
+// 语义:
+// - 未出现在 entries 中的用户行:rate_multiplier 归 NULL;若 rpm_override 也为 NULL 则整行删除。
+// - 出现的用户行:upsert rate_multiplier。
func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error {
- if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil {
+ keepUserIDs := make([]int64, 0, len(entries))
+ for _, e := range entries {
+ keepUserIDs = append(keepUserIDs, e.UserID)
+ }
+
+ // 未在 entries 列表中的行:清空 rate_multiplier。
+ if len(keepUserIDs) == 0 {
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rate_multiplier = NULL, updated_at = NOW()
+ WHERE group_id = $1
+ `, groupID); err != nil {
+ return err
+ }
+ } else {
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rate_multiplier = NULL, updated_at = NOW()
+ WHERE group_id = $1 AND user_id <> ALL($2)
+ `, groupID, pq.Array(keepUserIDs)); err != nil {
+ return err
+ }
+ }
+
+ // 清空后若整行 NULL 则删除。
+ if _, err := r.sql.ExecContext(ctx, `
+ DELETE FROM user_group_rate_multipliers
+ WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
+ `, groupID); err != nil {
return err
}
+
if len(entries) == 0 {
return nil
}
+
userIDs := make([]int64, len(entries))
rates := make([]float64, len(entries))
for i, e := range entries {
@@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context,
return err
}
-// DeleteByGroupID 删除指定分组的所有用户专属倍率
+// SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier)。
+// 语义:
+// - 未出现的用户行:rpm_override 归 NULL;若 rate_multiplier 也为 NULL 则整行删除。
+// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。
+func (r *userGroupRateRepository) SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []service.GroupRPMOverrideInput) error {
+ keepUserIDs := make([]int64, 0, len(entries))
+ var clearUserIDs []int64
+ upsertUserIDs := make([]int64, 0, len(entries))
+ upsertValues := make([]int32, 0, len(entries))
+ for _, e := range entries {
+ keepUserIDs = append(keepUserIDs, e.UserID)
+ if e.RPMOverride == nil {
+ clearUserIDs = append(clearUserIDs, e.UserID)
+ } else {
+ upsertUserIDs = append(upsertUserIDs, e.UserID)
+ upsertValues = append(upsertValues, int32(*e.RPMOverride))
+ }
+ }
+
+ // 未在 entries 列表中的行:清空 rpm_override。
+ if len(keepUserIDs) == 0 {
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rpm_override = NULL, updated_at = NOW()
+ WHERE group_id = $1
+ `, groupID); err != nil {
+ return err
+ }
+ } else {
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rpm_override = NULL, updated_at = NOW()
+ WHERE group_id = $1 AND user_id <> ALL($2)
+ `, groupID, pq.Array(keepUserIDs)); err != nil {
+ return err
+ }
+ }
+
+ // 显式 clear 的行。
+ if len(clearUserIDs) > 0 {
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rpm_override = NULL, updated_at = NOW()
+ WHERE group_id = $1 AND user_id = ANY($2)
+ `, groupID, pq.Array(clearUserIDs)); err != nil {
+ return err
+ }
+ }
+
+ // 清空后若整行 NULL 则删除。
+ if _, err := r.sql.ExecContext(ctx, `
+ DELETE FROM user_group_rate_multipliers
+ WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
+ `, groupID); err != nil {
+ return err
+ }
+
+ if len(upsertUserIDs) > 0 {
+ now := time.Now()
+ _, err := r.sql.ExecContext(ctx, `
+ INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at)
+ SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz
+ FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override)
+ ON CONFLICT (user_id, group_id)
+ DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at
+ `, groupID, now, pq.Array(upsertUserIDs), pq.Array(upsertValues))
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。
+func (r *userGroupRateRepository) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
+ if _, err := r.sql.ExecContext(ctx, `
+ UPDATE user_group_rate_multipliers
+ SET rpm_override = NULL, updated_at = NOW()
+ WHERE group_id = $1
+ `, groupID); err != nil {
+ return err
+ }
+ _, err := r.sql.ExecContext(ctx, `
+ DELETE FROM user_group_rate_multipliers
+ WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL
+ `, groupID)
+ return err
+}
+
+// DeleteByGroupID 删除指定分组的所有用户专属条目
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
return err
}
-// DeleteByUserID 删除指定用户的所有专属倍率
+// DeleteByUserID 删除指定用户的所有专属条目
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err
diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go
new file mode 100644
index 00000000..b2b03746
--- /dev/null
+++ b/backend/internal/repository/user_profile_identity_repo.go
@@ -0,0 +1,880 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "hash/fnv"
+ "reflect"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+ "unsafe"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+var (
+ ErrAuthIdentityOwnershipConflict = infraerrors.Conflict(
+ "AUTH_IDENTITY_OWNERSHIP_CONFLICT",
+ "auth identity already belongs to another user",
+ )
+ ErrAuthIdentityChannelOwnershipConflict = infraerrors.Conflict(
+ "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT",
+ "auth identity channel already belongs to another user",
+ )
+ ErrAuthIdentityChannelProviderMismatch = infraerrors.BadRequest(
+ "AUTH_IDENTITY_CHANNEL_PROVIDER_MISMATCH",
+ "auth identity channel provider must match canonical identity",
+ )
+)
+
+type ProviderGrantReason string
+
+const (
+ ProviderGrantReasonSignup ProviderGrantReason = "signup"
+ ProviderGrantReasonFirstBind ProviderGrantReason = "first_bind"
+)
+
+type AuthIdentityKey struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+}
+
+type AuthIdentityChannelKey struct {
+ ProviderType string
+ ProviderKey string
+ Channel string
+ ChannelAppID string
+ ChannelSubject string
+}
+
+type CreateAuthIdentityInput struct {
+ UserID int64
+ Canonical AuthIdentityKey
+ Channel *AuthIdentityChannelKey
+ Issuer *string
+ VerifiedAt *time.Time
+ Metadata map[string]any
+ ChannelMetadata map[string]any
+}
+
+type BindAuthIdentityInput = CreateAuthIdentityInput
+
+type CreateAuthIdentityResult struct {
+ Identity *dbent.AuthIdentity
+ Channel *dbent.AuthIdentityChannel
+}
+
+func (r *CreateAuthIdentityResult) IdentityRef() AuthIdentityKey {
+ if r == nil || r.Identity == nil {
+ return AuthIdentityKey{}
+ }
+ return AuthIdentityKey{
+ ProviderType: r.Identity.ProviderType,
+ ProviderKey: r.Identity.ProviderKey,
+ ProviderSubject: r.Identity.ProviderSubject,
+ }
+}
+
+func (r *CreateAuthIdentityResult) ChannelRef() *AuthIdentityChannelKey {
+ if r == nil || r.Channel == nil {
+ return nil
+ }
+ return &AuthIdentityChannelKey{
+ ProviderType: r.Channel.ProviderType,
+ ProviderKey: r.Channel.ProviderKey,
+ Channel: r.Channel.Channel,
+ ChannelAppID: r.Channel.ChannelAppID,
+ ChannelSubject: r.Channel.ChannelSubject,
+ }
+}
+
+type UserAuthIdentityLookup struct {
+ User *dbent.User
+ Identity *dbent.AuthIdentity
+ Channel *dbent.AuthIdentityChannel
+}
+
+type ProviderGrantRecordInput struct {
+ UserID int64
+ ProviderType string
+ GrantReason ProviderGrantReason
+}
+
+type IdentityAdoptionDecisionInput struct {
+ PendingAuthSessionID int64
+ IdentityID *int64
+ AdoptDisplayName bool
+ AdoptAvatar bool
+}
+
+type sqlQueryExecutor interface {
+ ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
+ QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
+}
+
+var repositoryScopedKeyLocks = newScopedKeyLockRegistry()
+
+type scopedKeyLockRegistry struct {
+ mu sync.Mutex
+ locks map[string]*scopedKeyLockEntry
+}
+
+type scopedKeyLockEntry struct {
+ mu sync.Mutex
+ refs int
+}
+
+func newScopedKeyLockRegistry() *scopedKeyLockRegistry {
+ return &scopedKeyLockRegistry{
+ locks: make(map[string]*scopedKeyLockEntry),
+ }
+}
+
+func (r *scopedKeyLockRegistry) lock(keys ...string) func() {
+ normalized := normalizeLockKeys(keys...)
+ if len(normalized) == 0 {
+ return func() {}
+ }
+
+ entries := make([]*scopedKeyLockEntry, 0, len(normalized))
+ r.mu.Lock()
+ for _, key := range normalized {
+ entry := r.locks[key]
+ if entry == nil {
+ entry = &scopedKeyLockEntry{}
+ r.locks[key] = entry
+ }
+ entry.refs++
+ entries = append(entries, entry)
+ }
+ r.mu.Unlock()
+
+ for _, entry := range entries {
+ entry.mu.Lock()
+ }
+
+ return func() {
+ for i := len(entries) - 1; i >= 0; i-- {
+ entries[i].mu.Unlock()
+ }
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ for idx, key := range normalized {
+ entry := entries[idx]
+ entry.refs--
+ if entry.refs == 0 {
+ delete(r.locks, key)
+ }
+ }
+ }
+}
+
+func normalizeLockKeys(keys ...string) []string {
+ if len(keys) == 0 {
+ return nil
+ }
+
+ deduped := make(map[string]struct{}, len(keys))
+ for _, key := range keys {
+ trimmed := strings.TrimSpace(key)
+ if trimmed == "" {
+ continue
+ }
+ deduped[trimmed] = struct{}{}
+ }
+ if len(deduped) == 0 {
+ return nil
+ }
+
+ normalized := make([]string, 0, len(deduped))
+ for key := range deduped {
+ normalized = append(normalized, key)
+ }
+ sort.Strings(normalized)
+ return normalized
+}
+
+func advisoryLockHash(key string) int64 {
+ hasher := fnv.New64a()
+ _, _ = hasher.Write([]byte(key))
+ return int64(hasher.Sum64())
+}
+
+func lockRepositoryScopedKeys(ctx context.Context, client *dbent.Client, exec sqlQueryExecutor, keys ...string) (func(), error) {
+ release := repositoryScopedKeyLocks.lock(keys...)
+ normalized := normalizeLockKeys(keys...)
+ if len(normalized) == 0 || client == nil || exec == nil || client.Driver().Dialect() != dialect.Postgres {
+ return release, nil
+ }
+
+ for _, key := range normalized {
+ rows, err := exec.QueryContext(ctx, "SELECT pg_advisory_xact_lock($1)", advisoryLockHash(key))
+ if err != nil {
+ release()
+ return nil, err
+ }
+ _ = rows.Close()
+ }
+ return release, nil
+}
+
+func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ if dbent.TxFromContext(ctx) != nil {
+ return fn(ctx)
+ }
+
+ tx, err := r.client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := fn(txCtx); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) {
+ if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
+ return nil, err
+ }
+
+ client := clientFromContext(ctx, r.client)
+
+ create := client.AuthIdentity.Create().
+ SetUserID(input.UserID).
+ SetProviderType(strings.TrimSpace(input.Canonical.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Canonical.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(input.Canonical.ProviderSubject)).
+ SetMetadata(copyMetadata(input.Metadata)).
+ SetNillableIssuer(input.Issuer).
+ SetNillableVerifiedAt(input.VerifiedAt)
+
+ identity, err := create.Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ var channel *dbent.AuthIdentityChannel
+ if input.Channel != nil {
+ channel, err = client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
+ SetChannel(strings.TrimSpace(input.Channel.Channel)).
+ SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
+ SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
+ SetMetadata(copyMetadata(input.ChannelMetadata)).
+ Save(ctx)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return &CreateAuthIdentityResult{Identity: identity, Channel: channel}, nil
+}
+
+func (r *userRepository) GetUserByCanonicalIdentity(ctx context.Context, key AuthIdentityKey) (*UserAuthIdentityLookup, error) {
+ identity, err := clientFromContext(ctx, r.client).AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
+ authidentity.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(key.ProviderSubject)),
+ ).
+ WithUser().
+ Only(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return &UserAuthIdentityLookup{
+ User: identity.Edges.User,
+ Identity: identity,
+ }, nil
+}
+
+func (r *userRepository) GetUserByChannelIdentity(ctx context.Context, key AuthIdentityChannelKey) (*UserAuthIdentityLookup, error) {
+ channel, err := clientFromContext(ctx, r.client).AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
+ authidentitychannel.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
+ authidentitychannel.ChannelEQ(strings.TrimSpace(key.Channel)),
+ authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(key.ChannelAppID)),
+ authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(key.ChannelSubject)),
+ ).
+ WithIdentity(func(q *dbent.AuthIdentityQuery) {
+ q.WithUser()
+ }).
+ Only(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return &UserAuthIdentityLookup{
+ User: channel.Edges.Identity.Edges.User,
+ Identity: channel.Edges.Identity,
+ Channel: channel,
+ }, nil
+}
+
+func (r *userRepository) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ identities, err := clientFromContext(ctx, r.client).AuthIdentity.Query().
+ Where(authidentity.UserIDEQ(userID)).
+ All(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ records := make([]service.UserAuthIdentityRecord, 0, len(identities))
+ for _, identity := range identities {
+ if identity == nil {
+ continue
+ }
+ records = append(records, service.UserAuthIdentityRecord{
+ ProviderType: strings.TrimSpace(identity.ProviderType),
+ ProviderKey: strings.TrimSpace(identity.ProviderKey),
+ ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
+ VerifiedAt: identity.VerifiedAt,
+ Issuer: identity.Issuer,
+ Metadata: copyMetadata(identity.Metadata),
+ CreatedAt: identity.CreatedAt,
+ UpdatedAt: identity.UpdatedAt,
+ })
+ }
+
+ return records, nil
+}
+
+func (r *userRepository) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error {
+ provider = strings.ToLower(strings.TrimSpace(provider))
+ if provider == "" || provider == "email" {
+ return service.ErrIdentityProviderInvalid
+ }
+
+ return r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
+ client := clientFromContext(txCtx, r.client)
+ identityIDs, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ(provider),
+ ).
+ IDs(txCtx)
+ if err != nil {
+ return err
+ }
+ if len(identityIDs) == 0 {
+ return nil
+ }
+
+ if _, err := client.IdentityAdoptionDecision.Update().
+ Where(identityadoptiondecision.IdentityIDIn(identityIDs...)).
+ ClearIdentityID().
+ Save(txCtx); err != nil {
+ return err
+ }
+ if _, err := client.AuthIdentityChannel.Delete().
+ Where(authidentitychannel.IdentityIDIn(identityIDs...)).
+ Exec(txCtx); err != nil {
+ return err
+ }
+ _, err = client.AuthIdentity.Delete().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ(provider),
+ ).
+ Exec(txCtx)
+ return err
+ })
+}
+
+func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) {
+ if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
+ return nil, err
+ }
+
+ var result *CreateAuthIdentityResult
+ err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
+ client := clientFromContext(txCtx, r.client)
+ canonical := input.Canonical
+
+ identityRecords, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)),
+ authidentity.ProviderKeyIn(compatibleIdentityProviderKeys(canonical.ProviderType, canonical.ProviderKey)...),
+ authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)),
+ ).
+ All(txCtx)
+ if err != nil {
+ return err
+ }
+ identity := selectOwnedCompatibleIdentity(identityRecords, input.UserID)
+ if identity == nil && hasCompatibleIdentityConflict(identityRecords, input.UserID) {
+ return ErrAuthIdentityOwnershipConflict
+ }
+ if identity == nil {
+ identity, err = client.AuthIdentity.Create().
+ SetUserID(input.UserID).
+ SetProviderType(strings.TrimSpace(canonical.ProviderType)).
+ SetProviderKey(strings.TrimSpace(canonical.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(canonical.ProviderSubject)).
+ SetMetadata(copyMetadata(input.Metadata)).
+ SetNillableIssuer(input.Issuer).
+ SetNillableVerifiedAt(input.VerifiedAt).
+ Save(txCtx)
+ if err != nil {
+ return err
+ }
+ } else {
+ targetProviderKey := canonicalizeCompatibleIdentityProviderKey(canonical.ProviderType, identity.ProviderKey, canonical.ProviderKey)
+ update := client.AuthIdentity.UpdateOneID(identity.ID)
+ if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, identity.ProviderKey) {
+ update = update.SetProviderKey(targetProviderKey)
+ }
+ if input.Metadata != nil {
+ update = update.SetMetadata(copyMetadata(input.Metadata))
+ }
+ if input.Issuer != nil {
+ update = update.SetIssuer(strings.TrimSpace(*input.Issuer))
+ }
+ if input.VerifiedAt != nil {
+ update = update.SetVerifiedAt(*input.VerifiedAt)
+ }
+ identity, err = update.Save(txCtx)
+ if err != nil {
+ return err
+ }
+ }
+
+ var channel *dbent.AuthIdentityChannel
+ if input.Channel != nil {
+ channelRecords, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)),
+ authidentitychannel.ProviderKeyIn(compatibleIdentityProviderKeys(input.Channel.ProviderType, input.Channel.ProviderKey)...),
+ authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)),
+ authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)),
+ authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)),
+ ).
+ WithIdentity().
+ All(txCtx)
+ if err != nil {
+ return err
+ }
+ channel = selectOwnedCompatibleChannel(channelRecords, input.UserID)
+ if channel == nil && hasCompatibleChannelConflict(channelRecords, input.UserID) {
+ return ErrAuthIdentityChannelOwnershipConflict
+ }
+ if channel == nil {
+ channel, err = client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
+ SetChannel(strings.TrimSpace(input.Channel.Channel)).
+ SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
+ SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
+ SetMetadata(copyMetadata(input.ChannelMetadata)).
+ Save(txCtx)
+ if err != nil {
+ return err
+ }
+ } else {
+ targetProviderKey := canonicalizeCompatibleIdentityProviderKey(input.Channel.ProviderType, channel.ProviderKey, input.Channel.ProviderKey)
+ update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
+ SetIdentityID(identity.ID)
+ if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, channel.ProviderKey) {
+ update = update.SetProviderKey(targetProviderKey)
+ }
+ if input.ChannelMetadata != nil {
+ update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
+ }
+ channel, err = update.Save(txCtx)
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ result = &CreateAuthIdentityResult{Identity: identity, Channel: channel}
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func compatibleIdentityProviderKeys(providerType, providerKey string) []string {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ providerKey = strings.TrimSpace(providerKey)
+ if providerKey == "" {
+ return []string{providerKey}
+ }
+ if providerType != "wechat" {
+ return []string{providerKey}
+ }
+ keys := []string{providerKey}
+ if !strings.EqualFold(providerKey, "wechat-main") {
+ keys = append(keys, "wechat-main")
+ }
+ if !strings.EqualFold(providerKey, "wechat") {
+ keys = append(keys, "wechat")
+ }
+ return keys
+}
+
+func canonicalizeCompatibleIdentityProviderKey(providerType, existingKey, requestedKey string) string {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ existingKey = strings.TrimSpace(existingKey)
+ requestedKey = strings.TrimSpace(requestedKey)
+ if providerType != "wechat" {
+ if requestedKey != "" {
+ return requestedKey
+ }
+ return existingKey
+ }
+ if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
+ return "wechat-main"
+ }
+ if requestedKey != "" {
+ return requestedKey
+ }
+ return existingKey
+}
+
+func compatibleIdentityProviderKeyRank(providerType, providerKey string) int {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ providerKey = strings.TrimSpace(providerKey)
+ if providerType != "wechat" {
+ return 0
+ }
+ switch {
+ case strings.EqualFold(providerKey, "wechat-main"):
+ return 0
+ case strings.EqualFold(providerKey, "wechat"):
+ return 2
+ default:
+ return 1
+ }
+}
+
+func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
+ var selected *dbent.AuthIdentity
+ for _, record := range records {
+ if record.UserID != userID {
+ continue
+ }
+ if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
+ selected = record
+ }
+ }
+ return selected
+}
+
+func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool {
+ for _, record := range records {
+ if record.UserID != userID {
+ return true
+ }
+ }
+ return false
+}
+
+func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
+ var selected *dbent.AuthIdentityChannel
+ for _, record := range records {
+ if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
+ continue
+ }
+ if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
+ selected = record
+ }
+ }
+ return selected
+}
+
+func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
+ for _, record := range records {
+ if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
+ return true
+ }
+ }
+ return false
+}
+
+func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return false, fmt.Errorf("sql executor is not configured")
+ }
+
+ result, err := exec.ExecContext(ctx, `
+INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason)
+VALUES ($1, $2, $3)
+ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
+ input.UserID,
+ strings.TrimSpace(input.ProviderType),
+ string(input.GrantReason),
+ )
+ if err != nil {
+ return false, err
+ }
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
+ var result *dbent.IdentityAdoptionDecision
+ err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
+ client := clientFromContext(txCtx, r.client)
+ releaseLocks, err := lockRepositoryScopedKeys(
+ txCtx,
+ client,
+ txAwareSQLExecutor(txCtx, r.sql, r.client),
+ identityAdoptionDecisionLockKeys(input.PendingAuthSessionID, input.IdentityID)...,
+ )
+ if err != nil {
+ return err
+ }
+ defer releaseLocks()
+
+ if input.IdentityID != nil && *input.IdentityID > 0 {
+ if _, err := client.IdentityAdoptionDecision.Update().
+ Where(
+ identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
+ dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
+ col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
+ s.Where(entsql.Or(
+ entsql.IsNull(col),
+ entsql.NEQ(col, input.PendingAuthSessionID),
+ ))
+ }),
+ ).
+ ClearIdentityID().
+ Save(txCtx); err != nil {
+ return err
+ }
+ }
+
+ create := client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(input.PendingAuthSessionID).
+ SetAdoptDisplayName(input.AdoptDisplayName).
+ SetAdoptAvatar(input.AdoptAvatar).
+ SetDecidedAt(time.Now().UTC())
+ if input.IdentityID != nil && *input.IdentityID > 0 {
+ create = create.SetIdentityID(*input.IdentityID)
+ }
+
+ decisionID, err := create.
+ OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
+ UpdateNewValues().
+ ID(txCtx)
+ if err != nil {
+ return err
+ }
+
+ result, err = client.IdentityAdoptionDecision.Get(txCtx, decisionID)
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func identityAdoptionDecisionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
+ keys := []string{fmt.Sprintf("identity-adoption:pending:%d", pendingAuthSessionID)}
+ if identityID != nil && *identityID > 0 {
+ keys = append(keys, fmt.Sprintf("identity-adoption:identity:%d", *identityID))
+ }
+ return keys
+}
+
+func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {
+ return clientFromContext(ctx, r.client).IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingAuthSessionID)).
+ Only(ctx)
+}
+
+func (r *userRepository) UpdateUserLastLoginAt(ctx context.Context, userID int64, loginAt time.Time) error {
+ _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
+ SetLastLoginAt(loginAt).
+ Save(ctx)
+ return err
+}
+
+func (r *userRepository) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
+ SetLastActiveAt(activeAt).
+ Save(ctx)
+ return err
+}
+
+func (r *userRepository) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ exec, err := r.userProfileIdentitySQL(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ rows, err := exec.QueryContext(ctx, `
+SELECT storage_provider, storage_key, url, content_type, byte_size, sha256
+FROM user_avatars
+WHERE user_id = $1`, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ return nil, rows.Err()
+ }
+
+ var avatar service.UserAvatar
+ if err := rows.Scan(
+ &avatar.StorageProvider,
+ &avatar.StorageKey,
+ &avatar.URL,
+ &avatar.ContentType,
+ &avatar.ByteSize,
+ &avatar.SHA256,
+ ); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return &avatar, nil
+}
+
+func (r *userRepository) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ exec, err := r.userProfileIdentitySQL(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = exec.ExecContext(ctx, `
+INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at)
+VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
+ON CONFLICT (user_id) DO UPDATE SET
+ storage_provider = EXCLUDED.storage_provider,
+ storage_key = EXCLUDED.storage_key,
+ url = EXCLUDED.url,
+ content_type = EXCLUDED.content_type,
+ byte_size = EXCLUDED.byte_size,
+ sha256 = EXCLUDED.sha256,
+ updated_at = NOW()`,
+ userID,
+ strings.TrimSpace(input.StorageProvider),
+ strings.TrimSpace(input.StorageKey),
+ strings.TrimSpace(input.URL),
+ strings.TrimSpace(input.ContentType),
+ input.ByteSize,
+ strings.TrimSpace(input.SHA256),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ return &service.UserAvatar{
+ StorageProvider: strings.TrimSpace(input.StorageProvider),
+ StorageKey: strings.TrimSpace(input.StorageKey),
+ URL: strings.TrimSpace(input.URL),
+ ContentType: strings.TrimSpace(input.ContentType),
+ ByteSize: input.ByteSize,
+ SHA256: strings.TrimSpace(input.SHA256),
+ }, nil
+}
+
+func (r *userRepository) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ exec, err := r.userProfileIdentitySQL(ctx)
+ if err != nil {
+ return err
+ }
+ _, err = exec.ExecContext(ctx, `DELETE FROM user_avatars WHERE user_id = $1`, userID)
+ return err
+}
+
+func copyMetadata(in map[string]any) map[string]any {
+ if len(in) == 0 {
+ return map[string]any{}
+ }
+ out := make(map[string]any, len(in))
+ for k, v := range in {
+ out[k] = v
+ }
+ return out
+}
+
+func validateAuthIdentityChannelProviderMatch(canonical AuthIdentityKey, channel *AuthIdentityChannelKey) error {
+ if channel == nil {
+ return nil
+ }
+
+ canonicalProviderType := strings.TrimSpace(canonical.ProviderType)
+ canonicalProviderKey := strings.TrimSpace(canonical.ProviderKey)
+ channelProviderType := strings.TrimSpace(channel.ProviderType)
+ channelProviderKey := strings.TrimSpace(channel.ProviderKey)
+
+ if canonicalProviderType != channelProviderType || canonicalProviderKey != channelProviderKey {
+ return ErrAuthIdentityChannelProviderMismatch
+ }
+
+ return nil
+}
+
+func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor {
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil {
+ return exec
+ }
+ }
+ if fallback != nil {
+ return fallback
+ }
+ return sqlExecutorFromEntClient(client)
+}
+
+func (r *userRepository) userProfileIdentitySQL(ctx context.Context) (sqlQueryExecutor, error) {
+ exec := txAwareSQLExecutor(ctx, r.sql, r.client)
+ if exec == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+ return exec, nil
+}
+
+func sqlExecutorFromEntClient(client *dbent.Client) sqlQueryExecutor {
+ if client == nil {
+ return nil
+ }
+
+ clientValue := reflect.ValueOf(client).Elem()
+ configValue := clientValue.FieldByName("config")
+ driverValue := configValue.FieldByName("driver")
+ if !driverValue.IsValid() {
+ return nil
+ }
+
+ driver := reflect.NewAt(driverValue.Type(), unsafe.Pointer(driverValue.UnsafeAddr())).Elem().Interface()
+ exec, ok := driver.(sqlQueryExecutor)
+ if !ok {
+ return nil
+ }
+ return exec
+}
diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go
new file mode 100644
index 00000000..d4f9e8b3
--- /dev/null
+++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go
@@ -0,0 +1,578 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/suite"
+)
+
+type UserProfileIdentityRepoSuite struct {
+ suite.Suite
+ ctx context.Context
+ client *dbent.Client
+ repo *userRepository
+}
+
+func TestUserProfileIdentityRepoSuite(t *testing.T) {
+ suite.Run(t, new(UserProfileIdentityRepoSuite))
+}
+
+func (s *UserProfileIdentityRepoSuite) SetupTest() {
+ s.ctx = context.Background()
+ s.client = testEntClient(s.T())
+ s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
+
+ _, err := integrationDB.ExecContext(s.ctx, `
+TRUNCATE TABLE
+ identity_adoption_decisions,
+ auth_identity_channels,
+ auth_identities,
+ pending_auth_sessions,
+ user_provider_default_grants,
+ user_avatars
+RESTART IDENTITY`)
+ s.Require().NoError(err)
+}
+
+func (s *UserProfileIdentityRepoSuite) mustCreateUser(label string) *dbent.User {
+ s.T().Helper()
+
+ user, err := s.client.User.Create().
+ SetEmail(fmt.Sprintf("%s-%d@example.com", label, time.Now().UnixNano())).
+ SetPasswordHash("test-password-hash").
+ SetRole("user").
+ SetStatus("active").
+ Save(s.ctx)
+ s.Require().NoError(err)
+ return user
+}
+
+func (s *UserProfileIdentityRepoSuite) mustCreatePendingAuthSession(key AuthIdentityKey) *dbent.PendingAuthSession {
+ s.T().Helper()
+
+ session, err := s.client.PendingAuthSession.Create().
+ SetSessionToken(fmt.Sprintf("pending-%d", time.Now().UnixNano())).
+ SetIntent("bind_current_user").
+ SetProviderType(key.ProviderType).
+ SetProviderKey(key.ProviderKey).
+ SetProviderSubject(key.ProviderSubject).
+ SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
+ SetUpstreamIdentityClaims(map[string]any{"provider_subject": key.ProviderSubject}).
+ SetLocalFlowState(map[string]any{"step": "pending"}).
+ Save(s.ctx)
+ s.Require().NoError(err)
+ return session
+}
+
+func (s *UserProfileIdentityRepoSuite) TestCreateAndLookupCanonicalAndChannelIdentity() {
+ user := s.mustCreateUser("canonical-channel")
+
+ verifiedAt := time.Now().UTC().Truncate(time.Second)
+ created, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-123",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ Channel: "mp",
+ ChannelAppID: "wx-app",
+ ChannelSubject: "openid-123",
+ },
+ Issuer: stringPtr("https://issuer.example"),
+ VerifiedAt: &verifiedAt,
+ Metadata: map[string]any{"unionid": "union-123"},
+ ChannelMetadata: map[string]any{"openid": "openid-123"},
+ })
+ s.Require().NoError(err)
+ s.Require().NotNil(created.Identity)
+ s.Require().NotNil(created.Channel)
+
+ canonical, err := s.repo.GetUserByCanonicalIdentity(s.ctx, created.IdentityRef())
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, canonical.User.ID)
+ s.Require().Equal(created.Identity.ID, canonical.Identity.ID)
+ s.Require().Equal("union-123", canonical.Identity.ProviderSubject)
+
+ channel, err := s.repo.GetUserByChannelIdentity(s.ctx, *created.ChannelRef())
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, channel.User.ID)
+ s.Require().Equal(created.Identity.ID, channel.Identity.ID)
+ s.Require().Equal(created.Channel.ID, channel.Channel.ID)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAndRejectsOtherOwners() {
+ owner := s.mustCreateUser("owner")
+ other := s.mustCreateUser("other")
+
+ first, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: owner.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "linuxdo-web",
+ ChannelSubject: "subject-1",
+ },
+ Metadata: map[string]any{"username": "first"},
+ ChannelMetadata: map[string]any{"scope": "read"},
+ })
+ s.Require().NoError(err)
+
+ second, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: owner.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "linuxdo-web",
+ ChannelSubject: "subject-1",
+ },
+ Metadata: map[string]any{"username": "second"},
+ ChannelMetadata: map[string]any{"scope": "write"},
+ })
+ s.Require().NoError(err)
+ s.Require().Equal(first.Identity.ID, second.Identity.ID)
+ s.Require().Equal(first.Channel.ID, second.Channel.ID)
+ s.Require().Equal("second", second.Identity.Metadata["username"])
+ s.Require().Equal("write", second.Channel.Metadata["scope"])
+
+ _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: other.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityOwnershipConflict)
+
+ _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: other.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-2",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "linuxdo-web",
+ ChannelSubject: "subject-1",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_ReusesLegacyWeChatAliasRecords() {
+ user := s.mustCreateUser("wechat-legacy-alias")
+
+ legacyIdentity, err := s.client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetProviderSubject("union-legacy-123").
+ SetMetadata(map[string]any{"source": "legacy-alias"}).
+ Save(s.ctx)
+ s.Require().NoError(err)
+
+ legacyChannel, err := s.client.AuthIdentityChannel.Create().
+ SetIdentityID(legacyIdentity.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetChannel("oa").
+ SetChannelAppID("wx-app-legacy").
+ SetChannelSubject("openid-legacy-123").
+ SetMetadata(map[string]any{"scene": "legacy-alias"}).
+ Save(s.ctx)
+ s.Require().NoError(err)
+
+ bound, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-legacy-123",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ Channel: "oa",
+ ChannelAppID: "wx-app-legacy",
+ ChannelSubject: "openid-legacy-123",
+ },
+ Metadata: map[string]any{"source": "canonical-bind"},
+ ChannelMetadata: map[string]any{"scene": "canonical-bind"},
+ })
+ s.Require().NoError(err)
+ s.Require().NotNil(bound)
+ s.Require().NotNil(bound.Identity)
+ s.Require().NotNil(bound.Channel)
+ s.Require().Equal(legacyIdentity.ID, bound.Identity.ID)
+ s.Require().Equal(legacyChannel.ID, bound.Channel.ID)
+ s.Require().Equal("wechat-main", bound.Identity.ProviderKey)
+ s.Require().Equal("wechat-main", bound.Channel.ProviderKey)
+ s.Require().Equal("canonical-bind", bound.Identity.Metadata["source"])
+ s.Require().Equal("canonical-bind", bound.Channel.Metadata["scene"])
+
+ identityCount, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderSubjectEQ("union-legacy-123"),
+ ).
+ Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(1, identityCount)
+
+ channelCount, err := s.client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ChannelEQ("oa"),
+ authidentitychannel.ChannelAppIDEQ("wx-app-legacy"),
+ authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
+ ).
+ Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(1, channelCount)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestCreateAuthIdentity_RejectsChannelProviderMismatch() {
+ user := s.mustCreateUser("provider-mismatch-create")
+
+ _, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-create-mismatch",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ Channel: "oauth",
+ ChannelAppID: "app-mismatch",
+ ChannelSubject: "openid-create-mismatch",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_RejectsChannelProviderMismatch() {
+ user := s.mustCreateUser("provider-mismatch-bind")
+
+ _, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-bind-mismatch",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-legacy",
+ Channel: "oa",
+ ChannelAppID: "wx-app-bind-mismatch",
+ ChannelSubject: "openid-bind-mismatch",
+ },
+ })
+ s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() {
+ user := s.mustCreateUser("tx-rollback")
+ expectedErr := errors.New("rollback")
+
+ err := s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
+ _, err := s.repo.CreateAuthIdentity(txCtx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-rollback",
+ },
+ })
+ s.Require().NoError(err)
+
+ inserted, err := s.repo.RecordProviderGrant(txCtx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "oidc",
+ GrantReason: ProviderGrantReasonFirstBind,
+ })
+ s.Require().NoError(err)
+ s.Require().True(inserted)
+ return expectedErr
+ })
+ s.Require().ErrorIs(err, expectedErr)
+
+ _, err = s.repo.GetUserByCanonicalIdentity(s.ctx, AuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-rollback",
+ })
+ s.Require().True(dbent.IsNotFound(err))
+
+ var count int
+ s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
+SELECT COUNT(*)
+FROM user_provider_default_grants
+WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3`,
+ user.ID,
+ "oidc",
+ string(ProviderGrantReasonFirstBind),
+ ).Scan(&count))
+ s.Require().Zero(count)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestRecordProviderGrant_IsIdempotentPerReason() {
+ user := s.mustCreateUser("grant")
+
+ inserted, err := s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "wechat",
+ GrantReason: ProviderGrantReasonFirstBind,
+ })
+ s.Require().NoError(err)
+ s.Require().True(inserted)
+
+ inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "wechat",
+ GrantReason: ProviderGrantReasonFirstBind,
+ })
+ s.Require().NoError(err)
+ s.Require().False(inserted)
+
+ inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
+ UserID: user.ID,
+ ProviderType: "wechat",
+ GrantReason: ProviderGrantReasonSignup,
+ })
+ s.Require().NoError(err)
+ s.Require().True(inserted)
+
+ var count int
+ s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
+SELECT COUNT(*)
+FROM user_provider_default_grants
+WHERE user_id = $1 AND provider_type = $2`,
+ user.ID,
+ "wechat",
+ ).Scan(&count))
+ s.Require().Equal(2, count)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_PersistsAndLinksIdentity() {
+ user := s.mustCreateUser("adoption")
+ identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-adoption",
+ },
+ })
+ s.Require().NoError(err)
+
+ session := s.mustCreatePendingAuthSession(identity.IdentityRef())
+
+ first, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ s.Require().NoError(err)
+ s.Require().True(first.AdoptDisplayName)
+ s.Require().False(first.AdoptAvatar)
+ s.Require().Nil(first.IdentityID)
+
+ second, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.Identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ })
+ s.Require().NoError(err)
+ s.Require().Equal(first.ID, second.ID)
+ s.Require().NotNil(second.IdentityID)
+ s.Require().Equal(identity.Identity.ID, *second.IdentityID)
+ s.Require().True(second.AdoptAvatar)
+
+ loaded, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, session.ID)
+ s.Require().NoError(err)
+ s.Require().Equal(second.ID, loaded.ID)
+ s.Require().Equal(identity.Identity.ID, *loaded.IdentityID)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_ReassignsExistingIdentityReference() {
+ user := s.mustCreateUser("adoption-reassign")
+ identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-adoption-reassign",
+ },
+ })
+ s.Require().NoError(err)
+
+ firstSession := s.mustCreatePendingAuthSession(identity.IdentityRef())
+ firstDecision, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: firstSession.ID,
+ IdentityID: &identity.Identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ s.Require().NoError(err)
+ s.Require().NotNil(firstDecision.IdentityID)
+ s.Require().Equal(identity.Identity.ID, *firstDecision.IdentityID)
+
+ secondSession := s.mustCreatePendingAuthSession(identity.IdentityRef())
+ secondDecision, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: secondSession.ID,
+ IdentityID: &identity.Identity.ID,
+ AdoptDisplayName: false,
+ AdoptAvatar: true,
+ })
+ s.Require().NoError(err)
+ s.Require().NotNil(secondDecision.IdentityID)
+ s.Require().Equal(identity.Identity.ID, *secondDecision.IdentityID)
+
+ reloadedFirst, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, firstSession.ID)
+ s.Require().NoError(err)
+ s.Require().Nil(reloadedFirst.IdentityID)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_AllowsAvatarOnlyProfileUpdate() {
+ user := s.mustCreateUser("avatar-only-update")
+
+ model, err := s.repo.GetByID(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(model)
+
+ err = s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
+ _, err := s.repo.UpsertUserAvatar(txCtx, user.ID, service.UpsertUserAvatarInput{
+ StorageProvider: "remote_url",
+ URL: "https://cdn.example.com/avatar.png",
+ })
+ if err != nil {
+ return err
+ }
+ return s.repo.Update(txCtx, model)
+ })
+ s.Require().NoError(err)
+
+ avatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(avatar)
+ s.Require().Equal("https://cdn.example.com/avatar.png", avatar.URL)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() {
+ user := s.mustCreateUser("avatar")
+
+ inlineAvatar, err := s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
+ StorageProvider: "inline",
+ URL: "data:image/png;base64,QUJD",
+ ContentType: "image/png",
+ ByteSize: 3,
+ SHA256: "902fbdd2b1df0c4f70b4a5d23525e932",
+ })
+ s.Require().NoError(err)
+ s.Require().Equal("inline", inlineAvatar.StorageProvider)
+ s.Require().Equal("data:image/png;base64,QUJD", inlineAvatar.URL)
+
+ loadedAvatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(loadedAvatar)
+ s.Require().Equal("image/png", loadedAvatar.ContentType)
+ s.Require().Equal(3, loadedAvatar.ByteSize)
+
+ _, err = s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
+ StorageProvider: "remote_url",
+ URL: "https://cdn.example.com/avatar.png",
+ })
+ s.Require().NoError(err)
+
+ loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(loadedAvatar)
+ s.Require().Equal("remote_url", loadedAvatar.StorageProvider)
+ s.Require().Equal("https://cdn.example.com/avatar.png", loadedAvatar.URL)
+ s.Require().Zero(loadedAvatar.ByteSize)
+
+ s.Require().NoError(s.repo.DeleteUserAvatar(s.ctx, user.ID))
+ loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().Nil(loadedAvatar)
+}
+
+func (s *UserProfileIdentityRepoSuite) TestUpdateUserLastLoginAndActiveAt_UsesDedicatedColumns() {
+ user := s.mustCreateUser("activity")
+ loginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
+ activeAt := loginAt.Add(5 * time.Minute)
+
+ s.Require().NoError(s.repo.UpdateUserLastLoginAt(s.ctx, user.ID, loginAt))
+ s.Require().NoError(s.repo.UpdateUserLastActiveAt(s.ctx, user.ID, activeAt))
+
+ var storedLoginAt sqlNullTime
+ var storedActiveAt sqlNullTime
+ s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
+SELECT last_login_at, last_active_at
+FROM users
+WHERE id = $1`,
+ user.ID,
+ ).Scan(&storedLoginAt, &storedActiveAt))
+ s.Require().True(storedLoginAt.Valid)
+ s.Require().True(storedActiveAt.Valid)
+ s.Require().True(storedLoginAt.Time.Equal(loginAt))
+ s.Require().True(storedActiveAt.Time.Equal(activeAt))
+}
+
+type sqlNullTime struct {
+ Time time.Time
+ Valid bool
+}
+
+func (t *sqlNullTime) Scan(value any) error {
+ switch v := value.(type) {
+ case time.Time:
+ t.Time = v
+ t.Valid = true
+ return nil
+ case nil:
+ t.Time = time.Time{}
+ t.Valid = false
+ return nil
+ default:
+ return fmt.Errorf("unsupported scan type %T", value)
+ }
+}
+
+func stringPtr(v string) *string {
+ return &v
+}
diff --git a/backend/internal/repository/user_profile_identity_repo_unit_test.go b/backend/internal/repository/user_profile_identity_repo_unit_test.go
new file mode 100644
index 00000000..689f32f9
--- /dev/null
+++ b/backend/internal/repository/user_profile_identity_repo_unit_test.go
@@ -0,0 +1,212 @@
+package repository
+
+import (
+ "context"
+ "sync"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUserRepositoryBindAuthIdentityToUserCanonicalizesLegacyWeChatAlias(t *testing.T) {
+ repo, client := newUserEntRepo(t)
+ ctx := context.Background()
+
+ user := &service.User{
+ Email: "wechat-legacy@example.com",
+ Username: "wechat-legacy",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, repo.Create(ctx, user))
+
+ legacyIdentity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetProviderSubject("union-legacy-123").
+ SetMetadata(map[string]any{"source": "legacy-alias"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyChannel, err := client.AuthIdentityChannel.Create().
+ SetIdentityID(legacyIdentity.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetChannel("oa").
+ SetChannelAppID("wx-app-legacy").
+ SetChannelSubject("openid-legacy-123").
+ SetMetadata(map[string]any{"scene": "legacy-alias"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ bound, err := repo.BindAuthIdentityToUser(ctx, BindAuthIdentityInput{
+ UserID: user.ID,
+ Canonical: AuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-legacy-123",
+ },
+ Channel: &AuthIdentityChannelKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ Channel: "oa",
+ ChannelAppID: "wx-app-legacy",
+ ChannelSubject: "openid-legacy-123",
+ },
+ Metadata: map[string]any{"source": "canonical-bind"},
+ ChannelMetadata: map[string]any{"scene": "canonical-bind"},
+ })
+ require.NoError(t, err)
+ require.NotNil(t, bound)
+ require.NotNil(t, bound.Identity)
+ require.NotNil(t, bound.Channel)
+ require.Equal(t, legacyIdentity.ID, bound.Identity.ID)
+ require.Equal(t, legacyChannel.ID, bound.Channel.ID)
+ require.Equal(t, "wechat-main", bound.Identity.ProviderKey)
+ require.Equal(t, "wechat-main", bound.Channel.ProviderKey)
+
+ reloadedIdentity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "wechat-main", reloadedIdentity.ProviderKey)
+ require.Equal(t, "canonical-bind", reloadedIdentity.Metadata["source"])
+
+ reloadedChannel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID)
+ require.NoError(t, err)
+ require.Equal(t, "wechat-main", reloadedChannel.ProviderKey)
+ require.Equal(t, "canonical-bind", reloadedChannel.Metadata["scene"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderSubjectEQ("union-legacy-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, identityCount)
+
+ channelCount, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ChannelEQ("oa"),
+ authidentitychannel.ChannelAppIDEQ("wx-app-legacy"),
+ authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, channelCount)
+}
+
+func TestUserRepositoryUpsertIdentityAdoptionDecisionIsIdempotentUnderConcurrency(t *testing.T) {
+ repo, client := newUserEntRepo(t)
+ ctx := context.Background()
+
+ user := &service.User{
+ Email: "repo-adoption@example.com",
+ Username: "repo-adoption",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, repo.Create(ctx, user))
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("union-repo-adoption").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := client.PendingAuthSession.Create().
+ SetSessionToken("pending-repo-adoption").
+ SetIntent("bind_current_user").
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("union-repo-adoption").
+ SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
+ SetUpstreamIdentityClaims(map[string]any{"provider_subject": "union-repo-adoption"}).
+ SetLocalFlowState(map[string]any{"step": "pending"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ firstCreateStarted := make(chan struct{})
+ releaseFirstCreate := make(chan struct{})
+ var firstCreate sync.Once
+ client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
+ return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
+ blocked := false
+ if m.Op().Is(dbent.OpCreate) {
+ firstCreate.Do(func() {
+ blocked = true
+ close(firstCreateStarted)
+ })
+ }
+ if blocked {
+ <-releaseFirstCreate
+ }
+ return next.Mutate(ctx, m)
+ })
+ })
+
+ type adoptionResult struct {
+ decision *dbent.IdentityAdoptionDecision
+ err error
+ }
+
+ input := IdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ }
+
+ results := make(chan adoptionResult, 2)
+ go func() {
+ decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
+ results <- adoptionResult{decision: decision, err: err}
+ }()
+
+ <-firstCreateStarted
+
+ go func() {
+ decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
+ results <- adoptionResult{decision: decision, err: err}
+ }()
+
+ time.Sleep(100 * time.Millisecond)
+ close(releaseFirstCreate)
+
+ first := <-results
+ second := <-results
+
+ require.NoError(t, first.err)
+ require.NoError(t, second.err)
+ require.NotNil(t, first.decision)
+ require.NotNil(t, second.decision)
+ require.Equal(t, first.decision.ID, second.decision.ID)
+
+ count, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, count)
+
+ loaded, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, loaded.IdentityID)
+ require.Equal(t, identity.ID, *loaded.IdentityID)
+ require.True(t, loaded.AdoptDisplayName)
+ require.True(t, loaded.AdoptAvatar)
+}
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 575754e0..d1f10cbd 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -11,12 +11,19 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/lib/pq"
+
+ entsql "entgo.io/ent/dialect/sql"
)
type userRepository struct {
@@ -45,12 +52,33 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
}
var txClient *dbent.Client
+ txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
+ txCtx = dbent.NewTxContext(ctx, tx)
} else {
- // 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
- txClient = r.client
+ // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
+ if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
+ txClient = existingTx.Client()
+ } else {
+ txClient = r.client
+ }
+ }
+
+ releaseEmailLock, err := lockRepositoryScopedKeys(
+ txCtx,
+ txClient,
+ txAwareSQLExecutor(txCtx, r.sql, r.client),
+ normalizedEmailUniquenessLockKey(userIn.Email),
+ )
+ if err != nil {
+ return err
+ }
+ defer releaseEmailLock()
+
+ if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, 0, userIn.Email); err != nil {
+ return err
}
created, err := txClient.User.Create().
@@ -62,13 +90,19 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
- SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
- Save(ctx)
+ SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
+ SetNillableLastLoginAt(userIn.LastLoginAt).
+ SetNillableLastActiveAt(userIn.LastActiveAt).
+ SetRpmLimit(userIn.RPMLimit).
+ Save(txCtx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
- if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
+ if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil {
+ return err
+ }
+ if err := ensureEmailAuthIdentityWithClient(txCtx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
return err
}
@@ -100,10 +134,20 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User,
}
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
- m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
+ matches, err := r.client.User.Query().
+ Where(userEmailLookupPredicate(email)).
+ Order(dbent.Asc(dbuser.FieldID)).
+ All(ctx)
if err != nil {
- return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
+ return nil, err
}
+ if len(matches) == 0 {
+ return nil, service.ErrUserNotFound
+ }
+ if len(matches) > 1 {
+ return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email))
+ }
+ m := matches[0]
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
@@ -128,15 +172,42 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
}
var txClient *dbent.Client
+ txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
+ txCtx = dbent.NewTxContext(ctx, tx)
} else {
- // 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
- txClient = r.client
+ // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
+ if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
+ txClient = existingTx.Client()
+ } else {
+ txClient = r.client
+ }
}
- updated, err := txClient.User.UpdateOneID(userIn.ID).
+ releaseEmailLock, err := lockRepositoryScopedKeys(
+ txCtx,
+ txClient,
+ txAwareSQLExecutor(txCtx, r.sql, r.client),
+ normalizedEmailUniquenessLockKey(userIn.Email),
+ )
+ if err != nil {
+ return err
+ }
+ defer releaseEmailLock()
+
+ if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, userIn.ID, userIn.Email); err != nil {
+ return err
+ }
+
+ existing, err := clientFromContext(txCtx, txClient).User.Get(txCtx, userIn.ID)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ oldEmail := existing.Email
+
+ updateOp := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetNotes(userIn.Notes).
@@ -145,14 +216,33 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
- SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
- SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes).
- Save(ctx)
+ SetBalanceNotifyEnabled(userIn.BalanceNotifyEnabled).
+ SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
+ SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
+ SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
+ SetTotalRecharged(userIn.TotalRecharged).
+ SetRpmLimit(userIn.RPMLimit)
+ if userIn.SignupSource != "" {
+ updateOp = updateOp.SetSignupSource(userIn.SignupSource)
+ }
+ if userIn.LastLoginAt != nil {
+ updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt)
+ }
+ if userIn.LastActiveAt != nil {
+ updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt)
+ }
+ if userIn.BalanceNotifyThreshold == nil {
+ updateOp = updateOp.ClearBalanceNotifyThreshold()
+ }
+ updated, err := updateOp.Save(txCtx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
}
- if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
+ if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
+ return err
+ }
+ if err := replaceEmailAuthIdentityWithClient(txCtx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
return err
}
@@ -166,14 +256,146 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
return nil
}
+func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
+ client = clientFromContext(ctx, client)
+ if client == nil || userID <= 0 {
+ return nil
+ }
+
+ subject := normalizeEmailAuthIdentitySubject(email)
+ if subject == "" {
+ return nil
+ }
+
+ if err := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(subject).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": source}).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ if !isSQLNoRowsError(err) {
+ return err
+ }
+ }
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(subject),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil
+ }
+ return err
+ }
+ if identity.UserID != userID {
+ return ErrAuthIdentityOwnershipConflict
+ }
+ return nil
+}
+
+func replaceEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, oldEmail, newEmail string, source string) error {
+ newSubject := normalizeEmailAuthIdentitySubject(newEmail)
+ if err := ensureEmailAuthIdentityWithClient(ctx, client, userID, newEmail, source); err != nil {
+ return err
+ }
+
+ oldSubject := normalizeEmailAuthIdentitySubject(oldEmail)
+ if oldSubject == "" || oldSubject == newSubject {
+ return nil
+ }
+
+ _, err := clientFromContext(ctx, client).AuthIdentity.Delete().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(oldSubject),
+ ).
+ Exec(ctx)
+ return err
+}
+
+func normalizeEmailAuthIdentitySubject(email string) string {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" {
+ return ""
+ }
+ if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) {
+ return ""
+ }
+ return normalized
+}
+
func (r *userRepository) Delete(ctx context.Context, id int64) error {
- affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
+ tx, err := r.client.Tx(ctx)
+ if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+
+ var txClient *dbent.Client
+ if err == nil {
+ defer func() { _ = tx.Rollback() }()
+ txClient = tx.Client()
+ } else {
+ if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
+ txClient = existingTx.Client()
+ } else {
+ txClient = r.client
+ }
+ }
+
+ identityIDs, err := txClient.AuthIdentity.Query().
+ Where(authidentity.UserIDEQ(id)).
+ IDs(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ if len(identityIDs) > 0 {
+ if _, err := txClient.IdentityAdoptionDecision.Update().
+ Where(identityadoptiondecision.IdentityIDIn(identityIDs...)).
+ ClearIdentityID().
+ Save(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ if _, err := txClient.AuthIdentityChannel.Delete().
+ Where(authidentitychannel.IdentityIDIn(identityIDs...)).
+ Exec(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ if _, err := txClient.AuthIdentity.Delete().
+ Where(authidentity.UserIDEQ(id)).
+ Exec(ctx); err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ }
+
+ affected, err := txClient.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if affected == 0 {
return service.ErrUserNotFound
}
+
+ if tx != nil {
+ if err := tx.Commit(); err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ }
return nil
}
@@ -227,11 +449,14 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return nil, nil, err
}
- users, err := q.
+ usersQuery := q.
Offset(params.Offset()).
- Limit(params.Limit()).
- Order(dbent.Desc(dbuser.FieldID)).
- All(ctx)
+ Limit(params.Limit())
+ for _, order := range userListOrder(params) {
+ usersQuery = usersQuery.Order(order)
+ }
+
+ users, err := usersQuery.All(ctx)
if err != nil {
return nil, nil, err
}
@@ -284,6 +509,137 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
return outUsers, paginationResultFromTotal(int64(total), params), nil
}
+func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
+ sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
+ sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
+
+ if sortBy == "last_used_at" {
+ return userLastUsedAtOrder(sortOrder)
+ }
+
+ var field string
+ defaultField := true
+ nullsLastField := false
+ switch sortBy {
+ case "email":
+ field = dbuser.FieldEmail
+ defaultField = false
+ case "username":
+ field = dbuser.FieldUsername
+ defaultField = false
+ case "role":
+ field = dbuser.FieldRole
+ defaultField = false
+ case "balance":
+ field = dbuser.FieldBalance
+ defaultField = false
+ case "concurrency":
+ field = dbuser.FieldConcurrency
+ defaultField = false
+ case "status":
+ field = dbuser.FieldStatus
+ defaultField = false
+ case "created_at":
+ field = dbuser.FieldCreatedAt
+ defaultField = false
+ case "last_active_at":
+ field = dbuser.FieldLastActiveAt
+ defaultField = false
+ nullsLastField = true
+ default:
+ field = dbuser.FieldID
+ }
+
+ if sortOrder == pagination.SortOrderAsc {
+ if defaultField && field == dbuser.FieldID {
+ return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
+ }
+ if nullsLastField {
+ return []func(*entsql.Selector){
+ entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(),
+ dbent.Asc(dbuser.FieldID),
+ }
+ }
+ return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
+ }
+ if defaultField && field == dbuser.FieldID {
+ return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
+ }
+ if nullsLastField {
+ return []func(*entsql.Selector){
+ entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(),
+ dbent.Desc(dbuser.FieldID),
+ }
+ }
+ return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
+}
+
+func (r *userRepository) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ result := make(map[int64]*time.Time, len(userIDs))
+ if len(userIDs) == 0 {
+ return result, nil
+ }
+ if r.sql == nil {
+ return nil, fmt.Errorf("sql executor is not configured")
+ }
+
+ const query = `
+ SELECT user_id, MAX(created_at) AS last_used_at
+ FROM usage_logs
+ WHERE user_id = ANY($1)
+ GROUP BY user_id
+ `
+
+ rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ for rows.Next() {
+ var (
+ userID int64
+ lastUsedAt time.Time
+ )
+ if scanErr := rows.Scan(&userID, &lastUsedAt); scanErr != nil {
+ return nil, scanErr
+ }
+ ts := lastUsedAt.UTC()
+ result[userID] = &ts
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func (r *userRepository) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ latestByUserID, err := r.GetLatestUsedAtByUserIDs(ctx, []int64{userID})
+ if err != nil {
+ return nil, err
+ }
+ return latestByUserID[userID], nil
+}
+
+func userLastUsedAtOrder(sortOrder string) []func(*entsql.Selector) {
+ orderExpr := func(direction, nulls string, tieOrder func(string) string) func(*entsql.Selector) {
+ return func(s *entsql.Selector) {
+ subquery := fmt.Sprintf("(SELECT MAX(created_at) FROM usage_logs WHERE user_id = %s)", s.C(dbuser.FieldID))
+ s.OrderExpr(entsql.Expr(subquery + " " + direction + " NULLS " + nulls))
+ s.OrderBy(tieOrder(s.C(dbuser.FieldID)))
+ }
+ }
+
+ if sortOrder == pagination.SortOrderAsc {
+ return []func(*entsql.Selector){
+ orderExpr("ASC", "FIRST", entsql.Asc),
+ }
+ }
+ return []func(*entsql.Selector){
+ orderExpr("DESC", "LAST", entsql.Desc),
+ }
+}
+
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 {
@@ -336,7 +692,12 @@ func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client)
- n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
+ update := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount)
+ // Track cumulative recharge amount for percentage-based notifications
+ if amount > 0 {
+ update = update.AddTotalRecharged(amount)
+ }
+ n, err := update.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
@@ -376,77 +737,69 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
return nil
}
-// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。
-func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) {
- if deltaBytes <= 0 {
- user, err := r.GetByID(ctx, userID)
- if err != nil {
- return 0, err
- }
- return user.SoraStorageUsedBytes, nil
- }
- var newUsed int64
- err := scanSingleRow(ctx, r.sql, `
- UPDATE users
- SET sora_storage_used_bytes = sora_storage_used_bytes + $2
- WHERE id = $1
- AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3)
- RETURNING sora_storage_used_bytes
- `, []any{userID, deltaBytes, effectiveQuota}, &newUsed)
- if err == nil {
- return newUsed, nil
- }
- if errors.Is(err, sql.ErrNoRows) {
- // 区分用户不存在和配额冲突
- exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx)
- if existsErr != nil {
- return 0, existsErr
- }
- if !exists {
- return 0, service.ErrUserNotFound
- }
- return 0, service.ErrSoraStorageQuotaExceeded
- }
- return 0, err
-}
-
-// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。
-func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) {
- if deltaBytes <= 0 {
- user, err := r.GetByID(ctx, userID)
- if err != nil {
- return 0, err
- }
- return user.SoraStorageUsedBytes, nil
- }
- var newUsed int64
- err := scanSingleRow(ctx, r.sql, `
- UPDATE users
- SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0)
- WHERE id = $1
- RETURNING sora_storage_used_bytes
- `, []any{userID, deltaBytes}, &newUsed)
- if err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return 0, service.ErrUserNotFound
- }
- return 0, err
- }
- return newUsed, nil
-}
-
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
- return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
+ return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
+}
+
+func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error {
+ client = clientFromContext(ctx, client)
+ if client == nil {
+ return nil
+ }
+
+ matches, err := client.User.Query().
+ Where(userEmailLookupPredicate(email)).
+ All(ctx)
+ if err != nil {
+ return err
+ }
+ for _, match := range matches {
+ if match.ID != userID {
+ return service.ErrEmailExists
+ }
+ }
+ return nil
+}
+
+func userEmailLookupPredicate(email string) predicate.User {
+ normalized := normalizeEmailLookupValue(email)
+ if normalized == "" {
+ return dbuser.EmailEQ(email)
+ }
+ return predicate.User(func(s *entsql.Selector) {
+ s.Where(entsql.P(func(b *entsql.Builder) {
+ b.WriteString("LOWER(TRIM(").
+ Ident(s.C(dbuser.FieldEmail)).
+ WriteString(")) = ").
+ Arg(normalized)
+ }))
+ })
+}
+
+func normalizeEmailLookupValue(email string) string {
+ return strings.ToLower(strings.TrimSpace(email))
+}
+
+func normalizedEmailUniquenessLockKey(email string) string {
+ normalized := normalizeEmailLookupValue(email)
+ if normalized == "" {
+ return ""
+ }
+ return "users:normalized-email:" + normalized
}
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client)
- return client.UserAllowedGroup.Create().
+ err := client.UserAllowedGroup.Create().
SetUserID(userID).
SetGroupID(groupID).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx)
+ if isSQLNoRowsError(err) {
+ return nil
+ }
+ return err
}
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
@@ -546,6 +899,9 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx); err != nil {
+ if isSQLNoRowsError(err) {
+ return nil
+ }
return err
}
}
@@ -558,10 +914,29 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
return
}
dst.ID = src.ID
+ dst.SignupSource = src.SignupSource
+ dst.LastLoginAt = src.LastLoginAt
+ dst.LastActiveAt = src.LastActiveAt
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
+func userSignupSourceOrDefault(signupSource string) string {
+ switch strings.TrimSpace(strings.ToLower(signupSource)) {
+ case "", "email":
+ return "email"
+ case "linuxdo", "wechat", "oidc":
+ return strings.TrimSpace(strings.ToLower(signupSource))
+ default:
+ return "email"
+ }
+}
+
+// marshalExtraEmails serializes notify email entries to JSON for storage.
+func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
+ return service.MarshalNotifyEmails(entries)
+}
+
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
client := clientFromContext(ctx, r.client)
diff --git a/backend/internal/repository/user_repo_email_identity_integration_test.go b/backend/internal/repository/user_repo_email_identity_integration_test.go
new file mode 100644
index 00000000..fddd82c5
--- /dev/null
+++ b/backend/internal/repository/user_repo_email_identity_integration_test.go
@@ -0,0 +1,86 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func (s *UserRepoSuite) TestCreate_CreatesEmailAuthIdentityForNormalEmail() {
+ user := &service.User{
+ Email: "repo-create@example.com",
+ PasswordHash: "test-password-hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Concurrency: 2,
+ }
+
+ s.Require().NoError(s.repo.Create(s.ctx, user))
+
+ identity, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("repo-create@example.com"),
+ ).
+ Only(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, identity.UserID)
+}
+
+func (s *UserRepoSuite) TestCreate_SkipsEmailAuthIdentityForSyntheticLinuxDoEmail() {
+ user := &service.User{
+ Email: "linuxdo-legacy-user@linuxdo-connect.invalid",
+ PasswordHash: "test-password-hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Concurrency: 2,
+ }
+
+ s.Require().NoError(s.repo.Create(s.ctx, user))
+
+ count, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ ).
+ Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Zero(count)
+}
+
+func (s *UserRepoSuite) TestUpdate_ReplacesEmailAuthIdentityWhenEmailChanges() {
+ user := s.mustCreateUser(&service.User{
+ Email: "before-update@example.com",
+ })
+
+ user.Email = "after-update@example.com"
+ s.Require().NoError(s.repo.Update(s.ctx, user))
+
+ newIdentity, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("after-update@example.com"),
+ ).
+ Only(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(user.ID, newIdentity.UserID)
+
+ oldCount, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("before-update@example.com"),
+ ).
+ Count(context.Background())
+ s.Require().NoError(err)
+ s.Require().Zero(oldCount)
+}
diff --git a/backend/internal/repository/user_repo_email_lookup_unit_test.go b/backend/internal/repository/user_repo_email_lookup_unit_test.go
new file mode 100644
index 00000000..7da3db9b
--- /dev/null
+++ b/backend/internal/repository/user_repo_email_lookup_unit_test.go
@@ -0,0 +1,227 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "sync"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name()))
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+ db.SetMaxOpenConns(10)
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ return newUserRepositoryWithSQL(client, db), client
+}
+
+func TestUserRepositoryGetByEmailNormalizesLegacySpacingAndCase(t *testing.T) {
+ repo, _ := newUserEntRepo(t)
+ ctx := context.Background()
+
+ err := repo.Create(ctx, &service.User{
+ Email: " Legacy@Example.com ",
+ Username: "legacy-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ require.NoError(t, err)
+
+ got, err := repo.GetByEmail(ctx, "legacy@example.com")
+ require.NoError(t, err)
+ require.Equal(t, " Legacy@Example.com ", got.Email)
+}
+
+func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T) {
+ repo, _ := newUserEntRepo(t)
+ ctx := context.Background()
+
+ err := repo.Create(ctx, &service.User{
+ Email: " Legacy@Example.com ",
+ Username: "legacy-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ require.NoError(t, err)
+
+ exists, err := repo.ExistsByEmail(ctx, " LEGACY@example.com ")
+ require.NoError(t, err)
+ require.True(t, exists)
+}
+
+func TestUserRepositoryCreateRejectsNormalizedEmailDuplicate(t *testing.T) {
+ repo, _ := newUserEntRepo(t)
+ ctx := context.Background()
+
+ err := repo.Create(ctx, &service.User{
+ Email: " Existing@Example.com ",
+ Username: "existing-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ require.NoError(t, err)
+
+ err = repo.Create(ctx, &service.User{
+ Email: "existing@example.com",
+ Username: "duplicate-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })
+ require.ErrorIs(t, err, service.ErrEmailExists)
+}
+
+func TestUserRepositoryUpdateRejectsNormalizedEmailDuplicate(t *testing.T) {
+ repo, _ := newUserEntRepo(t)
+ ctx := context.Background()
+
+ first := &service.User{
+ Email: " Existing@Example.com ",
+ Username: "existing-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, repo.Create(ctx, first))
+
+ second := &service.User{
+ Email: "second@example.com",
+ Username: "second-user",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ }
+ require.NoError(t, repo.Create(ctx, second))
+
+ second.Email = " existing@example.com "
+ err := repo.Update(ctx, second)
+ require.ErrorIs(t, err, service.ErrEmailExists)
+}
+
+func TestUserRepositoryGetByEmailReportsNormalizedEmailConflict(t *testing.T) {
+ repo, client := newUserEntRepo(t)
+ ctx := context.Background()
+
+ _, err := client.User.Create().
+ SetEmail("Conflict@Example.com").
+ SetUsername("conflict-user-1").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.User.Create().
+ SetEmail(" conflict@example.com ").
+ SetUsername("conflict-user-2").
+ SetPasswordHash("hash").
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = repo.GetByEmail(ctx, "conflict@example.com")
+ require.Error(t, err)
+ require.ErrorContains(t, err, "normalized email lookup matched multiple users")
+}
+
+func TestUserRepositoryCreateSerializesNormalizedEmailConflictsUnderConcurrency(t *testing.T) {
+ repo, client := newUserEntRepo(t)
+ ctx := context.Background()
+
+ firstCreateStarted := make(chan struct{})
+ releaseFirstCreate := make(chan struct{})
+ var firstCreate sync.Once
+ client.User.Use(func(next dbent.Mutator) dbent.Mutator {
+ return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
+ blocked := false
+ if m.Op().Is(dbent.OpCreate) {
+ firstCreate.Do(func() {
+ blocked = true
+ close(firstCreateStarted)
+ })
+ }
+ if blocked {
+ <-releaseFirstCreate
+ }
+ return next.Mutate(ctx, m)
+ })
+ })
+
+ type createResult struct {
+ err error
+ }
+
+ results := make(chan createResult, 2)
+ go func() {
+ results <- createResult{err: repo.Create(ctx, &service.User{
+ Email: " Race@Example.com ",
+ Username: "race-user-1",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })}
+ }()
+
+ <-firstCreateStarted
+
+ go func() {
+ results <- createResult{err: repo.Create(ctx, &service.User{
+ Email: "race@example.com",
+ Username: "race-user-2",
+ PasswordHash: "hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ })}
+ }()
+
+ time.Sleep(100 * time.Millisecond)
+ close(releaseFirstCreate)
+
+ first := <-results
+ second := <-results
+
+ errors := []error{first.err, second.err}
+ successes := 0
+ conflicts := 0
+ for _, err := range errors {
+ switch err {
+ case nil:
+ successes++
+ case service.ErrEmailExists:
+ conflicts++
+ default:
+ t.Fatalf("unexpected create error: %v", err)
+ }
+ }
+ require.Equal(t, 1, successes)
+ require.Equal(t, 1, conflicts)
+
+ count, err := client.User.Query().Where(userEmailLookupPredicate("race@example.com")).Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, count)
+}
diff --git a/backend/internal/repository/user_repo_integration_test.go b/backend/internal/repository/user_repo_integration_test.go
index f5d0f9ff..13a605a2 100644
--- a/backend/internal/repository/user_repo_integration_test.go
+++ b/backend/internal/repository/user_repo_integration_test.go
@@ -8,6 +8,8 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
@@ -26,6 +28,8 @@ func (s *UserRepoSuite) SetupTest() {
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
// 清理测试数据,确保每个测试从干净状态开始
+ _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identity_channels")
+ _, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identities")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users")
@@ -122,11 +126,27 @@ func (s *UserRepoSuite) TestGetByEmail() {
s.Require().Equal(user.ID, got.ID)
}
+func (s *UserRepoSuite) TestGetByEmail_NormalizesSpacingAndCaseOnPostgres() {
+ user := s.mustCreateUser(&service.User{Email: " Legacy@Example.com "})
+
+ got, err := s.repo.GetByEmail(s.ctx, " legacy@example.com ")
+ s.Require().NoError(err, "GetByEmail normalized lookup")
+ s.Require().Equal(user.ID, got.ID)
+}
+
func (s *UserRepoSuite) TestGetByEmail_NotFound() {
_, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com")
s.Require().Error(err, "expected error for non-existent email")
}
+func (s *UserRepoSuite) TestExistsByEmail_NormalizesSpacingAndCaseOnPostgres() {
+ s.mustCreateUser(&service.User{Email: " Legacy@Example.com "})
+
+ exists, err := s.repo.ExistsByEmail(s.ctx, " LEGACY@example.com ")
+ s.Require().NoError(err, "ExistsByEmail normalized lookup")
+ s.Require().True(exists)
+}
+
func (s *UserRepoSuite) TestUpdate() {
user := s.mustCreateUser(&service.User{Email: "update@test.com", Username: "original"})
@@ -140,6 +160,30 @@ func (s *UserRepoSuite) TestUpdate() {
s.Require().Equal("updated", updated.Username)
}
+func (s *UserRepoSuite) TestUpdateIgnoresNoRowsFromConflictingEmailIdentityUpsert() {
+ user := s.mustCreateUser(&service.User{Email: "update-existing-identity@test.com", Username: "original"})
+
+ identityCount, err := s.client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("update-existing-identity@test.com"),
+ ).
+ Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Equal(1, identityCount)
+
+ got, err := s.repo.GetByID(s.ctx, user.ID)
+ s.Require().NoError(err)
+ got.Username = "updated"
+ s.Require().NoError(s.repo.Update(s.ctx, got), "Update should tolerate ON CONFLICT DO NOTHING returning no rows")
+
+ updated, err := s.repo.GetByID(s.ctx, user.ID)
+ s.Require().NoError(err)
+ s.Require().Equal("updated", updated.Username)
+}
+
func (s *UserRepoSuite) TestDelete() {
user := s.mustCreateUser(&service.User{Email: "delete@test.com"})
@@ -150,6 +194,39 @@ func (s *UserRepoSuite) TestDelete() {
s.Require().Error(err, "expected error after delete")
}
+func (s *UserRepoSuite) TestDeleteRemovesAuthIdentitiesAndChannels() {
+ user := s.mustCreateUser(&service.User{Email: "delete-oauth@test.com"})
+
+ identity, err := s.client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("linuxdo").
+ SetProviderKey("linuxdo").
+ SetProviderSubject("delete-oauth-subject").
+ Save(s.ctx)
+ s.Require().NoError(err)
+
+ _, err = s.client.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetChannel("open").
+ SetChannelAppID("app-id").
+ SetChannelSubject("openid-123").
+ Save(s.ctx)
+ s.Require().NoError(err)
+
+ err = s.repo.Delete(s.ctx, user.ID)
+ s.Require().NoError(err)
+
+ identityCount, err := s.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(user.ID)).Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Zero(identityCount)
+
+ channelCount, err := s.client.AuthIdentityChannel.Query().Where(authidentitychannel.IdentityIDEQ(identity.ID)).Count(s.ctx)
+ s.Require().NoError(err)
+ s.Require().Zero(channelCount)
+}
+
// --- List / ListWithFilters ---
func (s *UserRepoSuite) TestList() {
diff --git a/backend/internal/repository/user_repo_sort_integration_test.go b/backend/internal/repository/user_repo_sort_integration_test.go
new file mode 100644
index 00000000..3a15bc10
--- /dev/null
+++ b/backend/internal/repository/user_repo_sort_integration_test.go
@@ -0,0 +1,164 @@
+//go:build integration
+
+package repository
+
+import (
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func (s *UserRepoSuite) mustInsertUsageLog(userID int64, createdAt time.Time) {
+ s.T().Helper()
+
+ account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "usage-log-account"})
+ apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userID})
+
+ _, err := integrationDB.ExecContext(
+ s.ctx,
+ `INSERT INTO usage_logs (user_id, api_key_id, account_id, model, input_tokens, output_tokens, total_cost, actual_cost, created_at)
+ VALUES ($1, $2, $3, 'gpt-test', 1, 1, 0.01, 0.01, $4)`,
+ userID,
+ apiKey.ID,
+ account.ID,
+ createdAt.UTC(),
+ )
+ s.Require().NoError(err)
+}
+
+func (s *UserRepoSuite) TestListWithFilters_SortByEmailAsc() {
+ s.mustCreateUser(&service.User{Email: "z-last@example.com", Username: "z-user"})
+ s.mustCreateUser(&service.User{Email: "a-first@example.com", Username: "a-user"})
+
+ users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "email",
+ SortOrder: "asc",
+ }, service.UserListFilters{})
+ s.Require().NoError(err)
+ s.Require().Len(users, 2)
+ s.Require().Equal("a-first@example.com", users[0].Email)
+ s.Require().Equal("z-last@example.com", users[1].Email)
+}
+
+func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() {
+ first := s.mustCreateUser(&service.User{Email: "first@example.com"})
+ second := s.mustCreateUser(&service.User{Email: "second@example.com"})
+
+ users, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
+ s.Require().NoError(err)
+ s.Require().Len(users, 2)
+ s.Require().Equal(second.ID, users[0].ID)
+ s.Require().Equal(first.ID, users[1].ID)
+}
+
+func (s *UserRepoSuite) TestCreateAndRead_PreservesSignupSourceAndActivityTimestamps() {
+ lastLoginAt := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Microsecond)
+ lastActiveAt := time.Now().Add(-30 * time.Minute).UTC().Truncate(time.Microsecond)
+
+ created := s.mustCreateUser(&service.User{
+ Email: "identity-meta@example.com",
+ SignupSource: "linuxdo",
+ LastLoginAt: &lastLoginAt,
+ LastActiveAt: &lastActiveAt,
+ })
+
+ got, err := s.repo.GetByID(s.ctx, created.ID)
+ s.Require().NoError(err)
+ s.Require().Equal("linuxdo", got.SignupSource)
+ s.Require().NotNil(got.LastLoginAt)
+ s.Require().NotNil(got.LastActiveAt)
+ s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
+ s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
+}
+
+func (s *UserRepoSuite) TestUpdate_PersistsSignupSourceAndActivityTimestamps() {
+ created := s.mustCreateUser(&service.User{Email: "identity-update@example.com"})
+ lastLoginAt := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Microsecond)
+ lastActiveAt := time.Now().Add(-15 * time.Minute).UTC().Truncate(time.Microsecond)
+
+ created.SignupSource = "oidc"
+ created.LastLoginAt = &lastLoginAt
+ created.LastActiveAt = &lastActiveAt
+
+ s.Require().NoError(s.repo.Update(s.ctx, created))
+
+ got, err := s.repo.GetByID(s.ctx, created.ID)
+ s.Require().NoError(err)
+ s.Require().Equal("oidc", got.SignupSource)
+ s.Require().NotNil(got.LastLoginAt)
+ s.Require().NotNil(got.LastActiveAt)
+ s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
+ s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
+}
+
+func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() {
+ earlier := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Microsecond)
+ later := time.Now().Add(-45 * time.Minute).UTC().Truncate(time.Microsecond)
+
+ s.mustCreateUser(&service.User{Email: "nil-active@example.com"})
+ s.mustCreateUser(&service.User{Email: "later-active@example.com", LastActiveAt: &later})
+ s.mustCreateUser(&service.User{Email: "earlier-active@example.com", LastActiveAt: &earlier})
+
+ users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "last_active_at",
+ SortOrder: "asc",
+ }, service.UserListFilters{})
+ s.Require().NoError(err)
+ s.Require().Len(users, 3)
+ s.Require().Equal("earlier-active@example.com", users[0].Email)
+ s.Require().Equal("later-active@example.com", users[1].Email)
+ s.Require().Equal("nil-active@example.com", users[2].Email)
+}
+
+func (s *UserRepoSuite) TestGetLatestUsedAtByUserIDs_UsesUsageLogs() {
+ older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Second)
+ newer := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Second)
+
+ userWithUsage := s.mustCreateUser(&service.User{Email: "usage-source@example.com"})
+ userWithoutUsage := s.mustCreateUser(&service.User{Email: "usage-missing@example.com"})
+ s.mustInsertUsageLog(userWithUsage.ID, older)
+ s.mustInsertUsageLog(userWithUsage.ID, newer)
+
+ got, err := s.repo.GetLatestUsedAtByUserIDs(s.ctx, []int64{userWithUsage.ID, userWithoutUsage.ID})
+ s.Require().NoError(err)
+ s.Require().Contains(got, userWithUsage.ID)
+ s.Require().NotContains(got, userWithoutUsage.ID)
+ s.Require().NotNil(got[userWithUsage.ID])
+ s.Require().True(got[userWithUsage.ID].Equal(newer))
+}
+
+func (s *UserRepoSuite) TestListWithFilters_SortByLastUsedAtDesc_UsesUsageLogsNotLastActiveAt() {
+ lastUsedOlder := time.Now().Add(-6 * time.Hour).UTC().Truncate(time.Second)
+ lastUsedNewer := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Second)
+ lastActiveVeryRecent := time.Now().Add(-10 * time.Minute).UTC().Truncate(time.Second)
+
+ nilUsage := s.mustCreateUser(&service.User{Email: "nil-last-used@example.com"})
+ wrongSource := s.mustCreateUser(&service.User{
+ Email: "active-not-usage@example.com",
+ LastActiveAt: &lastActiveVeryRecent,
+ })
+ rightSource := s.mustCreateUser(&service.User{Email: "usage-wins@example.com"})
+
+ s.mustInsertUsageLog(wrongSource.ID, lastUsedOlder)
+ s.mustInsertUsageLog(rightSource.ID, lastUsedNewer)
+
+ users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
+ Page: 1,
+ PageSize: 10,
+ SortBy: "last_used_at",
+ SortOrder: "desc",
+ }, service.UserListFilters{})
+ s.Require().NoError(err)
+ s.Require().Len(users, 3)
+ s.Require().Equal(rightSource.ID, users[0].ID)
+ s.Require().Equal(wrongSource.ID, users[1].ID)
+ s.Require().Equal(nilUsage.ID, users[2].ID)
+}
+
+func TestUserRepoSortSuiteSmoke(_ *testing.T) {}
diff --git a/backend/internal/repository/user_rpm_cache.go b/backend/internal/repository/user_rpm_cache.go
new file mode 100644
index 00000000..42bf9332
--- /dev/null
+++ b/backend/internal/repository/user_rpm_cache.go
@@ -0,0 +1,108 @@
+package repository
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/redis/go-redis/v9"
+)
+
+// 用户/分组级 RPM 计数器 Redis 实现。
+//
+// 设计说明:
+// - key 形式:rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute}
+// - 时间来源:rdb.Time()(Redis 服务端时间),避免多实例时钟漂移。
+// - 原子操作:TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE,兼容 Redis Cluster。
+// - TTL:120s,覆盖当前分钟窗口 + 少量冗余。
+// - 返回值语义:超限判断由调用方(billing_cache_service.checkRPM)与 RPMLimit 比较完成。
+const (
+ userGroupRPMKeyPrefix = "rpm:ug:"
+ userRPMKeyPrefix = "rpm:u:"
+
+ userRPMKeyTTL = 120 * time.Second
+)
+
+type userRPMCacheImpl struct {
+ rdb *redis.Client
+}
+
+// NewUserRPMCache 创建用户/分组级 RPM 计数器。
+func NewUserRPMCache(rdb *redis.Client) service.UserRPMCache {
+ return &userRPMCacheImpl{rdb: rdb}
+}
+
+// minuteTS 获取当前 Redis 服务端分钟时间戳。
+func (c *userRPMCacheImpl) minuteTS(ctx context.Context) (int64, error) {
+ t, err := c.rdb.Time(ctx).Result()
+ if err != nil {
+ return 0, fmt.Errorf("redis TIME: %w", err)
+ }
+ return t.Unix() / 60, nil
+}
+
+// atomicIncr 原子 INCR+EXPIRE。
+func (c *userRPMCacheImpl) atomicIncr(ctx context.Context, key string) (int, error) {
+ pipe := c.rdb.TxPipeline()
+ incr := pipe.Incr(ctx, key)
+ pipe.Expire(ctx, key, userRPMKeyTTL)
+ if _, err := pipe.Exec(ctx); err != nil {
+ return 0, fmt.Errorf("user rpm increment: %w", err)
+ }
+ return int(incr.Val()), nil
+}
+
+// IncrementUserGroupRPM 递增 (user, group) 分钟计数。
+func (c *userRPMCacheImpl) IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
+ minute, err := c.minuteTS(ctx)
+ if err != nil {
+ return 0, err
+ }
+ key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
+ return c.atomicIncr(ctx, key)
+}
+
+// IncrementUserRPM 递增用户分钟计数。
+func (c *userRPMCacheImpl) IncrementUserRPM(ctx context.Context, userID int64) (int, error) {
+ minute, err := c.minuteTS(ctx)
+ if err != nil {
+ return 0, err
+ }
+ key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
+ return c.atomicIncr(ctx, key)
+}
+
+// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读)。
+func (c *userRPMCacheImpl) GetUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) {
+ minute, err := c.minuteTS(ctx)
+ if err != nil {
+ return 0, err
+ }
+ key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute)
+ val, err := c.rdb.Get(ctx, key).Int()
+ if err == redis.Nil {
+ return 0, nil
+ }
+ if err != nil {
+ return 0, fmt.Errorf("user group rpm get: %w", err)
+ }
+ return val, nil
+}
+
+// GetUserRPM 获取用户当前分钟已用 RPM(只读)。
+func (c *userRPMCacheImpl) GetUserRPM(ctx context.Context, userID int64) (int, error) {
+ minute, err := c.minuteTS(ctx)
+ if err != nil {
+ return 0, err
+ }
+ key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute)
+ val, err := c.rdb.Get(ctx, key).Int()
+ if err == redis.Nil {
+ return 0, nil
+ }
+ if err != nil {
+ return 0, fmt.Errorf("user rpm get: %w", err)
+ }
+ return val, nil
+}
diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go
index 49d47bf6..f07bbb33 100644
--- a/backend/internal/repository/wire.go
+++ b/backend/internal/repository/wire.go
@@ -47,13 +47,27 @@ func ProvideSessionLimitCache(rdb *redis.Client, cfg *config.Config) service.Ses
return NewSessionLimitCache(rdb, defaultIdleTimeoutMinutes)
}
+// ProvideSchedulerCache 创建调度快照缓存,并注入快照分块参数。
+func ProvideSchedulerCache(rdb *redis.Client, cfg *config.Config) service.SchedulerCache {
+ mgetChunkSize := defaultSchedulerSnapshotMGetChunkSize
+ writeChunkSize := defaultSchedulerSnapshotWriteChunkSize
+ if cfg != nil {
+ if cfg.Gateway.Scheduling.SnapshotMGetChunkSize > 0 {
+ mgetChunkSize = cfg.Gateway.Scheduling.SnapshotMGetChunkSize
+ }
+ if cfg.Gateway.Scheduling.SnapshotWriteChunkSize > 0 {
+ writeChunkSize = cfg.Gateway.Scheduling.SnapshotWriteChunkSize
+ }
+ }
+ return newSchedulerCacheWithChunkSizes(rdb, mgetChunkSize, writeChunkSize)
+}
+
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
NewAPIKeyRepository,
NewGroupRepository,
NewAccountRepository,
- NewSoraAccountRepository, // Sora 账号扩展表仓储
NewScheduledTestPlanRepository, // 定时测试计划仓储
NewScheduledTestResultRepository, // 定时测试结果仓储
NewProxyRepository,
@@ -74,6 +88,10 @@ var ProviderSet = wire.NewSet(
NewUserGroupRateRepository,
NewErrorPassthroughRepository,
NewTLSFingerprintProfileRepository,
+ NewChannelRepository,
+ NewChannelMonitorRepository,
+ NewChannelMonitorRequestTemplateRepository,
+ NewAffiliateRepository,
// Cache implementations
NewGatewayCache,
@@ -81,10 +99,12 @@ var ProviderSet = wire.NewSet(
NewAPIKeyCache,
NewTempUnschedCache,
NewTimeoutCounterCache,
+ NewOpenAI403CounterCache,
NewInternal500CounterCache,
ProvideConcurrencyCache,
ProvideSessionLimitCache,
NewRPMCache,
+ NewUserRPMCache,
NewUserMsgQueueCache,
NewDashboardCache,
NewEmailCache,
@@ -92,7 +112,7 @@ var ProviderSet = wire.NewSet(
NewRedeemCache,
NewUpdateCache,
NewGeminiTokenCache,
- NewSchedulerCache,
+ ProvideSchedulerCache,
NewSchedulerOutboxRepository,
NewProxyLatencyCache,
NewTotpCache,
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 7059cb76..9d1dfa6b 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -50,14 +50,138 @@ func TestAPIContracts(t *testing.T) {
"data": {
"id": 1,
"email": "alice@example.com",
+ "email_bound": true,
"username": "alice",
"role": "user",
"balance": 12.5,
"concurrency": 5,
+ "rpm_limit": 0,
"status": "active",
"allowed_groups": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z",
+ "balance_notify_enabled": false,
+ "balance_notify_threshold_type": "",
+ "balance_notify_threshold": null,
+ "balance_notify_extra_emails": null,
+ "total_recharged": 0,
+ "linuxdo_bound": false,
+ "oidc_bound": false,
+ "wechat_bound": false,
+ "identities": {
+ "email": {
+ "provider": "email",
+ "provider_key": "email",
+ "bound": true,
+ "bound_count": 1,
+ "can_bind": false,
+ "can_unbind": false,
+ "display_name": "alice@example.com",
+ "subject_hint": "a***e@example.com",
+ "note_key": "profile.authBindings.notes.emailManagedFromProfile",
+ "note": "Primary account email is managed from the profile form."
+ },
+ "linuxdo": {
+ "provider": "linuxdo",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "oidc": {
+ "provider": "oidc",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "wechat": {
+ "provider": "wechat",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ }
+ },
+ "identity_bindings": {
+ "email": {
+ "provider": "email",
+ "provider_key": "email",
+ "bound": true,
+ "bound_count": 1,
+ "can_bind": false,
+ "can_unbind": false,
+ "display_name": "alice@example.com",
+ "subject_hint": "a***e@example.com",
+ "note_key": "profile.authBindings.notes.emailManagedFromProfile",
+ "note": "Primary account email is managed from the profile form."
+ },
+ "linuxdo": {
+ "provider": "linuxdo",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "oidc": {
+ "provider": "oidc",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "wechat": {
+ "provider": "wechat",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ }
+ },
+ "auth_bindings": {
+ "email": {
+ "provider": "email",
+ "provider_key": "email",
+ "bound": true,
+ "bound_count": 1,
+ "can_bind": false,
+ "can_unbind": false,
+ "display_name": "alice@example.com",
+ "subject_hint": "a***e@example.com",
+ "note_key": "profile.authBindings.notes.emailManagedFromProfile",
+ "note": "Primary account email is managed from the profile form."
+ },
+ "linuxdo": {
+ "provider": "linuxdo",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "oidc": {
+ "provider": "oidc",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "wechat": {
+ "provider": "wechat",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ }
+ },
"run_mode": "standard"
}
}`,
@@ -204,18 +328,13 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
- "sora_image_price_360": null,
- "sora_image_price_540": null,
- "sora_storage_quota_bytes": 0,
- "sora_video_price_per_request": null,
- "sora_video_price_per_request_hd": null,
- "claude_code_only": false,
+ "claude_code_only": false,
"allow_messages_dispatch": false,
"fallback_group_id": null,
"fallback_group_id_on_invalid_request": null,
- "allow_messages_dispatch": false,
"require_oauth_only": false,
"require_privacy_set": false,
+ "rpm_limit": 0,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
@@ -467,6 +586,28 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyTurnstileSiteKey: "site-key",
service.SettingKeyTurnstileSecretKey: "secret-key",
+ service.SettingKeyOIDCConnectEnabled: "false",
+ service.SettingKeyOIDCConnectProviderName: "OIDC",
+ service.SettingKeyOIDCConnectClientID: "",
+ service.SettingKeyOIDCConnectIssuerURL: "",
+ service.SettingKeyOIDCConnectDiscoveryURL: "",
+ service.SettingKeyOIDCConnectAuthorizeURL: "",
+ service.SettingKeyOIDCConnectTokenURL: "",
+ service.SettingKeyOIDCConnectUserInfoURL: "",
+ service.SettingKeyOIDCConnectJWKSURL: "",
+ service.SettingKeyOIDCConnectScopes: "openid email profile",
+ service.SettingKeyOIDCConnectRedirectURL: "",
+ service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
+ service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
+ service.SettingKeyOIDCConnectUsePKCE: "true",
+ service.SettingKeyOIDCConnectValidateIDToken: "true",
+ service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
+ service.SettingKeyOIDCConnectClockSkewSeconds: "120",
+ service.SettingKeyOIDCConnectRequireEmailVerified: "false",
+ service.SettingKeyOIDCConnectUserInfoEmailPath: "",
+ service.SettingKeyOIDCConnectUserInfoIDPath: "",
+ service.SettingKeyOIDCConnectUserInfoUsernamePath: "",
+
service.SettingKeySiteName: "TianShuAPI",
service.SettingKeySiteLogo: "",
service.SettingKeySiteSubtitle: "Subtitle",
@@ -474,13 +615,20 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyContactInfo: "support",
service.SettingKeyDocURL: "https://docs.example.com",
- service.SettingKeyDefaultConcurrency: "5",
- service.SettingKeyDefaultBalance: "1.25",
+ service.SettingKeyDefaultConcurrency: "5",
+ service.SettingKeyDefaultBalance: "1.25",
+ service.SettingKeyTableDefaultPageSize: "20",
+ service.SettingKeyTablePageSizeOptions: "[10,20,50,100]",
- service.SettingKeyOpsMonitoringEnabled: "false",
- service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
- service.SettingKeyOpsQueryModeDefault: "auto",
- service.SettingKeyOpsMetricsIntervalSeconds: "60",
+ service.SettingKeyOpsMonitoringEnabled: "false",
+ service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
+ service.SettingKeyOpsQueryModeDefault: "auto",
+ service.SettingKeyOpsMetricsIntervalSeconds: "60",
+ service.SettingPaymentVisibleMethodAlipaySource: service.VisibleMethodSourceEasyPayAlipay,
+ service.SettingPaymentVisibleMethodWxpaySource: service.VisibleMethodSourceOfficialWechat,
+ service.SettingPaymentVisibleMethodAlipayEnabled: "true",
+ service.SettingPaymentVisibleMethodWxpayEnabled: "false",
+ "openai_advanced_scheduler_enabled": "true",
})
},
method: http.MethodGet,
@@ -508,10 +656,32 @@ func TestAPIContracts(t *testing.T) {
"turnstile_enabled": true,
"turnstile_site_key": "site-key",
"turnstile_secret_key_configured": true,
- "linuxdo_connect_enabled": false,
+ "linuxdo_connect_enabled": false,
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
+ "oidc_connect_enabled": false,
+ "oidc_connect_provider_name": "OIDC",
+ "oidc_connect_client_id": "",
+ "oidc_connect_client_secret_configured": false,
+ "oidc_connect_issuer_url": "",
+ "oidc_connect_discovery_url": "",
+ "oidc_connect_authorize_url": "",
+ "oidc_connect_token_url": "",
+ "oidc_connect_userinfo_url": "",
+ "oidc_connect_jwks_url": "",
+ "oidc_connect_scopes": "openid email profile",
+ "oidc_connect_redirect_url": "",
+ "oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
+ "oidc_connect_token_auth_method": "client_secret_post",
+ "oidc_connect_use_pkce": true,
+ "oidc_connect_validate_id_token": true,
+ "oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
+ "oidc_connect_clock_skew_seconds": 120,
+ "oidc_connect_require_email_verified": false,
+ "oidc_connect_userinfo_email_path": "",
+ "oidc_connect_userinfo_id_path": "",
+ "oidc_connect_userinfo_username_path": "",
"ops_monitoring_enabled": false,
"ops_realtime_monitoring_enabled": true,
"ops_query_mode_default": "auto",
@@ -522,8 +692,34 @@ func TestAPIContracts(t *testing.T) {
"api_base_url": "https://api.example.com",
"contact_info": "support",
"doc_url": "https://docs.example.com",
+ "auth_source_default_email_balance": 0,
+ "auth_source_default_email_concurrency": 5,
+ "auth_source_default_email_subscriptions": [],
+ "auth_source_default_email_grant_on_signup": false,
+ "auth_source_default_email_grant_on_first_bind": false,
+ "auth_source_default_linuxdo_balance": 0,
+ "auth_source_default_linuxdo_concurrency": 5,
+ "auth_source_default_linuxdo_subscriptions": [],
+ "auth_source_default_linuxdo_grant_on_signup": false,
+ "auth_source_default_linuxdo_grant_on_first_bind": false,
+ "auth_source_default_oidc_balance": 0,
+ "auth_source_default_oidc_concurrency": 5,
+ "auth_source_default_oidc_subscriptions": [],
+ "auth_source_default_oidc_grant_on_signup": false,
+ "auth_source_default_oidc_grant_on_first_bind": false,
+ "auth_source_default_wechat_balance": 0,
+ "auth_source_default_wechat_concurrency": 5,
+ "auth_source_default_wechat_subscriptions": [],
+ "auth_source_default_wechat_grant_on_signup": false,
+ "auth_source_default_wechat_grant_on_first_bind": false,
+ "force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
+ "affiliate_rebate_rate": 20,
+ "affiliate_rebate_freeze_hours": 0,
+ "affiliate_rebate_duration_days": 0,
+ "affiliate_rebate_per_invitee_cap": 0,
+ "default_user_rpm_limit": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
@@ -532,20 +728,274 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": "",
- "sora_client_enabled": false,
"invitation_code_enabled": false,
"home_content": "",
"hide_ccs_import_button": false,
"purchase_subscription_enabled": false,
"purchase_subscription_url": "",
+ "table_default_page_size": 20,
+ "table_page_size_options": [10, 20, 50, 100],
+ "min_claude_code_version": "",
+ "max_claude_code_version": "",
+ "allow_ungrouped_key_scheduling": false,
+ "backend_mode_enabled": false,
+ "enable_cch_signing": false,
+ "enable_fingerprint_unification": true,
+ "enable_metadata_passthrough": false,
+ "web_search_emulation_enabled": false,
+ "payment_visible_method_alipay_source": "easypay_alipay",
+ "payment_visible_method_wxpay_source": "official_wxpay",
+ "payment_visible_method_alipay_enabled": true,
+ "payment_visible_method_wxpay_enabled": false,
+ "openai_advanced_scheduler_enabled": true,
+ "custom_menu_items": [],
+ "custom_endpoints": [],
+ "payment_enabled": false,
+ "payment_min_amount": 0,
+ "payment_max_amount": 0,
+ "payment_daily_limit": 0,
+ "payment_order_timeout_minutes": 0,
+ "payment_max_pending_orders": 0,
+ "payment_balance_disabled": false,
+ "payment_balance_recharge_multiplier": 0,
+ "payment_recharge_fee_rate": 0,
+ "payment_load_balance_strategy": "",
+ "payment_product_name_prefix": "",
+ "payment_product_name_suffix": "",
+ "payment_help_image_url": "",
+ "payment_help_text": "",
+ "payment_enabled_types": null,
+ "payment_cancel_rate_limit_enabled": false,
+ "payment_cancel_rate_limit_max": 0,
+ "payment_cancel_rate_limit_window": 0,
+ "payment_cancel_rate_limit_unit": "",
+ "payment_cancel_rate_limit_window_mode": "",
+ "balance_low_notify_enabled": false,
+ "account_quota_notify_enabled": false,
+ "balance_low_notify_threshold": 0,
+ "balance_low_notify_recharge_url": "",
+ "account_quota_notify_emails": [],
+ "channel_monitor_enabled": true,
+ "channel_monitor_default_interval_seconds": 60,
+ "available_channels_enabled": false,
+ "affiliate_enabled": false,
+ "wechat_connect_enabled": false,
+ "wechat_connect_app_id": "",
+ "wechat_connect_app_secret_configured": false,
+ "wechat_connect_mode": "open",
+ "wechat_connect_open_enabled": false,
+ "wechat_connect_open_app_id": "",
+ "wechat_connect_open_app_secret_configured": false,
+ "wechat_connect_mp_enabled": false,
+ "wechat_connect_mp_app_id": "",
+ "wechat_connect_mp_app_secret_configured": false,
+ "wechat_connect_mobile_enabled": false,
+ "wechat_connect_mobile_app_id": "",
+ "wechat_connect_mobile_app_secret_configured": false,
+ "wechat_connect_redirect_url": "",
+ "wechat_connect_frontend_redirect_url": "/auth/wechat/callback",
+ "wechat_connect_scopes": "snsapi_login"
+ }
+ }`,
+ },
+ {
+ name: "GET /api/v1/admin/settings falls back to config oauth defaults",
+ setup: func(t *testing.T, deps *contractDeps) {
+ t.Helper()
+ deps.cfg.OIDC = config.OIDCConnectConfig{
+ Enabled: true,
+ ProviderName: "ConfigOIDC",
+ ClientID: "oidc-config-client",
+ ClientSecret: "oidc-config-secret",
+ IssuerURL: "https://issuer.example.com",
+ RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ FrontendRedirectURL: "/auth/oidc/callback",
+ Scopes: "openid email profile",
+ TokenAuthMethod: "client_secret_post",
+ UsePKCE: true,
+ ValidateIDToken: true,
+ AllowedSigningAlgs: "RS256,ES256,PS256",
+ ClockSkewSeconds: 120,
+ }
+ deps.cfg.WeChat = config.WeChatConnectConfig{
+ Enabled: true,
+ OpenEnabled: true,
+ OpenAppID: "wx-open-config",
+ OpenAppSecret: "wx-open-secret",
+ Mode: "open",
+ Scopes: "snsapi_login",
+ FrontendRedirectURL: "/auth/wechat/callback",
+ }
+ deps.settingRepo.SetAll(map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyEmailVerifyEnabled: "false",
+ service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
+ })
+ },
+ method: http.MethodGet,
+ path: "/api/v1/admin/settings",
+ wantStatus: http.StatusOK,
+ wantJSON: `{
+ "code": 0,
+ "message": "success",
+ "data": {
+ "registration_enabled": true,
+ "email_verify_enabled": false,
+ "registration_email_suffix_whitelist": [],
+ "promo_code_enabled": true,
+ "password_reset_enabled": false,
+ "frontend_url": "",
+ "invitation_code_enabled": false,
+ "totp_enabled": false,
+ "totp_encryption_key_configured": false,
+ "smtp_host": "",
+ "smtp_port": 587,
+ "smtp_username": "",
+ "smtp_password_configured": false,
+ "smtp_from_email": "",
+ "smtp_from_name": "",
+ "smtp_use_tls": false,
+ "turnstile_enabled": false,
+ "turnstile_site_key": "",
+ "turnstile_secret_key_configured": false,
+ "linuxdo_connect_enabled": false,
+ "linuxdo_connect_client_id": "",
+ "linuxdo_connect_client_secret_configured": false,
+ "linuxdo_connect_redirect_url": "",
+ "oidc_connect_enabled": true,
+ "oidc_connect_provider_name": "ConfigOIDC",
+ "oidc_connect_client_id": "oidc-config-client",
+ "oidc_connect_client_secret_configured": true,
+ "oidc_connect_issuer_url": "https://issuer.example.com",
+ "oidc_connect_discovery_url": "",
+ "oidc_connect_authorize_url": "",
+ "oidc_connect_token_url": "",
+ "oidc_connect_userinfo_url": "",
+ "oidc_connect_jwks_url": "",
+ "oidc_connect_scopes": "openid email profile",
+ "oidc_connect_redirect_url": "https://api.example.com/api/v1/auth/oauth/oidc/callback",
+ "oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
+ "oidc_connect_token_auth_method": "client_secret_post",
+ "oidc_connect_use_pkce": true,
+ "oidc_connect_validate_id_token": true,
+ "oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
+ "oidc_connect_clock_skew_seconds": 120,
+ "oidc_connect_require_email_verified": false,
+ "oidc_connect_userinfo_email_path": "",
+ "oidc_connect_userinfo_id_path": "",
+ "oidc_connect_userinfo_username_path": "",
+ "site_name": "Sub2API",
+ "site_logo": "",
+ "site_subtitle": "Subscription to API Conversion Platform",
+ "api_base_url": "",
+ "contact_info": "",
+ "doc_url": "",
+ "home_content": "",
+ "hide_ccs_import_button": false,
+ "purchase_subscription_enabled": false,
+ "purchase_subscription_url": "",
+ "table_default_page_size": 20,
+ "table_page_size_options": [10, 20, 50],
+ "custom_menu_items": [],
+ "custom_endpoints": [],
+ "default_concurrency": 0,
+ "default_balance": 0,
+ "affiliate_rebate_rate": 20,
+ "affiliate_rebate_freeze_hours": 0,
+ "affiliate_rebate_duration_days": 0,
+ "affiliate_rebate_per_invitee_cap": 0,
+ "default_user_rpm_limit": 0,
+ "default_subscriptions": [],
+ "enable_model_fallback": false,
+ "fallback_model_anthropic": "claude-3-5-sonnet-20241022",
+ "fallback_model_openai": "gpt-4o",
+ "fallback_model_gemini": "gemini-2.5-pro",
+ "fallback_model_antigravity": "gemini-2.5-pro",
+ "enable_identity_patch": true,
+ "identity_patch_prompt": "",
+ "ops_monitoring_enabled": false,
+ "ops_realtime_monitoring_enabled": true,
+ "ops_query_mode_default": "auto",
+ "ops_metrics_interval_seconds": 60,
"min_claude_code_version": "",
"max_claude_code_version": "",
"allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
- "custom_menu_items": [],
- "custom_endpoints": []
+ "enable_cch_signing": false,
+ "web_search_emulation_enabled": false,
+ "payment_visible_method_alipay_source": "",
+ "payment_visible_method_wxpay_source": "",
+ "payment_visible_method_alipay_enabled": false,
+ "payment_visible_method_wxpay_enabled": false,
+ "openai_advanced_scheduler_enabled": false,
+ "payment_enabled": false,
+ "payment_min_amount": 0,
+ "payment_max_amount": 0,
+ "payment_daily_limit": 0,
+ "payment_order_timeout_minutes": 0,
+ "payment_max_pending_orders": 0,
+ "payment_enabled_types": null,
+ "payment_balance_disabled": false,
+ "payment_balance_recharge_multiplier": 0,
+ "payment_recharge_fee_rate": 0,
+ "payment_load_balance_strategy": "",
+ "payment_product_name_prefix": "",
+ "payment_product_name_suffix": "",
+ "payment_help_image_url": "",
+ "payment_help_text": "",
+ "payment_cancel_rate_limit_enabled": false,
+ "payment_cancel_rate_limit_max": 0,
+ "payment_cancel_rate_limit_window": 0,
+ "payment_cancel_rate_limit_unit": "",
+ "payment_cancel_rate_limit_window_mode": "",
+ "balance_low_notify_enabled": false,
+ "account_quota_notify_enabled": false,
+ "balance_low_notify_threshold": 0,
+ "balance_low_notify_recharge_url": "",
+ "account_quota_notify_emails": [],
+ "channel_monitor_enabled": true,
+ "channel_monitor_default_interval_seconds": 60,
+ "available_channels_enabled": false,
+ "affiliate_enabled": false,
+ "wechat_connect_enabled": true,
+ "wechat_connect_app_id": "wx-open-config",
+ "wechat_connect_app_secret_configured": true,
+ "wechat_connect_mode": "open",
+ "wechat_connect_open_enabled": true,
+ "wechat_connect_open_app_id": "wx-open-config",
+ "wechat_connect_open_app_secret_configured": true,
+ "wechat_connect_mp_enabled": false,
+ "wechat_connect_mp_app_id": "wx-open-config",
+ "wechat_connect_mp_app_secret_configured": true,
+ "wechat_connect_mobile_enabled": false,
+ "wechat_connect_mobile_app_id": "wx-open-config",
+ "wechat_connect_mobile_app_secret_configured": true,
+ "wechat_connect_redirect_url": "",
+ "wechat_connect_frontend_redirect_url": "/auth/wechat/callback",
+ "wechat_connect_scopes": "snsapi_login",
+ "auth_source_default_email_balance": 0,
+ "auth_source_default_email_concurrency": 5,
+ "auth_source_default_email_subscriptions": [],
+ "auth_source_default_email_grant_on_signup": false,
+ "auth_source_default_email_grant_on_first_bind": false,
+ "auth_source_default_linuxdo_balance": 0,
+ "auth_source_default_linuxdo_concurrency": 5,
+ "auth_source_default_linuxdo_subscriptions": [],
+ "auth_source_default_linuxdo_grant_on_signup": false,
+ "auth_source_default_linuxdo_grant_on_first_bind": false,
+ "auth_source_default_oidc_balance": 0,
+ "auth_source_default_oidc_concurrency": 5,
+ "auth_source_default_oidc_subscriptions": [],
+ "auth_source_default_oidc_grant_on_signup": false,
+ "auth_source_default_oidc_grant_on_first_bind": false,
+ "auth_source_default_wechat_balance": 0,
+ "auth_source_default_wechat_concurrency": 5,
+ "auth_source_default_wechat_subscriptions": [],
+ "auth_source_default_wechat_grant_on_signup": false,
+ "auth_source_default_wechat_grant_on_first_bind": false,
+ "force_email_on_third_party_signup": false
}
}`,
},
@@ -592,6 +1042,7 @@ func TestAPIContracts(t *testing.T) {
type contractDeps struct {
now time.Time
router http.Handler
+ cfg *config.Config
apiKeyRepo *stubApiKeyRepo
groupRepo *stubGroupRepo
userSubRepo *stubUserSubscriptionRepo
@@ -638,7 +1089,7 @@ func newContractDeps(t *testing.T) *contractDeps {
RunMode: config.RunModeStandard,
}
- userService := service.NewUserService(userRepo, nil, nil)
+ userService := service.NewUserService(userRepo, nil, nil, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
@@ -653,11 +1104,11 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
- adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
+ adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
- adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil)
+ adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {
@@ -712,6 +1163,7 @@ func newContractDeps(t *testing.T) *contractDeps {
return &contractDeps{
now: now,
router: r,
+ cfg: cfg,
apiKeyRepo: apiKeyRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
@@ -785,6 +1237,18 @@ func (r *stubUserRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
+func (r *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (r *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (r *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ return errors.New("not implemented")
+}
+
func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
@@ -821,6 +1285,26 @@ func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64
return errors.New("not implemented")
}
+func (r *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+
+func (r *stubUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error {
+ return errors.New("not implemented")
+}
+
+func (r *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (r *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func (r *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ return nil
+}
+
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented")
}
diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go
index a8034e98..023e40bb 100644
--- a/backend/internal/server/http.go
+++ b/backend/internal/server/http.go
@@ -2,12 +2,15 @@
package server
import (
+ "context"
"log"
+ "log/slog"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -56,6 +59,42 @@ func ProvideRouter(
}
}
+ // Wire up websearch Manager builder so it initializes on startup and rebuilds on config save.
+ settingService.SetWebSearchManagerBuilder(context.Background(), func(cfg *service.WebSearchEmulationConfig, proxyURLs map[int64]string) {
+ if cfg == nil || !cfg.Enabled || len(cfg.Providers) == 0 {
+ service.SetWebSearchManager(nil)
+ return
+ }
+ configs := make([]websearch.ProviderConfig, 0, len(cfg.Providers))
+ for _, p := range cfg.Providers {
+ if p.APIKey == "" {
+ continue
+ }
+ pc := websearch.ProviderConfig{
+ Type: p.Type,
+ APIKey: p.APIKey,
+ QuotaLimit: derefInt64(p.QuotaLimit),
+ ExpiresAt: p.ExpiresAt,
+ }
+ if p.SubscribedAt != nil {
+ pc.SubscribedAt = p.SubscribedAt
+ }
+ if p.ProxyID != nil {
+ pc.ProxyID = *p.ProxyID
+ if u, ok := proxyURLs[*p.ProxyID]; ok {
+ pc.ProxyURL = u
+ } else {
+ // Proxy configured but not found — skip this provider to prevent direct connection.
+ slog.Warn("websearch: proxy not found for provider, skipping",
+ "provider", p.Type, "proxy_id", *p.ProxyID)
+ continue
+ }
+ }
+ configs = append(configs, pc)
+ }
+ service.SetWebSearchManager(websearch.NewManager(configs, redisClient))
+ })
+
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
}
@@ -102,3 +141,10 @@ func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
// 不设置 ReadTimeout,因为大请求体可能需要较长时间读取
}
}
+
+func derefInt64(p *int64) int64 {
+ if p == nil {
+ return 0
+ }
+ return *p
+}
diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go
index aafe4a58..dde92dfd 100644
--- a/backend/internal/server/middleware/admin_auth_test.go
+++ b/backend/internal/server/middleware/admin_auth_test.go
@@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -19,7 +20,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
- authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
+ authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,
@@ -39,7 +40,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
return &clone, nil
},
}
- userService := service.NewUserService(userRepo, nil, nil)
+ userService := service.NewUserService(userRepo, nil, nil, nil)
router := gin.New()
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))
@@ -153,6 +154,18 @@ func (s *stubUserRepo) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
+func (s *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (s *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ panic("unexpected UpsertUserAvatar call")
+}
+
+func (s *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ panic("unexpected DeleteUserAvatar call")
+}
+
func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
@@ -161,6 +174,18 @@ func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.Pa
panic("unexpected ListWithFilters call")
}
+func (s *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserIDs call")
+}
+
+func (s *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserID call")
+}
+
+func (s *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ panic("unexpected UpdateUserLastActiveAt call")
+}
+
func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
@@ -189,6 +214,14 @@ func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64
panic("unexpected AddGroupToAllowedGroups call")
}
+func (s *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
+ panic("unexpected ListUserAuthIdentities call")
+}
+
+func (s *stubUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error {
+ panic("unexpected UnbindUserAuthProvider call")
+}
+
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
diff --git a/backend/internal/server/middleware/backend_mode_guard.go b/backend/internal/server/middleware/backend_mode_guard.go
index 46482af3..ae53037e 100644
--- a/backend/internal/server/middleware/backend_mode_guard.go
+++ b/backend/internal/server/middleware/backend_mode_guard.go
@@ -27,23 +27,50 @@ func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFun
}
}
+func backendModeAllowsAuthPath(path string) bool {
+ path = strings.ToLower(strings.TrimSpace(path))
+ for _, suffix := range []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} {
+ if strings.HasSuffix(path, suffix) {
+ return true
+ }
+ }
+
+ for _, suffix := range []string{
+ "/auth/oauth/linuxdo/callback",
+ "/auth/oauth/wechat/callback",
+ "/auth/oauth/wechat/payment/callback",
+ "/auth/oauth/oidc/callback",
+ "/auth/oauth/linuxdo/complete-registration",
+ "/auth/oauth/wechat/complete-registration",
+ "/auth/oauth/oidc/complete-registration",
+ "/auth/oauth/linuxdo/create-account",
+ "/auth/oauth/wechat/create-account",
+ "/auth/oauth/oidc/create-account",
+ "/auth/oauth/linuxdo/bind-login",
+ "/auth/oauth/wechat/bind-login",
+ "/auth/oauth/oidc/bind-login",
+ } {
+ if strings.HasSuffix(path, suffix) {
+ return true
+ }
+ }
+
+ return strings.Contains(path, "/auth/oauth/pending/")
+}
+
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
-// Allows: login, login/2fa, logout, refresh (admin needs these).
-// Blocks: register, forgot-password, reset-password, OAuth, etc.
+// Allows the minimal auth surface admins still need in backend mode, including
+// OAuth callbacks and pending continuations. Handler-level backend mode checks
+// still enforce admin-only login and forbid self-service registration.
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
return func(c *gin.Context) {
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
c.Next()
return
}
- path := c.Request.URL.Path
- // Allow login, 2FA, logout, refresh, public settings
- allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
- for _, suffix := range allowedSuffixes {
- if strings.HasSuffix(path, suffix) {
- c.Next()
- return
- }
+ if backendModeAllowsAuthPath(c.Request.URL.Path) {
+ c.Next()
+ return
}
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
c.Abort()
diff --git a/backend/internal/server/middleware/backend_mode_guard_test.go b/backend/internal/server/middleware/backend_mode_guard_test.go
index 8878ebc9..bd77677b 100644
--- a/backend/internal/server/middleware/backend_mode_guard_test.go
+++ b/backend/internal/server/middleware/backend_mode_guard_test.go
@@ -198,6 +198,96 @@ func TestBackendModeAuthGuard(t *testing.T) {
path: "/api/v1/auth/refresh",
wantStatus: http.StatusOK,
},
+ {
+ name: "enabled_blocks_linuxdo_oauth_start",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/linuxdo/start",
+ wantStatus: http.StatusForbidden,
+ },
+ {
+ name: "enabled_allows_linuxdo_oauth_callback",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/linuxdo/callback",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_blocks_wechat_oauth_start",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/wechat/start",
+ wantStatus: http.StatusForbidden,
+ },
+ {
+ name: "enabled_allows_wechat_oauth_callback",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/wechat/callback",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_blocks_wechat_payment_oauth_start",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/wechat/payment/start",
+ wantStatus: http.StatusForbidden,
+ },
+ {
+ name: "enabled_allows_wechat_payment_oauth_callback",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/wechat/payment/callback",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_blocks_oidc_oauth_start",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/oidc/start",
+ wantStatus: http.StatusForbidden,
+ },
+ {
+ name: "enabled_allows_oidc_oauth_callback",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/oidc/callback",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_oauth_pending_exchange",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/pending/exchange",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_oauth_pending_send_verify_code",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/pending/send-verify-code",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_oauth_pending_create_account",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/pending/create-account",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_oauth_pending_bind_login",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/pending/bind-login",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_provider_bind_login",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/oidc/bind-login",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_provider_create_account",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/wechat/create-account",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_legacy_complete_registration",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/linuxdo/complete-registration",
+ wantStatus: http.StatusOK,
+ },
{
name: "enabled_blocks_register",
enabled: "true",
diff --git a/backend/internal/server/middleware/jwt_auth.go b/backend/internal/server/middleware/jwt_auth.go
index 4aceb355..48cb9004 100644
--- a/backend/internal/server/middleware/jwt_auth.go
+++ b/backend/internal/server/middleware/jwt_auth.go
@@ -1,6 +1,7 @@
package middleware
import (
+ "context"
"errors"
"strings"
@@ -11,11 +12,19 @@ import (
// NewJWTAuthMiddleware 创建 JWT 认证中间件
func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware {
- return JWTAuthMiddleware(jwtAuth(authService, userService))
+ return JWTAuthMiddleware(jwtAuth(authService, userService, userService))
+}
+
+type jwtUserReader interface {
+ GetByID(ctx context.Context, id int64) (*service.User, error)
+}
+
+type userActivityToucher interface {
+ TouchLastActiveForUser(ctx context.Context, user *service.User)
}
// jwtAuth JWT认证中间件实现
-func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc {
+func jwtAuth(authService *service.AuthService, userService jwtUserReader, activityToucher userActivityToucher) gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization header中提取token
authHeader := c.GetHeader("Authorization")
@@ -73,6 +82,9 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
+ if activityToucher != nil {
+ activityToucher.TouchLastActiveForUser(c.Request.Context(), user)
+ }
c.Next()
}
diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go
index ad9c1b5b..a643d3bc 100644
--- a/backend/internal/server/middleware/jwt_auth_test.go
+++ b/backend/internal/server/middleware/jwt_auth_test.go
@@ -9,6 +9,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -30,6 +31,25 @@ func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, e
return u, nil
}
+func (r *stubJWTUserRepo) GetUserAvatar(_ context.Context, _ int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (r *stubJWTUserRepo) UpdateUserLastActiveAt(_ context.Context, _ int64, _ time.Time) error {
+ return nil
+}
+
+type recordingActivityToucher struct {
+ userIDs []int64
+}
+
+func (r *recordingActivityToucher) TouchLastActiveForUser(_ context.Context, user *service.User) {
+ if user == nil {
+ return
+ }
+ r.userIDs = append(r.userIDs, user.ID)
+}
+
// newJWTTestEnv 创建 JWT 认证中间件测试环境。
// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。
func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) {
@@ -40,8 +60,8 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: users}
- authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
- userSvc := service.NewUserService(userRepo, nil, nil)
+ authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
+ userSvc := service.NewUserService(userRepo, nil, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc)
r := gin.New()
@@ -106,6 +126,45 @@ func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code)
}
+func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
+ user := &service.User{
+ ID: 1,
+ Email: "test@example.com",
+ Role: "user",
+ Status: service.StatusActive,
+ Concurrency: 5,
+ TokenVersion: 1,
+ }
+
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!"
+ cfg.JWT.AccessTokenExpireMinutes = 60
+
+ userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
+ authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil, nil)
+ userSvc := service.NewUserService(userRepo, nil, nil, nil)
+ toucher := &recordingActivityToucher{}
+
+ r := gin.New()
+ r.Use(jwtAuth(authSvc, userSvc, toucher))
+ r.GET("/protected", func(c *gin.Context) {
+ c.Status(http.StatusOK)
+ })
+
+ token, err := authSvc.GenerateToken(user)
+ require.NoError(t, err)
+
+ w := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/protected", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+ r.ServeHTTP(w, req)
+
+ require.Equal(t, http.StatusOK, w.Code)
+ require.Equal(t, []int64{1}, toucher.userIDs)
+}
+
func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) {
router, _ := newJWTTestEnv(nil)
diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go
index d9ec951e..398c0351 100644
--- a/backend/internal/server/middleware/security_headers.go
+++ b/backend/internal/server/middleware/security_headers.go
@@ -18,6 +18,8 @@ const (
NonceTemplate = "__CSP_NONCE__"
// CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
+ // StripeDomain is the domain for Stripe.js SDK
+ StripeDomain = "https://*.stripe.com"
)
// GenerateNonce generates a cryptographically secure random nonce.
@@ -94,12 +96,13 @@ func isAPIRoutePath(c *gin.Context) bool {
return strings.HasPrefix(path, "/v1/") ||
strings.HasPrefix(path, "/v1beta/") ||
strings.HasPrefix(path, "/antigravity/") ||
- strings.HasPrefix(path, "/sora/") ||
- strings.HasPrefix(path, "/responses")
+ strings.HasPrefix(path, "/responses") ||
+ strings.HasPrefix(path, "/images")
}
-// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
-// This allows the application to work correctly even if the config file has an older CSP policy.
+// enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights,
+// and Stripe.js domains. This allows the application to work correctly even if the
+// config file has an older CSP policy.
func enhanceCSPPolicy(policy string) string {
// Add nonce placeholder to script-src if not present
if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
@@ -111,6 +114,12 @@ func enhanceCSPPolicy(policy string) string {
policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
}
+ // Add Stripe.js domain to script-src and frame-src if not present
+ if !strings.Contains(policy, "stripe.com") {
+ policy = addToDirective(policy, "script-src", StripeDomain)
+ policy = addToDirective(policy, "frame-src", StripeDomain)
+ }
+
return policy
}
diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go
index 99701531..a507b6f8 100644
--- a/backend/internal/server/router.go
+++ b/backend/internal/server/router.go
@@ -109,7 +109,7 @@ func registerRoutes(
// 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
- routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
+ routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService)
}
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index e04dae85..1c786f50 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -34,8 +34,6 @@ func RegisterAdminRoutes(
// OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h)
- // Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
- registerSoraOAuthRoutes(admin, h)
// Gemini OAuth
registerGeminiOAuthRoutes(admin, h)
@@ -87,6 +85,15 @@ func RegisterAdminRoutes(
// 定时测试计划
registerScheduledTestRoutes(admin, h)
+
+ // 渠道管理
+ registerChannelRoutes(admin, h)
+
+ // 渠道监控
+ registerChannelMonitorRoutes(admin, h)
+
+ // 邀请返利(专属用户管理)
+ registerAffiliateRoutes(admin, h)
}
}
@@ -211,6 +218,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
+ users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
@@ -219,6 +227,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory)
users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup)
+ users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus)
// User attribute values
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
@@ -242,6 +251,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers)
groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers)
+ groups.PUT("/:id/rpm-overrides", h.Admin.Group.BatchSetGroupRPMOverrides)
+ groups.DELETE("/:id/rpm-overrides", h.Admin.Group.ClearGroupRPMOverrides)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
}
@@ -318,19 +329,6 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
-func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
- sora := admin.Group("/sora")
- {
- sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
- sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
- sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
- sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
- sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
- sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
- sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
- }
-}
-
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
gemini := admin.Group("/gemini")
{
@@ -419,15 +417,11 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Beta 策略配置
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
- // Sora S3 存储配置
- adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
- adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
- adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection)
- adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles)
- adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile)
- adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile)
- adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile)
- adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile)
+ // Web Search 模拟配置
+ adminSettings.GET("/web-search-emulation", h.Admin.Setting.GetWebSearchEmulationConfig)
+ adminSettings.PUT("/web-search-emulation", h.Admin.Setting.UpdateWebSearchEmulationConfig)
+ adminSettings.POST("/web-search-emulation/test", h.Admin.Setting.TestWebSearchEmulation)
+ adminSettings.POST("/web-search-emulation/reset-usage", h.Admin.Setting.ResetWebSearchUsage)
}
}
@@ -567,3 +561,54 @@ func registerTLSFingerprintProfileRoutes(admin *gin.RouterGroup, h *handler.Hand
profiles.DELETE("/:id", h.Admin.TLSFingerprintProfile.Delete)
}
}
+
+func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ channels := admin.Group("/channels")
+ {
+ channels.GET("", h.Admin.Channel.List)
+ channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing)
+ channels.GET("/:id", h.Admin.Channel.GetByID)
+ channels.POST("", h.Admin.Channel.Create)
+ channels.PUT("/:id", h.Admin.Channel.Update)
+ channels.DELETE("/:id", h.Admin.Channel.Delete)
+ }
+}
+
+func registerChannelMonitorRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ monitors := admin.Group("/channel-monitors")
+ {
+ monitors.GET("", h.Admin.ChannelMonitor.List)
+ monitors.POST("", h.Admin.ChannelMonitor.Create)
+ monitors.GET("/:id", h.Admin.ChannelMonitor.Get)
+ monitors.PUT("/:id", h.Admin.ChannelMonitor.Update)
+ monitors.DELETE("/:id", h.Admin.ChannelMonitor.Delete)
+ monitors.POST("/:id/run", h.Admin.ChannelMonitor.Run)
+ monitors.GET("/:id/history", h.Admin.ChannelMonitor.History)
+ }
+
+ templates := admin.Group("/channel-monitor-templates")
+ {
+ templates.GET("", h.Admin.ChannelMonitorTemplate.List)
+ templates.POST("", h.Admin.ChannelMonitorTemplate.Create)
+ templates.GET("/:id", h.Admin.ChannelMonitorTemplate.Get)
+ templates.PUT("/:id", h.Admin.ChannelMonitorTemplate.Update)
+ templates.DELETE("/:id", h.Admin.ChannelMonitorTemplate.Delete)
+ templates.GET("/:id/monitors", h.Admin.ChannelMonitorTemplate.AssociatedMonitors)
+ templates.POST("/:id/apply", h.Admin.ChannelMonitorTemplate.Apply)
+ }
+}
+
+// registerAffiliateRoutes 注册邀请返利的管理端路由(专属用户配置)
+func registerAffiliateRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ affiliates := admin.Group("/affiliates")
+ {
+ users := affiliates.Group("/users")
+ {
+ users.GET("", h.Admin.Affiliate.ListUsers)
+ users.GET("/lookup", h.Admin.Affiliate.LookupUsers)
+ users.POST("/batch-rate", h.Admin.Affiliate.BatchSetRate)
+ users.PUT("/:user_id", h.Admin.Affiliate.UpdateUserSettings)
+ users.DELETE("/:user_id", h.Admin.Affiliate.ClearUserSettings)
+ }
+ }
+}
diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go
index a6c0ecf5..642a2103 100644
--- a/backend/internal/server/routes/auth.go
+++ b/backend/internal/server/routes/auth.go
@@ -63,13 +63,109 @@ func RegisterAuthRoutes(
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
+ auth.GET("/oauth/linuxdo/bind/start", func(c *gin.Context) {
+ query := c.Request.URL.Query()
+ query.Set("intent", "bind_current_user")
+ c.Request.URL.RawQuery = query.Encode()
+ h.Auth.LinuxDoOAuthStart(c)
+ })
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
+ auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart)
+ auth.GET("/oauth/wechat/bind/start", func(c *gin.Context) {
+ query := c.Request.URL.Query()
+ query.Set("intent", "bind_current_user")
+ c.Request.URL.RawQuery = query.Encode()
+ h.Auth.WeChatOAuthStart(c)
+ })
+ auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback)
+ auth.GET("/oauth/wechat/payment/start", h.Auth.WeChatPaymentOAuthStart)
+ auth.GET("/oauth/wechat/payment/callback", h.Auth.WeChatPaymentOAuthCallback)
+ auth.POST("/oauth/pending/exchange",
+ rateLimiter.LimitWithOptions("oauth-pending-exchange", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.ExchangePendingOAuthCompletion,
+ )
+ auth.POST("/oauth/pending/send-verify-code",
+ rateLimiter.LimitWithOptions("oauth-pending-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.SendPendingOAuthVerifyCode,
+ )
+ auth.POST("/oauth/pending/create-account",
+ rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreatePendingOAuthAccount,
+ )
+ auth.POST("/oauth/pending/bind-login",
+ rateLimiter.LimitWithOptions("oauth-pending-bind-login", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindPendingOAuthLogin,
+ )
auth.POST("/oauth/linuxdo/complete-registration",
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteLinuxDoOAuthRegistration,
)
+ auth.POST("/oauth/linuxdo/bind-login",
+ rateLimiter.LimitWithOptions("oauth-linuxdo-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindLinuxDoOAuthLogin,
+ )
+ auth.POST("/oauth/linuxdo/create-account",
+ rateLimiter.LimitWithOptions("oauth-linuxdo-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateLinuxDoOAuthAccount,
+ )
+ auth.POST("/oauth/wechat/complete-registration",
+ rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CompleteWeChatOAuthRegistration,
+ )
+ auth.POST("/oauth/wechat/bind-login",
+ rateLimiter.LimitWithOptions("oauth-wechat-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindWeChatOAuthLogin,
+ )
+ auth.POST("/oauth/wechat/create-account",
+ rateLimiter.LimitWithOptions("oauth-wechat-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateWeChatOAuthAccount,
+ )
+ auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
+ auth.GET("/oauth/oidc/bind/start", func(c *gin.Context) {
+ query := c.Request.URL.Query()
+ query.Set("intent", "bind_current_user")
+ c.Request.URL.RawQuery = query.Encode()
+ h.Auth.OIDCOAuthStart(c)
+ })
+ auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
+ auth.POST("/oauth/oidc/complete-registration",
+ rateLimiter.LimitWithOptions("oauth-oidc-complete", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CompleteOIDCOAuthRegistration,
+ )
+ auth.POST("/oauth/oidc/bind-login",
+ rateLimiter.LimitWithOptions("oauth-oidc-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindOIDCOAuthLogin,
+ )
+ auth.POST("/oauth/oidc/create-account",
+ rateLimiter.LimitWithOptions("oauth-oidc-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateOIDCOAuthAccount,
+ )
}
// 公开设置(无需认证)
@@ -86,5 +182,6 @@ func RegisterAuthRoutes(
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 撤销所有会话(需要认证)
authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions)
+ authenticated.POST("/auth/oauth/bind-token", h.Auth.PrepareOAuthBindAccessTokenCookie)
}
}
diff --git a/backend/internal/server/routes/auth_rate_limit_test.go b/backend/internal/server/routes/auth_rate_limit_test.go
index 4f411cec..07a66efb 100644
--- a/backend/internal/server/routes/auth_rate_limit_test.go
+++ b/backend/internal/server/routes/auth_rate_limit_test.go
@@ -52,6 +52,7 @@ func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) {
"/api/v1/auth/login",
"/api/v1/auth/login/2fa",
"/api/v1/auth/send-verify-code",
+ "/api/v1/auth/oauth/pending/send-verify-code",
}
for _, path := range paths {
diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go
index 072cfdee..9541cda1 100644
--- a/backend/internal/server/routes/gateway.go
+++ b/backend/internal/server/routes/gateway.go
@@ -23,11 +23,6 @@ func RegisterGatewayRoutes(
cfg *config.Config,
) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
- soraMaxBodySize := cfg.Gateway.SoraMaxBodySize
- if soraMaxBodySize <= 0 {
- soraMaxBodySize = cfg.Gateway.MaxBodySize
- }
- soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
endpointNorm := handler.InboundEndpointMiddleware()
@@ -93,6 +88,30 @@ func RegisterGatewayRoutes(
}
h.Gateway.ChatCompletions(c)
})
+ gateway.POST("/images/generations", func(c *gin.Context) {
+ if getGroupPlatform(c) != service.PlatformOpenAI {
+ c.JSON(http.StatusNotFound, gin.H{
+ "error": gin.H{
+ "type": "not_found_error",
+ "message": "Images API is not supported for this platform",
+ },
+ })
+ return
+ }
+ h.OpenAIGateway.Images(c)
+ })
+ gateway.POST("/images/edits", func(c *gin.Context) {
+ if getGroupPlatform(c) != service.PlatformOpenAI {
+ c.JSON(http.StatusNotFound, gin.H{
+ "error": gin.H{
+ "type": "not_found_error",
+ "message": "Images API is not supported for this platform",
+ },
+ })
+ return
+ }
+ h.OpenAIGateway.Images(c)
+ })
}
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
@@ -121,6 +140,13 @@ func RegisterGatewayRoutes(
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler)
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, responsesHandler)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
+ codexDirect := r.Group("/backend-api/codex")
+ codexDirect.Use(bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic)
+ {
+ codexDirect.POST("/responses", responsesHandler)
+ codexDirect.POST("/responses/*subpath", responsesHandler)
+ codexDirect.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
+ }
// OpenAI Chat Completions API(不带v1前缀的别名)— auto-route based on group platform
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
if getGroupPlatform(c) == service.PlatformOpenAI {
@@ -129,6 +155,30 @@ func RegisterGatewayRoutes(
}
h.Gateway.ChatCompletions(c)
})
+ r.POST("/images/generations", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
+ if getGroupPlatform(c) != service.PlatformOpenAI {
+ c.JSON(http.StatusNotFound, gin.H{
+ "error": gin.H{
+ "type": "not_found_error",
+ "message": "Images API is not supported for this platform",
+ },
+ })
+ return
+ }
+ h.OpenAIGateway.Images(c)
+ })
+ r.POST("/images/edits", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, func(c *gin.Context) {
+ if getGroupPlatform(c) != service.PlatformOpenAI {
+ c.JSON(http.StatusNotFound, gin.H{
+ "error": gin.H{
+ "type": "not_found_error",
+ "message": "Images API is not supported for this platform",
+ },
+ })
+ return
+ }
+ h.OpenAIGateway.Images(c)
+ })
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
@@ -163,28 +213,6 @@ func RegisterGatewayRoutes(
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
}
- // Sora 专用路由(强制使用 sora 平台)
- soraV1 := r.Group("/sora/v1")
- soraV1.Use(soraBodyLimit)
- soraV1.Use(clientRequestID)
- soraV1.Use(opsErrorLogger)
- soraV1.Use(endpointNorm)
- soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
- soraV1.Use(gin.HandlerFunc(apiKeyAuth))
- soraV1.Use(requireGroupAnthropic)
- {
- soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
- soraV1.GET("/models", h.Gateway.Models)
- }
-
- // Sora 媒体代理(可选 API Key 验证)
- if cfg.Gateway.SoraMediaRequireAPIKey {
- r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy)
- } else {
- r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy)
- }
- // Sora 媒体代理(签名 URL,无需 API Key)
- r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
}
// getGroupPlatform extracts the group platform from the API Key stored in context.
diff --git a/backend/internal/server/routes/gateway_test.go b/backend/internal/server/routes/gateway_test.go
index 00edd31b..19ef5686 100644
--- a/backend/internal/server/routes/gateway_test.go
+++ b/backend/internal/server/routes/gateway_test.go
@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -22,9 +23,13 @@ func newGatewayRoutesTestRouter() *gin.Engine {
&handler.Handlers{
Gateway: &handler.GatewayHandler{},
OpenAIGateway: &handler.OpenAIGatewayHandler{},
- SoraGateway: &handler.SoraGatewayHandler{},
},
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
+ groupID := int64(1)
+ c.Set(string(servermiddleware.ContextKeyAPIKey), &service.APIKey{
+ GroupID: &groupID,
+ Group: &service.Group{Platform: service.PlatformOpenAI},
+ })
c.Next()
}),
nil,
@@ -40,7 +45,12 @@ func newGatewayRoutesTestRouter() *gin.Engine {
func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) {
router := newGatewayRoutesTestRouter()
- for _, path := range []string{"/v1/responses/compact", "/responses/compact"} {
+ for _, path := range []string{
+ "/v1/responses/compact",
+ "/responses/compact",
+ "/backend-api/codex/responses",
+ "/backend-api/codex/responses/compact",
+ } {
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-5"}`))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
@@ -49,3 +59,21 @@ func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) {
require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI responses handler", path)
}
}
+
+func TestGatewayRoutesOpenAIImagesPathsAreRegistered(t *testing.T) {
+ router := newGatewayRoutesTestRouter()
+
+ for _, path := range []string{
+ "/v1/images/generations",
+ "/v1/images/edits",
+ "/images/generations",
+ "/images/edits",
+ } {
+ req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-image-2","prompt":"draw a cat"}`))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+
+ router.ServeHTTP(w, req)
+ require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI images handler", path)
+ }
+}
diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go
new file mode 100644
index 00000000..e4828ead
--- /dev/null
+++ b/backend/internal/server/routes/payment.go
@@ -0,0 +1,106 @@
+package routes
+
+import (
+ "github.com/Wei-Shaw/sub2api/internal/handler"
+ "github.com/Wei-Shaw/sub2api/internal/handler/admin"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// RegisterPaymentRoutes registers all payment-related routes:
+// user-facing endpoints, webhook endpoints, and admin endpoints.
+func RegisterPaymentRoutes(
+ v1 *gin.RouterGroup,
+ paymentHandler *handler.PaymentHandler,
+ webhookHandler *handler.PaymentWebhookHandler,
+ adminPaymentHandler *admin.PaymentHandler,
+ jwtAuth middleware.JWTAuthMiddleware,
+ adminAuth middleware.AdminAuthMiddleware,
+ settingService *service.SettingService,
+) {
+ // --- User-facing payment endpoints (authenticated) ---
+ authenticated := v1.Group("/payment")
+ authenticated.Use(gin.HandlerFunc(jwtAuth))
+ authenticated.Use(middleware.BackendModeUserGuard(settingService))
+ {
+ authenticated.GET("/config", paymentHandler.GetPaymentConfig)
+ authenticated.GET("/checkout-info", paymentHandler.GetCheckoutInfo)
+ authenticated.GET("/plans", paymentHandler.GetPlans)
+ authenticated.GET("/channels", paymentHandler.GetChannels)
+ authenticated.GET("/limits", paymentHandler.GetLimits)
+
+ orders := authenticated.Group("/orders")
+ {
+ orders.POST("", paymentHandler.CreateOrder)
+ orders.POST("/verify", paymentHandler.VerifyOrder)
+ orders.GET("/my", paymentHandler.GetMyOrders)
+ orders.GET("/:id", paymentHandler.GetOrder)
+ orders.POST("/:id/cancel", paymentHandler.CancelOrder)
+ orders.POST("/:id/refund-request", paymentHandler.RequestRefund)
+ orders.GET("/refund-eligible-providers", paymentHandler.GetRefundEligibleProviders)
+ }
+ }
+
+ // --- Public payment endpoints (no auth) ---
+ // Signed resume-token recovery is the preferred public lookup path.
+ // The legacy anonymous out_trade_no verify endpoint remains available as a
+ // persisted-state compatibility path for staggered upgrades.
+ public := v1.Group("/payment/public")
+ {
+ public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)
+ public.POST("/orders/resolve", paymentHandler.ResolveOrderPublicByResumeToken)
+ }
+
+ // --- Webhook endpoints (no auth) ---
+ webhook := v1.Group("/payment/webhook")
+ {
+ // EasyPay sends GET callbacks with query params
+ webhook.GET("/easypay", webhookHandler.EasyPayNotify)
+ webhook.POST("/easypay", webhookHandler.EasyPayNotify)
+ webhook.POST("/alipay", webhookHandler.AlipayNotify)
+ webhook.POST("/wxpay", webhookHandler.WxpayNotify)
+ webhook.POST("/stripe", webhookHandler.StripeWebhook)
+ }
+
+ // --- Admin payment endpoints (admin auth) ---
+ adminGroup := v1.Group("/admin/payment")
+ adminGroup.Use(gin.HandlerFunc(adminAuth))
+ {
+ // Dashboard
+ adminGroup.GET("/dashboard", adminPaymentHandler.GetDashboard)
+
+ // Config
+ adminGroup.GET("/config", adminPaymentHandler.GetConfig)
+ adminGroup.PUT("/config", adminPaymentHandler.UpdateConfig)
+
+ // Orders
+ adminOrders := adminGroup.Group("/orders")
+ {
+ adminOrders.GET("", adminPaymentHandler.ListOrders)
+ adminOrders.GET("/:id", adminPaymentHandler.GetOrderDetail)
+ adminOrders.POST("/:id/cancel", adminPaymentHandler.CancelOrder)
+ adminOrders.POST("/:id/retry", adminPaymentHandler.RetryFulfillment)
+ adminOrders.POST("/:id/refund", adminPaymentHandler.ProcessRefund)
+ }
+
+ // Subscription Plans
+ plans := adminGroup.Group("/plans")
+ {
+ plans.GET("", adminPaymentHandler.ListPlans)
+ plans.POST("", adminPaymentHandler.CreatePlan)
+ plans.PUT("/:id", adminPaymentHandler.UpdatePlan)
+ plans.DELETE("/:id", adminPaymentHandler.DeletePlan)
+ }
+
+ // Provider Instances
+ providers := adminGroup.Group("/providers")
+ {
+ providers.GET("", adminPaymentHandler.ListProviders)
+ providers.POST("", adminPaymentHandler.CreateProvider)
+ providers.PUT("/:id", adminPaymentHandler.UpdateProvider)
+ providers.DELETE("/:id", adminPaymentHandler.DeleteProvider)
+ }
+ }
+}
diff --git a/backend/internal/server/routes/sora_client.go b/backend/internal/server/routes/sora_client.go
deleted file mode 100644
index 13fceb81..00000000
--- a/backend/internal/server/routes/sora_client.go
+++ /dev/null
@@ -1,36 +0,0 @@
-package routes
-
-import (
- "github.com/Wei-Shaw/sub2api/internal/handler"
- "github.com/Wei-Shaw/sub2api/internal/server/middleware"
- "github.com/Wei-Shaw/sub2api/internal/service"
-
- "github.com/gin-gonic/gin"
-)
-
-// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。
-func RegisterSoraClientRoutes(
- v1 *gin.RouterGroup,
- h *handler.Handlers,
- jwtAuth middleware.JWTAuthMiddleware,
- settingService *service.SettingService,
-) {
- if h.SoraClient == nil {
- return
- }
-
- authenticated := v1.Group("/sora")
- authenticated.Use(gin.HandlerFunc(jwtAuth))
- authenticated.Use(middleware.BackendModeUserGuard(settingService))
- {
- authenticated.POST("/generate", h.SoraClient.Generate)
- authenticated.GET("/generations", h.SoraClient.ListGenerations)
- authenticated.GET("/generations/:id", h.SoraClient.GetGeneration)
- authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration)
- authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration)
- authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage)
- authenticated.GET("/quota", h.SoraClient.GetQuota)
- authenticated.GET("/models", h.SoraClient.GetModels)
- authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus)
- }
-}
diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go
index c3b82742..9976954c 100644
--- a/backend/internal/server/routes/user.go
+++ b/backend/internal/server/routes/user.go
@@ -25,6 +25,21 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
+ user.GET("/aff", h.User.GetAffiliate)
+ user.POST("/aff/transfer", h.User.TransferAffiliateQuota)
+ user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode)
+ user.POST("/account-bindings/email", h.User.BindEmailIdentity)
+ user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)
+ user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding)
+
+ // 通知邮箱管理
+ notifyEmail := user.Group("/notify-email")
+ {
+ notifyEmail.POST("/send-code", h.User.SendNotifyEmailCode)
+ notifyEmail.POST("/verify", h.User.VerifyNotifyEmail)
+ notifyEmail.PUT("/toggle", h.User.ToggleNotifyEmail)
+ notifyEmail.DELETE("", h.User.RemoveNotifyEmail)
+ }
// TOTP 双因素认证
totp := user.Group("/totp")
@@ -55,6 +70,12 @@ func RegisterUserRoutes(
groups.GET("/rates", h.APIKey.GetUserGroupRates)
}
+ // 用户可用渠道(非管理员接口)
+ channels := authenticated.Group("/channels")
+ {
+ channels.GET("/available", h.AvailableChannel.List)
+ }
+
// 使用记录
usage := authenticated.Group("/usage")
{
@@ -90,5 +111,12 @@ func RegisterUserRoutes(
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
+
+ // 渠道监控(用户只读)
+ monitors := authenticated.Group("/channel-monitors")
+ {
+ monitors.GET("", h.ChannelMonitor.List)
+ monitors.GET("/:id/status", h.ChannelMonitor.GetStatus)
+ }
}
}
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 512195e3..cd06ffa3 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"hash/fnv"
+ "log/slog"
"reflect"
"sort"
"strconv"
@@ -120,6 +121,9 @@ func (a *Account) IsSchedulable() bool {
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
return false
}
+ if a.IsAPIKeyOrBedrock() && a.IsQuotaExceeded() {
+ return false
+ }
return true
}
@@ -389,6 +393,56 @@ func parseTempUnschedInt(value any) int {
return 0
}
+const (
+ // OpenAICompactModeAuto follows compact-probe results when deciding compact eligibility.
+ OpenAICompactModeAuto = "auto"
+ // OpenAICompactModeForceOn always treats the account as compact-supported.
+ OpenAICompactModeForceOn = "force_on"
+ // OpenAICompactModeForceOff always treats the account as compact-unsupported.
+ OpenAICompactModeForceOff = "force_off"
+)
+
+func normalizeOpenAICompactMode(mode string) string {
+ switch strings.ToLower(strings.TrimSpace(mode)) {
+ case OpenAICompactModeForceOn:
+ return OpenAICompactModeForceOn
+ case OpenAICompactModeForceOff:
+ return OpenAICompactModeForceOff
+ default:
+ return OpenAICompactModeAuto
+ }
+}
+
+func stringMappingFromRaw(raw any) map[string]string {
+ switch mapping := raw.(type) {
+ case map[string]any:
+ if len(mapping) == 0 {
+ return nil
+ }
+ result := make(map[string]string, len(mapping))
+ for key, value := range mapping {
+ if str, ok := value.(string); ok {
+ result[key] = str
+ }
+ }
+ if len(result) == 0 {
+ return nil
+ }
+ return result
+ case map[string]string:
+ if len(mapping) == 0 {
+ return nil
+ }
+ result := make(map[string]string, len(mapping))
+ for key, value := range mapping {
+ result[key] = value
+ }
+ return result
+ default:
+ return nil
+ }
+}
+
func (a *Account) GetModelMapping() map[string]string {
credentialsPtr := mapPtr(a.Credentials)
rawMapping, _ := a.Credentials["model_mapping"].(map[string]any)
@@ -594,6 +648,77 @@ func (a *Account) ResolveMappedModel(requestedModel string) (mappedModel string,
return requestedModel, false
}
+// GetOpenAICompactMode returns the compact routing mode for an OpenAI account.
+// Missing or invalid values fall back to "auto".
+func (a *Account) GetOpenAICompactMode() string {
+ if a == nil || !a.IsOpenAI() || a.Extra == nil {
+ return OpenAICompactModeAuto
+ }
+ mode, _ := a.Extra["openai_compact_mode"].(string)
+ return normalizeOpenAICompactMode(mode)
+}
+
+// OpenAICompactSupportKnown reports whether compact capability is known for this
+// account and, when known, whether it is supported.
+func (a *Account) OpenAICompactSupportKnown() (supported bool, known bool) {
+ if a == nil || !a.IsOpenAI() {
+ return false, false
+ }
+
+ switch a.GetOpenAICompactMode() {
+ case OpenAICompactModeForceOn:
+ return true, true
+ case OpenAICompactModeForceOff:
+ return false, true
+ }
+
+ if a.Extra == nil {
+ return false, false
+ }
+ supported, ok := a.Extra["openai_compact_supported"].(bool)
+ if !ok {
+ return false, false
+ }
+ return supported, true
+}
+
+// AllowsOpenAICompact reports whether the account may be considered for compact
+// requests. Unknown capability remains allowed to avoid breaking older accounts
+// before an explicit probe has been run.
+func (a *Account) AllowsOpenAICompact() bool {
+ if a == nil || !a.IsOpenAI() {
+ return false
+ }
+ supported, known := a.OpenAICompactSupportKnown()
+ if !known {
+ return true
+ }
+ return supported
+}
+
+// GetCompactModelMapping returns compact-only model remapping configuration.
+// This mapping is intended for /responses/compact only and does not affect
+// normal /responses traffic.
+func (a *Account) GetCompactModelMapping() map[string]string {
+ if a == nil || a.Credentials == nil {
+ return nil
+ }
+ return stringMappingFromRaw(a.Credentials["compact_model_mapping"])
+}
+
+// ResolveCompactMappedModel resolves compact-only model remapping and reports
+// whether a compact-specific mapping rule matched.
+func (a *Account) ResolveCompactMappedModel(requestedModel string) (mappedModel string, matched bool) {
+ mapping := a.GetCompactModelMapping()
+ if len(mapping) == 0 {
+ return requestedModel, false
+ }
+ if mappedModel, matched := resolveRequestedModelInMapping(mapping, requestedModel); matched {
+ return mappedModel, true
+ }
+ return requestedModel, false
+}
+
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeAPIKey {
return ""
@@ -907,6 +1032,32 @@ func (a *Account) GetChatGPTAccountID() string {
return a.GetCredential("chatgpt_account_id")
}
+func (a *Account) GetOpenAIDeviceID() string {
+ if !a.IsOpenAIOAuth() {
+ return ""
+ }
+ return strings.TrimSpace(a.GetExtraString("openai_device_id"))
+}
+
+func (a *Account) GetOpenAISessionID() string {
+ if !a.IsOpenAIOAuth() {
+ return ""
+ }
+ return strings.TrimSpace(a.GetExtraString("openai_session_id"))
+}
+
+func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapability) bool {
+ if !a.IsOpenAI() {
+ return false
+ }
+ switch capability {
+ case OpenAIImagesCapabilityBasic, OpenAIImagesCapabilityNative:
+ return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey
+ default:
+ return true
+ }
+}
+
func (a *Account) GetChatGPTUserID() string {
if !a.IsOpenAIOAuth() {
return ""
@@ -969,7 +1120,7 @@ func (a *Account) IsOveragesEnabled() bool {
return false
}
-// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
+// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用"自动透传(仅替换认证)"。
//
// 新字段:accounts.extra.openai_passthrough。
// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。
@@ -1133,7 +1284,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return resolvedDefault
}
-// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
+// IsOpenAIWSForceHTTPEnabled 返回账号级"强制 HTTP"开关。
// 字段:accounts.extra.openai_ws_force_http。
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil {
@@ -1158,7 +1309,7 @@ func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
}
-// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)”。
+// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用"自动透传(仅替换认证)"。
// 字段:accounts.extra.anthropic_passthrough。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
@@ -1169,7 +1320,42 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
return ok && enabled
}
-// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。
+// WebSearch 模拟三态常量
+const (
+ WebSearchModeDefault = "default" // 跟随渠道配置
+ WebSearchModeEnabled = "enabled" // 强制开启
+ WebSearchModeDisabled = "disabled" // 强制关闭
+)
+
+// GetWebSearchEmulationMode 返回账号的 WebSearch 模拟模式。
+// 三态:default(跟随渠道)/ enabled(强制开启)/ disabled(强制关闭)。
+// 兼容旧 bool 值:true→enabled, false→default(并记录 debug 日志)。
+func (a *Account) GetWebSearchEmulationMode() string {
+ if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
+ return WebSearchModeDefault
+ }
+ raw := a.Extra[featureKeyWebSearchEmulation]
+ // Tolerant: legacy bool values (pre-migration or stale writes)
+ if b, ok := raw.(bool); ok {
+ slog.Debug("legacy bool web_search_emulation value", "account_id", a.ID, "value", b)
+ if b {
+ return WebSearchModeEnabled
+ }
+ return WebSearchModeDefault
+ }
+ mode, ok := raw.(string)
+ if !ok {
+ return WebSearchModeDefault
+ }
+ switch mode {
+ case WebSearchModeEnabled, WebSearchModeDisabled:
+ return mode
+ default:
+ return WebSearchModeDefault
+ }
+}
+
+// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。
// 字段:accounts.extra.codex_cli_only。
// 字段缺失或类型不正确时,按 false(关闭)处理。
func (a *Account) IsCodexCLIOnlyEnabled() bool {
@@ -1395,6 +1581,19 @@ func (a *Account) getExtraTime(key string) time.Time {
return time.Time{}
}
+// getExtraBool 从 Extra 中读取指定 key 的 bool 值
+func (a *Account) getExtraBool(key string) bool {
+ if a.Extra == nil {
+ return false
+ }
+ if v, ok := a.Extra[key]; ok {
+ if b, ok := v.(bool); ok {
+ return b
+ }
+ }
+ return false
+}
+
// getExtraString 从 Extra 中读取指定 key 的字符串值
func (a *Account) getExtraString(key string) string {
if a.Extra == nil {
@@ -1408,6 +1607,14 @@ func (a *Account) getExtraString(key string) string {
return ""
}
+// getExtraStringDefault 从 Extra 中读取指定 key 的字符串值,不存在时返回 defaultVal
+func (a *Account) getExtraStringDefault(key, defaultVal string) string {
+ if v := a.getExtraString(key); v != "" {
+ return v
+ }
+ return defaultVal
+}
+
// getExtraInt 从 Extra 中读取指定 key 的 int 值
func (a *Account) getExtraInt(key string) int {
if a.Extra == nil {
@@ -1464,6 +1671,62 @@ func (a *Account) GetQuotaResetTimezone() string {
return "UTC"
}
+// --- Quota Notification Getters ---
+
+// QuotaNotifyConfig returns the notify configuration for a given quota dimension.
+// dim must be one of quotaDimDaily, quotaDimWeekly, quotaDimTotal.
+func (a *Account) QuotaNotifyConfig(dim string) (enabled bool, threshold float64, thresholdType string) {
+ enabled = a.getExtraBool("quota_notify_" + dim + "_enabled")
+ threshold = a.getExtraFloat64("quota_notify_" + dim + "_threshold")
+ thresholdType = a.getExtraStringDefault("quota_notify_"+dim+"_threshold_type", thresholdTypeFixed)
+ return
+}
+
+func (a *Account) GetQuotaNotifyDailyEnabled() bool {
+ e, _, _ := a.QuotaNotifyConfig(quotaDimDaily)
+ return e
+}
+
+func (a *Account) GetQuotaNotifyDailyThreshold() float64 {
+ _, t, _ := a.QuotaNotifyConfig(quotaDimDaily)
+ return t
+}
+
+func (a *Account) GetQuotaNotifyDailyThresholdType() string {
+ _, _, tt := a.QuotaNotifyConfig(quotaDimDaily)
+ return tt
+}
+
+func (a *Account) GetQuotaNotifyWeeklyEnabled() bool {
+ e, _, _ := a.QuotaNotifyConfig(quotaDimWeekly)
+ return e
+}
+
+func (a *Account) GetQuotaNotifyWeeklyThreshold() float64 {
+ _, t, _ := a.QuotaNotifyConfig(quotaDimWeekly)
+ return t
+}
+
+func (a *Account) GetQuotaNotifyWeeklyThresholdType() string {
+ _, _, tt := a.QuotaNotifyConfig(quotaDimWeekly)
+ return tt
+}
+
+func (a *Account) GetQuotaNotifyTotalEnabled() bool {
+ e, _, _ := a.QuotaNotifyConfig(quotaDimTotal)
+ return e
+}
+
+func (a *Account) GetQuotaNotifyTotalThreshold() float64 {
+ _, t, _ := a.QuotaNotifyConfig(quotaDimTotal)
+ return t
+}
+
+func (a *Account) GetQuotaNotifyTotalThresholdType() string {
+ _, _, tt := a.QuotaNotifyConfig(quotaDimTotal)
+ return tt
+}
+
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
t := after.In(tz)
diff --git a/backend/internal/service/account_openai_compact_test.go b/backend/internal/service/account_openai_compact_test.go
new file mode 100644
index 00000000..442b00da
--- /dev/null
+++ b/backend/internal/service/account_openai_compact_test.go
@@ -0,0 +1,369 @@
+package service
+
+import "testing"
+
+func TestAccountGetOpenAICompactMode(t *testing.T) {
+ tests := []struct {
+ name string
+ account *Account
+ want string
+ }{
+ {
+ name: "nil account defaults to auto",
+ want: OpenAICompactModeAuto,
+ },
+ {
+ name: "non openai account defaults to auto",
+ account: &Account{
+ Platform: PlatformAnthropic,
+ Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn},
+ },
+ want: OpenAICompactModeAuto,
+ },
+ {
+ name: "missing extra defaults to auto",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ },
+ want: OpenAICompactModeAuto,
+ },
+ {
+ name: "invalid mode falls back to auto",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_mode": " invalid "},
+ },
+ want: OpenAICompactModeAuto,
+ },
+ {
+ name: "force on is normalized",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_mode": " FORCE_ON "},
+ },
+ want: OpenAICompactModeForceOn,
+ },
+ {
+ name: "force off is normalized",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_mode": "force_off"},
+ },
+ want: OpenAICompactModeForceOff,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := tt.account.GetOpenAICompactMode(); got != tt.want {
+ t.Fatalf("GetOpenAICompactMode() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestAccountOpenAICompactSupportKnown(t *testing.T) {
+ tests := []struct {
+ name string
+ account *Account
+ wantSupported bool
+ wantKnown bool
+ }{
+ {
+ name: "nil account is unknown",
+ wantSupported: false,
+ wantKnown: false,
+ },
+ {
+ name: "non openai account is unknown",
+ account: &Account{
+ Platform: PlatformAnthropic,
+ Extra: map[string]any{"openai_compact_supported": true},
+ },
+ wantSupported: false,
+ wantKnown: false,
+ },
+ {
+ name: "force on overrides probe state",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{
+ "openai_compact_mode": OpenAICompactModeForceOn,
+ "openai_compact_supported": false,
+ },
+ },
+ wantSupported: true,
+ wantKnown: true,
+ },
+ {
+ name: "force off overrides probe state",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{
+ "openai_compact_mode": OpenAICompactModeForceOff,
+ "openai_compact_supported": true,
+ },
+ },
+ wantSupported: false,
+ wantKnown: true,
+ },
+ {
+ name: "auto true is known supported",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_supported": true},
+ },
+ wantSupported: true,
+ wantKnown: true,
+ },
+ {
+ name: "auto false is known unsupported",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_supported": false},
+ },
+ wantSupported: false,
+ wantKnown: true,
+ },
+ {
+ name: "auto without probe state remains unknown",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{},
+ },
+ wantSupported: false,
+ wantKnown: false,
+ },
+ {
+ name: "invalid probe field remains unknown",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_supported": "true"},
+ },
+ wantSupported: false,
+ wantKnown: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotSupported, gotKnown := tt.account.OpenAICompactSupportKnown()
+ if gotSupported != tt.wantSupported || gotKnown != tt.wantKnown {
+ t.Fatalf("OpenAICompactSupportKnown() = (%v, %v), want (%v, %v)", gotSupported, gotKnown, tt.wantSupported, tt.wantKnown)
+ }
+ })
+ }
+}
+
+func TestAccountAllowsOpenAICompact(t *testing.T) {
+ tests := []struct {
+ name string
+ account *Account
+ want bool
+ }{
+ {
+ name: "nil account does not allow compact",
+ want: false,
+ },
+ {
+ name: "non openai account does not allow compact",
+ account: &Account{
+ Platform: PlatformAnthropic,
+ },
+ want: false,
+ },
+ {
+ name: "unknown openai account remains allowed",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{},
+ },
+ want: true,
+ },
+ {
+ name: "supported openai account is allowed",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_supported": true},
+ },
+ want: true,
+ },
+ {
+ name: "unsupported openai account is rejected",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_supported": false},
+ },
+ want: false,
+ },
+ {
+ name: "force on is allowed",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOn},
+ },
+ want: true,
+ },
+ {
+ name: "force off is rejected",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ Extra: map[string]any{"openai_compact_mode": OpenAICompactModeForceOff},
+ },
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := tt.account.AllowsOpenAICompact(); got != tt.want {
+ t.Fatalf("AllowsOpenAICompact() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestAccountGetCompactModelMapping(t *testing.T) {
+ tests := []struct {
+ name string
+ account *Account
+ want map[string]string
+ }{
+ {
+ name: "nil account returns nil",
+ want: nil,
+ },
+ {
+ name: "missing credentials returns nil",
+ account: &Account{
+ Platform: PlatformOpenAI,
+ },
+ want: nil,
+ },
+ {
+ name: "map any is converted",
+ account: &Account{
+ Credentials: map[string]any{
+ "compact_model_mapping": map[string]any{
+ "gpt-5.4": "gpt-5.4-openai-compact",
+ "invalid": 1,
+ },
+ },
+ },
+ want: map[string]string{
+ "gpt-5.4": "gpt-5.4-openai-compact",
+ },
+ },
+ {
+ name: "map string string is copied",
+ account: &Account{
+ Credentials: map[string]any{
+ "compact_model_mapping": map[string]string{
+ "gpt-*": "compact-*",
+ },
+ },
+ },
+ want: map[string]string{
+ "gpt-*": "compact-*",
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.account.GetCompactModelMapping()
+ if !equalStringMap(got, tt.want) {
+ t.Fatalf("GetCompactModelMapping() = %#v, want %#v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestAccountResolveCompactMappedModel(t *testing.T) {
+ tests := []struct {
+ name string
+ credentials map[string]any
+ requestedModel string
+ expectedModel string
+ expectedMatch bool
+ }{
+ {
+ name: "no compact mapping reports unmatched",
+ credentials: nil,
+ requestedModel: "gpt-5.4",
+ expectedModel: "gpt-5.4",
+ expectedMatch: false,
+ },
+ {
+ name: "exact compact mapping matches",
+ credentials: map[string]any{
+ "compact_model_mapping": map[string]any{
+ "gpt-5.4": "gpt-5.4-openai-compact",
+ },
+ },
+ requestedModel: "gpt-5.4",
+ expectedModel: "gpt-5.4-openai-compact",
+ expectedMatch: true,
+ },
+ {
+ name: "exact passthrough counts as match",
+ credentials: map[string]any{
+ "compact_model_mapping": map[string]any{
+ "gpt-5.4": "gpt-5.4",
+ },
+ },
+ requestedModel: "gpt-5.4",
+ expectedModel: "gpt-5.4",
+ expectedMatch: true,
+ },
+ {
+ name: "longest wildcard wins",
+ credentials: map[string]any{
+ "compact_model_mapping": map[string]any{
+ "gpt-*": "fallback-compact",
+ "gpt-5.4*": "gpt-5.4-openai-compact",
+ "gpt-5.4-mini*": "gpt-5.4-mini-openai-compact",
+ },
+ },
+ requestedModel: "gpt-5.4-mini",
+ expectedModel: "gpt-5.4-mini-openai-compact",
+ expectedMatch: true,
+ },
+ {
+ name: "missing compact mapping reports unmatched",
+ credentials: map[string]any{
+ "compact_model_mapping": map[string]any{
+ "gpt-5.3": "gpt-5.3-openai-compact",
+ },
+ },
+ requestedModel: "gpt-5.4",
+ expectedModel: "gpt-5.4",
+ expectedMatch: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Credentials: tt.credentials,
+ }
+ gotModel, gotMatch := account.ResolveCompactMappedModel(tt.requestedModel)
+ if gotModel != tt.expectedModel || gotMatch != tt.expectedMatch {
+ t.Fatalf("ResolveCompactMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, gotModel, gotMatch, tt.expectedModel, tt.expectedMatch)
+ }
+ })
+ }
+}
+
+func equalStringMap(left, right map[string]string) bool {
+ if len(left) != len(right) {
+ return false
+ }
+ for key, want := range right {
+ if got, ok := left[key]; !ok || got != want {
+ return false
+ }
+ }
+ return true
+}
diff --git a/backend/internal/service/account_quota_schedulable_test.go b/backend/internal/service/account_quota_schedulable_test.go
new file mode 100644
index 00000000..2895b34c
--- /dev/null
+++ b/backend/internal/service/account_quota_schedulable_test.go
@@ -0,0 +1,123 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAccountIsSchedulable_QuotaExceeded(t *testing.T) {
+ now := time.Now()
+
+ tests := []struct {
+ name string
+ account *Account
+ want bool
+ }{
+ {
+ name: "apikey daily quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: false,
+ },
+ {
+ name: "apikey weekly quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_weekly_limit": 50.0,
+ "quota_weekly_used": 50.0,
+ "quota_weekly_start": now.Add(-2 * 24 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: false,
+ },
+ {
+ name: "apikey total quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_limit": 100.0,
+ "quota_used": 100.0,
+ },
+ },
+ want: false,
+ },
+ {
+ name: "apikey quota not exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 5.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: true,
+ },
+ {
+ name: "apikey expired daily period restores schedulable",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-25 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: true,
+ },
+ {
+ name: "oauth ignores quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "quota_daily_limit": 10.0,
+ "quota_daily_used": 10.0,
+ "quota_daily_start": now.Add(-1 * time.Hour).Format(time.RFC3339),
+ },
+ },
+ want: true,
+ },
+ {
+ name: "bedrock quota exceeded",
+ account: &Account{
+ Status: StatusActive,
+ Schedulable: true,
+ Type: AccountTypeBedrock,
+ Extra: map[string]any{
+ "quota_limit": 200.0,
+ "quota_used": 200.0,
+ },
+ },
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ require.Equal(t, tt.want, tt.account.IsSchedulable())
+ })
+ }
+}
diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go
index 328790a8..3189a729 100644
--- a/backend/internal/service/account_service.go
+++ b/backend/internal/service/account_service.go
@@ -28,8 +28,7 @@ type AccountRepository interface {
// GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
- // FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora')
- // 用于查找通过 linked_openai_account_id 关联的 Sora 账号
+ // FindByExtraField 根据 extra 字段中的键值对查找账号
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
// for all accounts that have been synced from CRS.
diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go
new file mode 100644
index 00000000..90ff450f
--- /dev/null
+++ b/backend/internal/service/account_stats_pricing.go
@@ -0,0 +1,236 @@
+package service
+
+import (
+ "context"
+ "strings"
+)
+
+// resolveAccountStatsCost 计算账号统计定价费用。
+// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。
+//
+// 优先级(先命中为准):
+// 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关)
+// 2. ApplyPricingToAccountStats 启用时,直接使用本次请求的客户计费(倍率前的 totalCost)
+// 3. 模型定价文件(LiteLLM)中上游模型的默认价格
+// 4. nil → 走默认公式(total_cost × account_rate_multiplier)
+//
+// upstreamModel 是最终发往上游的模型 ID。
+// totalCost 是本次请求的客户计费(倍率前),用于优先级 2。
+func resolveAccountStatsCost(
+ ctx context.Context,
+ channelService *ChannelService,
+ billingService *BillingService,
+ accountID int64,
+ groupID int64,
+ upstreamModel string,
+ tokens UsageTokens,
+ requestCount int,
+ totalCost float64,
+) *float64 {
+ if channelService == nil || upstreamModel == "" {
+ return nil
+ }
+ channel, err := channelService.GetChannelForGroup(ctx, groupID)
+ if err != nil || channel == nil {
+ return nil
+ }
+
+ platform := channelService.GetGroupPlatform(ctx, groupID)
+
+ // 优先级 1:自定义规则(始终尝试)
+ if cost := tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount); cost != nil {
+ return cost
+ }
+
+ // 优先级 2:渠道开启"应用模型定价到账号统计"时,直接使用客户计费(倍率前)
+ if channel.ApplyPricingToAccountStats {
+ cost := totalCost
+ if cost <= 0 {
+ return nil
+ }
+ return &cost
+ }
+
+ // 优先级 3:模型定价文件(LiteLLM)默认价格
+ if billingService != nil {
+ return tryModelFilePricing(billingService, upstreamModel, tokens)
+ }
+
+ return nil
+}
+
+// tryModelFilePricing 使用模型定价文件(LiteLLM/fallback)中的标准价格计算费用。
+func tryModelFilePricing(billingService *BillingService, model string, tokens UsageTokens) *float64 {
+ pricing, err := billingService.GetModelPricing(model)
+ if err != nil || pricing == nil {
+ return nil
+ }
+ cost := float64(tokens.InputTokens)*pricing.InputPricePerToken +
+ float64(tokens.OutputTokens)*pricing.OutputPricePerToken +
+ float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken +
+ float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken +
+ float64(tokens.ImageOutputTokens)*pricing.ImageOutputPricePerToken
+ if cost <= 0 {
+ return nil
+ }
+ return &cost
+}
+
+// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
+func tryCustomRules(
+ channel *Channel, accountID, groupID int64,
+ platform, model string, tokens UsageTokens, requestCount int,
+) *float64 {
+ modelLower := strings.ToLower(model)
+ for _, rule := range channel.AccountStatsPricingRules {
+ if !matchAccountStatsRule(&rule, accountID, groupID) {
+ continue
+ }
+ pricing := findPricingForModel(rule.Pricing, platform, modelLower)
+ if pricing == nil {
+ continue // 规则匹配但模型不在规则定价中,继续下一条
+ }
+ return calculateStatsCost(pricing, tokens, requestCount)
+ }
+ return nil
+}
+
+// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
+// 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
+// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
+func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool {
+ if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 {
+ return false
+ }
+ for _, id := range rule.AccountIDs {
+ if id == accountID {
+ return true
+ }
+ }
+ for _, id := range rule.GroupIDs {
+ if id == groupID {
+ return true
+ }
+ }
+ return false
+}
+
+// findPricingForModel 在定价列表中查找匹配的模型定价。
+// 先精确匹配,再通配符匹配(按配置顺序,先匹配先使用)。
+func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
+ // 精确匹配优先
+ for i := range pricingList {
+ p := &pricingList[i]
+ if !isPlatformMatch(platform, p.Platform) {
+ continue
+ }
+ for _, m := range p.Models {
+ if strings.ToLower(m) == modelLower {
+ return p
+ }
+ }
+ }
+ // 通配符匹配:按配置顺序,先匹配先使用
+ for i := range pricingList {
+ p := &pricingList[i]
+ if !isPlatformMatch(platform, p.Platform) {
+ continue
+ }
+ for _, m := range p.Models {
+ ml := strings.ToLower(m)
+ if !strings.HasSuffix(ml, "*") {
+ continue
+ }
+ prefix := strings.TrimSuffix(ml, "*")
+ if strings.HasPrefix(modelLower, prefix) {
+ return p
+ }
+ }
+ }
+ return nil
+}
+
+// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
+func isPlatformMatch(queryPlatform, pricingPlatform string) bool {
+ if queryPlatform == "" || pricingPlatform == "" {
+ return true
+ }
+ return queryPlatform == pricingPlatform
+}
+
+// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
+func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 {
+ if pricing == nil {
+ return nil
+ }
+ switch pricing.BillingMode {
+ case BillingModePerRequest, BillingModeImage:
+ return calculatePerRequestStatsCost(pricing, requestCount)
+ default:
+ return calculateTokenStatsCost(pricing, tokens)
+ }
+}
+
+// calculatePerRequestStatsCost 按次/图片计费。
+func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 {
+ if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 {
+ return nil
+ }
+ cost := *pricing.PerRequestPrice * float64(requestCount)
+ return &cost
+}
+
+// calculateTokenStatsCost Token 计费。
+// If the pricing has intervals, find the matching interval by total token count
+// and use its prices instead of the flat pricing fields.
+func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
+ p := pricing
+ if len(pricing.Intervals) > 0 {
+ totalTokens := tokens.InputTokens + tokens.OutputTokens + tokens.CacheCreationTokens + tokens.CacheReadTokens
+ if iv := FindMatchingInterval(pricing.Intervals, totalTokens); iv != nil {
+ p = &ChannelModelPricing{
+ InputPrice: iv.InputPrice,
+ OutputPrice: iv.OutputPrice,
+ CacheWritePrice: iv.CacheWritePrice,
+ CacheReadPrice: iv.CacheReadPrice,
+ PerRequestPrice: iv.PerRequestPrice,
+ }
+ }
+ }
+ deref := func(ptr *float64) float64 {
+ if ptr == nil {
+ return 0
+ }
+ return *ptr
+ }
+ cost := float64(tokens.InputTokens)*deref(p.InputPrice) +
+ float64(tokens.OutputTokens)*deref(p.OutputPrice) +
+ float64(tokens.CacheCreationTokens)*deref(p.CacheWritePrice) +
+ float64(tokens.CacheReadTokens)*deref(p.CacheReadPrice) +
+ float64(tokens.ImageOutputTokens)*deref(p.ImageOutputPrice)
+ if cost <= 0 {
+ return nil
+ }
+ return &cost
+}
+
+// applyAccountStatsCost resolves the account stats cost for a usage log entry.
+// It resolves the upstream model (falling back to the requested model) and calls
+// the 4-level priority chain via resolveAccountStatsCost.
+func applyAccountStatsCost(
+ ctx context.Context,
+ usageLog *UsageLog,
+ cs *ChannelService, bs *BillingService,
+ accountID int64, groupID int64,
+ upstreamModel, requestedModel string,
+ tokens UsageTokens,
+ totalCost float64,
+) {
+ model := upstreamModel
+ if model == "" {
+ model = requestedModel
+ }
+ usageLog.AccountStatsCost = resolveAccountStatsCost(
+ ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost,
+ )
+}
diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go
new file mode 100644
index 00000000..36e5eb74
--- /dev/null
+++ b/backend/internal/service/account_stats_pricing_test.go
@@ -0,0 +1,771 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// matchAccountStatsRule
+// ---------------------------------------------------------------------------
+
+func TestMatchAccountStatsRule_BothEmpty_NoMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{}
+ require.False(t, matchAccountStatsRule(rule, 1, 10))
+}
+
+func TestMatchAccountStatsRule_AccountIDMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{AccountIDs: []int64{1, 2, 3}}
+ require.True(t, matchAccountStatsRule(rule, 2, 999))
+}
+
+func TestMatchAccountStatsRule_GroupIDMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{GroupIDs: []int64{10, 20}}
+ require.True(t, matchAccountStatsRule(rule, 999, 20))
+}
+
+func TestMatchAccountStatsRule_BothConfigured_AccountMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{
+ AccountIDs: []int64{1, 2},
+ GroupIDs: []int64{10, 20},
+ }
+ require.True(t, matchAccountStatsRule(rule, 2, 999))
+}
+
+func TestMatchAccountStatsRule_BothConfigured_GroupMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{
+ AccountIDs: []int64{1, 2},
+ GroupIDs: []int64{10, 20},
+ }
+ require.True(t, matchAccountStatsRule(rule, 999, 10))
+}
+
+func TestMatchAccountStatsRule_BothConfigured_NeitherMatch(t *testing.T) {
+ rule := &AccountStatsPricingRule{
+ AccountIDs: []int64{1, 2},
+ GroupIDs: []int64{10, 20},
+ }
+ require.False(t, matchAccountStatsRule(rule, 999, 999))
+}
+
+// ---------------------------------------------------------------------------
+// findPricingForModel
+// ---------------------------------------------------------------------------
+
+func TestFindPricingForModel(t *testing.T) {
+ exactPricing := ChannelModelPricing{
+ ID: 1,
+ Models: []string{"claude-opus-4"},
+ }
+ wildcardPricing := ChannelModelPricing{
+ ID: 2,
+ Models: []string{"claude-*"},
+ }
+ platformPricing := ChannelModelPricing{
+ ID: 3,
+ Platform: "openai",
+ Models: []string{"gpt-4o"},
+ }
+ emptyPlatformPricing := ChannelModelPricing{
+ ID: 4,
+ Models: []string{"gemini-2.5-pro"},
+ }
+
+ tests := []struct {
+ name string
+ list []ChannelModelPricing
+ platform string
+ model string
+ wantID int64
+ wantNil bool
+ }{
+ {
+ name: "exact match",
+ list: []ChannelModelPricing{exactPricing},
+ platform: "anthropic",
+ model: "claude-opus-4",
+ wantID: 1,
+ },
+ {
+ name: "exact match case insensitive",
+ list: []ChannelModelPricing{{ID: 5, Models: []string{"Claude-Opus-4"}}},
+ platform: "",
+ model: "claude-opus-4",
+ wantID: 5,
+ },
+ {
+ name: "wildcard match",
+ list: []ChannelModelPricing{wildcardPricing},
+ platform: "anthropic",
+ model: "claude-opus-4",
+ wantID: 2,
+ },
+ {
+ name: "exact match takes priority over wildcard",
+ list: []ChannelModelPricing{wildcardPricing, exactPricing},
+ platform: "anthropic",
+ model: "claude-opus-4",
+ wantID: 1,
+ },
+ {
+ name: "platform mismatch skipped",
+ list: []ChannelModelPricing{platformPricing},
+ platform: "anthropic",
+ model: "gpt-4o",
+ wantNil: true,
+ },
+ {
+ name: "empty platform in pricing matches any",
+ list: []ChannelModelPricing{emptyPlatformPricing},
+ platform: "gemini",
+ model: "gemini-2.5-pro",
+ wantID: 4,
+ },
+ {
+ name: "empty platform in query matches any pricing platform",
+ list: []ChannelModelPricing{platformPricing},
+ platform: "",
+ model: "gpt-4o",
+ wantID: 3,
+ },
+ {
+ name: "no match at all",
+ list: []ChannelModelPricing{exactPricing, wildcardPricing},
+ platform: "anthropic",
+ model: "gpt-4o",
+ wantNil: true,
+ },
+ {
+ name: "empty list returns nil",
+ list: nil,
+ model: "claude-opus-4",
+ wantNil: true,
+ },
+ {
+ name: "wildcard matches by config order (first match wins)",
+ list: []ChannelModelPricing{
+ {ID: 10, Models: []string{"claude-*"}},
+ {ID: 11, Models: []string{"claude-opus-*"}},
+ },
+ platform: "",
+ model: "claude-opus-4",
+ wantID: 10, // config order: "claude-*" is first and matches, so it wins
+ },
+ {
+ name: "shorter wildcard used when longer does not match",
+ list: []ChannelModelPricing{
+ {ID: 10, Models: []string{"claude-*"}},
+ {ID: 11, Models: []string{"claude-opus-*"}},
+ },
+ platform: "",
+ model: "claude-sonnet-4",
+ wantID: 10, // only "claude-*" matches
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := findPricingForModel(tt.list, tt.platform, tt.model)
+ if tt.wantNil {
+ require.Nil(t, result)
+ return
+ }
+ require.NotNil(t, result)
+ require.Equal(t, tt.wantID, result.ID)
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// calculateStatsCost
+// ---------------------------------------------------------------------------
+
+func TestCalculateStatsCost_NilPricing(t *testing.T) {
+ result := calculateStatsCost(nil, UsageTokens{}, 1)
+ require.Nil(t, result)
+}
+
+func TestCalculateStatsCost_TokenBilling(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(0.001),
+ OutputPrice: testPtrFloat64(0.002),
+ }
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ }
+ result := calculateStatsCost(pricing, tokens, 1)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
+ require.InDelta(t, 0.2, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_TokenBilling_WithCache(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(0.001),
+ OutputPrice: testPtrFloat64(0.002),
+ CacheWritePrice: testPtrFloat64(0.003),
+ CacheReadPrice: testPtrFloat64(0.0005),
+ }
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ CacheCreationTokens: 200,
+ CacheReadTokens: 300,
+ }
+ result := calculateStatsCost(pricing, tokens, 1)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
+ // = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
+ require.InDelta(t, 0.95, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_TokenBilling_WithImageOutput(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(0.001),
+ OutputPrice: testPtrFloat64(0.002),
+ ImageOutputPrice: testPtrFloat64(0.01),
+ }
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ ImageOutputTokens: 10,
+ }
+ result := calculateStatsCost(pricing, tokens, 1)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
+ require.InDelta(t, 0.3, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_TokenBilling_PartialPricesNil(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(0.001),
+ // OutputPrice, CacheWritePrice, etc. are all nil → treated as 0
+ }
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ CacheCreationTokens: 200,
+ }
+ result := calculateStatsCost(pricing, tokens, 1)
+ require.NotNil(t, result)
+ // Only input contributes: 100*0.001 = 0.1
+ require.InDelta(t, 0.1, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_TokenBilling_AllTokensZero(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(0.001),
+ OutputPrice: testPtrFloat64(0.002),
+ }
+ tokens := UsageTokens{} // all zeros
+ result := calculateStatsCost(pricing, tokens, 1)
+ // totalCost == 0 → returns nil (does not override, falls back to default formula)
+ require.Nil(t, result)
+}
+
+func TestCalculateStatsCost_PerRequestBilling(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModePerRequest,
+ PerRequestPrice: testPtrFloat64(0.05),
+ }
+ tokens := UsageTokens{InputTokens: 999, OutputTokens: 999}
+ result := calculateStatsCost(pricing, tokens, 3)
+ require.NotNil(t, result)
+ // 0.05 * 3 = 0.15
+ require.InDelta(t, 0.15, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_PerRequestBilling_PriceNil(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModePerRequest,
+ // PerRequestPrice is nil
+ }
+ result := calculateStatsCost(pricing, UsageTokens{}, 1)
+ require.Nil(t, result)
+}
+
+func TestCalculateStatsCost_PerRequestBilling_PriceZero(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModePerRequest,
+ PerRequestPrice: testPtrFloat64(0),
+ }
+ result := calculateStatsCost(pricing, UsageTokens{}, 1)
+ // price == 0 → condition *pricing.PerRequestPrice > 0 is false → returns nil
+ require.Nil(t, result)
+}
+
+func TestCalculateStatsCost_ImageBilling(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeImage,
+ PerRequestPrice: testPtrFloat64(0.10),
+ }
+ result := calculateStatsCost(pricing, UsageTokens{}, 2)
+ require.NotNil(t, result)
+ // 0.10 * 2 = 0.20
+ require.InDelta(t, 0.20, *result, 1e-12)
+}
+
+func TestCalculateStatsCost_ImageBilling_PriceNil(t *testing.T) {
+ pricing := &ChannelModelPricing{
+ BillingMode: BillingModeImage,
+ // PerRequestPrice is nil
+ }
+ result := calculateStatsCost(pricing, UsageTokens{}, 1)
+ require.Nil(t, result)
+}
+
+func TestCalculateStatsCost_DefaultBillingMode_FallsToToken(t *testing.T) {
+ // BillingMode is empty string (default) → falls into token billing
+ pricing := &ChannelModelPricing{
+ InputPrice: testPtrFloat64(0.001),
+ OutputPrice: testPtrFloat64(0.002),
+ }
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ }
+ result := calculateStatsCost(pricing, tokens, 1)
+ require.NotNil(t, result)
+ require.InDelta(t, 0.2, *result, 1e-12)
+}
+
+// ---------------------------------------------------------------------------
+// tryCustomRules — 多规则顺序测试
+// ---------------------------------------------------------------------------
+
+func TestTryCustomRules_FirstMatchWins(t *testing.T) {
+ channel := &Channel{
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ GroupIDs: []int64{1},
+ Pricing: []ChannelModelPricing{
+ {ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01), OutputPrice: testPtrFloat64(0.02)},
+ },
+ },
+ {
+ GroupIDs: []int64{1},
+ Pricing: []ChannelModelPricing{
+ {ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99), OutputPrice: testPtrFloat64(0.99)},
+ },
+ },
+ },
+ }
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+ result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
+ require.NotNil(t, result)
+ // 应使用第一条规则的价格:100*0.01 + 50*0.02 = 2.0
+ require.InDelta(t, 2.0, *result, 1e-12)
+}
+
+func TestTryCustomRules_SkipsNonMatchingRules(t *testing.T) {
+ channel := &Channel{
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ AccountIDs: []int64{888}, // 不匹配
+ Pricing: []ChannelModelPricing{
+ {ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99)},
+ },
+ },
+ {
+ GroupIDs: []int64{1}, // 匹配
+ Pricing: []ChannelModelPricing{
+ {ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)},
+ },
+ },
+ },
+ }
+ tokens := UsageTokens{InputTokens: 100}
+ result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
+ require.NotNil(t, result)
+ // 跳过规则1(账号不匹配),使用规则2:100*0.05 = 5.0
+ require.InDelta(t, 5.0, *result, 1e-12)
+}
+
+func TestTryCustomRules_NoMatch_ReturnsNil(t *testing.T) {
+ channel := &Channel{
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ AccountIDs: []int64{888},
+ Pricing: []ChannelModelPricing{
+ {ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01)},
+ },
+ },
+ },
+ }
+ tokens := UsageTokens{InputTokens: 100}
+ result := tryCustomRules(channel, 999, 2, "", "claude-opus-4", tokens, 1)
+ require.Nil(t, result) // 账号和分组都不匹配
+}
+
+func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) {
+ channel := &Channel{
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ GroupIDs: []int64{1},
+ Pricing: []ChannelModelPricing{
+ {ID: 100, Models: []string{"gpt-4o"}, InputPrice: testPtrFloat64(0.01)}, // 模型不匹配
+ },
+ },
+ {
+ GroupIDs: []int64{1},
+ Pricing: []ChannelModelPricing{
+ {ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)}, // 模型匹配
+ },
+ },
+ },
+ }
+ tokens := UsageTokens{InputTokens: 100}
+ result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
+ require.NotNil(t, result)
+ require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2
+}
+
+// ---------------------------------------------------------------------------
+// tryModelFilePricing
+// ---------------------------------------------------------------------------
+
+// newTestBillingServiceWithPrices creates a BillingService with pre-populated
+// fallback prices for testing. No config or pricing service is needed.
+// The key must match what getFallbackPricing resolves to for a given model name.
+// E.g., model "claude-sonnet-4" resolves to key "claude-sonnet-4".
+func newTestBillingServiceWithPrices(prices map[string]*ModelPricing) *BillingService {
+ return &BillingService{
+ fallbackPrices: prices,
+ }
+}
+
+func TestTryModelFilePricing_Success(t *testing.T) {
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": {
+ InputPricePerToken: 0.001,
+ OutputPricePerToken: 0.002,
+ },
+ })
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+ result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
+ require.InDelta(t, 0.2, *result, 1e-12)
+}
+
+func TestTryModelFilePricing_PricingNotFound(t *testing.T) {
+ // "nonexistent-model" does not match any fallback pattern
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{})
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+ result := tryModelFilePricing(bs, "nonexistent-model", tokens)
+ require.Nil(t, result)
+}
+
+func TestTryModelFilePricing_NilFallback(t *testing.T) {
+ // getFallbackPricing returns nil when key maps to nil
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": nil,
+ })
+ tokens := UsageTokens{InputTokens: 100}
+ result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
+ require.Nil(t, result)
+}
+
+func TestTryModelFilePricing_ZeroCost(t *testing.T) {
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": {
+ InputPricePerToken: 0.001,
+ OutputPricePerToken: 0.002,
+ },
+ })
+ tokens := UsageTokens{} // all zero tokens → cost = 0 → nil
+ result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
+ require.Nil(t, result)
+}
+
+func TestTryModelFilePricing_WithImageOutput(t *testing.T) {
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": {
+ InputPricePerToken: 0.001,
+ OutputPricePerToken: 0.002,
+ ImageOutputPricePerToken: 0.01,
+ },
+ })
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ ImageOutputTokens: 10,
+ }
+ result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
+ require.InDelta(t, 0.3, *result, 1e-12)
+}
+
+func TestTryModelFilePricing_WithCacheTokens(t *testing.T) {
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": {
+ InputPricePerToken: 0.001,
+ OutputPricePerToken: 0.002,
+ CacheCreationPricePerToken: 0.003,
+ CacheReadPricePerToken: 0.0005,
+ },
+ })
+ tokens := UsageTokens{
+ InputTokens: 100,
+ OutputTokens: 50,
+ CacheCreationTokens: 200,
+ CacheReadTokens: 300,
+ }
+ result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
+ // = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
+ require.InDelta(t, 0.95, *result, 1e-12)
+}
+
+// ---------------------------------------------------------------------------
+// resolveAccountStatsCost — integration tests covering the 4-level priority chain
+// ---------------------------------------------------------------------------
+
+func TestResolveAccountStatsCost_NilChannelService(t *testing.T) {
+ result := resolveAccountStatsCost(
+ context.Background(),
+ nil, // channelService is nil
+ newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
+ 1, 1, "claude-sonnet-4",
+ UsageTokens{InputTokens: 100}, 1, 0.5,
+ )
+ require.Nil(t, result)
+}
+
+func TestResolveAccountStatsCost_EmptyUpstreamModel(t *testing.T) {
+ cs := newTestChannelServiceForStats(t, &Channel{
+ ID: 1,
+ Status: StatusActive,
+ }, 1, "")
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs,
+ newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
+ 1, 1, "", // empty upstream model
+ UsageTokens{InputTokens: 100}, 1, 0.5,
+ )
+ require.Nil(t, result)
+}
+
+func TestResolveAccountStatsCost_GetChannelForGroupReturnsNil(t *testing.T) {
+ // Group 99 is NOT in the cache, so GetChannelForGroup returns nil
+ cs := newTestChannelServiceForStats(t, &Channel{
+ ID: 1,
+ Status: StatusActive,
+ }, 1, "")
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs,
+ newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
+ 1, 99, "claude-sonnet-4", // groupID 99 has no channel
+ UsageTokens{InputTokens: 100}, 1, 0.5,
+ )
+ require.Nil(t, result)
+}
+
+func TestResolveAccountStatsCost_HitsCustomRule(t *testing.T) {
+ channel := &Channel{
+ ID: 1,
+ Status: StatusActive,
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ GroupIDs: []int64{10},
+ Pricing: []ChannelModelPricing{
+ {
+ ID: 100,
+ Models: []string{"claude-sonnet-4"},
+ InputPrice: testPtrFloat64(0.01),
+ OutputPrice: testPtrFloat64(0.02),
+ },
+ },
+ },
+ },
+ }
+ cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
+
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs, nil, // billingService not needed when custom rule hits
+ 1, 10, "claude-sonnet-4",
+ tokens, 1, 999.0, // totalCost ignored because custom rule hits
+ )
+ require.NotNil(t, result)
+ // 100*0.01 + 50*0.02 = 1.0 + 1.0 = 2.0
+ require.InDelta(t, 2.0, *result, 1e-12)
+}
+
+func TestResolveAccountStatsCost_ApplyPricingToAccountStats_UsesTotalCost(t *testing.T) {
+ channel := &Channel{
+ ID: 1,
+ Status: StatusActive,
+ ApplyPricingToAccountStats: true,
+ // No custom rules
+ }
+ cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
+
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs, nil,
+ 1, 10, "claude-sonnet-4",
+ tokens, 1, 0.75, // totalCost = 0.75
+ )
+ require.NotNil(t, result)
+ require.InDelta(t, 0.75, *result, 1e-12)
+}
+
+func TestResolveAccountStatsCost_ApplyPricingToAccountStats_ZeroTotalCost_ReturnsNil(t *testing.T) {
+ channel := &Channel{
+ ID: 1,
+ Status: StatusActive,
+ ApplyPricingToAccountStats: true,
+ }
+ cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs, nil,
+ 1, 10, "claude-sonnet-4",
+ UsageTokens{}, 1, 0.0, // totalCost = 0
+ )
+ require.Nil(t, result)
+}
+
+func TestResolveAccountStatsCost_FallsBackToLiteLLM(t *testing.T) {
+ channel := &Channel{
+ ID: 1,
+ Status: StatusActive,
+ ApplyPricingToAccountStats: false, // not enabled
+ // No custom rules
+ }
+ cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
+
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
+ "claude-sonnet-4": {
+ InputPricePerToken: 0.001,
+ OutputPricePerToken: 0.002,
+ },
+ })
+
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs, bs,
+ 1, 10, "claude-sonnet-4",
+ tokens, 1, 999.0, // totalCost ignored
+ )
+ require.NotNil(t, result)
+ // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
+ require.InDelta(t, 0.2, *result, 1e-12)
+}
+
+func TestResolveAccountStatsCost_AllMiss_ReturnsNil(t *testing.T) {
+ channel := &Channel{
+ ID: 1,
+ Status: StatusActive,
+ ApplyPricingToAccountStats: false,
+ // No custom rules
+ }
+ cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
+
+ // BillingService with no pricing for the model
+ bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{})
+
+ tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs, bs,
+ 1, 10, "totally-unknown-model",
+ tokens, 1, 0.0,
+ )
+ require.Nil(t, result)
+}
+
+func TestResolveAccountStatsCost_NilBillingService_SkipsLiteLLM(t *testing.T) {
+ channel := &Channel{
+ ID: 1,
+ Status: StatusActive,
+ ApplyPricingToAccountStats: false,
+ }
+ cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs, nil, // billingService is nil
+ 1, 10, "claude-sonnet-4",
+ UsageTokens{InputTokens: 100}, 1, 0.0,
+ )
+ require.Nil(t, result)
+}
+
+func TestResolveAccountStatsCost_CustomRulePriorityOverApplyPricing(t *testing.T) {
+ // Both custom rule and ApplyPricingToAccountStats are configured;
+ // custom rule should take precedence.
+ channel := &Channel{
+ ID: 1,
+ Status: StatusActive,
+ ApplyPricingToAccountStats: true,
+ AccountStatsPricingRules: []AccountStatsPricingRule{
+ {
+ GroupIDs: []int64{10},
+ Pricing: []ChannelModelPricing{
+ {
+ ID: 100,
+ Models: []string{"claude-sonnet-4"},
+ InputPrice: testPtrFloat64(0.05),
+ },
+ },
+ },
+ },
+ }
+ cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
+
+ tokens := UsageTokens{InputTokens: 100}
+
+ result := resolveAccountStatsCost(
+ context.Background(),
+ cs, nil,
+ 1, 10, "claude-sonnet-4",
+ tokens, 1, 99.0, // totalCost = 99.0 (would be used if ApplyPricing wins)
+ )
+ require.NotNil(t, result)
+ // Custom rule: 100*0.05 = 5.0 (NOT 99.0 from totalCost)
+ require.InDelta(t, 5.0, *result, 1e-12)
+}
+
+// ---------------------------------------------------------------------------
+// helpers for resolveAccountStatsCost tests
+// ---------------------------------------------------------------------------
+
+// newTestChannelServiceForStats creates a ChannelService with a single channel
+// mapped to the given groupID, suitable for resolveAccountStatsCost tests.
+func newTestChannelServiceForStats(t *testing.T, channel *Channel, groupID int64, platform string) *ChannelService {
+ t.Helper()
+ cache := newEmptyChannelCache()
+ cache.channelByGroupID[groupID] = channel
+ cache.groupPlatform[groupID] = platform
+ cs := &ChannelService{}
+ cache.loadedAt = time.Now()
+ cs.cache.Store(cache)
+ return cs
+}
diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go
index fec98e12..c0bbc6dc 100644
--- a/backend/internal/service/account_test_service.go
+++ b/backend/internal/service/account_test_service.go
@@ -13,18 +13,14 @@ import (
"log"
"net/http"
"net/http/httptest"
- "net/url"
"regexp"
"strings"
- "sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
- "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
- "github.com/Wei-Shaw/sub2api/internal/util/soraerror"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -37,11 +33,6 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
- soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接
- soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions"
- soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine"
- soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap"
- soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check"
)
// TestEvent represents a SSE event for account testing
@@ -61,8 +52,14 @@ type TestEvent struct {
const (
defaultGeminiTextTestPrompt = "hi"
defaultGeminiImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background."
+ defaultOpenAIImageTestPrompt = "Generate a cute orange cat astronaut sticker on a clean pastel background."
)
+// isOpenAIImageModel checks if the model is an OpenAI image generation model (e.g. gpt-image-2).
+func isOpenAIImageModel(model string) bool {
+ return strings.HasPrefix(strings.ToLower(model), "gpt-image-")
+}
+
// AccountTestService handles account testing operations
type AccountTestService struct {
accountRepo AccountRepository
@@ -71,13 +68,8 @@ type AccountTestService struct {
httpUpstream HTTPUpstream
cfg *config.Config
tlsFPProfileService *TLSFingerprintProfileService
- soraTestGuardMu sync.Mutex
- soraTestLastRun map[int64]time.Time
- soraTestCooldown time.Duration
}
-const defaultSoraTestCooldown = 10 * time.Second
-
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(
accountRepo AccountRepository,
@@ -94,8 +86,6 @@ func NewAccountTestService(
httpUpstream: httpUpstream,
cfg: cfg,
tlsFPProfileService: tlsFPProfileService,
- soraTestLastRun: make(map[int64]time.Time),
- soraTestCooldown: defaultSoraTestCooldown,
}
}
@@ -175,7 +165,8 @@ func createTestPayload(modelID string) (map[string]any, error) {
// TestAccountConnection tests an account's connection by sending a test request
// All account types use full Claude Code client characteristics, only auth header differs
// modelID is optional - if empty, defaults to claude.DefaultTestModel
-func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string) error {
+// mode is optional - "compact" routes OpenAI accounts to the /responses/compact probe path
+func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string, prompt string, mode string) error {
ctx := c.Request.Context()
// Get account
@@ -186,7 +177,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
// Route to platform-specific test method
if account.IsOpenAI() {
- return s.testOpenAIAccountConnection(c, account, modelID)
+ return s.testOpenAIAccountConnection(c, account, modelID, prompt, normalizeAccountTestMode(mode))
}
if account.IsGemini() {
@@ -197,10 +188,6 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
return s.routeAntigravityTest(c, account, modelID, prompt)
}
- if account.Platform == PlatformSora {
- return s.testSoraAccountConnection(c, account)
- }
-
return s.testClaudeAccountConnection(c, account, modelID)
}
@@ -430,8 +417,10 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
}
// testOpenAIAccountConnection tests an OpenAI account's connection
-func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
+func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string, mode string) error {
ctx := c.Request.Context()
+ _ = prompt
+ mode = normalizeAccountTestMode(mode)
// Default to openai.DefaultTestModel for OpenAI testing
testModelID := modelID
@@ -439,14 +428,24 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
testModelID = openai.DefaultTestModel
}
- // For API Key accounts with model mapping, map the model
- if account.Type == "apikey" {
- mapping := account.GetModelMapping()
- if len(mapping) > 0 {
- if mappedModel, exists := mapping[testModelID]; exists {
- testModelID = mappedModel
- }
+ // Align test routing with gateway behavior: OpenAI accounts apply normal
+ // account model mapping, and compact mode applies compact-only mapping on top.
+ testModelID = account.GetMappedModel(testModelID)
+ if mode == AccountTestModeCompact {
+ testModelID = resolveOpenAICompactForwardModel(account, testModelID)
+ return s.testOpenAICompactConnection(c, account, testModelID)
+ }
+
+ // Route to image generation test if an image model is selected
+ if isOpenAIImageModel(testModelID) {
+ imagePrompt := strings.TrimSpace(prompt)
+ if imagePrompt == "" {
+ imagePrompt = defaultOpenAIImageTestPrompt
}
+ if account.Type == "apikey" {
+ return s.testOpenAIImageAPIKey(c, ctx, account, testModelID, imagePrompt)
+ }
+ return s.testOpenAIImageOAuth(c, ctx, account, testModelID, imagePrompt)
}
// Determine authentication method and API URL
@@ -535,21 +534,17 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
mergeAccountExtra(account, updates)
}
- if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
- if resetAt := codexRateLimitResetAtFromSnapshot(snapshot, time.Now()); resetAt != nil {
- _ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
- account.RateLimitResetAt = resetAt
- }
- }
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
- if isOAuth && s.accountRepo != nil {
- if resetAt := (&RateLimitService{}).calculateOpenAI429ResetTime(resp.Header); resetAt != nil {
- _ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
- account.RateLimitResetAt = resetAt
- }
+ if resp.StatusCode == http.StatusTooManyRequests {
+ s.reconcileOpenAI429State(ctx, account, resp.Header, body)
+ }
+ // 401 Unauthorized: 标记账号为永久错误
+ if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
+ errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
+ _ = s.accountRepo.SetError(ctx, account.ID, errMsg)
}
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}
@@ -558,6 +553,154 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
return s.processOpenAIStream(c, resp.Body)
}
+// testOpenAICompactConnection probes /responses/compact and persists the
+// resulting capability state on the account.
+func (s *AccountTestService) testOpenAICompactConnection(c *gin.Context, account *Account, testModelID string) error {
+ ctx := c.Request.Context()
+
+ authToken := ""
+ apiURL := ""
+ isOAuth := false
+ chatgptAccountID := ""
+
+ switch {
+ case account.IsOAuth():
+ isOAuth = true
+ authToken = account.GetOpenAIAccessToken()
+ if authToken == "" {
+ return s.sendErrorAndEnd(c, "No access token available")
+ }
+ apiURL = chatgptCodexAPIURL + "/compact"
+ chatgptAccountID = account.GetChatGPTAccountID()
+ case account.Type == AccountTypeAPIKey:
+ authToken = account.GetOpenAIApiKey()
+ if authToken == "" {
+ return s.sendErrorAndEnd(c, "No API key available")
+ }
+ baseURL := account.GetOpenAIBaseURL()
+ if baseURL == "" {
+ baseURL = "https://api.openai.com"
+ }
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
+ }
+ apiURL = appendOpenAIResponsesRequestPathSuffix(buildOpenAIResponsesURL(normalizedBaseURL), "/compact")
+ default:
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
+ }
+
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.Flush()
+
+ payloadBytes, _ := json.Marshal(createOpenAICompactProbePayload(testModelID))
+ s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
+
+ req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
+ if err != nil {
+ return s.sendErrorAndEnd(c, "Failed to create request")
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Authorization", "Bearer "+authToken)
+ req.Header.Set("OpenAI-Beta", "responses=experimental")
+ req.Header.Set("Originator", "codex_cli_rs")
+ req.Header.Set("User-Agent", codexCLIUserAgent)
+ req.Header.Set("Version", codexCLIVersion)
+ probeSessionID := compactProbeSessionID(account.ID)
+ req.Header.Set("Session_ID", probeSessionID)
+ req.Header.Set("Conversation_ID", probeSessionID)
+
+ if isOAuth {
+ req.Host = "chatgpt.com"
+ if chatgptAccountID != "" {
+ req.Header.Set("chatgpt-account-id", chatgptAccountID)
+ }
+ }
+
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
+ if err != nil {
+ if s.accountRepo != nil {
+ updates := buildOpenAICompactProbeExtraUpdates(nil, nil, err, time.Now())
+ _ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
+ mergeAccountExtra(account, updates)
+ }
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+
+ if s.accountRepo != nil {
+ updates := buildOpenAICompactProbeExtraUpdates(resp, body, nil, time.Now())
+ if codexUpdates, err := extractOpenAICodexProbeUpdates(resp); err == nil && len(codexUpdates) > 0 {
+ updates = mergeExtraUpdates(updates, codexUpdates)
+ }
+ if len(updates) > 0 {
+ _ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
+ mergeAccountExtra(account, updates)
+ }
+ // 探测如返回 429,主动同步限流状态,避免后续短时间内继续选中。
+ if resp.StatusCode == http.StatusTooManyRequests {
+ s.reconcileOpenAI429State(ctx, account, resp.Header, body)
+ }
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
+ errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
+ _ = s.accountRepo.SetError(ctx, account.ID, errMsg)
+ }
+ return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
+ }
+
+ s.sendEvent(c, TestEvent{Type: "content", Text: "Compact probe succeeded"})
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+}
+
+func (s *AccountTestService) reconcileOpenAI429State(ctx context.Context, account *Account, headers http.Header, body []byte) {
+ if s == nil || s.accountRepo == nil || account == nil {
+ return
+ }
+
+ var resetAt *time.Time
+ if calculated := calculateOpenAI429ResetTime(headers); calculated != nil {
+ resetAt = calculated
+ } else if unixTs := parseOpenAIRateLimitResetTime(body); unixTs != nil {
+ t := time.Unix(*unixTs, 0)
+ resetAt = &t
+ }
+ if resetAt == nil {
+ return
+ }
+
+ if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
+ return
+ }
+
+ now := time.Now()
+ account.RateLimitedAt = &now
+ account.RateLimitResetAt = resetAt
+
+ if account.Status == StatusError {
+ if err := s.accountRepo.ClearError(ctx, account.ID); err != nil {
+ return
+ }
+ account.Status = StatusActive
+ account.ErrorMessage = ""
+ }
+}
+
// testGeminiAccountConnection tests a Gemini account's connection
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
ctx := c.Request.Context()
@@ -629,698 +772,6 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
return s.processGeminiStream(c, resp.Body)
}
-type soraProbeStep struct {
- Name string `json:"name"`
- Status string `json:"status"`
- HTTPStatus int `json:"http_status,omitempty"`
- ErrorCode string `json:"error_code,omitempty"`
- Message string `json:"message,omitempty"`
-}
-
-type soraProbeSummary struct {
- Status string `json:"status"`
- Steps []soraProbeStep `json:"steps"`
-}
-
-type soraProbeRecorder struct {
- steps []soraProbeStep
-}
-
-func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) {
- r.steps = append(r.steps, soraProbeStep{
- Name: name,
- Status: status,
- HTTPStatus: httpStatus,
- ErrorCode: strings.TrimSpace(errorCode),
- Message: strings.TrimSpace(message),
- })
-}
-
-func (r *soraProbeRecorder) finalize() soraProbeSummary {
- meSuccess := false
- partial := false
- for _, step := range r.steps {
- if step.Name == "me" {
- meSuccess = strings.EqualFold(step.Status, "success")
- continue
- }
- if strings.EqualFold(step.Status, "failed") {
- partial = true
- }
- }
-
- status := "success"
- if !meSuccess {
- status = "failed"
- } else if partial {
- status = "partial_success"
- }
-
- return soraProbeSummary{
- Status: status,
- Steps: append([]soraProbeStep(nil), r.steps...),
- }
-}
-
-func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) {
- if rec == nil {
- return
- }
- summary := rec.finalize()
- code := ""
- for _, step := range summary.Steps {
- if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" {
- code = step.ErrorCode
- break
- }
- }
- s.sendEvent(c, TestEvent{
- Type: "sora_test_result",
- Status: summary.Status,
- Code: code,
- Data: summary,
- })
-}
-
-func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) {
- if accountID <= 0 {
- return 0, true
- }
- s.soraTestGuardMu.Lock()
- defer s.soraTestGuardMu.Unlock()
-
- if s.soraTestLastRun == nil {
- s.soraTestLastRun = make(map[int64]time.Time)
- }
- cooldown := s.soraTestCooldown
- if cooldown <= 0 {
- cooldown = defaultSoraTestCooldown
- }
-
- now := time.Now()
- if lastRun, ok := s.soraTestLastRun[accountID]; ok {
- elapsed := now.Sub(lastRun)
- if elapsed < cooldown {
- return cooldown - elapsed, false
- }
- }
- s.soraTestLastRun[accountID] = now
- return 0, true
-}
-
-func ceilSeconds(d time.Duration) int {
- if d <= 0 {
- return 1
- }
- sec := int(d / time.Second)
- if d%time.Second != 0 {
- sec++
- }
- if sec < 1 {
- sec = 1
- }
- return sec
-}
-
-// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。
-// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。
-func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error {
- ctx := c.Request.Context()
-
- apiKey := account.GetCredential("api_key")
- if apiKey == "" {
- return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证")
- }
-
- baseURL := account.GetBaseURL()
- if baseURL == "" {
- return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url")
- }
-
- // 验证 base_url 格式
- normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
- if err != nil {
- return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error()))
- }
- upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions"
-
- // 设置 SSE 头
- c.Writer.Header().Set("Content-Type", "text/event-stream")
- c.Writer.Header().Set("Cache-Control", "no-cache")
- c.Writer.Header().Set("Connection", "keep-alive")
- c.Writer.Header().Set("X-Accel-Buffering", "no")
- c.Writer.Flush()
-
- if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
- msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
- return s.sendErrorAndEnd(c, msg)
- }
-
- s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"})
-
- // 构建轻量级 prompt-enhance 请求作为连通性测试
- testPayload := map[string]any{
- "model": "prompt-enhance-short-10s",
- "messages": []map[string]string{{"role": "user", "content": "test"}},
- "stream": false,
- }
- payloadBytes, _ := json.Marshal(testPayload)
-
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes))
- if err != nil {
- return s.sendErrorAndEnd(c, "构建测试请求失败")
- }
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Authorization", "Bearer "+apiKey)
-
- // 获取代理 URL
- proxyURL := ""
- if account.ProxyID != nil && account.Proxy != nil {
- proxyURL = account.Proxy.URL()
- }
-
- resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
- if err != nil {
- return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error()))
- }
- defer func() { _ = resp.Body.Close() }()
-
- respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
-
- if resp.StatusCode == http.StatusOK {
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)})
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
- }
-
- if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
- return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode))
- }
-
- // 其他错误但能连通(如 400 参数错误)也算连通性测试通过
- if resp.StatusCode == http.StatusBadRequest {
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)})
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)})
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
- }
-
- return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256)))
-}
-
-// testSoraAccountConnection 测试 Sora 账号的连接
-// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性
-// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性
-func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error {
- // apikey 类型走独立测试流程
- if account.Type == AccountTypeAPIKey {
- return s.testSoraAPIKeyAccountConnection(c, account)
- }
-
- ctx := c.Request.Context()
- recorder := &soraProbeRecorder{}
-
- authToken := account.GetCredential("access_token")
- if authToken == "" {
- recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available")
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, "No access token available")
- }
-
- // Set SSE headers
- c.Writer.Header().Set("Content-Type", "text/event-stream")
- c.Writer.Header().Set("Cache-Control", "no-cache")
- c.Writer.Header().Set("Connection", "keep-alive")
- c.Writer.Header().Set("X-Accel-Buffering", "no")
- c.Writer.Flush()
-
- if wait, ok := s.acquireSoraTestPermit(account.ID); !ok {
- msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait))
- recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg)
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, msg)
- }
-
- // Send test_start event
- s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"})
-
- req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil)
- if err != nil {
- recorder.addStep("me", "failed", 0, "request_build_failed", err.Error())
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, "Failed to create request")
- }
-
- // 使用 Sora 客户端标准请求头
- req.Header.Set("Authorization", "Bearer "+authToken)
- req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Accept-Language", "en-US,en;q=0.9")
- req.Header.Set("Origin", "https://sora.chatgpt.com")
- req.Header.Set("Referer", "https://sora.chatgpt.com/")
-
- // Get proxy URL
- proxyURL := ""
- if account.ProxyID != nil && account.Proxy != nil {
- proxyURL = account.Proxy.URL()
- }
- soraTLSProfile := s.resolveSoraTLSProfile()
-
- resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, soraTLSProfile)
- if err != nil {
- recorder.addStep("me", "failed", 0, "network_error", err.Error())
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
- }
- defer func() { _ = resp.Body.Close() }()
-
- body, _ := io.ReadAll(resp.Body)
-
- if resp.StatusCode != http.StatusOK {
- if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
- recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
- s.emitSoraProbeSummary(c, recorder)
- s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body)
- return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body))
- }
- upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body)
- switch {
- case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"):
- recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated")
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号")
- case strings.EqualFold(upstreamCode, "unsupported_country_code"):
- recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region")
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试")
- case strings.TrimSpace(upstreamMessage) != "":
- recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage)
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage))
- default:
- recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed")
- s.emitSoraProbeSummary(c, recorder)
- return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512)))
- }
- }
- recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok")
-
- // 解析 /me 响应,提取用户信息
- var meResp map[string]any
- if err := json.Unmarshal(body, &meResp); err != nil {
- // 能收到 200 就说明 token 有效
- s.sendEvent(c, TestEvent{Type: "content", Text: "Sora connection OK (token valid)"})
- } else {
- // 尝试提取用户名或邮箱信息
- info := "Sora connection OK"
- if name, ok := meResp["name"].(string); ok && name != "" {
- info = fmt.Sprintf("Sora connection OK - User: %s", name)
- } else if email, ok := meResp["email"].(string); ok && email != "" {
- info = fmt.Sprintf("Sora connection OK - Email: %s", email)
- }
- s.sendEvent(c, TestEvent{Type: "content", Text: info})
- }
-
- // 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试)
- subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil)
- if err == nil {
- subReq.Header.Set("Authorization", "Bearer "+authToken)
- subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
- subReq.Header.Set("Accept", "application/json")
- subReq.Header.Set("Accept-Language", "en-US,en;q=0.9")
- subReq.Header.Set("Origin", "https://sora.chatgpt.com")
- subReq.Header.Set("Referer", "https://sora.chatgpt.com/")
-
- subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, soraTLSProfile)
- if subErr != nil {
- recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error())
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())})
- } else {
- subBody, _ := io.ReadAll(subResp.Body)
- _ = subResp.Body.Close()
- if subResp.StatusCode == http.StatusOK {
- recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok")
- if summary := parseSoraSubscriptionSummary(subBody); summary != "" {
- s.sendEvent(c, TestEvent{Type: "content", Text: summary})
- } else {
- s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"})
- }
- } else {
- if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) {
- recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected")
- s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody)
- s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)})
- } else {
- upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody)
- recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage)
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)})
- }
- }
- }
- }
-
- // 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。
- s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, soraTLSProfile, recorder)
-
- s.emitSoraProbeSummary(c, recorder)
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
-}
-
-func (s *AccountTestService) testSora2Capabilities(
- c *gin.Context,
- ctx context.Context,
- account *Account,
- authToken string,
- proxyURL string,
- tlsProfile *tlsfingerprint.Profile,
- recorder *soraProbeRecorder,
-) {
- inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint(
- ctx,
- account,
- authToken,
- soraInviteMineURL,
- proxyURL,
- tlsProfile,
- )
- if err != nil {
- if recorder != nil {
- recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
- }
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())})
- return
- }
-
- if inviteStatus == http.StatusUnauthorized {
- bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint(
- ctx,
- account,
- authToken,
- soraBootstrapURL,
- proxyURL,
- tlsProfile,
- )
- if bootstrapErr == nil && bootstrapStatus == http.StatusOK {
- if recorder != nil {
- recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok")
- }
- s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"})
- inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint(
- ctx,
- account,
- authToken,
- soraInviteMineURL,
- proxyURL,
- tlsProfile,
- )
- if err != nil {
- if recorder != nil {
- recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error())
- }
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())})
- return
- }
- } else if recorder != nil {
- code := ""
- msg := ""
- if bootstrapErr != nil {
- code = "network_error"
- msg = bootstrapErr.Error()
- }
- recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg)
- }
- }
-
- if inviteStatus != http.StatusOK {
- if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) {
- if recorder != nil {
- recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected")
- }
- s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody)
- s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)})
- return
- }
- upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody)
- if recorder != nil {
- recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage)
- }
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)})
- return
- }
- if recorder != nil {
- recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok")
- }
-
- if summary := parseSoraInviteSummary(inviteBody); summary != "" {
- s.sendEvent(c, TestEvent{Type: "content", Text: summary})
- } else {
- s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"})
- }
-
- remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint(
- ctx,
- account,
- authToken,
- soraRemainingURL,
- proxyURL,
- tlsProfile,
- )
- if remainingErr != nil {
- if recorder != nil {
- recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error())
- }
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())})
- return
- }
- if remainingStatus != http.StatusOK {
- if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) {
- if recorder != nil {
- recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected")
- }
- s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody)
- s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)})
- return
- }
- upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody)
- if recorder != nil {
- recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage)
- }
- s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)})
- return
- }
- if recorder != nil {
- recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok")
- }
- if summary := parseSoraRemainingSummary(remainingBody); summary != "" {
- s.sendEvent(c, TestEvent{Type: "content", Text: summary})
- } else {
- s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"})
- }
-}
-
-func (s *AccountTestService) fetchSoraTestEndpoint(
- ctx context.Context,
- account *Account,
- authToken string,
- url string,
- proxyURL string,
- tlsProfile *tlsfingerprint.Profile,
-) (int, http.Header, []byte, error) {
- req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
- if err != nil {
- return 0, nil, nil, err
- }
- req.Header.Set("Authorization", "Bearer "+authToken)
- req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Accept-Language", "en-US,en;q=0.9")
- req.Header.Set("Origin", "https://sora.chatgpt.com")
- req.Header.Set("Referer", "https://sora.chatgpt.com/")
-
- resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, tlsProfile)
- if err != nil {
- return 0, nil, nil, err
- }
- defer func() { _ = resp.Body.Close() }()
-
- body, readErr := io.ReadAll(resp.Body)
- if readErr != nil {
- return resp.StatusCode, resp.Header, nil, readErr
- }
- return resp.StatusCode, resp.Header, body, nil
-}
-
-func parseSoraSubscriptionSummary(body []byte) string {
- var subResp struct {
- Data []struct {
- Plan struct {
- ID string `json:"id"`
- Title string `json:"title"`
- } `json:"plan"`
- EndTS string `json:"end_ts"`
- } `json:"data"`
- }
- if err := json.Unmarshal(body, &subResp); err != nil {
- return ""
- }
- if len(subResp.Data) == 0 {
- return ""
- }
-
- first := subResp.Data[0]
- parts := make([]string, 0, 3)
- if first.Plan.Title != "" {
- parts = append(parts, first.Plan.Title)
- }
- if first.Plan.ID != "" {
- parts = append(parts, first.Plan.ID)
- }
- if first.EndTS != "" {
- parts = append(parts, "end="+first.EndTS)
- }
- if len(parts) == 0 {
- return ""
- }
- return "Subscription: " + strings.Join(parts, " | ")
-}
-
-func parseSoraInviteSummary(body []byte) string {
- var inviteResp struct {
- InviteCode string `json:"invite_code"`
- RedeemedCount int64 `json:"redeemed_count"`
- TotalCount int64 `json:"total_count"`
- }
- if err := json.Unmarshal(body, &inviteResp); err != nil {
- return ""
- }
-
- parts := []string{"Sora2: supported"}
- if inviteResp.InviteCode != "" {
- parts = append(parts, "invite="+inviteResp.InviteCode)
- }
- if inviteResp.TotalCount > 0 {
- parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount))
- }
- return strings.Join(parts, " | ")
-}
-
-func parseSoraRemainingSummary(body []byte) string {
- var remainingResp struct {
- RateLimitAndCreditBalance struct {
- EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"`
- RateLimitReached bool `json:"rate_limit_reached"`
- AccessResetsInSeconds int64 `json:"access_resets_in_seconds"`
- } `json:"rate_limit_and_credit_balance"`
- }
- if err := json.Unmarshal(body, &remainingResp); err != nil {
- return ""
- }
- info := remainingResp.RateLimitAndCreditBalance
- parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)}
- if info.RateLimitReached {
- parts = append(parts, "rate_limited=true")
- }
- if info.AccessResetsInSeconds > 0 {
- parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds))
- }
- return strings.Join(parts, " | ")
-}
-
-func (s *AccountTestService) resolveSoraTLSProfile() *tlsfingerprint.Profile {
- if s == nil || s.cfg == nil || !s.cfg.Sora.Client.DisableTLSFingerprint {
- // Sora TLS fingerprint enabled — use built-in default profile
- return &tlsfingerprint.Profile{Name: "Built-in Default (Sora)"}
- }
- return nil // disabled
-}
-
-func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
- return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
-}
-
-func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
- return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
-}
-
-func extractCloudflareRayID(headers http.Header, body []byte) string {
- return soraerror.ExtractCloudflareRayID(headers, body)
-}
-
-func extractSoraEgressIPHint(headers http.Header) string {
- if headers == nil {
- return "unknown"
- }
- candidates := []string{
- "x-openai-public-ip",
- "x-envoy-external-address",
- "cf-connecting-ip",
- "x-forwarded-for",
- }
- for _, key := range candidates {
- if value := strings.TrimSpace(headers.Get(key)); value != "" {
- return value
- }
- }
- return "unknown"
-}
-
-func sanitizeProxyURLForLog(raw string) string {
- raw = strings.TrimSpace(raw)
- if raw == "" {
- return ""
- }
- u, err := url.Parse(raw)
- if err != nil {
- return ""
- }
- if u.User != nil {
- u.User = nil
- }
- return u.String()
-}
-
-func endpointPathForLog(endpoint string) string {
- parsed, err := url.Parse(strings.TrimSpace(endpoint))
- if err != nil || parsed.Path == "" {
- return endpoint
- }
- return parsed.Path
-}
-
-func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) {
- accountID := int64(0)
- platform := ""
- proxyID := "none"
- if account != nil {
- accountID = account.ID
- platform = account.Platform
- if account.ProxyID != nil {
- proxyID = fmt.Sprintf("%d", *account.ProxyID)
- }
- }
- cfRay := extractCloudflareRayID(headers, body)
- if cfRay == "" {
- cfRay = "unknown"
- }
- log.Printf(
- "[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s",
- accountID,
- platform,
- endpoint,
- endpointPathForLog(endpoint),
- proxyID,
- sanitizeProxyURLForLog(proxyURL),
- cfRay,
- extractSoraEgressIPHint(headers),
- )
-}
-
-func truncateSoraErrorBody(body []byte, max int) string {
- return soraerror.TruncateBody(body, max)
-}
-
// routeAntigravityTest 路由 Antigravity 账号的测试请求。
// APIKey 类型走原生协议(与 gateway_handler 路由一致),OAuth/Upstream 走 CRS 中转。
func (s *AccountTestService) routeAntigravityTest(c *gin.Context, account *Account, modelID string, prompt string) error {
@@ -1694,13 +1145,17 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader)
// processOpenAIStream processes the SSE stream from OpenAI Responses API
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body)
+ seenCompleted := false
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
+ if seenCompleted {
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+ }
+ return s.sendErrorAndEnd(c, "Stream ended before response.completed")
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
}
@@ -1712,8 +1167,11 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
jsonStr := sseDataPrefix.ReplaceAllString(line, "")
if jsonStr == "[DONE]" {
- s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
- return nil
+ if seenCompleted {
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+ }
+ return s.sendErrorAndEnd(c, "Stream ended before response.completed")
}
var data map[string]any
@@ -1729,9 +1187,19 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
if delta, ok := data["delta"].(string); ok && delta != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: delta})
}
- case "response.completed":
+ case "response.completed", "response.done":
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
+ case "response.failed":
+ errorMsg := "OpenAI response failed"
+ if responseData, ok := data["response"].(map[string]any); ok {
+ if errData, ok := responseData["error"].(map[string]any); ok {
+ if msg, ok := errData["message"].(string); ok && msg != "" {
+ errorMsg = msg
+ }
+ }
+ }
+ return s.sendErrorAndEnd(c, errorMsg)
case "error":
errorMsg := "Unknown error"
if errData, ok := data["error"].(map[string]any); ok {
@@ -1744,7 +1212,198 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
}
}
-// sendEvent sends a SSE event to the client
+// testOpenAIImageAPIKey tests OpenAI image generation using an API Key account.
+func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error {
+ authToken := account.GetOpenAIApiKey()
+ if authToken == "" {
+ return s.sendErrorAndEnd(c, "No API key available")
+ }
+
+ baseURL := account.GetOpenAIBaseURL()
+ if baseURL == "" {
+ baseURL = "https://api.openai.com"
+ }
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
+ }
+ apiURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/images/generations"
+
+ // Set SSE headers
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.Flush()
+
+ s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID})
+
+ payload := map[string]any{
+ "model": modelID,
+ "prompt": prompt,
+ "n": 1,
+ "response_format": "b64_json",
+ }
+ payloadBytes, _ := json.Marshal(payload)
+
+ req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
+ if err != nil {
+ return s.sendErrorAndEnd(c, "Failed to create request")
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+authToken)
+
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+
+ resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to read response: %s", err.Error()))
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
+ }
+
+ // Parse {"data": [{"b64_json": "...", "revised_prompt": "..."}]}
+ var result struct {
+ Data []struct {
+ B64JSON string `json:"b64_json"`
+ RevisedPrompt string `json:"revised_prompt"`
+ } `json:"data"`
+ }
+ if err := json.Unmarshal(body, &result); err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse response: %s", err.Error()))
+ }
+
+ if len(result.Data) == 0 {
+ return s.sendErrorAndEnd(c, "No images returned from API")
+ }
+
+ for _, item := range result.Data {
+ if item.RevisedPrompt != "" {
+ s.sendEvent(c, TestEvent{Type: "content", Text: item.RevisedPrompt})
+ }
+ if item.B64JSON != "" {
+ s.sendEvent(c, TestEvent{
+ Type: "image",
+ ImageURL: "data:image/png;base64," + item.B64JSON,
+ MimeType: "image/png",
+ })
+ }
+ }
+
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+}
+
+// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via Codex /responses API.
+func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error {
+ authToken := account.GetOpenAIAccessToken()
+ if authToken == "" {
+ return s.sendErrorAndEnd(c, "No access token available")
+ }
+
+ // Set SSE headers
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.Flush()
+
+ s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID})
+ s.sendEvent(c, TestEvent{Type: "content", Text: "Calling Codex /responses image tool...\n"})
+
+ parsed := &OpenAIImagesRequest{
+ Endpoint: openAIImagesGenerationsEndpoint,
+ Model: strings.TrimSpace(modelID),
+ Prompt: prompt,
+ }
+ applyOpenAIImagesDefaults(parsed)
+
+ responsesBody, err := buildOpenAIImagesResponsesRequest(parsed, parsed.Model)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build image request: %s", err.Error()))
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatgptCodexAPIURL, bytes.NewReader(responsesBody))
+ if err != nil {
+ return s.sendErrorAndEnd(c, "Failed to create request")
+ }
+ req.Host = "chatgpt.com"
+ req.Header.Set("Authorization", "Bearer "+authToken)
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "text/event-stream")
+ req.Header.Set("OpenAI-Beta", "responses=experimental")
+ req.Header.Set("originator", "opencode")
+ if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" {
+ req.Header.Set("User-Agent", customUA)
+ } else {
+ req.Header.Set("User-Agent", codexCLIUserAgent)
+ }
+ if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" {
+ req.Header.Set("chatgpt-account-id", chatgptAccountID)
+ }
+
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+ resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Responses API request failed: %s", err.Error()))
+ }
+ defer func() {
+ if resp != nil && resp.Body != nil {
+ _ = resp.Body.Close()
+ }
+ }()
+ if resp.StatusCode >= 400 {
+ body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ message := strings.TrimSpace(extractUpstreamErrorMessage(body))
+ if message == "" {
+ message = fmt.Sprintf("Responses API returned %d", resp.StatusCode)
+ }
+ return s.sendErrorAndEnd(c, message)
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to read image response: %s", err.Error()))
+ }
+
+ results, _, _, _, _, err := collectOpenAIImagesFromResponsesBody(body)
+ if err != nil {
+ return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse image response: %s", err.Error()))
+ }
+ if len(results) == 0 {
+ return s.sendErrorAndEnd(c, "No images returned from responses API")
+ }
+
+ for _, item := range results {
+ if item.RevisedPrompt != "" {
+ s.sendEvent(c, TestEvent{Type: "content", Text: item.RevisedPrompt})
+ }
+ mimeType := openAIImageOutputMIMEType(item.OutputFormat)
+ s.sendEvent(c, TestEvent{
+ Type: "image",
+ ImageURL: "data:" + mimeType + ";base64," + item.Result,
+ MimeType: mimeType,
+ })
+ }
+
+ s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
+ return nil
+}
+
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
eventJSON, _ := json.Marshal(event)
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
@@ -1770,7 +1429,7 @@ func (s *AccountTestService) RunTestBackground(ctx context.Context, accountID in
ginCtx, _ := gin.CreateTestContext(w)
ginCtx.Request = (&http.Request{}).WithContext(ctx)
- testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "")
+ testErr := s.TestAccountConnection(ginCtx, accountID, modelID, "", AccountTestModeDefault)
finishedAt := time.Now()
body := w.Body.String()
diff --git a/backend/internal/service/account_test_service_gemini_test.go b/backend/internal/service/account_test_service_gemini_test.go
index 5ba04c69..f38264a2 100644
--- a/backend/internal/service/account_test_service_gemini_test.go
+++ b/backend/internal/service/account_test_service_gemini_test.go
@@ -42,7 +42,7 @@ func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
- ctx, recorder := newSoraTestContext()
+ ctx, recorder := newTestContext()
svc := &AccountTestService{}
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
diff --git a/backend/internal/service/account_test_service_openai_compact_test.go b/backend/internal/service/account_test_service_openai_compact_test.go
new file mode 100644
index 00000000..9eb98fdc
--- /dev/null
+++ b/backend/internal/service/account_test_service_openai_compact_test.go
@@ -0,0 +1,199 @@
+package service
+
+import (
+ "bytes"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestAccountTestService_TestAccountConnection_OpenAICompactOAuthSuccessPersistsSupport(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ updateCalls := make(chan map[string]any, 1)
+ account := Account{
+ ID: 1,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "oauth-token",
+ "chatgpt_account_id": "chatgpt-acc",
+ },
+ }
+ repo := &snapshotUpdateAccountRepo{
+ stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ updateExtraCalls: updateCalls,
+ }
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid-probe"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe","status":"completed"}`)),
+ }}
+ svc := &AccountTestService{
+ accountRepo: repo,
+ httpUpstream: upstream,
+ }
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", bytes.NewReader(nil))
+
+ err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
+ require.NoError(t, err)
+
+ require.Equal(t, chatgptCodexAPIURL+"/compact", upstream.lastReq.URL.String())
+ require.Equal(t, "chatgpt.com", upstream.lastReq.Host)
+ require.Equal(t, "application/json", upstream.lastReq.Header.Get("Accept"))
+ require.Equal(t, codexCLIVersion, upstream.lastReq.Header.Get("Version"))
+ require.NotEmpty(t, upstream.lastReq.Header.Get("Session_Id"))
+ require.Equal(t, codexCLIUserAgent, upstream.lastReq.Header.Get("User-Agent"))
+ require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id"))
+ require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
+
+ updates := <-updateCalls
+ require.Equal(t, true, updates["openai_compact_supported"])
+ require.Equal(t, http.StatusOK, updates["openai_compact_last_status"])
+ require.Contains(t, rec.Body.String(), `"type":"test_complete"`)
+}
+
+func TestAccountTestService_TestAccountConnection_OpenAICompactOAuth404MarksUnsupported(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ updateCalls := make(chan map[string]any, 1)
+ account := Account{
+ ID: 2,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "oauth-token",
+ "chatgpt_account_id": "chatgpt-acc",
+ },
+ }
+ repo := &snapshotUpdateAccountRepo{
+ stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ updateExtraCalls: updateCalls,
+ }
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusNotFound,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`404 page not found`)),
+ }}
+ svc := &AccountTestService{
+ accountRepo: repo,
+ httpUpstream: upstream,
+ }
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/2/test", bytes.NewReader(nil))
+
+ err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
+ require.Error(t, err)
+
+ updates := <-updateCalls
+ require.Equal(t, false, updates["openai_compact_supported"])
+ require.Equal(t, http.StatusNotFound, updates["openai_compact_last_status"])
+ require.Contains(t, rec.Body.String(), `"type":"error"`)
+}
+
+func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyUsesCompactPath(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ updateCalls := make(chan map[string]any, 1)
+ account := Account{
+ ID: 3,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": "https://example.com/v1",
+ "compact_model_mapping": map[string]any{"gpt-5.4": "gpt-5.4-openai-compact"},
+ },
+ }
+ repo := &snapshotUpdateAccountRepo{
+ stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ updateExtraCalls: updateCalls,
+ }
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey","status":"completed"}`)),
+ }}
+ svc := &AccountTestService{
+ accountRepo: repo,
+ httpUpstream: upstream,
+ cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
+ }
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/3/test", bytes.NewReader(nil))
+
+ err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
+ require.NoError(t, err)
+
+ require.Equal(t, "https://example.com/v1/responses/compact", upstream.lastReq.URL.String())
+ require.Equal(t, "gpt-5.4-openai-compact", gjson.GetBytes(upstream.lastBody, "model").String())
+ updates := <-updateCalls
+ require.Equal(t, true, updates["openai_compact_supported"])
+}
+
+func TestAccountTestService_TestAccountConnection_OpenAICompactAPIKeyDefaultBaseURLUsesV1Path(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ updateCalls := make(chan map[string]any, 1)
+ account := Account{
+ ID: 4,
+ Name: "openai-apikey-default",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ }
+ repo := &snapshotUpdateAccountRepo{
+ stubOpenAIAccountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
+ updateExtraCalls: updateCalls,
+ }
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"cmp_probe_apikey_default","status":"completed"}`)),
+ }}
+ svc := &AccountTestService{
+ accountRepo: repo,
+ httpUpstream: upstream,
+ cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
+ }
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/4/test", bytes.NewReader(nil))
+
+ err := svc.TestAccountConnection(c, account.ID, "gpt-5.4", "", AccountTestModeCompact)
+ require.NoError(t, err)
+ require.Equal(t, "https://api.openai.com/v1/responses/compact", upstream.lastReq.URL.String())
+ <-updateCalls
+}
diff --git a/backend/internal/service/account_test_service_openai_image_test.go b/backend/internal/service/account_test_service_openai_image_test.go
new file mode 100644
index 00000000..80a2fc31
--- /dev/null
+++ b/backend/internal/service/account_test_service_openai_image_test.go
@@ -0,0 +1,50 @@
+package service
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{
+ "Content-Type": []string{"text/event-stream"},
+ },
+ Body: io.NopCloser(strings.NewReader(
+ "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\"}}\n\n" +
+ "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000006,\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" +
+ "data: [DONE]\n\n",
+ )),
+ },
+ }
+ svc := &AccountTestService{httpUpstream: upstream}
+ account := &Account{
+ ID: 53,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "token-123",
+ },
+ }
+
+ err := svc.testOpenAIImageOAuth(c, context.Background(), account, "gpt-image-2", "draw a cat")
+ require.NoError(t, err)
+ require.Contains(t, rec.Body.String(), "Calling Codex /responses image tool")
+ require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=")
+ require.Contains(t, rec.Body.String(), "\"success\":true")
+}
diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go
index efa6f7da..56204be3 100644
--- a/backend/internal/service/account_test_service_openai_test.go
+++ b/backend/internal/service/account_test_service_openai_test.go
@@ -4,21 +4,69 @@ package service
import (
"context"
+ "fmt"
"io"
"net/http"
+ "net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
)
+// --- shared test helpers ---
+
+type queuedHTTPUpstream struct {
+ responses []*http.Response
+ requests []*http.Request
+ tlsFlags []bool
+}
+
+func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
+ return nil, fmt.Errorf("unexpected Do call")
+}
+
+func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) {
+ u.requests = append(u.requests, req)
+ u.tlsFlags = append(u.tlsFlags, profile != nil)
+ if len(u.responses) == 0 {
+ return nil, fmt.Errorf("no mocked response")
+ }
+ resp := u.responses[0]
+ u.responses = u.responses[1:]
+ return resp, nil
+}
+
+func newJSONResponse(status int, body string) *http.Response {
+ return &http.Response{
+ StatusCode: status,
+ Header: make(http.Header),
+ Body: io.NopCloser(strings.NewReader(body)),
+ }
+}
+
+// --- test functions ---
+
+func newTestContext() (*gin.Context, *httptest.ResponseRecorder) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
+ return c, rec
+}
+
type openAIAccountTestRepo struct {
mockAccountRepoForGemini
- updatedExtra map[string]any
- rateLimitedID int64
- rateLimitedAt *time.Time
+ updatedExtra map[string]any
+ rateLimitedID int64
+ rateLimitedAt *time.Time
+ clearedErrorID int64
+ setErrorID int64
+ setErrorMsg string
}
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
@@ -32,9 +80,20 @@ func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, rese
return nil
}
+func (r *openAIAccountTestRepo) ClearError(_ context.Context, id int64) error {
+ r.clearedErrorID = id
+ return nil
+}
+
+func (r *openAIAccountTestRepo) SetError(_ context.Context, id int64, errorMsg string) error {
+ r.setErrorID = id
+ r.setErrorMsg = errorMsg
+ return nil
+}
+
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
- ctx, recorder := newSoraTestContext()
+ ctx, recorder := newTestContext()
resp := newJSONResponse(http.StatusOK, "")
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
@@ -58,7 +117,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
Credentials: map[string]any{"access_token": "test-token"},
}
- err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.NoError(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
@@ -66,11 +125,36 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
require.Contains(t, recorder.Body.String(), "test_complete")
}
-func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
+func TestAccountTestService_OpenAIStreamEOFBeforeCompletedFails(t *testing.T) {
gin.SetMode(gin.TestMode)
- ctx, _ := newSoraTestContext()
+ ctx, recorder := newTestContext()
- resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
+ resp := newJSONResponse(http.StatusOK, "")
+ resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.output_text.delta","delta":"hi"}
+
+`))
+
+ upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
+ svc := &AccountTestService{httpUpstream: upstream}
+ account := &Account{
+ ID: 90,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{"access_token": "test-token"},
+ }
+
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
+ require.Error(t, err)
+ require.Contains(t, recorder.Body.String(), "response.completed")
+ require.NotContains(t, recorder.Body.String(), `"success":true`)
+}
+
+func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimitState(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":1777283883}}`)
resp.Header.Set("x-codex-primary-used-percent", "100")
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
resp.Header.Set("x-codex-primary-window-minutes", "10080")
@@ -85,18 +169,132 @@ func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T)
ID: 88,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
+ Status: StatusError,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
- err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
- require.Equal(t, int64(88), repo.rateLimitedID)
+ require.Equal(t, account.ID, repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
+ require.Equal(t, account.ID, repo.clearedErrorID)
+ require.Equal(t, StatusActive, account.Status)
+ require.Empty(t, account.ErrorMessage)
require.NotNil(t, account.RateLimitResetAt)
- if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil {
- require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second)
- }
+}
+
+func TestAccountTestService_OpenAI429BodyOnlyPersistsRateLimitAndClearsStaleError(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_at":"1777283883"}}`)
+
+ repo := &openAIAccountTestRepo{}
+ upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
+ svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
+ account := &Account{
+ ID: 77,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusError,
+ ErrorMessage: "Access forbidden (403): account may be suspended or lack permissions",
+ Concurrency: 1,
+ Credentials: map[string]any{"access_token": "test-token"},
+ }
+
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
+ require.Error(t, err)
+ require.Equal(t, account.ID, repo.rateLimitedID)
+ require.NotNil(t, repo.rateLimitedAt)
+ require.Equal(t, account.ID, repo.clearedErrorID)
+ require.Equal(t, StatusActive, account.Status)
+ require.Empty(t, account.ErrorMessage)
+ require.NotNil(t, account.RateLimitResetAt)
+ require.Empty(t, repo.updatedExtra)
+}
+
+func TestAccountTestService_OpenAI429ActiveAccountDoesNotClearError(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached","resets_in_seconds":3600}}`)
+
+ repo := &openAIAccountTestRepo{}
+ upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
+ svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
+ account := &Account{
+ ID: 78,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Concurrency: 1,
+ Credentials: map[string]any{"access_token": "test-token"},
+ }
+
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
+ require.Error(t, err)
+ require.Equal(t, account.ID, repo.rateLimitedID)
+ require.NotNil(t, repo.rateLimitedAt)
+ require.Zero(t, repo.clearedErrorID)
+ require.Equal(t, StatusActive, account.Status)
+ require.NotNil(t, account.RateLimitResetAt)
+}
+
+func TestAccountTestService_OpenAI429WithoutResetSignalDoesNotMutateRuntimeState(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
+
+ repo := &openAIAccountTestRepo{}
+ upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
+ svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
+ account := &Account{
+ ID: 79,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusError,
+ ErrorMessage: "stale 403",
+ Concurrency: 1,
+ Credentials: map[string]any{"access_token": "test-token"},
+ }
+
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
+ require.Error(t, err)
+ require.Zero(t, repo.rateLimitedID)
+ require.Nil(t, repo.rateLimitedAt)
+ require.Zero(t, repo.clearedErrorID)
+ require.Equal(t, StatusError, account.Status)
+ require.Equal(t, "stale 403", account.ErrorMessage)
+ require.Nil(t, account.RateLimitResetAt)
+}
+
+func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ ctx, _ := newTestContext()
+
+ resp := newJSONResponse(http.StatusUnauthorized, `{"error":"bad token"}`)
+
+ repo := &openAIAccountTestRepo{}
+ upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
+ svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
+ account := &Account{
+ ID: 80,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Concurrency: 1,
+ Credentials: map[string]any{"access_token": "test-token"},
+ }
+
+ err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
+ require.Error(t, err)
+ require.Equal(t, account.ID, repo.setErrorID)
+ require.Contains(t, repo.setErrorMsg, "Authentication failed (401)")
+ require.Zero(t, repo.rateLimitedID)
+ require.Zero(t, repo.clearedErrorID)
+ require.Nil(t, account.RateLimitResetAt)
}
diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go
deleted file mode 100644
index 52f832a9..00000000
--- a/backend/internal/service/account_test_service_sora_test.go
+++ /dev/null
@@ -1,320 +0,0 @@
-package service
-
-import (
- "fmt"
- "io"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/require"
-)
-
-type queuedHTTPUpstream struct {
- responses []*http.Response
- requests []*http.Request
- tlsFlags []bool
-}
-
-func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
- return nil, fmt.Errorf("unexpected Do call")
-}
-
-func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, profile *tlsfingerprint.Profile) (*http.Response, error) {
- u.requests = append(u.requests, req)
- u.tlsFlags = append(u.tlsFlags, profile != nil)
- if len(u.responses) == 0 {
- return nil, fmt.Errorf("no mocked response")
- }
- resp := u.responses[0]
- u.responses = u.responses[1:]
- return resp, nil
-}
-
-func newJSONResponse(status int, body string) *http.Response {
- return &http.Response{
- StatusCode: status,
- Header: make(http.Header),
- Body: io.NopCloser(strings.NewReader(body)),
- }
-}
-
-func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
- resp := newJSONResponse(status, body)
- resp.Header.Set(key, value)
- return resp
-}
-
-func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
- gin.SetMode(gin.TestMode)
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
- return c, rec
-}
-
-func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
- upstream := &queuedHTTPUpstream{
- responses: []*http.Response{
- newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
- newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
- newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`),
- newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`),
- },
- }
- svc := &AccountTestService{
- httpUpstream: upstream,
- cfg: &config.Config{
- Gateway: config.GatewayConfig{
- TLSFingerprint: config.TLSFingerprintConfig{
- Enabled: true,
- },
- },
- Sora: config.SoraConfig{
- Client: config.SoraClientConfig{
- DisableTLSFingerprint: false,
- },
- },
- },
- }
- account := &Account{
- ID: 1,
- Platform: PlatformSora,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "test_token",
- },
- }
-
- c, rec := newSoraTestContext()
- err := svc.testSoraAccountConnection(c, account)
-
- require.NoError(t, err)
- require.Len(t, upstream.requests, 4)
- require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
- require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
- require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String())
- require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String())
- require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
- require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
- require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags)
-
- body := rec.Body.String()
- require.Contains(t, body, `"type":"test_start"`)
- require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
- require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
- require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
- require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
- require.Contains(t, body, `"type":"sora_test_result"`)
- require.Contains(t, body, `"status":"success"`)
- require.Contains(t, body, `"type":"test_complete","success":true`)
-}
-
-func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
- upstream := &queuedHTTPUpstream{
- responses: []*http.Response{
- newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
- newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
- newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`),
- newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
- },
- }
- svc := &AccountTestService{httpUpstream: upstream}
- account := &Account{
- ID: 1,
- Platform: PlatformSora,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "test_token",
- },
- }
-
- c, rec := newSoraTestContext()
- err := svc.testSoraAccountConnection(c, account)
-
- require.NoError(t, err)
- require.Len(t, upstream.requests, 4)
- body := rec.Body.String()
- require.Contains(t, body, "Sora connection OK - User: demo-user")
- require.Contains(t, body, "Subscription check returned 403")
- require.Contains(t, body, "Sora2 invite check returned 401")
- require.Contains(t, body, `"type":"sora_test_result"`)
- require.Contains(t, body, `"status":"partial_success"`)
- require.Contains(t, body, `"type":"test_complete","success":true`)
-}
-
-func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
- upstream := &queuedHTTPUpstream{
- responses: []*http.Response{
- newJSONResponseWithHeader(http.StatusForbidden, `Just a moment... Enable JavaScript and cookies to continue `, "cf-ray", "9cff2d62d83bb98d"),
- },
- }
- svc := &AccountTestService{httpUpstream: upstream}
- account := &Account{
- ID: 1,
- Platform: PlatformSora,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "test_token",
- },
- }
-
- c, rec := newSoraTestContext()
- err := svc.testSoraAccountConnection(c, account)
-
- require.Error(t, err)
- require.Contains(t, err.Error(), "Cloudflare challenge")
- require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
- body := rec.Body.String()
- require.Contains(t, body, `"type":"error"`)
- require.Contains(t, body, "Cloudflare challenge")
- require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
-}
-
-func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
- upstream := &queuedHTTPUpstream{
- responses: []*http.Response{
- newJSONResponseWithHeader(http.StatusTooManyRequests, `Just a moment... `, "cf-mitigated", "challenge"),
- },
- }
- svc := &AccountTestService{httpUpstream: upstream}
- account := &Account{
- ID: 1,
- Platform: PlatformSora,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "test_token",
- },
- }
-
- c, rec := newSoraTestContext()
- err := svc.testSoraAccountConnection(c, account)
-
- require.Error(t, err)
- require.Contains(t, err.Error(), "Cloudflare challenge")
- require.Contains(t, err.Error(), "HTTP 429")
- body := rec.Body.String()
- require.Contains(t, body, "Cloudflare challenge")
-}
-
-func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
- upstream := &queuedHTTPUpstream{
- responses: []*http.Response{
- newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
- },
- }
- svc := &AccountTestService{httpUpstream: upstream}
- account := &Account{
- ID: 1,
- Platform: PlatformSora,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "test_token",
- },
- }
-
- c, rec := newSoraTestContext()
- err := svc.testSoraAccountConnection(c, account)
-
- require.Error(t, err)
- require.Contains(t, err.Error(), "token_invalidated")
- body := rec.Body.String()
- require.Contains(t, body, `"type":"sora_test_result"`)
- require.Contains(t, body, `"status":"failed"`)
- require.Contains(t, body, "token_invalidated")
- require.NotContains(t, body, `"type":"test_complete","success":true`)
-}
-
-func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
- upstream := &queuedHTTPUpstream{
- responses: []*http.Response{
- newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
- },
- }
- svc := &AccountTestService{
- httpUpstream: upstream,
- soraTestCooldown: time.Hour,
- }
- account := &Account{
- ID: 1,
- Platform: PlatformSora,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "test_token",
- },
- }
-
- c1, _ := newSoraTestContext()
- err := svc.testSoraAccountConnection(c1, account)
- require.NoError(t, err)
-
- c2, rec2 := newSoraTestContext()
- err = svc.testSoraAccountConnection(c2, account)
- require.Error(t, err)
- require.Contains(t, err.Error(), "测试过于频繁")
- body := rec2.Body.String()
- require.Contains(t, body, `"type":"sora_test_result"`)
- require.Contains(t, body, `"code":"test_rate_limited"`)
- require.Contains(t, body, `"status":"failed"`)
- require.NotContains(t, body, `"type":"test_complete","success":true`)
-}
-
-func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
- upstream := &queuedHTTPUpstream{
- responses: []*http.Response{
- newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
- newJSONResponse(http.StatusForbidden, `Just a moment... Enable JavaScript and cookies to continue `),
- newJSONResponse(http.StatusForbidden, `Just a moment... Enable JavaScript and cookies to continue `),
- },
- }
- svc := &AccountTestService{httpUpstream: upstream}
- account := &Account{
- ID: 1,
- Platform: PlatformSora,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "test_token",
- },
- }
-
- c, rec := newSoraTestContext()
- err := svc.testSoraAccountConnection(c, account)
-
- require.NoError(t, err)
- body := rec.Body.String()
- require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
- require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)")
- require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
- require.Contains(t, body, `"type":"test_complete","success":true`)
-}
-
-func TestSanitizeProxyURLForLog(t *testing.T) {
- require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080"))
- require.Equal(t, "", sanitizeProxyURLForLog(""))
- require.Equal(t, "", sanitizeProxyURLForLog("://invalid"))
-}
-
-func TestExtractSoraEgressIPHint(t *testing.T) {
- h := make(http.Header)
- h.Set("x-openai-public-ip", "203.0.113.10")
- require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h))
-
- h2 := make(http.Header)
- h2.Set("x-envoy-external-address", "198.51.100.9")
- require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2))
-
- require.Equal(t, "unknown", extractSoraEgressIPHint(nil))
- require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{}))
-}
diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go
index 0e5741d8..68ba8f8c 100644
--- a/backend/internal/service/account_usage_service.go
+++ b/backend/internal/service/account_usage_service.go
@@ -110,7 +110,7 @@ const (
apiQueryMaxJitter = 800 * time.Millisecond // 用量查询最大随机延迟
windowStatsCacheTTL = 1 * time.Minute
openAIProbeCacheTTL = 10 * time.Minute
- openAICodexProbeVersion = "0.104.0"
+ openAICodexProbeVersion = "0.125.0"
)
// UsageCache 封装账户使用量相关的缓存
@@ -499,7 +499,6 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
if account == nil {
return usage, nil
}
- syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, now)
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
usage.FiveHour = progress
@@ -509,11 +508,8 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
}
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
- if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) {
+ if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
mergeAccountExtra(account, updates)
- if resetAt != nil {
- account.RateLimitResetAt = resetAt
- }
if usage.UpdatedAt == nil {
usage.UpdatedAt = &now
}
@@ -594,26 +590,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no
return true
}
-func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) {
+func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) {
if account == nil || !account.IsOAuth() {
- return nil, nil, nil
+ return nil, nil
}
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
- return nil, nil, fmt.Errorf("no access token available")
+ return nil, fmt.Errorf("no access token available")
}
modelID := openaipkg.DefaultTestModel
payload := createOpenAITestPayload(modelID, true)
payloadBytes, err := json.Marshal(payload)
if err != nil {
- return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err)
+ return nil, fmt.Errorf("marshal openai probe payload: %w", err)
}
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
if err != nil {
- return nil, nil, fmt.Errorf("create openai probe request: %w", err)
+ return nil, fmt.Errorf("create openai probe request: %w", err)
}
req.Host = "chatgpt.com"
req.Header.Set("Content-Type", "application/json")
@@ -642,67 +638,51 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
ResponseHeaderTimeout: 10 * time.Second,
})
if err != nil {
- return nil, nil, fmt.Errorf("build openai probe client: %w", err)
+ return nil, fmt.Errorf("build openai probe client: %w", err)
}
resp, err := client.Do(req)
if err != nil {
- return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err)
+ return nil, fmt.Errorf("openai codex probe request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
- updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp)
+ updates, err := extractOpenAICodexProbeUpdates(resp)
if err != nil {
- return nil, nil, err
+ return nil, err
}
- if len(updates) > 0 || resetAt != nil {
- s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt)
- return updates, resetAt, nil
+ if len(updates) > 0 {
+ s.persistOpenAICodexProbeSnapshot(account.ID, updates)
+ return updates, nil
}
- return nil, nil, nil
+ return nil, nil
}
-func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) {
+func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any) {
if s == nil || s.accountRepo == nil || accountID <= 0 {
return
}
- if len(updates) == 0 && resetAt == nil {
+ if len(updates) == 0 {
return
}
go func() {
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer updateCancel()
- if len(updates) > 0 {
- _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
- }
- if resetAt != nil {
- _ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
- }
+ _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}()
}
-func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) {
+func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
if resp == nil {
- return nil, nil, nil
+ return nil, nil
}
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
- baseTime := time.Now()
- updates := buildCodexUsageExtraUpdates(snapshot, baseTime)
- resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime)
- if len(updates) > 0 {
- return updates, resetAt, nil
- }
- return nil, resetAt, nil
+ return buildCodexUsageExtraUpdates(snapshot, time.Now()), nil
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
- return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
+ return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
}
- return nil, nil, nil
-}
-
-func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
- updates, _, err := extractOpenAICodexProbeSnapshot(resp)
- return updates, err
+ return nil, nil
}
func mergeAccountExtra(account *Account, updates map[string]any) {
diff --git a/backend/internal/service/account_usage_service_test.go b/backend/internal/service/account_usage_service_test.go
index fe255225..28b49838 100644
--- a/backend/internal/service/account_usage_service_test.go
+++ b/backend/internal/service/account_usage_service_test.go
@@ -92,30 +92,7 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T)
}
}
-func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) {
- t.Parallel()
-
- headers := make(http.Header)
- headers.Set("x-codex-primary-used-percent", "100")
- headers.Set("x-codex-primary-reset-after-seconds", "604800")
- headers.Set("x-codex-primary-window-minutes", "10080")
- headers.Set("x-codex-secondary-used-percent", "100")
- headers.Set("x-codex-secondary-reset-after-seconds", "18000")
- headers.Set("x-codex-secondary-window-minutes", "300")
-
- updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
- if err != nil {
- t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err)
- }
- if len(updates) == 0 {
- t.Fatal("expected codex probe updates from 429 headers")
- }
- if resetAt == nil {
- t.Fatal("expected resetAt from exhausted codex headers")
- }
-}
-
-func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) {
+func TestAccountUsageService_PersistOpenAICodexProbeSnapshotOnlyUpdatesExtra(t *testing.T) {
t.Parallel()
repo := &accountUsageCodexProbeRepo{
@@ -123,12 +100,10 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
rateLimitCh: make(chan time.Time, 1),
}
svc := &AccountUsageService{accountRepo: repo}
- resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
-
svc.persistOpenAICodexProbeSnapshot(321, map[string]any{
"codex_7d_used_percent": 100.0,
- "codex_7d_reset_at": resetAt.Format(time.RFC3339),
- }, &resetAt)
+ "codex_7d_reset_at": time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second).Format(time.RFC3339),
+ })
select {
case updates := <-repo.updateExtraCh:
@@ -136,16 +111,49 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
}
case <-time.After(2 * time.Second):
- t.Fatal("waiting for codex probe extra persistence timed out")
+ t.Fatal("等待 codex 探测快照写入 extra 超时")
}
select {
case got := <-repo.rateLimitCh:
- if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) {
- t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt)
- }
- case <-time.After(2 * time.Second):
- t.Fatal("waiting for codex probe rate limit persistence timed out")
+ t.Fatalf("不应将探测快照写入运行时限流状态: %v", got)
+ case <-time.After(200 * time.Millisecond):
+ }
+}
+
+func TestAccountUsageService_GetOpenAIUsage_DoesNotPromoteCodexExtraToRateLimit(t *testing.T) {
+ t.Parallel()
+
+ resetAt := time.Now().Add(6 * 24 * time.Hour).UTC().Truncate(time.Second)
+ repo := &accountUsageCodexProbeRepo{
+ rateLimitCh: make(chan time.Time, 1),
+ }
+ svc := &AccountUsageService{accountRepo: repo}
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "codex_5h_used_percent": 1.0,
+ "codex_5h_reset_at": time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second).Format(time.RFC3339),
+ "codex_7d_used_percent": 100.0,
+ "codex_7d_reset_at": resetAt.Format(time.RFC3339),
+ },
+ }
+
+ usage, err := svc.getOpenAIUsage(context.Background(), account)
+ if err != nil {
+ t.Fatalf("getOpenAIUsage() error = %v", err)
+ }
+ if usage.SevenDay == nil || usage.SevenDay.Utilization != 100.0 {
+ t.Fatalf("预期 7 天用量仍然可见,实际为 %#v", usage.SevenDay)
+ }
+ if account.RateLimitResetAt != nil {
+ t.Fatalf("不应让已耗尽的 codex extra 改写运行时限流状态: %v", account.RateLimitResetAt)
+ }
+ select {
+ case got := <-repo.rateLimitCh:
+ t.Fatalf("不应将已耗尽的 codex extra 持久化为运行时限流状态: %v", got)
+ case <-time.After(200 * time.Millisecond):
}
}
diff --git a/backend/internal/service/account_websearch_test.go b/backend/internal/service/account_websearch_test.go
new file mode 100644
index 00000000..6ed69d4c
--- /dev/null
+++ b/backend/internal/service/account_websearch_test.go
@@ -0,0 +1,105 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestGetWebSearchEmulationMode_Enabled(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
+ }
+ require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_Disabled(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "disabled"},
+ }
+ require.Equal(t, WebSearchModeDisabled, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_Default(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "default"},
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_UnknownString(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "unknown"},
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_OldBoolTrue(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: true},
+ }
+ // bool true → tolerant fallback → enabled (not default)
+ require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_OldBoolFalse(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: false},
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_NilAccount(t *testing.T) {
+ var a *Account
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_NilExtra(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: nil,
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_MissingField(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{},
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_NonAnthropicPlatform(t *testing.T) {
+ a := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
+
+func TestGetWebSearchEmulationMode_NonAPIKeyType(t *testing.T) {
+ a := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
+ }
+ require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
+}
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 0620d7ca..434f1f38 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -2,40 +2,46 @@ package service
import (
"context"
+ "encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
+ "sort"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
- "github.com/Wei-Shaw/sub2api/internal/util/soraerror"
+ "github.com/Wei-Shaw/sub2api/internal/util/httputil"
)
// AdminService interface defines admin management operations
type AdminService interface {
// User management
- ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error)
+ ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
- GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
+ GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
+ GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error)
// GetUserBalanceHistory returns paginated balance/concurrency change records for a user.
// codeType is optional - pass empty string to return all types.
// Also returns totalRecharged (sum of all positive balance top-ups).
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
+ BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error)
// Group management
- ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
+ ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*Group, error)
@@ -46,6 +52,8 @@ type AdminService interface {
GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error)
ClearGroupRateMultipliers(ctx context.Context, groupID int64) error
BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error
+ ClearGroupRPMOverrides(ctx context.Context, groupID int64) error
+ BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// API Key management (admin)
@@ -55,7 +63,7 @@ type AdminService interface {
ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error)
// Account management
- ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error)
+ ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
@@ -77,8 +85,8 @@ type AdminService interface {
CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error
// Proxy management
- ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
- ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error)
+ ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]Proxy, int64, error)
+ ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]ProxyWithAccountCount, int64, error)
GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*Proxy, error)
@@ -93,7 +101,7 @@ type AdminService interface {
CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error)
// Redeem code management
- ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
+ ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error)
DeleteRedeemCode(ctx context.Context, id int64) error
@@ -104,14 +112,14 @@ type AdminService interface {
// CreateUserInput represents input for creating a new user via admin operations.
type CreateUserInput struct {
- Email string
- Password string
- Username string
- Notes string
- Balance float64
- Concurrency int
- AllowedGroups []int64
- SoraStorageQuotaBytes int64
+ Email string
+ Password string
+ Username string
+ Notes string
+ Balance float64
+ Concurrency int
+ RPMLimit int
+ AllowedGroups []int64
}
type UpdateUserInput struct {
@@ -121,12 +129,50 @@ type UpdateUserInput struct {
Notes *string
Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0"
+ RPMLimit *int // 使用指针区分"未提供"和"设置为0"
Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
// map[groupID]*rate,nil 表示删除该分组的专属倍率
- GroupRates map[int64]*float64
- SoraStorageQuotaBytes *int64
+ GroupRates map[int64]*float64
+}
+
+type AdminBindAuthIdentityInput struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+ Issuer *string
+ Metadata map[string]any
+ Channel *AdminBindAuthIdentityChannelInput
+}
+
+type AdminBindAuthIdentityChannelInput struct {
+ Channel string
+ ChannelAppID string
+ ChannelSubject string
+ Metadata map[string]any
+}
+
+type AdminBoundAuthIdentity struct {
+ UserID int64 `json:"user_id"`
+ ProviderType string `json:"provider_type"`
+ ProviderKey string `json:"provider_key"`
+ ProviderSubject string `json:"provider_subject"`
+ VerifiedAt *time.Time `json:"verified_at,omitempty"`
+ Issuer *string `json:"issuer,omitempty"`
+ Metadata map[string]any `json:"metadata"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+ Channel *AdminBoundAuthIdentityChannel `json:"channel,omitempty"`
+}
+
+type AdminBoundAuthIdentityChannel struct {
+ Channel string `json:"channel"`
+ ChannelAppID string `json:"channel_app_id"`
+ ChannelSubject string `json:"channel_subject"`
+ Metadata map[string]any `json:"metadata"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
}
type CreateGroupInput struct {
@@ -140,16 +186,11 @@ type CreateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
- ImagePrice1K *float64
- ImagePrice2K *float64
- ImagePrice4K *float64
- // Sora 按次计费配置
- SoraImagePrice360 *float64
- SoraImagePrice540 *float64
- SoraVideoPricePerRequest *float64
- SoraVideoPricePerRequestHD *float64
- ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
- FallbackGroupID *int64 // 降级分组 ID
+ ImagePrice1K *float64
+ ImagePrice2K *float64
+ ImagePrice4K *float64
+ ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
+ FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
@@ -158,13 +199,14 @@ type CreateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string
- // Sora 存储配额
- SoraStorageQuotaBytes int64
// OpenAI Messages 调度配置(仅 openai 平台使用)
- AllowMessagesDispatch bool
- DefaultMappedModel string
- RequireOAuthOnly bool
- RequirePrivacySet bool
+ AllowMessagesDispatch bool
+ DefaultMappedModel string
+ RequireOAuthOnly bool
+ RequirePrivacySet bool
+ MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig
+ // RPMLimit 分组 RPM 上限(0 = 不限制)
+ RPMLimit int
// 从指定分组复制账号(创建分组后在同一事务内绑定)
CopyAccountsFromGroupIDs []int64
}
@@ -181,16 +223,11 @@ type UpdateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
- ImagePrice1K *float64
- ImagePrice2K *float64
- ImagePrice4K *float64
- // Sora 按次计费配置
- SoraImagePrice360 *float64
- SoraImagePrice540 *float64
- SoraVideoPricePerRequest *float64
- SoraVideoPricePerRequestHD *float64
- ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
- FallbackGroupID *int64 // 降级分组 ID
+ ImagePrice1K *float64
+ ImagePrice2K *float64
+ ImagePrice4K *float64
+ ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
+ FallbackGroupID *int64 // 降级分组 ID
// 无效请求兜底分组 ID(仅 anthropic 平台使用)
FallbackGroupIDOnInvalidRequest *int64
// 模型路由配置(仅 anthropic 平台使用)
@@ -199,13 +236,14 @@ type UpdateGroupInput struct {
MCPXMLInject *bool
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string
- // Sora 存储配额
- SoraStorageQuotaBytes *int64
// OpenAI Messages 调度配置(仅 openai 平台使用)
- AllowMessagesDispatch *bool
- DefaultMappedModel *string
- RequireOAuthOnly *bool
- RequirePrivacySet *bool
+ AllowMessagesDispatch *bool
+ DefaultMappedModel *string
+ RequireOAuthOnly *bool
+ RequirePrivacySet *bool
+ MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig
+ // RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。
+ RPMLimit *int
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64
}
@@ -289,6 +327,22 @@ type ReplaceUserGroupResult struct {
MigratedKeys int64 // 迁移的 Key 数量
}
+// UserRPMStatus describes a user's current per-minute RPM usage.
+type UserRPMStatus struct {
+ UserRPMUsed int `json:"user_rpm_used"`
+ UserRPMLimit int `json:"user_rpm_limit"`
+ PerGroup []UserGroupRPMStatus `json:"per_group"`
+}
+
+// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair.
+type UserGroupRPMStatus struct {
+ GroupID int64 `json:"group_id"`
+ GroupName string `json:"group_name"`
+ Used int `json:"used"`
+ Limit int `json:"limit"`
+ Source string `json:"source"` // "group" | "override"
+}
+
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
@@ -426,14 +480,6 @@ var proxyQualityTargets = []proxyQualityTarget{
http.StatusOK: {},
},
},
- {
- Target: "sora",
- URL: "https://sora.chatgpt.com/backend/me",
- Method: http.MethodGet,
- AllowedStatuses: map[int]struct{}{
- http.StatusUnauthorized: {},
- },
- },
}
const (
@@ -443,16 +489,18 @@ const (
proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36"
)
+var ErrRPMStatusUnavailable = infraerrors.New(http.StatusNotImplemented, "RPM_STATUS_UNAVAILABLE", "RPM cache not available")
+
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
userRepo UserRepository
groupRepo GroupRepository
accountRepo AccountRepository
- soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
userGroupRateRepo UserGroupRateRepository
+ userRPMCache UserRPMCache
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
@@ -473,11 +521,11 @@ func NewAdminService(
userRepo UserRepository,
groupRepo GroupRepository,
accountRepo AccountRepository,
- soraAccountRepo SoraAccountRepository,
proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
userGroupRateRepo UserGroupRateRepository,
+ userRPMCache UserRPMCache,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
@@ -492,11 +540,11 @@ func NewAdminService(
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
- soraAccountRepo: soraAccountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
userGroupRateRepo: userGroupRateRepo,
+ userRPMCache: userRPMCache,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
@@ -510,12 +558,26 @@ func NewAdminService(
}
// User management implementations
-func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) {
- params := pagination.PaginationParams{Page: page, PageSize: pageSize}
+func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error) {
+ params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
users, result, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil {
return nil, 0, err
}
+ if len(users) > 0 {
+ userIDs := make([]int64, 0, len(users))
+ for i := range users {
+ userIDs = append(userIDs, users[i].ID)
+ }
+ lastUsedByUserID, latestErr := s.userRepo.GetLatestUsedAtByUserIDs(ctx, userIDs)
+ if latestErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to load user last_used_at in batch: err=%v", latestErr)
+ } else {
+ for i := range users {
+ users[i].LastUsedAt = lastUsedByUserID[users[i].ID]
+ }
+ }
+ }
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok {
@@ -560,6 +622,12 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
if err != nil {
return nil, err
}
+ lastUsedAt, latestErr := s.userRepo.GetLatestUsedAtByUserID(ctx, id)
+ if latestErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to load user last_used_at: user_id=%d err=%v", id, latestErr)
+ } else {
+ user.LastUsedAt = lastUsedAt
+ }
// 加载用户专属分组倍率
if s.userGroupRateRepo != nil {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
@@ -574,15 +642,15 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
user := &User{
- Email: input.Email,
- Username: input.Username,
- Notes: input.Notes,
- Role: RoleUser, // Always create as regular user, never admin
- Balance: input.Balance,
- Concurrency: input.Concurrency,
- Status: StatusActive,
- AllowedGroups: input.AllowedGroups,
- SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
+ Email: input.Email,
+ Username: input.Username,
+ Notes: input.Notes,
+ Role: RoleUser, // Always create as regular user, never admin
+ Balance: input.Balance,
+ Concurrency: input.Concurrency,
+ RPMLimit: input.RPMLimit,
+ Status: StatusActive,
+ AllowedGroups: input.AllowedGroups,
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
@@ -612,6 +680,15 @@ func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userI
}
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
+ // 校验用户专属分组倍率:必须 > 0(nil 合法,表示清除专属倍率)
+ if input.GroupRates != nil {
+ for groupID, rate := range input.GroupRates {
+ if rate != nil && *rate <= 0 {
+ return nil, fmt.Errorf("rate_multiplier must be > 0 (group_id=%d)", groupID)
+ }
+ }
+ }
+
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
@@ -625,6 +702,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
oldConcurrency := user.Concurrency
oldStatus := user.Status
oldRole := user.Role
+ oldRPMLimit := user.RPMLimit
if input.Email != "" {
user.Email = input.Email
@@ -650,12 +728,12 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
user.Concurrency = *input.Concurrency
}
- if input.AllowedGroups != nil {
- user.AllowedGroups = *input.AllowedGroups
+ if input.RPMLimit != nil {
+ user.RPMLimit = *input.RPMLimit
}
- if input.SoraStorageQuotaBytes != nil {
- user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
+ if input.AllowedGroups != nil {
+ user.AllowedGroups = *input.AllowedGroups
}
if err := s.userRepo.Update(ctx, user); err != nil {
@@ -670,7 +748,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
}
if s.authCacheInvalidator != nil {
- if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
+ // RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联,
+ // 不失效缓存会让修改在一个 L2 TTL 内失去效果。
+ if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole || user.RPMLimit != oldRPMLimit {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
}
}
@@ -783,8 +863,8 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil
}
-func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
- params := pagination.PaginationParams{Page: page, PageSize: pageSize}
+func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) {
+ params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{})
if err != nil {
return nil, 0, err
@@ -792,6 +872,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag
return keys, result.Total, nil
}
+func (s *adminServiceImpl) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) {
+ if s.userRPMCache == nil {
+ return nil, ErrRPMStatusUnavailable
+ }
+
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+
+ userRPMUsed, err := s.userRPMCache.GetUserRPM(ctx, userID)
+ if err != nil {
+ logger.LegacyPrintf("service.admin", "failed to get user rpm: user_id=%d err=%v", userID, err)
+ }
+
+ keys, _, err := s.GetUserAPIKeys(ctx, userID, 1, 1000, "", "")
+ if err != nil {
+ return nil, err
+ }
+
+ groupIDSet := make(map[int64]struct{})
+ for _, key := range keys {
+ if key.GroupID != nil && *key.GroupID > 0 {
+ groupIDSet[*key.GroupID] = struct{}{}
+ }
+ }
+
+ groupIDs := make([]int64, 0, len(groupIDSet))
+ for groupID := range groupIDSet {
+ groupIDs = append(groupIDs, groupID)
+ }
+ sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] })
+
+ var perGroup []UserGroupRPMStatus
+ for _, groupID := range groupIDs {
+ used, getErr := s.userRPMCache.GetUserGroupRPM(ctx, userID, groupID)
+ if getErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to get user group rpm: user_id=%d group_id=%d err=%v", userID, groupID, getErr)
+ }
+
+ entry := UserGroupRPMStatus{
+ GroupID: groupID,
+ Used: used,
+ }
+
+ if s.groupRepo != nil {
+ if group, groupErr := s.groupRepo.GetByIDLite(ctx, groupID); groupErr == nil && group != nil {
+ entry.GroupName = group.Name
+ entry.Limit = group.RPMLimit
+ entry.Source = "group"
+ } else if groupErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to get group rpm status metadata: group_id=%d err=%v", groupID, groupErr)
+ }
+ }
+
+ if s.userGroupRateRepo != nil {
+ override, overrideErr := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, userID, groupID)
+ if overrideErr != nil {
+ logger.LegacyPrintf("service.admin", "failed to get rpm override: user_id=%d group_id=%d err=%v", userID, groupID, overrideErr)
+ } else if override != nil {
+ entry.Limit = *override
+ entry.Source = "override"
+ }
+ }
+
+ perGroup = append(perGroup, entry)
+ }
+
+ return &UserRPMStatus{
+ UserRPMUsed: userRPMUsed,
+ UserRPMLimit: user.RPMLimit,
+ PerGroup: perGroup,
+ }, nil
+}
+
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
// Return mock data for now
return map[string]any{
@@ -818,9 +973,337 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int
return codes, result.Total, totalRecharged, nil
}
+func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {
+ if userID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0")
+ }
+ if s == nil || s.entClient == nil || s.userRepo == nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_UNAVAILABLE", "auth identity binding service is unavailable")
+ }
+ if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
+ return nil, err
+ }
+
+ providerType := normalizeAdminAuthIdentityProviderType(input.ProviderType)
+ providerKey := strings.TrimSpace(input.ProviderKey)
+ providerSubject := strings.TrimSpace(input.ProviderSubject)
+ if providerType == "" {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat")
+ }
+ if providerKey == "" || providerSubject == "" {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
+ }
+ canonicalProviderKey := canonicalAdminAuthIdentityProviderKey(providerType, "", providerKey)
+ compatibleProviderKeys := compatibleAdminAuthIdentityProviderKeys(providerType, providerKey)
+
+ var issuer *string
+ if input.Issuer != nil {
+ trimmed := strings.TrimSpace(*input.Issuer)
+ if trimmed != "" {
+ issuer = &trimmed
+ }
+ }
+
+ channelInput := normalizeAdminBindChannelInput(input.Channel)
+ if input.Channel != nil && channelInput == nil {
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "channel, channel_app_id, and channel_subject are required when channel binding is provided")
+ }
+
+ verifiedAt := time.Now().UTC()
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_TX_FAILED", "failed to start auth identity bind transaction").WithCause(err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ identityRecords, err := tx.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ(providerType),
+ authidentity.ProviderKeyIn(compatibleProviderKeys...),
+ authidentity.ProviderSubjectEQ(providerSubject),
+ ).
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
+ }
+ if hasAdminAuthIdentityOwnershipConflict(identityRecords, userID) {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
+ }
+ identity := selectOwnedAdminAuthIdentity(identityRecords, userID)
+
+ if identity == nil {
+ create := tx.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType(providerType).
+ SetProviderKey(canonicalProviderKey).
+ SetProviderSubject(providerSubject).
+ SetVerifiedAt(verifiedAt)
+ if issuer != nil {
+ create = create.SetIssuer(*issuer)
+ }
+ if input.Metadata != nil {
+ create = create.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata))
+ }
+ identity, err = create.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
+ }
+ } else {
+ update := tx.AuthIdentity.UpdateOneID(identity.ID).
+ SetVerifiedAt(verifiedAt).
+ SetProviderKey(canonicalProviderKey)
+ if issuer != nil {
+ update = update.SetIssuer(*issuer)
+ }
+ if input.Metadata != nil {
+ update = update.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata))
+ }
+ identity, err = update.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
+ }
+ }
+
+ var channel *dbent.AuthIdentityChannel
+ if channelInput != nil {
+ channelRecords, err := tx.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ(providerType),
+ authidentitychannel.ProviderKeyIn(compatibleProviderKeys...),
+ authidentitychannel.ChannelEQ(channelInput.Channel),
+ authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID),
+ authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject),
+ ).
+ WithIdentity().
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
+ }
+ if hasAdminAuthIdentityChannelOwnershipConflict(channelRecords, userID) {
+ return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
+ }
+ channel = selectOwnedAdminAuthIdentityChannel(channelRecords, userID)
+ if channel == nil {
+ create := tx.AuthIdentityChannel.Create().
+ SetIdentityID(identity.ID).
+ SetProviderType(providerType).
+ SetProviderKey(canonicalProviderKey).
+ SetChannel(channelInput.Channel).
+ SetChannelAppID(channelInput.ChannelAppID).
+ SetChannelSubject(channelInput.ChannelSubject)
+ if channelInput.Metadata != nil {
+ create = create.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
+ }
+ channel, err = create.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
+ }
+ } else {
+ update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).
+ SetIdentityID(identity.ID).
+ SetProviderKey(canonicalProviderKey)
+ if channelInput.Metadata != nil {
+ update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
+ }
+ channel, err = update.Save(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
+ }
+ }
+ }
+
+ if err := tx.Commit(); err != nil {
+ return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_COMMIT_FAILED", "failed to commit auth identity bind").WithCause(err)
+ }
+ return buildAdminBoundAuthIdentity(identity, channel), nil
+}
+
+func compatibleAdminAuthIdentityProviderKeys(providerType, providerKey string) []string {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ providerKey = strings.TrimSpace(providerKey)
+ if providerKey == "" {
+ return []string{providerKey}
+ }
+ if providerType != "wechat" {
+ return []string{providerKey}
+ }
+
+ keys := []string{providerKey}
+ if !strings.EqualFold(providerKey, "wechat-main") {
+ keys = append(keys, "wechat-main")
+ }
+ if !strings.EqualFold(providerKey, "wechat") {
+ keys = append(keys, "wechat")
+ }
+ return keys
+}
+
+func canonicalAdminAuthIdentityProviderKey(providerType, existingKey, requestedKey string) string {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ existingKey = strings.TrimSpace(existingKey)
+ requestedKey = strings.TrimSpace(requestedKey)
+ if providerType != "wechat" {
+ if requestedKey != "" {
+ return requestedKey
+ }
+ return existingKey
+ }
+ if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
+ return "wechat-main"
+ }
+ if requestedKey != "" {
+ return requestedKey
+ }
+ return existingKey
+}
+
+func adminAuthIdentityProviderKeyRank(providerType, providerKey string) int {
+ providerType = strings.TrimSpace(strings.ToLower(providerType))
+ providerKey = strings.TrimSpace(providerKey)
+ if providerType != "wechat" {
+ return 0
+ }
+ switch {
+ case strings.EqualFold(providerKey, "wechat-main"):
+ return 0
+ case strings.EqualFold(providerKey, "wechat"):
+ return 2
+ default:
+ return 1
+ }
+}
+
+func selectOwnedAdminAuthIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
+ var selected *dbent.AuthIdentity
+ for _, record := range records {
+ if record.UserID != userID {
+ continue
+ }
+ if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
+ selected = record
+ }
+ }
+ return selected
+}
+
+func hasAdminAuthIdentityOwnershipConflict(records []*dbent.AuthIdentity, userID int64) bool {
+ for _, record := range records {
+ if record.UserID != userID {
+ return true
+ }
+ }
+ return false
+}
+
+func selectOwnedAdminAuthIdentityChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
+ var selected *dbent.AuthIdentityChannel
+ for _, record := range records {
+ if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
+ continue
+ }
+ if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
+ selected = record
+ }
+ }
+ return selected
+}
+
+func hasAdminAuthIdentityChannelOwnershipConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
+ for _, record := range records {
+ if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
+ return true
+ }
+ }
+ return false
+}
+
+func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput {
+ if input == nil {
+ return nil
+ }
+ channel := &AdminBindAuthIdentityChannelInput{
+ Channel: strings.TrimSpace(input.Channel),
+ ChannelAppID: strings.TrimSpace(input.ChannelAppID),
+ ChannelSubject: strings.TrimSpace(input.ChannelSubject),
+ Metadata: cloneAdminAuthIdentityMetadata(input.Metadata),
+ }
+ if channel.Channel == "" || channel.ChannelAppID == "" || channel.ChannelSubject == "" {
+ return nil
+ }
+ return channel
+}
+
+func normalizeAdminAuthIdentityProviderType(input string) string {
+ switch strings.ToLower(strings.TrimSpace(input)) {
+ case "email":
+ return "email"
+ case "linuxdo":
+ return "linuxdo"
+ case "oidc":
+ return "oidc"
+ case "wechat":
+ return "wechat"
+ default:
+ return ""
+ }
+}
+
+func buildAdminBoundAuthIdentity(identity *dbent.AuthIdentity, channel *dbent.AuthIdentityChannel) *AdminBoundAuthIdentity {
+ if identity == nil {
+ return nil
+ }
+ result := &AdminBoundAuthIdentity{
+ UserID: identity.UserID,
+ ProviderType: strings.TrimSpace(identity.ProviderType),
+ ProviderKey: strings.TrimSpace(identity.ProviderKey),
+ ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
+ VerifiedAt: identity.VerifiedAt,
+ Issuer: identity.Issuer,
+ Metadata: cloneAdminAuthIdentityMetadata(identity.Metadata),
+ CreatedAt: identity.CreatedAt,
+ UpdatedAt: identity.UpdatedAt,
+ }
+ if channel != nil {
+ result.Channel = &AdminBoundAuthIdentityChannel{
+ Channel: strings.TrimSpace(channel.Channel),
+ ChannelAppID: strings.TrimSpace(channel.ChannelAppID),
+ ChannelSubject: strings.TrimSpace(channel.ChannelSubject),
+ Metadata: cloneAdminAuthIdentityMetadata(channel.Metadata),
+ CreatedAt: channel.CreatedAt,
+ UpdatedAt: channel.UpdatedAt,
+ }
+ }
+ return result
+}
+
+func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any {
+ if input == nil {
+ return nil
+ }
+ if len(input) == 0 {
+ return map[string]any{}
+ }
+ data, err := json.Marshal(input)
+ if err != nil {
+ out := make(map[string]any, len(input))
+ for key, value := range input {
+ out[key] = value
+ }
+ return out
+ }
+ var out map[string]any
+ if err := json.Unmarshal(data, &out); err != nil {
+ out = make(map[string]any, len(input))
+ for key, value := range input {
+ out[key] = value
+ }
+ }
+ return out
+}
+
// Group management implementations
-func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
- params := pagination.PaginationParams{Page: page, PageSize: pageSize}
+func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) {
+ params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive)
if err != nil {
return nil, 0, err
@@ -841,6 +1324,10 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro
}
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
+ if input.RateMultiplier <= 0 {
+ return nil, errors.New("rate_multiplier must be > 0")
+ }
+
platform := input.Platform
if platform == "" {
platform = PlatformAnthropic
@@ -860,10 +1347,6 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
imagePrice1K := normalizePrice(input.ImagePrice1K)
imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K)
- soraImagePrice360 := normalizePrice(input.SoraImagePrice360)
- soraImagePrice540 := normalizePrice(input.SoraImagePrice540)
- soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest)
- soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD)
// 校验降级分组
if input.FallbackGroupID != nil {
@@ -934,22 +1417,20 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K,
- SoraImagePrice360: soraImagePrice360,
- SoraImagePrice540: soraImagePrice540,
- SoraVideoPricePerRequest: soraVideoPrice,
- SoraVideoPricePerRequestHD: soraVideoPriceHD,
ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest,
ModelRouting: input.ModelRouting,
MCPXMLInject: mcpXMLInject,
SupportedModelScopes: input.SupportedModelScopes,
- SoraStorageQuotaBytes: input.SoraStorageQuotaBytes,
AllowMessagesDispatch: input.AllowMessagesDispatch,
RequireOAuthOnly: input.RequireOAuthOnly,
RequirePrivacySet: input.RequirePrivacySet,
DefaultMappedModel: input.DefaultMappedModel,
+ MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig),
+ RPMLimit: input.RPMLimit,
}
+ sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
}
@@ -1087,6 +1568,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.Platform = input.Platform
}
if input.RateMultiplier != nil {
+ if *input.RateMultiplier <= 0 {
+ return nil, errors.New("rate_multiplier must be > 0")
+ }
group.RateMultiplier = *input.RateMultiplier
}
if input.IsExclusive != nil {
@@ -1115,21 +1599,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.ImagePrice4K != nil {
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
}
- if input.SoraImagePrice360 != nil {
- group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360)
- }
- if input.SoraImagePrice540 != nil {
- group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540)
- }
- if input.SoraVideoPricePerRequest != nil {
- group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest)
- }
- if input.SoraVideoPricePerRequestHD != nil {
- group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD)
- }
- if input.SoraStorageQuotaBytes != nil {
- group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes
- }
// Claude Code 客户端限制
if input.ClaudeCodeOnly != nil {
@@ -1191,11 +1660,22 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.DefaultMappedModel != nil {
group.DefaultMappedModel = *input.DefaultMappedModel
}
+ if input.MessagesDispatchModelConfig != nil {
+ group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig)
+ }
+ if input.RPMLimit != nil {
+ group.RPMLimit = *input.RPMLimit
+ }
+ sanitizeGroupMessagesDispatchFields(group)
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
+ }
+
// 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号)
if len(input.CopyAccountsFromGroupIDs) > 0 {
// 去重源分组 IDs
@@ -1264,9 +1744,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
}
}
- if s.authCacheInvalidator != nil {
- s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
- }
return group, nil
}
@@ -1334,9 +1811,47 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro
if s.userGroupRateRepo == nil {
return nil
}
+ for _, e := range entries {
+ if e.RateMultiplier <= 0 {
+ return fmt.Errorf("rate_multiplier must be > 0 (user_id=%d)", e.UserID)
+ }
+ }
return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries)
}
+func (s *adminServiceImpl) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error {
+ if s.userGroupRateRepo == nil {
+ return nil
+ }
+ if err := s.userGroupRateRepo.ClearGroupRPMOverrides(ctx, groupID); err != nil {
+ return err
+ }
+ // RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
+ }
+ return nil
+}
+
+func (s *adminServiceImpl) BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
+ if s.userGroupRateRepo == nil {
+ return nil
+ }
+ for _, e := range entries {
+ if e.RPMOverride != nil && *e.RPMOverride < 0 {
+ return infraerrors.BadRequest("INVALID_RPM_OVERRIDE", fmt.Sprintf("rpm_override must be >= 0 (user_id=%d)", e.UserID))
+ }
+ }
+ if err := s.userGroupRateRepo.SyncGroupRPMOverrides(ctx, groupID, entries); err != nil {
+ return err
+ }
+ // RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID)
+ }
+ return nil
+}
+
func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return s.groupRepo.UpdateSortOrders(ctx, updates)
}
@@ -1512,16 +2027,12 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou
}
// Account management implementations
-func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) {
- params := pagination.PaginationParams{Page: page, PageSize: pageSize}
+func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error) {
+ params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode)
if err != nil {
return nil, 0, err
}
- now := time.Now()
- for i := range accounts {
- syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now)
- }
return accounts, result.Total, nil
}
@@ -1566,18 +2077,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
}
- // Sora apikey 账号的 base_url 必填校验
- if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey {
- baseURL, _ := input.Credentials["base_url"].(string)
- baseURL = strings.TrimSpace(baseURL)
- if baseURL == "" {
- return nil, errors.New("sora apikey 账号必须设置 base_url")
- }
- if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
- return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
- }
- }
-
account := &Account{
Name: input.Name,
Notes: normalizeAccountNotes(input.Notes),
@@ -1623,18 +2122,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
return nil, err
}
- // 如果是 Sora 平台账号,自动创建 sora_accounts 扩展表记录
- if account.Platform == PlatformSora && s.soraAccountRepo != nil {
- soraUpdates := map[string]any{
- "access_token": account.GetCredential("access_token"),
- "refresh_token": account.GetCredential("refresh_token"),
- }
- if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil {
- // 只记录警告日志,不阻塞账号创建
- logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err)
- }
- }
-
// 绑定分组
if len(groupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
@@ -1642,16 +2129,29 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
}
- // Antigravity OAuth 账号:创建后异步设置隐私
- if account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth {
- go func() {
- defer func() {
- if r := recover(); r != nil {
- slog.Error("create_account_antigravity_privacy_panic", "account_id", account.ID, "recover", r)
- }
+ // OAuth 账号:创建后异步设置隐私。
+ // 使用 Ensure(幂等)而非 Force:新建账号 Extra 为空时效果相同,但更安全。
+ if account.Type == AccountTypeOAuth {
+ switch account.Platform {
+ case PlatformOpenAI:
+ go func() {
+ defer func() {
+ if r := recover(); r != nil {
+ slog.Error("create_account_openai_privacy_panic", "account_id", account.ID, "recover", r)
+ }
+ }()
+ s.EnsureOpenAIPrivacy(context.Background(), account)
}()
- s.EnsureAntigravityPrivacy(context.Background(), account)
- }()
+ case PlatformAntigravity:
+ go func() {
+ defer func() {
+ if r := recover(); r != nil {
+ slog.Error("create_account_antigravity_privacy_panic", "account_id", account.ID, "recover", r)
+ }
+ }()
+ s.EnsureAntigravityPrivacy(context.Background(), account)
+ }()
+ }
}
return account, nil
@@ -1750,18 +2250,6 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
}
- // Sora apikey 账号的 base_url 必填校验
- if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey {
- baseURL, _ := account.Credentials["base_url"].(string)
- baseURL = strings.TrimSpace(baseURL)
- if baseURL == "" {
- return nil, errors.New("sora apikey 账号必须设置 base_url")
- }
- if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
- return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
- }
- }
-
// 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil {
if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil {
@@ -1964,8 +2452,8 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
}
// Proxy management implementations
-func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) {
- params := pagination.PaginationParams{Page: page, PageSize: pageSize}
+func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]Proxy, int64, error) {
+ params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil {
return nil, 0, err
@@ -1973,8 +2461,8 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
return proxies, result.Total, nil
}
-func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) {
- params := pagination.PaginationParams{Page: page, PageSize: pageSize}
+func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]ProxyWithAccountCount, int64, error) {
+ params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search)
if err != nil {
return nil, 0, err
@@ -2111,8 +2599,8 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
}
// Redeem code management implementations
-func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) {
- params := pagination.PaginationParams{Page: page, PageSize: pageSize}
+func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error) {
+ params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil {
return nil, 0, err
@@ -2364,10 +2852,11 @@ func runProxyQualityTarget(ctx context.Context, client *http.Client, target prox
body = body[:proxyQualityMaxBodyBytes]
}
- if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
+ // Cloudflare challenge 检测
+ if httputil.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) {
item.Status = "challenge"
- item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body)
- item.Message = "Sora 命中 Cloudflare challenge"
+ item.CFRay = httputil.ExtractCloudflareRayID(resp.Header, body)
+ item.Message = "命中 Cloudflare challenge"
return item
}
@@ -2783,16 +3272,14 @@ func (s *adminServiceImpl) ForceOpenAIPrivacy(ctx context.Context, account *Acco
}
// EnsureAntigravityPrivacy 检查 Antigravity OAuth 账号隐私状态。
-// 如果 Extra["privacy_mode"] 已存在(无论成功或失败),直接跳过。
-// 仅对从未设置过隐私的账号执行 setUserSettings + fetchUserInfo 流程。
-// 用户可通过前端 ForceAntigravityPrivacy(SetPrivacy 按钮)强制重新设置。
+// 仅当 privacy_mode 已成功设置("privacy_set")时跳过;
+// 未设置或之前失败("privacy_set_failed")均会重试。
func (s *adminServiceImpl) EnsureAntigravityPrivacy(ctx context.Context, account *Account) string {
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
return ""
}
- // 已设置过则跳过(无论成功或失败),用户可通过 Force 手动重试
if account.Extra != nil {
- if existing, ok := account.Extra["privacy_mode"].(string); ok && existing != "" {
+ if existing, ok := account.Extra["privacy_mode"].(string); ok && existing == AntigravityPrivacySet {
return existing
}
}
diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go
index f9fd6742..fcde5cbf 100644
--- a/backend/internal/service/admin_service_apikey_test.go
+++ b/backend/internal/service/admin_service_apikey_test.go
@@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro
}
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) DeleteUserAvatar(context.Context, int64) error {
+ panic("unexpected")
+}
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
@@ -65,14 +74,31 @@ func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (boo
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
panic("unexpected")
}
-func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
- panic("unexpected")
-}
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ panic("unexpected")
+}
+
+func (s *userRepoStubForGroupUpdate) UnbindUserAuthProvider(context.Context, int64, string) error {
+ panic("unexpected")
+}
+
+func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ panic("unexpected")
+}
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
type apiKeyRepoStubForGroupUpdate struct {
@@ -131,9 +157,6 @@ func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, str
func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
-func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) {
- panic("unexpected")
-}
func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
@@ -158,6 +181,9 @@ func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, in
func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
panic("unexpected")
}
+func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) {
+ panic("unexpected")
+}
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
type groupRepoStubForGroupUpdate struct {
diff --git a/backend/internal/service/admin_service_auth_identity_binding_test.go b/backend/internal/service/admin_service_auth_identity_binding_test.go
new file mode 100644
index 00000000..719199f2
--- /dev/null
+++ b/backend/internal/service/admin_service_auth_identity_binding_test.go
@@ -0,0 +1,302 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func newAdminServiceAuthIdentityBindingTestClient(t *testing.T) *dbent.Client {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:admin_service_auth_identity_binding?mode=memory&cache=shared&_fk=1")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+ return client
+}
+
+func TestAdminServiceBindUserAuthIdentityCreatesCanonicalAndChannelBinding(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("bind-target@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-123",
+ Metadata: map[string]any{"source": "admin-repair"},
+ Channel: &AdminBindAuthIdentityChannelInput{
+ Channel: "open",
+ ChannelAppID: "wx-open",
+ ChannelSubject: "openid-123",
+ Metadata: map[string]any{"scene": "migration"},
+ },
+ })
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, user.ID, result.UserID)
+ require.Equal(t, "wechat", result.ProviderType)
+ require.Equal(t, "wechat-main", result.ProviderKey)
+ require.NotNil(t, result.VerifiedAt)
+ require.NotNil(t, result.Channel)
+ require.Equal(t, "open", result.Channel.Channel)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderKeyEQ("wechat-main"),
+ authidentity.ProviderSubjectEQ("union-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+ require.Equal(t, "admin-repair", identity.Metadata["source"])
+ require.NotNil(t, identity.VerifiedAt)
+
+ channel, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ProviderKeyEQ("wechat-main"),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open"),
+ authidentitychannel.ChannelSubjectEQ("openid-123"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, identity.ID, channel.IdentityID)
+ require.Equal(t, "migration", channel.Metadata["scene"])
+}
+
+func TestAdminServiceBindUserAuthIdentityRejectsOtherOwner(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ owner, err := client.User.Create().
+ SetEmail("owner@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ target, err := client.User.Create().
+ SetEmail("target@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.AuthIdentity.Create().
+ SetUserID(owner.ID).
+ SetProviderType("oidc").
+ SetProviderKey("https://issuer.example").
+ SetProviderSubject("subject-1").
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: target.ID, Email: target.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ _, err = svc.BindUserAuthIdentity(ctx, target.ID, AdminBindAuthIdentityInput{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-1",
+ })
+ require.Error(t, err)
+ require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", infraerrors.Reason(err))
+}
+
+func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("same-user@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ first, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-2",
+ Metadata: map[string]any{"source": "first"},
+ })
+ require.NoError(t, err)
+
+ second, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-2",
+ Metadata: map[string]any{"source": "second"},
+ })
+ require.NoError(t, err)
+ require.Equal(t, first.UserID, second.UserID)
+ require.Equal(t, "second", second.Metadata["source"])
+
+ identities, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("oidc"),
+ authidentity.ProviderKeyEQ("https://issuer.example"),
+ authidentity.ProviderSubjectEQ("subject-2"),
+ ).
+ All(ctx)
+ require.NoError(t, err)
+ require.Len(t, identities, 1)
+ require.Equal(t, "second", identities[0].Metadata["source"])
+}
+
+func TestAdminServiceBindUserAuthIdentityReusesLegacyWeChatAliasRecords(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("wechat-alias@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyIdentity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetProviderSubject("union-legacy-123").
+ SetMetadata(map[string]any{"source": "legacy"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ legacyChannel, err := client.AuthIdentityChannel.Create().
+ SetIdentityID(legacyIdentity.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat").
+ SetChannel("open").
+ SetChannelAppID("wx-open").
+ SetChannelSubject("openid-legacy-123").
+ SetMetadata(map[string]any{"scene": "legacy"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-legacy-123",
+ Metadata: map[string]any{"source": "admin-repair"},
+ Channel: &AdminBindAuthIdentityChannelInput{
+ Channel: "open",
+ ChannelAppID: "wx-open",
+ ChannelSubject: "openid-legacy-123",
+ Metadata: map[string]any{"scene": "admin-repair"},
+ },
+ })
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "wechat-main", result.ProviderKey)
+ require.NotNil(t, result.Channel)
+ require.Equal(t, "open", result.Channel.Channel)
+
+ identity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID)
+ require.NoError(t, err)
+ require.Equal(t, "wechat-main", identity.ProviderKey)
+ require.Equal(t, "admin-repair", identity.Metadata["source"])
+
+ channel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID)
+ require.NoError(t, err)
+ require.Equal(t, "wechat-main", channel.ProviderKey)
+ require.Equal(t, legacyIdentity.ID, channel.IdentityID)
+ require.Equal(t, "admin-repair", channel.Metadata["scene"])
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("wechat"),
+ authidentity.ProviderSubjectEQ("union-legacy-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, identityCount)
+
+ channelCount, err := client.AuthIdentityChannel.Query().
+ Where(
+ authidentitychannel.ProviderTypeEQ("wechat"),
+ authidentitychannel.ChannelEQ("open"),
+ authidentitychannel.ChannelAppIDEQ("wx-open"),
+ authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, channelCount)
+}
+
+func TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType(t *testing.T) {
+ client := newAdminServiceAuthIdentityBindingTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("invalid-provider@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ svc := &adminServiceImpl{
+ userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
+ entClient: client,
+ }
+
+ _, err = svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
+ ProviderType: "github",
+ ProviderKey: "github-main",
+ ProviderSubject: "subject-3",
+ })
+ require.Error(t, err)
+ require.Equal(t, "INVALID_INPUT", infraerrors.Reason(err))
+}
diff --git a/backend/internal/service/admin_service_clear_error_test.go b/backend/internal/service/admin_service_clear_error_test.go
index f039612c..141466dc 100644
--- a/backend/internal/service/admin_service_clear_error_test.go
+++ b/backend/internal/service/admin_service_clear_error_test.go
@@ -12,12 +12,12 @@ import (
type accountRepoStubForClearAccountError struct {
mockAccountRepoForGemini
- account *Account
- clearErrorCalls int
- clearRateLimitCalls int
- clearAntigravityCalls int
+ account *Account
+ clearErrorCalls int
+ clearRateLimitCalls int
+ clearAntigravityCalls int
clearModelRateLimitCalls int
- clearTempUnschedCalls int
+ clearTempUnschedCalls int
}
func (r *accountRepoStubForClearAccountError) GetByID(ctx context.Context, id int64) (*Account, error) {
@@ -60,13 +60,13 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
resetAt := time.Now().Add(5 * time.Minute)
repo := &accountRepoStubForClearAccountError{
account: &Account{
- ID: 31,
- Platform: PlatformOpenAI,
- Type: AccountTypeOAuth,
- Status: StatusError,
- ErrorMessage: "refresh failed",
- RateLimitResetAt: &resetAt,
- TempUnschedulableUntil: &until,
+ ID: 31,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusError,
+ ErrorMessage: "refresh failed",
+ RateLimitResetAt: &resetAt,
+ TempUnschedulableUntil: &until,
TempUnschedulableReason: "missing refresh token",
},
}
diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go
index fbc856cf..fe9e7701 100644
--- a/backend/internal/service/admin_service_delete_test.go
+++ b/backend/internal/service/admin_service_delete_test.go
@@ -13,15 +13,18 @@ import (
)
type userRepoStub struct {
- user *User
- getErr error
- createErr error
- deleteErr error
- exists bool
- existsErr error
- nextID int64
- created []*User
- deletedIDs []int64
+ user *User
+ getErr error
+ createErr error
+ deleteErr error
+ exists bool
+ existsErr error
+ nextID int64
+ created []*User
+ updated []*User
+ deletedIDs []int64
+ usersByEmail map[string]*User
+ getByEmailErr error
}
func (s *userRepoStub) Create(ctx context.Context, user *User) error {
@@ -32,6 +35,11 @@ func (s *userRepoStub) Create(ctx context.Context, user *User) error {
user.ID = s.nextID
}
s.created = append(s.created, user)
+ if s.usersByEmail == nil {
+ s.usersByEmail = make(map[string]*User)
+ }
+ s.usersByEmail[user.Email] = user
+ s.user = user
return nil
}
@@ -46,7 +54,18 @@ func (s *userRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
}
func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) {
- panic("unexpected GetByEmail call")
+ if s.getByEmailErr != nil {
+ return nil, s.getByEmailErr
+ }
+ if s.usersByEmail != nil {
+ if user, ok := s.usersByEmail[email]; ok {
+ return user, nil
+ }
+ }
+ if s.user != nil && s.user.Email == email {
+ return s.user, nil
+ }
+ return nil, ErrUserNotFound
}
func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
@@ -54,7 +73,13 @@ func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
}
func (s *userRepoStub) Update(ctx context.Context, user *User) error {
- panic("unexpected Update call")
+ s.updated = append(s.updated, user)
+ if s.usersByEmail == nil {
+ s.usersByEmail = make(map[string]*User)
+ }
+ s.usersByEmail[user.Email] = user
+ s.user = user
+ return nil
}
func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
@@ -62,6 +87,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
return s.deleteErr
}
+func (s *userRepoStub) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
+ panic("unexpected GetUserAvatar call")
+}
+
+func (s *userRepoStub) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
+ panic("unexpected UpsertUserAvatar call")
+}
+
+func (s *userRepoStub) DeleteUserAvatar(ctx context.Context, userID int64) error {
+ panic("unexpected DeleteUserAvatar call")
+}
+
func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
@@ -70,6 +107,18 @@ func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.Pa
panic("unexpected ListWithFilters call")
}
+func (s *userRepoStub) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserIDs call")
+}
+
+func (s *userRepoStub) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
+ panic("unexpected GetLatestUsedAtByUserID call")
+}
+
+func (s *userRepoStub) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
+ panic("unexpected UpdateUserLastActiveAt call")
+}
+
func (s *userRepoStub) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
@@ -101,6 +150,14 @@ func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64
panic("unexpected AddGroupToAllowedGroups call")
}
+func (s *userRepoStub) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
+ panic("unexpected ListUserAuthIdentities call")
+}
+
+func (s *userRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
+ panic("unexpected UnbindUserAuthProvider call")
+}
+
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go
new file mode 100644
index 00000000..2232c9c3
--- /dev/null
+++ b/backend/internal/service/admin_service_email_identity_sync_test.go
@@ -0,0 +1,187 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+type ensureEmailCall struct {
+ userID int64
+ email string
+}
+
+type replaceEmailCall struct {
+ userID int64
+ oldEmail string
+ newEmail string
+}
+
+type emailSyncRepoStub struct {
+ user *User
+ nextID int64
+ updateCalls int
+ created []*User
+ updated []*User
+ ensureCalls []ensureEmailCall
+ replaceCalls []replaceEmailCall
+ ensureErr error
+ replaceErr error
+}
+
+func (s *emailSyncRepoStub) Create(_ context.Context, user *User) error {
+ if s.nextID != 0 && user.ID == 0 {
+ user.ID = s.nextID
+ }
+ s.created = append(s.created, user)
+ s.user = user
+ return nil
+}
+
+func (s *emailSyncRepoStub) GetByID(_ context.Context, _ int64) (*User, error) {
+ if s.user == nil {
+ return nil, ErrUserNotFound
+ }
+ cloned := *s.user
+ return &cloned, nil
+}
+
+func (s *emailSyncRepoStub) GetByEmail(_ context.Context, _ string) (*User, error) {
+ return nil, ErrUserNotFound
+}
+
+func (s *emailSyncRepoStub) GetFirstAdmin(context.Context) (*User, error) {
+ return nil, fmt.Errorf("unexpected GetFirstAdmin call")
+}
+
+func (s *emailSyncRepoStub) Update(_ context.Context, user *User) error {
+ s.updateCalls++
+ s.updated = append(s.updated, user)
+ s.user = user
+ return nil
+}
+
+func (s *emailSyncRepoStub) Delete(context.Context, int64) error { return nil }
+
+func (s *emailSyncRepoStub) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
+ return nil, fmt.Errorf("unexpected GetUserAvatar call")
+}
+
+func (s *emailSyncRepoStub) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
+ return nil, fmt.Errorf("unexpected UpsertUserAvatar call")
+}
+
+func (s *emailSyncRepoStub) DeleteUserAvatar(context.Context, int64) error {
+ return fmt.Errorf("unexpected DeleteUserAvatar call")
+}
+
+func (s *emailSyncRepoStub) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
+ return nil, nil, fmt.Errorf("unexpected List call")
+}
+
+func (s *emailSyncRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
+ return nil, nil, fmt.Errorf("unexpected ListWithFilters call")
+}
+
+func (s *emailSyncRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (s *emailSyncRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func (s *emailSyncRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
+ return nil
+}
+
+func (s *emailSyncRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
+
+func (s *emailSyncRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
+
+func (s *emailSyncRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
+
+func (s *emailSyncRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
+
+func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
+
+func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+
+func (s *emailSyncRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+
+func (s *emailSyncRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error { return nil }
+
+func (s *emailSyncRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+
+func (s *emailSyncRepoStub) EnableTotp(context.Context, int64) error { return nil }
+
+func (s *emailSyncRepoStub) DisableTotp(context.Context, int64) error { return nil }
+
+func (s *emailSyncRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
+ s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email})
+ return s.ensureErr
+}
+
+func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
+ s.replaceCalls = append(s.replaceCalls, replaceEmailCall{
+ userID: userID,
+ oldEmail: oldEmail,
+ newEmail: newEmail,
+ })
+ return s.replaceErr
+}
+
+func TestAdminService_CreateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
+ repo := &emailSyncRepoStub{
+ nextID: 55,
+ ensureErr: fmt.Errorf("unexpected email resync"),
+ }
+ svc := &adminServiceImpl{userRepo: repo}
+
+ user, err := svc.CreateUser(context.Background(), &CreateUserInput{
+ Email: "admin-created@example.com",
+ Password: "strong-pass",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, int64(55), user.ID)
+ require.Empty(t, repo.ensureCalls)
+ require.Empty(t, repo.replaceCalls)
+}
+
+func TestAdminService_UpdateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
+ repo := &emailSyncRepoStub{
+ user: &User{
+ ID: 91,
+ Email: "before@example.com",
+ Role: RoleUser,
+ Status: StatusActive,
+ Concurrency: 3,
+ },
+ replaceErr: fmt.Errorf("unexpected email resync"),
+ }
+ svc := &adminServiceImpl{userRepo: repo}
+
+ updated, err := svc.UpdateUser(context.Background(), 91, &UpdateUserInput{
+ Email: "after@example.com",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, updated)
+ require.Equal(t, "after@example.com", updated.Email)
+ require.Empty(t, repo.replaceCalls)
+ require.Empty(t, repo.ensureCalls)
+}
diff --git a/backend/internal/service/admin_service_group_rate_test.go b/backend/internal/service/admin_service_group_rate_test.go
index 77635247..d2efb644 100644
--- a/backend/internal/service/admin_service_group_rate_test.go
+++ b/backend/internal/service/admin_service_group_rate_test.go
@@ -5,8 +5,10 @@ package service
import (
"context"
"errors"
+ "net/http"
"testing"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
@@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct {
syncedGroupID int64
syncedEntries []GroupRateMultiplierInput
syncGroupErr error
+
+ rpmSyncedGroupID int64
+ rpmSyncedEntries []GroupRPMOverrideInput
+ rpmSyncErr error
}
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
@@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context,
panic("unexpected GetByUserAndGroup call")
}
+func (s *userGroupRateRepoStubForGroupRate) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
+ panic("unexpected GetRPMOverrideByUserAndGroup call")
+}
+
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
if s.getByGroupIDErr != nil {
return nil, s.getByGroupIDErr
@@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C
return s.syncGroupErr
}
+func (s *userGroupRateRepoStubForGroupRate) SyncGroupRPMOverrides(_ context.Context, groupID int64, entries []GroupRPMOverrideInput) error {
+ s.rpmSyncedGroupID = groupID
+ s.rpmSyncedEntries = entries
+ return s.rpmSyncErr
+}
+
+func (s *userGroupRateRepoStubForGroupRate) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
+ panic("unexpected ClearGroupRPMOverrides call")
+}
+
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
return s.deleteByGroupErr
@@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDData: map[int64][]UserGroupRateEntry{
10: {
- {UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
- {UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
+ {UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: ptrFloat(1.5)},
+ {UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: ptrFloat(0.8)},
},
},
}
@@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
require.Len(t, entries, 2)
require.Equal(t, int64(1), entries[0].UserID)
require.Equal(t, "alice", entries[0].UserName)
- require.Equal(t, 1.5, entries[0].RateMultiplier)
+ require.NotNil(t, entries[0].RateMultiplier)
+ require.Equal(t, 1.5, *entries[0].RateMultiplier)
require.Equal(t, int64(2), entries[1].UserID)
- require.Equal(t, 0.8, entries[1].RateMultiplier)
+ require.NotNil(t, entries[1].RateMultiplier)
+ require.Equal(t, 0.8, *entries[1].RateMultiplier)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
@@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
require.Contains(t, err.Error(), "sync failed")
})
}
+
+func TestAdminService_BatchSetGroupRPMOverrides(t *testing.T) {
+ t.Run("syncs entries to repo", func(t *testing.T) {
+ repo := &userGroupRateRepoStubForGroupRate{}
+ svc := &adminServiceImpl{userGroupRateRepo: repo}
+ override := 20
+ entries := []GroupRPMOverrideInput{{UserID: 2, RPMOverride: &override}}
+
+ err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, entries)
+ require.NoError(t, err)
+ require.Equal(t, int64(10), repo.rpmSyncedGroupID)
+ require.Equal(t, entries, repo.rpmSyncedEntries)
+ })
+
+ t.Run("rejects negative override as bad request", func(t *testing.T) {
+ repo := &userGroupRateRepoStubForGroupRate{}
+ svc := &adminServiceImpl{userGroupRateRepo: repo}
+ negative := -1
+
+ err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, []GroupRPMOverrideInput{
+ {UserID: 2, RPMOverride: &negative},
+ })
+ require.Error(t, err)
+ require.Equal(t, http.StatusBadRequest, infraerrors.Code(err))
+ require.Zero(t, repo.rpmSyncedGroupID)
+ })
+}
diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go
index 536be0b5..eef02240 100644
--- a/backend/internal/service/admin_service_group_test.go
+++ b/backend/internal/service/admin_service_group_test.go
@@ -10,6 +10,11 @@ import (
"github.com/stretchr/testify/require"
)
+func ptrString[T ~string](v T) *string {
+ s := string(v)
+ return &s
+}
+
// groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub
type groupRepoStubForAdmin struct {
created *Group // 记录 Create 调用的参数
@@ -120,6 +125,22 @@ func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSor
return nil
}
+func TestAdminService_ListGroups_PassesSortParams(t *testing.T) {
+ repo := &groupRepoStubForAdmin{
+ listWithFiltersGroups: []Group{{ID: 1, Name: "g1"}},
+ }
+ svc := &adminServiceImpl{groupRepo: repo}
+
+ _, _, err := svc.ListGroups(context.Background(), 3, 25, PlatformOpenAI, StatusActive, "needle", nil, "account_count", "ASC")
+ require.NoError(t, err)
+ require.Equal(t, pagination.PaginationParams{
+ Page: 3,
+ PageSize: 25,
+ SortBy: "account_count",
+ SortOrder: "ASC",
+ }, repo.listWithFiltersParams)
+}
+
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{}
@@ -245,6 +266,141 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.Nil(t, repo.updated.ImagePrice4K)
}
+func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
+ existingGroup := &Group{
+ ID: 1,
+ Name: "existing-group",
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ RPMLimit: 10,
+ }
+ repo := &groupRepoStubForAdmin{getByID: existingGroup}
+ invalidator := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{
+ groupRepo: repo,
+ authCacheInvalidator: invalidator,
+ }
+
+ rpmLimit := 60
+ group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
+ RPMLimit: &rpmLimit,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, group)
+ require.Equal(t, 60, repo.updated.RPMLimit)
+ require.Equal(t, []int64{1}, invalidator.groupIDs, "分组 RPMLimit 写入 auth snapshot,变更后必须失效 API Key 认证缓存")
+}
+
+func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
+ repo := &groupRepoStubForAdmin{}
+ svc := &adminServiceImpl{groupRepo: repo}
+
+ group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
+ Name: "dispatch-group",
+ Description: "dispatch config",
+ Platform: PlatformOpenAI,
+ RateMultiplier: 1.0,
+ MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
+ OpusMappedModel: " gpt-5.4-high ",
+ SonnetMappedModel: " gpt-5.3-codex ",
+ HaikuMappedModel: " gpt-5.4-mini-medium ",
+ ExactModelMappings: map[string]string{
+ " claude-sonnet-4-5-20250929 ": " gpt-5.2-high ",
+ },
+ },
+ })
+ require.NoError(t, err)
+ require.NotNil(t, group)
+ require.NotNil(t, repo.created)
+ require.Equal(t, OpenAIMessagesDispatchModelConfig{
+ OpusMappedModel: "gpt-5.4",
+ SonnetMappedModel: "gpt-5.3-codex",
+ HaikuMappedModel: "gpt-5.4-mini",
+ ExactModelMappings: map[string]string{
+ "claude-sonnet-4-5-20250929": "gpt-5.2",
+ },
+ }, repo.created.MessagesDispatchModelConfig)
+}
+
+func TestAdminService_UpdateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) {
+ existingGroup := &Group{
+ ID: 1,
+ Name: "existing-group",
+ Platform: PlatformOpenAI,
+ Status: StatusActive,
+ }
+ repo := &groupRepoStubForAdmin{getByID: existingGroup}
+ svc := &adminServiceImpl{groupRepo: repo}
+
+ group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
+ MessagesDispatchModelConfig: &OpenAIMessagesDispatchModelConfig{
+ SonnetMappedModel: " gpt-5.4-medium ",
+ ExactModelMappings: map[string]string{
+ " claude-haiku-4-5-20251001 ": " gpt-5.4-mini-high ",
+ },
+ },
+ })
+ require.NoError(t, err)
+ require.NotNil(t, group)
+ require.NotNil(t, repo.updated)
+ require.Equal(t, OpenAIMessagesDispatchModelConfig{
+ SonnetMappedModel: "gpt-5.4",
+ ExactModelMappings: map[string]string{
+ "claude-haiku-4-5-20251001": "gpt-5.4-mini",
+ },
+ }, repo.updated.MessagesDispatchModelConfig)
+}
+
+func TestAdminService_CreateGroup_ClearsMessagesDispatchFieldsForNonOpenAIPlatform(t *testing.T) {
+ repo := &groupRepoStubForAdmin{}
+ svc := &adminServiceImpl{groupRepo: repo}
+
+ group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
+ Name: "anthropic-group",
+ Description: "non-openai",
+ Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
+ AllowMessagesDispatch: true,
+ DefaultMappedModel: "gpt-5.4",
+ MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
+ OpusMappedModel: "gpt-5.4",
+ },
+ })
+ require.NoError(t, err)
+ require.NotNil(t, group)
+ require.NotNil(t, repo.created)
+ require.False(t, repo.created.AllowMessagesDispatch)
+ require.Empty(t, repo.created.DefaultMappedModel)
+ require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.created.MessagesDispatchModelConfig)
+}
+
+func TestAdminService_UpdateGroup_ClearsMessagesDispatchFieldsWhenPlatformChangesAwayFromOpenAI(t *testing.T) {
+ existingGroup := &Group{
+ ID: 1,
+ Name: "existing-openai-group",
+ Platform: PlatformOpenAI,
+ Status: StatusActive,
+ AllowMessagesDispatch: true,
+ DefaultMappedModel: "gpt-5.4",
+ MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
+ SonnetMappedModel: "gpt-5.3-codex",
+ },
+ }
+ repo := &groupRepoStubForAdmin{getByID: existingGroup}
+ svc := &adminServiceImpl{groupRepo: repo}
+
+ group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{
+ Platform: PlatformAnthropic,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, group)
+ require.NotNil(t, repo.updated)
+ require.Equal(t, PlatformAnthropic, repo.updated.Platform)
+ require.False(t, repo.updated.AllowMessagesDispatch)
+ require.Empty(t, repo.updated.DefaultMappedModel)
+ require.Equal(t, OpenAIMessagesDispatchModelConfig{}, repo.updated.MessagesDispatchModelConfig)
+}
+
func TestAdminService_ListGroups_WithSearch(t *testing.T) {
// 测试:
// 1. search 参数正常传递到 repository 层
@@ -258,7 +414,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{groupRepo: repo}
- groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil)
+ groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil, "", "")
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
@@ -276,7 +432,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{groupRepo: repo}
- groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil)
+ groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil, "", "")
require.NoError(t, err)
require.Empty(t, groups)
require.Equal(t, int64(0), total)
@@ -295,7 +451,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{groupRepo: repo}
- groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive)
+ groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive, "", "")
require.NoError(t, err)
require.Equal(t, int64(42), total)
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
@@ -490,6 +646,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformOpenAI,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -510,6 +667,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeSubscription,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -564,6 +722,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -582,6 +741,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -602,6 +762,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAntigravity,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
@@ -619,6 +780,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
+ RateMultiplier: 1.0,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &zero,
})
diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go
index 37f348df..ff3f65a8 100644
--- a/backend/internal/service/admin_service_list_users_test.go
+++ b/backend/internal/service/admin_service_list_users_test.go
@@ -6,6 +6,7 @@ import (
"context"
"errors"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
@@ -13,11 +14,15 @@ import (
type userRepoStubForListUsers struct {
userRepoStub
- users []User
- err error
+ users []User
+ err error
+ listWithFiltersParams pagination.PaginationParams
+ lastUsedByUserID map[int64]*time.Time
+ lastUsedErr error
}
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
+ s.listWithFiltersParams = params
if s.err != nil {
return nil, nil, s.err
}
@@ -30,6 +35,26 @@ func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pag
}, nil
}
+func (s *userRepoStubForListUsers) GetLatestUsedAtByUserIDs(_ context.Context, userIDs []int64) (map[int64]*time.Time, error) {
+ if s.lastUsedErr != nil {
+ return nil, s.lastUsedErr
+ }
+ result := make(map[int64]*time.Time, len(userIDs))
+ for _, userID := range userIDs {
+ if ts, ok := s.lastUsedByUserID[userID]; ok {
+ result[userID] = ts
+ }
+ }
+ return result, nil
+}
+
+func (s *userRepoStubForListUsers) GetLatestUsedAtByUserID(_ context.Context, userID int64) (*time.Time, error) {
+ if s.lastUsedErr != nil {
+ return nil, s.lastUsedErr
+ }
+ return s.lastUsedByUserID[userID], nil
+}
+
type userGroupRateRepoStubForListUsers struct {
batchCalls int
singleCall []int64
@@ -64,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context,
panic("unexpected GetByUserAndGroup call")
}
+func (s *userGroupRateRepoStubForListUsers) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) {
+ panic("unexpected GetRPMOverrideByUserAndGroup call")
+}
+
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
panic("unexpected SyncUserGroupRates call")
}
@@ -76,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C
panic("unexpected SyncGroupRateMultipliers call")
}
+func (s *userGroupRateRepoStubForListUsers) SyncGroupRPMOverrides(_ context.Context, _ int64, _ []GroupRPMOverrideInput) error {
+ panic("unexpected SyncGroupRPMOverrides call")
+}
+
+func (s *userGroupRateRepoStubForListUsers) ClearGroupRPMOverrides(_ context.Context, _ int64) error {
+ panic("unexpected ClearGroupRPMOverrides call")
+}
+
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
panic("unexpected DeleteByGroupID call")
}
@@ -103,7 +140,7 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
userGroupRateRepo: rateRepo,
}
- users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{})
+ users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}, "", "")
require.NoError(t, err)
require.Equal(t, int64(2), total)
require.Len(t, users, 2)
@@ -112,3 +149,37 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
require.Equal(t, 1.1, users[0].GroupRates[11])
require.Equal(t, 2.2, users[1].GroupRates[22])
}
+
+func TestAdminService_ListUsers_PassesSortParams(t *testing.T) {
+ userRepo := &userRepoStubForListUsers{
+ users: []User{{ID: 1, Email: "a@example.com"}},
+ }
+ svc := &adminServiceImpl{userRepo: userRepo}
+
+ _, _, err := svc.ListUsers(context.Background(), 2, 50, UserListFilters{}, "email", "ASC")
+ require.NoError(t, err)
+ require.Equal(t, pagination.PaginationParams{
+ Page: 2,
+ PageSize: 50,
+ SortBy: "email",
+ SortOrder: "ASC",
+ }, userRepo.listWithFiltersParams)
+}
+
+func TestAdminService_ListUsers_PopulatesLastUsedAt(t *testing.T) {
+ lastUsed := time.Now().UTC().Add(-30 * time.Minute).Truncate(time.Second)
+ userRepo := &userRepoStubForListUsers{
+ users: []User{{ID: 101, Email: "u@example.com"}},
+ lastUsedByUserID: map[int64]*time.Time{
+ 101: &lastUsed,
+ },
+ }
+ svc := &adminServiceImpl{userRepo: userRepo}
+
+ users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}, "", "")
+ require.NoError(t, err)
+ require.Equal(t, int64(1), total)
+ require.Len(t, users, 1)
+ require.NotNil(t, users[0].LastUsedAt)
+ require.WithinDuration(t, lastUsed, *users[0].LastUsedAt, time.Second)
+}
diff --git a/backend/internal/service/admin_service_proxy_quality_test.go b/backend/internal/service/admin_service_proxy_quality_test.go
index 5a43cd9c..d3b3f61b 100644
--- a/backend/internal/service/admin_service_proxy_quality_test.go
+++ b/backend/internal/service/admin_service_proxy_quality_test.go
@@ -27,7 +27,7 @@ func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) {
require.Contains(t, result.Summary, "挑战 1 项")
}
-func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
+func TestRunProxyQualityTarget_CloudflareChallenge(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Header().Set("cf-ray", "test-ray-123")
@@ -37,7 +37,7 @@ func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
defer server.Close()
target := proxyQualityTarget{
- Target: "sora",
+ Target: "openai",
URL: server.URL,
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
diff --git a/backend/internal/service/admin_service_rpm_status_test.go b/backend/internal/service/admin_service_rpm_status_test.go
new file mode 100644
index 00000000..c298f69b
--- /dev/null
+++ b/backend/internal/service/admin_service_rpm_status_test.go
@@ -0,0 +1,112 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+type rpmStatusUserRepoStub struct {
+ UserRepository
+ user *User
+}
+
+func (s *rpmStatusUserRepoStub) GetByID(_ context.Context, _ int64) (*User, error) {
+ return s.user, nil
+}
+
+type rpmStatusAPIKeyRepoStub struct {
+ APIKeyRepository
+ keys []APIKey
+}
+
+func (s *rpmStatusAPIKeyRepoStub) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
+ return s.keys, &pagination.PaginationResult{Total: int64(len(s.keys))}, nil
+}
+
+type rpmStatusGroupRepoStub struct {
+ GroupRepository
+ groups map[int64]*Group
+}
+
+func (s *rpmStatusGroupRepoStub) GetByIDLite(_ context.Context, id int64) (*Group, error) {
+ return s.groups[id], nil
+}
+
+type rpmStatusRateRepoStub struct {
+ UserGroupRateRepository
+ overrides map[int64]*int
+}
+
+func (s *rpmStatusRateRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, groupID int64) (*int, error) {
+ return s.overrides[groupID], nil
+}
+
+type rpmStatusCacheStub struct {
+ UserRPMCache
+ userUsed int
+ groupUsed map[int64]int
+}
+
+func (s *rpmStatusCacheStub) IncrementUserGroupRPM(context.Context, int64, int64) (int, error) {
+ return 0, nil
+}
+
+func (s *rpmStatusCacheStub) IncrementUserRPM(context.Context, int64) (int, error) {
+ return 0, nil
+}
+
+func (s *rpmStatusCacheStub) GetUserGroupRPM(_ context.Context, _, groupID int64) (int, error) {
+ return s.groupUsed[groupID], nil
+}
+
+func (s *rpmStatusCacheStub) GetUserRPM(context.Context, int64) (int, error) {
+ return s.userUsed, nil
+}
+
+func TestAdminService_GetUserRPMStatus_AggregatesUserAndGroupLimits(t *testing.T) {
+ groupOneID := int64(1)
+ groupTwoID := int64(2)
+ override := 7
+ svc := &adminServiceImpl{
+ userRepo: &rpmStatusUserRepoStub{user: &User{
+ ID: 42,
+ RPMLimit: 20,
+ }},
+ apiKeyRepo: &rpmStatusAPIKeyRepoStub{keys: []APIKey{
+ {ID: 100, UserID: 42, GroupID: &groupTwoID},
+ {ID: 101, UserID: 42, GroupID: &groupOneID},
+ {ID: 102, UserID: 42, GroupID: &groupTwoID},
+ {ID: 103, UserID: 42},
+ }},
+ groupRepo: &rpmStatusGroupRepoStub{groups: map[int64]*Group{
+ groupOneID: {ID: groupOneID, Name: "group-one", RPMLimit: 10},
+ groupTwoID: {ID: groupTwoID, Name: "group-two", RPMLimit: 60},
+ }},
+ userGroupRateRepo: &rpmStatusRateRepoStub{overrides: map[int64]*int{
+ groupTwoID: &override,
+ }},
+ userRPMCache: &rpmStatusCacheStub{
+ userUsed: 5,
+ groupUsed: map[int64]int{
+ groupOneID: 3,
+ groupTwoID: 4,
+ },
+ },
+ }
+
+ status, err := svc.GetUserRPMStatus(context.Background(), 42)
+ require.NoError(t, err)
+ require.Equal(t, &UserRPMStatus{
+ UserRPMUsed: 5,
+ UserRPMLimit: 20,
+ PerGroup: []UserGroupRPMStatus{
+ {GroupID: groupOneID, GroupName: "group-one", Used: 3, Limit: 10, Source: "group"},
+ {GroupID: groupTwoID, GroupName: "group-two", Used: 4, Limit: 7, Source: "override"},
+ },
+ }, status)
+}
diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go
index eb213e6a..595e99e3 100644
--- a/backend/internal/service/admin_service_search_test.go
+++ b/backend/internal/service/admin_service_search_test.go
@@ -170,13 +170,13 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{accountRepo: repo}
- accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "")
+ accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "", "name", "ASC")
require.NoError(t, err)
require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
require.Equal(t, 1, repo.listWithFiltersCalls)
- require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
+ require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20, SortBy: "name", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
@@ -192,7 +192,7 @@ func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) {
}
svc := &adminServiceImpl{accountRepo: repo}
- accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked)
+ accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked, "", "")
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts)
@@ -208,13 +208,13 @@ func TestAdminService_ListProxies_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{proxyRepo: repo}
- proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1")
+ proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1", "name", "ASC")
require.NoError(t, err)
require.Equal(t, int64(7), total)
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
require.Equal(t, 1, repo.listWithFiltersCalls)
- require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
+ require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50, SortBy: "name", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, "http", repo.listWithFiltersProtocol)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "p1", repo.listWithFiltersSearch)
@@ -229,13 +229,13 @@ func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{proxyRepo: repo}
- proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2")
+ proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2", "account_count", "DESC")
require.NoError(t, err)
require.Equal(t, int64(9), total)
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
- require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams)
+ require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10, SortBy: "account_count", SortOrder: "DESC"}, repo.listWithFiltersAndAccountCountParams)
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
@@ -250,13 +250,13 @@ func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
}
svc := &adminServiceImpl{redeemCodeRepo: repo}
- codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC")
+ codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC", "value", "ASC")
require.NoError(t, err)
require.Equal(t, int64(3), total)
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
require.Equal(t, 1, repo.listWithFiltersCalls)
- require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
+ require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20, SortBy: "value", SortOrder: "ASC"}, repo.listWithFiltersParams)
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
require.Equal(t, "ABC", repo.listWithFiltersSearch)
diff --git a/backend/internal/service/admin_service_update_user_rpm_test.go b/backend/internal/service/admin_service_update_user_rpm_test.go
new file mode 100644
index 00000000..cb4c3986
--- /dev/null
+++ b/backend/internal/service/admin_service_update_user_rpm_test.go
@@ -0,0 +1,69 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// rpmUserRepoStub 复用 admin_service_update_balance_test.go 的基础 stub 结构,
+// 只在 Update 时把入参克隆一份,便于断言修改后的 RPMLimit。
+type rpmUserRepoStub struct {
+ *userRepoStub
+ lastUpdated *User
+}
+
+func (s *rpmUserRepoStub) Update(_ context.Context, user *User) error {
+ if user == nil {
+ return nil
+ }
+ clone := *user
+ s.lastUpdated = &clone
+ if s.userRepoStub != nil {
+ s.userRepoStub.user = &clone
+ }
+ return nil
+}
+
+func TestAdminService_UpdateUser_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) {
+ base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10}}
+ repo := &rpmUserRepoStub{userRepoStub: base}
+ invalidator := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{
+ userRepo: repo,
+ redeemCodeRepo: &redeemRepoStub{},
+ authCacheInvalidator: invalidator,
+ }
+
+ newRPM := 60
+ updated, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
+ RPMLimit: &newRPM,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, updated)
+ require.Equal(t, 60, updated.RPMLimit)
+ require.Equal(t, []int64{42}, invalidator.userIDs, "仅修改 RPMLimit 也应失效 API Key 认证缓存")
+}
+
+func TestAdminService_UpdateUser_NoInvalidateWhenRPMLimitUnchanged(t *testing.T) {
+ base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10, Username: "old"}}
+ repo := &rpmUserRepoStub{userRepoStub: base}
+ invalidator := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{
+ userRepo: repo,
+ redeemCodeRepo: &redeemRepoStub{},
+ authCacheInvalidator: invalidator,
+ }
+
+ newName := "new"
+ sameRPM := 10
+ _, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{
+ Username: &newName,
+ RPMLimit: &sameRPM,
+ })
+ require.NoError(t, err)
+ require.Empty(t, invalidator.userIDs, "只改 username 不应触发认证缓存失效")
+}
diff --git a/backend/internal/service/affiliate_service.go b/backend/internal/service/affiliate_service.go
new file mode 100644
index 00000000..5a4e91e7
--- /dev/null
+++ b/backend/internal/service/affiliate_service.go
@@ -0,0 +1,490 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "math"
+ "strings"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+)
+
+var (
+ ErrAffiliateProfileNotFound = infraerrors.NotFound("AFFILIATE_PROFILE_NOT_FOUND", "affiliate profile not found")
+ ErrAffiliateCodeInvalid = infraerrors.BadRequest("AFFILIATE_CODE_INVALID", "invalid affiliate code")
+ ErrAffiliateCodeTaken = infraerrors.Conflict("AFFILIATE_CODE_TAKEN", "affiliate code already in use")
+ ErrAffiliateAlreadyBound = infraerrors.Conflict("AFFILIATE_ALREADY_BOUND", "affiliate inviter already bound")
+ ErrAffiliateQuotaEmpty = infraerrors.BadRequest("AFFILIATE_QUOTA_EMPTY", "no affiliate quota available to transfer")
+)
+
+const (
+ affiliateInviteesLimit = 100
+ // AffiliateCodeMinLength / AffiliateCodeMaxLength bound both system-generated
+ // 12-char codes and admin-customized codes (e.g. "VIP2026").
+ AffiliateCodeMinLength = 4
+ AffiliateCodeMaxLength = 32
+)
+
+// affiliateCodeValidChar accepts uppercase letters, digits, underscore and dash.
+// All input passes through strings.ToUpper before validation, so lowercase from
+// users is normalized — admins may supply mixed case in their UI.
+var affiliateCodeValidChar = func() [256]bool {
+ var tbl [256]bool
+ for c := byte('A'); c <= 'Z'; c++ {
+ tbl[c] = true
+ }
+ for c := byte('0'); c <= '9'; c++ {
+ tbl[c] = true
+ }
+ tbl['_'] = true
+ tbl['-'] = true
+ return tbl
+}()
+
+// isValidAffiliateCodeFormat validates code format for both binding (user input)
+// and admin updates. Caller is expected to upper-case the input first.
+func isValidAffiliateCodeFormat(code string) bool {
+ if len(code) < AffiliateCodeMinLength || len(code) > AffiliateCodeMaxLength {
+ return false
+ }
+ for i := 0; i < len(code); i++ {
+ if !affiliateCodeValidChar[code[i]] {
+ return false
+ }
+ }
+ return true
+}
+
+type AffiliateSummary struct {
+ UserID int64 `json:"user_id"`
+ AffCode string `json:"aff_code"`
+ AffCodeCustom bool `json:"aff_code_custom"`
+ AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"`
+ InviterID *int64 `json:"inviter_id,omitempty"`
+ AffCount int `json:"aff_count"`
+ AffQuota float64 `json:"aff_quota"`
+ AffFrozenQuota float64 `json:"aff_frozen_quota"`
+ AffHistoryQuota float64 `json:"aff_history_quota"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
+
+type AffiliateInvitee struct {
+ UserID int64 `json:"user_id"`
+ Email string `json:"email"`
+ Username string `json:"username"`
+ CreatedAt *time.Time `json:"created_at,omitempty"`
+ TotalRebate float64 `json:"total_rebate"`
+}
+
+type AffiliateDetail struct {
+ UserID int64 `json:"user_id"`
+ AffCode string `json:"aff_code"`
+ InviterID *int64 `json:"inviter_id,omitempty"`
+ AffCount int `json:"aff_count"`
+ AffQuota float64 `json:"aff_quota"`
+ AffFrozenQuota float64 `json:"aff_frozen_quota"`
+ AffHistoryQuota float64 `json:"aff_history_quota"`
+ // EffectiveRebateRatePercent 是当前用户作为邀请人时实际生效的返利比例:
+ // 优先用户自己的专属比例(aff_rebate_rate_percent),否则回退到全局比例。
+ // 用于在用户的 /affiliate 页面直观展示「分享后能拿到多少」。
+ EffectiveRebateRatePercent float64 `json:"effective_rebate_rate_percent"`
+ Invitees []AffiliateInvitee `json:"invitees"`
+}
+
+type AffiliateRepository interface {
+ EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error)
+ GetAffiliateByCode(ctx context.Context, code string) (*AffiliateSummary, error)
+ BindInviter(ctx context.Context, userID, inviterID int64) (bool, error)
+ AccrueQuota(ctx context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int) (bool, error)
+ GetAccruedRebateFromInvitee(ctx context.Context, inviterID, inviteeUserID int64) (float64, error)
+ ThawFrozenQuota(ctx context.Context, userID int64) (float64, error)
+ TransferQuotaToBalance(ctx context.Context, userID int64) (float64, float64, error)
+ ListInvitees(ctx context.Context, inviterID int64, limit int) ([]AffiliateInvitee, error)
+
+ // 管理端:用户级专属配置
+ UpdateUserAffCode(ctx context.Context, userID int64, newCode string) error
+ ResetUserAffCode(ctx context.Context, userID int64) (string, error)
+ SetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error
+ BatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error
+ ListUsersWithCustomSettings(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error)
+}
+
+// AffiliateAdminFilter 列表筛选条件
+type AffiliateAdminFilter struct {
+ Search string
+ Page int
+ PageSize int
+}
+
+// AffiliateAdminEntry 专属用户列表条目
+type AffiliateAdminEntry struct {
+ UserID int64 `json:"user_id"`
+ Email string `json:"email"`
+ Username string `json:"username"`
+ AffCode string `json:"aff_code"`
+ AffCodeCustom bool `json:"aff_code_custom"`
+ AffRebateRatePercent *float64 `json:"aff_rebate_rate_percent,omitempty"`
+ AffCount int `json:"aff_count"`
+}
+
+type AffiliateService struct {
+ repo AffiliateRepository
+ settingService *SettingService
+ authCacheInvalidator APIKeyAuthCacheInvalidator
+ billingCacheService *BillingCacheService
+}
+
+func NewAffiliateService(repo AffiliateRepository, settingService *SettingService, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCacheService *BillingCacheService) *AffiliateService {
+ return &AffiliateService{
+ repo: repo,
+ settingService: settingService,
+ authCacheInvalidator: authCacheInvalidator,
+ billingCacheService: billingCacheService,
+ }
+}
+
+// IsEnabled reports whether the affiliate (邀请返利) feature is turned on.
+func (s *AffiliateService) IsEnabled(ctx context.Context) bool {
+ if s == nil || s.settingService == nil {
+ return AffiliateEnabledDefault
+ }
+ return s.settingService.IsAffiliateEnabled(ctx)
+}
+
+func (s *AffiliateService) EnsureUserAffiliate(ctx context.Context, userID int64) (*AffiliateSummary, error) {
+ if userID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_USER", "invalid user")
+ }
+ if s == nil || s.repo == nil {
+ return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ return s.repo.EnsureUserAffiliate(ctx, userID)
+}
+
+func (s *AffiliateService) GetAffiliateDetail(ctx context.Context, userID int64) (*AffiliateDetail, error) {
+ // Lazy thaw: move any matured frozen quota to available before reading.
+ if s != nil && s.repo != nil {
+ // best-effort: thaw failure is non-fatal
+ _, _ = s.repo.ThawFrozenQuota(ctx, userID)
+ }
+
+ summary, err := s.EnsureUserAffiliate(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ invitees, err := s.listInvitees(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ return &AffiliateDetail{
+ UserID: summary.UserID,
+ AffCode: summary.AffCode,
+ InviterID: summary.InviterID,
+ AffCount: summary.AffCount,
+ AffQuota: summary.AffQuota,
+ AffFrozenQuota: summary.AffFrozenQuota,
+ AffHistoryQuota: summary.AffHistoryQuota,
+ EffectiveRebateRatePercent: s.resolveRebateRatePercent(ctx, summary),
+ Invitees: invitees,
+ }, nil
+}
+
+func (s *AffiliateService) BindInviterByCode(ctx context.Context, userID int64, rawCode string) error {
+ code := strings.ToUpper(strings.TrimSpace(rawCode))
+ if code == "" {
+ return nil
+ }
+ if s == nil || s.repo == nil {
+ return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ // 总开关关闭时,注册阶段静默忽略 aff 参数(不报错,避免阻断注册流程)
+ if !s.IsEnabled(ctx) {
+ return nil
+ }
+ if !isValidAffiliateCodeFormat(code) {
+ return ErrAffiliateCodeInvalid
+ }
+
+ selfSummary, err := s.repo.EnsureUserAffiliate(ctx, userID)
+ if err != nil {
+ return err
+ }
+ if selfSummary.InviterID != nil {
+ return nil
+ }
+
+ inviterSummary, err := s.repo.GetAffiliateByCode(ctx, code)
+ if err != nil {
+ if errors.Is(err, ErrAffiliateProfileNotFound) {
+ return ErrAffiliateCodeInvalid
+ }
+ return err
+ }
+ if inviterSummary == nil || inviterSummary.UserID <= 0 || inviterSummary.UserID == userID {
+ return ErrAffiliateCodeInvalid
+ }
+
+ bound, err := s.repo.BindInviter(ctx, userID, inviterSummary.UserID)
+ if err != nil {
+ return err
+ }
+ if !bound {
+ return ErrAffiliateAlreadyBound
+ }
+ return nil
+}
+
+func (s *AffiliateService) AccrueInviteRebate(ctx context.Context, inviteeUserID int64, baseRechargeAmount float64) (float64, error) {
+ if s == nil || s.repo == nil {
+ return 0, nil
+ }
+ if inviteeUserID <= 0 || baseRechargeAmount <= 0 || math.IsNaN(baseRechargeAmount) || math.IsInf(baseRechargeAmount, 0) {
+ return 0, nil
+ }
+ // 总开关关闭时,新充值不再产生返利
+ if !s.IsEnabled(ctx) {
+ return 0, nil
+ }
+
+ inviteeSummary, err := s.repo.EnsureUserAffiliate(ctx, inviteeUserID)
+ if err != nil {
+ return 0, err
+ }
+ if inviteeSummary.InviterID == nil || *inviteeSummary.InviterID <= 0 {
+ return 0, nil
+ }
+
+ // 加载邀请人 profile,优先使用专属比例(覆盖全局)
+ inviterSummary, err := s.repo.EnsureUserAffiliate(ctx, *inviteeSummary.InviterID)
+ if err != nil {
+ return 0, err
+ }
+ // 有效期检查:超过返利有效期后不再产生返利
+ if s.settingService != nil {
+ if durationDays := s.settingService.GetAffiliateRebateDurationDays(ctx); durationDays > 0 {
+ if time.Now().After(inviteeSummary.CreatedAt.AddDate(0, 0, durationDays)) {
+ return 0, nil
+ }
+ }
+ }
+
+ rebateRatePercent := s.resolveRebateRatePercent(ctx, inviterSummary)
+ rebate := roundTo(baseRechargeAmount*(rebateRatePercent/100), 8)
+ if rebate <= 0 {
+ return 0, nil
+ }
+
+ // 单人上限检查:精确截断到剩余额度
+ if s.settingService != nil {
+ if perInviteeCap := s.settingService.GetAffiliateRebatePerInviteeCap(ctx); perInviteeCap > 0 {
+ existing, err := s.repo.GetAccruedRebateFromInvitee(ctx, *inviteeSummary.InviterID, inviteeUserID)
+ if err != nil {
+ return 0, err
+ }
+ if existing >= perInviteeCap {
+ return 0, nil
+ }
+ if remaining := perInviteeCap - existing; rebate > remaining {
+ rebate = roundTo(remaining, 8)
+ }
+ }
+ }
+
+ var freezeHours int
+ if s.settingService != nil {
+ freezeHours = s.settingService.GetAffiliateRebateFreezeHours(ctx)
+ }
+
+ applied, err := s.repo.AccrueQuota(ctx, *inviteeSummary.InviterID, inviteeUserID, rebate, freezeHours)
+ if err != nil {
+ return 0, err
+ }
+ if !applied {
+ return 0, nil
+ }
+ return rebate, nil
+}
+
+// resolveRebateRatePercent returns the inviter's exclusive rate when set,
+// otherwise the global setting value (clamped to [Min, Max]).
+func (s *AffiliateService) resolveRebateRatePercent(ctx context.Context, inviter *AffiliateSummary) float64 {
+ if inviter != nil && inviter.AffRebateRatePercent != nil {
+ v := *inviter.AffRebateRatePercent
+ if math.IsNaN(v) || math.IsInf(v, 0) {
+ return s.globalRebateRatePercent(ctx)
+ }
+ return clampAffiliateRebateRate(v)
+ }
+ return s.globalRebateRatePercent(ctx)
+}
+
+// globalRebateRatePercent reads the system-wide rebate rate via SettingService,
+// returning the documented default when SettingService is unavailable.
+func (s *AffiliateService) globalRebateRatePercent(ctx context.Context) float64 {
+ if s == nil || s.settingService == nil {
+ return AffiliateRebateRateDefault
+ }
+ return s.settingService.GetAffiliateRebateRatePercent(ctx)
+}
+
+func (s *AffiliateService) TransferAffiliateQuota(ctx context.Context, userID int64) (float64, float64, error) {
+ if s == nil || s.repo == nil {
+ return 0, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+
+ transferred, balance, err := s.repo.TransferQuotaToBalance(ctx, userID)
+ if err != nil {
+ return 0, 0, err
+ }
+ if transferred > 0 {
+ s.invalidateAffiliateCaches(ctx, userID)
+ }
+ return transferred, balance, nil
+}
+
+func (s *AffiliateService) listInvitees(ctx context.Context, inviterID int64) ([]AffiliateInvitee, error) {
+ if s == nil || s.repo == nil {
+ return nil, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ invitees, err := s.repo.ListInvitees(ctx, inviterID, affiliateInviteesLimit)
+ if err != nil {
+ return nil, err
+ }
+ for i := range invitees {
+ invitees[i].Email = maskEmail(invitees[i].Email)
+ }
+ return invitees, nil
+}
+
+func roundTo(v float64, scale int) float64 {
+ factor := math.Pow10(scale)
+ return math.Round(v*factor) / factor
+}
+
+func maskEmail(email string) string {
+ email = strings.TrimSpace(email)
+ if email == "" {
+ return ""
+ }
+ at := strings.Index(email, "@")
+ if at <= 0 || at >= len(email)-1 {
+ return "***"
+ }
+
+ local := email[:at]
+ domain := email[at+1:]
+ dot := strings.LastIndex(domain, ".")
+
+ maskedLocal := maskSegment(local)
+ if dot <= 0 || dot >= len(domain)-1 {
+ return maskedLocal + "@" + maskSegment(domain)
+ }
+
+ domainName := domain[:dot]
+ tld := domain[dot:]
+ return maskedLocal + "@" + maskSegment(domainName) + tld
+}
+
+func maskSegment(s string) string {
+ r := []rune(s)
+ if len(r) == 0 {
+ return "***"
+ }
+ if len(r) == 1 {
+ return string(r[0]) + "***"
+ }
+ return string(r[0]) + "***"
+}
+
+func (s *AffiliateService) invalidateAffiliateCaches(ctx context.Context, userID int64) {
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
+ }
+ if s.billingCacheService != nil {
+ if err := s.billingCacheService.InvalidateUserBalance(ctx, userID); err != nil {
+ logger.LegacyPrintf("service.affiliate", "[Affiliate] Failed to invalidate billing cache for user %d: %v", userID, err)
+ }
+ }
+}
+
+// =========================
+// Admin: 专属配置管理
+// =========================
+
+// validateExclusiveRate ensures a per-user override is finite and within
+// [Min, Max]. nil is always valid (means "clear / fall back to global").
+func validateExclusiveRate(ratePercent *float64) error {
+ if ratePercent == nil {
+ return nil
+ }
+ v := *ratePercent
+ if math.IsNaN(v) || math.IsInf(v, 0) {
+ return infraerrors.BadRequest("INVALID_RATE", "invalid rebate rate")
+ }
+ if v < AffiliateRebateRateMin || v > AffiliateRebateRateMax {
+ return infraerrors.BadRequest("INVALID_RATE", "rebate rate out of range")
+ }
+ return nil
+}
+
+// AdminUpdateUserAffCode 管理员改写用户的邀请码(专属邀请码)。
+func (s *AffiliateService) AdminUpdateUserAffCode(ctx context.Context, userID int64, rawCode string) error {
+ if s == nil || s.repo == nil {
+ return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ code := strings.ToUpper(strings.TrimSpace(rawCode))
+ if !isValidAffiliateCodeFormat(code) {
+ return ErrAffiliateCodeInvalid
+ }
+ return s.repo.UpdateUserAffCode(ctx, userID, code)
+}
+
+// AdminResetUserAffCode 重置用户邀请码为系统随机码。
+func (s *AffiliateService) AdminResetUserAffCode(ctx context.Context, userID int64) (string, error) {
+ if s == nil || s.repo == nil {
+ return "", infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ return s.repo.ResetUserAffCode(ctx, userID)
+}
+
+// AdminSetUserRebateRate 设置/清除用户专属返利比例。ratePercent==nil 表示清除。
+func (s *AffiliateService) AdminSetUserRebateRate(ctx context.Context, userID int64, ratePercent *float64) error {
+ if s == nil || s.repo == nil {
+ return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ if err := validateExclusiveRate(ratePercent); err != nil {
+ return err
+ }
+ return s.repo.SetUserRebateRate(ctx, userID, ratePercent)
+}
+
+// AdminBatchSetUserRebateRate 批量设置/清除用户专属返利比例。
+func (s *AffiliateService) AdminBatchSetUserRebateRate(ctx context.Context, userIDs []int64, ratePercent *float64) error {
+ if s == nil || s.repo == nil {
+ return infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ if err := validateExclusiveRate(ratePercent); err != nil {
+ return err
+ }
+ cleaned := make([]int64, 0, len(userIDs))
+ for _, uid := range userIDs {
+ if uid > 0 {
+ cleaned = append(cleaned, uid)
+ }
+ }
+ if len(cleaned) == 0 {
+ return nil
+ }
+ return s.repo.BatchSetUserRebateRate(ctx, cleaned, ratePercent)
+}
+
+// AdminListCustomUsers 列出有专属配置的用户。
+func (s *AffiliateService) AdminListCustomUsers(ctx context.Context, filter AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) {
+ if s == nil || s.repo == nil {
+ return nil, 0, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "affiliate service unavailable")
+ }
+ return s.repo.ListUsersWithCustomSettings(ctx, filter)
+}
diff --git a/backend/internal/service/affiliate_service_test.go b/backend/internal/service/affiliate_service_test.go
new file mode 100644
index 00000000..c02a4dd7
--- /dev/null
+++ b/backend/internal/service/affiliate_service_test.go
@@ -0,0 +1,131 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "math"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// TestResolveRebateRatePercent_PerUserOverride verifies that per-inviter
+// AffRebateRatePercent overrides the global rate, that NULL falls back to the
+// global rate, and that out-of-range exclusive rates are clamped silently.
+//
+// SettingService is left nil here so globalRebateRatePercent returns the
+// documented default (AffiliateRebateRateDefault = 20%) — this exercises the
+// fallback path without spinning up a settings stub.
+func TestResolveRebateRatePercent_PerUserOverride(t *testing.T) {
+ t.Parallel()
+ svc := &AffiliateService{}
+
+ // nil exclusive rate → falls back to global default (20%)
+ require.InDelta(t, AffiliateRebateRateDefault,
+ svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{}), 1e-9)
+
+ // exclusive rate set → overrides global
+ rate := 50.0
+ require.InDelta(t, 50.0,
+ svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &rate}), 1e-9)
+
+ // exclusive rate 0 → returns 0 (no rebate, intentional)
+ zero := 0.0
+ require.InDelta(t, 0.0,
+ svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &zero}), 1e-9)
+
+ // exclusive rate above max → clamped to Max
+ tooHigh := 250.0
+ require.InDelta(t, AffiliateRebateRateMax,
+ svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooHigh}), 1e-9)
+
+ // exclusive rate below min → clamped to Min
+ tooLow := -5.0
+ require.InDelta(t, AffiliateRebateRateMin,
+ svc.resolveRebateRatePercent(context.Background(), &AffiliateSummary{AffRebateRatePercent: &tooLow}), 1e-9)
+}
+
+// TestIsEnabled_NilSettingServiceReturnsDefault verifies that IsEnabled
+// safely handles a nil settingService dependency by returning the default
+// (off). This protects callers from nil-pointer crashes in misconfigured
+// environments.
+func TestIsEnabled_NilSettingServiceReturnsDefault(t *testing.T) {
+ t.Parallel()
+ svc := &AffiliateService{}
+ require.False(t, svc.IsEnabled(context.Background()))
+ require.Equal(t, AffiliateEnabledDefault, svc.IsEnabled(context.Background()))
+}
+
+// TestValidateExclusiveRate_BoundaryAndInvalid covers the validator used by
+// admin-facing rate setters: nil is always valid (clear), in-range values
+// are accepted, NaN/Inf and out-of-range values produce a typed BadRequest.
+func TestValidateExclusiveRate_BoundaryAndInvalid(t *testing.T) {
+ t.Parallel()
+ require.NoError(t, validateExclusiveRate(nil))
+
+ for _, v := range []float64{0, 0.01, 50, 99.99, 100} {
+ v := v
+ require.NoError(t, validateExclusiveRate(&v), "value %v should be valid", v)
+ }
+
+ for _, v := range []float64{-0.01, 100.01, -100, 200} {
+ v := v
+ require.Error(t, validateExclusiveRate(&v), "value %v should be rejected", v)
+ }
+
+ nan := math.NaN()
+ require.Error(t, validateExclusiveRate(&nan))
+ posInf := math.Inf(1)
+ require.Error(t, validateExclusiveRate(&posInf))
+ negInf := math.Inf(-1)
+ require.Error(t, validateExclusiveRate(&negInf))
+}
+
+func TestMaskEmail(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, "a***@g***.com", maskEmail("alice@gmail.com"))
+ require.Equal(t, "x***@d***", maskEmail("x@domain"))
+ require.Equal(t, "", maskEmail(""))
+}
+
+func TestIsValidAffiliateCodeFormat(t *testing.T) {
+ t.Parallel()
+
+ // 邀请码格式校验同时服务于:
+ // 1) 系统自动生成的 12 位随机码(A-Z 去 I/O,2-9 去 0/1)
+ // 2) 管理员设置的自定义专属码(如 "VIP2026"、"NEW_USER-1")
+ // 因此校验放宽到 [A-Z0-9_-]{4,32}(要求调用方先 ToUpper)。
+ cases := []struct {
+ name string
+ in string
+ want bool
+ }{
+ {"valid canonical 12-char", "ABCDEFGHJKLM", true},
+ {"valid all digits 2-9", "234567892345", true},
+ {"valid mixed", "A2B3C4D5E6F7", true},
+ {"valid admin custom short", "VIP1", true},
+ {"valid admin custom with hyphen", "NEW-USER", true},
+ {"valid admin custom with underscore", "VIP_2026", true},
+ {"valid 32-char max", "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345", true},
+ // Previously-excluded chars (I/O/0/1) are now allowed since admins may use them.
+ {"letter I now allowed", "IBCDEFGHJKLM", true},
+ {"letter O now allowed", "OBCDEFGHJKLM", true},
+ {"digit 0 now allowed", "0BCDEFGHJKLM", true},
+ {"digit 1 now allowed", "1BCDEFGHJKLM", true},
+ {"too short (3 chars)", "ABC", false},
+ {"too long (33 chars)", "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456", false},
+ {"lowercase rejected (caller must ToUpper first)", "abcdefghjklm", false},
+ {"empty", "", false},
+ {"utf8 non-ascii", "ÄÄÄÄÄÄ", false}, // bytes out of charset
+ {"ascii punctuation .", "ABCDEFGHJK.M", false},
+ {"whitespace", "ABCDEFGHJK M", false},
+ }
+ for _, tc := range cases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tc.want, isValidAffiliateCodeFormat(tc.in))
+ })
+ }
+}
diff --git a/backend/internal/service/announcement.go b/backend/internal/service/announcement.go
index 25c66eb4..02741d37 100644
--- a/backend/internal/service/announcement.go
+++ b/backend/internal/service/announcement.go
@@ -5,6 +5,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
@@ -34,8 +35,23 @@ const (
)
var (
- ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
- ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
+ ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
+ ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
+ ErrAnnouncementNilInput = infraerrors.BadRequest("ANNOUNCEMENT_INPUT_REQUIRED", "announcement input is required")
+ ErrAnnouncementInvalidTitle = infraerrors.BadRequest("ANNOUNCEMENT_TITLE_INVALID", "announcement title is invalid")
+ ErrAnnouncementContentRequired = infraerrors.BadRequest(
+ "ANNOUNCEMENT_CONTENT_REQUIRED",
+ "announcement content is required",
+ )
+ ErrAnnouncementInvalidStatus = infraerrors.BadRequest("ANNOUNCEMENT_STATUS_INVALID", "announcement status is invalid")
+ ErrAnnouncementInvalidNotifyMode = infraerrors.BadRequest(
+ "ANNOUNCEMENT_NOTIFY_MODE_INVALID",
+ "announcement notify_mode is invalid",
+ )
+ ErrAnnouncementInvalidSchedule = infraerrors.BadRequest(
+ "ANNOUNCEMENT_TIME_RANGE_INVALID",
+ "starts_at must be before ends_at",
+ )
)
type AnnouncementTargeting = domain.AnnouncementTargeting
diff --git a/backend/internal/service/announcement_service.go b/backend/internal/service/announcement_service.go
index c0a0681a..12479041 100644
--- a/backend/internal/service/announcement_service.go
+++ b/backend/internal/service/announcement_service.go
@@ -70,16 +70,16 @@ type AnnouncementUserReadStatus struct {
func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncementInput) (*Announcement, error) {
if input == nil {
- return nil, fmt.Errorf("create announcement: nil input")
+ return nil, ErrAnnouncementNilInput
}
title := strings.TrimSpace(input.Title)
content := strings.TrimSpace(input.Content)
if title == "" || len(title) > 200 {
- return nil, fmt.Errorf("create announcement: invalid title")
+ return nil, ErrAnnouncementInvalidTitle
}
if content == "" {
- return nil, fmt.Errorf("create announcement: content is required")
+ return nil, ErrAnnouncementContentRequired
}
status := strings.TrimSpace(input.Status)
@@ -87,7 +87,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
status = AnnouncementStatusDraft
}
if !isValidAnnouncementStatus(status) {
- return nil, fmt.Errorf("create announcement: invalid status")
+ return nil, ErrAnnouncementInvalidStatus
}
targeting, err := domain.AnnouncementTargeting(input.Targeting).NormalizeAndValidate()
@@ -100,12 +100,12 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
notifyMode = AnnouncementNotifyModeSilent
}
if !isValidAnnouncementNotifyMode(notifyMode) {
- return nil, fmt.Errorf("create announcement: invalid notify_mode")
+ return nil, ErrAnnouncementInvalidNotifyMode
}
if input.StartsAt != nil && input.EndsAt != nil {
if !input.StartsAt.Before(*input.EndsAt) {
- return nil, fmt.Errorf("create announcement: starts_at must be before ends_at")
+ return nil, ErrAnnouncementInvalidSchedule
}
}
@@ -131,7 +131,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
func (s *AnnouncementService) Update(ctx context.Context, id int64, input *UpdateAnnouncementInput) (*Announcement, error) {
if input == nil {
- return nil, fmt.Errorf("update announcement: nil input")
+ return nil, ErrAnnouncementNilInput
}
a, err := s.announcementRepo.GetByID(ctx, id)
@@ -142,21 +142,21 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
if input.Title != nil {
title := strings.TrimSpace(*input.Title)
if title == "" || len(title) > 200 {
- return nil, fmt.Errorf("update announcement: invalid title")
+ return nil, ErrAnnouncementInvalidTitle
}
a.Title = title
}
if input.Content != nil {
content := strings.TrimSpace(*input.Content)
if content == "" {
- return nil, fmt.Errorf("update announcement: content is required")
+ return nil, ErrAnnouncementContentRequired
}
a.Content = content
}
if input.Status != nil {
status := strings.TrimSpace(*input.Status)
if !isValidAnnouncementStatus(status) {
- return nil, fmt.Errorf("update announcement: invalid status")
+ return nil, ErrAnnouncementInvalidStatus
}
a.Status = status
}
@@ -164,7 +164,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
if input.NotifyMode != nil {
notifyMode := strings.TrimSpace(*input.NotifyMode)
if !isValidAnnouncementNotifyMode(notifyMode) {
- return nil, fmt.Errorf("update announcement: invalid notify_mode")
+ return nil, ErrAnnouncementInvalidNotifyMode
}
a.NotifyMode = notifyMode
}
@@ -186,7 +186,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
if a.StartsAt != nil && a.EndsAt != nil {
if !a.StartsAt.Before(*a.EndsAt) {
- return nil, fmt.Errorf("update announcement: starts_at must be before ends_at")
+ return nil, ErrAnnouncementInvalidSchedule
}
}
diff --git a/backend/internal/service/announcement_service_test.go b/backend/internal/service/announcement_service_test.go
new file mode 100644
index 00000000..77fb9896
--- /dev/null
+++ b/backend/internal/service/announcement_service_test.go
@@ -0,0 +1,81 @@
+package service
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+type announcementRepoStub struct {
+ item *Announcement
+}
+
+func (s *announcementRepoStub) Create(_ context.Context, a *Announcement) error {
+ s.item = a
+ return nil
+}
+
+func (s *announcementRepoStub) GetByID(_ context.Context, _ int64) (*Announcement, error) {
+ if s.item == nil {
+ return nil, ErrAnnouncementNotFound
+ }
+ return s.item, nil
+}
+
+func (s *announcementRepoStub) Update(_ context.Context, a *Announcement) error {
+ s.item = a
+ return nil
+}
+
+func (*announcementRepoStub) Delete(context.Context, int64) error {
+ return nil
+}
+
+func (*announcementRepoStub) List(context.Context, pagination.PaginationParams, AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error) {
+ return nil, nil, nil
+}
+
+func (*announcementRepoStub) ListActive(context.Context, time.Time) ([]Announcement, error) {
+ return nil, nil
+}
+
+func TestAnnouncementServiceCreateRejectsEqualStartEndTimes(t *testing.T) {
+ repo := &announcementRepoStub{}
+ svc := NewAnnouncementService(repo, nil, nil, nil)
+ now := time.Unix(1776790020, 0)
+
+ _, err := svc.Create(context.Background(), &CreateAnnouncementInput{
+ Title: "公告",
+ Content: "内容",
+ Status: AnnouncementStatusActive,
+ NotifyMode: AnnouncementNotifyModePopup,
+ StartsAt: &now,
+ EndsAt: &now,
+ })
+ require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule)
+}
+
+func TestAnnouncementServiceUpdateRejectsEqualStartEndTimes(t *testing.T) {
+ repo := &announcementRepoStub{
+ item: &Announcement{
+ ID: 1,
+ Title: "公告",
+ Content: "内容",
+ Status: AnnouncementStatusActive,
+ NotifyMode: AnnouncementNotifyModePopup,
+ },
+ }
+ svc := NewAnnouncementService(repo, nil, nil, nil)
+ now := time.Unix(1776790020, 0)
+ startsAt := &now
+ endsAt := &now
+
+ _, err := svc.Update(context.Background(), 1, &UpdateAnnouncementInput{
+ StartsAt: &startsAt,
+ EndsAt: &endsAt,
+ })
+ require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule)
+}
diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go
index a300d898..3a4600db 100644
--- a/backend/internal/service/antigravity_oauth_service.go
+++ b/backend/internal/service/antigravity_oauth_service.go
@@ -91,6 +91,7 @@ type AntigravityTokenInfo struct {
ProjectID string `json:"project_id,omitempty"`
ProjectIDMissing bool `json:"-"`
PlanType string `json:"-"`
+ PrivacyMode string `json:"-"`
}
// ExchangeCode 用 authorization code 交换 token
@@ -159,6 +160,9 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
}
}
+ // 令牌刚获取,立即设置隐私(不依赖后续账号创建流程)
+ result.PrivacyMode = setAntigravityPrivacy(ctx, result.AccessToken, result.ProjectID, proxyURL)
+
return result, nil
}
@@ -248,6 +252,9 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr
}
}
+ // 令牌刚获取,立即设置隐私
+ tokenInfo.PrivacyMode = setAntigravityPrivacy(ctx, tokenInfo.AccessToken, tokenInfo.ProjectID, proxyURL)
+
return tokenInfo, nil
}
diff --git a/backend/internal/service/antigravity_smart_retry_test.go b/backend/internal/service/antigravity_smart_retry_test.go
index ecaffcbc..e3b60a27 100644
--- a/backend/internal/service/antigravity_smart_retry_test.go
+++ b/backend/internal/service/antigravity_smart_retry_test.go
@@ -5,13 +5,12 @@ package service
import (
"bytes"
"context"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
+ "github.com/stretchr/testify/require"
"io"
"net/http"
"strings"
"testing"
-
- "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
- "github.com/stretchr/testify/require"
)
// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock
@@ -81,17 +80,12 @@ func (m *mockSmartRetryUpstream) Do(req *http.Request, proxyURL string, accountI
m.responseBodies[respIdx] = bodyBytes
}
- // 用缓存的 body 字节重建新的 reader
- var body io.ReadCloser
+ // 用缓存的 body 重建 reader(支持重试场景多次读取)
+ cloned := *resp
if m.responseBodies[respIdx] != nil {
- body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx]))
+ cloned.Body = io.NopCloser(bytes.NewReader(m.responseBodies[respIdx]))
}
-
- return &http.Response{
- StatusCode: resp.StatusCode,
- Header: resp.Header.Clone(),
- Body: body,
- }, respErr
+ return &cloned, respErr
}
func (m *mockSmartRetryUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go
index e8ad5c9c..1a1c78b8 100644
--- a/backend/internal/service/api_key_auth_cache.go
+++ b/backend/internal/service/api_key_auth_cache.go
@@ -4,6 +4,7 @@ import "time"
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type APIKeyAuthSnapshot struct {
+ Version int `json:"version"`
APIKeyID int64 `json:"api_key_id"`
UserID int64 `json:"user_id"`
GroupID *int64 `json:"group_id,omitempty"`
@@ -33,6 +34,22 @@ type APIKeyAuthUserSnapshot struct {
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
+
+ // Balance notification fields (required for CheckBalanceAfterDeduction)
+ Email string `json:"email"`
+ Username string `json:"username"`
+ BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
+ BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
+ BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
+ BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
+ TotalRecharged float64 `json:"total_recharged"`
+
+ // RPMLimit 用户级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 兜底判断。
+ RPMLimit int `json:"rpm_limit"`
+
+ // UserGroupRPMOverride 该 API Key 对应的 (user, group) 专属 RPM 覆盖值。
+ // nil = 无 override(回退到 group/user 级);0 = 不限流;>0 = 专属上限。
+ UserGroupRPMOverride *int `json:"user_group_rpm_override,omitempty"`
}
// APIKeyAuthGroupSnapshot 分组快照
@@ -49,10 +66,6 @@ type APIKeyAuthGroupSnapshot struct {
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
- SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
- SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
- SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
- SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
@@ -67,8 +80,12 @@ type APIKeyAuthGroupSnapshot struct {
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
- AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
- DefaultMappedModel string `json:"default_mapped_model,omitempty"`
+ AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
+ DefaultMappedModel string `json:"default_mapped_model,omitempty"`
+ MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"`
+
+ // RPMLimit 分组级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。
+ RPMLimit int `json:"rpm_limit"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go
index f727ab10..974ea66e 100644
--- a/backend/internal/service/api_key_auth_cache_impl.go
+++ b/backend/internal/service/api_key_auth_cache_impl.go
@@ -6,6 +6,7 @@ import (
"encoding/hex"
"errors"
"fmt"
+ "log/slog"
"math/rand/v2"
"time"
@@ -13,6 +14,8 @@ import (
"github.com/dgraph-io/ristretto"
)
+const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot
+
type apiKeyAuthCacheConfig struct {
l1Size int
l1TTL time.Duration
@@ -97,7 +100,7 @@ func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context
s.authCacheL1.Del(cacheKey)
}); err != nil {
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
- println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error())
+ slog.Warn("failed to start auth cache invalidation subscriber", "error", err)
}
}
@@ -173,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
- snapshot := s.snapshotFromAPIKey(apiKey)
+ snapshot := s.snapshotFromAPIKey(ctx, apiKey)
if snapshot == nil {
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
}
@@ -192,14 +195,18 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn
if entry.Snapshot == nil {
return nil, false, nil
}
+ if entry.Snapshot.Version != apiKeyAuthSnapshotVersion {
+ return nil, false, nil
+ }
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
}
-func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
+func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) *APIKeyAuthSnapshot {
if apiKey == nil || apiKey.User == nil {
return nil
}
snapshot := &APIKeyAuthSnapshot{
+ Version: apiKeyAuthSnapshotVersion,
APIKeyID: apiKey.ID,
UserID: apiKey.UserID,
GroupID: apiKey.GroupID,
@@ -213,13 +220,30 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
RateLimit1d: apiKey.RateLimit1d,
RateLimit7d: apiKey.RateLimit7d,
User: APIKeyAuthUserSnapshot{
- ID: apiKey.User.ID,
- Status: apiKey.User.Status,
- Role: apiKey.User.Role,
- Balance: apiKey.User.Balance,
- Concurrency: apiKey.User.Concurrency,
+ ID: apiKey.User.ID,
+ Status: apiKey.User.Status,
+ Role: apiKey.User.Role,
+ Balance: apiKey.User.Balance,
+ Concurrency: apiKey.User.Concurrency,
+ Email: apiKey.User.Email,
+ Username: apiKey.User.Username,
+ BalanceNotifyEnabled: apiKey.User.BalanceNotifyEnabled,
+ BalanceNotifyThresholdType: apiKey.User.BalanceNotifyThresholdType,
+ BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
+ BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
+ TotalRecharged: apiKey.User.TotalRecharged,
+ RPMLimit: apiKey.User.RPMLimit,
},
}
+
+ // 填充 (user, group) RPM override —— snapshot 构建时查一次 DB,后续请求零 DB 往返。
+ if apiKey.GroupID != nil && *apiKey.GroupID > 0 && s.userGroupRateRepo != nil {
+ override, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, apiKey.UserID, *apiKey.GroupID)
+ if err == nil && override != nil {
+ snapshot.User.UserGroupRPMOverride = override
+ }
+ // 查询失败或无 override 时留 nil,checkRPM 会回退到 DB 查询
+ }
if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID,
@@ -234,10 +258,6 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K,
- SoraImagePrice360: apiKey.Group.SoraImagePrice360,
- SoraImagePrice540: apiKey.Group.SoraImagePrice540,
- SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
- SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
@@ -247,6 +267,8 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
+ MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig,
+ RPMLimit: apiKey.Group.RPMLimit,
}
}
return snapshot
@@ -271,11 +293,20 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
RateLimit1d: snapshot.RateLimit1d,
RateLimit7d: snapshot.RateLimit7d,
User: &User{
- ID: snapshot.User.ID,
- Status: snapshot.User.Status,
- Role: snapshot.User.Role,
- Balance: snapshot.User.Balance,
- Concurrency: snapshot.User.Concurrency,
+ ID: snapshot.User.ID,
+ Status: snapshot.User.Status,
+ Role: snapshot.User.Role,
+ Balance: snapshot.User.Balance,
+ Concurrency: snapshot.User.Concurrency,
+ Email: snapshot.User.Email,
+ Username: snapshot.User.Username,
+ BalanceNotifyEnabled: snapshot.User.BalanceNotifyEnabled,
+ BalanceNotifyThresholdType: snapshot.User.BalanceNotifyThresholdType,
+ BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
+ BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
+ TotalRecharged: snapshot.User.TotalRecharged,
+ RPMLimit: snapshot.User.RPMLimit,
+ UserGroupRPMOverride: snapshot.User.UserGroupRPMOverride,
},
}
if snapshot.Group != nil {
@@ -293,10 +324,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K,
- SoraImagePrice360: snapshot.Group.SoraImagePrice360,
- SoraImagePrice540: snapshot.Group.SoraImagePrice540,
- SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
- SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
@@ -306,6 +333,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
+ MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig,
+ RPMLimit: snapshot.Group.RPMLimit,
}
}
s.compileAPIKeyIPRules(apiKey)
diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go
index 357f8def..8cb1b8c4 100644
--- a/backend/internal/service/api_key_service_cache_test.go
+++ b/backend/internal/service/api_key_service_cache_test.go
@@ -188,6 +188,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{
Snapshot: &APIKeyAuthSnapshot{
+ Version: apiKeyAuthSnapshotVersion,
APIKeyID: 1,
UserID: 2,
GroupID: &groupID,
@@ -226,6 +227,129 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting)
}
+func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t *testing.T) {
+ svc := NewAPIKeyService(nil, nil, nil, nil, nil, nil, &config.Config{})
+ groupID := int64(9)
+ apiKey := &APIKey{
+ ID: 1,
+ UserID: 2,
+ GroupID: &groupID,
+ Key: "k-roundtrip",
+ Status: StatusActive,
+ User: &User{
+ ID: 2,
+ Status: StatusActive,
+ Role: RoleUser,
+ Balance: 10,
+ Concurrency: 3,
+ },
+ Group: &Group{
+ ID: groupID,
+ Name: "openai",
+ Platform: PlatformOpenAI,
+ Status: StatusActive,
+ SubscriptionType: SubscriptionTypeStandard,
+ RateMultiplier: 1,
+ AllowMessagesDispatch: true,
+ DefaultMappedModel: "gpt-5.4",
+ MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
+ OpusMappedModel: "gpt-5.4-nano",
+ SonnetMappedModel: "gpt-5.3-codex",
+ HaikuMappedModel: "gpt-5.4-mini",
+ ExactModelMappings: map[string]string{
+ "claude-sonnet-4.5": "gpt-5.4-nano",
+ },
+ },
+ },
+ }
+
+ snapshot := svc.snapshotFromAPIKey(context.Background(), apiKey)
+ roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot)
+
+ require.NotNil(t, roundTrip)
+ require.NotNil(t, roundTrip.Group)
+ require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig)
+}
+
+func TestAPIKeyService_GetByKey_IgnoresLegacyAuthCacheSnapshotWithoutMessagesDispatchConfig(t *testing.T) {
+ cache := &authCacheStub{}
+ var repoCalls int32
+ repo := &authRepoStub{
+ getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
+ atomic.AddInt32(&repoCalls, 1)
+ groupID := int64(9)
+ return &APIKey{
+ ID: 1,
+ UserID: 2,
+ GroupID: &groupID,
+ Status: StatusActive,
+ User: &User{
+ ID: 2,
+ Status: StatusActive,
+ Role: RoleUser,
+ Balance: 10,
+ Concurrency: 3,
+ },
+ Group: &Group{
+ ID: groupID,
+ Name: "openai",
+ Platform: PlatformOpenAI,
+ Status: StatusActive,
+ Hydrated: true,
+ SubscriptionType: SubscriptionTypeStandard,
+ RateMultiplier: 1,
+ AllowMessagesDispatch: true,
+ DefaultMappedModel: "gpt-5.4",
+ MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{
+ OpusMappedModel: "gpt-5.4-nano",
+ },
+ },
+ }, nil
+ },
+ }
+ cfg := &config.Config{
+ APIKeyAuth: config.APIKeyAuthCacheConfig{
+ L2TTLSeconds: 60,
+ },
+ }
+ svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
+
+ groupID := int64(9)
+ cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
+ return &APIKeyAuthCacheEntry{
+ Snapshot: &APIKeyAuthSnapshot{
+ APIKeyID: 1,
+ UserID: 2,
+ GroupID: &groupID,
+ Status: StatusActive,
+ User: APIKeyAuthUserSnapshot{
+ ID: 2,
+ Status: StatusActive,
+ Role: RoleUser,
+ Balance: 10,
+ Concurrency: 3,
+ },
+ Group: &APIKeyAuthGroupSnapshot{
+ ID: groupID,
+ Name: "openai",
+ Platform: PlatformOpenAI,
+ Status: StatusActive,
+ SubscriptionType: SubscriptionTypeStandard,
+ RateMultiplier: 1,
+ AllowMessagesDispatch: true,
+ DefaultMappedModel: "gpt-5.4",
+ },
+ },
+ }, nil
+ }
+
+ apiKey, err := svc.GetByKey(context.Background(), "k-legacy")
+ require.NoError(t, err)
+ require.Equal(t, int32(1), atomic.LoadInt32(&repoCalls))
+ require.NotNil(t, apiKey.Group)
+ require.Equal(t, "gpt-5.4-nano", apiKey.Group.MessagesDispatchModelConfig.OpusMappedModel)
+}
+
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
diff --git a/backend/internal/service/auth_email_binding.go b/backend/internal/service/auth_email_binding.go
new file mode 100644
index 00000000..78f1185d
--- /dev/null
+++ b/backend/internal/service/auth_email_binding.go
@@ -0,0 +1,319 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/mail"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+)
+
+// BindEmailIdentity verifies and binds a local email/password identity to the
+// current user, or replaces the existing bound primary email.
+func (s *AuthService) BindEmailIdentity(
+ ctx context.Context,
+ userID int64,
+ email string,
+ verifyCode string,
+ password string,
+) (*User, error) {
+ if s == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ normalizedEmail, err := normalizeEmailForIdentityBinding(email)
+ if err != nil {
+ return nil, err
+ }
+ if isReservedEmail(normalizedEmail) {
+ return nil, ErrEmailReserved
+ }
+ if strings.TrimSpace(password) == "" {
+ return nil, ErrPasswordRequired
+ }
+ if err := s.VerifyOAuthEmailCode(ctx, normalizedEmail, verifyCode); err != nil {
+ return nil, err
+ }
+
+ currentUser, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
+ if firstRealEmailBind && len(password) < 6 {
+ return nil, infraerrors.BadRequest("PASSWORD_TOO_SHORT", "password must be at least 6 characters")
+ }
+ if !firstRealEmailBind && !s.CheckPassword(password, currentUser.PasswordHash) {
+ return nil, ErrPasswordIncorrect
+ }
+
+ existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
+ switch {
+ case err == nil && existingUser != nil && existingUser.ID != userID:
+ return nil, ErrEmailExists
+ case err != nil && !errors.Is(err, ErrUserNotFound):
+ return nil, ErrServiceUnavailable
+ }
+
+ hashedPassword, err := s.HashPassword(password)
+ if err != nil {
+ return nil, fmt.Errorf("hash password: %w", err)
+ }
+
+ if s.entClient != nil {
+ if err := s.updateBoundEmailIdentityTx(ctx, currentUser, normalizedEmail, hashedPassword, firstRealEmailBind); err != nil {
+ return nil, err
+ }
+ s.revokeEmailIdentitySessions(ctx, userID)
+ return currentUser, nil
+ }
+
+ currentUser.Email = normalizedEmail
+ currentUser.PasswordHash = hashedPassword
+ if err := s.userRepo.Update(ctx, currentUser); err != nil {
+ if errors.Is(err, ErrEmailExists) {
+ return nil, ErrEmailExists
+ }
+ return nil, ErrServiceUnavailable
+ }
+
+ if firstRealEmailBind {
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, userID, "email"); err != nil {
+ return nil, fmt.Errorf("apply email first bind defaults: %w", err)
+ }
+ }
+
+ s.revokeEmailIdentitySessions(ctx, userID)
+ return currentUser, nil
+}
+
+// SendEmailIdentityBindCode sends a verification code for authenticated email binding flows.
+func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error {
+ if s == nil {
+ return ErrServiceUnavailable
+ }
+
+ normalizedEmail, err := normalizeEmailForIdentityBinding(email)
+ if err != nil {
+ return err
+ }
+ if isReservedEmail(normalizedEmail) {
+ return ErrEmailReserved
+ }
+ if s.emailService == nil {
+ return ErrServiceUnavailable
+ }
+ if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
+ if errors.Is(err, ErrUserNotFound) {
+ return ErrUserNotFound
+ }
+ return ErrServiceUnavailable
+ }
+
+ existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
+ switch {
+ case err == nil && existingUser != nil && existingUser.ID != userID:
+ return ErrEmailExists
+ case err != nil && !errors.Is(err, ErrUserNotFound):
+ return ErrServiceUnavailable
+ }
+
+ siteName := "Sub2API"
+ if s.settingService != nil {
+ siteName = s.settingService.GetSiteName(ctx)
+ }
+ return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName)
+}
+
+func normalizeEmailForIdentityBinding(email string) (string, error) {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" || len(normalized) > 255 {
+ return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
+ }
+ if _, err := mail.ParseAddress(normalized); err != nil {
+ return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
+ }
+ return normalized, nil
+}
+
+func hasBindableEmailIdentitySubject(email string) bool {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ return normalized != "" && !isReservedEmail(normalized)
+}
+
+func (s *AuthService) updateBoundEmailIdentityTx(
+ ctx context.Context,
+ currentUser *User,
+ email string,
+ hashedPassword string,
+ applyFirstBindDefaults bool,
+) error {
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return s.updateBoundEmailIdentityWithClient(ctx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults)
+ }
+
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil {
+ return ErrServiceUnavailable
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := s.updateBoundEmailIdentityWithClient(txCtx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults); err != nil {
+ return err
+ }
+ if err := tx.Commit(); err != nil {
+ return ErrServiceUnavailable
+ }
+ return nil
+}
+
+func (s *AuthService) updateBoundEmailIdentityWithClient(
+ ctx context.Context,
+ client *dbent.Client,
+ currentUser *User,
+ email string,
+ hashedPassword string,
+ applyFirstBindDefaults bool,
+) error {
+ if client == nil || currentUser == nil || currentUser.ID <= 0 {
+ return ErrServiceUnavailable
+ }
+
+ oldEmail := currentUser.Email
+ if _, err := client.User.UpdateOneID(currentUser.ID).
+ SetEmail(email).
+ SetPasswordHash(hashedPassword).
+ Save(ctx); err != nil {
+ if dbent.IsConstraintError(err) {
+ return ErrEmailExists
+ }
+ return ErrServiceUnavailable
+ }
+
+ if err := replaceBoundEmailAuthIdentityWithClient(ctx, client, currentUser.ID, oldEmail, email, "auth_service_email_bind"); err != nil {
+ if errors.Is(err, ErrEmailExists) {
+ return ErrEmailExists
+ }
+ return ErrServiceUnavailable
+ }
+
+ if applyFirstBindDefaults {
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil {
+ return fmt.Errorf("apply email first bind defaults: %w", err)
+ }
+ }
+
+ updatedUser, err := client.User.Get(ctx, currentUser.ID)
+ if err != nil {
+ return ErrServiceUnavailable
+ }
+ currentUser.Email = updatedUser.Email
+ currentUser.PasswordHash = updatedUser.PasswordHash
+ currentUser.Balance = updatedUser.Balance
+ currentUser.Concurrency = updatedUser.Concurrency
+ currentUser.UpdatedAt = updatedUser.UpdatedAt
+ return nil
+}
+
+func (s *AuthService) revokeEmailIdentitySessions(ctx context.Context, userID int64) {
+ if err := s.RevokeAllUserSessions(ctx, userID); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after email identity bind for user %d: %v", userID, err)
+ }
+}
+
+func replaceBoundEmailAuthIdentityWithClient(
+ ctx context.Context,
+ client *dbent.Client,
+ userID int64,
+ oldEmail string,
+ newEmail string,
+ source string,
+) error {
+ newSubject := normalizeBoundEmailAuthIdentitySubject(newEmail)
+ if err := ensureBoundEmailAuthIdentityWithClient(ctx, client, userID, newSubject, source); err != nil {
+ return err
+ }
+
+ oldSubject := normalizeBoundEmailAuthIdentitySubject(oldEmail)
+ if oldSubject == "" || oldSubject == newSubject {
+ return nil
+ }
+
+ _, err := client.AuthIdentity.Delete().
+ Where(
+ authidentity.UserIDEQ(userID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(oldSubject),
+ ).
+ Exec(ctx)
+ return err
+}
+
+func ensureBoundEmailAuthIdentityWithClient(
+ ctx context.Context,
+ client *dbent.Client,
+ userID int64,
+ subject string,
+ source string,
+) error {
+ if client == nil || userID <= 0 || subject == "" {
+ return nil
+ }
+
+ if strings.TrimSpace(source) == "" {
+ source = "auth_service_email_bind"
+ }
+
+ if err := client.AuthIdentity.Create().
+ SetUserID(userID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(subject).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": strings.TrimSpace(source)}).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ if !isSQLNoRowsError(err) {
+ return err
+ }
+ }
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(subject),
+ ).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil
+ }
+ return err
+ }
+ if identity.UserID != userID {
+ return ErrEmailExists
+ }
+ return nil
+}
+
+func normalizeBoundEmailAuthIdentitySubject(email string) string {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ if normalized == "" || isReservedEmail(normalized) {
+ return ""
+ }
+ return normalized
+}
diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go
new file mode 100644
index 00000000..9815f31b
--- /dev/null
+++ b/backend/internal/service/auth_oauth_email_flow.go
@@ -0,0 +1,387 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/mail"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/redeemcode"
+)
+
+func normalizeOAuthSignupSource(signupSource string) string {
+ signupSource = strings.TrimSpace(strings.ToLower(signupSource))
+ switch signupSource {
+ case "", "email":
+ return "email"
+ case "linuxdo", "wechat", "oidc":
+ return signupSource
+ default:
+ return "email"
+ }
+}
+
+// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
+// account-creation flows without relying on the public registration gate.
+func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
+ email = strings.TrimSpace(strings.ToLower(email))
+ if email == "" {
+ return nil, ErrEmailVerifyRequired
+ }
+ if _, err := mail.ParseAddress(email); err != nil {
+ return nil, ErrEmailVerifyRequired
+ }
+ if isReservedEmail(email) {
+ return nil, ErrEmailReserved
+ }
+ if s == nil || s.emailService == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ siteName := "Sub2API"
+ if s.settingService != nil {
+ siteName = s.settingService.GetSiteName(ctx)
+ }
+ if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil {
+ return nil, err
+ }
+ return &SendVerifyCodeResult{
+ Countdown: int(verifyCodeCooldown / time.Second),
+ }, nil
+}
+
+func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) {
+ if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
+ return nil, nil
+ }
+ if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ invitationCode = strings.TrimSpace(invitationCode)
+ if invitationCode == "" {
+ return nil, ErrInvitationCodeRequired
+ }
+
+ redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode)
+ if err != nil {
+ return nil, ErrInvitationCodeInvalid
+ }
+ if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
+ return nil, ErrInvitationCodeInvalid
+ }
+ return redeemCode, nil
+}
+
+// VerifyOAuthEmailCode verifies the locally entered email verification code for
+// third-party signup and binding flows. This is intentionally independent from
+// the global registration email verification toggle.
+func (s *AuthService) VerifyOAuthEmailCode(ctx context.Context, email, verifyCode string) error {
+ email = strings.TrimSpace(strings.ToLower(email))
+ verifyCode = strings.TrimSpace(verifyCode)
+
+ if email == "" {
+ return ErrEmailVerifyRequired
+ }
+ if verifyCode == "" {
+ return ErrEmailVerifyRequired
+ }
+ if s == nil || s.emailService == nil {
+ return ErrServiceUnavailable
+ }
+ return s.emailService.VerifyCode(ctx, email, verifyCode)
+}
+
+// RegisterOAuthEmailAccount creates a local account from a third-party first
+// login after the user has verified a local email address.
+func (s *AuthService) RegisterOAuthEmailAccount(
+ ctx context.Context,
+ email string,
+ password string,
+ verifyCode string,
+ invitationCode string,
+ signupSource string,
+) (*TokenPair, *User, error) {
+ if s == nil {
+ return nil, nil, ErrServiceUnavailable
+ }
+ if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
+ return nil, nil, ErrRegDisabled
+ }
+
+ email = strings.TrimSpace(strings.ToLower(email))
+ if isReservedEmail(email) {
+ return nil, nil, ErrEmailReserved
+ }
+ if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
+ return nil, nil, err
+ }
+ if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil {
+ return nil, nil, err
+ }
+
+ if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
+ return nil, nil, err
+ }
+
+ existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
+ if err != nil {
+ return nil, nil, ErrServiceUnavailable
+ }
+ if existsEmail {
+ return nil, nil, ErrEmailExists
+ }
+
+ hashedPassword, err := s.HashPassword(password)
+ if err != nil {
+ return nil, nil, fmt.Errorf("hash password: %w", err)
+ }
+
+ signupSource = normalizeOAuthSignupSource(signupSource)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
+
+ user := &User{
+ Email: email,
+ PasswordHash: hashedPassword,
+ Role: RoleUser,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
+ Status: StatusActive,
+ SignupSource: signupSource,
+ }
+
+ if err := s.userRepo.Create(ctx, user); err != nil {
+ if errors.Is(err, ErrEmailExists) {
+ return nil, nil, ErrEmailExists
+ }
+ return nil, nil, ErrServiceUnavailable
+ }
+
+ tokenPair, err := s.GenerateTokenPair(ctx, user, "")
+ if err != nil {
+ _ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "")
+ return nil, nil, fmt.Errorf("generate token pair: %w", err)
+ }
+ return tokenPair, user, nil
+}
+
+// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
+// only after the pending OAuth flow has fully reached its last reversible step.
+func (s *AuthService) FinalizeOAuthEmailAccount(
+ ctx context.Context,
+ user *User,
+ invitationCode string,
+ signupSource string,
+ affiliateCode string,
+) error {
+ if s == nil || user == nil || user.ID <= 0 {
+ return ErrServiceUnavailable
+ }
+
+ signupSource = normalizeOAuthSignupSource(signupSource)
+ invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode)
+ if err != nil {
+ return err
+ }
+ if invitationRedeemCode != nil {
+ if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil {
+ return ErrInvitationCodeInvalid
+ }
+ }
+
+ s.updateOAuthSignupSource(ctx, user.ID, signupSource)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+ s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
+ return nil
+}
+
+// RollbackOAuthEmailAccountCreation removes a partially-created local account
+// and restores any invitation code already consumed by that account.
+func (s *AuthService) RollbackOAuthEmailAccountCreation(ctx context.Context, userID int64, invitationCode string) error {
+ if s == nil || s.userRepo == nil || userID <= 0 {
+ return ErrServiceUnavailable
+ }
+ if err := s.restoreOAuthRegistrationInvitation(ctx, invitationCode, userID); err != nil {
+ return err
+ }
+ if err := s.userRepo.Delete(ctx, userID); err != nil {
+ return fmt.Errorf("delete created oauth user: %w", err)
+ }
+ return nil
+}
+
+func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, invitationCode string, userID int64) error {
+ if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
+ return nil
+ }
+ if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil {
+ return ErrServiceUnavailable
+ }
+
+ invitationCode = strings.TrimSpace(invitationCode)
+ if invitationCode == "" || userID <= 0 {
+ return nil
+ }
+
+ redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode)
+ if err != nil {
+ if errors.Is(err, ErrRedeemCodeNotFound) {
+ return nil
+ }
+ return fmt.Errorf("load invitation code: %w", err)
+ }
+ if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUsed || redeemCode.UsedBy == nil || *redeemCode.UsedBy != userID {
+ return nil
+ }
+
+ redeemCode.Status = StatusUnused
+ redeemCode.UsedBy = nil
+ redeemCode.UsedAt = nil
+ if err := s.updateOAuthRegistrationInvitation(ctx, redeemCode); err != nil {
+ return fmt.Errorf("restore invitation code: %w", err)
+ }
+ return nil
+}
+
+func (s *AuthService) oauthEmailFlowClient(ctx context.Context) *dbent.Client {
+ if s == nil || s.entClient == nil {
+ return nil
+ }
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return tx.Client()
+ }
+ return s.entClient
+}
+
+func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) {
+ if client := s.oauthEmailFlowClient(ctx); client != nil {
+ entity, err := client.RedeemCode.Query().Where(redeemcode.CodeEQ(invitationCode)).Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrRedeemCodeNotFound
+ }
+ return nil, err
+ }
+ return &RedeemCode{
+ ID: entity.ID,
+ Code: entity.Code,
+ Type: entity.Type,
+ Value: entity.Value,
+ Status: entity.Status,
+ UsedBy: entity.UsedBy,
+ UsedAt: entity.UsedAt,
+ Notes: oauthEmailFlowStringValue(entity.Notes),
+ CreatedAt: entity.CreatedAt,
+ GroupID: entity.GroupID,
+ ValidityDays: entity.ValidityDays,
+ }, nil
+ }
+ return s.redeemRepo.GetByCode(ctx, invitationCode)
+}
+
+func (s *AuthService) useOAuthRegistrationInvitation(ctx context.Context, invitationID, userID int64) error {
+ if client := s.oauthEmailFlowClient(ctx); client != nil {
+ affected, err := client.RedeemCode.Update().
+ Where(redeemcode.IDEQ(invitationID), redeemcode.StatusEQ(StatusUnused)).
+ SetStatus(StatusUsed).
+ SetUsedBy(userID).
+ SetUsedAt(time.Now().UTC()).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return ErrRedeemCodeUsed
+ }
+ return nil
+ }
+ return s.redeemRepo.Use(ctx, invitationID, userID)
+}
+
+func (s *AuthService) updateOAuthRegistrationInvitation(ctx context.Context, code *RedeemCode) error {
+ if code == nil {
+ return nil
+ }
+ if client := s.oauthEmailFlowClient(ctx); client != nil {
+ update := client.RedeemCode.UpdateOneID(code.ID).
+ SetCode(code.Code).
+ SetType(code.Type).
+ SetValue(code.Value).
+ SetStatus(code.Status).
+ SetNotes(code.Notes).
+ SetValidityDays(code.ValidityDays)
+ if code.UsedBy != nil {
+ update = update.SetUsedBy(*code.UsedBy)
+ } else {
+ update = update.ClearUsedBy()
+ }
+ if code.UsedAt != nil {
+ update = update.SetUsedAt(*code.UsedAt)
+ } else {
+ update = update.ClearUsedAt()
+ }
+ if code.GroupID != nil {
+ update = update.SetGroupID(*code.GroupID)
+ } else {
+ update = update.ClearGroupID()
+ }
+ _, err := update.Save(ctx)
+ return err
+ }
+ return s.redeemRepo.Update(ctx, code)
+}
+
+func (s *AuthService) updateOAuthSignupSource(ctx context.Context, userID int64, signupSource string) {
+ client := s.oauthEmailFlowClient(ctx)
+ if client == nil || userID <= 0 || strings.TrimSpace(signupSource) == "" {
+ return
+ }
+ _ = client.User.UpdateOneID(userID).SetSignupSource(signupSource).Exec(ctx)
+}
+
+func oauthEmailFlowStringValue(value *string) string {
+ if value == nil {
+ return ""
+ }
+ return *value
+}
+
+// ValidatePasswordCredentials checks the local password without completing the
+// login flow. This is used by pending third-party account adoption flows before
+// the external identity has been bound.
+func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, password string) (*User, error) {
+ if s == nil {
+ return nil, ErrServiceUnavailable
+ }
+
+ user, err := s.userRepo.GetByEmail(ctx, strings.TrimSpace(strings.ToLower(email)))
+ if err != nil {
+ if errors.Is(err, ErrUserNotFound) {
+ return nil, ErrInvalidCredentials
+ }
+ return nil, ErrServiceUnavailable
+ }
+ if !user.IsActive() {
+ return nil, ErrUserNotActive
+ }
+ if !s.CheckPassword(password, user.PasswordHash) {
+ return nil, ErrInvalidCredentials
+ }
+ return user, nil
+}
+
+// RecordSuccessfulLogin updates last-login activity after a non-standard login
+// flow finishes with a real session.
+func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) {
+ if s != nil && s.userRepo != nil && userID > 0 {
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err == nil && user != nil && !isReservedEmail(user.Email) {
+ s.backfillEmailIdentityOnSuccessfulLogin(ctx, user)
+ }
+ }
+ s.touchUserLogin(ctx, userID)
+}
diff --git a/backend/internal/service/auth_oauth_email_flow_test.go b/backend/internal/service/auth_oauth_email_flow_test.go
new file mode 100644
index 00000000..21d9d6e9
--- /dev/null
+++ b/backend/internal/service/auth_oauth_email_flow_test.go
@@ -0,0 +1,326 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+type redeemCodeRepoStub struct {
+ codesByCode map[string]*RedeemCode
+ useCalls []struct {
+ id int64
+ userID int64
+ }
+ updateCalls []*RedeemCode
+}
+
+func (s *redeemCodeRepoStub) Create(context.Context, *RedeemCode) error {
+ panic("unexpected Create call")
+}
+
+func (s *redeemCodeRepoStub) CreateBatch(context.Context, []RedeemCode) error {
+ panic("unexpected CreateBatch call")
+}
+
+func (s *redeemCodeRepoStub) GetByID(context.Context, int64) (*RedeemCode, error) {
+ panic("unexpected GetByID call")
+}
+
+func (s *redeemCodeRepoStub) GetByCode(_ context.Context, code string) (*RedeemCode, error) {
+ if s.codesByCode == nil {
+ return nil, ErrRedeemCodeNotFound
+ }
+ redeemCode, ok := s.codesByCode[code]
+ if !ok {
+ return nil, ErrRedeemCodeNotFound
+ }
+ cloned := *redeemCode
+ return &cloned, nil
+}
+
+func (s *redeemCodeRepoStub) Update(_ context.Context, code *RedeemCode) error {
+ if code == nil {
+ return nil
+ }
+ cloned := *code
+ s.updateCalls = append(s.updateCalls, &cloned)
+ if s.codesByCode == nil {
+ s.codesByCode = make(map[string]*RedeemCode)
+ }
+ s.codesByCode[cloned.Code] = &cloned
+ return nil
+}
+
+func (s *redeemCodeRepoStub) Delete(context.Context, int64) error {
+ panic("unexpected Delete call")
+}
+
+func (s *redeemCodeRepoStub) Use(_ context.Context, id, userID int64) error {
+ for code, redeemCode := range s.codesByCode {
+ if redeemCode.ID != id {
+ continue
+ }
+ now := time.Now().UTC()
+ redeemCode.Status = StatusUsed
+ redeemCode.UsedBy = &userID
+ redeemCode.UsedAt = &now
+ s.codesByCode[code] = redeemCode
+ s.useCalls = append(s.useCalls, struct {
+ id int64
+ userID int64
+ }{id: id, userID: userID})
+ return nil
+ }
+ return ErrRedeemCodeNotFound
+}
+
+func (s *redeemCodeRepoStub) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (s *redeemCodeRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (s *redeemCodeRepoStub) ListByUser(context.Context, int64, int) ([]RedeemCode, error) {
+ panic("unexpected ListByUser call")
+}
+
+func (s *redeemCodeRepoStub) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) {
+ panic("unexpected ListByUserPaginated call")
+}
+
+func (s *redeemCodeRepoStub) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
+ panic("unexpected SumPositiveBalanceByUser call")
+}
+
+func newOAuthEmailFlowAuthService(
+ userRepo UserRepository,
+ redeemRepo RedeemCodeRepository,
+ refreshTokenCache RefreshTokenCache,
+ settings map[string]string,
+ emailCache EmailCache,
+) *AuthService {
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 3.5,
+ UserConcurrency: 2,
+ },
+ }
+
+ settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
+ emailService := NewEmailService(&settingRepoStub{values: settings}, emailCache)
+
+ return NewAuthService(
+ nil,
+ userRepo,
+ redeemRepo,
+ refreshTokenCache,
+ cfg,
+ settingService,
+ emailService,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+}
+
+func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFails(t *testing.T) {
+ userRepo := &userRepoStub{nextID: 42}
+ redeemRepo := &redeemCodeRepoStub{
+ codesByCode: map[string]*RedeemCode{
+ "INVITE123": {
+ ID: 7,
+ Code: "INVITE123",
+ Type: RedeemTypeInvitation,
+ Status: StatusUnused,
+ },
+ },
+ }
+ emailCache := &emailCacheStub{
+ data: &VerificationCodeData{
+ Code: "246810",
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ }
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ redeemRepo,
+ nil,
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyInvitationCodeEnabled: "true",
+ SettingKeyEmailVerifyEnabled: "true",
+ },
+ emailCache,
+ )
+
+ tokenPair, user, err := authService.RegisterOAuthEmailAccount(
+ context.Background(),
+ "fresh@example.com",
+ "secret-123",
+ "246810",
+ "INVITE123",
+ "oidc",
+ )
+
+ require.Nil(t, tokenPair)
+ require.Nil(t, user)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "generate token pair")
+ require.Equal(t, []int64{42}, userRepo.deletedIDs)
+ require.Len(t, userRepo.created, 1)
+ require.Empty(t, redeemRepo.useCalls)
+ require.Empty(t, redeemRepo.updateCalls)
+}
+
+func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *testing.T) {
+ userRepo := &userRepoStub{nextID: 42}
+ emailCache := &emailCacheStub{
+ data: &VerificationCodeData{
+ Code: "246810",
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ }
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ &redeemCodeRepoStub{},
+ &refreshTokenCacheStub{},
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyEmailVerifyEnabled: "true",
+ },
+ emailCache,
+ )
+
+ tokenPair, user, err := authService.RegisterOAuthEmailAccount(
+ context.Background(),
+ "fresh@example.com",
+ "secret-123",
+ "246810",
+ "",
+ " OIDC ",
+ )
+
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.NotNil(t, user)
+ require.Len(t, userRepo.created, 1)
+ require.Equal(t, "oidc", userRepo.created[0].SignupSource)
+}
+
+func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) {
+ userRepo := &userRepoStub{nextID: 43}
+ emailCache := &emailCacheStub{
+ data: &VerificationCodeData{
+ Code: "246810",
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
+ },
+ }
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ &redeemCodeRepoStub{},
+ &refreshTokenCacheStub{},
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyEmailVerifyEnabled: "true",
+ },
+ emailCache,
+ )
+
+ tokenPair, user, err := authService.RegisterOAuthEmailAccount(
+ context.Background(),
+ "fallback@example.com",
+ "secret-123",
+ "246810",
+ "",
+ "github",
+ )
+
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.NotNil(t, user)
+ require.Len(t, userRepo.created, 1)
+ require.Equal(t, "email", userRepo.created[0].SignupSource)
+}
+
+func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
+ userRepo := &userRepoStub{}
+ redeemRepo := &redeemCodeRepoStub{
+ codesByCode: map[string]*RedeemCode{
+ "INVITE123": {
+ ID: 7,
+ Code: "INVITE123",
+ Type: RedeemTypeInvitation,
+ Status: StatusUsed,
+ UsedBy: func() *int64 {
+ v := int64(42)
+ return &v
+ }(),
+ UsedAt: func() *time.Time {
+ v := time.Now().UTC()
+ return &v
+ }(),
+ },
+ },
+ }
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ redeemRepo,
+ &refreshTokenCacheStub{},
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyInvitationCodeEnabled: "true",
+ },
+ &emailCacheStub{},
+ )
+
+ err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123")
+
+ require.NoError(t, err)
+ require.Equal(t, []int64{42}, userRepo.deletedIDs)
+ require.Len(t, redeemRepo.updateCalls, 1)
+ require.Equal(t, StatusUnused, redeemRepo.updateCalls[0].Status)
+ require.Nil(t, redeemRepo.updateCalls[0].UsedBy)
+ require.Nil(t, redeemRepo.updateCalls[0].UsedAt)
+}
+
+func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) {
+ userRepo := &userRepoStub{deleteErr: errors.New("delete failed")}
+ authService := newOAuthEmailFlowAuthService(
+ userRepo,
+ &redeemCodeRepoStub{},
+ &refreshTokenCacheStub{},
+ map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ },
+ &emailCacheStub{},
+ )
+
+ err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "")
+
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "delete created oauth user")
+}
diff --git a/backend/internal/service/auth_oauth_first_bind.go b/backend/internal/service/auth_oauth_first_bind.go
new file mode 100644
index 00000000..aa06e59f
--- /dev/null
+++ b/backend/internal/service/auth_oauth_first_bind.go
@@ -0,0 +1,104 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+
+ entsql "entgo.io/ent/dialect/sql"
+)
+
+// ApplyProviderDefaultSettingsOnFirstBind applies provider-specific bootstrap
+// settings the first time a user binds a third-party identity. The grant is
+// idempotent per user/provider pair.
+func (s *AuthService) ApplyProviderDefaultSettingsOnFirstBind(
+ ctx context.Context,
+ userID int64,
+ providerType string,
+) error {
+ if s == nil || s.entClient == nil || s.settingService == nil || userID <= 0 {
+ return nil
+ }
+
+ if dbent.TxFromContext(ctx) != nil {
+ return s.applyProviderDefaultSettingsOnFirstBind(ctx, userID, providerType)
+ }
+
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil {
+ return fmt.Errorf("begin first bind defaults transaction: %w", err)
+ }
+ defer func() { _ = tx.Rollback() }()
+
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := s.applyProviderDefaultSettingsOnFirstBind(txCtx, userID, providerType); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func (s *AuthService) applyProviderDefaultSettingsOnFirstBind(
+ ctx context.Context,
+ userID int64,
+ providerType string,
+) error {
+ providerDefaults, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, providerType, true)
+ if err != nil {
+ return fmt.Errorf("load auth source defaults: %w", err)
+ }
+ if !enabled {
+ return nil
+ }
+
+ client := s.entClient
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ client = tx.Client()
+ }
+
+ var result entsql.Result
+ if err := client.Driver().Exec(
+ ctx,
+ `INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason)
+VALUES ($1, $2, $3)
+ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
+ []any{userID, strings.TrimSpace(providerType), "first_bind"},
+ &result,
+ ); err != nil {
+ return fmt.Errorf("record first bind provider grant: %w", err)
+ }
+
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return fmt.Errorf("read first bind provider grant result: %w", err)
+ }
+ if affected == 0 {
+ return nil
+ }
+
+ if providerDefaults.Balance != 0 {
+ if err := client.User.UpdateOneID(userID).AddBalance(providerDefaults.Balance).Exec(ctx); err != nil {
+ return fmt.Errorf("apply first bind balance default: %w", err)
+ }
+ }
+ if providerDefaults.Concurrency != 0 {
+ if err := client.User.UpdateOneID(userID).AddConcurrency(providerDefaults.Concurrency).Exec(ctx); err != nil {
+ return fmt.Errorf("apply first bind concurrency default: %w", err)
+ }
+ }
+ if s.defaultSubAssigner != nil {
+ for _, item := range providerDefaults.Subscriptions {
+ if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
+ UserID: userID,
+ GroupID: item.GroupID,
+ ValidityDays: item.ValidityDays,
+ Notes: "auto assigned by first bind defaults",
+ }); err != nil {
+ return fmt.Errorf("apply first bind subscription default: %w", err)
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go
new file mode 100644
index 00000000..6e69c121
--- /dev/null
+++ b/backend/internal/service/auth_pending_identity_service.go
@@ -0,0 +1,543 @@
+package service
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "hash/fnv"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+
+ "entgo.io/ent/dialect"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
+ dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+
+ entsql "entgo.io/ent/dialect/sql"
+)
+
+var (
+ ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found")
+ ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired")
+ ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used")
+ ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid")
+ ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired")
+ ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used")
+ ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session")
+)
+
+const (
+ defaultPendingAuthTTL = 15 * time.Minute
+ defaultPendingAuthCompletionTTL = 5 * time.Minute
+)
+
+type PendingAuthIdentityKey struct {
+ ProviderType string
+ ProviderKey string
+ ProviderSubject string
+}
+
+type CreatePendingAuthSessionInput struct {
+ SessionToken string
+ Intent string
+ Identity PendingAuthIdentityKey
+ TargetUserID *int64
+ RedirectTo string
+ ResolvedEmail string
+ RegistrationPasswordHash string
+ BrowserSessionKey string
+ UpstreamIdentityClaims map[string]any
+ LocalFlowState map[string]any
+ ExpiresAt time.Time
+}
+
+type IssuePendingAuthCompletionCodeInput struct {
+ PendingAuthSessionID int64
+ BrowserSessionKey string
+ TTL time.Duration
+}
+
+type IssuePendingAuthCompletionCodeResult struct {
+ Code string
+ ExpiresAt time.Time
+}
+
+type PendingIdentityAdoptionDecisionInput struct {
+ PendingAuthSessionID int64
+ IdentityID *int64
+ AdoptDisplayName bool
+ AdoptAvatar bool
+}
+
+type AuthPendingIdentityService struct {
+ entClient *dbent.Client
+}
+
+var authPendingIdentityScopedKeyLocks = newAuthPendingIdentityScopedKeyLockRegistry()
+
+type authPendingIdentityScopedKeyLockRegistry struct {
+ mu sync.Mutex
+ locks map[string]*authPendingIdentityScopedKeyLockEntry
+}
+
+type authPendingIdentityScopedKeyLockEntry struct {
+ mu sync.Mutex
+ refs int
+}
+
+func newAuthPendingIdentityScopedKeyLockRegistry() *authPendingIdentityScopedKeyLockRegistry {
+ return &authPendingIdentityScopedKeyLockRegistry{
+ locks: make(map[string]*authPendingIdentityScopedKeyLockEntry),
+ }
+}
+
+func (r *authPendingIdentityScopedKeyLockRegistry) lock(keys ...string) func() {
+ normalized := normalizeAuthPendingIdentityLockKeys(keys...)
+ if len(normalized) == 0 {
+ return func() {}
+ }
+
+ entries := make([]*authPendingIdentityScopedKeyLockEntry, 0, len(normalized))
+ r.mu.Lock()
+ for _, key := range normalized {
+ entry := r.locks[key]
+ if entry == nil {
+ entry = &authPendingIdentityScopedKeyLockEntry{}
+ r.locks[key] = entry
+ }
+ entry.refs++
+ entries = append(entries, entry)
+ }
+ r.mu.Unlock()
+
+ for _, entry := range entries {
+ entry.mu.Lock()
+ }
+
+ return func() {
+ for i := len(entries) - 1; i >= 0; i-- {
+ entries[i].mu.Unlock()
+ }
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ for idx, key := range normalized {
+ entry := entries[idx]
+ entry.refs--
+ if entry.refs == 0 {
+ delete(r.locks, key)
+ }
+ }
+ }
+}
+
+func normalizeAuthPendingIdentityLockKeys(keys ...string) []string {
+ if len(keys) == 0 {
+ return nil
+ }
+
+ deduped := make(map[string]struct{}, len(keys))
+ for _, key := range keys {
+ trimmed := strings.TrimSpace(key)
+ if trimmed == "" {
+ continue
+ }
+ deduped[trimmed] = struct{}{}
+ }
+ if len(deduped) == 0 {
+ return nil
+ }
+
+ normalized := make([]string, 0, len(deduped))
+ for key := range deduped {
+ normalized = append(normalized, key)
+ }
+ sort.Strings(normalized)
+ return normalized
+}
+
+func authPendingIdentityAdvisoryLockHash(key string) int64 {
+ hasher := fnv.New64a()
+ _, _ = hasher.Write([]byte(key))
+ return int64(hasher.Sum64())
+}
+
+func lockAuthPendingIdentityKeys(ctx context.Context, client *dbent.Client, keys ...string) (func(), error) {
+ release := authPendingIdentityScopedKeyLocks.lock(keys...)
+ normalized := normalizeAuthPendingIdentityLockKeys(keys...)
+ if len(normalized) == 0 || client == nil || client.Driver().Dialect() != dialect.Postgres {
+ return release, nil
+ }
+
+ for _, key := range normalized {
+ var rows entsql.Rows
+ if err := client.Driver().Query(ctx, "SELECT pg_advisory_xact_lock($1)", []any{authPendingIdentityAdvisoryLockHash(key)}, &rows); err != nil {
+ release()
+ return nil, err
+ }
+ _ = rows.Close()
+ }
+
+ return release, nil
+}
+
+func pendingIdentityAdoptionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
+ keys := []string{fmt.Sprintf("pending-auth-adoption:pending:%d", pendingAuthSessionID)}
+ if identityID != nil && *identityID > 0 {
+ keys = append(keys, fmt.Sprintf("pending-auth-adoption:identity:%d", *identityID))
+ }
+ return keys
+}
+
+func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
+ return &AuthPendingIdentityService{entClient: entClient}
+}
+
+func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ sessionToken := strings.TrimSpace(input.SessionToken)
+ if sessionToken == "" {
+ var err error
+ sessionToken, err = randomOpaqueToken(24)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ expiresAt := input.ExpiresAt.UTC()
+ if expiresAt.IsZero() {
+ expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL)
+ }
+
+ create := s.entClient.PendingAuthSession.Create().
+ SetSessionToken(sessionToken).
+ SetIntent(strings.TrimSpace(input.Intent)).
+ SetProviderType(strings.TrimSpace(input.Identity.ProviderType)).
+ SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)).
+ SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)).
+ SetRedirectTo(strings.TrimSpace(input.RedirectTo)).
+ SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)).
+ SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)).
+ SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)).
+ SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)).
+ SetLocalFlowState(copyPendingMap(input.LocalFlowState)).
+ SetExpiresAt(expiresAt)
+ if input.TargetUserID != nil {
+ create = create.SetTargetUserID(*input.TargetUserID)
+ }
+ return create.Save(ctx)
+}
+
+func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+ return nil, err
+ }
+
+ code, err := randomOpaqueToken(24)
+ if err != nil {
+ return nil, err
+ }
+ ttl := input.TTL
+ if ttl <= 0 {
+ ttl = defaultPendingAuthCompletionTTL
+ }
+ expiresAt := time.Now().UTC().Add(ttl)
+
+ update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
+ SetCompletionCodeHash(hashPendingAuthCode(code)).
+ SetCompletionCodeExpiresAt(expiresAt)
+ if strings.TrimSpace(input.BrowserSessionKey) != "" {
+ update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey))
+ }
+ if _, err := update.Save(ctx); err != nil {
+ return nil, err
+ }
+
+ return &IssuePendingAuthCompletionCodeResult{
+ Code: code,
+ ExpiresAt: expiresAt,
+ }, nil
+}
+
+func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode))
+ session, err := s.entClient.PendingAuthSession.Query().
+ Where(pendingauthsession.CompletionCodeHashEQ(codeHash)).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrPendingAuthCodeInvalid
+ }
+ return nil, err
+ }
+
+ return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed)
+}
+
+func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ session, err := s.getBrowserSession(ctx, sessionToken)
+ if err != nil {
+ return nil, err
+ }
+
+ return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
+}
+
+func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ session, err := s.getBrowserSession(ctx, sessionToken)
+ if err != nil {
+ return nil, err
+ }
+ if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil {
+ return nil, err
+ }
+ return session, nil
+}
+
+func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ sessionToken = strings.TrimSpace(sessionToken)
+ if sessionToken == "" {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+
+ session, err := s.entClient.PendingAuthSession.Query().
+ Where(pendingauthsession.SessionTokenEQ(sessionToken)).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+ return nil, err
+ }
+ return session, nil
+}
+
+func (s *AuthPendingIdentityService) consumeSession(
+ ctx context.Context,
+ session *dbent.PendingAuthSession,
+ browserSessionKey string,
+ expiredErr error,
+ consumedErr error,
+) (*dbent.PendingAuthSession, error) {
+ if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil {
+ return nil, err
+ }
+
+ sanitizedLocalFlowState := sanitizePendingAuthLocalFlowState(session.LocalFlowState)
+ now := time.Now().UTC()
+ update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
+ Where(
+ pendingauthsession.ConsumedAtIsNil(),
+ pendingauthsession.ExpiresAtGTE(now),
+ pendingauthsession.Or(
+ pendingauthsession.CompletionCodeExpiresAtIsNil(),
+ pendingauthsession.CompletionCodeExpiresAtGTE(now),
+ ),
+ ).
+ SetConsumedAt(now).
+ SetLocalFlowState(sanitizedLocalFlowState).
+ SetCompletionCodeHash("").
+ ClearCompletionCodeExpiresAt()
+ if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" {
+ update = update.Where(pendingauthsession.BrowserSessionKeyEQ(expectedBrowserSessionKey))
+ }
+ updated, err := update.Save(ctx)
+ if err == nil {
+ return updated, nil
+ }
+ if !dbent.IsNotFound(err) {
+ return nil, err
+ }
+
+ current, currentErr := s.entClient.PendingAuthSession.Get(ctx, session.ID)
+ if currentErr != nil {
+ if dbent.IsNotFound(currentErr) {
+ return nil, ErrPendingAuthSessionNotFound
+ }
+ return nil, currentErr
+ }
+ if err := validatePendingSessionState(current, browserSessionKey, expiredErr, consumedErr); err != nil {
+ return nil, err
+ }
+ return nil, consumedErr
+}
+
+func sanitizePendingAuthLocalFlowState(localFlowState map[string]any) map[string]any {
+ sanitized := copyPendingMap(localFlowState)
+ if len(sanitized) == 0 {
+ return sanitized
+ }
+
+ rawCompletion, ok := sanitized["completion_response"]
+ if !ok {
+ return sanitized
+ }
+ completion, ok := rawCompletion.(map[string]any)
+ if !ok {
+ return sanitized
+ }
+
+ cleanedCompletion := copyPendingMap(completion)
+ for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
+ delete(cleanedCompletion, key)
+ }
+ sanitized["completion_response"] = cleanedCompletion
+ return sanitized
+}
+
+func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
+ if session == nil {
+ return ErrPendingAuthSessionNotFound
+ }
+
+ now := time.Now().UTC()
+ if session.ConsumedAt != nil {
+ return consumedErr
+ }
+ if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
+ return expiredErr
+ }
+ if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) {
+ return expiredErr
+ }
+ if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
+ return ErrPendingAuthBrowserMismatch
+ }
+ return nil
+}
+
+func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
+ if s == nil || s.entClient == nil {
+ return nil, fmt.Errorf("pending auth ent client is not configured")
+ }
+
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
+ return nil, err
+ }
+
+ client := s.entClient
+ txCtx := ctx
+ if err == nil {
+ defer func() { _ = tx.Rollback() }()
+ client = tx.Client()
+ txCtx = dbent.NewTxContext(ctx, tx)
+ } else if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
+ client = existingTx.Client()
+ }
+
+ releaseLocks, err := lockAuthPendingIdentityKeys(txCtx, client, pendingIdentityAdoptionLockKeys(input.PendingAuthSessionID, input.IdentityID)...)
+ if err != nil {
+ return nil, err
+ }
+ defer releaseLocks()
+
+ if input.IdentityID != nil && *input.IdentityID > 0 {
+ if _, err := client.IdentityAdoptionDecision.Update().
+ Where(
+ identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
+ dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
+ col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
+ s.Where(entsql.Or(
+ entsql.IsNull(col),
+ entsql.NEQ(col, input.PendingAuthSessionID),
+ ))
+ }),
+ ).
+ ClearIdentityID().
+ Save(txCtx); err != nil {
+ return nil, err
+ }
+ }
+
+ create := client.IdentityAdoptionDecision.Create().
+ SetPendingAuthSessionID(input.PendingAuthSessionID).
+ SetAdoptDisplayName(input.AdoptDisplayName).
+ SetAdoptAvatar(input.AdoptAvatar).
+ SetDecidedAt(time.Now().UTC())
+ if input.IdentityID != nil && *input.IdentityID > 0 {
+ create = create.SetIdentityID(*input.IdentityID)
+ }
+
+ decisionID, err := create.
+ OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
+ UpdateNewValues().
+ ID(txCtx)
+ if err != nil {
+ return nil, err
+ }
+
+ decision, err := client.IdentityAdoptionDecision.Get(txCtx, decisionID)
+ if err != nil {
+ return nil, err
+ }
+
+ if tx != nil {
+ if err := tx.Commit(); err != nil {
+ return nil, err
+ }
+ }
+
+ return decision, nil
+}
+
+func copyPendingMap(in map[string]any) map[string]any {
+ if len(in) == 0 {
+ return map[string]any{}
+ }
+ out := make(map[string]any, len(in))
+ for k, v := range in {
+ out[k] = v
+ }
+ return out
+}
+
+func randomOpaqueToken(byteLen int) (string, error) {
+ if byteLen <= 0 {
+ byteLen = 16
+ }
+ buf := make([]byte, byteLen)
+ if _, err := rand.Read(buf); err != nil {
+ return "", err
+ }
+ return hex.EncodeToString(buf), nil
+}
+
+func hashPendingAuthCode(code string) string {
+ sum := sha256.Sum256([]byte(code))
+ return hex.EncodeToString(sum[:])
+}
diff --git a/backend/internal/service/auth_pending_identity_service_test.go b/backend/internal/service/auth_pending_identity_service_test.go
new file mode 100644
index 00000000..555bb0e7
--- /dev/null
+++ b/backend/internal/service/auth_pending_identity_service_test.go
@@ -0,0 +1,526 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "database/sql"
+ "sync"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func newAuthPendingIdentityServiceTestClient(t *testing.T) (*AuthPendingIdentityService, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_pending_identity_service?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ return NewAuthPendingIdentityService(client), client
+}
+
+func TestAuthPendingIdentityService_CreatePendingSessionStoresSeparatedState(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ targetUser, err := client.User.Create().
+ SetEmail("pending-target@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-123",
+ },
+ TargetUserID: &targetUser.ID,
+ RedirectTo: "/profile",
+ ResolvedEmail: "user@example.com",
+ BrowserSessionKey: "browser-1",
+ UpstreamIdentityClaims: map[string]any{"nickname": "wx-user", "avatar_url": "https://cdn.example/avatar.png"},
+ LocalFlowState: map[string]any{"step": "email_required"},
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, session.SessionToken)
+ require.Equal(t, "bind_current_user", session.Intent)
+ require.Equal(t, "wechat", session.ProviderType)
+ require.NotNil(t, session.TargetUserID)
+ require.Equal(t, targetUser.ID, *session.TargetUserID)
+ require.Equal(t, "wx-user", session.UpstreamIdentityClaims["nickname"])
+ require.Equal(t, "email_required", session.LocalFlowState["step"])
+}
+
+func TestAuthPendingIdentityService_CompletionCodeIsBrowserBoundAndOneTime(t *testing.T) {
+ svc, _ := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo-main",
+ ProviderSubject: "subject-1",
+ },
+ BrowserSessionKey: "browser-expected",
+ UpstreamIdentityClaims: map[string]any{"nickname": "linux-user"},
+ LocalFlowState: map[string]any{"step": "pending"},
+ })
+ require.NoError(t, err)
+
+ issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
+ PendingAuthSessionID: session.ID,
+ BrowserSessionKey: "browser-expected",
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, issued.Code)
+
+ _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-other")
+ require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
+
+ consumed, err := svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+ require.Empty(t, consumed.CompletionCodeHash)
+ require.Nil(t, consumed.CompletionCodeExpiresAt)
+
+ _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
+ require.ErrorIs(t, err, ErrPendingAuthCodeInvalid)
+}
+
+func TestAuthPendingIdentityService_CompletionCodeExpires(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "oidc",
+ ProviderKey: "https://issuer.example",
+ ProviderSubject: "subject-1",
+ },
+ BrowserSessionKey: "browser-expired",
+ })
+ require.NoError(t, err)
+
+ issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
+ PendingAuthSessionID: session.ID,
+ BrowserSessionKey: "browser-expired",
+ TTL: time.Second,
+ })
+ require.NoError(t, err)
+
+ _, err = client.PendingAuthSession.UpdateOneID(session.ID).
+ SetCompletionCodeExpiresAt(time.Now().UTC().Add(-time.Minute)).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expired")
+ require.ErrorIs(t, err, ErrPendingAuthCodeExpired)
+}
+
+func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("adoption@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-open").
+ SetProviderSubject("union-adoption").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-adoption",
+ },
+ })
+ require.NoError(t, err)
+
+ first, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ require.NoError(t, err)
+ require.True(t, first.AdoptDisplayName)
+ require.False(t, first.AdoptAvatar)
+ require.Nil(t, first.IdentityID)
+
+ second, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+ require.Equal(t, first.ID, second.ID)
+ require.NotNil(t, second.IdentityID)
+ require.Equal(t, identity.ID, *second.IdentityID)
+ require.True(t, second.AdoptAvatar)
+}
+
+func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIdentityReference(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("adoption-reassign@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-open").
+ SetProviderSubject("union-reassign").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ firstSession, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-reassign",
+ },
+ })
+ require.NoError(t, err)
+
+ firstDecision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: firstSession.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: false,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, firstDecision.IdentityID)
+ require.Equal(t, identity.ID, *firstDecision.IdentityID)
+
+ secondSession, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-open",
+ ProviderSubject: "union-reassign",
+ },
+ })
+ require.NoError(t, err)
+
+ secondDecision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: secondSession.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: false,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, secondDecision.IdentityID)
+ require.Equal(t, identity.ID, *secondDecision.IdentityID)
+
+ reloadedFirst, err := client.IdentityAdoptionDecision.Get(ctx, firstDecision.ID)
+ require.NoError(t, err)
+ require.Nil(t, reloadedFirst.IdentityID)
+}
+
+func TestAuthPendingIdentityService_UpsertAdoptionDecision_IsIdempotentUnderConcurrency(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("adoption-concurrent@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("union-concurrent").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "union-concurrent",
+ },
+ })
+ require.NoError(t, err)
+
+ firstCreateStarted := make(chan struct{})
+ releaseFirstCreate := make(chan struct{})
+ var firstCreate sync.Once
+ client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
+ return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
+ blocked := false
+ if m.Op().Is(dbent.OpCreate) {
+ firstCreate.Do(func() {
+ blocked = true
+ close(firstCreateStarted)
+ })
+ }
+ if blocked {
+ <-releaseFirstCreate
+ }
+ return next.Mutate(ctx, m)
+ })
+ })
+
+ type adoptionResult struct {
+ decision *dbent.IdentityAdoptionDecision
+ err error
+ }
+
+ input := PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: true,
+ AdoptAvatar: true,
+ }
+
+ results := make(chan adoptionResult, 2)
+ go func() {
+ decision, err := svc.UpsertAdoptionDecision(ctx, input)
+ results <- adoptionResult{decision: decision, err: err}
+ }()
+
+ <-firstCreateStarted
+
+ go func() {
+ decision, err := svc.UpsertAdoptionDecision(ctx, input)
+ results <- adoptionResult{decision: decision, err: err}
+ }()
+
+ time.Sleep(100 * time.Millisecond)
+ close(releaseFirstCreate)
+
+ first := <-results
+ second := <-results
+
+ require.NoError(t, first.err)
+ require.NoError(t, second.err)
+ require.NotNil(t, first.decision)
+ require.NotNil(t, second.decision)
+ require.Equal(t, first.decision.ID, second.decision.ID)
+
+ count, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, count)
+
+ loaded, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, loaded.IdentityID)
+ require.Equal(t, identity.ID, *loaded.IdentityID)
+}
+
+func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) {
+ t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL")
+
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ user, err := client.User.Create().
+ SetEmail("legacy-null-session@example.com").
+ SetPasswordHash("hash").
+ SetRole(RoleUser).
+ SetStatus(StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ identity, err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("wechat").
+ SetProviderKey("wechat-main").
+ SetProviderSubject("legacy-null-session").
+ SetMetadata(map[string]any{}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ _, err = client.ExecContext(
+ ctx,
+ `INSERT INTO identity_adoption_decisions
+ (identity_id, adopt_display_name, adopt_avatar, decided_at, created_at, updated_at, pending_auth_session_id)
+ VALUES (?, ?, ?, ?, ?, ?, NULL)`,
+ identity.ID,
+ true,
+ false,
+ time.Now().UTC(),
+ time.Now().UTC(),
+ time.Now().UTC(),
+ )
+ require.NoError(t, err)
+ legacyDecision, err := client.IdentityAdoptionDecision.Query().
+ Where(identityadoptiondecision.IdentityIDEQ(identity.ID)).
+ Only(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, legacyDecision.IdentityID)
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "bind_current_user",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "wechat",
+ ProviderKey: "wechat-main",
+ ProviderSubject: "legacy-null-session",
+ },
+ })
+ require.NoError(t, err)
+
+ decision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
+ PendingAuthSessionID: session.ID,
+ IdentityID: &identity.ID,
+ AdoptDisplayName: false,
+ AdoptAvatar: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, decision.IdentityID)
+ require.Equal(t, identity.ID, *decision.IdentityID)
+
+ reloadedLegacy, err := client.IdentityAdoptionDecision.Get(ctx, legacyDecision.ID)
+ require.NoError(t, err)
+ require.Nil(t, reloadedLegacy.IdentityID)
+}
+
+func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
+ svc, _ := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "subject-session-token",
+ },
+ BrowserSessionKey: "browser-session",
+ LocalFlowState: map[string]any{
+ "completion_response": map[string]any{
+ "access_token": "token",
+ },
+ },
+ })
+ require.NoError(t, err)
+
+ _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-other")
+ require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
+
+ consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+
+ _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
+ require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
+}
+
+func TestAuthPendingIdentityService_ConsumeBrowserSessionRejectsStaleLoadedSessionReplay(t *testing.T) {
+ svc, _ := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "stale-replay-subject",
+ },
+ BrowserSessionKey: "browser-session",
+ })
+ require.NoError(t, err)
+
+ loaded, err := svc.getBrowserSession(ctx, session.SessionToken)
+ require.NoError(t, err)
+
+ consumed, err := svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+
+ _, err = svc.consumeSession(ctx, loaded, "browser-session", ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
+ require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
+}
+
+func TestAuthPendingIdentityService_ConsumeBrowserSessionScrubsLegacyCompletionTokens(t *testing.T) {
+ svc, client := newAuthPendingIdentityServiceTestClient(t)
+ ctx := context.Background()
+
+ session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
+ Intent: "login",
+ Identity: PendingAuthIdentityKey{
+ ProviderType: "linuxdo",
+ ProviderKey: "linuxdo",
+ ProviderSubject: "legacy-token-subject",
+ },
+ BrowserSessionKey: "browser-session",
+ LocalFlowState: map[string]any{
+ "completion_response": map[string]any{
+ "access_token": "legacy-access-token",
+ "refresh_token": "legacy-refresh-token",
+ "expires_in": float64(3600),
+ "token_type": "Bearer",
+ "redirect": "/dashboard",
+ },
+ },
+ })
+ require.NoError(t, err)
+
+ consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
+ require.NoError(t, err)
+ require.NotNil(t, consumed.ConsumedAt)
+
+ stored, err := client.PendingAuthSession.Get(ctx, session.ID)
+ require.NoError(t, err)
+
+ completion, ok := stored.LocalFlowState["completion_response"].(map[string]any)
+ require.True(t, ok)
+ require.NotContains(t, completion, "access_token")
+ require.NotContains(t, completion, "refresh_token")
+ require.NotContains(t, completion, "expires_in")
+ require.NotContains(t, completion, "token_type")
+ require.Equal(t, "/dashboard", completion["redirect"])
+}
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index 42b6cf91..8f4461e4 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"crypto/sha256"
+ "encoding/binary"
"encoding/hex"
"errors"
"fmt"
@@ -13,6 +14,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -70,6 +72,7 @@ type AuthService struct {
turnstileService *TurnstileService
emailQueueService *EmailQueueService
promoService *PromoService
+ affiliateService *AffiliateService
defaultSubAssigner DefaultSubscriptionAssigner
}
@@ -77,6 +80,12 @@ type DefaultSubscriptionAssigner interface {
AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error)
}
+type signupGrantPlan struct {
+ Balance float64
+ Concurrency int
+ Subscriptions []DefaultSubscriptionSetting
+}
+
// NewAuthService 创建认证服务实例
func NewAuthService(
entClient *dbent.Client,
@@ -90,6 +99,7 @@ func NewAuthService(
emailQueueService *EmailQueueService,
promoService *PromoService,
defaultSubAssigner DefaultSubscriptionAssigner,
+ affiliateService *AffiliateService,
) *AuthService {
return &AuthService{
entClient: entClient,
@@ -102,17 +112,25 @@ func NewAuthService(
turnstileService: turnstileService,
emailQueueService: emailQueueService,
promoService: promoService,
+ affiliateService: affiliateService,
defaultSubAssigner: defaultSubAssigner,
}
}
-// Register 用户注册,返回token和用户
-func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
- return s.RegisterWithVerification(ctx, email, password, "", "", "")
+func (s *AuthService) EntClient() *dbent.Client {
+ if s == nil {
+ return nil
+ }
+ return s.entClient
}
-// RegisterWithVerification 用户注册(支持邮件验证、优惠码和邀请码),返回token和用户
-func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) {
+// Register 用户注册,返回token和用户
+func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
+ return s.RegisterWithVerification(ctx, email, password, "", "", "", "")
+}
+
+// RegisterWithVerification 用户注册(支持邮件验证、优惠码、邀请码和邀请返利码),返回token和用户。
+func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode, affiliateCode string) (string, *User, error) {
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
@@ -179,12 +197,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, fmt.Errorf("hash password: %w", err)
}
- // 获取默认配置
- defaultBalance := s.cfg.Default.UserBalance
- defaultConcurrency := s.cfg.Default.UserConcurrency
+ grantPlan := s.resolveSignupGrantPlan(ctx, "email")
+
+ // 新用户默认 RPM(0 = 不限制)。注册时写入,后续作为用户级兜底。
+ var defaultRPMLimit int
if s.settingService != nil {
- defaultBalance = s.settingService.GetDefaultBalance(ctx)
- defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
+ defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
// 创建用户
@@ -192,8 +210,9 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
Email: email,
PasswordHash: hashedPassword,
Role: RoleUser,
- Balance: defaultBalance,
- Concurrency: defaultConcurrency,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
+ RPMLimit: defaultRPMLimit,
Status: StatusActive,
}
@@ -205,7 +224,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable
}
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, "email", true)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+ if s.affiliateService != nil {
+ if _, err := s.affiliateService.EnsureUserAffiliate(ctx, user.ID); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", user.ID, err)
+ }
+ if code := strings.TrimSpace(affiliateCode); code != "" {
+ if err := s.affiliateService.BindInviterByCode(ctx, user.ID, code); err != nil {
+ // 邀请返利码绑定失败不影响注册,只记录日志
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", user.ID, err)
+ }
+ }
+ }
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
@@ -469,12 +500,11 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
return "", nil, fmt.Errorf("hash password: %w", err)
}
- // 新用户默认值。
- defaultBalance := s.cfg.Default.UserBalance
- defaultConcurrency := s.cfg.Default.UserConcurrency
+ signupSource := inferLegacySignupSource(email)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
+ var defaultRPMLimit int
if s.settingService != nil {
- defaultBalance = s.settingService.GetDefaultBalance(ctx)
- defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
+ defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{
@@ -482,9 +512,11 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
- Balance: defaultBalance,
- Concurrency: defaultConcurrency,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
+ RPMLimit: defaultRPMLimit,
Status: StatusActive,
+ SignupSource: signupSource,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
@@ -501,7 +533,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, signupSource, false)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@@ -520,7 +553,6 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
-
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
@@ -531,7 +563,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
-func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode string) (*TokenPair, *User, error) {
+// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
+func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, nil, errors.New("refresh token cache not configured")
@@ -584,11 +617,11 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, fmt.Errorf("hash password: %w", err)
}
- defaultBalance := s.cfg.Default.UserBalance
- defaultConcurrency := s.cfg.Default.UserConcurrency
+ signupSource := inferLegacySignupSource(email)
+ grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
+ var defaultRPMLimit int
if s.settingService != nil {
- defaultBalance = s.settingService.GetDefaultBalance(ctx)
- defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
+ defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx)
}
newUser := &User{
@@ -596,9 +629,11 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
- Balance: defaultBalance,
- Concurrency: defaultConcurrency,
+ Balance: grantPlan.Balance,
+ Concurrency: grantPlan.Concurrency,
+ RPMLimit: defaultRPMLimit,
Status: StatusActive,
+ SignupSource: signupSource,
}
if s.entClient != nil && invitationRedeemCode != nil {
@@ -630,7 +665,9 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable
}
user = newUser
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, signupSource, false)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+ s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
}
} else {
if err := s.userRepo.Create(ctx, newUser); err != nil {
@@ -646,7 +683,9 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
- s.assignDefaultSubscriptions(ctx, user.ID)
+ s.postAuthUserBootstrap(ctx, user, signupSource, false)
+ s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
+ s.bindOAuthAffiliate(ctx, user.ID, affiliateCode)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
@@ -670,7 +709,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
-
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
return nil, nil, fmt.Errorf("generate token pair: %w", err)
@@ -678,80 +716,289 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil
}
-// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
-const pendingOAuthTokenTTL = 10 * time.Minute
-
-// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
-const pendingOAuthPurpose = "pending_oauth_registration"
-
-type pendingOAuthClaims struct {
- Email string `json:"email"`
- Username string `json:"username"`
- Purpose string `json:"purpose"`
- jwt.RegisteredClaims
-}
-
-// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
-// while waiting for the user to supply an invitation code.
-func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
- now := time.Now()
- claims := &pendingOAuthClaims{
- Email: email,
- Username: username,
- Purpose: pendingOAuthPurpose,
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
- IssuedAt: jwt.NewNumericDate(now),
- NotBefore: jwt.NewNumericDate(now),
- },
- }
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- return token.SignedString([]byte(s.cfg.JWT.Secret))
-}
-
-// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
-// Returns ErrInvalidToken when the token is invalid or expired.
-func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
- if len(tokenStr) > maxTokenLength {
- return "", "", ErrInvalidToken
- }
- parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
- token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
- if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
- return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
- }
- return []byte(s.cfg.JWT.Secret), nil
- })
- if parseErr != nil {
- return "", "", ErrInvalidToken
- }
- claims, ok := token.Claims.(*pendingOAuthClaims)
- if !ok || !token.Valid {
- return "", "", ErrInvalidToken
- }
- if claims.Purpose != pendingOAuthPurpose {
- return "", "", ErrInvalidToken
- }
- return claims.Email, claims.Username, nil
-}
-
-func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
+func (s *AuthService) assignSubscriptions(ctx context.Context, userID int64, items []DefaultSubscriptionSetting, notes string) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return
}
- items := s.settingService.GetDefaultSubscriptions(ctx)
for _, item := range items {
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
- Notes: "auto assigned by default user subscriptions setting",
+ Notes: notes,
}); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
}
}
}
+func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource string) signupGrantPlan {
+ plan := signupGrantPlan{}
+ if s != nil && s.cfg != nil {
+ plan.Balance = s.cfg.Default.UserBalance
+ plan.Concurrency = s.cfg.Default.UserConcurrency
+ }
+ if s == nil || s.settingService == nil {
+ return plan
+ }
+
+ plan.Balance = s.settingService.GetDefaultBalance(ctx)
+ plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx)
+ plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx)
+
+ resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false)
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err)
+ return plan
+ }
+ if !enabled {
+ return plan
+ }
+
+ plan.Balance = resolved.Balance
+ plan.Concurrency = resolved.Concurrency
+ plan.Subscriptions = resolved.Subscriptions
+ return plan
+}
+
+func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource string) (ProviderDefaultGrantSettings, bool) {
+ if defaults == nil {
+ return ProviderDefaultGrantSettings{}, false
+ }
+
+ switch strings.ToLower(strings.TrimSpace(signupSource)) {
+ case "email":
+ return defaults.Email, true
+ case "linuxdo":
+ return defaults.LinuxDo, true
+ case "oidc":
+ return defaults.OIDC, true
+ case "wechat":
+ return defaults.WeChat, true
+ default:
+ return ProviderDefaultGrantSettings{}, false
+ }
+}
+
+// bindOAuthAffiliate initializes the affiliate profile and binds the inviter
+// for an OAuth-registered user. Failures are logged but never block registration.
+func (s *AuthService) bindOAuthAffiliate(ctx context.Context, userID int64, affiliateCode string) {
+ if s.affiliateService == nil || userID <= 0 {
+ return
+ }
+ if _, err := s.affiliateService.EnsureUserAffiliate(ctx, userID); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to initialize affiliate profile for user %d: %v", userID, err)
+ }
+ if code := strings.TrimSpace(affiliateCode); code != "" {
+ if err := s.affiliateService.BindInviterByCode(ctx, userID, code); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to bind affiliate inviter for user %d: %v", userID, err)
+ }
+ }
+}
+
+func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
+ if user == nil || user.ID <= 0 {
+ return
+ }
+
+ if strings.TrimSpace(signupSource) == "" {
+ signupSource = "email"
+ }
+ s.updateUserSignupSource(ctx, user.ID, signupSource)
+
+ if touchLogin {
+ s.touchUserLogin(ctx, user.ID)
+ }
+}
+
+func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) {
+ if s == nil || s.entClient == nil || userID <= 0 {
+ return
+ }
+ if strings.TrimSpace(signupSource) == "" {
+ return
+ }
+ if err := s.entClient.User.UpdateOneID(userID).
+ SetSignupSource(signupSource).
+ Exec(ctx); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err)
+ }
+}
+
+func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) {
+ if s == nil || s.entClient == nil || userID <= 0 {
+ return
+ }
+ now := time.Now().UTC()
+ if err := s.entClient.User.UpdateOneID(userID).
+ SetLastLoginAt(now).
+ SetLastActiveAt(now).
+ Exec(ctx); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err)
+ }
+}
+
+func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context, user *User) {
+ if s == nil || user == nil || user.ID <= 0 {
+ return
+ }
+ identity, created := s.ensureEmailAuthIdentity(ctx, user, "auth_service_login_backfill")
+ if s.shouldApplyEmailFirstBindDefaults(ctx, user.ID, identity, created) {
+ if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
+ }
+ }
+}
+
+func (s *AuthService) shouldApplyEmailFirstBindDefaults(
+ ctx context.Context,
+ userID int64,
+ identity *dbent.AuthIdentity,
+ created bool,
+) bool {
+ source := emailAuthIdentitySource(identity.Metadata)
+ if source == "auth_service_login_backfill" {
+ return false
+ }
+ if created {
+ return true
+ }
+ if s == nil || s.entClient == nil || userID <= 0 || identity == nil || identity.UserID != userID {
+ return false
+ }
+ if source != "auth_service_dual_write" {
+ return false
+ }
+
+ hasGrant, err := s.hasProviderGrantRecord(ctx, userID, "email", "first_bind")
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email first bind grant state: user_id=%d err=%v", userID, err)
+ return false
+ }
+ return !hasGrant
+}
+
+func emailAuthIdentitySource(metadata map[string]any) string {
+ if len(metadata) == 0 {
+ return ""
+ }
+ raw, ok := metadata["source"]
+ if !ok {
+ return ""
+ }
+ return strings.TrimSpace(fmt.Sprint(raw))
+}
+
+func (s *AuthService) hasProviderGrantRecord(
+ ctx context.Context,
+ userID int64,
+ providerType string,
+ grantReason string,
+) (bool, error) {
+ if s == nil || s.entClient == nil || userID <= 0 {
+ return false, nil
+ }
+
+ rows, err := s.entClient.QueryContext(
+ ctx,
+ `SELECT 1 FROM user_provider_default_grants WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3 LIMIT 1`,
+ userID,
+ strings.TrimSpace(providerType),
+ strings.TrimSpace(grantReason),
+ )
+ if err != nil {
+ return false, err
+ }
+ defer func() { _ = rows.Close() }()
+ return rows.Next(), rows.Err()
+}
+
+func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, source string) (*dbent.AuthIdentity, bool) {
+ if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
+ return nil, false
+ }
+
+ email := strings.ToLower(strings.TrimSpace(user.Email))
+ if email == "" || isReservedEmail(email) {
+ return nil, false
+ }
+ if strings.TrimSpace(source) == "" {
+ source = "auth_service_dual_write"
+ }
+
+ client := s.entClient
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ client = tx.Client()
+ }
+
+ buildQuery := func() *dbent.AuthIdentityQuery {
+ return client.AuthIdentity.Query().Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ(email),
+ )
+ }
+
+ existed, err := buildQuery().Exist(ctx)
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ return nil, false
+ }
+
+ if !existed {
+ if err := client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject(email).
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{
+ "source": strings.TrimSpace(source),
+ }).
+ OnConflictColumns(
+ authidentity.FieldProviderType,
+ authidentity.FieldProviderKey,
+ authidentity.FieldProviderSubject,
+ ).
+ DoNothing().
+ Exec(ctx); err != nil {
+ if isSQLNoRowsError(err) {
+ return nil, false
+ }
+ }
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ return nil, false
+ }
+ }
+
+ identity, err := buildQuery().Only(ctx)
+ if err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
+ return nil, false
+ }
+ if identity.UserID != user.ID {
+ logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID)
+ return nil, false
+ }
+
+ return identity, !existed
+}
+
+func inferLegacySignupSource(email string) string {
+ normalized := strings.ToLower(strings.TrimSpace(email))
+ switch {
+ case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
+ return "linuxdo"
+ case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
+ return "oidc"
+ case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain):
+ return "wechat"
+ default:
+ return "email"
+ }
+}
+
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
if s.settingService == nil {
return nil
@@ -833,7 +1080,9 @@ func randomHexString(byteLength int) (string, error) {
func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
- return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
+ return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT access token
@@ -852,7 +1101,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
UserID: user.ID,
Email: user.Email,
Role: user.Role,
- TokenVersion: user.TokenVersion,
+ TokenVersion: resolvedTokenVersion(user),
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
@@ -918,7 +1167,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// Security: Check TokenVersion to prevent refreshing revoked tokens
// This ensures tokens issued before a password change cannot be refreshed
- if claims.TokenVersion != user.TokenVersion {
+ if claims.TokenVersion != resolvedTokenVersion(user) {
return "", ErrTokenRevoked
}
@@ -1146,7 +1395,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
data := &RefreshTokenData{
UserID: user.ID,
- TokenVersion: user.TokenVersion,
+ TokenVersion: resolvedTokenVersion(user),
FamilyID: familyID,
CreatedAt: now,
ExpiresAt: now.Add(ttl),
@@ -1226,7 +1475,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
}
// 检查TokenVersion(密码更改后所有Token失效)
- if data.TokenVersion != user.TokenVersion {
+ if data.TokenVersion != resolvedTokenVersion(user) {
// TokenVersion不匹配,撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrTokenRevoked
@@ -1271,8 +1520,42 @@ func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) e
return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID)
}
+// RevokeAllUserTokens invalidates both stateless access tokens and refresh sessions.
+// Access/refresh token verification both depend on TokenVersion, so bumping it provides
+// immediate revocation even if refresh-token cache cleanup later fails.
+func (s *AuthService) RevokeAllUserTokens(ctx context.Context, userID int64) error {
+ user, err := s.userRepo.GetByID(ctx, userID)
+ if err != nil {
+ return fmt.Errorf("get user: %w", err)
+ }
+
+ user.TokenVersion++
+ if err := s.userRepo.Update(ctx, user); err != nil {
+ return fmt.Errorf("update user: %w", err)
+ }
+
+ if err := s.RevokeAllUserSessions(ctx, userID); err != nil {
+ logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after token invalidation for user %d: %v", userID, err)
+ }
+ return nil
+}
+
// hashToken 计算Token的SHA256哈希
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
+
+func resolvedTokenVersion(user *User) int64 {
+ if user == nil {
+ return 0
+ }
+ if user.TokenVersionResolved {
+ return user.TokenVersion
+ }
+
+ material := strings.ToLower(strings.TrimSpace(user.Email)) + "\n" + user.PasswordHash
+ sum := sha256.Sum256([]byte(material))
+ fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff)
+ return user.TokenVersion ^ fingerprint
+}
diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go
new file mode 100644
index 00000000..ea2308f7
--- /dev/null
+++ b/backend/internal/service/auth_service_email_bind_test.go
@@ -0,0 +1,853 @@
+//go:build unit
+
+package service_test
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "sync"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/repository"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+type emailBindDefaultSubAssignerStub struct {
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *emailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
+}
+
+type flakyEmailBindDefaultSubAssignerStub struct {
+ err error
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *flakyEmailBindDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ return nil, false, s.err
+}
+
+func newAuthServiceForEmailBind(
+ t *testing.T,
+ settings map[string]string,
+ emailCache service.EmailCache,
+ defaultSubAssigner service.DefaultSubscriptionAssigner,
+) (*service.AuthService, service.UserRepository, *dbent.Client) {
+ return newAuthServiceForEmailBindWithRefreshCache(t, settings, emailCache, defaultSubAssigner, nil)
+}
+
+func newAuthServiceForEmailBindWithRefreshCache(
+ t *testing.T,
+ settings map[string]string,
+ emailCache service.EmailCache,
+ defaultSubAssigner service.DefaultSubscriptionAssigner,
+ refreshTokenCache service.RefreshTokenCache,
+) (*service.AuthService, service.UserRepository, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_service_email_bind?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ provider_type TEXT NOT NULL,
+ grant_reason TEXT NOT NULL DEFAULT 'first_bind',
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(user_id, provider_type, grant_reason)
+)`)
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ repo := repository.NewUserRepository(client, db)
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-bind-email-secret",
+ ExpireHour: 1,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 3.5,
+ UserConcurrency: 2,
+ },
+ }
+
+ settingRepo := &emailBindSettingRepoStub{values: settings}
+ settingSvc := service.NewSettingService(settingRepo, cfg)
+
+ var emailSvc *service.EmailService
+ if emailCache != nil {
+ emailSvc = service.NewEmailService(settingRepo, emailCache)
+ }
+
+ svc := service.NewAuthService(client, repo, nil, refreshTokenCache, cfg, settingSvc, emailSvc, nil, nil, nil, defaultSubAssigner, nil)
+ return svc, repo, client
+}
+
+func TestAuthServiceBindEmailIdentity_UpdatesEmailAndAppliesFirstBindDefaults(t *testing.T) {
+ assigner := &emailBindDefaultSubAssignerStub{}
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, cache, assigner)
+
+ ctx := context.Background()
+ user, err := client.User.Create().
+ SetEmail("legacy-user" + service.LinuxDoConnectSyntheticEmailDomain).
+ SetUsername("legacy-user").
+ SetPasswordHash("old-hash").
+ SetBalance(2.5).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, " NewEmail@Example.com ", "123456", "new-password")
+ require.NoError(t, err)
+ require.NotNil(t, updatedUser)
+ require.Equal(t, "newemail@example.com", updatedUser.Email)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "newemail@example.com", storedUser.Email)
+ require.Equal(t, 11.0, storedUser.Balance)
+ require.Equal(t, 5, storedUser.Concurrency)
+ require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash))
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("newemail@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, identityCount)
+
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, user.ID, assigner.calls[0].UserID)
+ require.Equal(t, int64(11), assigner.calls[0].GroupID)
+ require.Equal(t, 30, assigner.calls[0].ValidityDays)
+ require.Equal(t, 1, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RejectsExistingEmailOnAnotherUser(t *testing.T) {
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
+
+ ctx := context.Background()
+ sourceUser, err := client.User.Create().
+ SetEmail("source-user" + service.OIDCConnectSyntheticEmailDomain).
+ SetUsername("source-user").
+ SetPasswordHash("old-hash").
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.User.Create().
+ SetEmail("taken@example.com").
+ SetUsername("taken-user").
+ SetPasswordHash("hash").
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, sourceUser.ID, "taken@example.com", "123456", "new-password")
+ require.ErrorIs(t, err, service.ErrEmailExists)
+ require.Nil(t, updatedUser)
+
+ storedUser, err := client.User.Get(ctx, sourceUser.ID)
+ require.NoError(t, err)
+ require.Equal(t, "source-user"+service.OIDCConnectSyntheticEmailDomain, storedUser.Email)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, sourceUser.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RollsBackWhenFirstBindDefaultsFail(t *testing.T) {
+ assigner := &flakyEmailBindDefaultSubAssignerStub{err: errors.New("temporary assign failure")}
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, cache, assigner)
+
+ ctx := context.Background()
+ originalEmail := "legacy-rollback" + service.LinuxDoConnectSyntheticEmailDomain
+ user, err := client.User.Create().
+ SetEmail(originalEmail).
+ SetUsername("legacy-rollback").
+ SetPasswordHash("old-hash").
+ SetBalance(2.5).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "rollback@example.com", "123456", "new-password")
+ require.ErrorContains(t, err, "apply email first bind defaults")
+ require.ErrorContains(t, err, "temporary assign failure")
+ require.Nil(t, updatedUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, originalEmail, storedUser.Email)
+ require.Equal(t, "old-hash", storedUser.PasswordHash)
+ require.Equal(t, 2.5, storedUser.Balance)
+ require.Equal(t, 1, storedUser.Concurrency)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("rollback@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 0, identityCount)
+
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RejectsReservedEmail(t *testing.T) {
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
+
+ ctx := context.Background()
+ user, err := client.User.Create().
+ SetEmail("source-user@example.com").
+ SetUsername("source-user").
+ SetPasswordHash("old-hash").
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "reserved"+service.LinuxDoConnectSyntheticEmailDomain, "123456", "new-password")
+ require.ErrorIs(t, err, service.ErrEmailReserved)
+ require.Nil(t, updatedUser)
+}
+
+func TestAuthServiceBindEmailIdentity_ReplacesBoundEmailAndSkipsFirstBindDefaults(t *testing.T) {
+ assigner := &emailBindDefaultSubAssignerStub{}
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, map[string]string{
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, cache, assigner)
+
+ ctx := context.Background()
+ hashedPassword, err := svc.HashPassword("current-password")
+ require.NoError(t, err)
+
+ user, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("bound-user").
+ SetPasswordHash(hashedPassword).
+ SetBalance(7.5).
+ SetConcurrency(3).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ require.NoError(t, client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject("current@example.com").
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": "test"}).
+ Exec(ctx))
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "current-password")
+ require.NoError(t, err)
+ require.NotNil(t, updatedUser)
+ require.Equal(t, "new@example.com", updatedUser.Email)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "new@example.com", storedUser.Email)
+ require.Equal(t, 7.5, storedUser.Balance)
+ require.Equal(t, 3, storedUser.Concurrency)
+ require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash))
+
+ newIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("new@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, newIdentityCount)
+
+ oldIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("current@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 0, oldIdentityCount)
+
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceBindEmailIdentity_RejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ svc, _, client := newAuthServiceForEmailBind(t, nil, cache, nil)
+
+ ctx := context.Background()
+ hashedPassword, err := svc.HashPassword("current-password")
+ require.NoError(t, err)
+
+ user, err := client.User.Create().
+ SetEmail("current@example.com").
+ SetUsername("bound-user").
+ SetPasswordHash(hashedPassword).
+ SetBalance(1).
+ SetConcurrency(1).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ require.NoError(t, client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject("current@example.com").
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": "test"}).
+ Exec(ctx))
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, user.ID, "new@example.com", "123456", "wrong-password")
+ require.ErrorIs(t, err, service.ErrPasswordIncorrect)
+ require.Nil(t, updatedUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "current@example.com", storedUser.Email)
+ require.True(t, svc.CheckPassword("current-password", storedUser.PasswordHash))
+
+ oldIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("current@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, oldIdentityCount)
+
+ newIdentityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.UserIDEQ(user.ID),
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("new@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 0, newIdentityCount)
+}
+
+func TestAuthServiceBindEmailIdentity_RevokesExistingAccessAndRefreshTokens(t *testing.T) {
+ ctx := context.Background()
+ cache := &emailBindCacheStub{
+ data: &service.VerificationCodeData{
+ Code: "123456",
+ CreatedAt: time.Now().UTC(),
+ ExpiresAt: time.Now().UTC().Add(10 * time.Minute),
+ },
+ }
+ refreshTokenCache := newEmailBindRefreshTokenCacheStub()
+ userRepo := newEmailBindUserRepoStub(&service.User{
+ ID: 41,
+ Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain,
+ Username: "legacy-user",
+ PasswordHash: "old-hash",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ TokenVersion: 4,
+ })
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-bind-email-secret",
+ ExpireHour: 1,
+ AccessTokenExpireMinutes: 60,
+ RefreshTokenExpireDays: 7,
+ },
+ }
+ emailService := service.NewEmailService(nil, cache)
+ svc := service.NewAuthService(nil, userRepo, nil, refreshTokenCache, cfg, nil, emailService, nil, nil, nil, nil, nil)
+
+ oldTokenPair, err := svc.GenerateTokenPair(ctx, &service.User{
+ ID: 41,
+ Email: "legacy-user" + service.OIDCConnectSyntheticEmailDomain,
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ TokenVersion: 4,
+ }, "")
+ require.NoError(t, err)
+
+ updatedUser, err := svc.BindEmailIdentity(ctx, 41, "new@example.com", "123456", "new-password")
+ require.NoError(t, err)
+ require.NotNil(t, updatedUser)
+
+ storedUser, err := userRepo.GetByID(ctx, 41)
+ require.NoError(t, err)
+ require.Equal(t, "new@example.com", storedUser.Email)
+ require.True(t, svc.CheckPassword("new-password", storedUser.PasswordHash))
+
+ _, err = svc.RefreshToken(ctx, oldTokenPair.AccessToken)
+ require.ErrorIs(t, err, service.ErrTokenRevoked)
+
+ _, err = svc.RefreshTokenPair(ctx, oldTokenPair.RefreshToken)
+ require.True(t, errors.Is(err, service.ErrTokenRevoked) || errors.Is(err, service.ErrRefreshTokenInvalid))
+}
+
+type emailBindSettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *emailBindSettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *emailBindSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if v, ok := s.values[key]; ok {
+ return v, nil
+ }
+ return "", service.ErrSettingNotFound
+}
+
+func (s *emailBindSettingRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected Set call")
+}
+
+func (s *emailBindSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if v, ok := s.values[key]; ok {
+ out[key] = v
+ }
+ }
+ return out, nil
+}
+
+func (s *emailBindSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *emailBindSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *emailBindSettingRepoStub) Delete(context.Context, string) error {
+ panic("unexpected Delete call")
+}
+
+type emailBindCacheStub struct {
+ data *service.VerificationCodeData
+ err error
+}
+
+func (s *emailBindCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
+ if s.err != nil {
+ return nil, s.err
+ }
+ return s.data, nil
+}
+
+func (s *emailBindCacheStub) SetVerificationCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) DeleteVerificationCode(context.Context, string) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *emailBindCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) DeleteNotifyVerifyCode(context.Context, string) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) {
+ return nil, nil
+}
+
+func (s *emailBindCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) DeletePasswordResetToken(context.Context, string) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool {
+ return false
+}
+
+func (s *emailBindCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error {
+ return nil
+}
+
+func (s *emailBindCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *emailBindCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) {
+ return 0, nil
+}
+
+type emailBindRefreshTokenCacheStub struct {
+ mu sync.Mutex
+ tokens map[string]*service.RefreshTokenData
+ userSets map[int64]map[string]struct{}
+ families map[string]map[string]struct{}
+}
+
+func newEmailBindRefreshTokenCacheStub() *emailBindRefreshTokenCacheStub {
+ return &emailBindRefreshTokenCacheStub{
+ tokens: make(map[string]*service.RefreshTokenData),
+ userSets: make(map[int64]map[string]struct{}),
+ families: make(map[string]map[string]struct{}),
+ }
+}
+
+func (s *emailBindRefreshTokenCacheStub) StoreRefreshToken(_ context.Context, tokenHash string, data *service.RefreshTokenData, _ time.Duration) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ cloned := *data
+ s.tokens[tokenHash] = &cloned
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) GetRefreshToken(_ context.Context, tokenHash string) (*service.RefreshTokenData, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ data, ok := s.tokens[tokenHash]
+ if !ok {
+ return nil, service.ErrRefreshTokenNotFound
+ }
+ cloned := *data
+ return &cloned, nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) DeleteRefreshToken(_ context.Context, tokenHash string) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ delete(s.tokens, tokenHash)
+ for _, tokenSet := range s.userSets {
+ delete(tokenSet, tokenHash)
+ }
+ for _, tokenSet := range s.families {
+ delete(tokenSet, tokenHash)
+ }
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for tokenHash := range s.userSets[userID] {
+ delete(s.tokens, tokenHash)
+ for _, tokenSet := range s.families {
+ delete(tokenSet, tokenHash)
+ }
+ }
+ delete(s.userSets, userID)
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) DeleteTokenFamily(_ context.Context, familyID string) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for tokenHash := range s.families[familyID] {
+ delete(s.tokens, tokenHash)
+ for _, tokenSet := range s.userSets {
+ delete(tokenSet, tokenHash)
+ }
+ }
+ delete(s.families, familyID)
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) AddToUserTokenSet(_ context.Context, userID int64, tokenHash string, _ time.Duration) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.userSets[userID] == nil {
+ s.userSets[userID] = make(map[string]struct{})
+ }
+ s.userSets[userID][tokenHash] = struct{}{}
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) AddToFamilyTokenSet(_ context.Context, familyID string, tokenHash string, _ time.Duration) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.families[familyID] == nil {
+ s.families[familyID] = make(map[string]struct{})
+ }
+ s.families[familyID][tokenHash] = struct{}{}
+ return nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) GetUserTokenHashes(_ context.Context, userID int64) ([]string, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ tokenSet := s.userSets[userID]
+ out := make([]string, 0, len(tokenSet))
+ for tokenHash := range tokenSet {
+ out = append(out, tokenHash)
+ }
+ return out, nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) GetFamilyTokenHashes(_ context.Context, familyID string) ([]string, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ tokenSet := s.families[familyID]
+ out := make([]string, 0, len(tokenSet))
+ for tokenHash := range tokenSet {
+ out = append(out, tokenHash)
+ }
+ return out, nil
+}
+
+func (s *emailBindRefreshTokenCacheStub) IsTokenInFamily(_ context.Context, familyID string, tokenHash string) (bool, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ _, ok := s.families[familyID][tokenHash]
+ return ok, nil
+}
+
+type emailBindUserRepoStub struct {
+ mu sync.Mutex
+ usersByID map[int64]*service.User
+ usersByEmail map[string]*service.User
+}
+
+func newEmailBindUserRepoStub(user *service.User) *emailBindUserRepoStub {
+ cloned := cloneEmailBindUser(user)
+ return &emailBindUserRepoStub{
+ usersByID: map[int64]*service.User{
+ cloned.ID: cloned,
+ },
+ usersByEmail: map[string]*service.User{
+ cloned.Email: cloned,
+ },
+ }
+}
+
+func (s *emailBindUserRepoStub) Create(context.Context, *service.User) error { return nil }
+
+func (s *emailBindUserRepoStub) GetByID(_ context.Context, id int64) (*service.User, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ user, ok := s.usersByID[id]
+ if !ok {
+ return nil, service.ErrUserNotFound
+ }
+ return cloneEmailBindUser(user), nil
+}
+
+func (s *emailBindUserRepoStub) GetByEmail(_ context.Context, email string) (*service.User, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ user, ok := s.usersByEmail[email]
+ if !ok {
+ return nil, service.ErrUserNotFound
+ }
+ return cloneEmailBindUser(user), nil
+}
+
+func (s *emailBindUserRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
+ panic("unexpected GetFirstAdmin call")
+}
+
+func (s *emailBindUserRepoStub) Update(_ context.Context, user *service.User) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ existing, ok := s.usersByID[user.ID]
+ if !ok {
+ return service.ErrUserNotFound
+ }
+ delete(s.usersByEmail, existing.Email)
+ cloned := cloneEmailBindUser(user)
+ s.usersByID[user.ID] = cloned
+ s.usersByEmail[cloned.Email] = cloned
+ return nil
+}
+
+func (s *emailBindUserRepoStub) Delete(context.Context, int64) error { return nil }
+
+func (s *emailBindUserRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
+ return nil, nil
+}
+
+func (s *emailBindUserRepoStub) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
+ panic("unexpected UpsertUserAvatar call")
+}
+
+func (s *emailBindUserRepoStub) DeleteUserAvatar(context.Context, int64) error {
+ panic("unexpected DeleteUserAvatar call")
+}
+
+func (s *emailBindUserRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected List call")
+}
+
+func (s *emailBindUserRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
+ panic("unexpected ListWithFilters call")
+}
+
+func (s *emailBindUserRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
+ return map[int64]*time.Time{}, nil
+}
+
+func (s *emailBindUserRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
+ return nil, nil
+}
+
+func (s *emailBindUserRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
+ return nil
+}
+
+func (s *emailBindUserRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
+func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
+func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
+
+func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (bool, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ _, ok := s.usersByEmail[email]
+ return ok, nil
+}
+
+func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *emailBindUserRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+
+func (s *emailBindUserRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
+
+func (s *emailBindUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]service.UserAuthIdentityRecord, error) {
+ return nil, nil
+}
+
+func (s *emailBindUserRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
+ return nil
+}
+
+func (s *emailBindUserRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (s *emailBindUserRepoStub) EnableTotp(context.Context, int64) error { return nil }
+func (s *emailBindUserRepoStub) DisableTotp(context.Context, int64) error { return nil }
+
+func cloneEmailBindUser(user *service.User) *service.User {
+ if user == nil {
+ return nil
+ }
+ cloned := *user
+ return &cloned
+}
diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go
new file mode 100644
index 00000000..53048b92
--- /dev/null
+++ b/backend/internal/service/auth_service_identity_sync_test.go
@@ -0,0 +1,482 @@
+//go:build unit
+
+package service_test
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/authidentity"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/repository"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+type authIdentityDefaultSubAssignerStub struct {
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *authIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
+}
+
+type flakyAuthIdentityDefaultSubAssignerStub struct {
+ failuresRemaining int
+ calls []*service.AssignSubscriptionInput
+}
+
+func (s *flakyAuthIdentityDefaultSubAssignerStub) AssignOrExtendSubscription(
+ _ context.Context,
+ input *service.AssignSubscriptionInput,
+) (*service.UserSubscription, bool, error) {
+ cloned := *input
+ s.calls = append(s.calls, &cloned)
+ if s.failuresRemaining > 0 {
+ s.failuresRemaining--
+ return nil, false, errors.New("temporary assign failure")
+ }
+ return &service.UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, true, nil
+}
+
+type authIdentitySettingRepoStub struct {
+ values map[string]string
+}
+
+func (s *authIdentitySettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *authIdentitySettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
+ if v, ok := s.values[key]; ok {
+ return v, nil
+ }
+ return "", service.ErrSettingNotFound
+}
+
+func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error {
+ panic("unexpected Set call")
+}
+
+func (s *authIdentitySettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if v, ok := s.values[key]; ok {
+ out[key] = v
+ }
+ }
+ return out, nil
+}
+
+func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
+ panic("unexpected SetMultiple call")
+}
+
+func (s *authIdentitySettingRepoStub) GetAll(context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error {
+ panic("unexpected Delete call")
+}
+
+func newAuthServiceWithEnt(
+ t *testing.T,
+ settings map[string]string,
+ defaultSubAssigner service.DefaultSubscriptionAssigner,
+) (*service.AuthService, service.UserRepository, *dbent.Client) {
+ t.Helper()
+
+ db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+ _, err = db.Exec(`
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ provider_type TEXT NOT NULL,
+ grant_reason TEXT NOT NULL DEFAULT 'first_bind',
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(user_id, provider_type, grant_reason)
+)`)
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ repo := repository.NewUserRepository(client, db)
+ cfg := &config.Config{
+ JWT: config.JWTConfig{
+ Secret: "test-auth-identity-secret",
+ ExpireHour: 1,
+ },
+ Default: config.DefaultConfig{
+ UserBalance: 3.5,
+ UserConcurrency: 2,
+ },
+ }
+ settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{
+ values: settings,
+ }, cfg)
+
+ svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, defaultSubAssigner, nil)
+ return svc, repo, client
+}
+
+func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ }, nil)
+ ctx := context.Background()
+
+ token, user, err := svc.Register(ctx, "user@example.com", "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, user)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, "email", storedUser.SignupSource)
+ require.NotNil(t, storedUser.LastLoginAt)
+ require.NotNil(t, storedUser.LastActiveAt)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("user@example.com"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+ require.NotNil(t, identity.VerifiedAt)
+}
+
+func TestAuthServiceLoginDefersLastLoginTouchUntilRecordSuccessfulLogin(t *testing.T) {
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ }, nil)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("login@example.com").
+ SetPasswordHash(passwordHash).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ SetBalance(1).
+ SetConcurrency(1).
+ Save(ctx)
+ require.NoError(t, err)
+
+ old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second)
+ _, err = client.User.UpdateOneID(user.ID).
+ SetLastLoginAt(old).
+ SetLastActiveAt(old).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.NotNil(t, storedUser.LastLoginAt)
+ require.NotNil(t, storedUser.LastActiveAt)
+ require.True(t, storedUser.LastLoginAt.Equal(old))
+ require.True(t, storedUser.LastActiveAt.Equal(old))
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("login@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Zero(t, identityCount)
+
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("login@example.com"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+}
+
+func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) {
+ svc, repo, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ }, nil)
+ ctx := context.Background()
+
+ user := &service.User{
+ Email: "record@example.com",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ Balance: 1,
+ Concurrency: 1,
+ }
+ require.NoError(t, user.SetPassword("password"))
+ require.NoError(t, repo.Create(ctx, user))
+
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ identity, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("record@example.com"),
+ ).
+ Only(ctx)
+ require.NoError(t, err)
+ require.Equal(t, user.ID, identity.UserID)
+}
+
+func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenBackfillingLegacyEmailIdentity(t *testing.T) {
+ assigner := &authIdentityDefaultSubAssignerStub{}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("legacy@example.com").
+ SetUsername("legacy-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+
+ identityCount, err := client.AuthIdentity.Query().
+ Where(
+ authidentity.ProviderTypeEQ("email"),
+ authidentity.ProviderKeyEQ("email"),
+ authidentity.ProviderSubjectEQ("legacy@example.com"),
+ ).
+ Count(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, identityCount)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+
+ token, gotUser, err = svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+
+ storedUser, err = client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceLogin_DoesNotApplyMergedEmailFirstBindDefaultsWhenBackfillingLegacyEmailIdentity(t *testing.T) {
+ assigner := &authIdentityDefaultSubAssignerStub{}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyDefaultSubscriptions: `[{"group_id":21,"validity_days":14}]`,
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "5",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("merged-first-bind@example.com").
+ SetUsername("merged-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyExists(t *testing.T) {
+ assigner := &authIdentityDefaultSubAssignerStub{}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("bound@example.com").
+ SetUsername("bound-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(2).
+ SetConcurrency(3).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+ _, err = client.AuthIdentity.Create().
+ SetUserID(user.ID).
+ SetProviderType("email").
+ SetProviderKey("email").
+ SetProviderSubject("bound@example.com").
+ SetVerifiedAt(time.Now().UTC()).
+ SetMetadata(map[string]any{"source": "preexisting"}).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 2.0, storedUser.Balance)
+ require.Equal(t, 3, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func TestAuthServiceLogin_DoesNotRetryEmailFirstBindDefaultsForBackfilledEmailIdentity(t *testing.T) {
+ assigner := &flakyAuthIdentityDefaultSubAssignerStub{failuresRemaining: 1}
+ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyAuthSourceDefaultEmailBalance: "8.5",
+ service.SettingKeyAuthSourceDefaultEmailConcurrency: "4",
+ service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "true",
+ }, assigner)
+ ctx := context.Background()
+
+ passwordHash, err := svc.HashPassword("password")
+ require.NoError(t, err)
+ user, err := client.User.Create().
+ SetEmail("retry-first-bind@example.com").
+ SetUsername("retry-user").
+ SetPasswordHash(passwordHash).
+ SetBalance(1.5).
+ SetConcurrency(2).
+ SetRole(service.RoleUser).
+ SetStatus(service.StatusActive).
+ Save(ctx)
+ require.NoError(t, err)
+
+ token, gotUser, err := svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err := client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+
+ token, gotUser, err = svc.Login(ctx, user.Email, "password")
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+ require.NotNil(t, gotUser)
+ svc.RecordSuccessfulLogin(ctx, user.ID)
+
+ storedUser, err = client.User.Get(ctx, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, 1.5, storedUser.Balance)
+ require.Equal(t, 2, storedUser.Concurrency)
+ require.Empty(t, assigner.calls)
+ require.Equal(t, 0, countProviderGrantRecords(t, client, user.ID, "email", "first_bind"))
+}
+
+func countProviderGrantRecords(
+ t *testing.T,
+ client *dbent.Client,
+ userID int64,
+ providerType string,
+ grantReason string,
+) int {
+ t.Helper()
+
+ var count int
+ rows, err := client.QueryContext(
+ context.Background(),
+ `SELECT COUNT(*) FROM user_provider_default_grants WHERE user_id = ? AND provider_type = ? AND grant_reason = ?`,
+ userID,
+ providerType,
+ grantReason,
+ )
+ require.NoError(t, err)
+ defer rows.Close()
+ require.True(t, rows.Next())
+ require.NoError(t, rows.Scan(&count))
+ require.NoError(t, rows.Err())
+ return count
+}
diff --git a/backend/internal/service/auth_service_pending_oauth_test.go b/backend/internal/service/auth_service_pending_oauth_test.go
deleted file mode 100644
index 0472e06c..00000000
--- a/backend/internal/service/auth_service_pending_oauth_test.go
+++ /dev/null
@@ -1,146 +0,0 @@
-//go:build unit
-
-package service
-
-import (
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/golang-jwt/jwt/v5"
- "github.com/stretchr/testify/require"
-)
-
-func newAuthServiceForPendingOAuthTest() *AuthService {
- cfg := &config.Config{
- JWT: config.JWTConfig{
- Secret: "test-secret-pending-oauth",
- ExpireHour: 1,
- },
- }
- return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
-}
-
-// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
-func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
- require.NoError(t, err)
- require.NotEmpty(t, token)
-
- email, username, err := svc.VerifyPendingOAuthToken(token)
- require.NoError(t, err)
- require.Equal(t, "user@example.com", email)
- require.Equal(t, "alice", username)
-}
-
-// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- // 签发一个普通 access token(JWTClaims,无 Purpose 字段)
- accessToken, err := svc.GenerateToken(&User{
- ID: 1,
- Email: "user@example.com",
- Role: RoleUser,
- })
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(accessToken)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- now := time.Now()
- claims := &pendingOAuthClaims{
- Email: "user@example.com",
- Username: "alice",
- Purpose: "some_other_purpose",
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
- IssuedAt: jwt.NewNumericDate(now),
- NotBefore: jwt.NewNumericDate(now),
- },
- }
- tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(tokenStr)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- now := time.Now()
- claims := &pendingOAuthClaims{
- Email: "user@example.com",
- Username: "alice",
- Purpose: "", // 旧 token 无此字段,反序列化后为零值
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
- IssuedAt: jwt.NewNumericDate(now),
- NotBefore: jwt.NewNumericDate(now),
- },
- }
- tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(tokenStr)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
-
- past := time.Now().Add(-1 * time.Hour)
- claims := &pendingOAuthClaims{
- Email: "user@example.com",
- Username: "alice",
- Purpose: pendingOAuthPurpose,
- RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(past),
- IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
- NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
- },
- }
- tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
- require.NoError(t, err)
-
- _, _, err = svc.VerifyPendingOAuthToken(tokenStr)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
- other := NewAuthService(nil, nil, nil, nil, &config.Config{
- JWT: config.JWTConfig{Secret: "other-secret"},
- }, nil, nil, nil, nil, nil, nil)
-
- token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
- require.NoError(t, err)
-
- svc := newAuthServiceForPendingOAuthTest()
- _, _, err = svc.VerifyPendingOAuthToken(token)
- require.ErrorIs(t, err, ErrInvalidToken)
-}
-
-// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
-func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
- svc := newAuthServiceForPendingOAuthTest()
- giant := make([]byte, maxTokenLength+1)
- for i := range giant {
- giant[i] = 'a'
- }
- _, _, err := svc.VerifyPendingOAuthToken(string(giant))
- require.ErrorIs(t, err, ErrInvalidToken)
-}
diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go
index 7b50e90d..acc44a38 100644
--- a/backend/internal/service/auth_service_register_test.go
+++ b/backend/internal/service/auth_service_register_test.go
@@ -37,7 +37,16 @@ func (s *settingRepoStub) Set(ctx context.Context, key, value string) error {
}
func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
- panic("unexpected GetMultiple call")
+ if s.err != nil {
+ return nil, s.err
+ }
+ result := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if v, ok := s.values[key]; ok {
+ result[key] = v
+ }
+ }
+ return result, nil
}
func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
@@ -62,6 +71,8 @@ type defaultSubscriptionAssignerStub struct {
err error
}
+type refreshTokenCacheStub struct{}
+
func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
if input != nil {
s.calls = append(s.calls, *input)
@@ -72,6 +83,46 @@ func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.C
return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
}
+func (s *refreshTokenCacheStub) StoreRefreshToken(context.Context, string, *RefreshTokenData, time.Duration) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) GetRefreshToken(context.Context, string) (*RefreshTokenData, error) {
+ return nil, ErrRefreshTokenNotFound
+}
+
+func (s *refreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
+ return nil
+}
+
+func (s *refreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
+ return nil, nil
+}
+
+func (s *refreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
+ return nil, nil
+}
+
+func (s *refreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
+ return false, nil
+}
+
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
if s.err != nil {
return nil, s.err
@@ -87,6 +138,18 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return nil
}
+func (s *emailCacheStub) GetNotifyVerifyCode(ctx context.Context, email string) (*VerificationCodeData, error) {
+ return nil, nil
+}
+
+func (s *emailCacheStub) SetNotifyVerifyCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error {
+ return nil
+}
+
+func (s *emailCacheStub) DeleteNotifyVerifyCode(ctx context.Context, email string) error {
+ return nil
+}
+
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
return nil, nil
}
@@ -107,6 +170,14 @@ func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, emai
return nil
}
+func (s *emailCacheStub) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
+ return 0, nil
+}
+
+func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
+ return 0, nil
+}
+
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
@@ -141,6 +212,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
nil,
nil, // promoService
nil, // defaultSubAssigner
+ nil, // affiliateService
)
}
@@ -172,7 +244,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
}, nil)
// 应返回服务不可用错误,而不是允许绕过验证
- _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
+ _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "", "")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
@@ -184,7 +256,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
- _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
+ _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired)
}
@@ -198,7 +270,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
SettingKeyEmailVerifyEnabled: "true",
}, cache)
- _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
+ _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code")
}
@@ -302,7 +374,8 @@ func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
func TestAuthService_Register_Success(t *testing.T) {
repo := &userRepoStub{nextID: 5}
service := newAuthService(repo, map[string]string{
- SettingKeyRegistrationEnabled: "true",
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
token, user, err := service.Register(context.Background(), "user@test.com", "password")
@@ -449,8 +522,9 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
repo := &userRepoStub{nextID: 42}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
- SettingKeyRegistrationEnabled: "true",
- SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
}, nil)
service.defaultSubAssigner = assigner
@@ -464,3 +538,132 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
require.Equal(t, int64(12), assigner.calls[1].GroupID)
require.Equal(t, 7, assigner.calls[1].ValidityDays)
}
+
+func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *testing.T) {
+ repo := &userRepoStub{nextID: 52}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":91,"validity_days":3}]`,
+ SettingKeyAuthSourceDefaultEmailBalance: "12.5",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "7",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+
+ _, user, err := service.Register(context.Background(), "email-defaults@test.com", "password")
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, 12.5, user.Balance)
+ require.Equal(t, 7, user.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(11), assigner.calls[0].GroupID)
+ require.Equal(t, 30, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *testing.T) {
+ repo := &userRepoStub{nextID: 53}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`,
+ SettingKeyAuthSourceDefaultEmailBalance: "99",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "88",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
+ }, nil)
+ service.defaultSubAssigner = assigner
+
+ _, user, err := service.Register(context.Background(), "email-global@test.com", "password")
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, 3.5, user.Balance)
+ require.Equal(t, 2, user.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(31), assigner.calls[0].GroupID)
+ require.Equal(t, 5, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaults(t *testing.T) {
+ repo := &userRepoStub{nextID: 54}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`,
+ SettingKeyAuthSourceDefaultEmailBalance: "9.5",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "5",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: `[]`,
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+
+ _, user, err := service.Register(context.Background(), "email-merged@test.com", "password")
+ require.NoError(t, err)
+ require.NotNil(t, user)
+ require.Equal(t, 9.5, user.Balance)
+ require.Equal(t, 2, user.Concurrency)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(31), assigner.calls[0].GroupID)
+ require.Equal(t, 5, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) {
+ repo := &userRepoStub{nextID: 61}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyDefaultSubscriptions: `[{"group_id":81,"validity_days":1}]`,
+ SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75",
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+ service.refreshTokenCache = &refreshTokenCacheStub{}
+
+ tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "")
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.NotNil(t, user)
+ require.Equal(t, int64(61), user.ID)
+ require.Equal(t, 21.75, user.Balance)
+ require.Equal(t, 9, user.Concurrency)
+ require.Len(t, repo.created, 1)
+ require.Len(t, assigner.calls, 1)
+ require.Equal(t, int64(22), assigner.calls[0].GroupID)
+ require.Equal(t, 14, assigner.calls[0].ValidityDays)
+}
+
+func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantAgain(t *testing.T) {
+ existing := &User{
+ ID: 88,
+ Email: "linuxdo-123@linuxdo-connect.invalid",
+ Username: "existing-linuxdo",
+ Role: RoleUser,
+ Status: StatusActive,
+ Balance: 4,
+ Concurrency: 1,
+ TokenVersion: 2,
+ }
+ repo := &userRepoStub{user: existing}
+ assigner := &defaultSubscriptionAssignerStub{}
+ service := newAuthService(repo, map[string]string{
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75",
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9",
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`,
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
+ }, nil)
+ service.defaultSubAssigner = assigner
+ service.refreshTokenCache = &refreshTokenCacheStub{}
+
+ tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "")
+ require.NoError(t, err)
+ require.NotNil(t, tokenPair)
+ require.Equal(t, existing.ID, user.ID)
+ require.Equal(t, 4.0, user.Balance)
+ require.Equal(t, 1, user.Concurrency)
+ require.Empty(t, repo.created)
+ require.Empty(t, assigner.calls)
+}
diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go
index 477ba1b2..3512822f 100644
--- a/backend/internal/service/auth_service_turnstile_register_test.go
+++ b/backend/internal/service/auth_service_turnstile_register_test.go
@@ -54,6 +54,7 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
nil, // emailQueueService
nil, // promoService
nil, // defaultSubAssigner
+ nil, // affiliateService
)
}
diff --git a/backend/internal/service/balance_notify_check_test.go b/backend/internal/service/balance_notify_check_test.go
new file mode 100644
index 00000000..7bb4cf9e
--- /dev/null
+++ b/backend/internal/service/balance_notify_check_test.go
@@ -0,0 +1,404 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// newBalanceNotifyServiceForTest constructs a BalanceNotifyService with an
+// in-memory settings repo and a non-nil emailService so that the guard-clause
+// nil-checks pass. The emailService is intentionally minimal — tests must
+// avoid crossing scenarios that would actually dispatch emails.
+func newBalanceNotifyServiceForTest() (*BalanceNotifyService, *mockSettingRepo) {
+ repo := newMockSettingRepo()
+ // EmailService is a concrete type; construct with the same repo so that
+ // any accidental fallback reads still succeed. Tests should not trigger a
+ // crossing that reaches SendEmail.
+ email := NewEmailService(repo, nil)
+ return NewBalanceNotifyService(email, repo, nil), repo
+}
+
+// ---------- guard clauses ----------
+
+func TestCheckBalanceAfterDeduction_NilUser(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ // Should not panic.
+ s.CheckBalanceAfterDeduction(context.Background(), nil, 100, 50)
+}
+
+func TestCheckBalanceAfterDeduction_UserNotifyDisabled(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
+ repo.data[SettingKeyBalanceLowNotifyThreshold] = "10"
+ u := &User{ID: 1, BalanceNotifyEnabled: false}
+ // Even with a crossing, disabled flag short-circuits.
+ s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
+}
+
+func TestCheckBalanceAfterDeduction_GlobalDisabled(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "false"
+ u := &User{ID: 1, BalanceNotifyEnabled: true}
+ s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
+}
+
+func TestCheckBalanceAfterDeduction_ThresholdZero(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
+ repo.data[SettingKeyBalanceLowNotifyThreshold] = "0"
+ u := &User{ID: 1, BalanceNotifyEnabled: true}
+ s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
+}
+
+func TestCheckBalanceAfterDeduction_UserThresholdOverride(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
+ repo.data[SettingKeyBalanceLowNotifyThreshold] = "100" // global default
+ customThreshold := 5.0
+ u := &User{
+ ID: 1,
+ BalanceNotifyEnabled: true,
+ BalanceNotifyThreshold: &customThreshold,
+ }
+ // User's 5.0 threshold takes precedence over global 100. 20 -> 15 does not
+ // cross 5, so nothing fires (verified by absence of panic).
+ s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
+}
+
+func TestCheckBalanceAfterDeduction_NoCrossingNotFired(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
+ repo.data[SettingKeyBalanceLowNotifyThreshold] = "10"
+ u := &User{ID: 1, BalanceNotifyEnabled: true}
+
+ // 100 -> 95, both remain above threshold=10, no crossing.
+ s.CheckBalanceAfterDeduction(context.Background(), u, 100, 5)
+ // 5 -> 3, both already below threshold, no crossing (only fires on first
+ // cross from above-to-below).
+ s.CheckBalanceAfterDeduction(context.Background(), u, 5, 2)
+}
+
+// ---------- nil-service guards on CheckAccountQuotaAfterIncrement ----------
+
+func TestCheckAccountQuotaAfterIncrement_NilAccount(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ // Should not panic.
+ s.CheckAccountQuotaAfterIncrement(context.Background(), nil, 10, nil)
+}
+
+func TestCheckAccountQuotaAfterIncrement_ZeroCost(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
+ s.CheckAccountQuotaAfterIncrement(context.Background(), a, 0, nil)
+}
+
+func TestCheckAccountQuotaAfterIncrement_NegativeCost(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
+ s.CheckAccountQuotaAfterIncrement(context.Background(), a, -5, nil)
+}
+
+func TestCheckAccountQuotaAfterIncrement_GlobalDisabled(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false"
+ a := &Account{
+ ID: 1,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "quota_notify_daily_enabled": true,
+ "quota_notify_daily_threshold": 100.0,
+ "quota_daily_limit": 1000.0,
+ "quota_daily_used": 950.0,
+ },
+ }
+ // Global disabled → no processing even if a dim would cross.
+ s.CheckAccountQuotaAfterIncrement(context.Background(), a, 100, nil)
+}
+
+// ---------- sanity: internal helpers still work ----------
+
+func TestGetBalanceNotifyConfig_AllFields(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
+ repo.data[SettingKeyBalanceLowNotifyThreshold] = "12.5"
+ repo.data[SettingKeyBalanceLowNotifyRechargeURL] = "https://example.com/pay"
+
+ enabled, threshold, url := s.getBalanceNotifyConfig(context.Background())
+ require.True(t, enabled)
+ require.Equal(t, 12.5, threshold)
+ require.Equal(t, "https://example.com/pay", url)
+}
+
+func TestGetBalanceNotifyConfig_Disabled(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "false"
+
+ enabled, _, _ := s.getBalanceNotifyConfig(context.Background())
+ require.False(t, enabled)
+}
+
+func TestGetBalanceNotifyConfig_InvalidThreshold(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
+ repo.data[SettingKeyBalanceLowNotifyThreshold] = "not-a-number"
+
+ enabled, threshold, _ := s.getBalanceNotifyConfig(context.Background())
+ require.True(t, enabled)
+ require.Equal(t, 0.0, threshold)
+}
+
+func TestIsAccountQuotaNotifyEnabled(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+
+ // Missing key → false
+ require.False(t, s.isAccountQuotaNotifyEnabled(context.Background()))
+
+ // Explicit "false"
+ repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false"
+ require.False(t, s.isAccountQuotaNotifyEnabled(context.Background()))
+
+ // Explicit "true"
+ repo.data[SettingKeyAccountQuotaNotifyEnabled] = "true"
+ require.True(t, s.isAccountQuotaNotifyEnabled(context.Background()))
+}
+
+func TestGetSiteName_FallsBackToDefault(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ name := s.getSiteName(context.Background())
+ require.Equal(t, defaultSiteName, name)
+}
+
+func TestGetSiteName_Configured(t *testing.T) {
+ s, repo := newBalanceNotifyServiceForTest()
+ repo.data[SettingKeySiteName] = "My Site"
+ require.Equal(t, "My Site", s.getSiteName(context.Background()))
+}
+
+// ---------- crossedDownward ----------
+
+func TestCrossedDownward_CrossesBelow(t *testing.T) {
+ // oldBalance > threshold, newBalance < threshold → true
+ require.True(t, crossedDownward(100, 5, 10))
+}
+
+func TestCrossedDownward_ExactlyAtThreshold(t *testing.T) {
+ // oldBalance > threshold, newBalance == threshold → false (not below)
+ require.False(t, crossedDownward(100, 10, 10))
+}
+
+func TestCrossedDownward_OldExactlyAtThreshold_NewBelow(t *testing.T) {
+ // oldBalance == threshold, newBalance < threshold → true
+ // (at-or-above → below counts as a crossing)
+ require.True(t, crossedDownward(10, 5, 10))
+}
+
+func TestCrossedDownward_AlreadyBelow(t *testing.T) {
+ // oldBalance < threshold → false (already below, no new crossing)
+ require.False(t, crossedDownward(5, 3, 10))
+}
+
+func TestCrossedDownward_BothAbove(t *testing.T) {
+ // oldBalance > threshold, newBalance > threshold → false (no crossing)
+ require.False(t, crossedDownward(100, 50, 10))
+}
+
+func TestCrossedDownward_ZeroThreshold(t *testing.T) {
+ // threshold == 0 → oldV >= 0 is always true, but newV < 0 only for negatives
+ // Typical case: positive balances should not fire when threshold is 0.
+ require.False(t, crossedDownward(10, 5, 0))
+ require.False(t, crossedDownward(0, 0, 0))
+}
+
+func TestCrossedDownward_ZeroThreshold_NegativeNew(t *testing.T) {
+ // Edge case: newBalance goes negative with threshold=0.
+ require.True(t, crossedDownward(5, -1, 0))
+}
+
+func TestCrossedDownward_NegativeValues(t *testing.T) {
+ // Both already negative, threshold is positive → no crossing (already below).
+ require.False(t, crossedDownward(-5, -10, 10))
+}
+
+func TestCrossedDownward_LargeDecrement(t *testing.T) {
+ // A single large deduction crosses the threshold.
+ require.True(t, crossedDownward(1000, 0.5, 100))
+}
+
+func TestCrossedDownward_SmallDecrement_NoCrossing(t *testing.T) {
+ // A tiny deduction stays above threshold.
+ require.False(t, crossedDownward(100, 99.99, 10))
+}
+
+// ---------- checkQuotaDimCrossings ----------
+
+func TestCheckQuotaDimCrossings_NoDimensions(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ // Empty dims → no crossing, no panic.
+ s.checkQuotaDimCrossings(account, nil, 10, []string{"admin@example.com"}, "TestSite")
+ s.checkQuotaDimCrossings(account, []quotaDim{}, 10, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_DisabledDimension(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ dims := []quotaDim{
+ {
+ name: quotaDimDaily,
+ enabled: false, // disabled
+ threshold: 100,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 950,
+ limit: 1000,
+ },
+ }
+ // Disabled dimension should be skipped even if crossing would occur.
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_ZeroThresholdSkipped(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ dims := []quotaDim{
+ {
+ name: quotaDimDaily,
+ enabled: true,
+ threshold: 0, // zero threshold
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 950,
+ limit: 1000,
+ },
+ }
+ // Zero threshold → skipped.
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_NoCrossing_BothBelowThreshold(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ // threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
+ // currentUsed=300 (after), oldUsed=300-50=250 (before). Both < 600, no crossing.
+ dims := []quotaDim{
+ {
+ name: quotaDimDaily,
+ enabled: true,
+ threshold: 400,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 300,
+ limit: 1000,
+ },
+ }
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_NoCrossing_BothAboveThreshold(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ // threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
+ // currentUsed=800 (after), oldUsed=800-50=750 (before). Both >= 600, no crossing.
+ dims := []quotaDim{
+ {
+ name: quotaDimDaily,
+ enabled: true,
+ threshold: 400,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 800,
+ limit: 1000,
+ },
+ }
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_NegativeResolvedThreshold_Skipped(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ // threshold=1200 remaining, limit=1000 → effectiveThreshold = 1000-1200 = -200
+ // Negative resolved threshold → skipped.
+ dims := []quotaDim{
+ {
+ name: quotaDimDaily,
+ enabled: true,
+ threshold: 1200,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 950,
+ limit: 1000,
+ },
+ }
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_PercentageThreshold_NoCrossing(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ // threshold=30%, limit=1000 → effectiveThreshold = 1000 * (1 - 0.30) = 700
+ // currentUsed=500, oldUsed=500-50=450. Both < 700, no crossing.
+ dims := []quotaDim{
+ {
+ name: quotaDimWeekly,
+ enabled: true,
+ threshold: 30,
+ thresholdType: thresholdTypePercentage,
+ currentUsed: 500,
+ limit: 1000,
+ },
+ }
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_ZeroLimit_Skipped(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ // limit=0 → resolvedThreshold returns 0 → skipped.
+ dims := []quotaDim{
+ {
+ name: quotaDimTotal,
+ enabled: true,
+ threshold: 100,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 50,
+ limit: 0,
+ },
+ }
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
+
+func TestCheckQuotaDimCrossings_MultipleDims_MixedResults(t *testing.T) {
+ s, _ := newBalanceNotifyServiceForTest()
+ account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
+ // dim1: no crossing (both below effective threshold)
+ // dim2: disabled (skipped)
+ // dim3: zero threshold (skipped)
+ dims := []quotaDim{
+ {
+ name: quotaDimDaily,
+ enabled: true,
+ threshold: 400,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 300, // oldUsed=250, effectiveThreshold=600, both below
+ limit: 1000,
+ },
+ {
+ name: quotaDimWeekly,
+ enabled: false,
+ threshold: 100,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 900,
+ limit: 1000,
+ },
+ {
+ name: quotaDimTotal,
+ enabled: true,
+ threshold: 0,
+ thresholdType: thresholdTypeFixed,
+ currentUsed: 500,
+ limit: 1000,
+ },
+ }
+ // None should trigger. No panic expected.
+ s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
+}
diff --git a/backend/internal/service/balance_notify_email_body_test.go b/backend/internal/service/balance_notify_email_body_test.go
new file mode 100644
index 00000000..aee5a5bc
--- /dev/null
+++ b/backend/internal/service/balance_notify_email_body_test.go
@@ -0,0 +1,147 @@
+//go:build unit
+
+package service
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// These tests guard against fmt.Sprintf arg-count mismatches in the email
+// templates. A mismatch would produce "%!(EXTRA ...)" or "%!v(MISSING)" in
+// the output, which these assertions will catch.
+
+// ---------- buildBalanceLowEmailBody ----------
+
+func TestBuildBalanceLowEmailBody_ContainsRequiredFields(t *testing.T) {
+ s := &BalanceNotifyService{}
+ body := s.buildBalanceLowEmailBody("Alice", 3.14, 10.0, "MySite", "")
+
+ // All substituted values should appear in the output.
+ require.Contains(t, body, "MySite")
+ require.Contains(t, body, "Alice")
+ require.Contains(t, body, "$3.14")
+ require.Contains(t, body, "$10.00")
+
+ // No fmt.Sprintf format error markers.
+ require.NotContains(t, body, "%!")
+ require.NotContains(t, body, "MISSING")
+ require.NotContains(t, body, "EXTRA")
+}
+
+func TestBuildBalanceLowEmailBody_WithRechargeURL(t *testing.T) {
+ s := &BalanceNotifyService{}
+ body := s.buildBalanceLowEmailBody("Bob", 5.0, 20.0, "Site", "https://example.com/pay")
+
+ // The recharge anchor element should appear with the URL.
+ require.Contains(t, body, `href="https://example.com/pay"`)
+ require.Contains(t, body, "立即充值")
+ require.NotContains(t, body, "%!")
+}
+
+func TestBuildBalanceLowEmailBody_RechargeURLEscaped(t *testing.T) {
+ s := &BalanceNotifyService{}
+ // Try a URL with characters that need HTML escaping.
+ body := s.buildBalanceLowEmailBody("u", 1.0, 5.0, "Site", `https://example.com/?a=1&b=`)))
- require.False(t, IsCloudflareChallengeResponse(http.StatusBadGateway, nil, []byte(`Just a moment... `)))
-}
-
-func TestExtractCloudflareRayID(t *testing.T) {
- headers := make(http.Header)
- headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
- require.Equal(t, "9d01b0e9ecc35829-SEA", ExtractCloudflareRayID(headers, nil))
-
- body := []byte(``)
- require.Equal(t, "9cff2d62d83bb98d", ExtractCloudflareRayID(nil, body))
-}
-
-func TestExtractUpstreamErrorCodeAndMessage(t *testing.T) {
- code, msg := ExtractUpstreamErrorCodeAndMessage([]byte(`{"error":{"code":"cf_shield_429","message":"rate limited"}}`))
- require.Equal(t, "cf_shield_429", code)
- require.Equal(t, "rate limited", msg)
-
- code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`{"code":"unsupported_country_code","message":"not available"}`))
- require.Equal(t, "unsupported_country_code", code)
- require.Equal(t, "not available", msg)
-
- code, msg = ExtractUpstreamErrorCodeAndMessage([]byte(`plain text`))
- require.Equal(t, "", code)
- require.Equal(t, "plain text", msg)
-}
-
-func TestFormatCloudflareChallengeMessage(t *testing.T) {
- headers := make(http.Header)
- headers.Set("cf-ray", "9d03b68c086027a1-SEA")
- msg := FormatCloudflareChallengeMessage("blocked", headers, nil)
- require.Equal(t, "blocked (cf-ray: 9d03b68c086027a1-SEA)", msg)
-}
diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go
index ffca98a5..2279d913 100644
--- a/backend/internal/web/embed_on.go
+++ b/backend/internal/web/embed_on.go
@@ -10,6 +10,8 @@ import (
"io"
"io/fs"
"net/http"
+ "os"
+ "path/filepath"
"strings"
"time"
@@ -32,11 +34,12 @@ type PublicSettingsProvider interface {
// FrontendServer serves the embedded frontend with settings injection
type FrontendServer struct {
- distFS fs.FS
- fileServer http.Handler
- baseHTML []byte
- cache *HTMLCache
- settings PublicSettingsProvider
+ distFS fs.FS
+ fileServer http.Handler
+ baseHTML []byte
+ cache *HTMLCache
+ settings PublicSettingsProvider
+ overrideDir string // local file override directory
}
// NewFrontendServer creates a new frontend server with settings injection
@@ -62,11 +65,12 @@ func NewFrontendServer(settingsProvider PublicSettingsProvider) (*FrontendServer
cache.SetBaseHTML(baseHTML)
return &FrontendServer{
- distFS: distFS,
- fileServer: http.FileServer(http.FS(distFS)),
- baseHTML: baseHTML,
- cache: cache,
- settings: settingsProvider,
+ distFS: distFS,
+ fileServer: http.FileServer(http.FS(distFS)),
+ baseHTML: baseHTML,
+ cache: cache,
+ settings: settingsProvider,
+ overrideDir: filepath.Join("data", "public"),
}, nil
}
@@ -99,6 +103,11 @@ func (s *FrontendServer) Middleware() gin.HandlerFunc {
return
}
+ // Try local override first
+ if s.tryServeOverride(c, cleanPath) {
+ return
+ }
+
// Serve static files normally
s.fileServer.ServeHTTP(c.Writer, c.Request)
c.Abort()
@@ -114,6 +123,22 @@ func (s *FrontendServer) fileExists(path string) bool {
return true
}
+// tryServeOverride checks if a local override file exists and serves it.
+// Files in overrideDir take precedence over embedded files.
+func (s *FrontendServer) tryServeOverride(c *gin.Context, cleanPath string) bool {
+ if s.overrideDir == "" {
+ return false
+ }
+ filePath := filepath.Join(s.overrideDir, filepath.Clean("/"+cleanPath))
+ info, err := os.Stat(filePath)
+ if err != nil || info.IsDir() {
+ return false
+ }
+ c.File(filePath)
+ c.Abort()
+ return true
+}
+
func (s *FrontendServer) serveIndexHTML(c *gin.Context) {
// Get nonce from context (generated by SecurityHeaders middleware)
nonce := middleware.GetNonceFromContext(c)
@@ -226,6 +251,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
panic("failed to get dist subdirectory: " + err.Error())
}
fileServer := http.FileServer(http.FS(distFS))
+ overrideDir := filepath.Join("data", "public")
return func(c *gin.Context) {
path := c.Request.URL.Path
@@ -242,6 +268,10 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
if file, err := distFS.Open(cleanPath); err == nil {
_ = file.Close()
+ // Try local override first
+ if tryServeOverrideFile(c, overrideDir, cleanPath) {
+ return
+ }
fileServer.ServeHTTP(c.Writer, c.Request)
c.Abort()
return
@@ -251,17 +281,33 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
}
}
+// tryServeOverrideFile is a standalone version of tryServeOverride for legacy usage.
+func tryServeOverrideFile(c *gin.Context, overrideDir, cleanPath string) bool {
+ if overrideDir == "" {
+ return false
+ }
+ filePath := filepath.Join(overrideDir, filepath.Clean("/"+cleanPath))
+ info, err := os.Stat(filePath)
+ if err != nil || info.IsDir() {
+ return false
+ }
+ c.File(filePath)
+ c.Abort()
+ return true
+}
+
func shouldBypassEmbeddedFrontend(path string) bool {
trimmed := strings.TrimSpace(path)
return strings.HasPrefix(trimmed, "/api/") ||
strings.HasPrefix(trimmed, "/v1/") ||
strings.HasPrefix(trimmed, "/v1beta/") ||
- strings.HasPrefix(trimmed, "/sora/") ||
+ strings.HasPrefix(trimmed, "/backend-api/") ||
strings.HasPrefix(trimmed, "/antigravity/") ||
strings.HasPrefix(trimmed, "/setup/") ||
trimmed == "/health" ||
trimmed == "/responses" ||
- strings.HasPrefix(trimmed, "/responses/")
+ strings.HasPrefix(trimmed, "/responses/") ||
+ strings.HasPrefix(trimmed, "/images/")
}
func serveIndexHTML(c *gin.Context, fsys fs.FS) {
diff --git a/backend/internal/web/embed_test.go b/backend/internal/web/embed_test.go
index fd47c4da..583d98a0 100644
--- a/backend/internal/web/embed_test.go
+++ b/backend/internal/web/embed_test.go
@@ -434,7 +434,8 @@ func TestFrontendServer_Middleware(t *testing.T) {
"/api/v1/users",
"/v1/models",
"/v1beta/chat",
- "/sora/v1/models",
+ "/backend-api/codex/responses",
+ "/backend-api/codex/responses/compact",
"/antigravity/test",
"/setup/init",
"/health",
@@ -637,7 +638,8 @@ func TestServeEmbeddedFrontend(t *testing.T) {
"/api/users",
"/v1/models",
"/v1beta/chat",
- "/sora/v1/models",
+ "/backend-api/codex/responses",
+ "/backend-api/codex/responses/compact",
"/antigravity/test",
"/setup/init",
"/health",
diff --git a/backend/migrations/081_create_channels.sql b/backend/migrations/081_create_channels.sql
new file mode 100644
index 00000000..3059816b
--- /dev/null
+++ b/backend/migrations/081_create_channels.sql
@@ -0,0 +1,56 @@
+-- Create channels table for managing pricing channels.
+-- A channel groups multiple groups together and provides custom model pricing.
+
+SET LOCAL lock_timeout = '5s';
+SET LOCAL statement_timeout = '10min';
+
+-- 渠道表
+CREATE TABLE IF NOT EXISTS channels (
+ id BIGSERIAL PRIMARY KEY,
+ name VARCHAR(100) NOT NULL,
+ description TEXT DEFAULT '',
+ status VARCHAR(20) NOT NULL DEFAULT 'active',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+-- 渠道名称唯一索引
+CREATE UNIQUE INDEX IF NOT EXISTS idx_channels_name ON channels (name);
+CREATE INDEX IF NOT EXISTS idx_channels_status ON channels (status);
+
+-- 渠道-分组关联表(每个分组只能属于一个渠道)
+CREATE TABLE IF NOT EXISTS channel_groups (
+ id BIGSERIAL PRIMARY KEY,
+ channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
+ group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_groups_group_id ON channel_groups (group_id);
+CREATE INDEX IF NOT EXISTS idx_channel_groups_channel_id ON channel_groups (channel_id);
+
+-- 渠道模型定价表(一条定价可绑定多个模型)
+CREATE TABLE IF NOT EXISTS channel_model_pricing (
+ id BIGSERIAL PRIMARY KEY,
+ channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
+ models JSONB NOT NULL DEFAULT '[]',
+ input_price NUMERIC(20,12),
+ output_price NUMERIC(20,12),
+ cache_write_price NUMERIC(20,12),
+ cache_read_price NUMERIC(20,12),
+ image_output_price NUMERIC(20,8),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_channel_id ON channel_model_pricing (channel_id);
+
+COMMENT ON TABLE channels IS '渠道管理:关联多个分组,提供自定义模型定价';
+COMMENT ON TABLE channel_groups IS '渠道-分组关联表:每个分组最多属于一个渠道';
+COMMENT ON TABLE channel_model_pricing IS '渠道模型定价:一条定价可绑定多个模型,价格一致';
+COMMENT ON COLUMN channel_model_pricing.models IS '绑定的模型列表,JSON 数组,如 ["claude-opus-4-6","claude-opus-4-6-thinking"]';
+COMMENT ON COLUMN channel_model_pricing.input_price IS '每 token 输入价格(USD),NULL 表示使用默认';
+COMMENT ON COLUMN channel_model_pricing.output_price IS '每 token 输出价格(USD),NULL 表示使用默认';
+COMMENT ON COLUMN channel_model_pricing.cache_write_price IS '缓存写入每 token 价格,NULL 表示使用默认';
+COMMENT ON COLUMN channel_model_pricing.cache_read_price IS '缓存读取每 token 价格,NULL 表示使用默认';
+COMMENT ON COLUMN channel_model_pricing.image_output_price IS '图片输出价格(Gemini Image 等),NULL 表示使用默认';
diff --git a/backend/migrations/082_refactor_channel_pricing.sql b/backend/migrations/082_refactor_channel_pricing.sql
new file mode 100644
index 00000000..d0a54062
--- /dev/null
+++ b/backend/migrations/082_refactor_channel_pricing.sql
@@ -0,0 +1,67 @@
+-- Extend channel_model_pricing with billing_mode and add context-interval child table.
+-- Supports three billing modes: token (per-token with context intervals),
+-- per_request (per-request with context-size tiers), and image (per-image).
+
+SET LOCAL lock_timeout = '5s';
+SET LOCAL statement_timeout = '10min';
+
+-- 1. 为 channel_model_pricing 添加 billing_mode 列
+ALTER TABLE channel_model_pricing
+ ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20) NOT NULL DEFAULT 'token';
+
+COMMENT ON COLUMN channel_model_pricing.billing_mode IS '计费模式:token(按 token 区间计费)、per_request(按次计费)、image(图片计费)';
+
+-- 2. 创建区间定价子表
+CREATE TABLE IF NOT EXISTS channel_pricing_intervals (
+ id BIGSERIAL PRIMARY KEY,
+ pricing_id BIGINT NOT NULL REFERENCES channel_model_pricing(id) ON DELETE CASCADE,
+ min_tokens INT NOT NULL DEFAULT 0,
+ max_tokens INT,
+ tier_label VARCHAR(50),
+ input_price NUMERIC(20,12),
+ output_price NUMERIC(20,12),
+ cache_write_price NUMERIC(20,12),
+ cache_read_price NUMERIC(20,12),
+ per_request_price NUMERIC(20,12),
+ sort_order INT NOT NULL DEFAULT 0,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS idx_channel_pricing_intervals_pricing_id
+ ON channel_pricing_intervals (pricing_id);
+
+COMMENT ON TABLE channel_pricing_intervals IS '渠道定价区间:支持按 token 区间、按次分层、图片分辨率分层';
+COMMENT ON COLUMN channel_pricing_intervals.min_tokens IS '区间下界(含),token 模式使用';
+COMMENT ON COLUMN channel_pricing_intervals.max_tokens IS '区间上界(不含),NULL 表示无上限';
+COMMENT ON COLUMN channel_pricing_intervals.tier_label IS '层级标签,按次/图片模式使用(如 1K、2K、4K、HD)';
+COMMENT ON COLUMN channel_pricing_intervals.input_price IS 'token 模式:每 token 输入价';
+COMMENT ON COLUMN channel_pricing_intervals.output_price IS 'token 模式:每 token 输出价';
+COMMENT ON COLUMN channel_pricing_intervals.cache_write_price IS 'token 模式:缓存写入价';
+COMMENT ON COLUMN channel_pricing_intervals.cache_read_price IS 'token 模式:缓存读取价';
+COMMENT ON COLUMN channel_pricing_intervals.per_request_price IS '按次/图片模式:每次请求价格';
+
+-- 3. 迁移现有 flat 定价为单区间 [0, +inf)
+-- 仅迁移有明确定价(至少一个价格字段非 NULL)的条目
+INSERT INTO channel_pricing_intervals (pricing_id, min_tokens, max_tokens, input_price, output_price, cache_write_price, cache_read_price, sort_order)
+SELECT
+ cmp.id,
+ 0,
+ NULL,
+ cmp.input_price,
+ cmp.output_price,
+ cmp.cache_write_price,
+ cmp.cache_read_price,
+ 0
+FROM channel_model_pricing cmp
+WHERE cmp.billing_mode = 'token'
+ AND (cmp.input_price IS NOT NULL OR cmp.output_price IS NOT NULL
+ OR cmp.cache_write_price IS NOT NULL OR cmp.cache_read_price IS NOT NULL)
+ AND NOT EXISTS (
+ SELECT 1 FROM channel_pricing_intervals cpi WHERE cpi.pricing_id = cmp.id
+ );
+
+-- 4. 迁移 image_output_price 为 image 模式的区间条目
+-- 将有 image_output_price 的现有条目复制为 billing_mode='image' 的独立条目
+-- 注意:这里不改变原条目的 billing_mode,而是将 image_output_price 作为向后兼容字段保留
+-- 实际的 image 计费在未来由独立的 billing_mode='image' 条目处理
diff --git a/backend/migrations/083_channel_model_mapping.sql b/backend/migrations/083_channel_model_mapping.sql
new file mode 100644
index 00000000..68e2203f
--- /dev/null
+++ b/backend/migrations/083_channel_model_mapping.sql
@@ -0,0 +1,5 @@
+SET LOCAL lock_timeout = '5s';
+SET LOCAL statement_timeout = '10min';
+
+ALTER TABLE channels ADD COLUMN IF NOT EXISTS model_mapping JSONB DEFAULT '{}';
+COMMENT ON COLUMN channels.model_mapping IS '渠道级模型映射,在账号映射之前执行。格式:{"source_model": "target_model"}';
diff --git a/backend/migrations/084_channel_billing_model_source.sql b/backend/migrations/084_channel_billing_model_source.sql
new file mode 100644
index 00000000..bd615bac
--- /dev/null
+++ b/backend/migrations/084_channel_billing_model_source.sql
@@ -0,0 +1,7 @@
+-- Add billing_model_source to channels (controls whether billing uses requested or upstream model)
+ALTER TABLE channels ADD COLUMN IF NOT EXISTS billing_model_source VARCHAR(20) DEFAULT 'requested';
+
+-- Add channel tracking fields to usage_logs
+ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS channel_id BIGINT;
+ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS model_mapping_chain VARCHAR(500);
+ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_tier VARCHAR(50);
diff --git a/backend/migrations/085_channel_restrict_and_per_request_price.sql b/backend/migrations/085_channel_restrict_and_per_request_price.sql
new file mode 100644
index 00000000..2f494c63
--- /dev/null
+++ b/backend/migrations/085_channel_restrict_and_per_request_price.sql
@@ -0,0 +1,5 @@
+-- Add model restriction switch to channels
+ALTER TABLE channels ADD COLUMN IF NOT EXISTS restrict_models BOOLEAN DEFAULT false;
+
+-- Add default per_request_price to channel_model_pricing (fallback when no tier matches)
+ALTER TABLE channel_model_pricing ADD COLUMN IF NOT EXISTS per_request_price NUMERIC(20,10);
diff --git a/backend/migrations/086_channel_platform_pricing.sql b/backend/migrations/086_channel_platform_pricing.sql
new file mode 100644
index 00000000..f2d08562
--- /dev/null
+++ b/backend/migrations/086_channel_platform_pricing.sql
@@ -0,0 +1,21 @@
+-- 086_channel_platform_pricing.sql
+-- 渠道按平台维度:model_pricing 加 platform 列,model_mapping 改为嵌套格式
+
+-- 1. channel_model_pricing 加 platform 列
+ALTER TABLE channel_model_pricing
+ ADD COLUMN IF NOT EXISTS platform VARCHAR(50) NOT NULL DEFAULT 'anthropic';
+
+CREATE INDEX IF NOT EXISTS idx_channel_model_pricing_platform
+ ON channel_model_pricing (platform);
+
+-- 2. model_mapping: 从扁平 {"src":"dst"} 迁移为嵌套 {"anthropic":{"src":"dst"}}
+-- 仅迁移非空、非 '{}' 的旧格式数据(通过检查第一个 value 是否为字符串来判断是否为旧格式)
+UPDATE channels
+SET model_mapping = jsonb_build_object('anthropic', model_mapping)
+WHERE model_mapping IS NOT NULL
+ AND model_mapping::text NOT IN ('{}', 'null', '')
+ AND NOT EXISTS (
+ SELECT 1 FROM jsonb_each(model_mapping) AS kv
+ WHERE jsonb_typeof(kv.value) = 'object'
+ LIMIT 1
+ );
diff --git a/backend/migrations/087_usage_log_billing_mode.sql b/backend/migrations/087_usage_log_billing_mode.sql
new file mode 100644
index 00000000..8552be0b
--- /dev/null
+++ b/backend/migrations/087_usage_log_billing_mode.sql
@@ -0,0 +1,2 @@
+-- Add billing_mode to usage_logs (records the billing mode: token/per_request/image)
+ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS billing_mode VARCHAR(20);
diff --git a/backend/migrations/088_channel_billing_model_source_channel_mapped.sql b/backend/migrations/088_channel_billing_model_source_channel_mapped.sql
new file mode 100644
index 00000000..83f96b09
--- /dev/null
+++ b/backend/migrations/088_channel_billing_model_source_channel_mapped.sql
@@ -0,0 +1,3 @@
+-- Change default billing_model_source for new channels to 'channel_mapped'
+-- Existing channels keep their current setting (no UPDATE on existing rows)
+ALTER TABLE channels ALTER COLUMN billing_model_source SET DEFAULT 'channel_mapped';
diff --git a/backend/migrations/089_usage_log_image_output_tokens.sql b/backend/migrations/089_usage_log_image_output_tokens.sql
new file mode 100644
index 00000000..dd142b15
--- /dev/null
+++ b/backend/migrations/089_usage_log_image_output_tokens.sql
@@ -0,0 +1,2 @@
+ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS image_output_tokens INTEGER NOT NULL DEFAULT 0;
+ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS image_output_cost DECIMAL(20, 10) NOT NULL DEFAULT 0;
diff --git a/backend/migrations/090_drop_sora.sql b/backend/migrations/090_drop_sora.sql
new file mode 100644
index 00000000..9ef3273f
--- /dev/null
+++ b/backend/migrations/090_drop_sora.sql
@@ -0,0 +1,34 @@
+-- Migration: 090_drop_sora
+-- Remove all Sora-related database objects.
+-- Drops tables: sora_tasks, sora_generations, sora_accounts
+-- Drops columns from: groups, users, usage_logs
+
+-- ============================================================
+-- 1. Drop Sora tables
+-- ============================================================
+DROP TABLE IF EXISTS sora_tasks;
+DROP TABLE IF EXISTS sora_generations;
+DROP TABLE IF EXISTS sora_accounts;
+
+-- ============================================================
+-- 2. Drop Sora columns from groups table
+-- ============================================================
+ALTER TABLE groups
+ DROP COLUMN IF EXISTS sora_image_price_360,
+ DROP COLUMN IF EXISTS sora_image_price_540,
+ DROP COLUMN IF EXISTS sora_video_price_per_request,
+ DROP COLUMN IF EXISTS sora_video_price_per_request_hd,
+ DROP COLUMN IF EXISTS sora_storage_quota_bytes;
+
+-- ============================================================
+-- 3. Drop Sora columns from users table
+-- ============================================================
+ALTER TABLE users
+ DROP COLUMN IF EXISTS sora_storage_quota_bytes,
+ DROP COLUMN IF EXISTS sora_storage_used_bytes;
+
+-- ============================================================
+-- 4. Drop Sora column from usage_logs table
+-- ============================================================
+ALTER TABLE usage_logs
+ DROP COLUMN IF EXISTS media_type;
diff --git a/backend/migrations/091_add_group_messages_dispatch_model_config.sql b/backend/migrations/091_add_group_messages_dispatch_model_config.sql
new file mode 100644
index 00000000..8ddfcb0f
--- /dev/null
+++ b/backend/migrations/091_add_group_messages_dispatch_model_config.sql
@@ -0,0 +1,2 @@
+ALTER TABLE groups
+ADD COLUMN IF NOT EXISTS messages_dispatch_model_config JSONB NOT NULL DEFAULT '{}'::jsonb;
diff --git a/backend/migrations/092_payment_orders.sql b/backend/migrations/092_payment_orders.sql
new file mode 100644
index 00000000..036e4ded
--- /dev/null
+++ b/backend/migrations/092_payment_orders.sql
@@ -0,0 +1,47 @@
+CREATE TABLE IF NOT EXISTS payment_orders (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL,
+ user_email VARCHAR(255) NOT NULL DEFAULT '',
+ user_name VARCHAR(100) NOT NULL DEFAULT '',
+ user_notes TEXT,
+ amount DECIMAL(20,2) NOT NULL,
+ pay_amount DECIMAL(20,2) NOT NULL,
+ fee_rate DECIMAL(10,4) NOT NULL DEFAULT 0,
+ recharge_code VARCHAR(64) NOT NULL DEFAULT '',
+ payment_type VARCHAR(30) NOT NULL DEFAULT '',
+ payment_trade_no VARCHAR(128) NOT NULL DEFAULT '',
+ pay_url TEXT,
+ qr_code TEXT,
+ qr_code_img TEXT,
+ order_type VARCHAR(20) NOT NULL DEFAULT 'balance',
+ plan_id BIGINT,
+ subscription_group_id BIGINT,
+ subscription_days INT,
+ provider_instance_id VARCHAR(64),
+ status VARCHAR(30) NOT NULL DEFAULT 'PENDING',
+ refund_amount DECIMAL(20,2) NOT NULL DEFAULT 0,
+ refund_reason TEXT,
+ refund_at TIMESTAMPTZ,
+ force_refund BOOLEAN NOT NULL DEFAULT FALSE,
+ refund_requested_at TIMESTAMPTZ,
+ refund_request_reason TEXT,
+ refund_requested_by VARCHAR(20),
+ expires_at TIMESTAMPTZ NOT NULL,
+ paid_at TIMESTAMPTZ,
+ completed_at TIMESTAMPTZ,
+ failed_at TIMESTAMPTZ,
+ failed_reason TEXT,
+ client_ip VARCHAR(50) NOT NULL DEFAULT '',
+ src_host VARCHAR(255) NOT NULL DEFAULT '',
+ src_url TEXT,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+-- Indexes
+CREATE INDEX IF NOT EXISTS idx_payment_orders_user_id ON payment_orders(user_id);
+CREATE INDEX IF NOT EXISTS idx_payment_orders_status ON payment_orders(status);
+CREATE INDEX IF NOT EXISTS idx_payment_orders_expires_at ON payment_orders(expires_at);
+CREATE INDEX IF NOT EXISTS idx_payment_orders_created_at ON payment_orders(created_at);
+CREATE INDEX IF NOT EXISTS idx_payment_orders_paid_at ON payment_orders(paid_at);
+CREATE INDEX IF NOT EXISTS idx_payment_orders_type_paid ON payment_orders(payment_type, paid_at);
+CREATE INDEX IF NOT EXISTS idx_payment_orders_order_type ON payment_orders(order_type);
diff --git a/backend/migrations/093_payment_audit_logs.sql b/backend/migrations/093_payment_audit_logs.sql
new file mode 100644
index 00000000..d05b15ef
--- /dev/null
+++ b/backend/migrations/093_payment_audit_logs.sql
@@ -0,0 +1,9 @@
+CREATE TABLE IF NOT EXISTS payment_audit_logs (
+ id BIGSERIAL PRIMARY KEY,
+ order_id VARCHAR(64) NOT NULL,
+ action VARCHAR(50) NOT NULL,
+ detail TEXT NOT NULL DEFAULT '',
+ operator VARCHAR(100) NOT NULL DEFAULT 'system',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+CREATE INDEX IF NOT EXISTS idx_payment_audit_logs_order_id ON payment_audit_logs(order_id);
diff --git a/backend/migrations/094_removed_payment_channels.sql b/backend/migrations/094_removed_payment_channels.sql
new file mode 100644
index 00000000..cb202347
--- /dev/null
+++ b/backend/migrations/094_removed_payment_channels.sql
@@ -0,0 +1,4 @@
+-- Migration 092: payment_channels table was removed before release.
+-- This file is a no-op placeholder to maintain migration numbering continuity.
+-- The payment system now uses the existing channels table (migration 081).
+SELECT 1;
diff --git a/backend/migrations/095_channel_features.sql b/backend/migrations/095_channel_features.sql
new file mode 100644
index 00000000..5f142002
--- /dev/null
+++ b/backend/migrations/095_channel_features.sql
@@ -0,0 +1,2 @@
+ALTER TABLE channels ADD COLUMN IF NOT EXISTS features TEXT NOT NULL DEFAULT '';
+COMMENT ON COLUMN channels.features IS '渠道特性描述,JSON 数组格式,用于支付页面展示';
diff --git a/backend/migrations/095_subscription_plans.sql b/backend/migrations/095_subscription_plans.sql
new file mode 100644
index 00000000..541d8f0c
--- /dev/null
+++ b/backend/migrations/095_subscription_plans.sql
@@ -0,0 +1,18 @@
+CREATE TABLE IF NOT EXISTS subscription_plans (
+ id BIGSERIAL PRIMARY KEY,
+ group_id BIGINT NOT NULL,
+ name VARCHAR(100) NOT NULL,
+ description TEXT NOT NULL DEFAULT '',
+ price DECIMAL(20,2) NOT NULL,
+ original_price DECIMAL(20,2),
+ validity_days INT NOT NULL DEFAULT 30,
+ validity_unit VARCHAR(10) NOT NULL DEFAULT 'day',
+ features TEXT NOT NULL DEFAULT '',
+ product_name VARCHAR(100) NOT NULL DEFAULT '',
+ for_sale BOOLEAN NOT NULL DEFAULT TRUE,
+ sort_order INT NOT NULL DEFAULT 0,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+CREATE INDEX IF NOT EXISTS idx_subscription_plans_group_id ON subscription_plans(group_id);
+CREATE INDEX IF NOT EXISTS idx_subscription_plans_for_sale ON subscription_plans(for_sale);
diff --git a/backend/migrations/096_payment_provider_instances.sql b/backend/migrations/096_payment_provider_instances.sql
new file mode 100644
index 00000000..bedd75df
--- /dev/null
+++ b/backend/migrations/096_payment_provider_instances.sql
@@ -0,0 +1,15 @@
+CREATE TABLE IF NOT EXISTS payment_provider_instances (
+ id BIGSERIAL PRIMARY KEY,
+ provider_key VARCHAR(30) NOT NULL,
+ name VARCHAR(100) NOT NULL DEFAULT '',
+ config TEXT NOT NULL,
+ supported_types VARCHAR(200) NOT NULL DEFAULT '',
+ enabled BOOLEAN NOT NULL DEFAULT TRUE,
+ sort_order INT NOT NULL DEFAULT 0,
+ limits TEXT NOT NULL DEFAULT '',
+ refund_enabled BOOLEAN NOT NULL DEFAULT FALSE,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+CREATE INDEX IF NOT EXISTS idx_payment_provider_instances_provider_key ON payment_provider_instances(provider_key);
+CREATE INDEX IF NOT EXISTS idx_payment_provider_instances_enabled ON payment_provider_instances(enabled);
diff --git a/backend/migrations/097_fix_settings_updated_at_default.sql b/backend/migrations/097_fix_settings_updated_at_default.sql
new file mode 100644
index 00000000..e1d6f9b9
--- /dev/null
+++ b/backend/migrations/097_fix_settings_updated_at_default.sql
@@ -0,0 +1,27 @@
+-- 097_fix_settings_updated_at_default.sql
+--
+-- 修复 settings.updated_at 列在历史实例上可能缺失 SQL DEFAULT 的问题。
+--
+-- 背景:
+-- 早期版本曾依赖 ent 自动迁移建表(ent 的 Default(time.Now) 仅是 Go 层默认值,
+-- 不会在 SQL 层落地为 DEFAULT),随后引入的 005_schema_parity.sql 使用了
+-- CREATE TABLE IF NOT EXISTS,对已存在的 settings 表不会重建,导致这部分实例
+-- 的 updated_at 列虽然是 NOT NULL,但缺少 SQL DEFAULT。
+--
+-- 后续 098_migrate_purchase_subscription_to_custom_menu.sql 是项目中唯一使用
+-- 原生 SQL INSERT INTO settings 的迁移(其余 settings 写入都走 ent / Go 层),
+-- 因此该 schema 缺陷直到 098 才会触发:
+-- "null value in column \"updated_at\" of relation \"settings\" violates not-null constraint"
+--
+-- 幂等性:
+-- - ALTER COLUMN ... SET DEFAULT NOW() 在已经具备相同默认值的实例上是无操作,
+-- 不会报错(PostgreSQL 允许重复设置相同的默认值)。
+-- - UPDATE 子句的 WHERE updated_at IS NULL 在健康实例上匹配 0 行,不影响数据。
+--
+-- 这样可以同时兼容:
+-- 1. 从未运行过旧版迁移的全新部署(005 已经把列建对,本迁移变成 no-op)。
+-- 2. 历史损坏实例(本迁移修复缺失的默认值,使后续 098 能够正常 INSERT)。
+
+ALTER TABLE settings ALTER COLUMN updated_at SET DEFAULT NOW();
+
+UPDATE settings SET updated_at = NOW() WHERE updated_at IS NULL;
diff --git a/backend/migrations/098_migrate_purchase_subscription_to_custom_menu.sql b/backend/migrations/098_migrate_purchase_subscription_to_custom_menu.sql
new file mode 100644
index 00000000..1864459e
--- /dev/null
+++ b/backend/migrations/098_migrate_purchase_subscription_to_custom_menu.sql
@@ -0,0 +1,70 @@
+-- 096_migrate_purchase_subscription_to_custom_menu.sql
+--
+-- Migrates the legacy purchase_subscription_url setting into custom_menu_items.
+-- After migration, purchase_subscription_enabled is set to "false" and
+-- purchase_subscription_url is cleared.
+--
+-- Idempotent: skips if custom_menu_items already contains
+-- "migrated_purchase_subscription".
+
+DO $$
+DECLARE
+ v_enabled text;
+ v_url text;
+ v_raw text;
+ v_items jsonb;
+ v_new_item jsonb;
+BEGIN
+ -- Read legacy settings
+ SELECT value INTO v_enabled
+ FROM settings WHERE key = 'purchase_subscription_enabled';
+ SELECT value INTO v_url
+ FROM settings WHERE key = 'purchase_subscription_url';
+
+ -- Skip if not enabled or URL is empty
+ IF COALESCE(v_enabled, '') <> 'true' OR COALESCE(TRIM(v_url), '') = '' THEN
+ RETURN;
+ END IF;
+
+ -- Read current custom_menu_items
+ SELECT value INTO v_raw
+ FROM settings WHERE key = 'custom_menu_items';
+
+ IF COALESCE(v_raw, '') = '' OR v_raw = 'null' THEN
+ v_items := '[]'::jsonb;
+ ELSE
+ v_items := v_raw::jsonb;
+ END IF;
+
+ -- Skip if already migrated (item with id "migrated_purchase_subscription" exists)
+ IF EXISTS (
+ SELECT 1 FROM jsonb_array_elements(v_items) elem
+ WHERE elem ->> 'id' = 'migrated_purchase_subscription'
+ ) THEN
+ RETURN;
+ END IF;
+
+ -- Build the new menu item
+ v_new_item := jsonb_build_object(
+ 'id', 'migrated_purchase_subscription',
+ 'label', 'Purchase',
+ 'icon_svg', '',
+ 'url', TRIM(v_url),
+ 'visibility', 'user',
+ 'sort_order', 100
+ );
+
+ -- Append to array
+ v_items := v_items || jsonb_build_array(v_new_item);
+
+ -- Upsert custom_menu_items
+ INSERT INTO settings (key, value)
+ VALUES ('custom_menu_items', v_items::text)
+ ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value;
+
+ -- Clear legacy settings
+ UPDATE settings SET value = 'false' WHERE key = 'purchase_subscription_enabled';
+ UPDATE settings SET value = '' WHERE key = 'purchase_subscription_url';
+
+ RAISE NOTICE '[migration-096] Migrated purchase_subscription_url (%) to custom_menu_items', v_url;
+END $$;
diff --git a/backend/migrations/099_fix_migrated_purchase_menu_label_icon.sql b/backend/migrations/099_fix_migrated_purchase_menu_label_icon.sql
new file mode 100644
index 00000000..5361ad81
--- /dev/null
+++ b/backend/migrations/099_fix_migrated_purchase_menu_label_icon.sql
@@ -0,0 +1,51 @@
+-- 097_fix_migrated_purchase_menu_label_icon.sql
+--
+-- Fixes the custom menu item created by migration 096: updates the label
+-- from hardcoded English "Purchase" to "充值/订阅", and sets the icon_svg
+-- to a credit-card SVG matching the sidebar CreditCardIcon.
+--
+-- Idempotent: only modifies items where id = 'migrated_purchase_subscription'.
+
+DO $$
+DECLARE
+ v_raw text;
+ v_items jsonb;
+ v_idx int;
+ v_icon text;
+ v_elem jsonb;
+ v_i int := 0;
+BEGIN
+ SELECT value INTO v_raw
+ FROM settings WHERE key = 'custom_menu_items';
+
+ IF COALESCE(v_raw, '') = '' OR v_raw = 'null' THEN
+ RETURN;
+ END IF;
+
+ v_items := v_raw::jsonb;
+
+ -- Find the index of the migrated item by iterating the array
+ v_idx := NULL;
+ FOR v_elem IN SELECT jsonb_array_elements(v_items) LOOP
+ IF v_elem ->> 'id' = 'migrated_purchase_subscription' THEN
+ v_idx := v_i;
+ EXIT;
+ END IF;
+ v_i := v_i + 1;
+ END LOOP;
+
+ IF v_idx IS NULL THEN
+ RETURN; -- item not found, nothing to fix
+ END IF;
+
+ -- Credit card SVG (Heroicons outline, matches CreditCardIcon in AppSidebar)
+ v_icon := ' ';
+
+ -- Update label and icon_svg
+ v_items := jsonb_set(v_items, ARRAY[v_idx::text, 'label'], '"充值/订阅"'::jsonb);
+ v_items := jsonb_set(v_items, ARRAY[v_idx::text, 'icon_svg'], to_jsonb(v_icon));
+
+ UPDATE settings SET value = v_items::text WHERE key = 'custom_menu_items';
+
+ RAISE NOTICE '[migration-097] Fixed migrated_purchase_subscription: label=充值/订阅, icon=CreditCard SVG';
+END $$;
diff --git a/backend/migrations/100_remove_easypay_from_enabled_payment_types.sql b/backend/migrations/100_remove_easypay_from_enabled_payment_types.sql
new file mode 100644
index 00000000..8128ed09
--- /dev/null
+++ b/backend/migrations/100_remove_easypay_from_enabled_payment_types.sql
@@ -0,0 +1,17 @@
+-- 098_remove_easypay_from_enabled_payment_types.sql
+--
+-- Removes "easypay" from ENABLED_PAYMENT_TYPES setting.
+-- "easypay" is a provider key, not a payment type. Valid payment types
+-- are: alipay, wxpay, alipay_direct, wxpay_direct, stripe.
+--
+-- Idempotent: safe to run multiple times.
+
+UPDATE settings
+ SET value = array_to_string(
+ array_remove(
+ string_to_array(value, ','),
+ 'easypay'
+ ), ','
+ )
+ WHERE key = 'ENABLED_PAYMENT_TYPES'
+ AND value LIKE '%easypay%';
diff --git a/backend/migrations/101_add_account_stats_pricing.sql b/backend/migrations/101_add_account_stats_pricing.sql
new file mode 100644
index 00000000..a61d0c26
--- /dev/null
+++ b/backend/migrations/101_add_account_stats_pricing.sql
@@ -0,0 +1,38 @@
+-- Account statistics pricing: allow channels to configure custom pricing for account cost tracking.
+
+-- 1. Channel-level toggle
+ALTER TABLE channels ADD COLUMN IF NOT EXISTS apply_pricing_to_account_stats BOOLEAN NOT NULL DEFAULT FALSE;
+
+-- 2. Account stats pricing rules (ordered list per channel)
+CREATE TABLE IF NOT EXISTS channel_account_stats_pricing_rules (
+ id BIGSERIAL PRIMARY KEY,
+ channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE,
+ name VARCHAR(100) NOT NULL DEFAULT '',
+ group_ids BIGINT[] NOT NULL DEFAULT '{}',
+ account_ids BIGINT[] NOT NULL DEFAULT '{}',
+ sort_order INT NOT NULL DEFAULT 0,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+CREATE INDEX IF NOT EXISTS idx_cas_pricing_rules_channel_id ON channel_account_stats_pricing_rules(channel_id);
+
+-- 3. Model pricing for each rule (same structure as channel_model_pricing)
+CREATE TABLE IF NOT EXISTS channel_account_stats_model_pricing (
+ id BIGSERIAL PRIMARY KEY,
+ rule_id BIGINT NOT NULL REFERENCES channel_account_stats_pricing_rules(id) ON DELETE CASCADE,
+ platform VARCHAR(50) NOT NULL DEFAULT '',
+ models JSONB NOT NULL DEFAULT '[]',
+ billing_mode VARCHAR(20) NOT NULL DEFAULT 'token',
+ input_price NUMERIC(20,10),
+ output_price NUMERIC(20,10),
+ cache_write_price NUMERIC(20,10),
+ cache_read_price NUMERIC(20,10),
+ image_output_price NUMERIC(20,10),
+ per_request_price NUMERIC(20,10),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+CREATE INDEX IF NOT EXISTS idx_cas_model_pricing_rule_id ON channel_account_stats_model_pricing(rule_id);
+
+-- 4. Usage logs: pre-computed account stats cost (NULL = use default formula)
+ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS account_stats_cost NUMERIC(20,10);
diff --git a/backend/migrations/101_add_balance_notify_fields.sql b/backend/migrations/101_add_balance_notify_fields.sql
new file mode 100644
index 00000000..ef0a0930
--- /dev/null
+++ b/backend/migrations/101_add_balance_notify_fields.sql
@@ -0,0 +1,4 @@
+-- Balance notification user preferences
+ALTER TABLE users ADD COLUMN IF NOT EXISTS balance_notify_enabled BOOLEAN NOT NULL DEFAULT true;
+ALTER TABLE users ADD COLUMN IF NOT EXISTS balance_notify_threshold DECIMAL(20,8) DEFAULT NULL;
+ALTER TABLE users ADD COLUMN IF NOT EXISTS balance_notify_extra_emails TEXT NOT NULL DEFAULT '[]';
diff --git a/backend/migrations/101_add_channel_features_config.sql b/backend/migrations/101_add_channel_features_config.sql
new file mode 100644
index 00000000..b054b085
--- /dev/null
+++ b/backend/migrations/101_add_channel_features_config.sql
@@ -0,0 +1,2 @@
+ALTER TABLE channels ADD COLUMN IF NOT EXISTS features_config JSONB NOT NULL DEFAULT '{}';
+COMMENT ON COLUMN channels.features_config IS '渠道特性配置(如 web_search_emulation),JSON 对象格式';
diff --git a/backend/migrations/101_add_payment_mode.sql b/backend/migrations/101_add_payment_mode.sql
new file mode 100644
index 00000000..eeb6ba7b
--- /dev/null
+++ b/backend/migrations/101_add_payment_mode.sql
@@ -0,0 +1,16 @@
+-- Add payment_mode field to payment_provider_instances
+-- Values: 'redirect' (hosted page redirect), 'api' (API call for QR/payurl), '' (default/N/A)
+ALTER TABLE payment_provider_instances ADD COLUMN IF NOT EXISTS payment_mode VARCHAR(20) NOT NULL DEFAULT '';
+
+-- Migrate existing data: easypay instances with 'easypay' in supported_types → redirect mode
+-- Remove 'easypay' from supported_types and set payment_mode = 'redirect'
+UPDATE payment_provider_instances
+SET payment_mode = 'redirect',
+ supported_types = TRIM(BOTH ',' FROM REPLACE(REPLACE(REPLACE(
+ supported_types, 'easypay,', ''), ',easypay', ''), 'easypay', ''))
+WHERE provider_key = 'easypay' AND supported_types LIKE '%easypay%';
+
+-- EasyPay instances without 'easypay' in supported_types → api mode
+UPDATE payment_provider_instances
+SET payment_mode = 'api'
+WHERE provider_key = 'easypay' AND payment_mode = '';
diff --git a/backend/migrations/102_add_balance_notify_threshold_type.sql b/backend/migrations/102_add_balance_notify_threshold_type.sql
new file mode 100644
index 00000000..7ad70552
--- /dev/null
+++ b/backend/migrations/102_add_balance_notify_threshold_type.sql
@@ -0,0 +1,4 @@
+-- Add threshold type support (fixed / percentage) to balance notification
+ALTER TABLE users ADD COLUMN IF NOT EXISTS balance_notify_threshold_type VARCHAR(10) NOT NULL DEFAULT 'fixed';
+-- Track cumulative recharge amount for percentage threshold calculation
+ALTER TABLE users ADD COLUMN IF NOT EXISTS total_recharged DECIMAL(20,8) NOT NULL DEFAULT 0;
diff --git a/backend/migrations/102_add_out_trade_no_to_payment_orders.sql b/backend/migrations/102_add_out_trade_no_to_payment_orders.sql
new file mode 100644
index 00000000..896c3c95
--- /dev/null
+++ b/backend/migrations/102_add_out_trade_no_to_payment_orders.sql
@@ -0,0 +1,6 @@
+-- 100_add_out_trade_no_to_payment_orders.sql
+-- Adds out_trade_no column for external order ID used with payment providers.
+-- Allows webhook handlers to look up orders by external ID instead of embedding DB ID.
+
+ALTER TABLE payment_orders ADD COLUMN IF NOT EXISTS out_trade_no VARCHAR(64) NOT NULL DEFAULT '';
+CREATE INDEX IF NOT EXISTS paymentorder_out_trade_no ON payment_orders (out_trade_no);
diff --git a/backend/migrations/103_add_allow_user_refund.sql b/backend/migrations/103_add_allow_user_refund.sql
new file mode 100644
index 00000000..79525382
--- /dev/null
+++ b/backend/migrations/103_add_allow_user_refund.sql
@@ -0,0 +1 @@
+ALTER TABLE payment_provider_instances ADD COLUMN IF NOT EXISTS allow_user_refund BOOLEAN NOT NULL DEFAULT false;
diff --git a/backend/migrations/104_migrate_notify_emails_to_struct.sql b/backend/migrations/104_migrate_notify_emails_to_struct.sql
new file mode 100644
index 00000000..4356da4f
--- /dev/null
+++ b/backend/migrations/104_migrate_notify_emails_to_struct.sql
@@ -0,0 +1,35 @@
+-- Migrate notification email lists from old []string format to new []NotifyEmailEntry format
+-- Old: ["a@x.com", "b@x.com"]
+-- New: [{"email":"a@x.com","disabled":false,"verified":true}, ...]
+-- Existing emails are marked as verified=false (unverified), disabled=false (enabled)
+
+-- 1. User balance notification emails
+UPDATE users
+SET balance_notify_extra_emails = (
+ SELECT COALESCE(
+ jsonb_agg(jsonb_build_object('email', elem::text, 'disabled', false, 'verified', false)),
+ '[]'::jsonb
+ )::text
+ FROM jsonb_array_elements_text(balance_notify_extra_emails::jsonb) AS elem
+)
+WHERE balance_notify_extra_emails IS NOT NULL
+ AND balance_notify_extra_emails <> '[]'
+ AND balance_notify_extra_emails <> ''
+ AND (balance_notify_extra_emails::jsonb -> 0) IS NOT NULL
+ AND jsonb_typeof(balance_notify_extra_emails::jsonb -> 0) = 'string';
+
+-- 2. Admin account quota notification emails
+UPDATE settings
+SET value = (
+ SELECT COALESCE(
+ jsonb_agg(jsonb_build_object('email', elem::text, 'disabled', false, 'verified', false)),
+ '[]'::jsonb
+ )::text
+ FROM jsonb_array_elements_text(value::jsonb) AS elem
+)
+WHERE key = 'account_quota_notify_emails'
+ AND value IS NOT NULL
+ AND value <> '[]'
+ AND value <> ''
+ AND (value::jsonb -> 0) IS NOT NULL
+ AND jsonb_typeof(value::jsonb -> 0) = 'string';
diff --git a/backend/migrations/105_migrate_websearch_emulation_to_tristate.sql b/backend/migrations/105_migrate_websearch_emulation_to_tristate.sql
new file mode 100644
index 00000000..745e58df
--- /dev/null
+++ b/backend/migrations/105_migrate_websearch_emulation_to_tristate.sql
@@ -0,0 +1,11 @@
+-- Convert old boolean web_search_emulation to tri-state string
+-- true → "enabled", false → remove key (becomes "default")
+UPDATE accounts
+SET extra = (extra - 'web_search_emulation') || jsonb_build_object('web_search_emulation', 'enabled')
+WHERE extra ? 'web_search_emulation'
+ AND extra->>'web_search_emulation' = 'true';
+
+UPDATE accounts
+SET extra = extra - 'web_search_emulation'
+WHERE extra ? 'web_search_emulation'
+ AND extra->>'web_search_emulation' = 'false';
diff --git a/backend/migrations/106_add_account_stats_pricing_intervals.sql b/backend/migrations/106_add_account_stats_pricing_intervals.sql
new file mode 100644
index 00000000..5ae10655
--- /dev/null
+++ b/backend/migrations/106_add_account_stats_pricing_intervals.sql
@@ -0,0 +1,19 @@
+-- Add intervals table for account stats pricing rules (mirrors channel_pricing_intervals).
+CREATE TABLE IF NOT EXISTS channel_account_stats_pricing_intervals (
+ id BIGSERIAL PRIMARY KEY,
+ pricing_id BIGINT NOT NULL REFERENCES channel_account_stats_model_pricing(id) ON DELETE CASCADE,
+ min_tokens INT NOT NULL DEFAULT 0,
+ max_tokens INT,
+ tier_label VARCHAR(50),
+ input_price NUMERIC(20,12),
+ output_price NUMERIC(20,12),
+ cache_write_price NUMERIC(20,12),
+ cache_read_price NUMERIC(20,12),
+ per_request_price NUMERIC(20,12),
+ sort_order INT NOT NULL DEFAULT 0,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS idx_account_stats_pricing_intervals_pricing_id
+ ON channel_account_stats_pricing_intervals (pricing_id);
diff --git a/backend/migrations/107_add_account_cost_to_dashboard_tables.sql b/backend/migrations/107_add_account_cost_to_dashboard_tables.sql
new file mode 100644
index 00000000..9f815a3f
--- /dev/null
+++ b/backend/migrations/107_add_account_cost_to_dashboard_tables.sql
@@ -0,0 +1,5 @@
+-- Add account_cost column to dashboard aggregation tables for admin dashboard display.
+-- account_cost = SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1))
+
+ALTER TABLE usage_dashboard_hourly ADD COLUMN IF NOT EXISTS account_cost DECIMAL(20, 10) NOT NULL DEFAULT 0;
+ALTER TABLE usage_dashboard_daily ADD COLUMN IF NOT EXISTS account_cost DECIMAL(20, 10) NOT NULL DEFAULT 0;
diff --git a/backend/migrations/108_auth_identity_foundation_core.sql b/backend/migrations/108_auth_identity_foundation_core.sql
new file mode 100644
index 00000000..117e3ca3
--- /dev/null
+++ b/backend/migrations/108_auth_identity_foundation_core.sql
@@ -0,0 +1,141 @@
+ALTER TABLE users
+ADD COLUMN IF NOT EXISTS signup_source VARCHAR(20) NOT NULL DEFAULT 'email',
+ADD COLUMN IF NOT EXISTS last_login_at TIMESTAMPTZ NULL,
+ADD COLUMN IF NOT EXISTS last_active_at TIMESTAMPTZ NULL;
+
+UPDATE users
+SET signup_source = 'email'
+WHERE signup_source IS NULL OR signup_source = '';
+
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'users_signup_source_check'
+ ) THEN
+ ALTER TABLE users
+ ADD CONSTRAINT users_signup_source_check
+ CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc'));
+ END IF;
+END $$;
+
+CREATE TABLE IF NOT EXISTS auth_identities (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ provider_type VARCHAR(20) NOT NULL,
+ provider_key TEXT NOT NULL,
+ provider_subject TEXT NOT NULL,
+ verified_at TIMESTAMPTZ NULL,
+ issuer TEXT NULL,
+ metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT auth_identities_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS auth_identities_provider_subject_key
+ ON auth_identities (provider_type, provider_key, provider_subject);
+
+CREATE INDEX IF NOT EXISTS auth_identities_user_id_idx
+ ON auth_identities (user_id);
+
+CREATE INDEX IF NOT EXISTS auth_identities_user_provider_idx
+ ON auth_identities (user_id, provider_type);
+
+CREATE TABLE IF NOT EXISTS auth_identity_channels (
+ id BIGSERIAL PRIMARY KEY,
+ identity_id BIGINT NOT NULL REFERENCES auth_identities(id) ON DELETE CASCADE,
+ provider_type VARCHAR(20) NOT NULL,
+ provider_key TEXT NOT NULL,
+ channel VARCHAR(20) NOT NULL,
+ channel_app_id TEXT NOT NULL,
+ channel_subject TEXT NOT NULL,
+ metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT auth_identity_channels_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_channels_channel_key
+ ON auth_identity_channels (provider_type, provider_key, channel, channel_app_id, channel_subject);
+
+CREATE INDEX IF NOT EXISTS auth_identity_channels_identity_id_idx
+ ON auth_identity_channels (identity_id);
+
+CREATE TABLE IF NOT EXISTS pending_auth_sessions (
+ id BIGSERIAL PRIMARY KEY,
+ session_token VARCHAR(255) NOT NULL,
+ intent VARCHAR(40) NOT NULL,
+ provider_type VARCHAR(20) NOT NULL,
+ provider_key TEXT NOT NULL,
+ provider_subject TEXT NOT NULL,
+ target_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
+ redirect_to TEXT NOT NULL DEFAULT '',
+ resolved_email TEXT NOT NULL DEFAULT '',
+ registration_password_hash TEXT NOT NULL DEFAULT '',
+ upstream_identity_claims JSONB NOT NULL DEFAULT '{}'::jsonb,
+ local_flow_state JSONB NOT NULL DEFAULT '{}'::jsonb,
+ browser_session_key TEXT NOT NULL DEFAULT '',
+ completion_code_hash TEXT NOT NULL DEFAULT '',
+ completion_code_expires_at TIMESTAMPTZ NULL,
+ email_verified_at TIMESTAMPTZ NULL,
+ password_verified_at TIMESTAMPTZ NULL,
+ totp_verified_at TIMESTAMPTZ NULL,
+ expires_at TIMESTAMPTZ NOT NULL,
+ consumed_at TIMESTAMPTZ NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT pending_auth_sessions_intent_check
+ CHECK (intent IN ('login', 'bind_current_user', 'adopt_existing_user_by_email')),
+ CONSTRAINT pending_auth_sessions_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS pending_auth_sessions_session_token_key
+ ON pending_auth_sessions (session_token);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_target_user_id_idx
+ ON pending_auth_sessions (target_user_id);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_expires_at_idx
+ ON pending_auth_sessions (expires_at);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_provider_idx
+ ON pending_auth_sessions (provider_type, provider_key, provider_subject);
+
+CREATE INDEX IF NOT EXISTS pending_auth_sessions_completion_code_idx
+ ON pending_auth_sessions (completion_code_hash);
+
+CREATE TABLE IF NOT EXISTS identity_adoption_decisions (
+ id BIGSERIAL PRIMARY KEY,
+ pending_auth_session_id BIGINT NOT NULL REFERENCES pending_auth_sessions(id) ON DELETE CASCADE,
+ identity_id BIGINT NULL REFERENCES auth_identities(id) ON DELETE SET NULL,
+ adopt_display_name BOOLEAN NOT NULL DEFAULT FALSE,
+ adopt_avatar BOOLEAN NOT NULL DEFAULT FALSE,
+ decided_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS identity_adoption_decisions_pending_auth_session_id_key
+ ON identity_adoption_decisions (pending_auth_session_id);
+
+CREATE INDEX IF NOT EXISTS identity_adoption_decisions_identity_id_idx
+ ON identity_adoption_decisions (identity_id);
+
+CREATE TABLE IF NOT EXISTS auth_identity_migration_reports (
+ id BIGSERIAL PRIMARY KEY,
+ report_type VARCHAR(40) NOT NULL,
+ report_key TEXT NOT NULL,
+ details JSONB NOT NULL DEFAULT '{}'::jsonb,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS auth_identity_migration_reports_type_idx
+ ON auth_identity_migration_reports (report_type);
+
+CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_migration_reports_type_key
+ ON auth_identity_migration_reports (report_type, report_key);
diff --git a/backend/migrations/108a_widen_auth_identity_migration_report_type.sql b/backend/migrations/108a_widen_auth_identity_migration_report_type.sql
new file mode 100644
index 00000000..bc170fb8
--- /dev/null
+++ b/backend/migrations/108a_widen_auth_identity_migration_report_type.sql
@@ -0,0 +1,14 @@
+DO $$
+BEGIN
+ IF EXISTS (
+ SELECT 1
+ FROM information_schema.columns
+ WHERE table_schema = 'public'
+ AND table_name = 'auth_identity_migration_reports'
+ AND column_name = 'report_type'
+ AND COALESCE(character_maximum_length, 0) < 80
+ ) THEN
+ ALTER TABLE auth_identity_migration_reports
+ ALTER COLUMN report_type TYPE VARCHAR(80);
+ END IF;
+END $$;
diff --git a/backend/migrations/109_auth_identity_compat_backfill.sql b/backend/migrations/109_auth_identity_compat_backfill.sql
new file mode 100644
index 00000000..ddbbedbc
--- /dev/null
+++ b/backend/migrations/109_auth_identity_compat_backfill.sql
@@ -0,0 +1,125 @@
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ u.id,
+ 'email',
+ 'email',
+ LOWER(BTRIM(u.email)),
+ COALESCE(u.updated_at, u.created_at, NOW()),
+ jsonb_build_object(
+ 'backfill_source', 'users.email',
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND BTRIM(COALESCE(u.email, '')) <> ''
+ AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@linuxdo-connect.invalid')) <> '@linuxdo-connect.invalid'
+ AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@oidc-connect.invalid')) <> '@oidc-connect.invalid'
+ AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@wechat-connect.invalid')) <> '@wechat-connect.invalid'
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ u.id,
+ 'linuxdo',
+ 'linuxdo',
+ SUBSTRING(BTRIM(u.email) FROM '(?i)^linuxdo-(.+)@linuxdo-connect\.invalid$'),
+ COALESCE(u.updated_at, u.created_at, NOW()),
+ jsonb_build_object(
+ 'backfill_source', 'synthetic_email',
+ 'legacy_email', BTRIM(u.email),
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^linuxdo-.+@linuxdo-connect\.invalid$'
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ u.id,
+ 'wechat',
+ 'wechat',
+ SUBSTRING(BTRIM(u.email) FROM '(?i)^wechat-(.+)@wechat-connect\.invalid$'),
+ COALESCE(u.updated_at, u.created_at, NOW()),
+ jsonb_build_object(
+ 'backfill_source', 'synthetic_email',
+ 'legacy_email', BTRIM(u.email),
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$'
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+
+UPDATE users
+SET signup_source = 'linuxdo'
+WHERE deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^linuxdo-.+@linuxdo-connect\.invalid$';
+
+UPDATE users
+SET signup_source = 'wechat'
+WHERE deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^wechat-.+@wechat-connect\.invalid$';
+
+UPDATE users
+SET signup_source = 'oidc'
+WHERE deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^oidc-.+@oidc-connect\.invalid$';
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'oidc_synthetic_email_requires_manual_recovery',
+ CAST(u.id AS TEXT),
+ jsonb_build_object(
+ 'user_id', u.id,
+ 'email', LOWER(BTRIM(u.email)),
+ 'reason', 'cannot recover issuer_plus_sub deterministically from synthetic email alone',
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^oidc-.+@oidc-connect\.invalid$'
+ON CONFLICT (report_type, report_key) DO NOTHING;
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ CAST(u.id AS TEXT),
+ jsonb_build_object(
+ 'user_id', u.id,
+ 'email', LOWER(BTRIM(u.email)),
+ 'reason', 'legacy wechat synthetic identity requires explicit unionid remediation if channel-only data exists',
+ 'migration', '109_auth_identity_compat_backfill'
+ )
+FROM users AS u
+WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM auth_identities ai
+ WHERE ai.user_id = u.id
+ AND ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat'
+ )
+ON CONFLICT (report_type, report_key) DO NOTHING;
diff --git a/backend/migrations/110_pending_auth_and_provider_default_grants.sql b/backend/migrations/110_pending_auth_and_provider_default_grants.sql
new file mode 100644
index 00000000..f59b2188
--- /dev/null
+++ b/backend/migrations/110_pending_auth_and_provider_default_grants.sql
@@ -0,0 +1,59 @@
+CREATE TABLE IF NOT EXISTS user_provider_default_grants (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ provider_type VARCHAR(20) NOT NULL,
+ grant_reason VARCHAR(20) NOT NULL DEFAULT 'first_bind',
+ granted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT user_provider_default_grants_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')),
+ CONSTRAINT user_provider_default_grants_reason_check
+ CHECK (grant_reason IN ('signup', 'first_bind'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS user_provider_default_grants_user_provider_reason_key
+ ON user_provider_default_grants (user_id, provider_type, grant_reason);
+
+CREATE INDEX IF NOT EXISTS user_provider_default_grants_user_id_idx
+ ON user_provider_default_grants (user_id);
+
+CREATE TABLE IF NOT EXISTS user_avatars (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ storage_provider VARCHAR(20) NOT NULL DEFAULT 'database',
+ storage_key TEXT NOT NULL DEFAULT '',
+ url TEXT NOT NULL DEFAULT '',
+ content_type VARCHAR(100) NOT NULL DEFAULT '',
+ byte_size INT NOT NULL DEFAULT 0,
+ sha256 VARCHAR(64) NOT NULL DEFAULT '',
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS user_avatars_user_id_key
+ ON user_avatars (user_id);
+
+INSERT INTO settings (key, value)
+VALUES
+ ('auth_source_default_email_balance', '0'),
+ ('auth_source_default_email_concurrency', '5'),
+ ('auth_source_default_email_subscriptions', '[]'),
+ ('auth_source_default_email_grant_on_signup', 'false'),
+ ('auth_source_default_email_grant_on_first_bind', 'false'),
+ ('auth_source_default_linuxdo_balance', '0'),
+ ('auth_source_default_linuxdo_concurrency', '5'),
+ ('auth_source_default_linuxdo_subscriptions', '[]'),
+ ('auth_source_default_linuxdo_grant_on_signup', 'false'),
+ ('auth_source_default_linuxdo_grant_on_first_bind', 'false'),
+ ('auth_source_default_oidc_balance', '0'),
+ ('auth_source_default_oidc_concurrency', '5'),
+ ('auth_source_default_oidc_subscriptions', '[]'),
+ ('auth_source_default_oidc_grant_on_signup', 'false'),
+ ('auth_source_default_oidc_grant_on_first_bind', 'false'),
+ ('auth_source_default_wechat_balance', '0'),
+ ('auth_source_default_wechat_concurrency', '5'),
+ ('auth_source_default_wechat_subscriptions', '[]'),
+ ('auth_source_default_wechat_grant_on_signup', 'false'),
+ ('auth_source_default_wechat_grant_on_first_bind', 'false'),
+ ('force_email_on_third_party_signup', 'false')
+ON CONFLICT (key) DO NOTHING;
diff --git a/backend/migrations/111_payment_routing_and_scheduler_flags.sql b/backend/migrations/111_payment_routing_and_scheduler_flags.sql
new file mode 100644
index 00000000..f222a8d4
--- /dev/null
+++ b/backend/migrations/111_payment_routing_and_scheduler_flags.sql
@@ -0,0 +1,8 @@
+INSERT INTO settings (key, value)
+VALUES
+ ('payment_visible_method_alipay_source', ''),
+ ('payment_visible_method_wxpay_source', ''),
+ ('payment_visible_method_alipay_enabled', 'false'),
+ ('payment_visible_method_wxpay_enabled', 'false'),
+ ('openai_advanced_scheduler_enabled', 'false')
+ON CONFLICT (key) DO NOTHING;
diff --git a/backend/migrations/112_add_payment_order_provider_key_snapshot.sql b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql
new file mode 100644
index 00000000..d331b824
--- /dev/null
+++ b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql
@@ -0,0 +1,10 @@
+ALTER TABLE payment_orders ADD COLUMN IF NOT EXISTS provider_key VARCHAR(30);
+
+UPDATE payment_orders
+SET provider_key = (
+ SELECT provider_key
+ FROM payment_provider_instances
+ WHERE CAST(id AS TEXT) = payment_orders.provider_instance_id
+)
+WHERE provider_key IS NULL
+ AND provider_instance_id IS NOT NULL;
diff --git a/backend/migrations/113_normalize_legacy_wechat_provider_key.sql b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql
new file mode 100644
index 00000000..15610af0
--- /dev/null
+++ b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql
@@ -0,0 +1,89 @@
+UPDATE auth_identities AS ai
+SET
+ provider_key = 'wechat-main',
+ metadata = COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object(
+ 'legacy_provider_key', 'wechat',
+ 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key'
+ ),
+ updated_at = NOW()
+WHERE ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM auth_identities AS canon
+ WHERE canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.provider_subject = ai.provider_subject
+ );
+
+UPDATE auth_identity_channels AS channel
+SET
+ provider_key = 'wechat-main',
+ metadata = COALESCE(channel.metadata, '{}'::jsonb) || jsonb_build_object(
+ 'legacy_provider_key', 'wechat',
+ 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key'
+ ),
+ updated_at = NOW()
+WHERE channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM auth_identity_channels AS canon
+ WHERE canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.channel = channel.channel
+ AND canon.channel_app_id = channel.channel_app_id
+ AND canon.channel_subject = channel.channel_subject
+ );
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_provider_key_conflict',
+ CAST(ai.id AS TEXT),
+ jsonb_build_object(
+ 'legacy_identity_id', ai.id,
+ 'legacy_user_id', ai.user_id,
+ 'provider_subject', ai.provider_subject,
+ 'canonical_identity_id', canon.id,
+ 'canonical_user_id', canon.user_id,
+ 'same_user', canon.user_id = ai.user_id,
+ 'migration', '113_normalize_legacy_wechat_provider_key'
+ )
+FROM auth_identities AS ai
+JOIN auth_identities AS canon
+ ON canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.provider_subject = ai.provider_subject
+WHERE ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat'
+ON CONFLICT (report_type, report_key) DO NOTHING;
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_channel_provider_key_conflict',
+ CAST(channel.id AS TEXT),
+ jsonb_build_object(
+ 'legacy_channel_id', channel.id,
+ 'legacy_identity_id', channel.identity_id,
+ 'canonical_channel_id', canon.id,
+ 'canonical_identity_id', canon.identity_id,
+ 'channel', channel.channel,
+ 'channel_app_id', channel.channel_app_id,
+ 'channel_subject', channel.channel_subject,
+ 'same_user', COALESCE(legacy_identity.user_id = canonical_identity.user_id, FALSE),
+ 'migration', '113_normalize_legacy_wechat_provider_key'
+ )
+FROM auth_identity_channels AS channel
+JOIN auth_identity_channels AS canon
+ ON canon.provider_type = 'wechat'
+ AND canon.provider_key = 'wechat-main'
+ AND canon.channel = channel.channel
+ AND canon.channel_app_id = channel.channel_app_id
+ AND canon.channel_subject = channel.channel_subject
+LEFT JOIN auth_identities AS legacy_identity
+ ON legacy_identity.id = channel.identity_id
+LEFT JOIN auth_identities AS canonical_identity
+ ON canonical_identity.id = canon.identity_id
+WHERE channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat'
+ON CONFLICT (report_type, report_key) DO NOTHING;
diff --git a/backend/migrations/114_auth_identity_migration_report_resolution.sql b/backend/migrations/114_auth_identity_migration_report_resolution.sql
new file mode 100644
index 00000000..f84bf822
--- /dev/null
+++ b/backend/migrations/114_auth_identity_migration_report_resolution.sql
@@ -0,0 +1,11 @@
+ALTER TABLE auth_identity_migration_reports
+ ADD COLUMN IF NOT EXISTS resolved_at TIMESTAMPTZ NULL;
+
+ALTER TABLE auth_identity_migration_reports
+ ADD COLUMN IF NOT EXISTS resolved_by_user_id BIGINT NULL;
+
+ALTER TABLE auth_identity_migration_reports
+ ADD COLUMN IF NOT EXISTS resolution_note TEXT NOT NULL DEFAULT '';
+
+CREATE INDEX IF NOT EXISTS idx_auth_identity_migration_reports_resolved_at
+ ON auth_identity_migration_reports (resolved_at);
diff --git a/backend/migrations/115_auth_identity_legacy_external_backfill.sql b/backend/migrations/115_auth_identity_legacy_external_backfill.sql
new file mode 100644
index 00000000..264da3c9
--- /dev/null
+++ b/backend/migrations/115_auth_identity_legacy_external_backfill.sql
@@ -0,0 +1,268 @@
+CREATE OR REPLACE FUNCTION public.__migration_115_safe_legacy_metadata_jsonb(input_text TEXT)
+RETURNS JSONB
+LANGUAGE plpgsql
+AS $$
+DECLARE
+ parsed JSONB;
+BEGIN
+ IF input_text IS NULL OR BTRIM(input_text) = '' THEN
+ RETURN '{}'::jsonb;
+ END IF;
+
+ BEGIN
+ parsed := input_text::jsonb;
+ EXCEPTION
+ WHEN OTHERS THEN
+ RETURN '{}'::jsonb;
+ END;
+
+ IF jsonb_typeof(parsed) = 'object' THEN
+ RETURN parsed;
+ END IF;
+
+ RETURN jsonb_build_object('_legacy_metadata_raw_json', parsed);
+END;
+$$;
+
+DO $$
+BEGIN
+ IF to_regclass('public.user_external_identities') IS NULL THEN
+ RETURN;
+ END IF;
+
+ EXECUTE $sql$
+WITH legacy AS (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ BTRIM(uei.provider_username) AS provider_username,
+ BTRIM(uei.display_name) AS display_name,
+ public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ uei.created_at,
+ uei.updated_at
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo'
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+),
+legacy_subjects AS (
+ SELECT
+ provider_user_id AS provider_subject,
+ COUNT(DISTINCT user_id) AS distinct_user_count
+ FROM legacy
+ GROUP BY provider_user_id
+),
+canonical_legacy AS (
+ SELECT
+ legacy.*,
+ ROW_NUMBER() OVER (
+ PARTITION BY legacy.provider_user_id
+ ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC
+ ) AS canonical_row_num
+ FROM legacy
+ JOIN legacy_subjects AS subjects
+ ON subjects.provider_subject = legacy.provider_user_id
+ AND subjects.distinct_user_count = 1
+)
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ legacy.user_id,
+ 'linuxdo',
+ 'linuxdo',
+ legacy.provider_user_id,
+ COALESCE(legacy.updated_at, legacy.created_at, NOW()),
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'provider_user_id', legacy.provider_user_id,
+ 'provider_username', legacy.provider_username,
+ 'display_name', legacy.display_name,
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM canonical_legacy AS legacy
+WHERE legacy.canonical_row_num = 1
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+WITH legacy AS (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ BTRIM(uei.provider_union_id) AS provider_union_id,
+ BTRIM(uei.provider_username) AS provider_username,
+ BTRIM(uei.display_name) AS display_name,
+ public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ uei.created_at,
+ uei.updated_at
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+),
+legacy_subjects AS (
+ SELECT
+ provider_union_id AS provider_subject,
+ COUNT(DISTINCT user_id) AS distinct_user_count
+ FROM legacy
+ GROUP BY provider_union_id
+),
+canonical_legacy AS (
+ SELECT
+ legacy.*,
+ ROW_NUMBER() OVER (
+ PARTITION BY legacy.provider_union_id
+ ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC
+ ) AS canonical_row_num
+ FROM legacy
+ JOIN legacy_subjects AS subjects
+ ON subjects.provider_subject = legacy.provider_union_id
+ AND subjects.distinct_user_count = 1
+)
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ legacy.user_id,
+ 'wechat',
+ 'wechat-main',
+ legacy.provider_union_id,
+ COALESCE(legacy.updated_at, legacy.created_at, NOW()),
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'openid', legacy.provider_user_id,
+ 'unionid', legacy.provider_union_id,
+ 'provider_user_id', legacy.provider_user_id,
+ 'provider_union_id', legacy.provider_union_id,
+ 'provider_username', legacy.provider_username,
+ 'display_name', legacy.display_name,
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM canonical_legacy AS legacy
+WHERE legacy.canonical_row_num = 1
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+WITH legacy AS (
+ SELECT
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ BTRIM(uei.provider_union_id) AS provider_union_id,
+ BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel,
+ BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id,
+ meta.metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ CROSS JOIN LATERAL (
+ SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ ) AS meta
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+),
+legacy_subjects AS (
+ SELECT
+ provider_union_id AS provider_subject,
+ COUNT(DISTINCT user_id) AS distinct_user_count
+ FROM legacy
+ GROUP BY provider_union_id
+)
+INSERT INTO auth_identity_channels (
+ identity_id,
+ provider_type,
+ provider_key,
+ channel,
+ channel_app_id,
+ channel_subject,
+ metadata
+)
+SELECT
+ ai.id,
+ 'wechat',
+ 'wechat-main',
+ legacy.channel,
+ legacy.channel_app_id,
+ legacy.provider_user_id,
+ legacy.metadata_json || jsonb_build_object(
+ 'openid', legacy.provider_user_id,
+ 'unionid', legacy.provider_union_id,
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM legacy
+JOIN legacy_subjects AS subjects
+ ON subjects.provider_subject = legacy.provider_union_id
+ AND subjects.distinct_user_count = 1
+JOIN auth_identities AS ai
+ ON ai.user_id = legacy.user_id
+ AND ai.provider_type = 'wechat'
+ AND ai.provider_key = 'wechat-main'
+ AND ai.provider_subject = legacy.provider_union_id
+WHERE legacy.channel <> ''
+ AND legacy.channel_app_id <> ''
+ AND legacy.provider_user_id <> ''
+ON CONFLICT DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'user_id', legacy.user_id,
+ 'openid', legacy.provider_user_id,
+ 'reason', 'legacy user_external_identities row only has openid and cannot be canonicalized offline',
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(uei.provider_user_id) AS provider_user_id,
+ public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) = ''
+) AS legacy
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+END $$;
+
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ 'synthetic_auth_identity:' || ai.id::text,
+ COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object(
+ 'auth_identity_id', ai.id,
+ 'user_id', ai.user_id,
+ 'provider_subject', ai.provider_subject,
+ 'reason', 'synthetic wechat auth identity still lacks unionid metadata and needs remediation',
+ 'migration', '115_auth_identity_legacy_external_backfill'
+ )
+FROM auth_identities AS ai
+WHERE ai.provider_type = 'wechat'
+ AND COALESCE(ai.metadata ->> 'backfill_source', '') = 'synthetic_email'
+ AND BTRIM(COALESCE(ai.metadata ->> 'unionid', '')) = ''
+ON CONFLICT (report_type, report_key) DO NOTHING;
+
+DROP FUNCTION IF EXISTS public.__migration_115_safe_legacy_metadata_jsonb(TEXT);
diff --git a/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql
new file mode 100644
index 00000000..81eb133c
--- /dev/null
+++ b/backend/migrations/116_auth_identity_legacy_external_safety_reports.sql
@@ -0,0 +1,525 @@
+CREATE OR REPLACE FUNCTION public.__migration_116_safe_legacy_metadata_jsonb(input_text TEXT)
+RETURNS JSONB
+LANGUAGE plpgsql
+AS $$
+DECLARE
+ parsed JSONB;
+BEGIN
+ IF input_text IS NULL OR BTRIM(input_text) = '' THEN
+ RETURN '{}'::jsonb;
+ END IF;
+
+ BEGIN
+ parsed := input_text::jsonb;
+ EXCEPTION
+ WHEN OTHERS THEN
+ RETURN '{}'::jsonb;
+ END;
+
+ IF jsonb_typeof(parsed) = 'object' THEN
+ RETURN parsed;
+ END IF;
+
+ RETURN jsonb_build_object('_legacy_metadata_raw_json', parsed);
+END;
+$$;
+
+CREATE OR REPLACE FUNCTION public.__migration_116_is_valid_legacy_metadata_jsonb(input_text TEXT)
+RETURNS BOOLEAN
+LANGUAGE plpgsql
+AS $$
+DECLARE
+ parsed JSONB;
+BEGIN
+ IF input_text IS NULL OR BTRIM(input_text) = '' THEN
+ RETURN TRUE;
+ END IF;
+
+ parsed := input_text::jsonb;
+ RETURN TRUE;
+EXCEPTION
+ WHEN OTHERS THEN
+ RETURN FALSE;
+END;
+$$;
+
+DO $$
+BEGIN
+ IF to_regclass('public.user_external_identities') IS NULL THEN
+ RETURN;
+ END IF;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_identity_invalid_metadata_json',
+ 'legacy_external_identity:' || uei.id::text,
+ jsonb_build_object(
+ 'legacy_identity_id', uei.id,
+ 'user_id', uei.user_id,
+ 'provider', LOWER(BTRIM(COALESCE(uei.provider, ''))),
+ 'provider_user_id', BTRIM(COALESCE(uei.provider_user_id, '')),
+ 'provider_union_id', BTRIM(COALESCE(uei.provider_union_id, '')),
+ 'reason', 'legacy metadata is not valid JSON; migration downgraded metadata to empty object',
+ 'raw_metadata', LEFT(BTRIM(COALESCE(uei.metadata, '')), 1000),
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM user_external_identities AS uei
+JOIN users AS u ON u.id = uei.user_id
+WHERE u.deleted_at IS NULL
+ AND BTRIM(COALESCE(uei.metadata, '')) <> ''
+ AND NOT public.__migration_116_is_valid_legacy_metadata_jsonb(uei.metadata)
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_identity_conflict',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'legacy_user_id', legacy.user_id,
+ 'provider_type', legacy.provider_type,
+ 'provider_key', legacy.provider_key,
+ 'provider_subject', legacy.provider_subject,
+ 'conflicting_legacy_user_ids', ambiguous.conflicting_legacy_user_ids,
+ 'reason', 'legacy canonical identity subject belongs to multiple legacy users and cannot be auto-resolved',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+) AS legacy
+JOIN (
+ SELECT
+ provider_type,
+ provider_key,
+ provider_subject,
+ to_jsonb(array_agg(DISTINCT user_id ORDER BY user_id)) AS conflicting_legacy_user_ids
+ FROM (
+ SELECT
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+ ) AS legacy_subjects
+ GROUP BY provider_type, provider_key, provider_subject
+ HAVING COUNT(DISTINCT user_id) > 1
+) AS ambiguous
+ ON ambiguous.provider_type = legacy.provider_type
+ AND ambiguous.provider_key = legacy.provider_key
+ AND ambiguous.provider_subject = legacy.provider_subject
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_identity_conflict',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'legacy_user_id', legacy.user_id,
+ 'existing_identity_id', ai.id,
+ 'existing_user_id', ai.user_id,
+ 'provider_type', legacy.provider_type,
+ 'provider_key', legacy.provider_key,
+ 'provider_subject', legacy.provider_subject,
+ 'reason', 'legacy canonical identity subject already belongs to another user',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ BTRIM(COALESCE(uei.provider_username, '')) AS provider_username,
+ BTRIM(COALESCE(uei.display_name, '')) AS display_name,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+) AS legacy
+JOIN (
+ SELECT
+ provider_type,
+ provider_key,
+ provider_subject
+ FROM (
+ SELECT
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+ ) AS legacy_subjects
+ GROUP BY provider_type, provider_key, provider_subject
+ HAVING COUNT(DISTINCT user_id) = 1
+) AS clear_subjects
+ ON clear_subjects.provider_type = legacy.provider_type
+ AND clear_subjects.provider_key = legacy.provider_key
+ AND clear_subjects.provider_subject = legacy.provider_subject
+JOIN auth_identities AS ai
+ ON ai.provider_type = legacy.provider_type
+ AND ai.provider_key = legacy.provider_key
+ AND ai.provider_subject = legacy.provider_subject
+WHERE ai.user_id <> legacy.user_id
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+WITH legacy AS (
+ SELECT
+ uei.id,
+ uei.user_id,
+ LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
+ ELSE 'linuxdo'
+ END AS provider_key,
+ CASE
+ WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
+ ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
+ END AS provider_subject,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ BTRIM(COALESCE(uei.provider_username, '')) AS provider_username,
+ BTRIM(COALESCE(uei.display_name, '')) AS display_name,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ COALESCE(uei.updated_at, uei.created_at, NOW()) AS verified_at
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
+ AND (
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
+ OR
+ (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
+ )
+),
+clear_subjects AS (
+ SELECT
+ provider_type,
+ provider_key,
+ provider_subject
+ FROM legacy
+ GROUP BY provider_type, provider_key, provider_subject
+ HAVING COUNT(DISTINCT user_id) = 1
+),
+canonical_legacy AS (
+ SELECT
+ legacy.*,
+ ROW_NUMBER() OVER (
+ PARTITION BY legacy.provider_type, legacy.provider_key, legacy.provider_subject
+ ORDER BY legacy.verified_at DESC, legacy.id DESC
+ ) AS canonical_row_num
+ FROM legacy
+ JOIN clear_subjects
+ ON clear_subjects.provider_type = legacy.provider_type
+ AND clear_subjects.provider_key = legacy.provider_key
+ AND clear_subjects.provider_subject = legacy.provider_subject
+)
+INSERT INTO auth_identities (
+ user_id,
+ provider_type,
+ provider_key,
+ provider_subject,
+ verified_at,
+ metadata
+)
+SELECT
+ legacy.user_id,
+ legacy.provider_type,
+ legacy.provider_key,
+ legacy.provider_subject,
+ legacy.verified_at,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'provider_user_id', legacy.provider_user_id,
+ 'provider_union_id', NULLIF(legacy.provider_union_id, ''),
+ 'provider_username', legacy.provider_username,
+ 'display_name', legacy.display_name,
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM canonical_legacy AS legacy
+LEFT JOIN auth_identities AS ai
+ ON ai.provider_type = legacy.provider_type
+ AND ai.provider_key = legacy.provider_key
+ AND ai.provider_subject = legacy.provider_subject
+WHERE legacy.canonical_row_num = 1
+ AND ai.id IS NULL
+ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_external_channel_conflict',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'legacy_user_id', legacy.user_id,
+ 'existing_channel_id', channel.id,
+ 'existing_identity_id', existing_ai.id,
+ 'existing_user_id', existing_ai.user_id,
+ 'provider_type', 'wechat',
+ 'provider_key', 'wechat-main',
+ 'provider_subject', legacy.provider_union_id,
+ 'channel', legacy.channel,
+ 'channel_app_id', legacy.channel_app_id,
+ 'channel_subject', legacy.provider_user_id,
+ 'reason', 'legacy channel subject already belongs to another user',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
+ BTRIM(COALESCE(
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
+ ''
+ )) AS channel_app_id
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+) AS legacy
+JOIN (
+ SELECT
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_subject
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+ GROUP BY BTRIM(COALESCE(uei.provider_union_id, ''))
+ HAVING COUNT(DISTINCT uei.user_id) = 1
+) AS clear_subjects
+ ON clear_subjects.provider_subject = legacy.provider_union_id
+JOIN auth_identities AS legacy_ai
+ ON legacy_ai.user_id = legacy.user_id
+ AND legacy_ai.provider_type = 'wechat'
+ AND legacy_ai.provider_key = 'wechat-main'
+ AND legacy_ai.provider_subject = legacy.provider_union_id
+JOIN auth_identity_channels AS channel
+ ON channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat-main'
+ AND channel.channel = legacy.channel
+ AND channel.channel_app_id = legacy.channel_app_id
+ AND channel.channel_subject = legacy.provider_user_id
+JOIN auth_identities AS existing_ai
+ ON existing_ai.id = channel.identity_id
+WHERE legacy.channel <> ''
+ AND legacy.channel_app_id <> ''
+ AND existing_ai.user_id <> legacy.user_id
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+WITH legacy AS (
+ SELECT
+ uei.user_id,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
+ BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
+ BTRIM(COALESCE(
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
+ ''
+ )) AS channel_app_id
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+),
+clear_subjects AS (
+ SELECT
+ provider_union_id AS provider_subject
+ FROM legacy
+ GROUP BY provider_union_id
+ HAVING COUNT(DISTINCT user_id) = 1
+)
+INSERT INTO auth_identity_channels (
+ identity_id,
+ provider_type,
+ provider_key,
+ channel,
+ channel_app_id,
+ channel_subject,
+ metadata
+)
+SELECT
+ legacy_ai.id,
+ 'wechat',
+ 'wechat-main',
+ legacy.channel,
+ legacy.channel_app_id,
+ legacy.provider_user_id,
+ legacy.metadata_json || jsonb_build_object(
+ 'openid', legacy.provider_user_id,
+ 'unionid', legacy.provider_union_id,
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM legacy
+JOIN clear_subjects
+ ON clear_subjects.provider_subject = legacy.provider_union_id
+JOIN auth_identities AS legacy_ai
+ ON legacy_ai.user_id = legacy.user_id
+ AND legacy_ai.provider_type = 'wechat'
+ AND legacy_ai.provider_key = 'wechat-main'
+ AND legacy_ai.provider_subject = legacy.provider_union_id
+LEFT JOIN auth_identity_channels AS channel
+ ON channel.provider_type = 'wechat'
+ AND channel.provider_key = 'wechat-main'
+ AND channel.channel = legacy.channel
+ AND channel.channel_app_id = legacy.channel_app_id
+ AND channel.channel_subject = legacy.provider_user_id
+WHERE legacy.channel <> ''
+ AND legacy.channel_app_id <> ''
+ AND channel.id IS NULL
+ON CONFLICT DO NOTHING;
+$sql$;
+
+ EXECUTE $sql$
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'wechat_openid_only_requires_remediation',
+ 'legacy_external_identity:' || legacy.id::text,
+ legacy.metadata_json || jsonb_build_object(
+ 'legacy_identity_id', legacy.id,
+ 'user_id', legacy.user_id,
+ 'openid', legacy.provider_user_id,
+ 'reason', 'legacy user_external_identities row only has openid and cannot be canonicalized offline',
+ 'migration', '116_auth_identity_legacy_external_safety_reports'
+ )
+FROM (
+ SELECT
+ uei.id,
+ uei.user_id,
+ BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
+ public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
+ FROM user_external_identities AS uei
+ JOIN users AS u ON u.id = uei.user_id
+ WHERE u.deleted_at IS NULL
+ AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
+ AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
+ AND BTRIM(COALESCE(uei.provider_union_id, '')) = ''
+) AS legacy
+ON CONFLICT (report_type, report_key) DO NOTHING;
+$sql$;
+END $$;
+
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'auth_identities_metadata_is_object_check'
+ ) THEN
+ ALTER TABLE auth_identities
+ ADD CONSTRAINT auth_identities_metadata_is_object_check
+ CHECK (jsonb_typeof(metadata) = 'object');
+ END IF;
+
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'auth_identity_channels_metadata_is_object_check'
+ ) THEN
+ ALTER TABLE auth_identity_channels
+ ADD CONSTRAINT auth_identity_channels_metadata_is_object_check
+ CHECK (jsonb_typeof(metadata) = 'object');
+ END IF;
+
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'auth_identity_migration_reports_details_is_object_check'
+ ) THEN
+ ALTER TABLE auth_identity_migration_reports
+ ADD CONSTRAINT auth_identity_migration_reports_details_is_object_check
+ CHECK (jsonb_typeof(details) = 'object');
+ END IF;
+END $$;
+
+DROP FUNCTION IF EXISTS public.__migration_116_is_valid_legacy_metadata_jsonb(TEXT);
+DROP FUNCTION IF EXISTS public.__migration_116_safe_legacy_metadata_jsonb(TEXT);
diff --git a/backend/migrations/117_add_payment_order_provider_snapshot.sql b/backend/migrations/117_add_payment_order_provider_snapshot.sql
new file mode 100644
index 00000000..56a5fe2d
--- /dev/null
+++ b/backend/migrations/117_add_payment_order_provider_snapshot.sql
@@ -0,0 +1,2 @@
+ALTER TABLE payment_orders
+ADD COLUMN IF NOT EXISTS provider_snapshot JSONB;
diff --git a/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql b/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql
new file mode 100644
index 00000000..18782617
--- /dev/null
+++ b/backend/migrations/118_wechat_dual_mode_and_auth_source_defaults.sql
@@ -0,0 +1,25 @@
+INSERT INTO settings (key, value)
+VALUES
+ (
+ 'wechat_connect_open_enabled',
+ CASE
+ WHEN NOT EXISTS (SELECT 1 FROM settings WHERE key = 'wechat_connect_enabled') THEN ''
+ WHEN COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_enabled'), 'false') <> 'true' THEN 'false'
+ WHEN LOWER(TRIM(COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_mode'), 'open'))) = 'mp' THEN 'false'
+ ELSE 'true'
+ END
+ ),
+ (
+ 'wechat_connect_mp_enabled',
+ CASE
+ WHEN NOT EXISTS (SELECT 1 FROM settings WHERE key = 'wechat_connect_enabled') THEN ''
+ WHEN COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_enabled'), 'false') <> 'true' THEN 'false'
+ WHEN LOWER(TRIM(COALESCE((SELECT value FROM settings WHERE key = 'wechat_connect_mode'), 'open'))) = 'mp' THEN 'true'
+ ELSE 'false'
+ END
+ ),
+ ('auth_source_default_email_grant_on_signup', 'false'),
+ ('auth_source_default_linuxdo_grant_on_signup', 'false'),
+ ('auth_source_default_oidc_grant_on_signup', 'false'),
+ ('auth_source_default_wechat_grant_on_signup', 'false')
+ON CONFLICT (key) DO NOTHING;
diff --git a/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql b/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql
new file mode 100644
index 00000000..15e2c15f
--- /dev/null
+++ b/backend/migrations/119_enforce_payment_orders_out_trade_no_unique.sql
@@ -0,0 +1,6 @@
+-- Intentionally left as a no-op.
+-- The online index rollout lives in 120_enforce_payment_orders_out_trade_no_unique_notx.sql
+DO $$
+BEGIN
+ NULL;
+END $$;
diff --git a/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql b/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql
new file mode 100644
index 00000000..638d8622
--- /dev/null
+++ b/backend/migrations/120_enforce_payment_orders_out_trade_no_unique_notx.sql
@@ -0,0 +1,10 @@
+-- Build the payment order uniqueness guarantee online.
+-- The migration runner performs an explicit duplicate out_trade_no precheck and
+-- drops any stale invalid paymentorder_out_trade_no_unique index before retrying.
+-- Create the new partial unique index concurrently first so writes keep flowing,
+-- then remove the legacy index name once the replacement is ready.
+CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
+ ON payment_orders (out_trade_no)
+ WHERE out_trade_no <> '';
+
+DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
diff --git a/backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql b/backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql
new file mode 100644
index 00000000..ef2599dc
--- /dev/null
+++ b/backend/migrations/120a_align_payment_orders_out_trade_no_index_name.sql
@@ -0,0 +1,22 @@
+DO $$
+BEGIN
+ IF EXISTS (
+ SELECT 1
+ FROM pg_indexes
+ WHERE schemaname = 'public'
+ AND tablename = 'payment_orders'
+ AND indexname = 'paymentorder_out_trade_no_unique'
+ ) THEN
+ IF EXISTS (
+ SELECT 1
+ FROM pg_indexes
+ WHERE schemaname = 'public'
+ AND tablename = 'payment_orders'
+ AND indexname = 'paymentorder_out_trade_no'
+ ) THEN
+ EXECUTE 'DROP INDEX IF EXISTS paymentorder_out_trade_no';
+ END IF;
+
+ EXECUTE 'ALTER INDEX paymentorder_out_trade_no_unique RENAME TO paymentorder_out_trade_no';
+ END IF;
+END $$;
diff --git a/backend/migrations/121_auth_identity_migration_report_type_widen.sql b/backend/migrations/121_auth_identity_migration_report_type_widen.sql
new file mode 100644
index 00000000..66bfb44a
--- /dev/null
+++ b/backend/migrations/121_auth_identity_migration_report_type_widen.sql
@@ -0,0 +1,2 @@
+ALTER TABLE auth_identity_migration_reports
+ALTER COLUMN report_type TYPE VARCHAR(80);
diff --git a/backend/migrations/122_pending_auth_completion_token_cleanup.sql b/backend/migrations/122_pending_auth_completion_token_cleanup.sql
new file mode 100644
index 00000000..e6341142
--- /dev/null
+++ b/backend/migrations/122_pending_auth_completion_token_cleanup.sql
@@ -0,0 +1,15 @@
+UPDATE pending_auth_sessions
+SET
+ local_flow_state = jsonb_set(
+ local_flow_state,
+ '{completion_response}',
+ ((local_flow_state -> 'completion_response') - 'access_token' - 'refresh_token' - 'expires_in' - 'token_type'),
+ true
+ )
+WHERE jsonb_typeof(local_flow_state -> 'completion_response') = 'object'
+ AND (
+ (local_flow_state -> 'completion_response') ? 'access_token'
+ OR (local_flow_state -> 'completion_response') ? 'refresh_token'
+ OR (local_flow_state -> 'completion_response') ? 'expires_in'
+ OR (local_flow_state -> 'completion_response') ? 'token_type'
+ );
diff --git a/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql
new file mode 100644
index 00000000..4388285a
--- /dev/null
+++ b/backend/migrations/123_fix_legacy_auth_source_grant_on_signup_defaults.sql
@@ -0,0 +1,68 @@
+-- Auto-backfill untouched migration 110 signup-grant defaults to the corrected false value.
+-- Rows still matching the migration-110 default payload and timestamp window are treated as
+-- untouched legacy defaults; any remaining legacy true values are reported for manual review.
+
+WITH migration_110 AS (
+ SELECT applied_at
+ FROM schema_migrations
+ WHERE filename = '110_pending_auth_and_provider_default_grants.sql'
+),
+providers AS (
+ SELECT provider_type
+ FROM (
+ VALUES ('email'), ('linuxdo'), ('oidc'), ('wechat')
+ ) AS providers(provider_type)
+),
+legacy_provider_defaults AS (
+ SELECT providers.provider_type
+ FROM providers
+ CROSS JOIN migration_110
+ JOIN settings balance
+ ON balance.key = 'auth_source_default_' || providers.provider_type || '_balance'
+ JOIN settings concurrency
+ ON concurrency.key = 'auth_source_default_' || providers.provider_type || '_concurrency'
+ JOIN settings subscriptions
+ ON subscriptions.key = 'auth_source_default_' || providers.provider_type || '_subscriptions'
+ JOIN settings grant_on_signup
+ ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup'
+ JOIN settings grant_on_first_bind
+ ON grant_on_first_bind.key = 'auth_source_default_' || providers.provider_type || '_grant_on_first_bind'
+ WHERE balance.value = '0'
+ AND concurrency.value = '5'
+ AND subscriptions.value = '[]'
+ AND grant_on_signup.value = 'true'
+ AND grant_on_first_bind.value = 'false'
+ AND balance.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
+ AND concurrency.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
+ AND subscriptions.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
+ AND grant_on_signup.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
+ AND grant_on_first_bind.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
+),
+updated_signup_grants AS (
+ UPDATE settings
+ SET
+ value = 'false',
+ updated_at = NOW()
+ FROM legacy_provider_defaults
+ WHERE settings.key = 'auth_source_default_' || legacy_provider_defaults.provider_type || '_grant_on_signup'
+ AND settings.value = 'true'
+ RETURNING legacy_provider_defaults.provider_type
+)
+INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
+SELECT
+ 'legacy_auth_source_signup_grant_review',
+ providers.provider_type,
+ jsonb_build_object(
+ 'provider_type', providers.provider_type,
+ 'current_value', grant_on_signup.value,
+ 'auto_backfilled', FALSE,
+ 'reason', 'legacy_true_default_not_auto_backfilled'
+ )
+FROM providers
+JOIN settings grant_on_signup
+ ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup'
+LEFT JOIN updated_signup_grants
+ ON updated_signup_grants.provider_type = providers.provider_type
+WHERE grant_on_signup.value = 'true'
+ AND updated_signup_grants.provider_type IS NULL
+ON CONFLICT (report_type, report_key) DO NOTHING;
diff --git a/backend/migrations/124_backfill_legacy_oidc_security_flags.sql b/backend/migrations/124_backfill_legacy_oidc_security_flags.sql
new file mode 100644
index 00000000..e68bb11a
--- /dev/null
+++ b/backend/migrations/124_backfill_legacy_oidc_security_flags.sql
@@ -0,0 +1,32 @@
+-- Preserve legacy OIDC behavior for upgraded installs that predate the
+-- introduction of secure PKCE/id_token defaults. Fresh installs continue to
+-- inherit runtime defaults when these rows are absent.
+
+WITH legacy_oidc_install AS (
+ SELECT 1
+ FROM settings
+ WHERE key IN (
+ 'oidc_connect_enabled',
+ 'oidc_connect_client_id',
+ 'oidc_connect_authorize_url',
+ 'oidc_connect_token_url',
+ 'oidc_connect_issuer_url',
+ 'oidc_connect_userinfo_url',
+ 'oidc_connect_frontend_redirect_url'
+ )
+ LIMIT 1
+)
+INSERT INTO settings (key, value)
+SELECT defaults.key, 'false'
+FROM legacy_oidc_install
+CROSS JOIN (
+ VALUES
+ ('oidc_connect_use_pkce'),
+ ('oidc_connect_validate_id_token')
+) AS defaults(key)
+WHERE NOT EXISTS (
+ SELECT 1
+ FROM settings existing
+ WHERE existing.key = defaults.key
+)
+ON CONFLICT (key) DO NOTHING;
diff --git a/backend/migrations/125_add_channel_monitors.sql b/backend/migrations/125_add_channel_monitors.sql
new file mode 100644
index 00000000..5ec327da
--- /dev/null
+++ b/backend/migrations/125_add_channel_monitors.sql
@@ -0,0 +1,58 @@
+-- Migration: 125_add_channel_monitors
+-- 渠道监控 MVP:周期性对外部 provider/endpoint/api_key 做模型心跳测试。
+--
+-- 表结构说明:
+-- - channel_monitors 渠道配置表(一行 = 一个监控对象)
+-- - channel_monitor_histories 检测历史明细表(一次检测一个模型 = 一行)
+--
+-- 设计要点:
+-- - api_key_encrypted 列存放 AES-256-GCM 密文(base64),由 service 层加密。
+-- - extra_models 用 JSONB 存储字符串数组,便于扩展(后续可加权重等元数据)。
+-- - history 表通过 ON DELETE CASCADE 自动清理已删除监控的历史。
+-- - (enabled, last_checked_at) 索引服务于调度器扫描“到期需要检测”的监控。
+-- - histories 上 (monitor_id, model, checked_at DESC) 服务用户视图聚合查询;
+-- 单独的 (checked_at) 索引服务定期清理 30 天前数据的 DELETE。
+
+CREATE TABLE IF NOT EXISTS channel_monitors (
+ id BIGSERIAL PRIMARY KEY,
+ name VARCHAR(100) NOT NULL,
+ provider VARCHAR(20) NOT NULL, -- openai / anthropic / gemini
+ endpoint VARCHAR(500) NOT NULL, -- base origin
+ api_key_encrypted TEXT NOT NULL, -- AES-256-GCM (base64)
+ primary_model VARCHAR(200) NOT NULL,
+ extra_models JSONB NOT NULL DEFAULT '[]'::jsonb,
+ group_name VARCHAR(100) NOT NULL DEFAULT '',
+ enabled BOOLEAN NOT NULL DEFAULT TRUE,
+ interval_seconds INT NOT NULL,
+ last_checked_at TIMESTAMPTZ,
+ created_by BIGINT NOT NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT channel_monitors_provider_check CHECK (provider IN ('openai', 'anthropic', 'gemini')),
+ CONSTRAINT channel_monitors_interval_check CHECK (interval_seconds BETWEEN 15 AND 3600)
+);
+
+CREATE INDEX IF NOT EXISTS idx_channel_monitors_enabled_last_checked
+ ON channel_monitors (enabled, last_checked_at);
+CREATE INDEX IF NOT EXISTS idx_channel_monitors_provider
+ ON channel_monitors (provider);
+CREATE INDEX IF NOT EXISTS idx_channel_monitors_group_name
+ ON channel_monitors (group_name);
+
+CREATE TABLE IF NOT EXISTS channel_monitor_histories (
+ id BIGSERIAL PRIMARY KEY,
+ monitor_id BIGINT NOT NULL REFERENCES channel_monitors(id) ON DELETE CASCADE,
+ model VARCHAR(200) NOT NULL,
+ status VARCHAR(20) NOT NULL,
+ latency_ms INT,
+ ping_latency_ms INT,
+ message VARCHAR(500) NOT NULL DEFAULT '',
+ checked_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT channel_monitor_histories_status_check
+ CHECK (status IN ('operational', 'degraded', 'failed', 'error'))
+);
+
+CREATE INDEX IF NOT EXISTS idx_channel_monitor_histories_monitor_model_checked
+ ON channel_monitor_histories (monitor_id, model, checked_at DESC);
+CREATE INDEX IF NOT EXISTS idx_channel_monitor_histories_checked_at
+ ON channel_monitor_histories (checked_at);
diff --git a/backend/migrations/125_add_group_rpm_limit.sql b/backend/migrations/125_add_group_rpm_limit.sql
new file mode 100644
index 00000000..fbde1b20
--- /dev/null
+++ b/backend/migrations/125_add_group_rpm_limit.sql
@@ -0,0 +1,7 @@
+-- Add per-group Requests-Per-Minute limit.
+-- rpm_limit: 分组统一 RPM 上限(0 = 不限制)。
+-- 一旦配置即接管该用户在该分组的限流,覆盖用户级 users.rpm_limit。
+-- 计数键:rpm:ug:{user_id}:{group_id}:{minute}。
+ALTER TABLE groups ADD COLUMN IF NOT EXISTS rpm_limit integer NOT NULL DEFAULT 0;
+
+COMMENT ON COLUMN groups.rpm_limit IS '分组 RPM 上限;0 表示不限制;设置后接管该分组用户的限流(覆盖用户级 rpm_limit)。';
diff --git a/backend/migrations/126_add_channel_monitor_aggregation.sql b/backend/migrations/126_add_channel_monitor_aggregation.sql
new file mode 100644
index 00000000..e643763c
--- /dev/null
+++ b/backend/migrations/126_add_channel_monitor_aggregation.sql
@@ -0,0 +1,60 @@
+-- Migration: 126_add_channel_monitor_aggregation
+-- 渠道监控日聚合:把 channel_monitor_histories 的明细按天聚合,明细只保留 1 天,
+-- 聚合保留 30 天。明细和聚合表都用软删除(deleted_at),由 ops cleanup 任务每天
+-- 凌晨随运维监控清理一起跑(共享 cron)。
+--
+-- 设计要点:
+-- - channel_monitor_histories 加 deleted_at 软删除字段(SoftDeleteMixin 全局
+-- Hook 会把 DELETE 自动改写成 UPDATE deleted_at = NOW())。
+-- - channel_monitor_daily_rollups 按 (monitor_id, model, bucket_date) 唯一,
+-- 用 ON CONFLICT DO UPDATE 实现幂等回填,状态分布和延迟分子分母都保留,
+-- 方便后续按窗口任意求加权可用率和均值。
+-- - watermark 表只有一行(id=1),记录最近一次聚合到达的日期,避免重启后重复
+-- 扫全表。
+-- - rollup 上 (bucket_date) 索引服务清理任务的 DELETE WHERE bucket_date < cutoff。
+
+-- 1) 给历史明细表加软删除字段
+ALTER TABLE channel_monitor_histories
+ ADD COLUMN IF NOT EXISTS deleted_at TIMESTAMPTZ;
+
+CREATE INDEX IF NOT EXISTS idx_channel_monitor_histories_deleted_at
+ ON channel_monitor_histories (deleted_at);
+
+-- 2) 创建日聚合表
+CREATE TABLE IF NOT EXISTS channel_monitor_daily_rollups (
+ id BIGSERIAL PRIMARY KEY,
+ monitor_id BIGINT NOT NULL REFERENCES channel_monitors(id) ON DELETE CASCADE,
+ model VARCHAR(200) NOT NULL,
+ bucket_date DATE NOT NULL,
+ total_checks INT NOT NULL DEFAULT 0,
+ ok_count INT NOT NULL DEFAULT 0,
+ operational_count INT NOT NULL DEFAULT 0,
+ degraded_count INT NOT NULL DEFAULT 0,
+ failed_count INT NOT NULL DEFAULT 0,
+ error_count INT NOT NULL DEFAULT 0,
+ sum_latency_ms BIGINT NOT NULL DEFAULT 0,
+ count_latency INT NOT NULL DEFAULT 0,
+ sum_ping_latency_ms BIGINT NOT NULL DEFAULT 0,
+ count_ping_latency INT NOT NULL DEFAULT 0,
+ computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ deleted_at TIMESTAMPTZ
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_monitor_daily_rollups_unique
+ ON channel_monitor_daily_rollups (monitor_id, model, bucket_date);
+CREATE INDEX IF NOT EXISTS idx_channel_monitor_daily_rollups_bucket
+ ON channel_monitor_daily_rollups (bucket_date);
+CREATE INDEX IF NOT EXISTS idx_channel_monitor_daily_rollups_deleted_at
+ ON channel_monitor_daily_rollups (deleted_at);
+
+-- 3) 创建 watermark 表(单行:id=1)
+CREATE TABLE IF NOT EXISTS channel_monitor_aggregation_watermark (
+ id INT PRIMARY KEY DEFAULT 1,
+ last_aggregated_date DATE,
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT channel_monitor_aggregation_watermark_singleton CHECK (id = 1)
+);
+
+INSERT INTO channel_monitor_aggregation_watermark (id, last_aggregated_date, updated_at)
+VALUES (1, NULL, NOW())
+ON CONFLICT (id) DO NOTHING;
diff --git a/backend/migrations/126_add_user_rpm_limit.sql b/backend/migrations/126_add_user_rpm_limit.sql
new file mode 100644
index 00000000..64a8b977
--- /dev/null
+++ b/backend/migrations/126_add_user_rpm_limit.sql
@@ -0,0 +1,7 @@
+-- Add per-user Requests-Per-Minute cap.
+-- rpm_limit: 用户全局 RPM 兜底(0 = 不限制)。
+-- 仅当所访问分组未设置 rpm_limit 且无 user-group rpm_override 时作为兜底生效。
+-- 计数键:rpm:u:{user_id}:{minute}。
+ALTER TABLE users ADD COLUMN IF NOT EXISTS rpm_limit integer NOT NULL DEFAULT 0;
+
+COMMENT ON COLUMN users.rpm_limit IS '用户级 RPM 兜底上限;0 表示不限制;仅当分组未设置 rpm_limit 时生效。';
diff --git a/backend/migrations/127_add_user_group_rpm_override.sql b/backend/migrations/127_add_user_group_rpm_override.sql
new file mode 100644
index 00000000..1d674258
--- /dev/null
+++ b/backend/migrations/127_add_user_group_rpm_override.sql
@@ -0,0 +1,16 @@
+-- 在已有的"用户专属分组倍率表"上扩展 rpm_override 列;同时放宽 rate_multiplier 为可空,
+-- 使一行记录可以只覆盖 rate、只覆盖 rpm,或同时覆盖两者。
+-- 语义:
+-- - rate_multiplier NULL → 该用户在此分组使用 groups.rate_multiplier 默认值
+-- - rate_multiplier 非 NULL → 覆盖分组默认计费倍率
+-- - rpm_override NULL → 该用户在此分组使用 groups.rpm_limit 默认值
+-- - rpm_override 非 NULL → 覆盖分组默认 RPM(0 = 不限制)
+-- 用户级 users.rpm_limit 仍独立生效(跨分组总配额)。
+ALTER TABLE user_group_rate_multipliers
+ ADD COLUMN IF NOT EXISTS rpm_override integer NULL;
+
+ALTER TABLE user_group_rate_multipliers
+ ALTER COLUMN rate_multiplier DROP NOT NULL;
+
+COMMENT ON COLUMN user_group_rate_multipliers.rate_multiplier IS '专属计费倍率;NULL 表示沿用分组默认倍率。';
+COMMENT ON COLUMN user_group_rate_multipliers.rpm_override IS '专属 RPM 上限;NULL 表示沿用分组默认;0 表示该用户在此分组不受 RPM 限制。';
diff --git a/backend/migrations/127_drop_channel_monitor_deleted_at.sql b/backend/migrations/127_drop_channel_monitor_deleted_at.sql
new file mode 100644
index 00000000..2260f06b
--- /dev/null
+++ b/backend/migrations/127_drop_channel_monitor_deleted_at.sql
@@ -0,0 +1,16 @@
+-- Migration: 127_drop_channel_monitor_deleted_at
+-- 纠正 110 引入的 SoftDeleteMixin:日志/聚合表无恢复需求,软删会让行和索引只增不减,
+-- 徒增磁盘和查询开销。改回分批物理删(由 OpsCleanupService 每天凌晨统一调度,
+-- deleteOldRowsByID 模板,batch=5000)。
+--
+-- 110 尚未跑过聚合/清理(首次 maintenance 在次日 02:00),所以此处不担心业务数据。
+-- 直接 DROP 列 + 索引;对应的 Go 侧 ent schema 已移除 SoftDeleteMixin、repo 的
+-- raw SQL 已移除 deleted_at IS NULL 过滤。
+
+DROP INDEX IF EXISTS idx_channel_monitor_histories_deleted_at;
+ALTER TABLE channel_monitor_histories
+ DROP COLUMN IF EXISTS deleted_at;
+
+DROP INDEX IF EXISTS idx_channel_monitor_daily_rollups_deleted_at;
+ALTER TABLE channel_monitor_daily_rollups
+ DROP COLUMN IF EXISTS deleted_at;
diff --git a/backend/migrations/128_add_channel_monitor_request_templates.sql b/backend/migrations/128_add_channel_monitor_request_templates.sql
new file mode 100644
index 00000000..2db8fef6
--- /dev/null
+++ b/backend/migrations/128_add_channel_monitor_request_templates.sql
@@ -0,0 +1,70 @@
+-- Migration: 128_add_channel_monitor_request_templates
+-- 加请求模板表 + 给 channel_monitors 加 4 个快照字段(template_id 关联引用 + extra_headers /
+-- body_override_mode / body_override 三个真正运行时使用的快照)。
+--
+-- 设计要点:
+-- 1) 模板与监控之间是「应用即拷贝」的快照语义,运行时 checker 不再回查模板表。
+-- 模板 UPDATE 不会自动影响监控;只有用户主动「应用到关联监控」才会刷新快照。
+-- 2) ON DELETE SET NULL:模板删除不级联清理监控;监控保留快照继续工作。
+-- 3) extra_headers / body_override 都是 JSONB;body_override_mode 用 varchar(不是 enum)
+-- 便于将来加新模式无需 ALTER TYPE。
+-- 4) 同一 provider 内模板 name 唯一(允许 Anthropic + OpenAI 重名 "伪装官方客户端")。
+
+CREATE TABLE IF NOT EXISTS channel_monitor_request_templates (
+ id BIGSERIAL PRIMARY KEY,
+ name VARCHAR(100) NOT NULL,
+ provider VARCHAR(20) NOT NULL,
+ description VARCHAR(500) NOT NULL DEFAULT '',
+ extra_headers JSONB NOT NULL DEFAULT '{}'::jsonb,
+ body_override_mode VARCHAR(10) NOT NULL DEFAULT 'off',
+ body_override JSONB NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ CONSTRAINT channel_monitor_request_templates_provider_check
+ CHECK (provider IN ('openai', 'anthropic', 'gemini')),
+ CONSTRAINT channel_monitor_request_templates_body_mode_check
+ CHECK (body_override_mode IN ('off', 'merge', 'replace'))
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS channel_monitor_request_templates_provider_name
+ ON channel_monitor_request_templates (provider, name);
+
+-- channel_monitors 加 4 列(ADD COLUMN IF NOT EXISTS 需要 PG 9.6+,生产使用 PG 16)
+ALTER TABLE channel_monitors
+ ADD COLUMN IF NOT EXISTS template_id BIGINT NULL;
+ALTER TABLE channel_monitors
+ ADD COLUMN IF NOT EXISTS extra_headers JSONB NOT NULL DEFAULT '{}'::jsonb;
+ALTER TABLE channel_monitors
+ ADD COLUMN IF NOT EXISTS body_override_mode VARCHAR(10) NOT NULL DEFAULT 'off';
+ALTER TABLE channel_monitors
+ ADD COLUMN IF NOT EXISTS body_override JSONB NULL;
+
+-- 约束 + 外键(DO 块里 IF NOT EXISTS 判断,保证幂等)
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1 FROM information_schema.table_constraints
+ WHERE constraint_name = 'channel_monitors_body_mode_check'
+ AND table_name = 'channel_monitors'
+ ) THEN
+ ALTER TABLE channel_monitors
+ ADD CONSTRAINT channel_monitors_body_mode_check
+ CHECK (body_override_mode IN ('off', 'merge', 'replace'));
+ END IF;
+
+ IF NOT EXISTS (
+ SELECT 1 FROM information_schema.table_constraints
+ WHERE constraint_name = 'channel_monitors_template_id_fkey'
+ AND table_name = 'channel_monitors'
+ ) THEN
+ ALTER TABLE channel_monitors
+ ADD CONSTRAINT channel_monitors_template_id_fkey
+ FOREIGN KEY (template_id)
+ REFERENCES channel_monitor_request_templates (id)
+ ON DELETE SET NULL;
+ END IF;
+END $$;
+
+CREATE INDEX IF NOT EXISTS idx_channel_monitors_template_id
+ ON channel_monitors (template_id)
+ WHERE template_id IS NOT NULL;
diff --git a/backend/migrations/129_seed_claude_code_template.sql b/backend/migrations/129_seed_claude_code_template.sql
new file mode 100644
index 00000000..d9b062c9
--- /dev/null
+++ b/backend/migrations/129_seed_claude_code_template.sql
@@ -0,0 +1,38 @@
+-- Migration: 129_seed_claude_code_template
+-- 内置「Claude Code 伪装」请求模板,覆盖 Anthropic 上游对官方 CLI 客户端的所有验证项:
+-- 1) User-Agent / X-App / anthropic-beta / anthropic-version 等头
+-- 2) system 数组首项与官方 system prompt 字面一致(Dice >= 0.5)
+-- 3) metadata.user_id 满足 ParseMetadataUserID — 这里用 legacy 格式(user_<64hex>_account__session_<36char>)
+-- 避免新版 JSON 字符串内嵌 JSON 在编辑器里出现一长串 \" 转义,便于用户阅读。
+--
+-- ON CONFLICT DO NOTHING:已部署环境(手动建过模板)跑此 migration 不会重复 / 覆盖。
+-- 用户可自行编辑后续覆盖此 seed;CC 升大版时再起一条 migration 提供新模板,不动用户的旧模板。
+
+INSERT INTO channel_monitor_request_templates (
+ name, provider, description, extra_headers, body_override_mode, body_override
+)
+VALUES (
+ 'Claude Code 伪装',
+ 'anthropic',
+ '完整模拟 Claude Code 2.1.114 客户端:UA + anthropic-beta + system + metadata.user_id 全部对齐,绕过 Anthropic 上游 ''Claude Code only'' 限制(如 Max 套餐)。',
+ '{
+ "User-Agent": "claude-cli/2.1.114 (external, sdk-cli)",
+ "X-App": "cli",
+ "anthropic-version": "2023-06-01",
+ "anthropic-beta": "claude-code-20250219,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05,advisor-tool-2026-03-01",
+ "anthropic-dangerous-direct-browser-access": "true"
+ }'::jsonb,
+ 'merge',
+ '{
+ "system": [
+ {
+ "type": "text",
+ "text": "You are Claude Code, Anthropic''s official CLI for Claude."
+ }
+ ],
+ "metadata": {
+ "user_id": "user_0000000000000000000000000000000000000000000000000000000000000000_account_00000000-0000-0000-0000-000000000000_session_00000000-0000-0000-0000-000000000000"
+ }
+ }'::jsonb
+)
+ON CONFLICT (provider, name) DO NOTHING;
diff --git a/backend/migrations/130_add_user_affiliates.sql b/backend/migrations/130_add_user_affiliates.sql
new file mode 100644
index 00000000..d8c001e0
--- /dev/null
+++ b/backend/migrations/130_add_user_affiliates.sql
@@ -0,0 +1,20 @@
+CREATE TABLE IF NOT EXISTS user_affiliates (
+ user_id BIGINT PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE,
+ aff_code VARCHAR(32) NOT NULL UNIQUE,
+ inviter_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
+ aff_count INTEGER NOT NULL DEFAULT 0,
+ aff_quota DECIMAL(20,8) NOT NULL DEFAULT 0,
+ aff_history_quota DECIMAL(20,8) NOT NULL DEFAULT 0,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS idx_user_affiliates_inviter_id ON user_affiliates(inviter_id);
+CREATE INDEX IF NOT EXISTS idx_user_affiliates_aff_quota ON user_affiliates(aff_quota);
+
+COMMENT ON TABLE user_affiliates IS '用户邀请返利信息';
+COMMENT ON COLUMN user_affiliates.aff_code IS '用户邀请代码';
+COMMENT ON COLUMN user_affiliates.inviter_id IS '邀请人用户ID';
+COMMENT ON COLUMN user_affiliates.aff_count IS '累计邀请人数';
+COMMENT ON COLUMN user_affiliates.aff_quota IS '当前可提取返利金额';
+COMMENT ON COLUMN user_affiliates.aff_history_quota IS '累计返利历史金额';
diff --git a/backend/migrations/131_affiliate_rebate_hardening.sql b/backend/migrations/131_affiliate_rebate_hardening.sql
new file mode 100644
index 00000000..81e37a9e
--- /dev/null
+++ b/backend/migrations/131_affiliate_rebate_hardening.sql
@@ -0,0 +1,58 @@
+-- 1) Normalize historical affiliate rebate rate values.
+-- Legacy compatibility treated 0 20%).
+-- We now use pure percentage semantics, so convert persisted fractional values once.
+UPDATE settings
+SET value = to_char((value::numeric * 100), 'FM999999990.########'),
+ updated_at = NOW()
+WHERE key = 'affiliate_rebate_rate'
+ AND value ~ '^-?[0-9]+(\\.[0-9]+)?$'
+ AND value::numeric > 0
+ AND value::numeric <= 1;
+
+-- 2) Affiliate ledger for accrual/transfer traceability.
+CREATE TABLE IF NOT EXISTS user_affiliate_ledger (
+ id BIGSERIAL PRIMARY KEY,
+ user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
+ action VARCHAR(32) NOT NULL,
+ amount DECIMAL(20,8) NOT NULL,
+ source_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_user_id ON user_affiliate_ledger(user_id);
+CREATE INDEX IF NOT EXISTS idx_user_affiliate_ledger_action ON user_affiliate_ledger(action);
+
+COMMENT ON TABLE user_affiliate_ledger IS '邀请返利资金流水(累计/转入)';
+COMMENT ON COLUMN user_affiliate_ledger.action IS 'accrue|transfer';
+
+-- 3) Enforce idempotency at DB layer for payment audit actions.
+WITH ranked AS (
+ SELECT id,
+ ROW_NUMBER() OVER (PARTITION BY order_id, action ORDER BY id) AS rn
+ FROM payment_audit_logs
+)
+DELETE FROM payment_audit_logs p
+USING ranked r
+WHERE p.id = r.id
+ AND r.rn > 1;
+
+CREATE UNIQUE INDEX IF NOT EXISTS idx_payment_audit_logs_order_action_uniq
+ON payment_audit_logs(order_id, action);
+
+-- 4) Prevent retroactive affiliate rebate issuance for legacy completed balance orders.
+INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at)
+SELECT po.id::text,
+ 'AFFILIATE_REBATE_SKIPPED',
+ '{"reason":"baseline before affiliate rebate idempotency rollout"}',
+ 'system',
+ NOW()
+FROM payment_orders po
+WHERE po.order_type = 'balance'
+ AND po.status = 'COMPLETED'
+ AND NOT EXISTS (
+ SELECT 1
+ FROM payment_audit_logs pal
+ WHERE pal.order_id = po.id::text
+ AND pal.action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED')
+ );
diff --git a/backend/migrations/132_affiliate_custom_settings.sql b/backend/migrations/132_affiliate_custom_settings.sql
new file mode 100644
index 00000000..840fe8e0
--- /dev/null
+++ b/backend/migrations/132_affiliate_custom_settings.sql
@@ -0,0 +1,16 @@
+-- 邀请返利:用户专属配置增强
+-- 1) aff_rebate_rate_percent: 用户作为邀请人时的专属返利比例(百分比,NULL 表示沿用全局比例)
+-- 2) aff_code_custom: 标记当前 aff_code 是否被管理员手动改写过(用于"专属用户"列表筛选)
+
+ALTER TABLE user_affiliates
+ ADD COLUMN IF NOT EXISTS aff_rebate_rate_percent DECIMAL(5,2);
+
+ALTER TABLE user_affiliates
+ ADD COLUMN IF NOT EXISTS aff_code_custom BOOLEAN NOT NULL DEFAULT false;
+
+CREATE INDEX IF NOT EXISTS idx_user_affiliates_admin_settings
+ ON user_affiliates (updated_at)
+ WHERE aff_code_custom = true OR aff_rebate_rate_percent IS NOT NULL;
+
+COMMENT ON COLUMN user_affiliates.aff_rebate_rate_percent IS '专属返利比例(百分比 0-100,NULL 表示沿用全局)';
+COMMENT ON COLUMN user_affiliates.aff_code_custom IS '邀请码是否由管理员改写过(用于专属用户筛选)';
diff --git a/backend/migrations/133_affiliate_rebate_freeze.sql b/backend/migrations/133_affiliate_rebate_freeze.sql
new file mode 100644
index 00000000..b87d59b7
--- /dev/null
+++ b/backend/migrations/133_affiliate_rebate_freeze.sql
@@ -0,0 +1,17 @@
+-- 1) Add frozen quota column to user_affiliates for rebate freeze period.
+ALTER TABLE user_affiliates
+ ADD COLUMN IF NOT EXISTS aff_frozen_quota DECIMAL(20,8) NOT NULL DEFAULT 0;
+
+COMMENT ON COLUMN user_affiliates.aff_frozen_quota IS 'Rebate quota currently frozen (pending thaw after freeze period)';
+
+-- 2) Add frozen_until column to user_affiliate_ledger for per-entry freeze tracking.
+-- NULL = no freeze (or already thawed); non-NULL = frozen until this timestamp.
+ALTER TABLE user_affiliate_ledger
+ ADD COLUMN IF NOT EXISTS frozen_until TIMESTAMPTZ NULL;
+
+COMMENT ON COLUMN user_affiliate_ledger.frozen_until IS 'Rebate frozen until this time; NULL means already thawed or never frozen';
+
+-- 3) Partial index for efficient thaw queries (only rows still frozen).
+CREATE INDEX IF NOT EXISTS idx_ual_frozen_thaw
+ ON user_affiliate_ledger (user_id, frozen_until)
+ WHERE frozen_until IS NOT NULL;
diff --git a/backend/migrations/auth_identity_payment_migrations_regression_test.go b/backend/migrations/auth_identity_payment_migrations_regression_test.go
new file mode 100644
index 00000000..798ae0fe
--- /dev/null
+++ b/backend/migrations/auth_identity_payment_migrations_regression_test.go
@@ -0,0 +1,129 @@
+package migrations
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestMigration112UsesIdempotentAddColumn(t *testing.T) {
+ content, err := FS.ReadFile("112_add_payment_order_provider_key_snapshot.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "ADD COLUMN IF NOT EXISTS provider_key VARCHAR(30)")
+ require.NotContains(t, sql, "ADD COLUMN provider_key VARCHAR(30);")
+}
+
+func TestMigration118DoesNotForceOverwriteAuthSourceGrantDefaults(t *testing.T) {
+ content, err := FS.ReadFile("118_wechat_dual_mode_and_auth_source_defaults.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.NotContains(t, sql, "UPDATE settings")
+ require.NotContains(t, sql, "SET value = 'false'")
+ require.True(t, strings.Contains(sql, "ON CONFLICT (key) DO NOTHING"))
+ require.Contains(t, sql, "THEN ''")
+}
+
+func TestAuthIdentityReportTypeWideningRunsBeforeLongReportWritersAndStillReconcilesAt121(t *testing.T) {
+ preflightContent, err := FS.ReadFile("108a_widen_auth_identity_migration_report_type.sql")
+ require.NoError(t, err)
+
+ preflightSQL := string(preflightContent)
+ require.Contains(t, preflightSQL, "ALTER TABLE auth_identity_migration_reports")
+ require.Contains(t, preflightSQL, "ALTER COLUMN report_type TYPE VARCHAR(80)")
+
+ content, err := FS.ReadFile("109_auth_identity_compat_backfill.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.NotContains(t, sql, "ALTER TABLE auth_identity_migration_reports")
+
+ followupContent, err := FS.ReadFile("121_auth_identity_migration_report_type_widen.sql")
+ require.NoError(t, err)
+
+ followupSQL := string(followupContent)
+ require.Contains(t, followupSQL, "ALTER TABLE auth_identity_migration_reports")
+ require.Contains(t, followupSQL, "ALTER COLUMN report_type TYPE VARCHAR(80)")
+}
+
+func TestMigration119DefersPaymentIndexRolloutToOnlineFollowup(t *testing.T) {
+ content, err := FS.ReadFile("119_enforce_payment_orders_out_trade_no_unique.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "120_enforce_payment_orders_out_trade_no_unique_notx.sql")
+ require.Contains(t, sql, "NULL;")
+ require.NotContains(t, sql, "CREATE UNIQUE INDEX")
+ require.NotContains(t, sql, "DROP INDEX")
+
+ followupContent, err := FS.ReadFile("120_enforce_payment_orders_out_trade_no_unique_notx.sql")
+ require.NoError(t, err)
+
+ followupSQL := string(followupContent)
+ require.Contains(t, followupSQL, "explicit duplicate out_trade_no precheck")
+ require.Contains(t, followupSQL, "stale invalid paymentorder_out_trade_no_unique index")
+ require.Contains(t, followupSQL, "CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique")
+ require.NotContains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique")
+ require.Contains(t, followupSQL, "DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no")
+ require.Contains(t, followupSQL, "WHERE out_trade_no <> ''")
+
+ alignmentContent, err := FS.ReadFile("120a_align_payment_orders_out_trade_no_index_name.sql")
+ require.NoError(t, err)
+
+ alignmentSQL := string(alignmentContent)
+ require.Contains(t, alignmentSQL, "paymentorder_out_trade_no_unique")
+ require.Contains(t, alignmentSQL, "RENAME TO paymentorder_out_trade_no")
+}
+
+func TestMigration110SeedsAuthSourceSignupGrantsDisabledByDefault(t *testing.T) {
+ content, err := FS.ReadFile("110_pending_auth_and_provider_default_grants.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "('auth_source_default_email_grant_on_signup', 'false')")
+ require.Contains(t, sql, "('auth_source_default_linuxdo_grant_on_signup', 'false')")
+ require.Contains(t, sql, "('auth_source_default_oidc_grant_on_signup', 'false')")
+ require.Contains(t, sql, "('auth_source_default_wechat_grant_on_signup', 'false')")
+ require.NotContains(t, sql, "('auth_source_default_email_grant_on_signup', 'true')")
+}
+
+func TestMigration122ScrubsPendingOAuthCompletionTokensAtRest(t *testing.T) {
+ content, err := FS.ReadFile("122_pending_auth_completion_token_cleanup.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "UPDATE pending_auth_sessions")
+ require.Contains(t, sql, "completion_response")
+ require.Contains(t, sql, "access_token")
+ require.Contains(t, sql, "refresh_token")
+ require.Contains(t, sql, "expires_in")
+ require.Contains(t, sql, "token_type")
+}
+
+func TestMigration123BackfillsLegacyAuthSourceGrantDefaultsSafely(t *testing.T) {
+ content, err := FS.ReadFile("123_fix_legacy_auth_source_grant_on_signup_defaults.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "110_pending_auth_and_provider_default_grants.sql")
+ require.Contains(t, sql, "schema_migrations")
+ require.Contains(t, sql, "updated_at")
+ require.Contains(t, sql, "'_grant_on_signup'")
+ require.Contains(t, sql, "value = 'false'")
+ require.Contains(t, sql, "auth_identity_migration_reports")
+}
+
+func TestMigration124BackfillsLegacyOIDCSecurityFlagsSafely(t *testing.T) {
+ content, err := FS.ReadFile("124_backfill_legacy_oidc_security_flags.sql")
+ require.NoError(t, err)
+
+ sql := string(content)
+ require.Contains(t, sql, "oidc_connect_use_pkce")
+ require.Contains(t, sql, "oidc_connect_validate_id_token")
+ require.Contains(t, sql, "ON CONFLICT (key) DO NOTHING")
+ require.Contains(t, sql, "oidc_connect_enabled")
+ require.Contains(t, sql, "'false'")
+}
diff --git a/deploy/Dockerfile b/deploy/Dockerfile
index 7caa5ca6..b0b6036c 100644
--- a/deploy/Dockerfile
+++ b/deploy/Dockerfile
@@ -7,7 +7,7 @@
# =============================================================================
ARG NODE_IMAGE=node:24-alpine
-ARG GOLANG_IMAGE=golang:1.26.1-alpine
+ARG GOLANG_IMAGE=golang:1.26.2-alpine
ARG ALPINE_IMAGE=alpine:3.20
ARG GOPROXY=https://goproxy.cn,direct
ARG GOSUMDB=sum.golang.google.cn
diff --git a/deploy/codex-instructions.md.tmpl b/deploy/codex-instructions.md.tmpl
new file mode 100644
index 00000000..87ad0a3d
--- /dev/null
+++ b/deploy/codex-instructions.md.tmpl
@@ -0,0 +1,5 @@
+You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
+
+{{ if .ExistingInstructions }}
+{{ .ExistingInstructions }}
+{{ end }}
diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml
index 8f60acd5..dfc363b5 100644
--- a/deploy/config.example.yaml
+++ b/deploy/config.example.yaml
@@ -202,6 +202,32 @@ gateway:
#
# 注意:开启后会影响所有客户端的行为(不仅限于 VS Code / Codex CLI),请谨慎开启。
force_codex_cli: false
+ # Optional: template file used to build the final top-level Codex `instructions`.
+ # 可选:用于构建最终 Codex 顶层 `instructions` 的模板文件路径。
+ #
+ # This is applied on the `/v1/messages -> Responses/Codex` conversion path,
+ # after Claude `system` has already been normalized into Codex `instructions`.
+ # 该模板作用于 `/v1/messages -> Responses/Codex` 转换链路,且发生在 Claude `system`
+ # 已经被归一化为 Codex `instructions` 之后。
+ #
+ # The template can reference:
+ # 模板可引用:
+ # - {{ .ExistingInstructions }} : converted client instructions/system
+ # - {{ .OriginalModel }} : original requested model
+ # - {{ .NormalizedModel }} : normalized routing model
+ # - {{ .BillingModel }} : billing model
+ # - {{ .UpstreamModel }} : final upstream model
+ #
+ # If you want to preserve client system prompts, keep {{ .ExistingInstructions }}
+ # somewhere in the template. If omitted, the template output fully replaces it.
+ # 如需保留客户端 system 提示词,请在模板中显式包含 {{ .ExistingInstructions }}。
+ # 若省略,则模板输出会完全覆盖它。
+ #
+ # Docker users can mount a host file to /app/data/codex-instructions.md.tmpl
+ # and point this field there.
+ # Docker 用户可将宿主机文件挂载到 /app/data/codex-instructions.md.tmpl,
+ # 然后把本字段指向该路径。
+ forced_codex_instructions_template_file: ""
# OpenAI 透传模式是否放行客户端超时头(如 x-stainless-timeout)
# 默认 false:过滤超时头,降低上游提前断流风险。
openai_passthrough_allow_timeout_headers: false
@@ -815,7 +841,47 @@ linuxdo_connect:
frontend_redirect_url: "/auth/linuxdo/callback"
token_auth_method: "client_secret_post" # client_secret_post | client_secret_basic | none
# 注意:当 token_auth_method=none(public client)时,必须启用 PKCE
+ use_pkce: true
+ userinfo_email_path: ""
+ userinfo_id_path: ""
+ userinfo_username_path: ""
+
+# =============================================================================
+# Generic OIDC OAuth Login (SSO)
+# 通用 OIDC OAuth 登录(用于 Sub2API 用户登录)
+# =============================================================================
+oidc_connect:
+ enabled: false
+ provider_name: "OIDC"
+ client_id: ""
+ client_secret: ""
+ # 例如: "https://keycloak.example.com/realms/myrealm"
+ issuer_url: ""
+ # 可选: OIDC Discovery URL。为空时可手动填写 authorize/token/userinfo/jwks
+ discovery_url: ""
+ authorize_url: ""
+ token_url: ""
+ # 可选(仅补充 email/username,不用于 sub 可信绑定)
+ userinfo_url: ""
+ # validate_id_token=true 时必填
+ jwks_url: ""
+ scopes: "openid email profile"
+ # 示例: "https://your-domain.com/api/v1/auth/oauth/oidc/callback"
+ redirect_url: ""
+ # 安全提示:
+ # - 建议使用同源相对路径(以 / 开头),避免把 token 重定向到意外的第三方域名
+ # - 该地址不应包含 #fragment(本实现使用 URL fragment 传递 access_token)
+ frontend_redirect_url: "/auth/oidc/callback"
+ token_auth_method: "client_secret_post" # client_secret_post | client_secret_basic | none
+ # 注意:当 token_auth_method=none(public client)时,必须启用 PKCE
use_pkce: false
+ # 开启后强制校验 id_token 的签名和 claims(推荐)
+ validate_id_token: true
+ allowed_signing_algs: "RS256,ES256,PS256"
+ # 允许的时钟偏移(秒)
+ clock_skew_seconds: 120
+ # 若 Provider 返回 email_verified=false,是否拒绝登录
+ require_email_verified: false
userinfo_email_path: ""
userinfo_id_path: ""
userinfo_username_path: ""
diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml
index d1a564a3..8d412321 100644
--- a/deploy/docker-compose.yml
+++ b/deploy/docker-compose.yml
@@ -32,6 +32,10 @@ services:
# Optional: Mount custom config.yaml (uncomment and create the file first)
# Copy config.example.yaml to config.yaml, modify it, then uncomment:
# - ./config.yaml:/app/data/config.yaml
+ # Optional: Mount a custom Codex instructions template file, then point
+ # gateway.forced_codex_instructions_template_file at /app/data/codex-instructions.md.tmpl
+ # in config.yaml.
+ # - ./codex-instructions.md.tmpl:/app/data/codex-instructions.md.tmpl:ro
environment:
# Auto Setup
- AUTO_SETUP=true
@@ -130,7 +134,17 @@ services:
- sub2api-network
- 1panel-network
healthcheck:
- test: ["CMD", "wget", "-q", "-T", "5", "-O", "/dev/null", "http://localhost:8080/health"]
+ test:
+ [
+ "CMD",
+ "wget",
+ "-q",
+ "-T",
+ "5",
+ "-O",
+ "/dev/null",
+ "http://localhost:8080/health",
+ ]
interval: 30s
timeout: 10s
retries: 3
@@ -161,7 +175,11 @@ services:
networks:
- sub2api-network
healthcheck:
- test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}"]
+ test:
+ [
+ "CMD-SHELL",
+ "pg_isready -U ${POSTGRES_USER:-sub2api} -d ${POSTGRES_DB:-sub2api}",
+ ]
interval: 10s
timeout: 5s
retries: 5
diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md
new file mode 100644
index 00000000..af93fa7e
--- /dev/null
+++ b/docs/PAYMENT.md
@@ -0,0 +1,287 @@
+# Payment System Configuration Guide
+
+Sub2API has a built-in payment system that enables user self-service top-up without deploying a separate payment service.
+
+---
+
+## Table of Contents
+
+- [Supported Payment Methods](#supported-payment-methods)
+- [Quick Start](#quick-start)
+- [System Settings](#system-settings)
+- [Provider Configuration](#provider-configuration)
+- [Provider Instance Management](#provider-instance-management)
+- [Webhook Configuration](#webhook-configuration)
+- [Payment Flow](#payment-flow)
+- [Migrating from Sub2ApiPay](#migrating-from-sub2apipay)
+
+---
+
+## Supported Payment Methods
+
+| Provider | Payment Methods | Description |
+|----------|----------------|-------------|
+| **EasyPay** | Alipay, WeChat Pay | Third-party aggregation via EasyPay protocol |
+| **Alipay (Direct)** | Desktop QR code, mobile Alipay redirect | Direct integration with Alipay Open Platform, returning desktop QR codes and mobile WAP/app launch links |
+| **WeChat Pay (Direct)** | Native QR, H5, MP/JSAPI Pay | Direct integration with WeChat Pay APIv3 with environment-aware routing |
+| **Stripe** | Card, Alipay, WeChat Pay, Link, etc. | International payments, multi-currency support |
+
+> Alipay/WeChat Pay direct and EasyPay can both exist as backend provider instances, but the frontend always exposes only two visible buttons: `Alipay` and `WeChat Pay`. Admins choose exactly one source for each visible method: direct or EasyPay. Direct channels connect to payment APIs directly with lower fees; EasyPay aggregates through third-party platforms with easier setup.
+
+> **EasyPay Provider Recommendations**: Both options below are third-party aggregators compatible with the EasyPay protocol. Pick based on the funding channel and settlement currency you need:
+>
+> - **Domestic channel / CNY settlement** — [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`): direct integration with official Alipay / WeChat Pay APIs, fee **1.6%**; funds go straight to the merchant account with **T+1 automatic settlement**. Supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it.
+> - **International channel / USDT or USD settlement** — [Kyren Topup](https://kyren.top/?code=SUB2API) (`https://kyren.top/?code=SUB2API`): a ready-to-launch global payment stack for AI startups with WeChat Pay and Alipay support, local-currency checkout, and USD settlement. Fees: WeChat 2%, Alipay 2.5%; withdrawal 0.1% (min $40, max $150), settled in **USDT or USD**. No qualification review required — sign up and use immediately, making it the lowest barrier to entry. Withdrawal threshold is relatively high, recommended for users **who do not use domestic Chinese payment channels, cannot tolerate Stripe's 6%+ fees, have high transaction volume, and have USD or USDT channels to receive withdrawn funds**. Kyren Topup charges a $200 account opening fee; signing up via this link (which contains Sub2Api author [@Wei-Shaw](https://github.com/Wei-Shaw)'s referral code) **waives the opening fee**. Feel free to remove it if you prefer.
+>
+> Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them.
+
+---
+
+## Quick Start
+
+1. Go to Admin Dashboard → **Settings** → **Payment Settings** tab
+2. Enable **Payment**
+3. Configure basic parameters (amount range, timeout, etc.)
+4. Add at least one provider instance in **Provider Management**
+5. Users can now top up from the frontend
+
+---
+
+## System Settings
+
+Configure the following in Admin Dashboard **Settings → Payment Settings**:
+
+### Basic Settings
+
+| Setting | Description | Default |
+|---------|-------------|---------|
+| **Enable Payment** | Enable or disable the payment system | Off |
+| **Product Name Prefix** | Prefix shown on payment page | - |
+| **Product Name Suffix** | Suffix (e.g., "Credits") | - |
+| **Minimum Amount** | Minimum single top-up amount | 1 |
+| **Maximum Amount** | Maximum single top-up amount (empty = unlimited) | - |
+| **Daily Limit** | Per-user daily cumulative limit (empty = unlimited) | - |
+| **Order Timeout** | Order timeout in minutes (minimum 1) | 30 |
+| **Max Pending Orders** | Maximum concurrent pending orders per user | 3 |
+| **Load Balance Strategy** | Strategy for selecting provider instances | Round Robin |
+
+### Frontend Visible Method Routing
+
+The current payment UX keeps the frontend method list unified and does not expose provider brands directly:
+
+- **Alipay**: when enabled, this button must be routed to either `Alipay (Direct)` or `EasyPay Alipay`
+- **WeChat Pay**: when enabled, this button must be routed to either `WeChat Pay (Direct)` or `EasyPay WeChat`
+- Each visible method can route to only one source at a time
+- If a visible method is enabled without a selected source, the frontend will not expose that method
+
+### Load Balance Strategies
+
+| Strategy | Description |
+|----------|-------------|
+| **Round Robin** | Distribute orders to instances in rotation |
+| **Least Amount** | Prefer instances with the lowest daily cumulative amount |
+
+### Cancel Rate Limiting
+
+Prevents users from repeatedly creating and canceling orders:
+
+| Setting | Description |
+|---------|-------------|
+| **Enable Limit** | Toggle |
+| **Window Mode** | Sliding / Fixed window |
+| **Time Window** | Window duration |
+| **Window Unit** | Minutes / Hours |
+| **Max Cancels** | Maximum cancellations allowed within the window |
+
+### Help Information
+
+| Setting | Description |
+|---------|-------------|
+| **Help Image** | Customer service QR code or help image (supports upload) |
+| **Help Text** | Instructions displayed on the payment page |
+
+---
+
+## Provider Configuration
+
+Each provider type requires different credentials. Select the type when adding a new provider instance in **Provider Management → Add Provider**.
+
+> **Callback URLs are auto-generated**: When adding a provider, the Notify URL and Return URL are automatically constructed from your site domain. You only need to confirm the domain is correct.
+
+### EasyPay
+
+Compatible with any payment service that implements the EasyPay protocol.
+
+| Parameter | Description | Required |
+|-----------|-------------|----------|
+| **Merchant ID (PID)** | EasyPay merchant ID | Yes |
+| **Merchant Key (PKey)** | EasyPay merchant secret key | Yes |
+| **API Base URL** | EasyPay API base address | Yes |
+| **Alipay Channel ID** | Specify Alipay channel (optional) | No |
+| **WeChat Channel ID** | Specify WeChat channel (optional) | No |
+
+### Alipay (Direct)
+
+Direct integration with Alipay Open Platform. Mobile flows return an Alipay WAP/app redirect URL. Desktop flows prefer Face-to-Face Precreate QR payloads; if the merchant has not enabled that product, the provider falls back to Computer Website Pay and also returns the cashier URL so the frontend can render a QR code or open the hosted checkout page directly.
+
+| Parameter | Description | Required |
+|-----------|-------------|----------|
+| **AppID** | Alipay application AppID | Yes |
+| **Private Key** | RSA2 application private key | Yes |
+| **Alipay Public Key** | Alipay public key | Yes |
+
+### WeChat Pay (Direct)
+
+Direct integration with WeChat Pay APIv3. Supports Native QR code payment, H5 payment, and MP/JSAPI payment inside the WeChat environment.
+
+| Parameter | Description | Required |
+|-----------|-------------|----------|
+| **AppID** | WeChat Pay AppID | Yes |
+| **Merchant ID (MchID)** | WeChat Pay merchant ID | Yes |
+| **Merchant API Private Key** | Merchant API private key (PEM format) | Yes |
+| **APIv3 Key** | 32-byte APIv3 key | Yes |
+| **WeChat Pay Public Key** | WeChat Pay public key (PEM format) | Yes |
+| **WeChat Pay Public Key ID** | WeChat Pay public key ID | Yes |
+| **Certificate Serial Number** | Merchant certificate serial number | Yes |
+
+### Stripe
+
+International payment platform supporting multiple payment methods and currencies.
+
+| Parameter | Description | Required |
+|-----------|-------------|----------|
+| **Secret Key** | Stripe secret key (`sk_live_...` or `sk_test_...`) | Yes |
+| **Publishable Key** | Stripe publishable key (`pk_live_...` or `pk_test_...`) | Yes |
+| **Webhook Secret** | Stripe Webhook signing secret (`whsec_...`) | Yes |
+
+---
+
+## Provider Instance Management
+
+You can create **multiple instances** of the same provider type for load balancing and risk control:
+
+- **Multi-instance load balancing** — Distribute orders via round-robin or least-amount strategy
+- **Independent limits** — Each instance can have its own min/max amount and daily limit
+- **Independent toggle** — Enable/disable individual instances without affecting others
+- **Refund control** — Enable or disable refunds per instance
+- **Payment methods** — Each instance can support a subset of payment methods
+- **Ordering** — Drag to reorder instances
+
+### Instance Limit Configuration
+
+Each instance supports these limits:
+
+| Limit | Description |
+|-------|-------------|
+| **Minimum Amount** | Minimum order amount accepted by this instance |
+| **Maximum Amount** | Maximum order amount accepted by this instance |
+| **Daily Limit** | Daily cumulative transaction limit for this instance |
+
+> During load balancing, instances that exceed their limits are automatically skipped.
+
+---
+
+## Webhook Configuration
+
+Payment callbacks are essential for the payment system to work correctly.
+
+### Callback URL Format
+
+When adding a provider, the system auto-generates callback URLs from your site domain:
+
+| Provider | Callback Path |
+|----------|-------------|
+| **EasyPay** | `https://your-domain.com/api/v1/payment/webhook/easypay` |
+| **Alipay (Direct)** | `https://your-domain.com/api/v1/payment/webhook/alipay` |
+| **WeChat Pay (Direct)** | `https://your-domain.com/api/v1/payment/webhook/wxpay` |
+| **Stripe** | `https://your-domain.com/api/v1/payment/webhook/stripe` |
+
+> Replace `your-domain.com` with your actual domain. For EasyPay / Alipay / WeChat Pay, the callback URL is auto-filled when adding the provider — no manual configuration needed.
+
+### Stripe Webhook Setup
+
+1. Log in to [Stripe Dashboard](https://dashboard.stripe.com/)
+2. Go to **Developers → Webhooks**
+3. Add an endpoint with the callback URL
+4. Subscribe to events: `payment_intent.succeeded`, `payment_intent.payment_failed`
+5. Copy the generated Webhook Secret (`whsec_...`) to your provider configuration
+
+### Important Notes
+
+- Callback URLs must use **HTTPS** (required by Stripe, strongly recommended for others)
+- Ensure your firewall allows callback requests from payment platforms
+- The system automatically verifies callback signatures to prevent forgery
+- Balance top-up is processed automatically upon successful payment — no manual intervention needed
+
+---
+
+## Payment Flow
+
+```
+User selects amount and payment method
+ │
+ ▼
+ Create Order (PENDING)
+ ├─ Validate amount range, pending order count, daily limit
+ ├─ Load balance to select provider instance
+ └─ Call provider to get payment info
+ │
+ ▼
+ User completes payment
+ ├─ EasyPay → QR code / H5 redirect
+ ├─ Alipay → Desktop QR payload (Face-to-Face preferred, Website Pay fallback) / mobile Alipay redirect
+ ├─ WeChat Pay → Desktop Native QR / non-WeChat H5 / in-WeChat JSAPI
+ └─ Stripe → Payment Element (card/Alipay/WeChat/etc.)
+ │
+ ▼
+ Webhook callback verified → Order PAID
+ │
+ ▼
+ Auto top-up to user balance → Order COMPLETED
+```
+
+### Order Status Reference
+
+| Status | Description |
+|--------|-------------|
+| `PENDING` | Waiting for user to complete payment |
+| `PAID` | Payment confirmed, awaiting balance credit |
+| `COMPLETED` | Balance credited successfully |
+| `EXPIRED` | Timed out without payment |
+| `CANCELLED` | Cancelled by user |
+| `FAILED` | Balance credit failed, admin can retry |
+| `REFUND_REQUESTED` | Refund requested |
+| `REFUNDING` | Refund in progress |
+| `REFUNDED` | Refund completed |
+
+### Timeout and Fallback
+
+- Before marking an order as expired, the background job queries the upstream payment status first
+- If the user has actually paid but the callback was delayed, the system will reconcile automatically
+- The background job runs every 60 seconds to check for timed-out orders
+
+---
+
+## Migrating from Sub2ApiPay
+
+If you previously used [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) as an external payment system, you can migrate to the built-in payment system:
+
+### Key Differences
+
+| Aspect | Sub2ApiPay | Built-in Payment |
+|--------|-----------|-----------------|
+| Deployment | Separate service (Next.js + PostgreSQL) | Built into Sub2API, no extra deployment |
+| Payment Methods | EasyPay, Alipay, WeChat, Stripe | Same |
+| Configuration | Environment variables + separate admin UI | Unified in Sub2API admin dashboard |
+| Top-up Integration | Via Admin API callback | Internal processing, more reliable |
+| Subscription Plans | Supported | Not yet (planned) |
+| Order Management | Separate admin interface | Integrated in Sub2API admin dashboard |
+
+### Migration Steps
+
+1. Enable payment in Sub2API admin dashboard and configure providers (use the same payment credentials)
+2. Update webhook callback URLs to Sub2API's callback endpoints
+3. Verify that new orders are processed correctly via built-in payment
+4. Decommission the Sub2ApiPay service
+
+> **Note**: Historical order data from Sub2ApiPay will not be automatically migrated. Keep Sub2ApiPay running for a while to access historical records.
diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md
new file mode 100644
index 00000000..ae765fb9
--- /dev/null
+++ b/docs/PAYMENT_CN.md
@@ -0,0 +1,287 @@
+# 支付系统配置指南
+
+Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支付服务。
+
+---
+
+## 目录
+
+- [支持的支付方式](#支持的支付方式)
+- [快速开始](#快速开始)
+- [系统设置](#系统设置)
+- [服务商配置](#服务商配置)
+- [服务商实例管理](#服务商实例管理)
+- [Webhook 配置](#webhook-配置)
+- [支付流程](#支付流程)
+- [从 Sub2ApiPay 迁移](#从-sub2apipay-迁移)
+
+---
+
+## 支持的支付方式
+
+| 服务商 | 支付方式 | 说明 |
+|--------|---------|------|
+| **EasyPay(易支付)** | 支付宝、微信支付 | 兼容易支付协议的第三方聚合支付 |
+| **支付宝官方** | 桌面二维码扫码、移动端支付宝跳转 | 直接对接支付宝开放平台,桌面端返回二维码,移动端返回 WAP/唤起链接 |
+| **微信官方** | Native 扫码、H5、公众号/JSAPI 支付 | 直接对接微信支付 APIv3,按终端环境自动分流 |
+| **Stripe** | 银行卡、支付宝、微信支付、Link 等 | 国际支付,支持多币种 |
+
+> 支付宝官方 / 微信官方与易支付可以同时作为后台服务商实例存在,但前台始终只展示 `支付宝`、`微信支付` 两个可见按钮。管理员需要分别为这两个按钮选择唯一支付来源:官方或易支付。官方渠道直接对接 API,资金直达商户账户,手续费更低;易支付通过第三方平台聚合,接入门槛更低。
+
+> **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择:
+>
+> - **国内渠道 / 人民币结算** — [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`):支付宝 / 微信官方 API 直连,手续费 **1.6%**;资金直达商家账户,**T+1 自动到账**。支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉。
+> - **国际渠道 / USDT 或美元结算** — [启润支付](https://kyren.top/?code=SUB2API)(`https://kyren.top/?code=SUB2API`):为 AI 项目提供低门槛国际收款通道,支持国际版微信支付与支付宝,本地货币支付、美元结算。手续费:微信 2%、支付宝 2.5%;提现 0.1%(最低 40 美元、最高 150 美元),以 **USDT 或美元**到账。无资质审核、注册即用,使用门槛最低;提现门槛略高,适合**不使用国内支付渠道、无法接受 Stripe 高达 6%+ 手续费、流水较大,且拥有美元或 USDT 渠道可接收提现资金**的用户。启润支付开户费 200 美元,通过本链接注册(含 Sub2Api 作者 [@Wei-Shaw](https://github.com/Wei-Shaw) 邀请码)可**免开户费**,介意可去掉。
+>
+> 支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。
+
+---
+
+## 快速开始
+
+1. 进入管理后台 → **设置** → **支付设置** 标签页
+2. 开启 **启用支付**
+3. 配置基本参数(金额范围、超时时间等)
+4. 在 **服务商管理** 中添加至少一个服务商实例
+5. 用户即可在前端页面进行充值
+
+---
+
+## 系统设置
+
+在管理后台 **设置 → 支付设置** 中配置以下参数:
+
+### 基本设置
+
+| 设置项 | 说明 | 默认值 |
+|--------|------|--------|
+| **启用支付** | 启用或禁用支付系统 | 关闭 |
+| **商品名前缀** | 支付页面显示的商品名前缀 | - |
+| **商品名后缀** | 商品名后缀(如"元") | - |
+| **最低金额** | 单笔最低充值金额 | 1 |
+| **最高金额** | 单笔最高充值金额(留空表示不限制) | - |
+| **每日限额** | 每用户每日累计充值上限(留空表示不限制) | - |
+| **订单超时时间** | 订单超时分钟数,至少 1 分钟 | 30 |
+| **最大待支付订单数** | 同一用户最大并行待支付订单数 | 3 |
+| **负载均衡策略** | 多服务商实例时的选择策略 | 轮询 |
+
+### 前台可见支付方式路由
+
+当前版本对用户统一展示支付方式,不区分官方渠道还是易支付:
+
+- **支付宝**:后台启用后,需要额外指定该按钮路由到 `支付宝官方` 或 `易支付支付宝`
+- **微信支付**:后台启用后,需要额外指定该按钮路由到 `微信官方` 或 `易支付微信`
+- 同一个可见支付方式在同一时刻只能路由到一个来源
+- 支付来源未选择时,即使对应按钮被开启,前台也不会暴露该支付方式
+
+### 负载均衡策略
+
+| 策略 | 说明 |
+|------|------|
+| **轮询(round-robin)** | 按顺序轮流分配到各服务商实例 |
+| **最少金额(least-amount)** | 优先分配到当日累计金额最少的实例 |
+
+### 取消频率限制
+
+防止用户频繁创建并取消订单:
+
+| 设置项 | 说明 |
+|--------|------|
+| **启用限制** | 开关 |
+| **窗口模式** | 滚动窗口 / 固定窗口 |
+| **时间窗口** | 窗口长度 |
+| **窗口单位** | 分钟 / 小时 |
+| **最大次数** | 窗口内允许的最大取消次数 |
+
+### 帮助信息
+
+| 设置项 | 说明 |
+|--------|------|
+| **帮助图片** | 充值页面显示的客服二维码等图片(支持上传) |
+| **帮助文本** | 充值页面显示的说明文字 |
+
+---
+
+## 服务商配置
+
+每种服务商需要不同的凭证和参数。在 **服务商管理 → 添加服务商** 中选择类型后填写。
+
+> **回调地址自动生成**:添加服务商时,异步回调地址(Notify URL)和同步跳转地址(Return URL)由系统根据你的站点域名自动拼接,无需手动填写。管理员只需确认域名正确即可。
+
+### EasyPay(易支付)
+
+兼容任何 EasyPay 协议的支付服务商。
+
+| 参数 | 说明 | 必填 |
+|------|------|------|
+| **商户 ID(PID)** | EasyPay 商户 ID | 是 |
+| **商户密钥(PKey)** | EasyPay 商户密钥 | 是 |
+| **API 地址** | EasyPay API 基础地址 | 是 |
+| **支付宝通道 ID** | 指定支付宝通道(可选) | 否 |
+| **微信通道 ID** | 指定微信通道(可选) | 否 |
+
+### 支付宝官方
+
+直接对接支付宝开放平台。移动端走支付宝手机网站支付跳转;桌面端优先使用当面付返回扫码串,若商户未开通当面付则回退到电脑网站支付,并将收银台链接同时返回给前端用于渲染二维码或直接打开支付页。
+
+| 参数 | 说明 | 必填 |
+|------|------|------|
+| **AppID** | 支付宝应用 AppID | 是 |
+| **应用私钥** | RSA2 应用私钥 | 是 |
+| **支付宝公钥** | 支付宝公钥 | 是 |
+
+### 微信官方
+
+直接对接微信支付 APIv3,支持 Native 扫码支付、H5 支付,以及在微信环境内的公众号/JSAPI 支付。
+
+| 参数 | 说明 | 必填 |
+|------|------|------|
+| **AppID** | 微信支付 AppID | 是 |
+| **商户号(MchID)** | 微信支付商户号 | 是 |
+| **商户 API 私钥** | 商户 API 私钥(PEM 格式) | 是 |
+| **APIv3 密钥** | 32 位 APIv3 密钥 | 是 |
+| **微信支付公钥** | 微信支付公钥(PEM 格式) | 是 |
+| **微信支付公钥 ID** | 微信支付公钥 ID | 是 |
+| **商户证书序列号** | 商户证书序列号 | 是 |
+
+### Stripe
+
+国际支付平台,支持多种支付方式和币种。
+
+| 参数 | 说明 | 必填 |
+|------|------|------|
+| **Secret Key** | Stripe 密钥(`sk_live_...` 或 `sk_test_...`) | 是 |
+| **Publishable Key** | Stripe 可公开密钥(`pk_live_...` 或 `pk_test_...`) | 是 |
+| **Webhook Secret** | Stripe Webhook 签名密钥(`whsec_...`) | 是 |
+
+---
+
+## 服务商实例管理
+
+同一种服务商可以创建**多个实例**,实现负载均衡和风控:
+
+- **多实例负载均衡** — 按轮询或最少金额策略分流订单
+- **独立限额** — 每个实例可独立配置单笔最小/最大金额和每日限额
+- **独立启停** — 可单独启用/禁用某个实例,不影响其他实例
+- **退款控制** — 每个实例可单独开启或关闭退款功能
+- **支付方式** — 每个实例可选择支持的支付方式子集
+- **排序** — 拖拽调整实例顺序
+
+### 实例限额配置
+
+每个实例支持以下限额:
+
+| 限额项 | 说明 |
+|--------|------|
+| **单笔最小金额** | 该实例接受的最小订单金额 |
+| **单笔最大金额** | 该实例接受的最大订单金额 |
+| **每日限额** | 该实例每日累计交易上限 |
+
+> 负载均衡时,系统会自动跳过超出限额的实例。
+
+---
+
+## Webhook 配置
+
+支付回调是支付系统的核心环节,必须正确配置:
+
+### 回调地址格式
+
+添加服务商时,系统会自动根据站点域名拼接回调地址,格式如下:
+
+| 服务商 | 回调路径 |
+|--------|---------|
+| **EasyPay** | `https://your-domain.com/api/v1/payment/webhook/easypay` |
+| **支付宝官方** | `https://your-domain.com/api/v1/payment/webhook/alipay` |
+| **微信官方** | `https://your-domain.com/api/v1/payment/webhook/wxpay` |
+| **Stripe** | `https://your-domain.com/api/v1/payment/webhook/stripe` |
+
+> 将 `your-domain.com` 替换为你的实际域名。EasyPay / 支付宝 / 微信的回调地址在添加服务商时自动填入,无需手动配置。
+
+### Stripe Webhook 设置
+
+1. 登录 [Stripe Dashboard](https://dashboard.stripe.com/)
+2. 进入 **Developers → Webhooks**
+3. 添加端点,填写回调地址
+4. 订阅事件:`payment_intent.succeeded`、`payment_intent.payment_failed`
+5. 将生成的 Webhook Secret(`whsec_...`)填入服务商配置
+
+### 注意事项
+
+- 回调地址必须是 **HTTPS**(Stripe 强制要求,其他服务商强烈推荐)
+- 确保服务器防火墙允许支付平台的回调请求
+- 系统会自动进行签名验证,防止伪造回调
+- 支付成功后自动完成余额充值,无需人工干预
+
+---
+
+## 支付流程
+
+```
+用户选择充值金额和支付方式
+ │
+ ▼
+ 创建订单 (PENDING)
+ ├─ 校验金额范围、待支付订单数、每日限额
+ ├─ 负载均衡选择服务商实例
+ └─ 调用服务商获取支付信息
+ │
+ ▼
+ 用户完成支付
+ ├─ EasyPay → 扫码 / H5 跳转
+ ├─ 支付宝官方 → 桌面扫码单(当面付优先,电脑网站支付回退)/ 移动端支付宝跳转
+ ├─ 微信官方 → 桌面 Native 扫码 / 非微信 H5 / 微信内 JSAPI
+ └─ Stripe → Payment Element(银行卡/支付宝/微信等)
+ │
+ ▼
+ 支付回调验签 → 订单 PAID
+ │
+ ▼
+ 自动充值到用户余额 → 订单 COMPLETED
+```
+
+### 订单状态说明
+
+| 状态 | 说明 |
+|------|------|
+| `PENDING` | 待支付,等待用户完成支付 |
+| `PAID` | 已支付,等待充值到账 |
+| `COMPLETED` | 已完成,余额已到账 |
+| `EXPIRED` | 已过期,超时未支付 |
+| `CANCELLED` | 已取消,用户主动取消 |
+| `FAILED` | 充值失败,可管理员重试 |
+| `REFUND_REQUESTED` | 已申请退款 |
+| `REFUNDING` | 退款处理中 |
+| `REFUNDED` | 已退款 |
+
+### 超时与兜底
+
+- 订单超时后,后台任务会先查询上游支付状态再标记过期
+- 如果用户实际已支付但回调延迟,系统会通过查询补单
+- 后台任务每 60 秒执行一次超时检查
+
+---
+
+## 从 Sub2ApiPay 迁移
+
+如果你之前使用 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 作为外部支付系统,现在可以迁移到内置支付:
+
+### 主要差异
+
+| 对比项 | Sub2ApiPay | 内置支付 |
+|--------|-----------|---------|
+| 部署方式 | 独立服务(Next.js + PostgreSQL) | 内置于 Sub2API,无需额外部署 |
+| 支付方式 | EasyPay、支付宝、微信、Stripe | 相同 |
+| 配置方式 | 环境变量 + 独立管理后台 | Sub2API 管理后台内统一配置 |
+| 充值对接 | 通过 Admin API 回调 | 内部直接处理,更可靠 |
+| 订阅套餐 | 支持 | 暂不支持(计划中) |
+| 订单管理 | 独立管理界面 | 集成在 Sub2API 管理后台 |
+
+### 迁移步骤
+
+1. 在 Sub2API 管理后台启用支付并配置服务商(使用相同的支付凭证)
+2. 更新 Webhook 回调地址为 Sub2API 的回调地址
+3. 确认新订单通过内置支付正常处理
+4. 停用 Sub2ApiPay 服务
+
+> **注意**:Sub2ApiPay 中的历史订单数据不会自动迁移。建议保留 Sub2ApiPay 一段时间以便查询历史记录。
diff --git a/frontend/package.json b/frontend/package.json
index d2a6dede..a220d3a7 100644
--- a/frontend/package.json
+++ b/frontend/package.json
@@ -16,9 +16,10 @@
},
"dependencies": {
"@lobehub/icons": "^4.0.2",
+ "@stripe/stripe-js": "^9.0.1",
"@tanstack/vue-virtual": "^3.13.23",
"@vueuse/core": "^10.7.0",
- "axios": "^1.13.5",
+ "axios": "^1.15.0",
"chart.js": "^4.4.1",
"dompurify": "^3.3.1",
"driver.js": "^1.4.0",
diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml
index 505b72f3..0a7b3fa1 100644
--- a/frontend/pnpm-lock.yaml
+++ b/frontend/pnpm-lock.yaml
@@ -11,6 +11,9 @@ importers:
'@lobehub/icons':
specifier: ^4.0.2
version: 4.0.2(@lobehub/ui@4.9.2)(@types/react@19.2.7)(antd@6.1.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@stripe/stripe-js':
+ specifier: ^9.0.1
+ version: 9.0.1
'@tanstack/vue-virtual':
specifier: ^3.13.23
version: 3.13.23(vue@3.5.26(typescript@5.6.3))
@@ -18,8 +21,8 @@ importers:
specifier: ^10.7.0
version: 10.11.1(vue@3.5.26(typescript@5.6.3))
axios:
- specifier: ^1.13.5
- version: 1.13.5
+ specifier: ^1.15.0
+ version: 1.15.0
chart.js:
specifier: ^4.4.1
version: 4.5.1
@@ -134,11 +137,11 @@ packages:
resolution: {integrity: sha512-30iZtAPgz+LTIYoeivqYo853f02jBYSd5uGnGpkFV0M3xOt9aN73erkgYAmZU43x4VfqcnLxW9Kpg3R5LC4YYw==}
engines: {node: '>=6.0.0'}
- '@ant-design/colors@8.0.0':
- resolution: {integrity: sha512-6YzkKCw30EI/E9kHOIXsQDHmMvTllT8STzjMb4K2qzit33RW2pqCJP0sk+hidBntXxE+Vz4n1+RvCTfBw6OErw==}
+ '@ant-design/colors@8.0.1':
+ resolution: {integrity: sha512-foPVl0+SWIslGUtD/xBr1p9U4AKzPhNYEseXYRRo5QSzGACYZrQbe11AYJbYfAWnWSpGBx6JjBmSeugUsD9vqQ==}
- '@ant-design/cssinjs-utils@2.0.2':
- resolution: {integrity: sha512-Mq3Hm6fJuQeFNKSp3+yT4bjuhVbdrsyXE2RyfpJFL0xiYNZdaJ6oFaE3zFrzmHbmvTd2Wp3HCbRtkD4fU+v2ZA==}
+ '@ant-design/cssinjs-utils@2.1.2':
+ resolution: {integrity: sha512-5fTHQ158jJJ5dC/ECeyIdZUzKxE/mpEMRZxthyG1sw/AKRHKgJBg00Yi6ACVXgycdje7KahRNvNET/uBccwCnA==}
peerDependencies:
react: '>=18'
react-dom: '>=18'
@@ -149,15 +152,21 @@ packages:
react: '>=16.0.0'
react-dom: '>=16.0.0'
- '@ant-design/fast-color@3.0.0':
- resolution: {integrity: sha512-eqvpP7xEDm2S7dUzl5srEQCBTXZMmY3ekf97zI+M2DHOYyKdJGH0qua0JACHTqbkRnD/KHFQP9J1uMJ/XWVzzA==}
+ '@ant-design/cssinjs@2.1.2':
+ resolution: {integrity: sha512-2Hy8BnCEH31xPeSLbhhB2ctCPXE2ZnASdi+KbSeS79BNbUhL9hAEe20SkUk+BR8aKTmqb6+FKFruk7w8z0VoRQ==}
+ peerDependencies:
+ react: '>=16.0.0'
+ react-dom: '>=16.0.0'
+
+ '@ant-design/fast-color@3.0.1':
+ resolution: {integrity: sha512-esKJegpW4nckh0o6kV3Tkb7NPIZYbPnnFxmQDUmL08ukXZAvV85TZBr70eGuke/CIArLaP6aw8lt9KILjnWuOw==}
engines: {node: '>=8.x'}
'@ant-design/icons-svg@4.4.2':
resolution: {integrity: sha512-vHbT+zJEVzllwP+CM+ul7reTEfBR0vgxFe7+lREAsAA7YGsYpboiq2sQNeQeRvh09GfQgs/GyFEvZpJ9cLXpXA==}
- '@ant-design/icons@6.1.0':
- resolution: {integrity: sha512-KrWMu1fIg3w/1F2zfn+JlfNDU8dDqILfA5Tg85iqs1lf8ooyGlbkA+TkwfOKKgqpUmAiRY1PTFpuOU2DAIgSUg==}
+ '@ant-design/icons@6.1.1':
+ resolution: {integrity: sha512-AMT4N2y++TZETNHiM77fs4a0uPVCJGuL5MTonk13Pvv7UN7sID1cNEZOc1qNqx6zLKAOilTEFAdAoAFKa0U//Q==}
engines: {node: '>=8'}
peerDependencies:
react: '>=16.0.0'
@@ -208,6 +217,10 @@ packages:
resolution: {integrity: sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ==}
engines: {node: '>=6.9.0'}
+ '@babel/runtime@7.29.2':
+ resolution: {integrity: sha512-JiDShH45zKHWyGe4ZNVRrCjBz8Nh9TMmZG1kh4QTK8hCBTWBi8Da+i7s1fJw7/lYpM4ccepSNfqzZ/QvABBi5g==}
+ engines: {node: '>=6.9.0'}
+
'@babel/template@7.27.2':
resolution: {integrity: sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==}
engines: {node: '>=6.9.0'}
@@ -220,8 +233,8 @@ packages:
resolution: {integrity: sha512-qQ5m48eI/MFLQ5PxQj4PFaprjyCTLI37ElWMmNs0K8Lk3dVeOdNpB3ks8jc7yM5CDmVC73eMVk/trk3fgmrUpA==}
engines: {node: '>=6.9.0'}
- '@base-ui/react@1.0.0':
- resolution: {integrity: sha512-4USBWz++DUSLTuIYpbYkSgy1F9ZmNG9S/lXvlUN6qMK0P0RlW+6eQmDUB4DgZ7HVvtXl4pvi4z5J2fv6Z3+9hg==}
+ '@base-ui/react@1.3.0':
+ resolution: {integrity: sha512-FwpKqZbPz14AITp1CVgf4AjhKPe1OeeVKSBMdgD10zbFlj3QSWelmtCMLi2+/PFZZcIm3l87G7rwtCZJwHyXWA==}
engines: {node: '>=14.0.0'}
peerDependencies:
'@types/react': ^17 || ^18 || ^19
@@ -231,8 +244,8 @@ packages:
'@types/react':
optional: true
- '@base-ui/utils@0.2.3':
- resolution: {integrity: sha512-/CguQ2PDaOzeVOkllQR8nocJ0FFIDqsWIcURsVmm53QGo8NhFNpePjNlyPIB41luxfOqnG7PU0xicMEw3ls7XQ==}
+ '@base-ui/utils@0.2.6':
+ resolution: {integrity: sha512-yQ+qeuqohwhsNpoYDqqXaLllYAkPCP4vYdDrVo8FQXaAPfHWm1pG/Vm+jmGTA5JFS0BAIjookyapuJFY8F9PIw==}
peerDependencies:
'@types/react': ^17 || ^18 || ^19
react: ^17 || ^18 || ^19
@@ -244,23 +257,23 @@ packages:
'@bcoe/v8-coverage@0.2.3':
resolution: {integrity: sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==}
- '@braintree/sanitize-url@7.1.1':
- resolution: {integrity: sha512-i1L7noDNxtFyL5DmZafWy1wRVhGehQmzZaz1HiN5e7iylJMSZR7ekOV7NsIqa5qBldlLrsKv4HbgFUVlQrz8Mw==}
+ '@braintree/sanitize-url@7.1.2':
+ resolution: {integrity: sha512-jigsZK+sMF/cuiB7sERuo9V7N9jx+dhmHHnQyDSVdpZwVutaBu7WvNYqMDLSgFgfB30n452TP3vjDAvFC973mA==}
- '@chevrotain/cst-dts-gen@11.0.3':
- resolution: {integrity: sha512-BvIKpRLeS/8UbfxXxgC33xOumsacaeCKAjAeLyOn7Pcp95HiRbrpl14S+9vaZLolnbssPIUuiUd8IvgkRyt6NQ==}
+ '@chevrotain/cst-dts-gen@12.0.0':
+ resolution: {integrity: sha512-fSL4KXjTl7cDgf0B5Rip9Q05BOrYvkJV/RrBTE/bKDN096E4hN/ySpcBK5B24T76dlQ2i32Zc3PAE27jFnFrKg==}
- '@chevrotain/gast@11.0.3':
- resolution: {integrity: sha512-+qNfcoNk70PyS/uxmj3li5NiECO+2YKZZQMbmjTqRI3Qchu8Hig/Q9vgkHpI3alNjr7M+a2St5pw5w5F6NL5/Q==}
+ '@chevrotain/gast@12.0.0':
+ resolution: {integrity: sha512-1ne/m3XsIT8aEdrvT33so0GUC+wkctpUPK6zU9IlOyJLUbR0rg4G7ZiApiJbggpgPir9ERy3FRjT6T7lpgetnQ==}
- '@chevrotain/regexp-to-ast@11.0.3':
- resolution: {integrity: sha512-1fMHaBZxLFvWI067AVbGJav1eRY7N8DDvYCTwGBiE/ytKBgP8azTdgyrKyWZ9Mfh09eHWb5PgTSO8wi7U824RA==}
+ '@chevrotain/regexp-to-ast@12.0.0':
+ resolution: {integrity: sha512-p+EW9MaJwgaHguhoqwOtx/FwuGr+DnNn857sXWOi/mClXIkPGl3rn7hGNWvo31HA3vyeQxjqe+H36yZJwYU8cA==}
- '@chevrotain/types@11.0.3':
- resolution: {integrity: sha512-gsiM3G8b58kZC2HaWR50gu6Y1440cHiJ+i3JUvcp/35JchYejb2+5MVeJK0iKThYpAa/P2PYFV4hoi44HD+aHQ==}
+ '@chevrotain/types@12.0.0':
+ resolution: {integrity: sha512-S+04vjFQKeuYw0/eW3U52LkAHQsB1ASxsPGsLPUyQgrZ2iNNibQrsidruDzjEX2JYfespXMG0eZmXlhA6z7nWA==}
- '@chevrotain/utils@11.0.3':
- resolution: {integrity: sha512-YslZMgtJUyuMbZ+aKvfF3x1f5liK4mWNxghFRv7jqRR9C3R3fAOGTTKvxXDa2Y1s9zSbcpuO0cAxDYsc9SrXoQ==}
+ '@chevrotain/utils@12.0.0':
+ resolution: {integrity: sha512-lB59uJoaGIfOOL9knQqQRfhl9g7x8/wqFkp13zTdkRu1huG9kg6IJs1O8hqj9rs6h7orGxHJUKb+mX3rPbWGhA==}
'@csstools/color-helpers@5.1.0':
resolution: {integrity: sha512-S11EXWJyy0Mz5SYvRmY8nJYTFFd1LCNV+7cXyAgQtOOuzb4EsgfqDufL+9esx72/eLhsRdGZwaldu/h+E4t4BA==}
@@ -536,26 +549,26 @@ packages:
resolution: {integrity: sha512-d9zaMRSTIKDLhctzH12MtXvJKSSUhaHcjV+2Z+GK+EEY7XKpP5yR4x+N3TAcHTcu963nIr+TMcCb4DBCYX1z6Q==}
engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0}
- '@floating-ui/core@1.7.3':
- resolution: {integrity: sha512-sGnvb5dmrJaKEZ+LDIpguvdX3bDlEllmv4/ClQ9awcmCZrlx5jQyyMWFM5kBI+EyNOCDDiKk8il0zeuX3Zlg/w==}
+ '@floating-ui/core@1.7.5':
+ resolution: {integrity: sha512-1Ih4WTWyw0+lKyFMcBHGbb5U5FtuHJuujoyyr5zTaWS5EYMeT6Jb2AuDeftsCsEuchO+mM2ij5+q9crhydzLhQ==}
- '@floating-ui/dom@1.7.4':
- resolution: {integrity: sha512-OOchDgh4F2CchOX94cRVqhvy7b3AFb+/rQXyswmzmGakRfkMgoWVjfnLWkRirfLEfuD4ysVW16eXzwt3jHIzKA==}
+ '@floating-ui/dom@1.7.6':
+ resolution: {integrity: sha512-9gZSAI5XM36880PPMm//9dfiEngYoC6Am2izES1FF406YFsjvyBMmeJ2g4SAju3xWwtuynNRFL2s9hgxpLI5SQ==}
- '@floating-ui/react-dom@2.1.6':
- resolution: {integrity: sha512-4JX6rEatQEvlmgU80wZyq9RT96HZJa88q8hp0pBd+LrczeDI4o6uA2M+uvxngVHo4Ihr8uibXxH6+70zhAFrVw==}
+ '@floating-ui/react-dom@2.1.8':
+ resolution: {integrity: sha512-cC52bHwM/n/CxS87FH0yWdngEZrjdtLW/qVruo68qg+prK7ZQ4YGdut2GyDVpoGeAYe/h899rVeOVm6Oi40k2A==}
peerDependencies:
react: '>=16.8.0'
react-dom: '>=16.8.0'
- '@floating-ui/react@0.27.16':
- resolution: {integrity: sha512-9O8N4SeG2z++TSM8QA/KTeKFBVCNEz/AGS7gWPJf6KFRzmRWixFRnCnkPHRDwSVZW6QPDO6uT0P2SpWNKCc9/g==}
+ '@floating-ui/react@0.27.19':
+ resolution: {integrity: sha512-31B8h5mm8YxotlE7/AU/PhNAl8eWxAmjL/v2QOxroDNkTFLk3Uu82u63N3b6TXa4EGJeeZLVcd/9AlNlVqzeog==}
peerDependencies:
react: '>=17.0.0'
react-dom: '>=17.0.0'
- '@floating-ui/utils@0.2.10':
- resolution: {integrity: sha512-aGTxbpbg8/b5JfU1HXSrbH3wXZuLPJcNEcZQFMxLs3oSzgtVu6nFPkbbGGUvBcUjKV2YyB9Wxxabo+HEH9tcRQ==}
+ '@floating-ui/utils@0.2.11':
+ resolution: {integrity: sha512-RiB/yIh78pcIxl6lLMG0CgBXAZ2Y0eVHqMPYugu+9U0AeT6YBeiJpf7lbdJNIugFP5SIjwNRgo4DhR1Qxi26Gg==}
'@giscus/react@3.1.0':
resolution: {integrity: sha512-0TCO2TvL43+oOdyVVGHDItwxD1UMKP2ZYpT6gXmhFOqfAJtZxTzJ9hkn34iAF/b6YzyJ4Um89QIt9z/ajmAEeg==}
@@ -618,8 +631,8 @@ packages:
'@kurkle/color@0.3.4':
resolution: {integrity: sha512-M5UknZPHRu3DEDWoipU6sE8PdkZ6Z/S+v4dD+Ke8IaNlpdSQah50lz1KtcFBa2vsdOnwbbnxJwVM4wty6udA5w==}
- '@lit-labs/ssr-dom-shim@1.5.0':
- resolution: {integrity: sha512-HLomZXMmrCFHSRKESF5vklAKsDY7/fsT/ZhqCu3V0UoW/Qbv8wxmO4W9bx4KnCCF2Zak4yuk+AGraK/bPmI4kA==}
+ '@lit-labs/ssr-dom-shim@1.5.1':
+ resolution: {integrity: sha512-Aou5UdlSpr5whQe8AA/bZG0jMj96CoJIWbGfZ91qieWu5AWUMKw8VR/pAkQkJYvBNhmCcWnZlyyk5oze8JIqYA==}
'@lit/reactive-element@2.1.2':
resolution: {integrity: sha512-pbCDiVMnne1lYUIaYNN5wrwQXDtHaYtg7YEFPeW+hws6U47WeFvISGUWekPGKWOP1ygrs0ef0o1VJMk1exos5A==}
@@ -660,8 +673,8 @@ packages:
'@types/react': '>=16'
react: '>=16'
- '@mermaid-js/parser@0.6.3':
- resolution: {integrity: sha512-lnjOhe7zyHjc+If7yT4zoedx2vo4sHaTmtkl1+or8BRTnCtDmcTpAjpzDSfCZrshM5bCoz0GyidzadJAH1xobA==}
+ '@mermaid-js/parser@1.1.0':
+ resolution: {integrity: sha512-gxK9ZX2+Fex5zu8LhRQoMeMPEHbc73UKZ0FQ54YrQtUxE1VVhMwzeNtKRPAu5aXks4FasbMe4xB4bWrmq6Jlxw==}
'@nodelib/fs.scandir@2.1.5':
resolution: {integrity: sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==}
@@ -682,8 +695,8 @@ packages:
resolution: {integrity: sha512-+1VkjdD0QBLPodGrJUeqarH8VAIvQODIbwh9XpP5Syisf7YoQgsJKPNFoqqLQlu+VQ/tVSshMR6loPMn8U+dPg==}
engines: {node: '>=14'}
- '@primer/octicons@19.21.1':
- resolution: {integrity: sha512-7tgtBkCNcg75YJnckinzvES+uxysYQCe+CHSEnzr3VYgxttzKRvfmrnVogl3aEuHCQP4xhiE9k2lFDhYwGtTzQ==}
+ '@primer/octicons@19.23.1':
+ resolution: {integrity: sha512-CzjGmxkmNhyst6EekrS3SJPdtzgIkUMP/LSJch65y99/kmiFXbO1a+q7zoYe3hnI9NaOM0IN+ydDIbOmd8YqcA==}
'@radix-ui/primitive@1.1.3':
resolution: {integrity: sha512-JTF99U/6XIjCBo0wqkU5sK10glYe27MRRsfwoiq5zzOEZLHU3A3KCMa5X/azekYRCJ0HlwI0crAXS/5dEHTzDg==}
@@ -929,8 +942,8 @@ packages:
'@radix-ui/rect@1.1.1':
resolution: {integrity: sha512-HPwpGIzkl28mWyZqG52jiqDJ12waP11Pa1lGoiyUkIEuMLBP0oeK/C89esbXrxsky5we7dfd8U58nm0SgAWpVw==}
- '@rc-component/async-validator@5.0.4':
- resolution: {integrity: sha512-qgGdcVIF604M9EqjNF0hbUTz42bz/RDtxWdWuU5EQe3hi7M8ob54B6B35rOsvX5eSvIHIzT9iH1R3n+hk3CGfg==}
+ '@rc-component/async-validator@5.1.0':
+ resolution: {integrity: sha512-n4HcR5siNUXRX23nDizbZBQPO0ZM/5oTtmKZ6/eqL0L2bo747cklFdZGRN2f+c9qWGICwDzrhW0H7tE9PptdcA==}
engines: {node: '>=14.x'}
'@rc-component/cascader@1.10.0':
@@ -981,8 +994,8 @@ packages:
react: '>=16.11.0'
react-dom: '>=16.11.0'
- '@rc-component/form@1.6.0':
- resolution: {integrity: sha512-A7vrN8kExtw4sW06mrsgCb1rowhvBFFvQU6Bk/NL0Fj6Wet/5GF0QnGCxBu/sG3JI9FEhsJWES0D44BW2d0hzg==}
+ '@rc-component/form@1.6.2':
+ resolution: {integrity: sha512-OgIn2RAoaSBqaIgzJf/X6iflIa9LpTozci1lagLBdURDFhGA370v0+T0tXxOi8YShMjTha531sFhwtnrv+EJaQ==}
engines: {node: '>=8.x'}
peerDependencies:
react: '>=16.9.0'
@@ -1018,8 +1031,8 @@ packages:
react: '>=16.9.0'
react-dom: '>=16.9.0'
- '@rc-component/mini-decimal@1.1.0':
- resolution: {integrity: sha512-jS4E7T9Li2GuYwI6PyiVXmxTiM6b07rlD9Ge8uGZSCz3WlzcG5ZK7g5bbuKNeZ9pgUuPK/5guV781ujdVpm4HQ==}
+ '@rc-component/mini-decimal@1.1.3':
+ resolution: {integrity: sha512-bk/FJ09fLf+NLODMAFll6CfYrHPBioTedhW6lxDBuuWucJEqFUd4l/D/5JgIi3dina6sYahB8iuPAZTNz2pMxw==}
engines: {node: '>=8.x'}
'@rc-component/motion@1.1.6':
@@ -1054,8 +1067,8 @@ packages:
react: '>=16.9.0'
react-dom: '>=16.9.0'
- '@rc-component/picker@1.9.0':
- resolution: {integrity: sha512-OLisdk8AWVCG9goBU1dWzuH5QlBQk8jktmQ6p0/IyBFwdKGwyIZOSjnBYo8hooHiTdl0lU+wGf/OfMtVBw02KQ==}
+ '@rc-component/picker@1.9.1':
+ resolution: {integrity: sha512-9FBYYsvH3HMLICaPDA/1Th5FLaDkFa7qAtangIdlhKb3ZALaR745e9PsOhheJb6asS4QXc12ffiAcjdkZ4C5/g==}
engines: {node: '>=12.x'}
peerDependencies:
date-fns: '>= 2.x'
@@ -1108,8 +1121,8 @@ packages:
react: '>=16.9.0'
react-dom: '>=16.9.0'
- '@rc-component/resize-observer@1.0.1':
- resolution: {integrity: sha512-r+w+Mz1EiueGk1IgjB3ptNXLYSLZ5vnEfKHH+gfgj7JMupftyzvUUl3fRcMZe5uMM04x0n8+G2o/c6nlO2+Wag==}
+ '@rc-component/resize-observer@1.1.2':
+ resolution: {integrity: sha512-t/Bb0W8uvL4PYKAB3YcChC+DlHh0Wt5kM7q/J+0qpVEUMLe7Hk5zuvc9km0hMnTFPSx5Z7Wu/fzCLN6erVLE8Q==}
peerDependencies:
react: '>=16.9.0'
react-dom: '>=16.9.0'
@@ -1193,15 +1206,15 @@ packages:
react: '*'
react-dom: '*'
- '@rc-component/trigger@2.3.0':
- resolution: {integrity: sha512-iwaxZyzOuK0D7lS+0AQEtW52zUWxoGqTGkke3dRyb8pYiShmRpCjB/8TzPI4R6YySCH7Vm9BZj/31VPiiQTLBg==}
+ '@rc-component/trigger@2.3.1':
+ resolution: {integrity: sha512-ORENF39PeXTzM+gQEshuk460Z8N4+6DkjpxlpE7Q3gYy1iBpLrx0FOJz3h62ryrJZ/3zCAUIkT1Pb/8hHWpb3A==}
engines: {node: '>=8.x'}
peerDependencies:
react: '>=16.9.0'
react-dom: '>=16.9.0'
- '@rc-component/trigger@3.8.1':
- resolution: {integrity: sha512-walnDJnKq+OcPQFHBMN+YZmdHV8+6z75+Rgpc0dW1c+Dmy6O7tRueDs4LdbwjlryQfTdsw84PIkNPzcx5yQ7qQ==}
+ '@rc-component/trigger@3.9.0':
+ resolution: {integrity: sha512-X8btpwfrT27AgrZVOz4swclhEHTZcqaHeQMXXBgveagOiakTa36uObXbdwerXffgV8G9dH1fAAE0DHtVQs8EHg==}
engines: {node: '>=8.x'}
peerDependencies:
react: '>=18.0.0'
@@ -1213,6 +1226,12 @@ packages:
react: '>=16.9.0'
react-dom: '>=16.9.0'
+ '@rc-component/util@1.10.1':
+ resolution: {integrity: sha512-q++9S6rUa5Idb/xIBNz6jtvumw5+O5YV5V0g4iK9mn9jWs4oGJheE3ZN1kAnE723AXyaD8v95yeOASmdk8Jnng==}
+ peerDependencies:
+ react: '>=18.0.0'
+ react-dom: '>=18.0.0'
+
'@rc-component/util@1.7.0':
resolution: {integrity: sha512-tIvIGj4Vl6fsZFvWSkYw9sAfiCKUXMyhVz6kpKyZbwyZyRPqv2vxYZROdaO1VB4gqTNvUZFXh6i3APUiterw5g==}
peerDependencies:
@@ -1347,26 +1366,26 @@ packages:
cpu: [x64]
os: [win32]
- '@shikijs/core@3.20.0':
- resolution: {integrity: sha512-f2ED7HYV4JEk827mtMDwe/yQ25pRiXZmtHjWF8uzZKuKiEsJR7Ce1nuQ+HhV9FzDcbIo4ObBCD9GPTzNuy9S1g==}
+ '@shikijs/core@3.23.0':
+ resolution: {integrity: sha512-NSWQz0riNb67xthdm5br6lAkvpDJRTgB36fxlo37ZzM2yq0PQFFzbd8psqC2XMPgCzo1fW6cVi18+ArJ44wqgA==}
- '@shikijs/engine-javascript@3.20.0':
- resolution: {integrity: sha512-OFx8fHAZuk7I42Z9YAdZ95To6jDePQ9Rnfbw9uSRTSbBhYBp1kEOKv/3jOimcj3VRUKusDYM6DswLauwfhboLg==}
+ '@shikijs/engine-javascript@3.23.0':
+ resolution: {integrity: sha512-aHt9eiGFobmWR5uqJUViySI1bHMqrAgamWE1TYSUoftkAeCCAiGawPMwM+VCadylQtF4V3VNOZ5LmfItH5f3yA==}
- '@shikijs/engine-oniguruma@3.20.0':
- resolution: {integrity: sha512-Yx3gy7xLzM0ZOjqoxciHjA7dAt5tyzJE3L4uQoM83agahy+PlW244XJSrmJRSBvGYELDhYXPacD4R/cauV5bzQ==}
+ '@shikijs/engine-oniguruma@3.23.0':
+ resolution: {integrity: sha512-1nWINwKXxKKLqPibT5f4pAFLej9oZzQTsby8942OTlsJzOBZ0MWKiwzMsd+jhzu8YPCHAswGnnN1YtQfirL35g==}
- '@shikijs/langs@3.20.0':
- resolution: {integrity: sha512-le+bssCxcSHrygCWuOrYJHvjus6zhQ2K7q/0mgjiffRbkhM4o1EWu2m+29l0yEsHDbWaWPNnDUTRVVBvBBeKaA==}
+ '@shikijs/langs@3.23.0':
+ resolution: {integrity: sha512-2Ep4W3Re5aB1/62RSYQInK9mM3HsLeB91cHqznAJMuylqjzNVAVCMnNWRHFtcNHXsoNRayP9z1qj4Sq3nMqYXg==}
- '@shikijs/themes@3.20.0':
- resolution: {integrity: sha512-U1NSU7Sl26Q7ErRvJUouArxfM2euWqq1xaSrbqMu2iqa+tSp0D1Yah8216sDYbdDHw4C8b75UpE65eWorm2erQ==}
+ '@shikijs/themes@3.23.0':
+ resolution: {integrity: sha512-5qySYa1ZgAT18HR/ypENL9cUSGOeI2x+4IvYJu4JgVJdizn6kG4ia5Q1jDEOi7gTbN4RbuYtmHh0W3eccOrjMA==}
- '@shikijs/transformers@3.20.0':
- resolution: {integrity: sha512-PrHHMRr3Q5W1qB/42kJW6laqFyWdhrPF2hNR9qjOm1xcSiAO3hAHo7HaVyHE6pMyevmy3i51O8kuGGXC78uK3g==}
+ '@shikijs/transformers@3.23.0':
+ resolution: {integrity: sha512-F9msZVxdF+krQNSdQ4V+Ja5QemeAoTQ2jxt7nJCwhDsdF1JWS3KxIQXA3lQbyKwS3J61oHRUSv4jYWv3CkaKTQ==}
- '@shikijs/types@3.20.0':
- resolution: {integrity: sha512-lhYAATn10nkZcBQ0BlzSbJA3wcmL5MXUUF8d2Zzon6saZDlToKaiRX60n2+ZaHJCmXEcZRWNzn+k9vplr8Jhsw==}
+ '@shikijs/types@3.23.0':
+ resolution: {integrity: sha512-3JZ5HXOZfYjsYSk0yPwBrkupyYSLpAE26Qc0HLghhZNGTZg/SKxXIIgoxOpmmeQP0RRSDJTk1/vPfw9tbw+jSQ==}
'@shikijs/vscode-textmate@10.0.2':
resolution: {integrity: sha512-83yeghZ2xxin3Nj8z1NMd/NCuca+gsYXswywDy5bHvwlWL8tpTQmzGeUuHd9FC3E/SBEMvzJRwWEOz5gGes9Qg==}
@@ -1379,6 +1398,10 @@ packages:
peerDependencies:
react: '>= 16.3.0'
+ '@stripe/stripe-js@9.0.1':
+ resolution: {integrity: sha512-un0URSosrW7wNr7xZ5iI2mC9mdeXZ3KERoVlA2RdmeLXYxHUPXq0yHzir2n/MtyXXEdSaELtz4WXGS6dzPEeKA==}
+ engines: {node: '>=12.16'}
+
'@tanstack/virtual-core@3.13.23':
resolution: {integrity: sha512-zSz2Z2HNyLjCplANTDyl3BcdQJc2k1+yyFoKhNRmCr7V7dY8o8q5m8uFTI1/Pg1kL+Hgrz6u3Xo6eFUB7l66cg==}
@@ -1459,8 +1482,8 @@ packages:
'@types/d3-selection@3.0.11':
resolution: {integrity: sha512-bhAXu23DJWsrI45xafYpkQ4NtcKMwWnAC/vKrd2l+nxMFuvOT3XMYTIj2opv8vq8AO5Yh7Qac/nSeP/3zjTK0w==}
- '@types/d3-shape@3.1.7':
- resolution: {integrity: sha512-VLvUQ33C+3J+8p+Daf+nYSOsjB4GXp19/S/aGo60m9h1v6XaxjiT82lKVWJCfzhtuZ3yD7i/TPeC/fuKLLOSmg==}
+ '@types/d3-shape@3.1.8':
+ resolution: {integrity: sha512-lae0iWfcDeR7qt7rA88BNiqdvPS5pFVPpo5OfjElwNaT2yyekbM0C9vK+yqBqEmHr6lDkRnYNoTBYlAgJa7a4w==}
'@types/d3-time-format@4.0.3':
resolution: {integrity: sha512-5xg9rC+wWL8kdDj153qZcsJ0FWiFt0J5RB6LYUNZjwSnesfblqrI/bJ1wBdJ8OQfncgbJG5+2F+qfqnqyzYxyg==}
@@ -1480,8 +1503,8 @@ packages:
'@types/d3@7.4.3':
resolution: {integrity: sha512-lZXZ9ckh5R8uiFVt8ogUNf+pIrK4EsWrx2Np75WvF/eTpJ0FMHNhjXk8CKEx/+gpHbNQyJWehbFaTvqmHWB3ww==}
- '@types/debug@4.1.12':
- resolution: {integrity: sha512-vIChWdVG3LG1SMxEvI/AK+FWJthlrqlTu7fbrlywTkkaONwk/UAGaULXRlf8vkzFBLVm0zkMdCquhL5aOjhXPQ==}
+ '@types/debug@4.1.13':
+ resolution: {integrity: sha512-KSVgmQmzMwPlmtljOomayoR89W4FynCAi3E8PPs7vmDVPe84hT+vGPKkJfThkmXs0x0jAaa9U8uW8bbfyS2fWw==}
'@types/dompurify@3.2.0':
resolution: {integrity: sha512-Fgg31wv9QbLDA0SpTOXO3MaxySc4DKGLi8sna4/Utjo4r3ZRPdCt4UQee8BWr+Q5z21yifghREPJGYaEOEIACg==}
@@ -1505,8 +1528,8 @@ packages:
'@types/js-cookie@3.0.6':
resolution: {integrity: sha512-wkw9yd1kEXOPnvEeEV1Go1MmxtBJL0RR79aOTAApecWFVu7w0NNXNqhcWgvw2YgZDYadliXkl14pa3WXw5jlCQ==}
- '@types/katex@0.16.7':
- resolution: {integrity: sha512-HMwFiRujE5PjrgwHQ25+bsLJgowjGjm5Z8FVSf0N6PwgJrwxH0QxzHYDcKsTfV3wva0vzrpqMTJS2jXPr5BMEQ==}
+ '@types/katex@0.16.8':
+ resolution: {integrity: sha512-trgaNyfU+Xh2Tc+ABIb44a5AYUpicB3uwirOioeOkNPPbmgRNtcWyDeeFRzjPZENO9Vq8gvVqfhaaXWLlevVwg==}
'@types/mdast@4.0.4':
resolution: {integrity: sha512-kGaNbPh1k7AFzgpud/gMdvIm5xuECykRR+JnWKQno9TAXVa6WIVCGTPvYGekIDL4uwCZQSYbUxNBSb1aUo79oA==}
@@ -1605,6 +1628,9 @@ packages:
'@ungap/structured-clone@1.3.0':
resolution: {integrity: sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==}
+ '@upsetjs/venn.js@2.0.0':
+ resolution: {integrity: sha512-WbBhLrooyePuQ1VZxrJjtLvTc4NVfpOyKx0sKqioq9bX1C1m7Jgykkn8gLrtwumBioXIqam8DLxp88Adbue6Hw==}
+
'@use-gesture/core@10.3.1':
resolution: {integrity: sha512-WcINiDt8WjqBdUXye25anHiNxPc0VOrlT8F6LLkU6cycrOGUDyY/yyFmsg3k8i5OLvv25llc0QC45GhR/C8llw==}
@@ -1736,6 +1762,11 @@ packages:
engines: {node: '>=0.4.0'}
hasBin: true
+ acorn@8.16.0:
+ resolution: {integrity: sha512-UVJyE9MttOsBQIDKw1skb9nAwQuR5wuGD3+82K6JgJlm/Y+KI92oNsMNGZCYdDsVtRHSak0pcV5Dno5+4jh9sw==}
+ engines: {node: '>=0.4.0'}
+ hasBin: true
+
adler-32@1.3.1:
resolution: {integrity: sha512-ynZ4w/nUUv5rrsR8UUGoe1VC9hZj6V5hU9Qw1HlMDJGEJw5S7TfTErWTjMys6M7vr0YWcPqs3qAr4ss0nDfP+A==}
engines: {node: '>=0.8'}
@@ -1744,8 +1775,8 @@ packages:
resolution: {integrity: sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==}
engines: {node: '>= 14'}
- ahooks@3.9.6:
- resolution: {integrity: sha512-Mr7f05swd5SmKlR9SZo5U6M0LsL4ErweLzpdgXjA1JPmnZ78Vr6wzx0jUtvoxrcqGKYnX0Yjc02iEASVxHFPjQ==}
+ ahooks@3.9.7:
+ resolution: {integrity: sha512-S0lvzhbdlhK36RFBkGv+RbOM/dbbweym+BIHM/bwwuWVSVN5TuVErHPMWo4w0t1NDYg5KPp2iEf7Y7E5LASYiw==}
peerDependencies:
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
@@ -1827,8 +1858,8 @@ packages:
peerDependencies:
postcss: ^8.1.0
- axios@1.13.5:
- resolution: {integrity: sha512-cz4ur7Vb0xS4/KUN0tPWe44eqxrIu31me+fbang3ijiNscE129POzipJJA6zniq2C/Z6sJCjMimjS8Lc/GAs8Q==}
+ axios@1.15.0:
+ resolution: {integrity: sha512-wWyJDlAatxk30ZJer+GeCWS209sA42X+N5jU2jy6oHTp7ufw8uzUTVFBX9+wTfAlhiJXGS0Bq7X6efruWjuK9Q==}
babel-plugin-macros@3.1.0:
resolution: {integrity: sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg==}
@@ -1924,13 +1955,14 @@ packages:
resolution: {integrity: sha512-PAJdDJusoxnwm1VwW07VWwUN1sl7smmC3OKggvndJFadxxDRyFJBX/ggnu/KE4kQAB7a3Dp8f/YXC1FlUprWmA==}
engines: {node: '>= 16'}
- chevrotain-allstar@0.3.1:
- resolution: {integrity: sha512-b7g+y9A0v4mxCW1qUhf3BSVPg+/NvGErk/dOkrDaHA0nQIQGAtrOjlX//9OQtRlSCy+x9rfB5N8yC71lH1nvMw==}
+ chevrotain-allstar@0.4.1:
+ resolution: {integrity: sha512-PvVJm3oGqrveUVW2Vt/eZGeiAIsJszYweUcYwcskg9e+IubNYKKD+rHHem7A6XVO22eDAL+inxNIGAzZ/VIWlA==}
peerDependencies:
- chevrotain: ^11.0.0
+ chevrotain: ^12.0.0
- chevrotain@11.0.3:
- resolution: {integrity: sha512-ci2iJH6LeIkvP9eJW6gpueU8cnZhv85ELY8w8WiFtNjMHA5ad6pQLaJo9mEly/9qUyCpvqX8/POVUTf18/HFdw==}
+ chevrotain@12.0.0:
+ resolution: {integrity: sha512-csJvb+6kEiQaqo1woTdSAuOWdN0WTLIydkKrBnS+V5gZz0oqBrp4kQ35519QgK6TpBThiG3V1vNSHlIkv4AglQ==}
+ engines: {node: '>=22.0.0'}
chokidar@3.6.0:
resolution: {integrity: sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==}
@@ -1952,10 +1984,6 @@ packages:
cliui@6.0.0:
resolution: {integrity: sha512-t6wbgtoCXvAzst7QgXxJYqPt0usEfbgQdftEPbLL/cvv6HPE5VgvqCuAIDR0NgU52ds6rFwqrgakNLrHEjCbrQ==}
- clsx@1.2.1:
- resolution: {integrity: sha512-EcR6r5a8bj6pu3ycsa/E/cKVGuTgZJZdsyUYHOksG/UHIiKfjxzRxYJpyVBwYaQeOvghal9fcc4PidlgzugAQg==}
- engines: {node: '>=6'}
-
clsx@2.1.1:
resolution: {integrity: sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==}
engines: {node: '>=6'}
@@ -2056,8 +2084,8 @@ packages:
peerDependencies:
cytoscape: ^3.2.0
- cytoscape@3.33.1:
- resolution: {integrity: sha512-iJc4TwyANnOGR1OmWhsS9ayRS3s+XQ185FmuHObThD+5AeJCakAAbWv8KimMTt08xCCLNgneQwFp+JRJOr9qGQ==}
+ cytoscape@3.33.2:
+ resolution: {integrity: sha512-sj4HXd3DokGhzZAdjDejGvTPLqlt84vNFN8m7bGsOzDY5DyVcxIb2ejIXat2Iy7HxWhdT/N1oKyheJ5YdpsGuw==}
engines: {node: '>=0.10'}
d3-array@2.12.1:
@@ -2116,8 +2144,8 @@ packages:
resolution: {integrity: sha512-zxV/SsA+U4yte8051P4ECydjD/S+qeYtnaIyAs9tgHCqfguma/aAQDjo85A9Z6EKhBirHRJHXIgJUlffT4wdLg==}
engines: {node: '>=12'}
- d3-format@3.1.0:
- resolution: {integrity: sha512-YyUI6AEuY/Wpt8KWLgZHsIU86atmikuoOmCfommt0LYHiQSPjvX2AcFc38PX0CBpr2RCyZhjex+NS/LPOv6YqA==}
+ d3-format@3.1.2:
+ resolution: {integrity: sha512-AJDdYOdnyRDV5b6ArilzCPPwc1ejkHcoyFarqlPqT7zRYjhavcT3uSrqcMvsgh2CgoPbK3RCwyHaVyxYcP2Arg==}
engines: {node: '>=12'}
d3-geo@3.1.1:
@@ -2199,15 +2227,15 @@ packages:
resolution: {integrity: sha512-e1U46jVP+w7Iut8Jt8ri1YsPOvFpg46k+K8TpCb0P+zjCkjkPnV7WzfDJzMHy1LnA+wj5pLT1wjO901gLXeEhA==}
engines: {node: '>=12'}
- dagre-d3-es@7.0.13:
- resolution: {integrity: sha512-efEhnxpSuwpYOKRm/L5KbqoZmNNukHa/Flty4Wp62JRvgH2ojwVgPgdYyr4twpieZnyRDdIH7PY2mopX26+j2Q==}
+ dagre-d3-es@7.0.14:
+ resolution: {integrity: sha512-P4rFMVq9ESWqmOgK+dlXvOtLwYg0i7u0HBGJER0LZDJT2VHIPAMZ/riPxqJceWMStH5+E61QxFra9kIS3AqdMg==}
data-urls@5.0.0:
resolution: {integrity: sha512-ZYP5VBHshaDAiVZxjbRVcFJpc+4xGgT0bK3vzy1HLN8jTO975HEbuYzZJcHoQEY5K1a0z8YayJkyVETa08eNTg==}
engines: {node: '>=18'}
- dayjs@1.11.19:
- resolution: {integrity: sha512-t5EcLVS6QPBNqM2z8fakk/NKel+Xzshgt8FFKAn+qwlD1pzZWxh0nVCrvFK7ZDb6XucZeF9z8C7CBWTRIVApAw==}
+ dayjs@1.11.20:
+ resolution: {integrity: sha512-YbwwqR/uYpeoP4pu043q+LTDLFBLApUP6VxRihdfNTqu4ubqMlGDLd6ErXhEgsyvY0K6nCs7nggYumAN+9uEuQ==}
de-indent@1.0.2:
resolution: {integrity: sha512-e/1zu3xH5MQryN2zdVaF0OrdNLUbvWxzMbi+iNA6Bky7l1RoP8a2fIbRocyHclXt/arDrrR6lL3TqFD9pMQTsg==}
@@ -2228,8 +2256,8 @@ packages:
decimal.js@10.6.0:
resolution: {integrity: sha512-YpgQiITW3JXGntzdUmyUR1V812Hn8T1YVXhCu+wO3OpS4eU9l4YdD3qjyiKdV6mvV29zapkMeD390UVEf2lkUg==}
- decode-named-character-reference@1.2.0:
- resolution: {integrity: sha512-c6fcElNV6ShtZXmsgNgFFV5tVX2PaV4g+MOAkb8eXHvn6sryJBrZa9r0zV6+dtTyoCKxtDy5tyQ5ZwQuidtd+Q==}
+ decode-named-character-reference@1.3.0:
+ resolution: {integrity: sha512-GtpQYB283KrPp6nRw50q3U9/VfOutZOe103qlN7BPP6Ad27xYnOIWv4lPzo8HCAL+mMZofJ9KEy30fq6MfaK6Q==}
decode-uri-component@0.4.1:
resolution: {integrity: sha512-+8VxcR21HhTy8nOt6jf20w0c9CADrw1O8d+VZ/YzzCt4bJ3uBjw+D1q2osAB8RnpwwaeYBxy0HyKQxD5JBMuuQ==}
@@ -2242,8 +2270,8 @@ packages:
deep-is@0.1.4:
resolution: {integrity: sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==}
- delaunator@5.0.1:
- resolution: {integrity: sha512-8nvh+XBe96aCESrGOqMp/84b13H9cdKbG5P2ejQCh4d4sK9RL4371qou9drQjMhvnPmhWl5hnmqbEE0fXr9Xnw==}
+ delaunator@5.1.0:
+ resolution: {integrity: sha512-AGrQ4QSgssa1NGmWmLPqN5NY2KajF5MqxetNEO+o0n3ZwZZeTmt7bBnvzHWrmkZFxGgr4HdyFgelzgi06otLuQ==}
delayed-stream@1.0.0:
resolution: {integrity: sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==}
@@ -2276,6 +2304,9 @@ packages:
dompurify@3.3.1:
resolution: {integrity: sha512-qkdCKzLNtrgPFP1Vo+98FRzJnBRGe4ffyCea9IwHB1fyxPOeNTHpLKYGd4Uk9xvNoH0ZoOjwZxNptyMwqrId1Q==}
+ dompurify@3.3.3:
+ resolution: {integrity: sha512-Oj6pzI2+RqBfFG+qOaOLbFXLQ90ARpcGG6UePL82bJLtdsa6CYJD7nmiU8MW9nQNOtCHV3lZ/Bzq1X0QYbBZCA==}
+
driver.js@1.4.0:
resolution: {integrity: sha512-Gm64jm6PmcU+si21sQhBrTAM1JvUrR0QhNmjkprNLxohOBzul9+pNHXgQaT9lW84gwg9GMLB3NZGuGolsz5uew==}
@@ -2336,8 +2367,8 @@ packages:
resolution: {integrity: sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==}
engines: {node: '>= 0.4'}
- es-toolkit@1.43.0:
- resolution: {integrity: sha512-SKCT8AsWvYzBBuUqMk4NPwFlSdqLpJwmy6AP322ERn8W2YLIB6JBXnwMI2Qsh2gfphT3q7EKAxKb23cvFHFwKA==}
+ es-toolkit@1.45.1:
+ resolution: {integrity: sha512-/jhoOj/Fx+A+IIyDNOvO3TItGmlMKhtX8ISAHKE90c4b/k1tqaqEZ+uUqfpU8DMnW5cgNJv606zS55jGvza0Xw==}
esast-util-from-estree@2.0.0:
resolution: {integrity: sha512-4CyanoAudUSBAn5K13H4JhsMH6L9ZP7XbLVe/dKybkxMO7eDyLsT8UHl9TRNrU2Gr9nz+FovfSIjuXWJ81uVwQ==}
@@ -2531,8 +2562,8 @@ packages:
fraction.js@5.3.4:
resolution: {integrity: sha512-1X1NTtiJphryn/uLQz3whtY6jK3fTqoE3ohKs0tT+Ujr1W59oopxmoEh7Lu5p6vBaPbgoM0bzveAW4Qi5RyWDQ==}
- framer-motion@12.23.26:
- resolution: {integrity: sha512-cPcIhgR42xBn1Uj+PzOyheMtZ73H927+uWPDVhUMqxy8UHt6Okavb6xIz9J/phFUHUj0OncR6UvMfJTXoc/LKA==}
+ framer-motion@12.38.0:
+ resolution: {integrity: sha512-rFYkY/pigbcswl1XQSb7q424kSTQ8q6eAC+YUsSKooHQYuLdzdHjrt6uxUC+PRAO++q5IS7+TamgIw1AphxR+g==}
peerDependencies:
'@emotion/is-prop-valid': '*'
react: ^18.0.0 || ^19.0.0
@@ -2560,8 +2591,8 @@ packages:
resolution: {integrity: sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==}
engines: {node: 6.* || 8.* || >= 10.*}
- get-east-asian-width@1.4.0:
- resolution: {integrity: sha512-QZjmEOC+IT1uk6Rx0sX22V6uHWVwbdbxf1faPqJ1QhLdGgsRGCZoyaQBm/piRdJy/D2um6hM1UP7ZEeQ4EkP+Q==}
+ get-east-asian-width@1.5.0:
+ resolution: {integrity: sha512-CQ+bEO+Tva/qlmw24dCejulK5pMzVnUOFOijVogd3KQs07HnRIgp8TGipvCCRT06xeYEbpbgwaCxglFyiuIcmA==}
engines: {node: '>=18'}
get-intrinsic@1.3.0:
@@ -2707,8 +2738,8 @@ packages:
resolution: {integrity: sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g==}
engines: {node: '>= 4'}
- immer@11.1.3:
- resolution: {integrity: sha512-6jQTc5z0KJFtr1UgFpIL3N9XSC3saRaI9PwWtzM2pSqkNGtiNkYY2OSwkOGDK2XcTRcLb1pi/aNkKZz0nxVH4Q==}
+ immer@11.1.4:
+ resolution: {integrity: sha512-XREFCPo6ksxVzP4E0ekD5aMdf8WMwmdNaz6vuvxgI40UaEiu6q3p8X52aU6GdyvLY3XXX/8R7JOTXStz/nBbRw==}
import-fresh@3.3.1:
resolution: {integrity: sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==}
@@ -2882,8 +2913,8 @@ packages:
json2mq@0.2.0:
resolution: {integrity: sha512-SzoRg7ux5DWTII9J2qkrZrqV1gt+rTaoufMxEzXbS26Uid0NwaJd123HcoB80TgubEppxxIGdNxCx50fEoEWQA==}
- katex@0.16.27:
- resolution: {integrity: sha512-aeQoDkuRWSqQN6nSvVCEFvfXdqo1OQiCmmW1kc9xSdjutPv7BGO7pqY9sQRJpMOGrEdfDgF2TfRXe5eUAD2Waw==}
+ katex@0.16.45:
+ resolution: {integrity: sha512-pQpZbdBu7wCTmQUh7ufPmLr0pFoObnGUoL/yhtwJDgmmQpbkg/0HSVti25Fu4rmd1oCR6NGWe9vqTWuWv3GcNA==}
hasBin: true
keyv@4.5.4:
@@ -2892,9 +2923,9 @@ packages:
khroma@2.1.0:
resolution: {integrity: sha512-Ls993zuzfayK269Svk9hzpeGUKob/sIgZzyHYdjQoAdQetRKpOLj+k/QQQ/6Qi0Yz65mlROrfd+Ev+1+7dz9Kw==}
- langium@3.3.1:
- resolution: {integrity: sha512-QJv/h939gDpvT+9SiLVlY7tZC3xB2qK57v0J04Sh9wpMb6MP1q8gB21L3WIo8T5P1MSMg3Ep14L7KkDCFG3y4w==}
- engines: {node: '>=16.0.0'}
+ langium@4.2.2:
+ resolution: {integrity: sha512-JUshTRAfHI4/MF9dH2WupvjSXyn8JBuUEWazB8ZVJUtXutT0doDlAv1XKbZ1Pb5sMexa8FF4CFBc0iiul7gbUQ==}
+ engines: {node: '>=20.10.0', npm: '>=10.2.3'}
layout-base@1.0.2:
resolution: {integrity: sha512-8h2oVEZNktL4BH2JCOI90iD1yXwL6iNW7KcCKT2QZgQJR2vbqDsldCTPRU9NifTCqHZci57XvQQ15YTu+sTYPg==}
@@ -2936,11 +2967,8 @@ packages:
resolution: {integrity: sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==}
engines: {node: '>=10'}
- lodash-es@4.17.21:
- resolution: {integrity: sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw==}
-
- lodash-es@4.17.22:
- resolution: {integrity: sha512-XEawp1t0gxSi9x01glktRZ5HDy0HXqrM0x5pXQM98EaI0NxO6jVM7omDOxsuEo5UIASAnm2bRp1Jt/e0a2XU8Q==}
+ lodash-es@4.18.1:
+ resolution: {integrity: sha512-J8xewKD/Gk22OZbhpOVSwcs60zhd95ESDwezOFuA3/099925PdHJ7OFHNTGtajL3AlZkykD32HykiMo+BIBI8A==}
lodash.merge@4.6.2:
resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==}
@@ -2948,6 +2976,9 @@ packages:
lodash@4.17.21:
resolution: {integrity: sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==}
+ lodash@4.18.1:
+ resolution: {integrity: sha512-dMInicTPVE8d1e5otfwmmjlxkZoUpiVLwyeTdUsi/Caj/gfzzblBcCE5sRHV/AsjuCmxWrte2TNGSYuCeCq+0Q==}
+
longest-streak@3.1.0:
resolution: {integrity: sha512-9Ri+o0JYgehTaVBBDoMqIl8GXtbWg711O3srftcHhZ0dqnETqLaoIK0x17fUw9rFSlK/0NlsKe0Ahhyl5pXE2g==}
@@ -2998,6 +3029,11 @@ packages:
engines: {node: '>= 20'}
hasBin: true
+ marked@17.0.6:
+ resolution: {integrity: sha512-gB0gkNafnonOw0obSTEGZTT86IuhILt2Wfx0mWH/1Au83kybTayroZ/V6nS25mN7u8ASy+5fMhgB3XPNrOZdmA==}
+ engines: {node: '>= 20'}
+ hasBin: true
+
math-intrinsics@1.1.0:
resolution: {integrity: sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==}
engines: {node: '>= 0.4'}
@@ -3005,8 +3041,8 @@ packages:
mdast-util-find-and-replace@3.0.2:
resolution: {integrity: sha512-Tmd1Vg/m3Xz43afeNxDIhWRtFZgM2VLyaf4vSTYwudTyeuTneoL3qtWMA5jeLyz/O1vDJmmV4QuScFCA2tBPwg==}
- mdast-util-from-markdown@2.0.2:
- resolution: {integrity: sha512-uZhTV/8NBuw0WHkPTrCqDOl0zVe1BIng5ZtHoDk49ME1qqcjYmmLmOf0gELgcRMxN4w2iuIeVso5/6QymSrgmA==}
+ mdast-util-from-markdown@2.0.3:
+ resolution: {integrity: sha512-W4mAWTvSlKvf8L6J+VN9yLSqQ9AOAAvHuoDAmPkz4dHf553m5gVj2ejadHJhoJmcmxEnOv6Pa8XJhpxE93kb8Q==}
mdast-util-gfm-autolink-literal@2.0.1:
resolution: {integrity: sha512-5HVP2MKaP6L+G6YaxPNjuL0BPrq9orG3TsrZ9YXbA3vDw/ACI4MEsnoDpn6ZNm7GnZgtAcONJyPhOP8tNJQavQ==}
@@ -3064,8 +3100,8 @@ packages:
resolution: {integrity: sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==}
engines: {node: '>= 8'}
- mermaid@11.12.2:
- resolution: {integrity: sha512-n34QPDPEKmaeCG4WDMGy0OT6PSyxKCfy2pJgShP+Qow2KLrvWjclwbc3yXfSIf4BanqWEhQEpngWwNp/XhZt6w==}
+ mermaid@11.14.0:
+ resolution: {integrity: sha512-GSGloRsBs+JINmmhl0JDwjpuezCsHB4WGI4NASHxL3fHo3o/BRXTxhDLKnln8/Q0lRFRyDdEjmk1/d5Sn1Xz8g==}
micromark-core-commonmark@2.0.3:
resolution: {integrity: sha512-RDBrHEMSxVFLg6xvnXmb1Ayr2WzLAWjeSATAoxwKYJV94TeNavgoIdA0a9ytzDSVzBy2YKFK+emCPOEibLeCrg==}
@@ -3225,14 +3261,14 @@ packages:
resolution: {integrity: sha512-WRoDn//mXBiJ1H40rqa3vH0toePwSsGb45iInWlTySa+Uu4k3tYUSxa2v1KqAiLtvlrSzaExqS1gtk96A9zvEA==}
engines: {node: '>=0.10.0'}
- mlly@1.8.0:
- resolution: {integrity: sha512-l8D9ODSRWLe2KHJSifWGwBqpTZXIXTeo8mlKjY+E2HAakaTeNpqAyBZ8GSqLzHgw4XmHmC8whvpjJNMbFZN7/g==}
+ mlly@1.8.2:
+ resolution: {integrity: sha512-d+ObxMQFmbt10sretNDytwt85VrbkhhUA/JBGm1MPaWJ65Cl4wOgLaB1NYvJSZ0Ef03MMEU/0xpPMXUIQ29UfA==}
- motion-dom@12.23.23:
- resolution: {integrity: sha512-n5yolOs0TQQBRUFImrRfs/+6X4p3Q4n1dUEqt/H58Vx7OW6RF+foWEgmTVDhIWJIMXOuNNL0apKH2S16en9eiA==}
+ motion-dom@12.38.0:
+ resolution: {integrity: sha512-pdkHLD8QYRp8VfiNLb8xIBJis1byQ9gPT3Jnh2jqfFtAsWUA3dEepDlsWe/xMpO8McV+VdpKVcp+E+TGJEtOoA==}
- motion-utils@12.23.6:
- resolution: {integrity: sha512-eAWoPgr4eFEOFfg2WjIsMoqJTW6Z8MTUCgn/GZ3VRpClWBdnbjryiA3ZSNLyxCTmCQx4RmYX6jX1iWHbenUPNQ==}
+ motion-utils@12.36.0:
+ resolution: {integrity: sha512-eHWisygbiwVvf6PZ1vhaHCLamvkSbPIeAYxWUuL3a2PD/TROgE7FvfHWTIH4vMl798QLfMw15nRqIaRDXTlYRg==}
motion@12.23.26:
resolution: {integrity: sha512-Ll8XhVxY8LXMVYTCfme27WH2GjBrCIzY4+ndr5QKxsK+YwCtOi2B/oBi5jcIbik5doXuWT/4KKDOVAZJkeY5VQ==}
@@ -3308,8 +3344,8 @@ packages:
oniguruma-parser@0.12.1:
resolution: {integrity: sha512-8Unqkvk1RYc6yq2WBYRj4hdnsAxVze8i7iPfQr8e4uSP3tRv0rpZcbGUDvxfQQcdwHt/e9PrMvGCsa8OqG9X3w==}
- oniguruma-to-es@4.3.4:
- resolution: {integrity: sha512-3VhUGN3w2eYxnTzHn+ikMI+fp/96KoRSVK9/kMTcFqj1NRDh2IhQCKvYxDnWePKRXY/AqH+Fuiyb7VHSzBjHfA==}
+ oniguruma-to-es@4.3.5:
+ resolution: {integrity: sha512-Zjygswjpsewa0NLTsiizVuMQZbp0MDyM6lIt66OxsF21npUDlzpHi1Mgb/qhQdkb+dWFTzJmFbEWdvZgRho8eQ==}
optionator@0.9.4:
resolution: {integrity: sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g==}
@@ -3503,8 +3539,9 @@ packages:
proto-list@1.2.4:
resolution: {integrity: sha512-vtK/94akxsTMhe0/cbfpR+syPuszcuwhqVjJq26CuNDgFGj682oRBXOP5MJpv2r7JtE8MsiepGIqvvOTBwn2vA==}
- proxy-from-env@1.1.0:
- resolution: {integrity: sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==}
+ proxy-from-env@2.1.0:
+ resolution: {integrity: sha512-cJ+oHTW1VAEa8cJslgmUZrc+sjRKgAKl3Zyse6+PV38hZe/V6Z14TbCuXcan9F9ghlz4QrFr2c92TNF82UkYHA==}
+ engines: {node: '>=10'}
psl@1.15.0:
resolution: {integrity: sha512-JZd3gMVBAVQkSs6HdNZo9Sdo0LNcQeMNP3CozBJb3JYC/QUYZTnKxP+f8oWRX4rHP5EurWxqAHTSwUCjlNKa1w==}
@@ -3617,8 +3654,8 @@ packages:
peerDependencies:
react: ^19.2.3
- react-draggable@4.4.6:
- resolution: {integrity: sha512-LtY5Xw1zTPqHkVmtM3X8MUOxNDOUhv/khTgBgrUvwaS064bwVvxT+q5El0uUFNx5IEPKXuRejr7UqLwBIg5pdw==}
+ react-draggable@4.5.0:
+ resolution: {integrity: sha512-VC+HBLEZ0XJxnOxVAZsdRi8rD04Iz3SiiKOoYzamjylUcju/hP9np/aZdLHf/7WOD268WMoNJMvYfB5yAK45cw==}
peerDependencies:
react: '>= 16.3.0'
react-dom: '>= 16.3.0'
@@ -3629,17 +3666,16 @@ packages:
peerDependencies:
react: '>= 16.8'
- react-error-boundary@6.0.1:
- resolution: {integrity: sha512-zArgQpjJUN1ZLMEKWtifxQweW3yfvwL5j2nh3Pesze1qG6r5oCDMy/TA97bUF01wy4xCeeL4/pd8GHmvEsP3Bg==}
+ react-error-boundary@6.1.1:
+ resolution: {integrity: sha512-BrYwPOdXi5mqkk5lw+Uvt0ThHx32rCt3BkukS4X23A2AIWDPSGX6iaWTc0y9TU/mHDA/6qOSGel+B2ERkOvD1w==}
peerDependencies:
react: ^18.0.0 || ^19.0.0
- react-dom: ^18.0.0 || ^19.0.0
react-fast-compare@3.2.2:
resolution: {integrity: sha512-nsO+KSNgo1SbJqJEYRE9ERzo7YtYbou/OqjSQKxV7jcKox7+usiUVZOAC+XnDOABXggQTno0Y1CpVnuWEc1boQ==}
- react-hotkeys-hook@5.2.1:
- resolution: {integrity: sha512-xbKh6zJxd/vJHT4Bw4+0pBD662Fk20V+VFhLqciCg+manTVO4qlqRqiwFOYelfHN9dBvWj9vxaPkSS26ZSIJGg==}
+ react-hotkeys-hook@5.2.4:
+ resolution: {integrity: sha512-BgKg+A1+TawkYluh5Bo4cTmcgMN5L29uhJbDUQdHwPX+qgXRjIPYU5kIDHyxnAwCkCBiu9V5OpB2mpyeluVF2A==}
peerDependencies:
react: '>=16.8.0'
react-dom: '>=16.8.0'
@@ -3664,8 +3700,8 @@ packages:
react:
optional: true
- react-rnd@10.5.2:
- resolution: {integrity: sha512-0Tm4x7k7pfHf2snewJA8x7Nwgt3LV+58MVEWOVsFjk51eYruFEa6Wy7BNdxt4/lH0wIRsu7Gm3KjSXY2w7YaNw==}
+ react-rnd@10.5.3:
+ resolution: {integrity: sha512-s/sIT3pGZnQ+57egijkTp9mizjIWrJz68Pq6yd+F/wniFY3IriML18dUXnQe/HP9uMiJ+9MAp44hljG99fZu6Q==}
peerDependencies:
react: '>=16.3.0'
react-dom: '>=16.3.0'
@@ -3795,8 +3831,8 @@ packages:
deprecated: Rimraf versions prior to v4 are no longer supported
hasBin: true
- robust-predicates@3.0.2:
- resolution: {integrity: sha512-IXgzBWvWQwE6PrDI05OvmXUIruQTcoMDzRsOd5CDvHCVLcLHMTSYvOK5Cm46kWqlV3yAbuSpBZdJ5oP5OUoStg==}
+ robust-predicates@3.0.3:
+ resolution: {integrity: sha512-NS3levdsRIUOmiJ8FZWCP7LG3QpJyrs/TE0Zpf1yvZu8cAJJ6QMW92H1c7kWpdIHo8RvmLxN/o2JXTKHp74lUA==}
rollup@4.54.0:
resolution: {integrity: sha512-3nk8Y3a9Ea8szgKhinMlGMhGMw89mqule3KWczxhIzqudyHdCIOHw8WJlj/r329fACjKLEh13ZSk7oE22kyeIw==}
@@ -3858,19 +3894,22 @@ packages:
resolution: {integrity: sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==}
engines: {node: '>=8'}
- shiki-stream@0.1.3:
- resolution: {integrity: sha512-pDIqmaP/zJWHNV8bJKp0tD0CZ6OkF+lWTIvmNRLktlTjBjN3+durr19JarS657U1oSEf/WrSYmdzwr9CeD6m2Q==}
+ shiki-stream@0.1.4:
+ resolution: {integrity: sha512-4pz6JGSDmVTTkPJ/ueixHkFAXY4ySCc+unvCaDZV7hqq/sdJZirRxgIXSuNSKgiFlGTgRR97sdu2R8K55sPsrw==}
peerDependencies:
react: ^19.0.0
+ solid-js: ^1.9.0
vue: ^3.2.0
peerDependenciesMeta:
react:
optional: true
+ solid-js:
+ optional: true
vue:
optional: true
- shiki@3.20.0:
- resolution: {integrity: sha512-kgCOlsnyWb+p0WU+01RjkCH+eBVsjL1jOwUYWv0YDWkM2/A46+LDKVs5yZCUXjJG6bj4ndFoAg5iLIIue6dulg==}
+ shiki@3.23.0:
+ resolution: {integrity: sha512-55Dj73uq9ZXL5zyeRPzHQsK7Nbyt6Y10k5s7OjuFZGMhpp4r/rsLBH0o/0fstIzX1Lep9VxefWljK/SKCzygIA==}
siginfo@2.0.0:
resolution: {integrity: sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==}
@@ -3967,8 +4006,8 @@ packages:
resolution: {integrity: sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==}
engines: {node: '>= 0.4'}
- swr@2.3.8:
- resolution: {integrity: sha512-gaCPRVoMq8WGDcWj9p4YWzCMPHzE0WNl6W8ADIx9c3JBEIdMkJGMzW+uzXvxHMltwcYACr9jP+32H8/hgwMR7w==}
+ swr@2.4.1:
+ resolution: {integrity: sha512-2CC6CiKQtEwaEeNiqWTAw9PGykW8SR5zZX8MZk6TeAvEAnVS7Visz8WzphqgtQ8v2xz/4Q5K+j+SeMaKXeeQIA==}
peerDependencies:
react: ^16.11.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
@@ -4010,8 +4049,8 @@ packages:
tinyexec@0.3.2:
resolution: {integrity: sha512-KQQR9yN7R5+OSwaK0XQoj22pwHoTlgYqmUscPYoknOoWCWfj/5/ABTMRi69FrKU5ffPVh5QcFikpWJI/P1ocHA==}
- tinyexec@1.0.2:
- resolution: {integrity: sha512-W/KYk+NFhkmsYpuHq5JykngiOCnxeVL8v8dFnqxSD8qEEdRfXk1SDM6JzNqcERbcGYj9tMrDQBYV9cjgnunFIg==}
+ tinyexec@1.1.1:
+ resolution: {integrity: sha512-VKS/ZaQhhkKFMANmAOhhXVoIfBXblQxGX1myCQ2faQrfmobMftXeJPcZGp0gS07ocvGJWDLZGyOZDadDBqYIJg==}
engines: {node: '>=18'}
tinyglobby@0.2.15:
@@ -4087,8 +4126,8 @@ packages:
engines: {node: '>=14.17'}
hasBin: true
- ufo@1.6.1:
- resolution: {integrity: sha512-9a4/uxlTWJ4+a5i0ooc1rU7C7YOw3wT+UGqdeNNHWnOF9qcMBgLRS+4IYUqbczewFx4mLEig6gawh7X6mFlEkA==}
+ ufo@1.6.3:
+ resolution: {integrity: sha512-yDJTmhydvl5lJzBmy/hyOAA0d+aqCBuwl818haVdYCRrWV84o7YyeVm4QlVHStqNrrJSTb6jKuFAVqAFsr+K3Q==}
undici-types@6.21.0:
resolution: {integrity: sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==}
@@ -4121,8 +4160,8 @@ packages:
unist-util-visit-parents@6.0.2:
resolution: {integrity: sha512-goh1s1TBrqSqukSc8wrjwWhL0hiJxgA8m4kFxGlQ+8FYQ3C/m11FcTs4YYem7V664AhHVvgoQLk890Ssdsr2IQ==}
- unist-util-visit@5.0.0:
- resolution: {integrity: sha512-MR04uvD+07cwl/yhVuVWAtw+3GOR/knlL55Nd/wAdblk27GCVt3lqpTivy/tkJcZoNPzTwS1Y+KMojlLDhoTzg==}
+ unist-util-visit@5.1.0:
+ resolution: {integrity: sha512-m+vIdyeCOpdr/QeQCu2EzxX/ohgS8KbnPDgFni4dQsfSCtpz8UqDyY5GjRru8PDKuYn7Fq19j1CQ+nJSsGKOzg==}
universalify@0.2.0:
resolution: {integrity: sha512-CJ1QgKmNg3CwvAv/kOFmtnEN05f0D/cn9QntgNOQlQF9dgvVTHj3t+8JPdjqawCHk7V/KA+fbUqzZ9XWhcqPUg==}
@@ -4289,9 +4328,6 @@ packages:
resolution: {integrity: sha512-woByF3PDpkHFUreUa7Hos7+pUWdeWMXRd26+ZX2A8cFx6v/JPTtd4/uN0/jB6XQHYaOlHbio03NTHCqrgG5n7g==}
hasBin: true
- vscode-uri@3.0.8:
- resolution: {integrity: sha512-AyFQ0EVmsOZOlAnxoFOGOq1SQDWAB7C6aqMGS23svWAllfOaxbuFvcT8D1i8z3Gyn8fraVeZNNmN6e9bxxXkKw==}
-
vscode-uri@3.1.0:
resolution: {integrity: sha512-/BpdSx+yCQGnCvecbyXdxHDkuk55/G3xwnC0GqY4gmQ3j+A+g8kzzgB4Nk/SINjqn6+waqw3EgbVF2QKExkRxQ==}
@@ -4487,15 +4523,15 @@ snapshots:
'@jridgewell/gen-mapping': 0.3.13
'@jridgewell/trace-mapping': 0.3.31
- '@ant-design/colors@8.0.0':
+ '@ant-design/colors@8.0.1':
dependencies:
- '@ant-design/fast-color': 3.0.0
+ '@ant-design/fast-color': 3.0.1
- '@ant-design/cssinjs-utils@2.0.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
+ '@ant-design/cssinjs-utils@2.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@ant-design/cssinjs': 2.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@babel/runtime': 7.28.4
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@ant-design/cssinjs': 2.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@babel/runtime': 7.29.2
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -4511,22 +4547,34 @@ snapshots:
react-dom: 19.2.3(react@19.2.3)
stylis: 4.3.6
- '@ant-design/fast-color@3.0.0': {}
+ '@ant-design/cssinjs@2.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
+ dependencies:
+ '@babel/runtime': 7.29.2
+ '@emotion/hash': 0.8.0
+ '@emotion/unitless': 0.7.5
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ clsx: 2.1.1
+ csstype: 3.2.3
+ react: 19.2.3
+ react-dom: 19.2.3(react@19.2.3)
+ stylis: 4.3.6
+
+ '@ant-design/fast-color@3.0.1': {}
'@ant-design/icons-svg@4.4.2': {}
- '@ant-design/icons@6.1.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
+ '@ant-design/icons@6.1.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@ant-design/colors': 8.0.0
+ '@ant-design/colors': 8.0.1
'@ant-design/icons-svg': 4.4.2
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@ant-design/react-slick@2.0.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
clsx: 2.1.1
json2mq: 0.2.0
react: 19.2.3
@@ -4536,7 +4584,7 @@ snapshots:
'@antfu/install-pkg@1.1.0':
dependencies:
package-manager-detector: 1.6.0
- tinyexec: 1.0.2
+ tinyexec: 1.1.1
'@asamuzakjp/css-color@3.2.0':
dependencies:
@@ -4579,6 +4627,8 @@ snapshots:
'@babel/runtime@7.28.4': {}
+ '@babel/runtime@7.29.2': {}
+
'@babel/template@7.27.2':
dependencies:
'@babel/code-frame': 7.27.1
@@ -4602,24 +4652,23 @@ snapshots:
'@babel/helper-string-parser': 7.27.1
'@babel/helper-validator-identifier': 7.28.5
- '@base-ui/react@1.0.0(@types/react@19.2.7)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
+ '@base-ui/react@1.3.0(@types/react@19.2.7)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@babel/runtime': 7.28.4
- '@base-ui/utils': 0.2.3(@types/react@19.2.7)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@floating-ui/react-dom': 2.1.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@floating-ui/utils': 0.2.10
+ '@babel/runtime': 7.29.2
+ '@base-ui/utils': 0.2.6(@types/react@19.2.7)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@floating-ui/react-dom': 2.1.8(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@floating-ui/utils': 0.2.11
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
- reselect: 5.1.1
tabbable: 6.4.0
use-sync-external-store: 1.6.0(react@19.2.3)
optionalDependencies:
'@types/react': 19.2.7
- '@base-ui/utils@0.2.3(@types/react@19.2.7)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
+ '@base-ui/utils@0.2.6(@types/react@19.2.7)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@babel/runtime': 7.28.4
- '@floating-ui/utils': 0.2.10
+ '@babel/runtime': 7.29.2
+ '@floating-ui/utils': 0.2.11
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
reselect: 5.1.1
@@ -4629,24 +4678,22 @@ snapshots:
'@bcoe/v8-coverage@0.2.3': {}
- '@braintree/sanitize-url@7.1.1': {}
+ '@braintree/sanitize-url@7.1.2': {}
- '@chevrotain/cst-dts-gen@11.0.3':
+ '@chevrotain/cst-dts-gen@12.0.0':
dependencies:
- '@chevrotain/gast': 11.0.3
- '@chevrotain/types': 11.0.3
- lodash-es: 4.17.21
+ '@chevrotain/gast': 12.0.0
+ '@chevrotain/types': 12.0.0
- '@chevrotain/gast@11.0.3':
+ '@chevrotain/gast@12.0.0':
dependencies:
- '@chevrotain/types': 11.0.3
- lodash-es: 4.17.21
+ '@chevrotain/types': 12.0.0
- '@chevrotain/regexp-to-ast@11.0.3': {}
+ '@chevrotain/regexp-to-ast@12.0.0': {}
- '@chevrotain/types@11.0.3': {}
+ '@chevrotain/types@12.0.0': {}
- '@chevrotain/utils@11.0.3': {}
+ '@chevrotain/utils@12.0.0': {}
'@csstools/color-helpers@5.1.0': {}
@@ -4881,30 +4928,30 @@ snapshots:
'@eslint/js@8.57.1': {}
- '@floating-ui/core@1.7.3':
+ '@floating-ui/core@1.7.5':
dependencies:
- '@floating-ui/utils': 0.2.10
+ '@floating-ui/utils': 0.2.11
- '@floating-ui/dom@1.7.4':
+ '@floating-ui/dom@1.7.6':
dependencies:
- '@floating-ui/core': 1.7.3
- '@floating-ui/utils': 0.2.10
+ '@floating-ui/core': 1.7.5
+ '@floating-ui/utils': 0.2.11
- '@floating-ui/react-dom@2.1.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
+ '@floating-ui/react-dom@2.1.8(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@floating-ui/dom': 1.7.4
+ '@floating-ui/dom': 1.7.6
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
- '@floating-ui/react@0.27.16(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
+ '@floating-ui/react@0.27.19(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@floating-ui/react-dom': 2.1.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@floating-ui/utils': 0.2.10
+ '@floating-ui/react-dom': 2.1.8(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@floating-ui/utils': 0.2.11
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
tabbable: 6.4.0
- '@floating-ui/utils@0.2.10': {}
+ '@floating-ui/utils@0.2.11': {}
'@giscus/react@3.1.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
@@ -4930,7 +4977,7 @@ snapshots:
dependencies:
'@antfu/install-pkg': 1.1.0
'@iconify/types': 2.0.0
- mlly: 1.8.0
+ mlly: 1.8.2
'@intlify/core-base@9.14.5':
dependencies:
@@ -4971,11 +5018,11 @@ snapshots:
'@kurkle/color@0.3.4': {}
- '@lit-labs/ssr-dom-shim@1.5.0': {}
+ '@lit-labs/ssr-dom-shim@1.5.1': {}
'@lit/reactive-element@2.1.2':
dependencies:
- '@lit-labs/ssr-dom-shim': 1.5.0
+ '@lit-labs/ssr-dom-shim': 1.5.1
'@lobehub/emojilib@1.0.0': {}
@@ -4984,7 +5031,7 @@ snapshots:
'@lobehub/emojilib': 1.0.0
antd-style: 4.1.0(@types/react@19.2.7)(antd@6.1.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
emoji-regex: 10.6.0
- es-toolkit: 1.43.0
+ es-toolkit: 1.45.1
lucide-react: 0.562.0(react@19.2.3)
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -5009,8 +5056,8 @@ snapshots:
'@lobehub/ui@4.9.2(@lobehub/fluent-emoji@4.1.0(@types/react@19.2.7)(antd@6.1.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(@lobehub/icons@4.0.2)(@types/mdast@4.0.4)(@types/react@19.2.7)(antd@6.1.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(micromark-util-types@2.0.2)(micromark@4.0.2)(motion@12.23.26(@emotion/is-prop-valid@1.4.0)(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)(vue@3.5.26(typescript@5.6.3))':
dependencies:
- '@ant-design/cssinjs': 2.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@base-ui/react': 1.0.0(@types/react@19.2.7)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@ant-design/cssinjs': 2.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@base-ui/react': 1.3.0(@types/react@19.2.7)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@dnd-kit/core': 6.3.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@dnd-kit/modifiers': 9.0.0(@dnd-kit/core@6.3.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react@19.2.3)
'@dnd-kit/sortable': 10.0.0(@dnd-kit/core@6.3.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react@19.2.3)
@@ -5018,32 +5065,32 @@ snapshots:
'@emoji-mart/data': 1.2.1
'@emoji-mart/react': 1.1.1(emoji-mart@5.6.0)(react@19.2.3)
'@emotion/is-prop-valid': 1.4.0
- '@floating-ui/react': 0.27.16(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@floating-ui/react': 0.27.19(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@giscus/react': 3.1.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@lobehub/fluent-emoji': 4.1.0(@types/react@19.2.7)(antd@6.1.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@lobehub/icons': 4.0.2(@lobehub/ui@4.9.2)(@types/react@19.2.7)(antd@6.1.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@mdx-js/mdx': 3.1.1
'@mdx-js/react': 3.1.1(@types/react@19.2.7)(react@19.2.3)
'@radix-ui/react-slot': 1.2.4(@types/react@19.2.7)(react@19.2.3)
- '@shikijs/core': 3.20.0
- '@shikijs/transformers': 3.20.0
+ '@shikijs/core': 3.23.0
+ '@shikijs/transformers': 3.23.0
'@splinetool/runtime': 0.9.526
- ahooks: 3.9.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ ahooks: 3.9.7(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
antd: 6.1.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
antd-style: 4.1.0(@types/react@19.2.7)(antd@6.1.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3))(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
chroma-js: 3.2.0
class-variance-authority: 0.7.1
clsx: 2.1.1
- dayjs: 1.11.19
+ dayjs: 1.11.20
emoji-mart: 5.6.0
- es-toolkit: 1.43.0
+ es-toolkit: 1.45.1
fast-deep-equal: 3.1.3
- immer: 11.1.3
- katex: 0.16.27
+ immer: 11.1.4
+ katex: 0.16.45
leva: 0.10.1(@types/react@19.2.7)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
lucide-react: 0.562.0(react@19.2.3)
- marked: 17.0.1
- mermaid: 11.12.2
+ marked: 17.0.6
+ mermaid: 11.14.0
motion: 12.23.26(@emotion/is-prop-valid@1.4.0)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
numeral: 2.0.6
polished: 4.3.1
@@ -5057,11 +5104,11 @@ snapshots:
react: 19.2.3
react-avatar-editor: 14.0.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
react-dom: 19.2.3(react@19.2.3)
- react-error-boundary: 6.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- react-hotkeys-hook: 5.2.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ react-error-boundary: 6.1.1(react@19.2.3)
+ react-hotkeys-hook: 5.2.4(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
react-markdown: 10.1.0(@types/react@19.2.7)(react@19.2.3)
react-merge-refs: 3.0.2(react@19.2.3)
- react-rnd: 10.5.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ react-rnd: 10.5.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
react-zoom-pan-pinch: 3.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
rehype-github-alerts: 4.2.0
rehype-katex: 7.0.1
@@ -5071,9 +5118,9 @@ snapshots:
remark-gfm: 4.0.1
remark-github: 12.0.0
remark-math: 6.0.0
- shiki: 3.20.0
- shiki-stream: 0.1.3(react@19.2.3)(vue@3.5.26(typescript@5.6.3))
- swr: 2.3.8(react@19.2.3)
+ shiki: 3.23.0
+ shiki-stream: 0.1.4(react@19.2.3)(vue@3.5.26(typescript@5.6.3))
+ swr: 2.4.1(react@19.2.3)
ts-md5: 2.0.1
unified: 11.0.5
url-join: 5.0.0
@@ -5085,6 +5132,7 @@ snapshots:
- '@types/react-dom'
- micromark
- micromark-util-types
+ - solid-js
- supports-color
- vue
@@ -5094,7 +5142,7 @@ snapshots:
'@types/estree-jsx': 1.0.5
'@types/hast': 3.0.4
'@types/mdx': 2.0.13
- acorn: 8.15.0
+ acorn: 8.16.0
collapse-white-space: 2.1.0
devlop: 1.1.0
estree-util-is-identifier-name: 3.0.0
@@ -5103,7 +5151,7 @@ snapshots:
hast-util-to-jsx-runtime: 2.3.6
markdown-extensions: 2.0.0
recma-build-jsx: 1.0.0
- recma-jsx: 1.0.1(acorn@8.15.0)
+ recma-jsx: 1.0.1(acorn@8.16.0)
recma-stringify: 1.0.0
rehype-recma: 1.0.0
remark-mdx: 3.1.1
@@ -5113,7 +5161,7 @@ snapshots:
unified: 11.0.5
unist-util-position-from-estree: 2.0.0
unist-util-stringify-position: 4.0.0
- unist-util-visit: 5.0.0
+ unist-util-visit: 5.1.0
vfile: 6.0.3
transitivePeerDependencies:
- supports-color
@@ -5124,9 +5172,9 @@ snapshots:
'@types/react': 19.2.7
react: 19.2.3
- '@mermaid-js/parser@0.6.3':
+ '@mermaid-js/parser@1.1.0':
dependencies:
- langium: 3.3.1
+ langium: 4.2.2
'@nodelib/fs.scandir@2.1.5':
dependencies:
@@ -5145,7 +5193,7 @@ snapshots:
'@pkgjs/parseargs@0.11.0':
optional: true
- '@primer/octicons@19.21.1':
+ '@primer/octicons@19.23.1':
dependencies:
object-assign: 4.1.1
@@ -5192,7 +5240,7 @@ snapshots:
'@radix-ui/react-popper@1.2.8(@types/react@19.2.7)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@floating-ui/react-dom': 2.1.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@floating-ui/react-dom': 2.1.8(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@radix-ui/react-arrow': 1.1.7(@types/react@19.2.7)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@radix-ui/react-compose-refs': 1.1.2(@types/react@19.2.7)(react@19.2.3)
'@radix-ui/react-context': 1.1.2(@types/react@19.2.7)(react@19.2.3)
@@ -5341,46 +5389,46 @@ snapshots:
'@radix-ui/rect@1.1.1': {}
- '@rc-component/async-validator@5.0.4':
+ '@rc-component/async-validator@5.1.0':
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
'@rc-component/cascader@1.10.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
'@rc-component/select': 1.4.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/tree': 1.1.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/checkbox@1.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/collapse@1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
'@rc-component/motion': 1.1.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/color-picker@3.0.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@ant-design/fast-color': 3.0.0
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@ant-design/fast-color': 3.0.1
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/context@2.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -5388,7 +5436,7 @@ snapshots:
dependencies:
'@rc-component/motion': 1.1.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/portal': 2.2.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -5397,23 +5445,23 @@ snapshots:
dependencies:
'@rc-component/motion': 1.1.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/portal': 2.2.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/dropdown@1.0.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/trigger': 3.8.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/trigger': 3.9.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
- '@rc-component/form@1.6.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
+ '@rc-component/form@1.6.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/async-validator': 5.0.4
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/async-validator': 5.1.0
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -5422,22 +5470,22 @@ snapshots:
dependencies:
'@rc-component/motion': 1.1.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/portal': 2.2.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/input-number@1.6.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/mini-decimal': 1.1.0
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/mini-decimal': 1.1.3
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/input@1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -5447,8 +5495,8 @@ snapshots:
'@rc-component/input': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/menu': 1.2.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/textarea': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/trigger': 3.8.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/trigger': 3.9.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -5457,68 +5505,68 @@ snapshots:
dependencies:
'@rc-component/motion': 1.1.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/overflow': 1.0.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/trigger': 3.8.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/trigger': 3.9.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
- '@rc-component/mini-decimal@1.1.0':
+ '@rc-component/mini-decimal@1.1.3':
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
'@rc-component/motion@1.1.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/mutate-observer@2.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/notification@1.2.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
'@rc-component/motion': 1.1.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/overflow@1.0.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@babel/runtime': 7.28.4
- '@rc-component/resize-observer': 1.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@babel/runtime': 7.29.2
+ '@rc-component/resize-observer': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/pagination@1.2.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
- '@rc-component/picker@1.9.0(dayjs@1.11.19)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
+ '@rc-component/picker@1.9.1(dayjs@1.11.20)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
'@rc-component/overflow': 1.0.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/resize-observer': 1.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/trigger': 3.8.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/resize-observer': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/trigger': 3.9.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
optionalDependencies:
- dayjs: 1.11.19
+ dayjs: 1.11.20
'@rc-component/portal@1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
classnames: 2.5.1
rc-util: 5.44.4(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
react: 19.2.3
@@ -5526,42 +5574,42 @@ snapshots:
'@rc-component/portal@2.2.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/progress@1.0.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/qrcode@1.1.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/rate@1.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
- '@rc-component/resize-observer@1.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
+ '@rc-component/resize-observer@1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/segmented@1.3.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
'@rc-component/motion': 1.1.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -5569,8 +5617,8 @@ snapshots:
'@rc-component/select@1.4.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
'@rc-component/overflow': 1.0.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/trigger': 3.8.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/trigger': 3.9.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/virtual-list': 1.0.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
@@ -5578,21 +5626,21 @@ snapshots:
'@rc-component/slider@1.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/steps@1.2.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/switch@1.0.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -5600,8 +5648,8 @@ snapshots:
'@rc-component/table@1.9.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
'@rc-component/context': 2.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/resize-observer': 1.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/resize-observer': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/virtual-list': 1.0.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
@@ -5612,8 +5660,8 @@ snapshots:
'@rc-component/dropdown': 1.0.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/menu': 1.2.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/motion': 1.1.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/resize-observer': 1.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/resize-observer': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -5621,16 +5669,16 @@ snapshots:
'@rc-component/textarea@1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
'@rc-component/input': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/resize-observer': 1.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/resize-observer': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/tooltip@1.4.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/trigger': 3.8.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/trigger': 3.9.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -5638,8 +5686,8 @@ snapshots:
'@rc-component/tour@2.2.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
'@rc-component/portal': 2.2.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/trigger': 3.8.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/trigger': 3.9.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -5648,7 +5696,7 @@ snapshots:
dependencies:
'@rc-component/select': 1.4.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/tree': 1.1.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -5656,15 +5704,15 @@ snapshots:
'@rc-component/tree@1.1.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
'@rc-component/motion': 1.1.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/virtual-list': 1.0.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
- '@rc-component/trigger@2.3.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
+ '@rc-component/trigger@2.3.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
'@rc-component/portal': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
classnames: 2.5.1
rc-motion: 2.9.5(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
@@ -5673,23 +5721,30 @@ snapshots:
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
- '@rc-component/trigger@3.8.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
+ '@rc-component/trigger@3.9.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
'@rc-component/motion': 1.1.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/portal': 2.2.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/resize-observer': 1.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/resize-observer': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
'@rc-component/upload@1.1.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
+ '@rc-component/util@1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
+ dependencies:
+ is-mobile: 5.0.0
+ react: 19.2.3
+ react-dom: 19.2.3(react@19.2.3)
+ react-is: 18.3.1
+
'@rc-component/util@1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
is-mobile: 5.0.0
@@ -5699,9 +5754,9 @@ snapshots:
'@rc-component/virtual-list@1.0.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)':
dependencies:
- '@babel/runtime': 7.28.4
- '@rc-component/resize-observer': 1.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@babel/runtime': 7.29.2
+ '@rc-component/resize-observer': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -5772,38 +5827,38 @@ snapshots:
'@rollup/rollup-win32-x64-msvc@4.54.0':
optional: true
- '@shikijs/core@3.20.0':
+ '@shikijs/core@3.23.0':
dependencies:
- '@shikijs/types': 3.20.0
+ '@shikijs/types': 3.23.0
'@shikijs/vscode-textmate': 10.0.2
'@types/hast': 3.0.4
hast-util-to-html: 9.0.5
- '@shikijs/engine-javascript@3.20.0':
+ '@shikijs/engine-javascript@3.23.0':
dependencies:
- '@shikijs/types': 3.20.0
+ '@shikijs/types': 3.23.0
'@shikijs/vscode-textmate': 10.0.2
- oniguruma-to-es: 4.3.4
+ oniguruma-to-es: 4.3.5
- '@shikijs/engine-oniguruma@3.20.0':
+ '@shikijs/engine-oniguruma@3.23.0':
dependencies:
- '@shikijs/types': 3.20.0
+ '@shikijs/types': 3.23.0
'@shikijs/vscode-textmate': 10.0.2
- '@shikijs/langs@3.20.0':
+ '@shikijs/langs@3.23.0':
dependencies:
- '@shikijs/types': 3.20.0
+ '@shikijs/types': 3.23.0
- '@shikijs/themes@3.20.0':
+ '@shikijs/themes@3.23.0':
dependencies:
- '@shikijs/types': 3.20.0
+ '@shikijs/types': 3.23.0
- '@shikijs/transformers@3.20.0':
+ '@shikijs/transformers@3.23.0':
dependencies:
- '@shikijs/core': 3.20.0
- '@shikijs/types': 3.20.0
+ '@shikijs/core': 3.23.0
+ '@shikijs/types': 3.23.0
- '@shikijs/types@3.20.0':
+ '@shikijs/types@3.23.0':
dependencies:
'@shikijs/vscode-textmate': 10.0.2
'@types/hast': 3.0.4
@@ -5819,6 +5874,8 @@ snapshots:
dependencies:
react: 19.2.3
+ '@stripe/stripe-js@9.0.1': {}
+
'@tanstack/virtual-core@3.13.23': {}
'@tanstack/vue-virtual@3.13.23(vue@3.5.26(typescript@5.6.3))':
@@ -5891,7 +5948,7 @@ snapshots:
'@types/d3-selection@3.0.11': {}
- '@types/d3-shape@3.1.7':
+ '@types/d3-shape@3.1.8':
dependencies:
'@types/d3-path': 3.1.1
@@ -5936,14 +5993,14 @@ snapshots:
'@types/d3-scale': 4.0.9
'@types/d3-scale-chromatic': 3.1.0
'@types/d3-selection': 3.0.11
- '@types/d3-shape': 3.1.7
+ '@types/d3-shape': 3.1.8
'@types/d3-time': 3.0.4
'@types/d3-time-format': 4.0.3
'@types/d3-timer': 3.0.2
'@types/d3-transition': 3.0.9
'@types/d3-zoom': 3.0.8
- '@types/debug@4.1.12':
+ '@types/debug@4.1.13':
dependencies:
'@types/ms': 2.1.0
@@ -5967,7 +6024,7 @@ snapshots:
'@types/js-cookie@3.0.6': {}
- '@types/katex@0.16.7': {}
+ '@types/katex@0.16.8': {}
'@types/mdast@4.0.4':
dependencies:
@@ -6084,6 +6141,11 @@ snapshots:
'@ungap/structured-clone@1.3.0': {}
+ '@upsetjs/venn.js@2.0.0':
+ optionalDependencies:
+ d3-selection: 3.0.0
+ d3-transition: 3.0.1(d3-selection@3.0.0)
+
'@use-gesture/core@10.3.1': {}
'@use-gesture/react@10.3.1(react@19.2.3)':
@@ -6270,20 +6332,26 @@ snapshots:
dependencies:
acorn: 8.15.0
+ acorn-jsx@5.3.2(acorn@8.16.0):
+ dependencies:
+ acorn: 8.16.0
+
acorn@8.15.0: {}
+ acorn@8.16.0: {}
+
adler-32@1.3.1: {}
agent-base@7.1.4: {}
- ahooks@3.9.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
+ ahooks@3.9.7(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
'@types/js-cookie': 3.0.6
- dayjs: 1.11.19
+ dayjs: 1.11.20
intersection-observer: 0.12.2
js-cookie: 3.0.5
- lodash: 4.17.21
+ lodash: 4.18.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
react-fast-compare: 3.2.2
@@ -6329,13 +6397,13 @@ snapshots:
antd@6.1.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- '@ant-design/colors': 8.0.0
- '@ant-design/cssinjs': 2.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@ant-design/cssinjs-utils': 2.0.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@ant-design/fast-color': 3.0.0
- '@ant-design/icons': 6.1.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@ant-design/colors': 8.0.1
+ '@ant-design/cssinjs': 2.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@ant-design/cssinjs-utils': 2.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@ant-design/fast-color': 3.0.1
+ '@ant-design/icons': 6.1.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@ant-design/react-slick': 2.0.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
'@rc-component/cascader': 1.10.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/checkbox': 1.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/collapse': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
@@ -6343,7 +6411,7 @@ snapshots:
'@rc-component/dialog': 1.5.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/drawer': 1.3.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/dropdown': 1.0.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/form': 1.6.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/form': 1.6.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/image': 1.5.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/input': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/input-number': 1.6.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
@@ -6353,11 +6421,11 @@ snapshots:
'@rc-component/mutate-observer': 2.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/notification': 1.2.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/pagination': 1.2.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/picker': 1.9.0(dayjs@1.11.19)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/picker': 1.9.1(dayjs@1.11.20)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/progress': 1.0.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/qrcode': 1.1.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/rate': 1.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/resize-observer': 1.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/resize-observer': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/segmented': 1.3.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/select': 1.4.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/slider': 1.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
@@ -6370,11 +6438,11 @@ snapshots:
'@rc-component/tour': 2.2.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/tree': 1.1.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/tree-select': 1.5.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/trigger': 3.8.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/trigger': 3.9.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
'@rc-component/upload': 1.1.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
- '@rc-component/util': 1.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@rc-component/util': 1.10.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
clsx: 2.1.1
- dayjs: 1.11.19
+ dayjs: 1.11.20
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
scroll-into-view-if-needed: 3.1.0
@@ -6416,11 +6484,11 @@ snapshots:
postcss: 8.5.6
postcss-value-parser: 4.2.0
- axios@1.13.5:
+ axios@1.15.0:
dependencies:
follow-redirects: 1.15.11
form-data: 4.0.5
- proxy-from-env: 1.1.0
+ proxy-from-env: 2.1.0
transitivePeerDependencies:
- debug
@@ -6510,19 +6578,18 @@ snapshots:
check-error@2.1.3: {}
- chevrotain-allstar@0.3.1(chevrotain@11.0.3):
+ chevrotain-allstar@0.4.1(chevrotain@12.0.0):
dependencies:
- chevrotain: 11.0.3
- lodash-es: 4.17.22
+ chevrotain: 12.0.0
+ lodash-es: 4.18.1
- chevrotain@11.0.3:
+ chevrotain@12.0.0:
dependencies:
- '@chevrotain/cst-dts-gen': 11.0.3
- '@chevrotain/gast': 11.0.3
- '@chevrotain/regexp-to-ast': 11.0.3
- '@chevrotain/types': 11.0.3
- '@chevrotain/utils': 11.0.3
- lodash-es: 4.17.21
+ '@chevrotain/cst-dts-gen': 12.0.0
+ '@chevrotain/gast': 12.0.0
+ '@chevrotain/regexp-to-ast': 12.0.0
+ '@chevrotain/types': 12.0.0
+ '@chevrotain/utils': 12.0.0
chokidar@3.6.0:
dependencies:
@@ -6554,8 +6621,6 @@ snapshots:
strip-ansi: 6.0.1
wrap-ansi: 6.2.0
- clsx@1.2.1: {}
-
clsx@2.1.1: {}
codepage@1.15.0: {}
@@ -6630,17 +6695,17 @@ snapshots:
csstype@3.2.3: {}
- cytoscape-cose-bilkent@4.1.0(cytoscape@3.33.1):
+ cytoscape-cose-bilkent@4.1.0(cytoscape@3.33.2):
dependencies:
cose-base: 1.0.3
- cytoscape: 3.33.1
+ cytoscape: 3.33.2
- cytoscape-fcose@2.2.0(cytoscape@3.33.1):
+ cytoscape-fcose@2.2.0(cytoscape@3.33.2):
dependencies:
cose-base: 2.2.0
- cytoscape: 3.33.1
+ cytoscape: 3.33.2
- cytoscape@3.33.1: {}
+ cytoscape@3.33.2: {}
d3-array@2.12.1:
dependencies:
@@ -6672,7 +6737,7 @@ snapshots:
d3-delaunay@6.0.4:
dependencies:
- delaunator: 5.0.1
+ delaunator: 5.1.0
d3-dispatch@3.0.1: {}
@@ -6699,7 +6764,7 @@ snapshots:
d3-quadtree: 3.0.1
d3-timer: 3.0.1
- d3-format@3.1.0: {}
+ d3-format@3.1.2: {}
d3-geo@3.1.1:
dependencies:
@@ -6734,7 +6799,7 @@ snapshots:
d3-scale@4.0.2:
dependencies:
d3-array: 3.2.4
- d3-format: 3.1.0
+ d3-format: 3.1.2
d3-interpolate: 3.0.1
d3-time: 3.1.0
d3-time-format: 4.1.0
@@ -6791,7 +6856,7 @@ snapshots:
d3-ease: 3.0.1
d3-fetch: 3.0.1
d3-force: 3.0.0
- d3-format: 3.1.0
+ d3-format: 3.1.2
d3-geo: 3.1.1
d3-hierarchy: 3.1.2
d3-interpolate: 3.0.1
@@ -6809,17 +6874,17 @@ snapshots:
d3-transition: 3.0.1(d3-selection@3.0.0)
d3-zoom: 3.0.0
- dagre-d3-es@7.0.13:
+ dagre-d3-es@7.0.14:
dependencies:
d3: 7.9.0
- lodash-es: 4.17.22
+ lodash-es: 4.18.1
data-urls@5.0.0:
dependencies:
whatwg-mimetype: 4.0.0
whatwg-url: 14.2.0
- dayjs@1.11.19: {}
+ dayjs@1.11.20: {}
de-indent@1.0.2: {}
@@ -6831,7 +6896,7 @@ snapshots:
decimal.js@10.6.0: {}
- decode-named-character-reference@1.2.0:
+ decode-named-character-reference@1.3.0:
dependencies:
character-entities: 2.0.2
@@ -6841,9 +6906,9 @@ snapshots:
deep-is@0.1.4: {}
- delaunator@5.0.1:
+ delaunator@5.1.0:
dependencies:
- robust-predicates: 3.0.2
+ robust-predicates: 3.0.3
delayed-stream@1.0.0: {}
@@ -6871,6 +6936,10 @@ snapshots:
optionalDependencies:
'@types/trusted-types': 2.0.7
+ dompurify@3.3.3:
+ optionalDependencies:
+ '@types/trusted-types': 2.0.7
+
driver.js@1.4.0: {}
dunder-proto@1.0.1:
@@ -6923,7 +6992,7 @@ snapshots:
has-tostringtag: 1.0.2
hasown: 2.0.2
- es-toolkit@1.43.0: {}
+ es-toolkit@1.45.1: {}
esast-util-from-estree@2.0.0:
dependencies:
@@ -6935,7 +7004,7 @@ snapshots:
esast-util-from-js@2.0.1:
dependencies:
'@types/estree-jsx': 1.0.5
- acorn: 8.15.0
+ acorn: 8.16.0
esast-util-from-estree: 2.0.0
vfile-message: 4.0.3
@@ -7180,10 +7249,10 @@ snapshots:
fraction.js@5.3.4: {}
- framer-motion@12.23.26(@emotion/is-prop-valid@1.4.0)(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
+ framer-motion@12.38.0(@emotion/is-prop-valid@1.4.0)(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- motion-dom: 12.23.23
- motion-utils: 12.23.6
+ motion-dom: 12.38.0
+ motion-utils: 12.36.0
tslib: 2.8.1
optionalDependencies:
'@emotion/is-prop-valid': 1.4.0
@@ -7199,7 +7268,7 @@ snapshots:
get-caller-file@2.0.5: {}
- get-east-asian-width@1.4.0: {}
+ get-east-asian-width@1.5.0: {}
get-intrinsic@1.3.0:
dependencies:
@@ -7334,7 +7403,7 @@ snapshots:
mdast-util-to-hast: 13.2.1
parse5: 7.3.0
unist-util-position: 5.0.0
- unist-util-visit: 5.0.0
+ unist-util-visit: 5.1.0
vfile: 6.0.3
web-namespaces: 2.0.1
zwitch: 2.0.4
@@ -7459,7 +7528,7 @@ snapshots:
ignore@5.3.2: {}
- immer@11.1.3: {}
+ immer@11.1.4: {}
import-fresh@3.3.1:
dependencies:
@@ -7625,7 +7694,7 @@ snapshots:
dependencies:
string-convert: 0.2.1
- katex@0.16.27:
+ katex@0.16.45:
dependencies:
commander: 8.3.0
@@ -7635,13 +7704,14 @@ snapshots:
khroma@2.1.0: {}
- langium@3.3.1:
+ langium@4.2.2:
dependencies:
- chevrotain: 11.0.3
- chevrotain-allstar: 0.3.1(chevrotain@11.0.3)
+ '@chevrotain/regexp-to-ast': 12.0.0
+ chevrotain: 12.0.0
+ chevrotain-allstar: 0.4.1(chevrotain@12.0.0)
vscode-languageserver: 9.0.1
vscode-languageserver-textdocument: 1.0.12
- vscode-uri: 3.0.8
+ vscode-uri: 3.1.0
layout-base@1.0.2: {}
@@ -7677,7 +7747,7 @@ snapshots:
lit-element@4.2.2:
dependencies:
- '@lit-labs/ssr-dom-shim': 1.5.0
+ '@lit-labs/ssr-dom-shim': 1.5.1
'@lit/reactive-element': 2.1.2
lit-html: 3.3.2
@@ -7699,14 +7769,14 @@ snapshots:
dependencies:
p-locate: 5.0.0
- lodash-es@4.17.21: {}
-
- lodash-es@4.17.22: {}
+ lodash-es@4.18.1: {}
lodash.merge@4.6.2: {}
lodash@4.17.21: {}
+ lodash@4.18.1: {}
+
longest-streak@3.1.0: {}
loose-envify@1.4.0:
@@ -7747,6 +7817,8 @@ snapshots:
marked@17.0.1: {}
+ marked@17.0.6: {}
+
math-intrinsics@1.1.0: {}
mdast-util-find-and-replace@3.0.2:
@@ -7756,11 +7828,11 @@ snapshots:
unist-util-is: 6.0.1
unist-util-visit-parents: 6.0.2
- mdast-util-from-markdown@2.0.2:
+ mdast-util-from-markdown@2.0.3:
dependencies:
'@types/mdast': 4.0.4
'@types/unist': 3.0.3
- decode-named-character-reference: 1.2.0
+ decode-named-character-reference: 1.3.0
devlop: 1.1.0
mdast-util-to-string: 4.0.0
micromark: 4.0.2
@@ -7785,7 +7857,7 @@ snapshots:
dependencies:
'@types/mdast': 4.0.4
devlop: 1.1.0
- mdast-util-from-markdown: 2.0.2
+ mdast-util-from-markdown: 2.0.3
mdast-util-to-markdown: 2.1.2
micromark-util-normalize-identifier: 2.0.1
transitivePeerDependencies:
@@ -7794,7 +7866,7 @@ snapshots:
mdast-util-gfm-strikethrough@2.0.0:
dependencies:
'@types/mdast': 4.0.4
- mdast-util-from-markdown: 2.0.2
+ mdast-util-from-markdown: 2.0.3
mdast-util-to-markdown: 2.1.2
transitivePeerDependencies:
- supports-color
@@ -7804,7 +7876,7 @@ snapshots:
'@types/mdast': 4.0.4
devlop: 1.1.0
markdown-table: 3.0.4
- mdast-util-from-markdown: 2.0.2
+ mdast-util-from-markdown: 2.0.3
mdast-util-to-markdown: 2.1.2
transitivePeerDependencies:
- supports-color
@@ -7813,14 +7885,14 @@ snapshots:
dependencies:
'@types/mdast': 4.0.4
devlop: 1.1.0
- mdast-util-from-markdown: 2.0.2
+ mdast-util-from-markdown: 2.0.3
mdast-util-to-markdown: 2.1.2
transitivePeerDependencies:
- supports-color
mdast-util-gfm@3.1.0:
dependencies:
- mdast-util-from-markdown: 2.0.2
+ mdast-util-from-markdown: 2.0.3
mdast-util-gfm-autolink-literal: 2.0.1
mdast-util-gfm-footnote: 2.1.0
mdast-util-gfm-strikethrough: 2.0.0
@@ -7836,7 +7908,7 @@ snapshots:
'@types/mdast': 4.0.4
devlop: 1.1.0
longest-streak: 3.1.0
- mdast-util-from-markdown: 2.0.2
+ mdast-util-from-markdown: 2.0.3
mdast-util-to-markdown: 2.1.2
unist-util-remove-position: 5.0.0
transitivePeerDependencies:
@@ -7848,7 +7920,7 @@ snapshots:
'@types/hast': 3.0.4
'@types/mdast': 4.0.4
devlop: 1.1.0
- mdast-util-from-markdown: 2.0.2
+ mdast-util-from-markdown: 2.0.3
mdast-util-to-markdown: 2.1.2
transitivePeerDependencies:
- supports-color
@@ -7861,7 +7933,7 @@ snapshots:
'@types/unist': 3.0.3
ccount: 2.0.1
devlop: 1.1.0
- mdast-util-from-markdown: 2.0.2
+ mdast-util-from-markdown: 2.0.3
mdast-util-to-markdown: 2.1.2
parse-entities: 4.0.2
stringify-entities: 4.0.4
@@ -7872,7 +7944,7 @@ snapshots:
mdast-util-mdx@3.0.0:
dependencies:
- mdast-util-from-markdown: 2.0.2
+ mdast-util-from-markdown: 2.0.3
mdast-util-mdx-expression: 2.0.1
mdast-util-mdx-jsx: 3.2.0
mdast-util-mdxjs-esm: 2.0.1
@@ -7886,7 +7958,7 @@ snapshots:
'@types/hast': 3.0.4
'@types/mdast': 4.0.4
devlop: 1.1.0
- mdast-util-from-markdown: 2.0.2
+ mdast-util-from-markdown: 2.0.3
mdast-util-to-markdown: 2.1.2
transitivePeerDependencies:
- supports-color
@@ -7910,7 +7982,7 @@ snapshots:
micromark-util-sanitize-uri: 2.0.1
trim-lines: 3.0.1
unist-util-position: 5.0.0
- unist-util-visit: 5.0.0
+ unist-util-visit: 5.1.0
vfile: 6.0.3
mdast-util-to-markdown@2.1.2:
@@ -7922,7 +7994,7 @@ snapshots:
mdast-util-to-string: 4.0.0
micromark-util-classify-character: 2.0.1
micromark-util-decode-string: 2.0.1
- unist-util-visit: 5.0.0
+ unist-util-visit: 5.1.0
zwitch: 2.0.4
mdast-util-to-string@4.0.0:
@@ -7938,23 +8010,24 @@ snapshots:
merge2@1.4.1: {}
- mermaid@11.12.2:
+ mermaid@11.14.0:
dependencies:
- '@braintree/sanitize-url': 7.1.1
+ '@braintree/sanitize-url': 7.1.2
'@iconify/utils': 3.1.0
- '@mermaid-js/parser': 0.6.3
+ '@mermaid-js/parser': 1.1.0
'@types/d3': 7.4.3
- cytoscape: 3.33.1
- cytoscape-cose-bilkent: 4.1.0(cytoscape@3.33.1)
- cytoscape-fcose: 2.2.0(cytoscape@3.33.1)
+ '@upsetjs/venn.js': 2.0.0
+ cytoscape: 3.33.2
+ cytoscape-cose-bilkent: 4.1.0(cytoscape@3.33.2)
+ cytoscape-fcose: 2.2.0(cytoscape@3.33.2)
d3: 7.9.0
d3-sankey: 0.12.3
- dagre-d3-es: 7.0.13
- dayjs: 1.11.19
- dompurify: 3.3.1
- katex: 0.16.27
+ dagre-d3-es: 7.0.14
+ dayjs: 1.11.20
+ dompurify: 3.3.3
+ katex: 0.16.45
khroma: 2.1.0
- lodash-es: 4.17.22
+ lodash-es: 4.18.1
marked: 16.4.2
roughjs: 4.6.6
stylis: 4.3.6
@@ -7963,7 +8036,7 @@ snapshots:
micromark-core-commonmark@2.0.3:
dependencies:
- decode-named-character-reference: 1.2.0
+ decode-named-character-reference: 1.3.0
devlop: 1.1.0
micromark-factory-destination: 2.0.1
micromark-factory-label: 2.0.1
@@ -7982,7 +8055,7 @@ snapshots:
micromark-extension-cjk-friendly-util@2.1.1(micromark-util-types@2.0.2):
dependencies:
- get-east-asian-width: 1.4.0
+ get-east-asian-width: 1.5.0
micromark-util-character: 2.1.1
micromark-util-symbol: 2.0.1
optionalDependencies:
@@ -8059,9 +8132,9 @@ snapshots:
micromark-extension-math@3.1.0:
dependencies:
- '@types/katex': 0.16.7
+ '@types/katex': 0.16.8
devlop: 1.1.0
- katex: 0.16.27
+ katex: 0.16.45
micromark-factory-space: 2.0.1
micromark-util-character: 2.1.1
micromark-util-symbol: 2.0.1
@@ -8109,8 +8182,8 @@ snapshots:
micromark-extension-mdxjs@3.0.0:
dependencies:
- acorn: 8.15.0
- acorn-jsx: 5.3.2(acorn@8.15.0)
+ acorn: 8.16.0
+ acorn-jsx: 5.3.2(acorn@8.16.0)
micromark-extension-mdx-expression: 3.0.1
micromark-extension-mdx-jsx: 3.0.2
micromark-extension-mdx-md: 2.0.0
@@ -8188,7 +8261,7 @@ snapshots:
micromark-util-decode-string@2.0.1:
dependencies:
- decode-named-character-reference: 1.2.0
+ decode-named-character-reference: 1.3.0
micromark-util-character: 2.1.1
micromark-util-decode-numeric-character-reference: 2.0.2
micromark-util-symbol: 2.0.1
@@ -8234,9 +8307,9 @@ snapshots:
micromark@4.0.2:
dependencies:
- '@types/debug': 4.1.12
+ '@types/debug': 4.1.13
debug: 4.4.3
- decode-named-character-reference: 1.2.0
+ decode-named-character-reference: 1.3.0
devlop: 1.1.0
micromark-core-commonmark: 2.0.3
micromark-factory-space: 2.0.1
@@ -8284,22 +8357,22 @@ snapshots:
for-in: 1.0.2
is-extendable: 1.0.1
- mlly@1.8.0:
+ mlly@1.8.2:
dependencies:
- acorn: 8.15.0
+ acorn: 8.16.0
pathe: 2.0.3
pkg-types: 1.3.1
- ufo: 1.6.1
+ ufo: 1.6.3
- motion-dom@12.23.23:
+ motion-dom@12.38.0:
dependencies:
- motion-utils: 12.23.6
+ motion-utils: 12.36.0
- motion-utils@12.23.6: {}
+ motion-utils@12.36.0: {}
motion@12.23.26(@emotion/is-prop-valid@1.4.0)(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- framer-motion: 12.23.26(@emotion/is-prop-valid@1.4.0)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ framer-motion: 12.38.0(@emotion/is-prop-valid@1.4.0)(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
tslib: 2.8.1
optionalDependencies:
'@emotion/is-prop-valid': 1.4.0
@@ -8353,7 +8426,7 @@ snapshots:
oniguruma-parser@0.12.1: {}
- oniguruma-to-es@4.3.4:
+ oniguruma-to-es@4.3.5:
dependencies:
oniguruma-parser: 0.12.1
regex: 6.1.0
@@ -8399,7 +8472,7 @@ snapshots:
'@types/unist': 2.0.11
character-entities-legacy: 3.0.0
character-reference-invalid: 2.0.1
- decode-named-character-reference: 1.2.0
+ decode-named-character-reference: 1.3.0
is-alphanumerical: 2.0.1
is-decimal: 2.0.1
is-hexadecimal: 2.0.1
@@ -8465,7 +8538,7 @@ snapshots:
pkg-types@1.3.1:
dependencies:
confbox: 0.1.8
- mlly: 1.8.0
+ mlly: 1.8.2
pathe: 2.0.3
pngjs@5.0.0: {}
@@ -8530,7 +8603,7 @@ snapshots:
proto-list@1.2.4: {}
- proxy-from-env@1.1.0: {}
+ proxy-from-env@2.1.0: {}
psl@1.15.0:
dependencies:
@@ -8556,7 +8629,7 @@ snapshots:
rc-collapse@4.0.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
classnames: 2.5.1
rc-motion: 2.9.5(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
rc-util: 5.44.4(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
@@ -8565,7 +8638,7 @@ snapshots:
rc-dialog@9.6.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
'@rc-component/portal': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
classnames: 2.5.1
rc-motion: 2.9.5(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
@@ -8575,14 +8648,14 @@ snapshots:
rc-footer@0.6.8(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
classnames: 2.5.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
rc-image@7.12.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
'@rc-component/portal': 1.1.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
classnames: 2.5.1
rc-dialog: 9.6.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
@@ -8593,8 +8666,8 @@ snapshots:
rc-input-number@9.5.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- '@babel/runtime': 7.28.4
- '@rc-component/mini-decimal': 1.1.0
+ '@babel/runtime': 7.29.2
+ '@rc-component/mini-decimal': 1.1.3
classnames: 2.5.1
rc-input: 1.8.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
rc-util: 5.44.4(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
@@ -8603,7 +8676,7 @@ snapshots:
rc-input@1.8.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
classnames: 2.5.1
rc-util: 5.44.4(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
react: 19.2.3
@@ -8611,8 +8684,8 @@ snapshots:
rc-menu@9.16.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- '@babel/runtime': 7.28.4
- '@rc-component/trigger': 2.3.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ '@babel/runtime': 7.29.2
+ '@rc-component/trigger': 2.3.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
classnames: 2.5.1
rc-motion: 2.9.5(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
rc-overflow: 1.5.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
@@ -8622,7 +8695,7 @@ snapshots:
rc-motion@2.9.5(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
classnames: 2.5.1
rc-util: 5.44.4(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
react: 19.2.3
@@ -8630,7 +8703,7 @@ snapshots:
rc-overflow@1.5.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
classnames: 2.5.1
rc-resize-observer: 1.4.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
rc-util: 5.44.4(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
@@ -8639,7 +8712,7 @@ snapshots:
rc-resize-observer@1.4.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
classnames: 2.5.1
rc-util: 5.44.4(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
react: 19.2.3
@@ -8648,7 +8721,7 @@ snapshots:
rc-util@5.44.4(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- '@babel/runtime': 7.28.4
+ '@babel/runtime': 7.29.2
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
react-is: 18.3.1
@@ -8673,9 +8746,9 @@ snapshots:
react: 19.2.3
scheduler: 0.27.0
- react-draggable@4.4.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
+ react-draggable@4.5.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
- clsx: 1.2.1
+ clsx: 2.1.1
prop-types: 15.8.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -8687,14 +8760,13 @@ snapshots:
prop-types: 15.8.1
react: 19.2.3
- react-error-boundary@6.0.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
+ react-error-boundary@6.1.1(react@19.2.3):
dependencies:
react: 19.2.3
- react-dom: 19.2.3(react@19.2.3)
react-fast-compare@3.2.2: {}
- react-hotkeys-hook@5.2.1(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
+ react-hotkeys-hook@5.2.4(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@@ -8716,7 +8788,7 @@ snapshots:
remark-parse: 11.0.0
remark-rehype: 11.1.2
unified: 11.0.5
- unist-util-visit: 5.0.0
+ unist-util-visit: 5.1.0
vfile: 6.0.3
transitivePeerDependencies:
- supports-color
@@ -8725,12 +8797,12 @@ snapshots:
optionalDependencies:
react: 19.2.3
- react-rnd@10.5.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
+ react-rnd@10.5.3(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
dependencies:
re-resizable: 6.11.2(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
- react-draggable: 4.4.6(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
+ react-draggable: 4.5.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3)
tslib: 2.6.2
react-zoom-pan-pinch@3.7.0(react-dom@19.2.3(react@19.2.3))(react@19.2.3):
@@ -8756,10 +8828,10 @@ snapshots:
estree-util-build-jsx: 3.0.1
vfile: 6.0.3
- recma-jsx@1.0.1(acorn@8.15.0):
+ recma-jsx@1.0.1(acorn@8.16.0):
dependencies:
- acorn: 8.15.0
- acorn-jsx: 5.3.2(acorn@8.15.0)
+ acorn: 8.16.0
+ acorn-jsx: 5.3.2(acorn@8.16.0)
estree-util-to-js: 2.0.0
recma-parse: 1.0.0
recma-stringify: 1.0.0
@@ -8791,18 +8863,18 @@ snapshots:
rehype-github-alerts@4.2.0:
dependencies:
- '@primer/octicons': 19.21.1
+ '@primer/octicons': 19.23.1
hast-util-from-html: 2.0.3
hast-util-is-element: 3.0.0
- unist-util-visit: 5.0.0
+ unist-util-visit: 5.1.0
rehype-katex@7.0.1:
dependencies:
'@types/hast': 3.0.4
- '@types/katex': 0.16.7
+ '@types/katex': 0.16.8
hast-util-from-html-isomorphic: 2.0.0
hast-util-to-text: 4.0.2
- katex: 0.16.27
+ katex: 0.16.45
unist-util-visit-parents: 6.0.2
vfile: 6.0.3
@@ -8853,7 +8925,7 @@ snapshots:
mdast-util-find-and-replace: 3.0.2
mdast-util-to-string: 4.0.0
to-vfile: 8.0.0
- unist-util-visit: 5.0.0
+ unist-util-visit: 5.1.0
vfile: 6.0.3
remark-math@6.0.0:
@@ -8875,7 +8947,7 @@ snapshots:
remark-parse@11.0.0:
dependencies:
'@types/mdast': 4.0.4
- mdast-util-from-markdown: 2.0.2
+ mdast-util-from-markdown: 2.0.3
micromark-util-types: 2.0.2
unified: 11.0.5
transitivePeerDependencies:
@@ -8919,7 +8991,7 @@ snapshots:
dependencies:
glob: 7.2.3
- robust-predicates@3.0.2: {}
+ robust-predicates@3.0.3: {}
rollup@4.54.0:
dependencies:
@@ -8999,21 +9071,21 @@ snapshots:
shebang-regex@3.0.0: {}
- shiki-stream@0.1.3(react@19.2.3)(vue@3.5.26(typescript@5.6.3)):
+ shiki-stream@0.1.4(react@19.2.3)(vue@3.5.26(typescript@5.6.3)):
dependencies:
- '@shikijs/core': 3.20.0
+ '@shikijs/core': 3.23.0
optionalDependencies:
react: 19.2.3
vue: 3.5.26(typescript@5.6.3)
- shiki@3.20.0:
+ shiki@3.23.0:
dependencies:
- '@shikijs/core': 3.20.0
- '@shikijs/engine-javascript': 3.20.0
- '@shikijs/engine-oniguruma': 3.20.0
- '@shikijs/langs': 3.20.0
- '@shikijs/themes': 3.20.0
- '@shikijs/types': 3.20.0
+ '@shikijs/core': 3.23.0
+ '@shikijs/engine-javascript': 3.23.0
+ '@shikijs/engine-oniguruma': 3.23.0
+ '@shikijs/langs': 3.23.0
+ '@shikijs/themes': 3.23.0
+ '@shikijs/types': 3.23.0
'@shikijs/vscode-textmate': 10.0.2
'@types/hast': 3.0.4
@@ -9102,7 +9174,7 @@ snapshots:
supports-preserve-symlinks-flag@1.0.0: {}
- swr@2.3.8(react@19.2.3):
+ swr@2.4.1(react@19.2.3):
dependencies:
dequal: 2.0.3
react: 19.2.3
@@ -9164,7 +9236,7 @@ snapshots:
tinyexec@0.3.2: {}
- tinyexec@1.0.2: {}
+ tinyexec@1.1.1: {}
tinyglobby@0.2.15:
dependencies:
@@ -9222,7 +9294,7 @@ snapshots:
typescript@5.6.3: {}
- ufo@1.6.1: {}
+ ufo@1.6.3: {}
undici-types@6.21.0: {}
@@ -9258,7 +9330,7 @@ snapshots:
unist-util-remove-position@5.0.0:
dependencies:
'@types/unist': 3.0.3
- unist-util-visit: 5.0.0
+ unist-util-visit: 5.1.0
unist-util-stringify-position@4.0.0:
dependencies:
@@ -9269,7 +9341,7 @@ snapshots:
'@types/unist': 3.0.3
unist-util-is: 6.0.1
- unist-util-visit@5.0.0:
+ unist-util-visit@5.1.0:
dependencies:
'@types/unist': 3.0.3
unist-util-is: 6.0.1
@@ -9421,8 +9493,6 @@ snapshots:
dependencies:
vscode-languageserver-protocol: 3.17.5
- vscode-uri@3.0.8: {}
-
vscode-uri@3.1.0: {}
vue-chartjs@5.3.3(chart.js@4.5.1)(vue@3.5.26(typescript@5.6.3)):
diff --git a/frontend/src/api/__tests__/admin.users.spec.ts b/frontend/src/api/__tests__/admin.users.spec.ts
new file mode 100644
index 00000000..37656b78
--- /dev/null
+++ b/frontend/src/api/__tests__/admin.users.spec.ts
@@ -0,0 +1,117 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+const { post } = vi.hoisted(() => ({
+ post: vi.fn(),
+}))
+
+vi.mock('@/api/client', () => ({
+ apiClient: {
+ post,
+ },
+}))
+
+import {
+ bindUserAuthIdentity,
+ type AdminBindAuthIdentityRequest,
+ type AdminBoundAuthIdentity,
+} from '@/api/admin/users'
+
+type Assert = T
+type IsExact = (
+ (() => G extends T ? 1 : 2) extends (() => G extends U ? 1 : 2)
+ ? ((() => G extends U ? 1 : 2) extends (() => G extends T ? 1 : 2) ? true : false)
+ : false
+)
+
+type ExpectedAdminBindAuthIdentityRequest = {
+ provider_type: string
+ provider_key: string
+ provider_subject: string
+ issuer?: string
+ metadata?: Record
+ channel?: {
+ channel: string
+ channel_app_id: string
+ channel_subject: string
+ metadata?: Record
+ }
+}
+
+type ExpectedAdminBoundAuthIdentity = {
+ user_id: number
+ provider_type: string
+ provider_key: string
+ provider_subject: string
+ verified_at?: string | null
+ issuer?: string | null
+ metadata: Record | null
+ created_at: string
+ updated_at: string
+ channel?: {
+ channel: string
+ channel_app_id: string
+ channel_subject: string
+ metadata: Record | null
+ created_at: string
+ updated_at: string
+ } | null
+}
+
+const requestContractExact: Assert<
+ IsExact
+> = true
+const responseContractExact: Assert<
+ IsExact
+> = true
+
+describe('admin users api auth identity binding', () => {
+ beforeEach(() => {
+ post.mockReset()
+ })
+
+ it('posts the backend-compatible auth identity bind payload and returns the backend response shape', async () => {
+ const payload: AdminBindAuthIdentityRequest = {
+ provider_type: 'wechat',
+ provider_key: 'wechat-main',
+ provider_subject: 'union-123',
+ metadata: { source: 'admin-repair' },
+ channel: {
+ channel: 'open',
+ channel_app_id: 'wx-open',
+ channel_subject: 'openid-123',
+ metadata: { scene: 'migration' },
+ },
+ }
+
+ const response: AdminBoundAuthIdentity = {
+ user_id: 9,
+ provider_type: 'wechat',
+ provider_key: 'wechat-main',
+ provider_subject: 'union-123',
+ verified_at: '2026-04-22T00:00:00Z',
+ issuer: null,
+ metadata: { source: 'admin-repair' },
+ created_at: '2026-04-22T00:00:00Z',
+ updated_at: '2026-04-22T00:00:00Z',
+ channel: {
+ channel: 'open',
+ channel_app_id: 'wx-open',
+ channel_subject: 'openid-123',
+ metadata: { scene: 'migration' },
+ created_at: '2026-04-22T00:00:00Z',
+ updated_at: '2026-04-22T00:00:00Z',
+ },
+ }
+ post.mockResolvedValue({ data: response })
+
+ const result = await bindUserAuthIdentity(9, payload)
+
+ expect(post).toHaveBeenCalledWith('/admin/users/9/auth-identities', payload)
+ expect(result).toEqual(response)
+ })
+
+ it('keeps bind auth identity request and response types aligned with the backend contract', () => {
+ expect(requestContractExact).toBe(true)
+ expect(responseContractExact).toBe(true)
+ })
+})
diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
new file mode 100644
index 00000000..07a68c03
--- /dev/null
+++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts
@@ -0,0 +1,224 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+const post = vi.fn()
+
+vi.mock('@/api/client', () => ({
+ apiClient: {
+ post
+ }
+}))
+
+describe('oauth adoption auth api', () => {
+ beforeEach(() => {
+ post.mockReset()
+ post.mockResolvedValue({ data: {} })
+ localStorage.clear()
+ document.cookie = 'oauth_bind_access_token=; Max-Age=0; path=/'
+ })
+
+ it('posts adoption decisions when exchanging pending oauth completion', async () => {
+ const { exchangePendingOAuthCompletion } = await import('@/api/auth')
+
+ await exchangePendingOAuthCompletion({
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
+ it('posts bind-login decisions when finalizing pending oauth bind flow', async () => {
+ const { completePendingOAuthBindLogin } = await import('@/api/auth')
+
+ await completePendingOAuthBindLogin({
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
+ it('posts linuxdo invitation completion with adoption decisions', async () => {
+ const { completeLinuxDoOAuthRegistration } = await import('@/api/auth')
+
+ await completeLinuxDoOAuthRegistration('invite-code', {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
+ it('posts linuxdo create-account completion with adoption decisions', async () => {
+ const { createPendingLinuxDoOAuthAccount } = await import('@/api/auth')
+
+ await createPendingLinuxDoOAuthAccount('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
+ it('posts affiliate code when completing linuxdo oauth registration', async () => {
+ const { completeLinuxDoOAuthRegistration } = await import('@/api/auth')
+
+ await completeLinuxDoOAuthRegistration(
+ 'invite-code',
+ {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ },
+ ' AFF123 '
+ )
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', {
+ invitation_code: 'invite-code',
+ aff_code: 'AFF123',
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
+ it('posts oidc invitation completion with adoption decisions', async () => {
+ const { completeOIDCOAuthRegistration } = await import('@/api/auth')
+
+ await completeOIDCOAuthRegistration('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
+ it('posts oidc create-account completion with adoption decisions', async () => {
+ const { createPendingOIDCOAuthAccount } = await import('@/api/auth')
+
+ await createPendingOIDCOAuthAccount('invite-code', {
+ adoptDisplayName: true,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: true,
+ adopt_avatar: false
+ })
+ })
+
+ it('posts wechat invitation completion with adoption decisions', async () => {
+ const { completeWeChatOAuthRegistration } = await import('@/api/auth')
+
+ await completeWeChatOAuthRegistration('invite-code', {
+ adoptDisplayName: true,
+ adoptAvatar: true
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: true,
+ adopt_avatar: true
+ })
+ })
+
+ it('posts wechat create-account completion with adoption decisions', async () => {
+ const { createPendingWeChatOAuthAccount } = await import('@/api/auth')
+
+ await createPendingWeChatOAuthAccount('invite-code', {
+ adoptDisplayName: false,
+ adoptAvatar: false
+ })
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
+ invitation_code: 'invite-code',
+ adopt_display_name: false,
+ adopt_avatar: false
+ })
+ })
+
+ it('posts affiliate code when creating pending wechat oauth account', async () => {
+ const { createPendingWeChatOAuthAccount } = await import('@/api/auth')
+
+ await createPendingWeChatOAuthAccount(
+ 'invite-code',
+ {
+ adoptDisplayName: false,
+ adoptAvatar: true
+ },
+ 'WXAFF'
+ )
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', {
+ invitation_code: 'invite-code',
+ aff_code: 'WXAFF',
+ adopt_display_name: false,
+ adopt_avatar: true
+ })
+ })
+
+ it('classifies oauth completion results as login or bind', async () => {
+ const { getOAuthCompletionKind } = await import('@/api/auth')
+
+ expect(getOAuthCompletionKind({ access_token: 'access-token' })).toBe('login')
+ expect(getOAuthCompletionKind({ redirect: '/profile' })).toBe('bind')
+ })
+
+ it('provides bind-login utility helpers for invitation and suggested profile states', async () => {
+ const {
+ getPendingOAuthBindLoginKind,
+ hasPendingOAuthSuggestedProfile,
+ isPendingOAuthCreateAccountRequired
+ } = await import('@/api/auth')
+
+ expect(getPendingOAuthBindLoginKind({ access_token: 'access-token' })).toBe('login')
+ expect(getPendingOAuthBindLoginKind({ redirect: '/profile' })).toBe('bind')
+ expect(
+ isPendingOAuthCreateAccountRequired({
+ error: 'invitation_required'
+ })
+ ).toBe(true)
+ expect(
+ isPendingOAuthCreateAccountRequired({
+ error: 'other'
+ })
+ ).toBe(false)
+ expect(
+ hasPendingOAuthSuggestedProfile({
+ suggested_display_name: 'OAuth Nick'
+ })
+ ).toBe(true)
+ expect(
+ hasPendingOAuthSuggestedProfile({
+ suggested_avatar_url: 'https://cdn.example/avatar.png'
+ })
+ ).toBe(true)
+ expect(hasPendingOAuthSuggestedProfile({})).toBe(false)
+ })
+
+ it('requests an HttpOnly oauth bind cookie before redirect binding', async () => {
+ localStorage.setItem('auth_token', 'access-token-value')
+ const { prepareOAuthBindAccessTokenCookie } = await import('@/api/auth')
+
+ await prepareOAuthBindAccessTokenCookie()
+
+ expect(post).toHaveBeenCalledWith('/auth/oauth/bind-token')
+ })
+})
diff --git a/frontend/src/api/__tests__/client.spec.ts b/frontend/src/api/__tests__/client.spec.ts
index 0f663e76..a46c39eb 100644
--- a/frontend/src/api/__tests__/client.spec.ts
+++ b/frontend/src/api/__tests__/client.spec.ts
@@ -91,6 +91,22 @@ describe('API Client', () => {
const config = adapter.mock.calls[0][0]
expect(config.params?.timezone).toBeUndefined()
})
+
+ it('请求默认带 withCredentials 以支持跨域 cookie', async () => {
+ const adapter = vi.fn().mockResolvedValue({
+ status: 200,
+ data: { code: 0, data: {} },
+ headers: {},
+ config: {},
+ statusText: 'OK',
+ })
+ apiClient.defaults.adapter = adapter
+
+ await apiClient.post('/auth/oauth/bind-token')
+
+ const config = adapter.mock.calls[0][0]
+ expect(config.withCredentials).toBe(true)
+ })
})
// --- 响应拦截器 ---
diff --git a/frontend/src/api/__tests__/payment.spec.ts b/frontend/src/api/__tests__/payment.spec.ts
new file mode 100644
index 00000000..e38fba57
--- /dev/null
+++ b/frontend/src/api/__tests__/payment.spec.ts
@@ -0,0 +1,40 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+
+const { get, post } = vi.hoisted(() => ({
+ get: vi.fn(),
+ post: vi.fn(),
+}))
+
+vi.mock('@/api/client', () => ({
+ apiClient: {
+ get,
+ post,
+ },
+}))
+
+import { paymentAPI } from '@/api/payment'
+
+describe('payment api', () => {
+ beforeEach(() => {
+ get.mockReset()
+ post.mockReset()
+ get.mockResolvedValue({ data: {} })
+ post.mockResolvedValue({ data: {} })
+ })
+
+ it('keeps legacy public out_trade_no verification for upgrade compatibility', async () => {
+ await paymentAPI.verifyOrderPublic('legacy-order-no')
+
+ expect(post).toHaveBeenCalledWith('/payment/public/orders/verify', {
+ out_trade_no: 'legacy-order-no',
+ })
+ })
+
+ it('keeps signed public resume-token resolve endpoint', async () => {
+ await paymentAPI.resolveOrderPublicByResumeToken('resume-token-123')
+
+ expect(post).toHaveBeenCalledWith('/payment/public/orders/resolve', {
+ resume_token: 'resume-token-123',
+ })
+ })
+})
diff --git a/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts
new file mode 100644
index 00000000..10f6247a
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts
@@ -0,0 +1,131 @@
+import { describe, expect, it } from "vitest";
+
+import {
+ appendAuthSourceDefaultsToUpdateRequest,
+ buildAuthSourceDefaultsState,
+ type UpdateSettingsRequest,
+} from "@/api/admin/settings";
+
+describe("admin settings auth source defaults helpers", () => {
+ it("builds auth source defaults state from flat settings fields", () => {
+ const state = buildAuthSourceDefaultsState({
+ auth_source_default_email_balance: 9.5,
+ auth_source_default_email_concurrency: 3,
+ auth_source_default_email_subscriptions: [
+ { group_id: 1, validity_days: 30 },
+ ],
+ auth_source_default_email_grant_on_signup: false,
+ auth_source_default_email_grant_on_first_bind: true,
+ auth_source_default_linuxdo_balance: 6,
+ auth_source_default_linuxdo_concurrency: 8,
+ auth_source_default_linuxdo_subscriptions: [
+ { group_id: 2, validity_days: 60 },
+ ],
+ auth_source_default_linuxdo_grant_on_signup: true,
+ auth_source_default_linuxdo_grant_on_first_bind: false,
+ });
+
+ expect(state.email).toEqual({
+ balance: 9.5,
+ concurrency: 3,
+ subscriptions: [{ group_id: 1, validity_days: 30 }],
+ grant_on_signup: false,
+ grant_on_first_bind: true,
+ });
+ expect(state.linuxdo).toEqual({
+ balance: 6,
+ concurrency: 8,
+ subscriptions: [{ group_id: 2, validity_days: 60 }],
+ grant_on_signup: true,
+ grant_on_first_bind: false,
+ });
+ expect(state.oidc).toEqual({
+ balance: 0,
+ concurrency: 5,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: false,
+ });
+ expect(state.wechat).toEqual({
+ balance: 0,
+ concurrency: 5,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: false,
+ });
+ });
+
+ it("defaults grant-on-signup to disabled when settings are missing", () => {
+ const state = buildAuthSourceDefaultsState({});
+
+ expect(state.email.grant_on_signup).toBe(false);
+ expect(state.linuxdo.grant_on_signup).toBe(false);
+ expect(state.oidc.grant_on_signup).toBe(false);
+ expect(state.wechat.grant_on_signup).toBe(false);
+ });
+
+ it("appends auth source defaults back onto update payload", () => {
+ const payload: UpdateSettingsRequest = {
+ site_name: "Sub2API",
+ };
+
+ appendAuthSourceDefaultsToUpdateRequest(payload, {
+ email: {
+ balance: 1.25,
+ concurrency: 2,
+ subscriptions: [{ group_id: 3, validity_days: 7 }],
+ grant_on_signup: true,
+ grant_on_first_bind: false,
+ },
+ linuxdo: {
+ balance: 0,
+ concurrency: 6,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: true,
+ },
+ oidc: {
+ balance: 4,
+ concurrency: 9,
+ subscriptions: [{ group_id: 9, validity_days: 90 }],
+ grant_on_signup: true,
+ grant_on_first_bind: true,
+ },
+ wechat: {
+ balance: 2,
+ concurrency: 5,
+ subscriptions: [],
+ grant_on_signup: false,
+ grant_on_first_bind: false,
+ },
+ });
+
+ expect(payload).toMatchObject({
+ site_name: "Sub2API",
+ auth_source_default_email_balance: 1.25,
+ auth_source_default_email_concurrency: 2,
+ auth_source_default_email_subscriptions: [
+ { group_id: 3, validity_days: 7 },
+ ],
+ auth_source_default_email_grant_on_signup: true,
+ auth_source_default_email_grant_on_first_bind: false,
+ auth_source_default_linuxdo_balance: 0,
+ auth_source_default_linuxdo_concurrency: 6,
+ auth_source_default_linuxdo_subscriptions: [],
+ auth_source_default_linuxdo_grant_on_signup: false,
+ auth_source_default_linuxdo_grant_on_first_bind: true,
+ auth_source_default_oidc_balance: 4,
+ auth_source_default_oidc_concurrency: 9,
+ auth_source_default_oidc_subscriptions: [
+ { group_id: 9, validity_days: 90 },
+ ],
+ auth_source_default_oidc_grant_on_signup: true,
+ auth_source_default_oidc_grant_on_first_bind: true,
+ auth_source_default_wechat_balance: 2,
+ auth_source_default_wechat_concurrency: 5,
+ auth_source_default_wechat_subscriptions: [],
+ auth_source_default_wechat_grant_on_signup: false,
+ auth_source_default_wechat_grant_on_first_bind: false,
+ });
+ });
+});
diff --git a/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
new file mode 100644
index 00000000..ad355afe
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.paymentVisibleMethods.spec.ts
@@ -0,0 +1,63 @@
+import { describe, expect, it } from 'vitest'
+
+import {
+ getPaymentVisibleMethodSourceOptions,
+ normalizePaymentVisibleMethodSource,
+} from '@/api/admin/settings'
+
+describe('admin settings payment visible method helpers', () => {
+ it('normalizes aliases into canonical source keys per visible method', () => {
+ expect(normalizePaymentVisibleMethodSource('alipay', 'official')).toBe('official_alipay')
+ expect(normalizePaymentVisibleMethodSource('alipay', 'alipay_direct')).toBe('official_alipay')
+ expect(normalizePaymentVisibleMethodSource('alipay', 'easypay')).toBe('easypay_alipay')
+
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'official')).toBe('official_wxpay')
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'wechat')).toBe('official_wxpay')
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'easypay')).toBe('easypay_wxpay')
+ })
+
+ it('rejects unknown or cross-method source values', () => {
+ expect(normalizePaymentVisibleMethodSource('alipay', 'official_wxpay')).toBe('')
+ expect(normalizePaymentVisibleMethodSource('wxpay', 'official_alipay')).toBe('')
+ expect(normalizePaymentVisibleMethodSource('alipay', 'unknown')).toBe('')
+ expect(normalizePaymentVisibleMethodSource('wxpay', null)).toBe('')
+ })
+
+ it('exposes method-scoped source options instead of arbitrary strings', () => {
+ expect(getPaymentVisibleMethodSourceOptions('alipay')).toEqual([
+ {
+ value: '',
+ labelZh: '未配置',
+ labelEn: 'Not configured',
+ },
+ {
+ value: 'official_alipay',
+ labelZh: '支付宝官方',
+ labelEn: 'Official Alipay',
+ },
+ {
+ value: 'easypay_alipay',
+ labelZh: '易支付支付宝',
+ labelEn: 'EasyPay Alipay',
+ },
+ ])
+
+ expect(getPaymentVisibleMethodSourceOptions('wxpay')).toEqual([
+ {
+ value: '',
+ labelZh: '未配置',
+ labelEn: 'Not configured',
+ },
+ {
+ value: 'official_wxpay',
+ labelZh: '微信官方',
+ labelEn: 'Official WeChat Pay',
+ },
+ {
+ value: 'easypay_wxpay',
+ labelZh: '易支付微信',
+ labelEn: 'EasyPay WeChat Pay',
+ },
+ ])
+ })
+})
diff --git a/frontend/src/api/__tests__/settings.wechatConnect.spec.ts b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts
new file mode 100644
index 00000000..eccb7214
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts
@@ -0,0 +1,21 @@
+import { describe, expect, it } from "vitest";
+
+import {
+ defaultWeChatConnectScopesForMode,
+ normalizeWeChatConnectMode,
+} from "@/api/admin/settings";
+
+describe("admin settings wechat connect helpers", () => {
+ it("normalizes legacy or noisy mode values to the backend contract", () => {
+ expect(normalizeWeChatConnectMode("OPEN")).toBe("open");
+ expect(normalizeWeChatConnectMode(" open_platform ")).toBe("open");
+ expect(normalizeWeChatConnectMode("mp")).toBe("mp");
+ expect(normalizeWeChatConnectMode("official_account")).toBe("mp");
+ expect(normalizeWeChatConnectMode("unknown")).toBe("open");
+ });
+
+ it("maps each mode to the backend default scopes", () => {
+ expect(defaultWeChatConnectScopesForMode("open")).toBe("snsapi_login");
+ expect(defaultWeChatConnectScopesForMode("mp")).toBe("snsapi_userinfo");
+ });
+});
diff --git a/frontend/src/api/__tests__/sora.spec.ts b/frontend/src/api/__tests__/sora.spec.ts
deleted file mode 100644
index 88c0c416..00000000
--- a/frontend/src/api/__tests__/sora.spec.ts
+++ /dev/null
@@ -1,80 +0,0 @@
-import { describe, expect, it } from 'vitest'
-import {
- normalizeGenerationListResponse,
- normalizeModelFamiliesResponse
-} from '../sora'
-
-describe('sora api normalizers', () => {
- it('normalizes generation list from data shape', () => {
- const result = normalizeGenerationListResponse({
- data: [{ id: 1, status: 'pending' }],
- total: 9,
- page: 2
- })
-
- expect(result.data).toHaveLength(1)
- expect(result.total).toBe(9)
- expect(result.page).toBe(2)
- })
-
- it('normalizes generation list from items shape', () => {
- const result = normalizeGenerationListResponse({
- items: [{ id: 1, status: 'completed' }],
- total: 1
- })
-
- expect(result.data).toHaveLength(1)
- expect(result.total).toBe(1)
- expect(result.page).toBe(1)
- })
-
- it('falls back to empty generation list on invalid payload', () => {
- const result = normalizeGenerationListResponse(null)
- expect(result).toEqual({ data: [], total: 0, page: 1 })
- })
-
- it('normalizes family model payload', () => {
- const result = normalizeModelFamiliesResponse({
- data: [
- {
- id: 'sora2',
- name: 'Sora 2',
- type: 'video',
- orientations: ['landscape', 'portrait'],
- durations: [10, 15]
- }
- ]
- })
-
- expect(result).toHaveLength(1)
- expect(result[0].id).toBe('sora2')
- expect(result[0].orientations).toEqual(['landscape', 'portrait'])
- expect(result[0].durations).toEqual([10, 15])
- })
-
- it('normalizes legacy flat model list into families', () => {
- const result = normalizeModelFamiliesResponse({
- items: [
- { id: 'sora2-landscape-10s', type: 'video' },
- { id: 'sora2-portrait-15s', type: 'video' },
- { id: 'gpt-image-square', type: 'image' }
- ]
- })
-
- const sora2 = result.find((m) => m.id === 'sora2')
- expect(sora2).toBeTruthy()
- expect(sora2?.orientations).toEqual(['landscape', 'portrait'])
- expect(sora2?.durations).toEqual([10, 15])
-
- const image = result.find((m) => m.id === 'gpt-image')
- expect(image).toBeTruthy()
- expect(image?.type).toBe('image')
- expect(image?.orientations).toEqual(['square'])
- })
-
- it('falls back to empty families on invalid payload', () => {
- expect(normalizeModelFamiliesResponse(undefined)).toEqual([])
- expect(normalizeModelFamiliesResponse({})).toEqual([])
- })
-})
-
diff --git a/frontend/src/api/__tests__/user.spec.ts b/frontend/src/api/__tests__/user.spec.ts
new file mode 100644
index 00000000..887046da
--- /dev/null
+++ b/frontend/src/api/__tests__/user.spec.ts
@@ -0,0 +1,32 @@
+import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
+
+describe('user api oauth binding urls', () => {
+ beforeEach(() => {
+ vi.resetModules()
+ vi.stubEnv('VITE_API_BASE_URL', 'https://api.example.com/api/v1')
+ })
+
+ afterEach(() => {
+ vi.unstubAllEnvs()
+ })
+
+ it('builds third-party bind urls against the bind start endpoint', async () => {
+ const { buildOAuthBindingStartURL } = await import('@/api/user')
+
+ expect(buildOAuthBindingStartURL('linuxdo', { redirectTo: '/settings/profile' })).toBe(
+ 'https://api.example.com/api/v1/auth/oauth/linuxdo/bind/start?redirect=%2Fsettings%2Fprofile&intent=bind_current_user'
+ )
+ expect(
+ buildOAuthBindingStartURL('wechat', {
+ redirectTo: '/settings/profile',
+ wechatOAuthSettings: {
+ wechat_oauth_open_enabled: true,
+ wechat_oauth_mp_enabled: false,
+ wechat_oauth_mobile_enabled: false
+ }
+ })
+ ).toBe(
+ 'https://api.example.com/api/v1/auth/oauth/wechat/bind/start?redirect=%2Fsettings%2Fprofile&intent=bind_current_user&mode=open'
+ )
+ })
+})
diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts
index fd93fe7e..a146f1f7 100644
--- a/frontend/src/api/admin/accounts.ts
+++ b/frontend/src/api/admin/accounts.ts
@@ -38,6 +38,8 @@ export async function list(
search?: string
privacy_mode?: string
lite?: string
+ sort_by?: string
+ sort_order?: 'asc' | 'desc'
},
options?: {
signal?: AbortSignal
@@ -71,6 +73,8 @@ export async function listWithEtag(
search?: string
privacy_mode?: string
lite?: string
+ sort_by?: string
+ sort_order?: 'asc' | 'desc'
},
options?: {
signal?: AbortSignal
@@ -500,7 +504,11 @@ export async function exportData(options?: {
platform?: string
type?: string
status?: string
+ group?: string
+ privacy_mode?: string
search?: string
+ sort_by?: string
+ sort_order?: 'asc' | 'desc'
}
includeProxies?: boolean
}): Promise {
@@ -508,11 +516,15 @@ export async function exportData(options?: {
if (options?.ids && options.ids.length > 0) {
params.ids = options.ids.join(',')
} else if (options?.filters) {
- const { platform, type, status, search } = options.filters
+ const { platform, type, status, group, privacy_mode, search, sort_by, sort_order } = options.filters
if (platform) params.platform = platform
if (type) params.type = type
if (status) params.status = status
+ if (group) params.group = group
+ if (privacy_mode) params.privacy_mode = privacy_mode
if (search) params.search = search
+ if (sort_by) params.sort_by = sort_by
+ if (sort_order) params.sort_order = sort_order
}
if (options?.includeProxies === false) {
params.include_proxies = 'false'
@@ -568,28 +580,6 @@ export async function refreshOpenAIToken(
return data
}
-/**
- * Validate Sora session token and exchange to access token
- * @param sessionToken - Sora session token
- * @param proxyId - Optional proxy ID
- * @param endpoint - API endpoint path
- * @returns Token information including access_token
- */
-export async function validateSoraSessionToken(
- sessionToken: string,
- proxyId?: number | null,
- endpoint: string = '/admin/sora/st2at'
-): Promise> {
- const payload: { session_token: string; proxy_id?: number } = {
- session_token: sessionToken
- }
- if (proxyId) {
- payload.proxy_id = proxyId
- }
- const { data } = await apiClient.post>(endpoint, payload)
- return data
-}
-
/**
* Batch operation result type
*/
@@ -663,7 +653,6 @@ export const accountsAPI = {
generateAuthUrl,
exchangeCode,
refreshOpenAIToken,
- validateSoraSessionToken,
batchCreate,
batchUpdateCredentials,
bulkUpdate,
diff --git a/frontend/src/api/admin/affiliates.ts b/frontend/src/api/admin/affiliates.ts
new file mode 100644
index 00000000..22639bd2
--- /dev/null
+++ b/frontend/src/api/admin/affiliates.ts
@@ -0,0 +1,108 @@
+/**
+ * Admin Affiliate API endpoints
+ * Manage per-user affiliate (邀请返利) configurations:
+ * exclusive invite codes (overrides aff_code) and exclusive rebate rates.
+ */
+
+import { apiClient } from '../client'
+import type { PaginatedResponse } from '@/types'
+
+export interface AffiliateAdminEntry {
+ user_id: number
+ email: string
+ username: string
+ aff_code: string
+ aff_code_custom: boolean
+ aff_rebate_rate_percent?: number | null
+ aff_count: number
+}
+
+export interface ListAffiliateUsersParams {
+ page?: number
+ page_size?: number
+ search?: string
+}
+
+export interface UpdateAffiliateUserRequest {
+ aff_code?: string
+ aff_rebate_rate_percent?: number | null
+ /** Set true to explicitly clear the per-user rate (sets it to NULL). */
+ clear_rebate_rate?: boolean
+}
+
+export interface BatchSetRateRequest {
+ user_ids: number[]
+ aff_rebate_rate_percent?: number | null
+ /** Set true to clear rates instead of setting. */
+ clear?: boolean
+}
+
+export interface SimpleUser {
+ id: number
+ email: string
+ username: string
+}
+
+export async function listUsers(
+ params: ListAffiliateUsersParams = {},
+): Promise> {
+ const { data } = await apiClient.get>(
+ '/admin/affiliates/users',
+ {
+ params: {
+ page: params.page ?? 1,
+ page_size: params.page_size ?? 20,
+ search: params.search ?? '',
+ },
+ },
+ )
+ return data
+}
+
+export async function lookupUsers(q: string): Promise {
+ const { data } = await apiClient.get(
+ '/admin/affiliates/users/lookup',
+ { params: { q } },
+ )
+ return data
+}
+
+export async function updateUserSettings(
+ userId: number,
+ payload: UpdateAffiliateUserRequest,
+): Promise<{ user_id: number }> {
+ const { data } = await apiClient.put<{ user_id: number }>(
+ `/admin/affiliates/users/${userId}`,
+ payload,
+ )
+ return data
+}
+
+export async function clearUserSettings(
+ userId: number,
+): Promise<{ user_id: number }> {
+ const { data } = await apiClient.delete<{ user_id: number }>(
+ `/admin/affiliates/users/${userId}`,
+ )
+ return data
+}
+
+export async function batchSetRate(
+ payload: BatchSetRateRequest,
+): Promise<{ affected: number }> {
+ const { data } = await apiClient.post<{ affected: number }>(
+ '/admin/affiliates/users/batch-rate',
+ payload,
+ )
+ return data
+}
+
+export const affiliatesAPI = {
+ listUsers,
+ lookupUsers,
+ updateUserSettings,
+ clearUserSettings,
+ batchSetRate,
+}
+
+export default affiliatesAPI
diff --git a/frontend/src/api/admin/announcements.ts b/frontend/src/api/admin/announcements.ts
index d02fdda7..92392a67 100644
--- a/frontend/src/api/admin/announcements.ts
+++ b/frontend/src/api/admin/announcements.ts
@@ -17,10 +17,16 @@ export async function list(
filters?: {
status?: string
search?: string
+ sort_by?: string
+ sort_order?: 'asc' | 'desc'
+ },
+ options?: {
+ signal?: AbortSignal
}
): Promise> {
const { data } = await apiClient.get>('/admin/announcements', {
- params: { page, page_size: pageSize, ...filters }
+ params: { page, page_size: pageSize, ...filters },
+ signal: options?.signal
})
return data
}
@@ -49,11 +55,21 @@ export async function getReadStatus(
id: number,
page: number = 1,
pageSize: number = 20,
- search: string = ''
+ filters?: {
+ search?: string
+ sort_by?: string
+ sort_order?: 'asc' | 'desc'
+ },
+ options?: {
+ signal?: AbortSignal
+ }
): Promise> {
const { data } = await apiClient.get>(
`/admin/announcements/${id}/read-status`,
- { params: { page, page_size: pageSize, search } }
+ {
+ params: { page, page_size: pageSize, ...filters },
+ signal: options?.signal
+ }
)
return data
}
@@ -68,4 +84,3 @@ const announcementsAPI = {
}
export default announcementsAPI
-
diff --git a/frontend/src/api/admin/channelMonitor.ts b/frontend/src/api/admin/channelMonitor.ts
new file mode 100644
index 00000000..949c4bc8
--- /dev/null
+++ b/frontend/src/api/admin/channelMonitor.ts
@@ -0,0 +1,202 @@
+/**
+ * Admin Channel Monitor API endpoints
+ * Handles channel monitor (uptime/health) management for administrators
+ */
+
+import { apiClient } from '../client'
+
+export type Provider = 'openai' | 'anthropic' | 'gemini'
+export type MonitorStatus = 'operational' | 'degraded' | 'failed' | 'error'
+export type BodyOverrideMode = 'off' | 'merge' | 'replace'
+
+export interface ChannelMonitor {
+ id: number
+ name: string
+ provider: Provider
+ endpoint: string
+ api_key_masked: string
+ /**
+ * True when the stored encrypted API key cannot be decrypted (e.g. the
+ * encryption key has changed). Admin must re-edit the monitor to provide
+ * a fresh key. Backend skips checks for these monitors.
+ */
+ api_key_decrypt_failed?: boolean
+ primary_model: string
+ extra_models: string[]
+ group_name: string
+ enabled: boolean
+ interval_seconds: number
+ last_checked_at: string | null
+ created_by: number
+ created_at: string
+ updated_at: string
+ /** Latest status of the primary model (empty when no history yet) */
+ primary_status: MonitorStatus | ''
+ /** Latest latency of the primary model in ms (null when no history yet) */
+ primary_latency_ms: number | null
+ /** Primary model 7-day availability percentage (0-100) */
+ availability_7d: number
+ /** Latest status per extra model (used for hover tooltip) */
+ extra_models_status: ExtraModelStatus[]
+ /** 请求自定义快照字段(高级设置) */
+ template_id: number | null
+ extra_headers: Record
+ body_override_mode: BodyOverrideMode
+ body_override: Record | null
+}
+
+export interface ExtraModelStatus {
+ model: string
+ status: MonitorStatus | ''
+ latency_ms: number | null
+}
+
+export interface ListParams {
+ page?: number
+ page_size?: number
+ provider?: Provider
+ enabled?: boolean
+ search?: string
+}
+
+export interface ListResponse {
+ items: ChannelMonitor[]
+ total: number
+ page: number
+ page_size: number
+ pages: number
+}
+
+export interface CreateParams {
+ name: string
+ provider: Provider
+ endpoint: string
+ api_key: string
+ primary_model: string
+ extra_models?: string[]
+ group_name?: string
+ enabled?: boolean
+ interval_seconds: number
+ template_id?: number | null
+ extra_headers?: Record
+ body_override_mode?: BodyOverrideMode
+ body_override?: Record | null
+}
+
+// Update request: api_key 空串 = 不修改;clear_template=true 时把 template_id 置空
+export type UpdateParams = Partial & {
+ clear_template?: boolean
+}
+
+export interface CheckResult {
+ model: string
+ status: MonitorStatus
+ latency_ms: number | null
+ ping_latency_ms: number | null
+ message: string
+ checked_at: string
+}
+
+export interface RunNowResponse {
+ results: CheckResult[]
+}
+
+export interface HistoryItem {
+ id: number
+ model: string
+ status: MonitorStatus
+ latency_ms: number | null
+ ping_latency_ms: number | null
+ message: string
+ checked_at: string
+}
+
+export interface HistoryParams {
+ model?: string
+ limit?: number
+}
+
+export interface HistoryResponse {
+ items: HistoryItem[]
+}
+
+/**
+ * List channel monitors with pagination and filters
+ */
+export async function list(
+ params: ListParams = {},
+ options?: { signal?: AbortSignal }
+): Promise {
+ const { data } = await apiClient.get('/admin/channel-monitors', {
+ params,
+ signal: options?.signal,
+ })
+ return data
+}
+
+/**
+ * Get a channel monitor by ID
+ */
+export async function get(id: number): Promise {
+ const { data } = await apiClient.get(`/admin/channel-monitors/${id}`)
+ return data
+}
+
+/**
+ * Create a new channel monitor
+ */
+export async function create(params: CreateParams): Promise {
+ const { data } = await apiClient.post('/admin/channel-monitors', params)
+ return data
+}
+
+/**
+ * Update an existing channel monitor.
+ * api_key field: empty string means "do not modify".
+ */
+export async function update(id: number, params: UpdateParams): Promise {
+ const { data } = await apiClient.put(`/admin/channel-monitors/${id}`, params)
+ return data
+}
+
+/**
+ * Delete a channel monitor
+ */
+export async function del(id: number): Promise {
+ await apiClient.delete(`/admin/channel-monitors/${id}`)
+}
+
+/**
+ * Trigger an immediate manual check for a channel monitor.
+ * Returns the latest check results for primary + extra models.
+ */
+export async function runNow(id: number): Promise {
+ const { data } = await apiClient.post(`/admin/channel-monitors/${id}/run`)
+ return data
+}
+
+/**
+ * List historical check results for a monitor.
+ */
+export async function listHistory(
+ id: number,
+ params: HistoryParams = {}
+): Promise {
+ const { data } = await apiClient.get(
+ `/admin/channel-monitors/${id}/history`,
+ { params }
+ )
+ return data
+}
+
+export const channelMonitorAPI = {
+ list,
+ get,
+ create,
+ update,
+ del,
+ runNow,
+ listHistory,
+}
+
+export default channelMonitorAPI
diff --git a/frontend/src/api/admin/channelMonitorTemplate.ts b/frontend/src/api/admin/channelMonitorTemplate.ts
new file mode 100644
index 00000000..01b3c2d0
--- /dev/null
+++ b/frontend/src/api/admin/channelMonitorTemplate.ts
@@ -0,0 +1,132 @@
+/**
+ * Admin Channel Monitor Request Template API.
+ *
+ * 模板 = 一组可复用的 headers + 可选 body 覆盖配置。
+ * 应用到监控 = 拷贝快照;模板后续变动不自动同步,需手动点「应用到关联监控」刷新。
+ */
+
+import { apiClient } from '../client'
+import type { BodyOverrideMode, Provider } from './channelMonitor'
+
+export interface ChannelMonitorTemplate {
+ id: number
+ name: string
+ provider: Provider
+ description: string
+ extra_headers: Record
+ body_override_mode: BodyOverrideMode
+ body_override: Record | null
+ created_at: string
+ updated_at: string
+ /** 关联的监控数量(快照来自此模板,仅 template_id 匹配即可) */
+ associated_monitors: number
+}
+
+export interface ListParams {
+ provider?: Provider
+}
+
+export interface ListResponse {
+ items: ChannelMonitorTemplate[]
+}
+
+export interface CreateParams {
+ name: string
+ provider: Provider
+ description?: string
+ extra_headers?: Record
+ body_override_mode?: BodyOverrideMode
+ body_override?: Record | null
+}
+
+export interface UpdateParams {
+ name?: string
+ description?: string
+ extra_headers?: Record
+ body_override_mode?: BodyOverrideMode
+ body_override?: Record | null
+}
+
+export interface ApplyResponse {
+ affected: number
+}
+
+export interface AssociatedMonitorBrief {
+ id: number
+ name: string
+ provider: Provider
+ enabled: boolean
+}
+
+export interface AssociatedMonitorsResponse {
+ items: AssociatedMonitorBrief[]
+}
+
+export async function list(params: ListParams = {}): Promise {
+ const { data } = await apiClient.get('/admin/channel-monitor-templates', {
+ params,
+ })
+ return data
+}
+
+export async function get(id: number): Promise {
+ const { data } = await apiClient.get(
+ `/admin/channel-monitor-templates/${id}`,
+ )
+ return data
+}
+
+export async function create(params: CreateParams): Promise {
+ const { data } = await apiClient.post(
+ '/admin/channel-monitor-templates',
+ params,
+ )
+ return data
+}
+
+export async function update(id: number, params: UpdateParams): Promise {
+ const { data } = await apiClient.put(
+ `/admin/channel-monitor-templates/${id}`,
+ params,
+ )
+ return data
+}
+
+export async function del(id: number): Promise {
+ await apiClient.delete(`/admin/channel-monitor-templates/${id}`)
+}
+
+/**
+ * Apply the template to the specified associated monitors (overwrite snapshot fields).
+ * monitorIds must be a non-empty subset of the template's associated monitors.
+ * Returns count of actually affected monitors.
+ */
+export async function apply(id: number, monitorIds: number[]): Promise {
+ const { data } = await apiClient.post(
+ `/admin/channel-monitor-templates/${id}/apply`,
+ { monitor_ids: monitorIds },
+ )
+ return data
+}
+
+/**
+ * List monitors currently associated to this template (used by apply picker).
+ */
+export async function listAssociatedMonitors(id: number): Promise {
+ const { data } = await apiClient.get(
+ `/admin/channel-monitor-templates/${id}/monitors`,
+ )
+ return data
+}
+
+export const channelMonitorTemplateAPI = {
+ list,
+ get,
+ create,
+ update,
+ del,
+ apply,
+ listAssociatedMonitors,
+}
+
+export default channelMonitorTemplateAPI
diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts
new file mode 100644
index 00000000..9d430134
--- /dev/null
+++ b/frontend/src/api/admin/channels.ts
@@ -0,0 +1,168 @@
+/**
+ * Admin Channels API endpoints
+ * Handles channel management for administrators
+ */
+
+import { apiClient } from '../client'
+import type { BillingMode, ChannelStatus, BillingModelSource } from '@/constants/channel'
+
+export type { BillingMode } from '@/constants/channel'
+
+export interface PricingInterval {
+ id?: number
+ min_tokens: number
+ max_tokens: number | null
+ tier_label: string
+ input_price: number | null
+ output_price: number | null
+ cache_write_price: number | null
+ cache_read_price: number | null
+ per_request_price: number | null
+ sort_order: number
+}
+
+export interface ChannelModelPricing {
+ id?: number
+ platform: string
+ models: string[]
+ billing_mode: BillingMode
+ input_price: number | null
+ output_price: number | null
+ cache_write_price: number | null
+ cache_read_price: number | null
+ image_output_price: number | null
+ per_request_price: number | null
+ intervals: PricingInterval[]
+}
+
+export interface AccountStatsPricingRule {
+ id?: number
+ name: string
+ group_ids: number[]
+ account_ids: number[]
+ pricing: ChannelModelPricing[]
+}
+
+export interface Channel {
+ id: number
+ name: string
+ description: string
+ status: ChannelStatus
+ billing_model_source: BillingModelSource
+ restrict_models: boolean
+ features_config?: Record
+ group_ids: number[]
+ model_pricing: ChannelModelPricing[]
+ model_mapping: Record> // platform → {src→dst}
+ apply_pricing_to_account_stats: boolean
+ account_stats_pricing_rules: AccountStatsPricingRule[]
+ created_at: string
+ updated_at: string
+}
+
+export interface CreateChannelRequest {
+ name: string
+ description?: string
+ group_ids?: number[]
+ model_pricing?: ChannelModelPricing[]
+ model_mapping?: Record>
+ billing_model_source?: string
+ restrict_models?: boolean
+ features_config?: Record
+ apply_pricing_to_account_stats?: boolean
+ account_stats_pricing_rules?: AccountStatsPricingRule[]
+}
+
+export interface UpdateChannelRequest {
+ name?: string
+ description?: string
+ status?: string
+ group_ids?: number[]
+ model_pricing?: ChannelModelPricing[]
+ model_mapping?: Record>
+ billing_model_source?: string
+ restrict_models?: boolean
+ features_config?: Record
+ apply_pricing_to_account_stats?: boolean
+ account_stats_pricing_rules?: AccountStatsPricingRule[]
+}
+
+interface PaginatedResponse {
+ items: T[]
+ total: number
+}
+
+/**
+ * List channels with pagination
+ */
+export async function list(
+ page: number = 1,
+ pageSize: number = 20,
+ filters?: {
+ status?: string
+ search?: string
+ sort_by?: string
+ sort_order?: 'asc' | 'desc'
+ },
+ options?: { signal?: AbortSignal }
+): Promise> {
+ const { data } = await apiClient.get>('/admin/channels', {
+ params: {
+ page,
+ page_size: pageSize,
+ ...filters
+ },
+ signal: options?.signal
+ })
+ return data
+}
+
+/**
+ * Get channel by ID
+ */
+export async function getById(id: number): Promise {
+ const { data } = await apiClient.get(`/admin/channels/${id}`)
+ return data
+}
+
+/**
+ * Create a new channel
+ */
+export async function create(req: CreateChannelRequest): Promise {
+ const { data } = await apiClient.post('/admin/channels', req)
+ return data
+}
+
+/**
+ * Update a channel
+ */
+export async function update(id: number, req: UpdateChannelRequest): Promise {
+ const { data } = await apiClient.put(`/admin/channels/${id}`, req)
+ return data
+}
+
+/**
+ * Delete a channel
+ */
+export async function remove(id: number): Promise {
+ await apiClient.delete(`/admin/channels/${id}`)
+}
+
+export interface ModelDefaultPricing {
+ found: boolean
+ input_price?: number // per-token price
+ output_price?: number
+ cache_write_price?: number
+ cache_read_price?: number
+ image_output_price?: number
+}
+
+export async function getModelDefaultPricing(model: string): Promise {
+ const { data } = await apiClient.get('/admin/channels/model-pricing', {
+ params: { model }
+ })
+ return data
+}
+
+const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing }
+export default channelsAPI
diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts
index 15d1540f..49e487b7 100644
--- a/frontend/src/api/admin/dashboard.ts
+++ b/frontend/src/api/admin/dashboard.ts
@@ -167,6 +167,13 @@ export interface UserBreakdownParams {
endpoint?: string
endpoint_type?: 'inbound' | 'upstream' | 'path'
limit?: number
+ // Additional filter conditions
+ user_id?: number
+ api_key_id?: number
+ account_id?: number
+ request_type?: number
+ stream?: boolean
+ billing_type?: number | null
}
export interface UserBreakdownResponse {
diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts
index 5885dc6a..6b94b799 100644
--- a/frontend/src/api/admin/groups.ts
+++ b/frontend/src/api/admin/groups.ts
@@ -27,6 +27,8 @@ export async function list(
status?: 'active' | 'inactive'
is_exclusive?: boolean
search?: string
+ sort_by?: string
+ sort_order?: 'asc' | 'desc'
},
options?: {
signal?: AbortSignal
@@ -162,7 +164,8 @@ export interface GroupRateMultiplierEntry {
user_email: string
user_notes: string
user_status: string
- rate_multiplier: number
+ rate_multiplier?: number | null
+ rpm_override?: number | null
}
/**
@@ -203,9 +206,7 @@ export async function clearGroupRateMultipliers(id: number): Promise<{ message:
/**
* Batch set rate multipliers for users in a group
- * @param id - Group ID
- * @param entries - Array of { user_id, rate_multiplier }
- * @returns Success confirmation
+ * Only touches rate_multiplier column; preserves rpm_override on existing rows.
*/
export async function batchSetGroupRateMultipliers(
id: number,
@@ -218,6 +219,60 @@ export async function batchSetGroupRateMultipliers(
return data
}
+/**
+ * RPM override entry for a user in a group
+ */
+export interface GroupRPMOverrideEntry {
+ user_id: number
+ user_name: string
+ user_email: string
+ user_notes: string
+ user_status: string
+ rpm_override: number
+}
+
+/**
+ * Get RPM overrides for users in a group (subset of rate-multipliers endpoint).
+ */
+export async function getGroupRPMOverrides(id: number): Promise {
+ const { data } = await apiClient.get(
+ `/admin/groups/${id}/rate-multipliers`
+ )
+ return data
+ .filter(e => e.rpm_override != null)
+ .map(e => ({
+ user_id: e.user_id,
+ user_name: e.user_name,
+ user_email: e.user_email,
+ user_notes: e.user_notes,
+ user_status: e.user_status,
+ rpm_override: e.rpm_override as number
+ }))
+}
+
+/**
+ * Batch set RPM overrides for users in a group.
+ * Only touches rpm_override column; preserves rate_multiplier on existing rows.
+ */
+export async function batchSetGroupRPMOverrides(
+ id: number,
+ entries: Array<{ user_id: number; rpm_override: number }>
+): Promise<{ message: string }> {
+ const { data } = await apiClient.put<{ message: string }>(
+ `/admin/groups/${id}/rpm-overrides`,
+ { entries }
+ )
+ return data
+}
+
+/**
+ * Clear all RPM overrides for a group (preserves rate_multiplier).
+ */
+export async function clearGroupRPMOverrides(id: number): Promise<{ message: string }> {
+ const { data } = await apiClient.delete<{ message: string }>(`/admin/groups/${id}/rpm-overrides`)
+ return data
+}
+
/**
* Get usage summary (today + cumulative cost) for all groups
* @param timezone - IANA timezone string (e.g. "Asia/Shanghai")
@@ -260,6 +315,9 @@ export const groupsAPI = {
getGroupRateMultipliers,
clearGroupRateMultipliers,
batchSetGroupRateMultipliers,
+ getGroupRPMOverrides,
+ clearGroupRPMOverrides,
+ batchSetGroupRPMOverrides,
updateSortOrder,
getUsageSummary,
getCapacitySummary
diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts
index 9a3fb8c5..80241794 100644
--- a/frontend/src/api/admin/index.ts
+++ b/frontend/src/api/admin/index.ts
@@ -25,6 +25,11 @@ import apiKeysAPI from './apiKeys'
import scheduledTestsAPI from './scheduledTests'
import backupAPI from './backup'
import tlsFingerprintProfileAPI from './tlsFingerprintProfile'
+import channelsAPI from './channels'
+import channelMonitorAPI from './channelMonitor'
+import channelMonitorTemplateAPI from './channelMonitorTemplate'
+import adminPaymentAPI from './payment'
+import affiliatesAPI from './affiliates'
/**
* Unified admin API object for convenient access
@@ -51,7 +56,12 @@ export const adminAPI = {
apiKeys: apiKeysAPI,
scheduledTests: scheduledTestsAPI,
backup: backupAPI,
- tlsFingerprintProfiles: tlsFingerprintProfileAPI
+ tlsFingerprintProfiles: tlsFingerprintProfileAPI,
+ channels: channelsAPI,
+ channelMonitor: channelMonitorAPI,
+ channelMonitorTemplate: channelMonitorTemplateAPI,
+ payment: adminPaymentAPI,
+ affiliates: affiliatesAPI
}
export {
@@ -76,7 +86,12 @@ export {
apiKeysAPI,
scheduledTestsAPI,
backupAPI,
- tlsFingerprintProfileAPI
+ tlsFingerprintProfileAPI,
+ channelsAPI,
+ channelMonitorAPI,
+ channelMonitorTemplateAPI,
+ adminPaymentAPI,
+ affiliatesAPI
}
export default adminAPI
diff --git a/frontend/src/api/admin/payment.ts b/frontend/src/api/admin/payment.ts
new file mode 100644
index 00000000..3daf56b2
--- /dev/null
+++ b/frontend/src/api/admin/payment.ts
@@ -0,0 +1,178 @@
+/**
+ * Admin Payment API endpoints
+ * Handles payment management operations for administrators
+ */
+
+import { apiClient } from '../client'
+import type {
+ DashboardStats,
+ PaymentOrder,
+ PaymentChannel,
+ SubscriptionPlan,
+ ProviderInstance
+} from '@/types/payment'
+import type { BasePaginationResponse } from '@/types'
+
+/** Admin-facing payment config returned by GET /admin/payment/config */
+export interface AdminPaymentConfig {
+ enabled: boolean
+ min_amount: number
+ max_amount: number
+ daily_limit: number
+ order_timeout_minutes: number
+ max_pending_orders: number
+ enabled_payment_types: string[]
+ balance_disabled: boolean
+ balance_recharge_multiplier: number
+ load_balance_strategy: string
+ product_name_prefix: string
+ product_name_suffix: string
+ help_image_url: string
+ help_text: string
+}
+
+/** Fields accepted by PUT /admin/payment/config (all optional via pointer semantics) */
+export interface UpdatePaymentConfigRequest {
+ enabled?: boolean
+ min_amount?: number
+ max_amount?: number
+ daily_limit?: number
+ order_timeout_minutes?: number
+ max_pending_orders?: number
+ enabled_payment_types?: string[]
+ balance_disabled?: boolean
+ balance_recharge_multiplier?: number
+ load_balance_strategy?: string
+ product_name_prefix?: string
+ product_name_suffix?: string
+ help_image_url?: string
+ help_text?: string
+}
+
+export const adminPaymentAPI = {
+ // ==================== Config ====================
+
+ /** Get payment configuration (admin view) */
+ getConfig() {
+ return apiClient.get('/admin/payment/config')
+ },
+
+ /** Update payment configuration */
+ updateConfig(data: UpdatePaymentConfigRequest) {
+ return apiClient.put('/admin/payment/config', data)
+ },
+
+ // ==================== Dashboard ====================
+
+ /** Get payment dashboard statistics */
+ getDashboard(days?: number) {
+ return apiClient.get('/admin/payment/dashboard', {
+ params: days ? { days } : undefined
+ })
+ },
+
+ // ==================== Orders ====================
+
+ /** Get all orders (paginated, with filters) */
+ getOrders(params?: {
+ page?: number
+ page_size?: number
+ status?: string
+ payment_type?: string
+ user_id?: number
+ keyword?: string
+ start_date?: string
+ end_date?: string
+ order_type?: string
+ }) {
+ return apiClient.get>('/admin/payment/orders', { params })
+ },
+
+ /** Get a specific order by ID */
+ getOrder(id: number) {
+ return apiClient.get(`/admin/payment/orders/${id}`)
+ },
+
+ /** Cancel an order (admin) */
+ cancelOrder(id: number) {
+ return apiClient.post(`/admin/payment/orders/${id}/cancel`)
+ },
+
+ /** Retry recharge for a failed order */
+ retryRecharge(id: number) {
+ return apiClient.post(`/admin/payment/orders/${id}/retry`)
+ },
+
+ /** Process a refund */
+ refundOrder(id: number, data: { amount: number; reason: string; deduct_balance?: boolean; force?: boolean }) {
+ return apiClient.post(`/admin/payment/orders/${id}/refund`, data)
+ },
+
+ // ==================== Channels ====================
+
+ /** Get all payment channels */
+ getChannels() {
+ return apiClient.get('/admin/payment/channels')
+ },
+
+ /** Create a payment channel */
+ createChannel(data: Partial) {
+ return apiClient.post('/admin/payment/channels', data)
+ },
+
+ /** Update a payment channel */
+ updateChannel(id: number, data: Partial) {
+ return apiClient.put(`/admin/payment/channels/${id}`, data)
+ },
+
+ /** Delete a payment channel */
+ deleteChannel(id: number) {
+ return apiClient.delete(`/admin/payment/channels/${id}`)
+ },
+
+ // ==================== Subscription Plans ====================
+
+ /** Get all subscription plans */
+ getPlans() {
+ return apiClient.get('/admin/payment/plans')
+ },
+
+ /** Create a subscription plan */
+ createPlan(data: Record) {
+ return apiClient.post('/admin/payment/plans', data)
+ },
+
+ /** Update a subscription plan */
+ updatePlan(id: number, data: Record) {
+ return apiClient.put(`/admin/payment/plans/${id}`, data)
+ },
+
+ /** Delete a subscription plan */
+ deletePlan(id: number) {
+ return apiClient.delete(`/admin/payment/plans/${id}`)
+ },
+
+ // ==================== Provider Instances ====================
+
+ /** Get all provider instances */
+ getProviders() {
+ return apiClient.get('/admin/payment/providers')
+ },
+
+ /** Create a provider instance */
+ createProvider(data: Partial) {
+ return apiClient.post('/admin/payment/providers', data)
+ },
+
+ /** Update a provider instance */
+ updateProvider(id: number, data: Partial) {
+ return apiClient.put(`/admin/payment/providers/${id}`, data)
+ },
+
+ /** Delete a provider instance */
+ deleteProvider(id: number) {
+ return apiClient.delete(`/admin/payment/providers/${id}`)
+ }
+}
+
+export default adminPaymentAPI
diff --git a/frontend/src/api/admin/promo.ts b/frontend/src/api/admin/promo.ts
index 6a8c4559..b24dffc2 100644
--- a/frontend/src/api/admin/promo.ts
+++ b/frontend/src/api/admin/promo.ts
@@ -17,10 +17,16 @@ export async function list(
filters?: {
status?: string
search?: string
+ sort_by?: string
+ sort_order?: 'asc' | 'desc'
+ },
+ options?: {
+ signal?: AbortSignal
}
): Promise> {
const { data } = await apiClient.get>('/admin/promo-codes', {
- params: { page, page_size: pageSize, ...filters }
+ params: { page, page_size: pageSize, ...filters },
+ signal: options?.signal
})
return data
}
diff --git a/frontend/src/api/admin/proxies.ts b/frontend/src/api/admin/proxies.ts
index 5e31ae20..3e041ba9 100644
--- a/frontend/src/api/admin/proxies.ts
+++ b/frontend/src/api/admin/proxies.ts
@@ -29,6 +29,8 @@ export async function list(
protocol?: string
status?: 'active' | 'inactive'
search?: string
+ sort_by?: string
+ sort_order?: 'asc' | 'desc'
},
options?: {
signal?: AbortSignal
@@ -227,16 +229,20 @@ export async function exportData(options?: {
protocol?: string
status?: 'active' | 'inactive'
search?: string
+ sort_by?: string
+ sort_order?: 'asc' | 'desc'
}
}): Promise {
const params: Record = {}
if (options?.ids && options.ids.length > 0) {
params.ids = options.ids.join(',')
} else if (options?.filters) {
- const { protocol, status, search } = options.filters
+ const { protocol, status, search, sort_by, sort_order } = options.filters
if (protocol) params.protocol = protocol
if (status) params.status = status
if (search) params.search = search
+ if (sort_by) params.sort_by = sort_by
+ if (sort_order) params.sort_order = sort_order
}
const { data } = await apiClient.get('/admin/proxies/data', { params })
return data
diff --git a/frontend/src/api/admin/redeem.ts b/frontend/src/api/admin/redeem.ts
index a53c3566..57626b1e 100644
--- a/frontend/src/api/admin/redeem.ts
+++ b/frontend/src/api/admin/redeem.ts
@@ -25,6 +25,8 @@ export async function list(
type?: RedeemCodeType
status?: 'active' | 'used' | 'expired' | 'unused'
search?: string
+ sort_by?: string
+ sort_order?: 'asc' | 'desc'
},
options?: {
signal?: AbortSignal
@@ -151,7 +153,10 @@ export async function getStats(): Promise<{
*/
export async function exportCodes(filters?: {
type?: RedeemCodeType
- status?: 'active' | 'used' | 'expired'
+ status?: 'used' | 'expired' | 'unused'
+ search?: string
+ sort_by?: string
+ sort_order?: 'asc' | 'desc'
}): Promise {
const response = await apiClient.get('/admin/redeem-codes/export', {
params: filters,
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index cabdd5aa..defbab43 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -3,12 +3,293 @@
* Handles system settings management for administrators
*/
-import { apiClient } from '../client'
-import type { CustomMenuItem, CustomEndpoint } from '@/types'
+import { apiClient } from "../client";
+import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from "@/types";
export interface DefaultSubscriptionSetting {
- group_id: number
- validity_days: number
+ group_id: number;
+ validity_days: number;
+}
+
+export type AuthSourceType = "email" | "linuxdo" | "oidc" | "wechat";
+
+export interface AuthSourceDefaultsValue {
+ balance: number;
+ concurrency: number;
+ subscriptions: DefaultSubscriptionSetting[];
+ grant_on_signup: boolean;
+ grant_on_first_bind: boolean;
+}
+
+export type AuthSourceDefaultsState = Record<
+ AuthSourceType,
+ AuthSourceDefaultsValue
+>;
+export type PaymentVisibleMethod = "alipay" | "wxpay";
+export type PaymentVisibleMethodSource =
+ | ""
+ | "official_alipay"
+ | "easypay_alipay"
+ | "official_wxpay"
+ | "easypay_wxpay";
+export type WeChatConnectMode = "open" | "mp" | "mobile";
+
+export interface PaymentVisibleMethodSourceOption {
+ value: PaymentVisibleMethodSource;
+ labelZh: string;
+ labelEn: string;
+}
+
+export interface WeChatConnectModeOption {
+ value: WeChatConnectMode;
+ labelZh: string;
+ labelEn: string;
+}
+
+const AUTH_SOURCE_TYPES: AuthSourceType[] = [
+ "email",
+ "linuxdo",
+ "oidc",
+ "wechat",
+];
+const AUTH_SOURCE_DEFAULT_BALANCE = 0;
+const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5;
+const PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS: Record<
+ PaymentVisibleMethod,
+ PaymentVisibleMethodSourceOption[]
+> = {
+ alipay: [
+ { value: "", labelZh: "未配置", labelEn: "Not configured" },
+ {
+ value: "official_alipay",
+ labelZh: "支付宝官方",
+ labelEn: "Official Alipay",
+ },
+ {
+ value: "easypay_alipay",
+ labelZh: "易支付支付宝",
+ labelEn: "EasyPay Alipay",
+ },
+ ],
+ wxpay: [
+ { value: "", labelZh: "未配置", labelEn: "Not configured" },
+ {
+ value: "official_wxpay",
+ labelZh: "微信官方",
+ labelEn: "Official WeChat Pay",
+ },
+ {
+ value: "easypay_wxpay",
+ labelZh: "易支付微信",
+ labelEn: "EasyPay WeChat Pay",
+ },
+ ],
+};
+const PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES: Record<
+ PaymentVisibleMethod,
+ Record
+> = {
+ alipay: {
+ official_alipay: "official_alipay",
+ alipay: "official_alipay",
+ alipay_direct: "official_alipay",
+ official: "official_alipay",
+ easypay_alipay: "easypay_alipay",
+ easypay: "easypay_alipay",
+ },
+ wxpay: {
+ official_wxpay: "official_wxpay",
+ wxpay: "official_wxpay",
+ wxpay_direct: "official_wxpay",
+ wechat: "official_wxpay",
+ official: "official_wxpay",
+ easypay_wxpay: "easypay_wxpay",
+ easypay: "easypay_wxpay",
+ },
+};
+const WECHAT_CONNECT_MODE_OPTIONS: WeChatConnectModeOption[] = [
+ { value: "open", labelZh: "PC 应用", labelEn: "PC App" },
+ {
+ value: "mp",
+ labelZh: "公众号",
+ labelEn: "Official Account",
+ },
+ {
+ value: "mobile",
+ labelZh: "移动应用",
+ labelEn: "Mobile App",
+ },
+];
+const WECHAT_CONNECT_MODE_ALIASES: Record = {
+ open: "open",
+ open_platform: "open",
+ official: "open",
+ wx_open: "open",
+ mp: "mp",
+ official_account: "mp",
+ wechat_mp: "mp",
+ mini_program: "mp",
+ mobile: "mobile",
+ mobile_app: "mobile",
+ native_app: "mobile",
+};
+
+export function normalizeDefaultSubscriptionSettings(
+ subscriptions: DefaultSubscriptionSetting[] | null | undefined,
+): DefaultSubscriptionSetting[] {
+ if (!Array.isArray(subscriptions)) return [];
+
+ return subscriptions
+ .filter((item) => item.group_id > 0 && item.validity_days > 0)
+ .map((item) => ({
+ group_id: Math.floor(item.group_id),
+ validity_days: Math.min(
+ 36500,
+ Math.max(1, Math.floor(item.validity_days)),
+ ),
+ }));
+}
+
+export function buildAuthSourceDefaultsState(
+ settings: Partial,
+): AuthSourceDefaultsState {
+ const raw = settings as Record;
+
+ return AUTH_SOURCE_TYPES.reduce((acc, source) => {
+ const subscriptions = raw[`auth_source_default_${source}_subscriptions`];
+ acc[source] = {
+ balance: Number(
+ raw[`auth_source_default_${source}_balance`] ??
+ AUTH_SOURCE_DEFAULT_BALANCE,
+ ),
+ concurrency: Math.max(
+ 1,
+ Number(
+ raw[`auth_source_default_${source}_concurrency`] ??
+ AUTH_SOURCE_DEFAULT_CONCURRENCY,
+ ),
+ ),
+ subscriptions: normalizeDefaultSubscriptionSettings(
+ Array.isArray(subscriptions)
+ ? (subscriptions as DefaultSubscriptionSetting[])
+ : [],
+ ),
+ grant_on_signup:
+ raw[`auth_source_default_${source}_grant_on_signup`] === true,
+ grant_on_first_bind:
+ raw[`auth_source_default_${source}_grant_on_first_bind`] === true,
+ };
+ return acc;
+ }, {} as AuthSourceDefaultsState);
+}
+
+export function appendAuthSourceDefaultsToUpdateRequest(
+ payload: UpdateSettingsRequest,
+ authSourceDefaults: AuthSourceDefaultsState,
+): UpdateSettingsRequest {
+ const target = payload as Record;
+
+ for (const source of AUTH_SOURCE_TYPES) {
+ const current = authSourceDefaults[source];
+ target[`auth_source_default_${source}_balance`] =
+ Number(current.balance) || 0;
+ target[`auth_source_default_${source}_concurrency`] = Math.max(
+ 1,
+ Math.floor(
+ Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY,
+ ),
+ );
+ target[`auth_source_default_${source}_subscriptions`] =
+ normalizeDefaultSubscriptionSettings(current.subscriptions);
+ target[`auth_source_default_${source}_grant_on_signup`] =
+ current.grant_on_signup;
+ target[`auth_source_default_${source}_grant_on_first_bind`] =
+ current.grant_on_first_bind;
+ }
+
+ return payload;
+}
+
+export function getPaymentVisibleMethodSourceOptions(
+ method: PaymentVisibleMethod,
+): PaymentVisibleMethodSourceOption[] {
+ return PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS[method];
+}
+
+export function normalizePaymentVisibleMethodSource(
+ method: PaymentVisibleMethod,
+ source: unknown,
+): PaymentVisibleMethodSource {
+ if (typeof source !== "string") return "";
+
+ const normalized = source.trim().toLowerCase();
+ if (!normalized) return "";
+
+ return PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES[method][normalized] ?? "";
+}
+
+export function getWeChatConnectModeOptions(): WeChatConnectModeOption[] {
+ return WECHAT_CONNECT_MODE_OPTIONS;
+}
+
+export function normalizeWeChatConnectMode(source: unknown): WeChatConnectMode {
+ if (typeof source !== "string") return "open";
+
+ const normalized = source.trim().toLowerCase();
+ if (!normalized) return "open";
+
+ return WECHAT_CONNECT_MODE_ALIASES[normalized] ?? "open";
+}
+
+export function defaultWeChatConnectScopesForMode(mode: unknown): string {
+ switch (normalizeWeChatConnectMode(mode)) {
+ case "mp":
+ return "snsapi_userinfo";
+ case "mobile":
+ return "";
+ default:
+ return "snsapi_login";
+ }
+}
+
+export function resolveWeChatConnectModeCapabilities(
+ openEnabled: unknown,
+ mpEnabled: unknown,
+ mobileEnabled: unknown,
+ legacyMode: unknown,
+): { openEnabled: boolean; mpEnabled: boolean; mobileEnabled: boolean } {
+ if (
+ typeof openEnabled === "boolean" ||
+ typeof mpEnabled === "boolean" ||
+ typeof mobileEnabled === "boolean"
+ ) {
+ return {
+ openEnabled: openEnabled === true,
+ mpEnabled: mpEnabled === true,
+ mobileEnabled: mobileEnabled === true,
+ };
+ }
+
+ switch (normalizeWeChatConnectMode(legacyMode)) {
+ case "mp":
+ return { openEnabled: false, mpEnabled: true, mobileEnabled: false };
+ case "mobile":
+ return { openEnabled: false, mpEnabled: false, mobileEnabled: true };
+ default:
+ return { openEnabled: true, mpEnabled: false, mobileEnabled: false };
+ }
+}
+
+export function deriveWeChatConnectStoredMode(
+ openEnabled: boolean,
+ mpEnabled: boolean,
+ mobileEnabled: boolean,
+ legacyMode: unknown,
+): WeChatConnectMode {
+ if (mpEnabled) return "mp";
+ if (mobileEnabled) return "mobile";
+ if (openEnabled) return "open";
+ return normalizeWeChatConnectMode(legacyMode);
}
/**
@@ -16,138 +297,357 @@ export interface DefaultSubscriptionSetting {
*/
export interface SystemSettings {
// Registration settings
- registration_enabled: boolean
- email_verify_enabled: boolean
- registration_email_suffix_whitelist: string[]
- promo_code_enabled: boolean
- password_reset_enabled: boolean
- frontend_url: string
- invitation_code_enabled: boolean
- totp_enabled: boolean // TOTP 双因素认证
- totp_encryption_key_configured: boolean // TOTP 加密密钥是否已配置
+ registration_enabled: boolean;
+ email_verify_enabled: boolean;
+ registration_email_suffix_whitelist: string[];
+ promo_code_enabled: boolean;
+ password_reset_enabled: boolean;
+ frontend_url: string;
+ invitation_code_enabled: boolean;
+ totp_enabled: boolean; // TOTP 双因素认证
+ totp_encryption_key_configured: boolean; // TOTP 加密密钥是否已配置
// Default settings
- default_balance: number
- default_concurrency: number
- default_subscriptions: DefaultSubscriptionSetting[]
+ default_balance: number;
+ affiliate_rebate_rate: number;
+ affiliate_rebate_freeze_hours: number;
+ affiliate_rebate_duration_days: number;
+ affiliate_rebate_per_invitee_cap: number;
+ default_concurrency: number;
+ default_user_rpm_limit: number;
+ default_subscriptions: DefaultSubscriptionSetting[];
+ auth_source_default_email_balance?: number;
+ auth_source_default_email_concurrency?: number;
+ auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_email_grant_on_signup?: boolean;
+ auth_source_default_email_grant_on_first_bind?: boolean;
+ auth_source_default_linuxdo_balance?: number;
+ auth_source_default_linuxdo_concurrency?: number;
+ auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_linuxdo_grant_on_signup?: boolean;
+ auth_source_default_linuxdo_grant_on_first_bind?: boolean;
+ auth_source_default_oidc_balance?: number;
+ auth_source_default_oidc_concurrency?: number;
+ auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_oidc_grant_on_signup?: boolean;
+ auth_source_default_oidc_grant_on_first_bind?: boolean;
+ auth_source_default_wechat_balance?: number;
+ auth_source_default_wechat_concurrency?: number;
+ auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_wechat_grant_on_signup?: boolean;
+ auth_source_default_wechat_grant_on_first_bind?: boolean;
+ force_email_on_third_party_signup?: boolean;
// OEM settings
- site_name: string
- site_logo: string
- site_subtitle: string
- api_base_url: string
- contact_info: string
- doc_url: string
- home_content: string
- hide_ccs_import_button: boolean
- purchase_subscription_enabled: boolean
- purchase_subscription_url: string
- sora_client_enabled: boolean
- backend_mode_enabled: boolean
- custom_menu_items: CustomMenuItem[]
- custom_endpoints: CustomEndpoint[]
+ site_name: string;
+ site_logo: string;
+ site_subtitle: string;
+ api_base_url: string;
+ contact_info: string;
+ doc_url: string;
+ home_content: string;
+ hide_ccs_import_button: boolean;
+ table_default_page_size: number;
+ table_page_size_options: number[];
+ backend_mode_enabled: boolean;
+ custom_menu_items: CustomMenuItem[];
+ custom_endpoints: CustomEndpoint[];
// SMTP settings
- smtp_host: string
- smtp_port: number
- smtp_username: string
- smtp_password_configured: boolean
- smtp_from_email: string
- smtp_from_name: string
- smtp_use_tls: boolean
+ smtp_host: string;
+ smtp_port: number;
+ smtp_username: string;
+ smtp_password_configured: boolean;
+ smtp_from_email: string;
+ smtp_from_name: string;
+ smtp_use_tls: boolean;
// Cloudflare Turnstile settings
- turnstile_enabled: boolean
- turnstile_site_key: string
- turnstile_secret_key_configured: boolean
+ turnstile_enabled: boolean;
+ turnstile_site_key: string;
+ turnstile_secret_key_configured: boolean;
// LinuxDo Connect OAuth settings
- linuxdo_connect_enabled: boolean
- linuxdo_connect_client_id: string
- linuxdo_connect_client_secret_configured: boolean
- linuxdo_connect_redirect_url: string
+ linuxdo_connect_enabled: boolean;
+ linuxdo_connect_client_id: string;
+ linuxdo_connect_client_secret_configured: boolean;
+ linuxdo_connect_redirect_url: string;
+
+ // WeChat Connect OAuth settings
+ wechat_connect_enabled: boolean;
+ wechat_connect_app_id: string;
+ wechat_connect_app_secret_configured: boolean;
+ wechat_connect_open_app_id?: string;
+ wechat_connect_open_app_secret_configured?: boolean;
+ wechat_connect_mp_app_id?: string;
+ wechat_connect_mp_app_secret_configured?: boolean;
+ wechat_connect_mobile_app_id?: string;
+ wechat_connect_mobile_app_secret_configured?: boolean;
+ wechat_connect_open_enabled?: boolean;
+ wechat_connect_mp_enabled?: boolean;
+ wechat_connect_mobile_enabled?: boolean;
+ wechat_connect_mode: string;
+ wechat_connect_scopes: string;
+ wechat_connect_redirect_url: string;
+ wechat_connect_frontend_redirect_url: string;
+
+ // Generic OIDC OAuth settings
+ oidc_connect_enabled: boolean;
+ oidc_connect_provider_name: string;
+ oidc_connect_client_id: string;
+ oidc_connect_client_secret_configured: boolean;
+ oidc_connect_issuer_url: string;
+ oidc_connect_discovery_url: string;
+ oidc_connect_authorize_url: string;
+ oidc_connect_token_url: string;
+ oidc_connect_userinfo_url: string;
+ oidc_connect_jwks_url: string;
+ oidc_connect_scopes: string;
+ oidc_connect_redirect_url: string;
+ oidc_connect_frontend_redirect_url: string;
+ oidc_connect_token_auth_method: string;
+ oidc_connect_use_pkce: boolean;
+ oidc_connect_validate_id_token: boolean;
+ oidc_connect_allowed_signing_algs: string;
+ oidc_connect_clock_skew_seconds: number;
+ oidc_connect_require_email_verified: boolean;
+ oidc_connect_userinfo_email_path: string;
+ oidc_connect_userinfo_id_path: string;
+ oidc_connect_userinfo_username_path: string;
// Model fallback configuration
- enable_model_fallback: boolean
- fallback_model_anthropic: string
- fallback_model_openai: string
- fallback_model_gemini: string
- fallback_model_antigravity: string
+ enable_model_fallback: boolean;
+ fallback_model_anthropic: string;
+ fallback_model_openai: string;
+ fallback_model_gemini: string;
+ fallback_model_antigravity: string;
// Identity patch configuration (Claude -> Gemini)
- enable_identity_patch: boolean
- identity_patch_prompt: string
+ enable_identity_patch: boolean;
+ identity_patch_prompt: string;
// Ops Monitoring (vNext)
- ops_monitoring_enabled: boolean
- ops_realtime_monitoring_enabled: boolean
- ops_query_mode_default: 'auto' | 'raw' | 'preagg' | string
- ops_metrics_interval_seconds: number
+ ops_monitoring_enabled: boolean;
+ ops_realtime_monitoring_enabled: boolean;
+ ops_query_mode_default: "auto" | "raw" | "preagg" | string;
+ ops_metrics_interval_seconds: number;
// Claude Code version check
- min_claude_code_version: string
- max_claude_code_version: string
+ min_claude_code_version: string;
+ max_claude_code_version: string;
// 分组隔离
- allow_ungrouped_key_scheduling: boolean
+ allow_ungrouped_key_scheduling: boolean;
// Gateway forwarding behavior
- enable_fingerprint_unification: boolean
- enable_metadata_passthrough: boolean
+ enable_fingerprint_unification: boolean;
+ enable_metadata_passthrough: boolean;
+ enable_cch_signing: boolean;
+ web_search_emulation_enabled?: boolean;
+
+ // Payment configuration
+ payment_enabled: boolean;
+ payment_min_amount: number;
+ payment_max_amount: number;
+ payment_daily_limit: number;
+ payment_order_timeout_minutes: number;
+ payment_max_pending_orders: number;
+ payment_enabled_types: string[];
+ payment_balance_disabled: boolean;
+ payment_balance_recharge_multiplier: number;
+ payment_recharge_fee_rate: number;
+ payment_load_balance_strategy: string;
+ payment_product_name_prefix: string;
+ payment_product_name_suffix: string;
+ payment_help_image_url: string;
+ payment_help_text: string;
+ payment_cancel_rate_limit_enabled: boolean;
+ payment_cancel_rate_limit_max: number;
+ payment_cancel_rate_limit_window: number;
+ payment_cancel_rate_limit_unit: string;
+ payment_cancel_rate_limit_window_mode: string;
+ payment_visible_method_alipay_source?: string;
+ payment_visible_method_wxpay_source?: string;
+ payment_visible_method_alipay_enabled?: boolean;
+ payment_visible_method_wxpay_enabled?: boolean;
+ openai_advanced_scheduler_enabled?: boolean;
+
+ // Balance & quota notification
+ balance_low_notify_enabled: boolean;
+ balance_low_notify_threshold: number;
+ balance_low_notify_recharge_url: string;
+ account_quota_notify_enabled: boolean;
+ account_quota_notify_emails: NotifyEmailEntry[];
+
+ // Channel Monitor feature switch
+ channel_monitor_enabled: boolean;
+ channel_monitor_default_interval_seconds: number;
+
+ // Available Channels feature switch
+ available_channels_enabled: boolean;
+
+ // Affiliate (邀请返利) feature switch
+ affiliate_enabled: boolean;
}
export interface UpdateSettingsRequest {
- registration_enabled?: boolean
- email_verify_enabled?: boolean
- registration_email_suffix_whitelist?: string[]
- promo_code_enabled?: boolean
- password_reset_enabled?: boolean
- frontend_url?: string
- invitation_code_enabled?: boolean
- totp_enabled?: boolean // TOTP 双因素认证
- default_balance?: number
- default_concurrency?: number
- default_subscriptions?: DefaultSubscriptionSetting[]
- site_name?: string
- site_logo?: string
- site_subtitle?: string
- api_base_url?: string
- contact_info?: string
- doc_url?: string
- home_content?: string
- hide_ccs_import_button?: boolean
- purchase_subscription_enabled?: boolean
- purchase_subscription_url?: string
- sora_client_enabled?: boolean
- backend_mode_enabled?: boolean
- custom_menu_items?: CustomMenuItem[]
- custom_endpoints?: CustomEndpoint[]
- smtp_host?: string
- smtp_port?: number
- smtp_username?: string
- smtp_password?: string
- smtp_from_email?: string
- smtp_from_name?: string
- smtp_use_tls?: boolean
- turnstile_enabled?: boolean
- turnstile_site_key?: string
- turnstile_secret_key?: string
- linuxdo_connect_enabled?: boolean
- linuxdo_connect_client_id?: string
- linuxdo_connect_client_secret?: string
- linuxdo_connect_redirect_url?: string
- enable_model_fallback?: boolean
- fallback_model_anthropic?: string
- fallback_model_openai?: string
- fallback_model_gemini?: string
- fallback_model_antigravity?: string
- enable_identity_patch?: boolean
- identity_patch_prompt?: string
- ops_monitoring_enabled?: boolean
- ops_realtime_monitoring_enabled?: boolean
- ops_query_mode_default?: 'auto' | 'raw' | 'preagg' | string
- ops_metrics_interval_seconds?: number
- min_claude_code_version?: string
- max_claude_code_version?: string
- allow_ungrouped_key_scheduling?: boolean
- enable_fingerprint_unification?: boolean
- enable_metadata_passthrough?: boolean
+ registration_enabled?: boolean;
+ email_verify_enabled?: boolean;
+ registration_email_suffix_whitelist?: string[];
+ promo_code_enabled?: boolean;
+ password_reset_enabled?: boolean;
+ frontend_url?: string;
+ invitation_code_enabled?: boolean;
+ totp_enabled?: boolean; // TOTP 双因素认证
+ default_balance?: number;
+ affiliate_rebate_rate?: number;
+ affiliate_rebate_freeze_hours?: number;
+ affiliate_rebate_duration_days?: number;
+ affiliate_rebate_per_invitee_cap?: number;
+ default_concurrency?: number;
+ default_user_rpm_limit?: number;
+ default_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_email_balance?: number;
+ auth_source_default_email_concurrency?: number;
+ auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_email_grant_on_signup?: boolean;
+ auth_source_default_email_grant_on_first_bind?: boolean;
+ auth_source_default_linuxdo_balance?: number;
+ auth_source_default_linuxdo_concurrency?: number;
+ auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_linuxdo_grant_on_signup?: boolean;
+ auth_source_default_linuxdo_grant_on_first_bind?: boolean;
+ auth_source_default_oidc_balance?: number;
+ auth_source_default_oidc_concurrency?: number;
+ auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_oidc_grant_on_signup?: boolean;
+ auth_source_default_oidc_grant_on_first_bind?: boolean;
+ auth_source_default_wechat_balance?: number;
+ auth_source_default_wechat_concurrency?: number;
+ auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_wechat_grant_on_signup?: boolean;
+ auth_source_default_wechat_grant_on_first_bind?: boolean;
+ force_email_on_third_party_signup?: boolean;
+ site_name?: string;
+ site_logo?: string;
+ site_subtitle?: string;
+ api_base_url?: string;
+ contact_info?: string;
+ doc_url?: string;
+ home_content?: string;
+ hide_ccs_import_button?: boolean;
+ table_default_page_size?: number;
+ table_page_size_options?: number[];
+ backend_mode_enabled?: boolean;
+ custom_menu_items?: CustomMenuItem[];
+ custom_endpoints?: CustomEndpoint[];
+ smtp_host?: string;
+ smtp_port?: number;
+ smtp_username?: string;
+ smtp_password?: string;
+ smtp_from_email?: string;
+ smtp_from_name?: string;
+ smtp_use_tls?: boolean;
+ turnstile_enabled?: boolean;
+ turnstile_site_key?: string;
+ turnstile_secret_key?: string;
+ linuxdo_connect_enabled?: boolean;
+ linuxdo_connect_client_id?: string;
+ linuxdo_connect_client_secret?: string;
+ linuxdo_connect_redirect_url?: string;
+ wechat_connect_enabled?: boolean;
+ wechat_connect_app_id?: string;
+ wechat_connect_app_secret?: string;
+ wechat_connect_open_app_id?: string;
+ wechat_connect_open_app_secret?: string;
+ wechat_connect_mp_app_id?: string;
+ wechat_connect_mp_app_secret?: string;
+ wechat_connect_mobile_app_id?: string;
+ wechat_connect_mobile_app_secret?: string;
+ wechat_connect_open_enabled?: boolean;
+ wechat_connect_mp_enabled?: boolean;
+ wechat_connect_mobile_enabled?: boolean;
+ wechat_connect_mode?: string;
+ wechat_connect_scopes?: string;
+ wechat_connect_redirect_url?: string;
+ wechat_connect_frontend_redirect_url?: string;
+ oidc_connect_enabled?: boolean;
+ oidc_connect_provider_name?: string;
+ oidc_connect_client_id?: string;
+ oidc_connect_client_secret?: string;
+ oidc_connect_issuer_url?: string;
+ oidc_connect_discovery_url?: string;
+ oidc_connect_authorize_url?: string;
+ oidc_connect_token_url?: string;
+ oidc_connect_userinfo_url?: string;
+ oidc_connect_jwks_url?: string;
+ oidc_connect_scopes?: string;
+ oidc_connect_redirect_url?: string;
+ oidc_connect_frontend_redirect_url?: string;
+ oidc_connect_token_auth_method?: string;
+ oidc_connect_use_pkce?: boolean;
+ oidc_connect_validate_id_token?: boolean;
+ oidc_connect_allowed_signing_algs?: string;
+ oidc_connect_clock_skew_seconds?: number;
+ oidc_connect_require_email_verified?: boolean;
+ oidc_connect_userinfo_email_path?: string;
+ oidc_connect_userinfo_id_path?: string;
+ oidc_connect_userinfo_username_path?: string;
+ enable_model_fallback?: boolean;
+ fallback_model_anthropic?: string;
+ fallback_model_openai?: string;
+ fallback_model_gemini?: string;
+ fallback_model_antigravity?: string;
+ enable_identity_patch?: boolean;
+ identity_patch_prompt?: string;
+ ops_monitoring_enabled?: boolean;
+ ops_realtime_monitoring_enabled?: boolean;
+ ops_query_mode_default?: "auto" | "raw" | "preagg" | string;
+ ops_metrics_interval_seconds?: number;
+ min_claude_code_version?: string;
+ max_claude_code_version?: string;
+ allow_ungrouped_key_scheduling?: boolean;
+ enable_fingerprint_unification?: boolean;
+ enable_metadata_passthrough?: boolean;
+ enable_cch_signing?: boolean;
+ // Payment configuration
+ payment_enabled?: boolean;
+ payment_min_amount?: number;
+ payment_max_amount?: number;
+ payment_daily_limit?: number;
+ payment_order_timeout_minutes?: number;
+ payment_max_pending_orders?: number;
+ payment_enabled_types?: string[];
+ payment_balance_disabled?: boolean;
+ payment_balance_recharge_multiplier?: number;
+ payment_recharge_fee_rate?: number;
+ payment_load_balance_strategy?: string;
+ payment_product_name_prefix?: string;
+ payment_product_name_suffix?: string;
+ payment_help_image_url?: string;
+ payment_help_text?: string;
+ payment_cancel_rate_limit_enabled?: boolean;
+ payment_cancel_rate_limit_max?: number;
+ payment_cancel_rate_limit_window?: number;
+ payment_cancel_rate_limit_unit?: string;
+ payment_cancel_rate_limit_window_mode?: string;
+ payment_visible_method_alipay_source?: string;
+ payment_visible_method_wxpay_source?: string;
+ payment_visible_method_alipay_enabled?: boolean;
+ payment_visible_method_wxpay_enabled?: boolean;
+ openai_advanced_scheduler_enabled?: boolean;
+ // Balance & quota notification
+ balance_low_notify_enabled?: boolean;
+ balance_low_notify_threshold?: number;
+ balance_low_notify_recharge_url?: string;
+ account_quota_notify_enabled?: boolean;
+ account_quota_notify_emails?: NotifyEmailEntry[];
+
+ // Channel Monitor feature switch
+ channel_monitor_enabled?: boolean;
+ channel_monitor_default_interval_seconds?: number;
+
+ // Available Channels feature switch
+ available_channels_enabled?: boolean;
+
+ // Affiliate (邀请返利) feature switch
+ affiliate_enabled?: boolean;
}
/**
@@ -155,8 +655,8 @@ export interface UpdateSettingsRequest {
* @returns System settings
*/
export async function getSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings')
- return data
+ const { data } = await apiClient.get("/admin/settings");
+ return data;
}
/**
@@ -164,20 +664,25 @@ export async function getSettings(): Promise {
* @param settings - Partial settings to update
* @returns Updated settings
*/
-export async function updateSettings(settings: UpdateSettingsRequest): Promise {
- const { data } = await apiClient.put('/admin/settings', settings)
- return data
+export async function updateSettings(
+ settings: UpdateSettingsRequest,
+): Promise {
+ const { data } = await apiClient.put(
+ "/admin/settings",
+ settings,
+ );
+ return data;
}
/**
* Test SMTP connection request
*/
export interface TestSmtpRequest {
- smtp_host: string
- smtp_port: number
- smtp_username: string
- smtp_password: string
- smtp_use_tls: boolean
+ smtp_host: string;
+ smtp_port: number;
+ smtp_username: string;
+ smtp_password: string;
+ smtp_use_tls: boolean;
}
/**
@@ -185,23 +690,28 @@ export interface TestSmtpRequest {
* @param config - SMTP configuration to test
* @returns Test result message
*/
-export async function testSmtpConnection(config: TestSmtpRequest): Promise<{ message: string }> {
- const { data } = await apiClient.post<{ message: string }>('/admin/settings/test-smtp', config)
- return data
+export async function testSmtpConnection(
+ config: TestSmtpRequest,
+): Promise<{ message: string }> {
+ const { data } = await apiClient.post<{ message: string }>(
+ "/admin/settings/test-smtp",
+ config,
+ );
+ return data;
}
/**
* Send test email request
*/
export interface SendTestEmailRequest {
- email: string
- smtp_host: string
- smtp_port: number
- smtp_username: string
- smtp_password: string
- smtp_from_email: string
- smtp_from_name: string
- smtp_use_tls: boolean
+ email: string;
+ smtp_host: string;
+ smtp_port: number;
+ smtp_username: string;
+ smtp_password: string;
+ smtp_from_email: string;
+ smtp_from_name: string;
+ smtp_use_tls: boolean;
}
/**
@@ -209,20 +719,22 @@ export interface SendTestEmailRequest {
* @param request - Email address and SMTP config
* @returns Test result message
*/
-export async function sendTestEmail(request: SendTestEmailRequest): Promise<{ message: string }> {
+export async function sendTestEmail(
+ request: SendTestEmailRequest,
+): Promise<{ message: string }> {
const { data } = await apiClient.post<{ message: string }>(
- '/admin/settings/send-test-email',
- request
- )
- return data
+ "/admin/settings/send-test-email",
+ request,
+ );
+ return data;
}
/**
* Admin API Key status response
*/
export interface AdminApiKeyStatus {
- exists: boolean
- masked_key: string
+ exists: boolean;
+ masked_key: string;
}
/**
@@ -230,8 +742,10 @@ export interface AdminApiKeyStatus {
* @returns Status indicating if key exists and masked version
*/
export async function getAdminApiKey(): Promise {
- const { data } = await apiClient.get('/admin/settings/admin-api-key')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/admin-api-key",
+ );
+ return data;
}
/**
@@ -239,8 +753,10 @@ export async function getAdminApiKey(): Promise {
* @returns The new full API key (only shown once)
*/
export async function regenerateAdminApiKey(): Promise<{ key: string }> {
- const { data } = await apiClient.post<{ key: string }>('/admin/settings/admin-api-key/regenerate')
- return data
+ const { data } = await apiClient.post<{ key: string }>(
+ "/admin/settings/admin-api-key/regenerate",
+ );
+ return data;
}
/**
@@ -248,8 +764,10 @@ export async function regenerateAdminApiKey(): Promise<{ key: string }> {
* @returns Success message
*/
export async function deleteAdminApiKey(): Promise<{ message: string }> {
- const { data } = await apiClient.delete<{ message: string }>('/admin/settings/admin-api-key')
- return data
+ const { data } = await apiClient.delete<{ message: string }>(
+ "/admin/settings/admin-api-key",
+ );
+ return data;
}
// ==================== Overload Cooldown Settings ====================
@@ -258,23 +776,25 @@ export async function deleteAdminApiKey(): Promise<{ message: string }> {
* Overload cooldown settings interface (529 handling)
*/
export interface OverloadCooldownSettings {
- enabled: boolean
- cooldown_minutes: number
+ enabled: boolean;
+ cooldown_minutes: number;
}
export async function getOverloadCooldownSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/overload-cooldown')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/overload-cooldown",
+ );
+ return data;
}
export async function updateOverloadCooldownSettings(
- settings: OverloadCooldownSettings
+ settings: OverloadCooldownSettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/overload-cooldown',
- settings
- )
- return data
+ "/admin/settings/overload-cooldown",
+ settings,
+ );
+ return data;
}
// ==================== Stream Timeout Settings ====================
@@ -283,11 +803,11 @@ export async function updateOverloadCooldownSettings(
* Stream timeout settings interface
*/
export interface StreamTimeoutSettings {
- enabled: boolean
- action: 'temp_unsched' | 'error' | 'none'
- temp_unsched_minutes: number
- threshold_count: number
- threshold_window_minutes: number
+ enabled: boolean;
+ action: "temp_unsched" | "error" | "none";
+ temp_unsched_minutes: number;
+ threshold_count: number;
+ threshold_window_minutes: number;
}
/**
@@ -295,8 +815,10 @@ export interface StreamTimeoutSettings {
* @returns Stream timeout settings
*/
export async function getStreamTimeoutSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/stream-timeout')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/stream-timeout",
+ );
+ return data;
}
/**
@@ -305,13 +827,13 @@ export async function getStreamTimeoutSettings(): Promise
* @returns Updated settings
*/
export async function updateStreamTimeoutSettings(
- settings: StreamTimeoutSettings
+ settings: StreamTimeoutSettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/stream-timeout',
- settings
- )
- return data
+ "/admin/settings/stream-timeout",
+ settings,
+ );
+ return data;
}
// ==================== Rectifier Settings ====================
@@ -320,11 +842,11 @@ export async function updateStreamTimeoutSettings(
* Rectifier settings interface
*/
export interface RectifierSettings {
- enabled: boolean
- thinking_signature_enabled: boolean
- thinking_budget_enabled: boolean
- apikey_signature_enabled: boolean
- apikey_signature_patterns: string[]
+ enabled: boolean;
+ thinking_signature_enabled: boolean;
+ thinking_budget_enabled: boolean;
+ apikey_signature_enabled: boolean;
+ apikey_signature_patterns: string[];
}
/**
@@ -332,8 +854,10 @@ export interface RectifierSettings {
* @returns Rectifier settings
*/
export async function getRectifierSettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/rectifier')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/rectifier",
+ );
+ return data;
}
/**
@@ -342,13 +866,13 @@ export async function getRectifierSettings(): Promise {
* @returns Updated settings
*/
export async function updateRectifierSettings(
- settings: RectifierSettings
+ settings: RectifierSettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/rectifier',
- settings
- )
- return data
+ "/admin/settings/rectifier",
+ settings,
+ );
+ return data;
}
// ==================== Beta Policy Settings ====================
@@ -357,17 +881,20 @@ export async function updateRectifierSettings(
* Beta policy rule interface
*/
export interface BetaPolicyRule {
- beta_token: string
- action: 'pass' | 'filter' | 'block'
- scope: 'all' | 'oauth' | 'apikey' | 'bedrock'
- error_message?: string
+ beta_token: string;
+ action: "pass" | "filter" | "block";
+ scope: "all" | "oauth" | "apikey" | "bedrock";
+ error_message?: string;
+ model_whitelist?: string[];
+ fallback_action?: "pass" | "filter" | "block";
+ fallback_error_message?: string;
}
/**
* Beta policy settings interface
*/
export interface BetaPolicySettings {
- rules: BetaPolicyRule[]
+ rules: BetaPolicyRule[];
}
/**
@@ -375,8 +902,10 @@ export interface BetaPolicySettings {
* @returns Beta policy settings
*/
export async function getBetaPolicySettings(): Promise {
- const { data } = await apiClient.get('/admin/settings/beta-policy')
- return data
+ const { data } = await apiClient.get(
+ "/admin/settings/beta-policy",
+ );
+ return data;
}
/**
@@ -385,149 +914,73 @@ export async function getBetaPolicySettings(): Promise {
* @returns Updated settings
*/
export async function updateBetaPolicySettings(
- settings: BetaPolicySettings
+ settings: BetaPolicySettings,
): Promise {
const { data } = await apiClient.put(
- '/admin/settings/beta-policy',
- settings
- )
- return data
+ "/admin/settings/beta-policy",
+ settings,
+ );
+ return data;
}
-// ==================== Sora S3 Settings ====================
+// --- Web Search Emulation Config ---
-export interface SoraS3Settings {
- enabled: boolean
- endpoint: string
- region: string
- bucket: string
- access_key_id: string
- secret_access_key_configured: boolean
- prefix: string
- force_path_style: boolean
- cdn_url: string
- default_storage_quota_bytes: number
+export interface WebSearchProviderConfig {
+ type: "brave" | "tavily";
+ api_key: string;
+ api_key_configured: boolean;
+ quota_limit: number | null;
+ subscribed_at: number | null;
+ quota_used?: number;
+ proxy_id: number | null;
+ expires_at: number | null;
}
-export interface SoraS3Profile {
- profile_id: string
- name: string
- is_active: boolean
- enabled: boolean
- endpoint: string
- region: string
- bucket: string
- access_key_id: string
- secret_access_key_configured: boolean
- prefix: string
- force_path_style: boolean
- cdn_url: string
- default_storage_quota_bytes: number
- updated_at: string
+export interface WebSearchEmulationConfig {
+ enabled: boolean;
+ providers: WebSearchProviderConfig[];
}
-export interface ListSoraS3ProfilesResponse {
- active_profile_id: string
- items: SoraS3Profile[]
+export interface WebSearchTestResult {
+ provider: string;
+ results: { url: string; title: string; snippet: string; page_age?: string }[];
+ query: string;
}
-export interface UpdateSoraS3SettingsRequest {
- profile_id?: string
- enabled: boolean
- endpoint: string
- region: string
- bucket: string
- access_key_id: string
- secret_access_key?: string
- prefix: string
- force_path_style: boolean
- cdn_url: string
- default_storage_quota_bytes: number
+export async function getWebSearchEmulationConfig(): Promise {
+ const { data } = await apiClient.get(
+ "/admin/settings/web-search-emulation",
+ );
+ return data;
}
-export interface CreateSoraS3ProfileRequest {
- profile_id: string
- name: string
- set_active?: boolean
- enabled: boolean
- endpoint: string
- region: string
- bucket: string
- access_key_id: string
- secret_access_key?: string
- prefix: string
- force_path_style: boolean
- cdn_url: string
- default_storage_quota_bytes: number
+export async function updateWebSearchEmulationConfig(
+ config: WebSearchEmulationConfig,
+): Promise {
+ const { data } = await apiClient.put(
+ "/admin/settings/web-search-emulation",
+ config,
+ );
+ return data;
}
-export interface UpdateSoraS3ProfileRequest {
- name: string
- enabled: boolean
- endpoint: string
- region: string
- bucket: string
- access_key_id: string
- secret_access_key?: string
- prefix: string
- force_path_style: boolean
- cdn_url: string
- default_storage_quota_bytes: number
+export async function testWebSearchEmulation(
+ query: string,
+): Promise {
+ const { data } = await apiClient.post(
+ "/admin/settings/web-search-emulation/test",
+ { query },
+ );
+ return data;
}
-export interface TestSoraS3ConnectionRequest {
- profile_id?: string
- enabled: boolean
- endpoint: string
- region: string
- bucket: string
- access_key_id: string
- secret_access_key?: string
- prefix: string
- force_path_style: boolean
- cdn_url: string
- default_storage_quota_bytes?: number
-}
-
-export async function getSoraS3Settings(): Promise {
- const { data } = await apiClient.get('/admin/settings/sora-s3')
- return data
-}
-
-export async function updateSoraS3Settings(settings: UpdateSoraS3SettingsRequest): Promise {
- const { data } = await apiClient.put('/admin/settings/sora-s3', settings)
- return data
-}
-
-export async function testSoraS3Connection(
- settings: TestSoraS3ConnectionRequest
-): Promise<{ message: string }> {
- const { data } = await apiClient.post<{ message: string }>('/admin/settings/sora-s3/test', settings)
- return data
-}
-
-export async function listSoraS3Profiles(): Promise {
- const { data } = await apiClient.get('/admin/settings/sora-s3/profiles')
- return data
-}
-
-export async function createSoraS3Profile(request: CreateSoraS3ProfileRequest): Promise {
- const { data } = await apiClient.post('/admin/settings/sora-s3/profiles', request)
- return data
-}
-
-export async function updateSoraS3Profile(profileID: string, request: UpdateSoraS3ProfileRequest): Promise {
- const { data } = await apiClient.put(`/admin/settings/sora-s3/profiles/${profileID}`, request)
- return data
-}
-
-export async function deleteSoraS3Profile(profileID: string): Promise {
- await apiClient.delete(`/admin/settings/sora-s3/profiles/${profileID}`)
-}
-
-export async function setActiveSoraS3Profile(profileID: string): Promise {
- const { data } = await apiClient.post(`/admin/settings/sora-s3/profiles/${profileID}/activate`)
- return data
+export async function resetWebSearchUsage(payload: {
+ provider_type: string;
+}): Promise {
+ await apiClient.post(
+ "/admin/settings/web-search-emulation/reset-usage",
+ payload,
+ );
}
export const settingsAPI = {
@@ -546,14 +999,10 @@ export const settingsAPI = {
updateRectifierSettings,
getBetaPolicySettings,
updateBetaPolicySettings,
- getSoraS3Settings,
- updateSoraS3Settings,
- testSoraS3Connection,
- listSoraS3Profiles,
- createSoraS3Profile,
- updateSoraS3Profile,
- deleteSoraS3Profile,
- setActiveSoraS3Profile
-}
+ getWebSearchEmulationConfig,
+ updateWebSearchEmulationConfig,
+ testWebSearchEmulation,
+ resetWebSearchUsage,
+};
-export default settingsAPI
+export default settingsAPI;
diff --git a/frontend/src/api/admin/usage.ts b/frontend/src/api/admin/usage.ts
index bd7e3e57..7ad00742 100644
--- a/frontend/src/api/admin/usage.ts
+++ b/frontend/src/api/admin/usage.ts
@@ -17,7 +17,7 @@ export interface AdminUsageStatsResponse {
total_tokens: number
total_cost: number
total_actual_cost: number
- total_account_cost?: number
+ total_account_cost: number
average_duration_ms: number
endpoints?: EndpointStat[]
upstream_endpoints?: EndpointStat[]
@@ -80,6 +80,9 @@ export interface CreateUsageCleanupTaskRequest {
export interface AdminUsageQueryParams extends UsageQueryParams {
user_id?: number
exact_total?: boolean
+ billing_mode?: string
+ sort_by?: string
+ sort_order?: 'asc' | 'desc'
}
// ==================== API Functions ====================
diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts
index bbf0ab51..3c75a6c4 100644
--- a/frontend/src/api/admin/users.ts
+++ b/frontend/src/api/admin/users.ts
@@ -6,6 +6,44 @@
import { apiClient } from '../client'
import type { AdminUser, UpdateUserRequest, PaginatedResponse, ApiKey } from '@/types'
+export interface AdminBindAuthIdentityChannelRequest {
+ channel: string
+ channel_app_id: string
+ channel_subject: string
+ metadata?: Record | null
+}
+
+export interface AdminBindAuthIdentityRequest {
+ provider_type: string
+ provider_key: string
+ provider_subject: string
+ issuer?: string | null
+ metadata?: Record | null
+ channel?: AdminBindAuthIdentityChannelRequest
+}
+
+export interface AdminBoundAuthIdentityChannel {
+ channel: string
+ channel_app_id: string
+ channel_subject: string
+ metadata: Record | null
+ created_at: string
+ updated_at: string
+}
+
+export interface AdminBoundAuthIdentity {
+ user_id: number
+ provider_type: string
+ provider_key: string
+ provider_subject: string
+ verified_at?: string | null
+ issuer?: string | null
+ metadata: Record | null
+ created_at: string
+ updated_at: string
+ channel?: AdminBoundAuthIdentityChannel | null
+}
+
/**
* List all users with pagination
* @param page - Page number (default: 1)
@@ -24,6 +62,8 @@ export async function list(
group_name?: string // fuzzy filter by allowed group name
attributes?: Record // attributeId -> value
include_subscriptions?: boolean
+ sort_by?: string
+ sort_order?: 'asc' | 'desc'
},
options?: {
signal?: AbortSignal
@@ -37,7 +77,9 @@ export async function list(
role: filters?.role,
search: filters?.search,
group_name: filters?.group_name,
- include_subscriptions: filters?.include_subscriptions
+ include_subscriptions: filters?.include_subscriptions,
+ sort_by: filters?.sort_by,
+ sort_order: filters?.sort_order
}
// Add attribute filters as attr[id]=value
@@ -244,6 +286,17 @@ export async function replaceGroup(
return data
}
+export async function bindUserAuthIdentity(
+ userId: number,
+ input: AdminBindAuthIdentityRequest
+): Promise {
+ const { data } = await apiClient.post(
+ `/admin/users/${userId}/auth-identities`,
+ input
+ )
+ return data
+}
+
export const usersAPI = {
list,
getById,
@@ -256,7 +309,8 @@ export const usersAPI = {
getUserApiKeys,
getUserUsageStats,
getUserBalanceHistory,
- replaceGroup
+ replaceGroup,
+ bindUserAuthIdentity
}
export default usersAPI
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index c5e1f35d..bb990fc4 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -186,6 +186,108 @@ export interface RefreshTokenResponse {
token_type: string
}
+export interface OAuthTokenResponse {
+ access_token: string
+ refresh_token?: string
+ expires_in?: number
+ token_type?: string
+}
+
+export interface PendingOAuthBindLoginResponse extends Partial {
+ auth_result?: string
+ redirect?: string
+ error?: string
+ requires_2fa?: boolean
+ temp_token?: string
+ user_email_masked?: string
+ adoption_required?: boolean
+ suggested_display_name?: string
+ suggested_avatar_url?: string
+}
+
+export type PendingOAuthExchangeResponse = PendingOAuthBindLoginResponse
+
+export interface PendingOAuthCreateAccountResponse extends OAuthTokenResponse {
+ auth_result?: string
+}
+
+export interface PendingOAuthSendVerifyCodeResponse extends SendVerifyCodeResponse {
+ auth_result?: string
+ provider?: string
+ redirect?: string
+}
+
+export type OAuthCompletionKind = 'login' | 'bind'
+
+export interface OAuthAdoptionDecision {
+ adoptDisplayName?: boolean
+ adoptAvatar?: boolean
+}
+
+function serializeOAuthAdoptionDecision(
+ decision?: OAuthAdoptionDecision
+): Record {
+ const payload: Record = {}
+
+ if (typeof decision?.adoptDisplayName === 'boolean') {
+ payload.adopt_display_name = decision.adoptDisplayName
+ }
+ if (typeof decision?.adoptAvatar === 'boolean') {
+ payload.adopt_avatar = decision.adoptAvatar
+ }
+
+ return payload
+}
+
+export function isOAuthLoginCompletion(
+ completion: Partial
+): completion is OAuthTokenResponse {
+ return typeof completion.access_token === 'string' && completion.access_token.trim().length > 0
+}
+
+export function getOAuthCompletionKind(
+ completion: Partial
+): OAuthCompletionKind {
+ return isOAuthLoginCompletion(completion) ? 'login' : 'bind'
+}
+
+export function getPendingOAuthBindLoginKind(
+ completion: PendingOAuthBindLoginResponse
+): OAuthCompletionKind {
+ return getOAuthCompletionKind(completion)
+}
+
+export function isPendingOAuthCreateAccountRequired(
+ completion: Pick
+): boolean {
+ return completion.error === 'invitation_required'
+}
+
+export function hasPendingOAuthSuggestedProfile(
+ completion: Pick<
+ PendingOAuthBindLoginResponse,
+ 'suggested_display_name' | 'suggested_avatar_url'
+ >
+): boolean {
+ return Boolean(completion.suggested_display_name || completion.suggested_avatar_url)
+}
+
+export function persistOAuthTokenContext(tokens: Partial): void {
+ if (tokens.refresh_token) {
+ setRefreshToken(tokens.refresh_token)
+ }
+ if (tokens.expires_in) {
+ setTokenExpiresAt(tokens.expires_in)
+ }
+}
+
+export async function prepareOAuthBindAccessTokenCookie(): Promise {
+ if (!getAuthToken()) {
+ return
+ }
+ await apiClient.post('/auth/oauth/bind-token')
+}
+
/**
* Refresh the access token using the refresh token
* @returns New token pair
@@ -234,6 +336,116 @@ export async function getPublicSettings(): Promise {
return data
}
+export type WeChatOAuthMode = 'open' | 'mp'
+export type WeChatOAuthUnavailableReason =
+ | 'not_configured'
+ | 'capability_unknown'
+ | 'external_browser_required'
+ | 'wechat_browser_required'
+ | 'native_app_required'
+
+export interface ResolvedWeChatOAuthStart {
+ mode: WeChatOAuthMode | null
+ openEnabled: boolean
+ mpEnabled: boolean
+ mobileEnabled: boolean
+ isWeChatBrowser: boolean
+ unavailableReason: WeChatOAuthUnavailableReason | null
+}
+
+export type WeChatOAuthPublicSettings = {
+ wechat_oauth_enabled?: boolean
+ wechat_oauth_open_enabled?: boolean
+ wechat_oauth_mp_enabled?: boolean
+ wechat_oauth_mobile_enabled?: boolean
+}
+
+export function isWeChatWebOAuthEnabled(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+): boolean {
+ const legacyEnabled = settings?.wechat_oauth_enabled ?? false
+ const hasExplicitCapabilities =
+ typeof settings?.wechat_oauth_open_enabled === 'boolean' ||
+ typeof settings?.wechat_oauth_mp_enabled === 'boolean'
+
+ if (!hasExplicitCapabilities) {
+ return legacyEnabled
+ }
+
+ return settings?.wechat_oauth_open_enabled === true || settings?.wechat_oauth_mp_enabled === true
+}
+
+export function hasExplicitWeChatOAuthCapabilities(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+): settings is WeChatOAuthPublicSettings & {
+ wechat_oauth_open_enabled: boolean
+ wechat_oauth_mp_enabled: boolean
+} {
+ return typeof settings?.wechat_oauth_open_enabled === 'boolean'
+ && typeof settings?.wechat_oauth_mp_enabled === 'boolean'
+}
+
+export function resolveWeChatOAuthStart(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+ userAgent?: string
+): ResolvedWeChatOAuthStart {
+ const normalizedUserAgent = (userAgent
+ ?? (typeof navigator !== 'undefined' ? navigator.userAgent : '')
+ ?? '').trim()
+ const isWeChatBrowser = /MicroMessenger/i.test(normalizedUserAgent)
+ const legacyEnabled = settings?.wechat_oauth_enabled ?? false
+ const openEnabled = typeof settings?.wechat_oauth_open_enabled === 'boolean'
+ ? settings.wechat_oauth_open_enabled
+ : legacyEnabled
+ const mpEnabled = typeof settings?.wechat_oauth_mp_enabled === 'boolean'
+ ? settings.wechat_oauth_mp_enabled
+ : legacyEnabled
+ const mobileEnabled = typeof settings?.wechat_oauth_mobile_enabled === 'boolean'
+ ? settings.wechat_oauth_mobile_enabled
+ : false
+
+ if (isWeChatBrowser) {
+ if (mpEnabled) {
+ return { mode: 'mp', openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: null }
+ }
+ if (openEnabled) {
+ return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'external_browser_required' }
+ }
+ return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'not_configured' }
+ }
+
+ if (openEnabled) {
+ return { mode: 'open', openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: null }
+ }
+ if (mpEnabled) {
+ return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'wechat_browser_required' }
+ }
+ return { mode: null, openEnabled, mpEnabled, mobileEnabled, isWeChatBrowser, unavailableReason: 'not_configured' }
+}
+
+export function resolveWeChatOAuthStartStrict(
+ settings: WeChatOAuthPublicSettings | null | undefined,
+ userAgent?: string,
+): ResolvedWeChatOAuthStart {
+ const normalizedUserAgent = (userAgent
+ ?? (typeof navigator !== 'undefined' ? navigator.userAgent : '')
+ ?? '').trim()
+ const isWeChatBrowser = /MicroMessenger/i.test(normalizedUserAgent)
+
+ if (!hasExplicitWeChatOAuthCapabilities(settings)) {
+ return {
+ mode: null,
+ openEnabled: false,
+ mpEnabled: false,
+ mobileEnabled: false,
+ isWeChatBrowser,
+ unavailableReason: 'capability_unknown',
+ }
+ }
+
+ return resolveWeChatOAuthStart(settings, normalizedUserAgent)
+}
+
/**
* Send verification code to email
* @param request - Email and optional Turnstile token
@@ -246,6 +458,16 @@ export async function sendVerifyCode(
return data
}
+export async function sendPendingOAuthVerifyCode(
+ request: SendVerifyCodeRequest
+): Promise {
+ const { data } = await apiClient.post(
+ '/auth/oauth/pending/send-verify-code',
+ request
+ )
+ return data
+}
+
/**
* Validate promo code response
*/
@@ -337,26 +559,96 @@ export async function resetPassword(request: ResetPasswordRequest): Promise {
- const { data } = await apiClient.post<{
- access_token: string
- refresh_token: string
- expires_in: number
- token_type: string
- }>('/auth/oauth/linuxdo/complete-registration', {
- pending_oauth_token: pendingOAuthToken,
- invitation_code: invitationCode
- })
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
+): Promise {
+ return createPendingLinuxDoOAuthAccount(invitationCode, decision, affiliateCode)
+}
+
+/**
+ * Complete OIDC OAuth registration by supplying an invitation code
+ * @param invitationCode - Invitation code entered by the user
+ * @returns Token pair on success
+ */
+export async function completeOIDCOAuthRegistration(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
+): Promise {
+ return createPendingOIDCOAuthAccount(invitationCode, decision, affiliateCode)
+}
+
+export async function completeWeChatOAuthRegistration(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
+): Promise {
+ return createPendingWeChatOAuthAccount(invitationCode, decision, affiliateCode)
+}
+
+async function createPendingOAuthAccount(
+ provider: 'linuxdo' | 'oidc' | 'wechat',
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
+): Promise {
+ const normalizedAffiliateCode = affiliateCode?.trim()
+ const { data } = await apiClient.post(
+ `/auth/oauth/${provider}/complete-registration`,
+ {
+ invitation_code: invitationCode,
+ ...(normalizedAffiliateCode ? { aff_code: normalizedAffiliateCode } : {}),
+ ...serializeOAuthAdoptionDecision(decision)
+ }
+ )
return data
}
+export async function createPendingLinuxDoOAuthAccount(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
+): Promise {
+ return createPendingOAuthAccount('linuxdo', invitationCode, decision, affiliateCode)
+}
+
+export async function createPendingOIDCOAuthAccount(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
+): Promise {
+ return createPendingOAuthAccount('oidc', invitationCode, decision, affiliateCode)
+}
+
+export async function createPendingWeChatOAuthAccount(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
+): Promise {
+ return createPendingOAuthAccount('wechat', invitationCode, decision, affiliateCode)
+}
+
+export async function completePendingOAuthBindLogin(
+ decision?: OAuthAdoptionDecision
+): Promise {
+ const { data } = await apiClient.post(
+ '/auth/oauth/pending/exchange',
+ serializeOAuthAdoptionDecision(decision)
+ )
+ return data
+}
+
+export async function exchangePendingOAuthCompletion(
+ decision?: OAuthAdoptionDecision
+): Promise {
+ return completePendingOAuthBindLogin(decision)
+}
+
export const authAPI = {
login,
login2FA,
@@ -374,13 +666,24 @@ export const authAPI = {
clearAuthToken,
getPublicSettings,
sendVerifyCode,
+ sendPendingOAuthVerifyCode,
validatePromoCode,
validateInvitationCode,
forgotPassword,
resetPassword,
refreshToken,
revokeAllSessions,
- completeLinuxDoOAuthRegistration
+ getPendingOAuthBindLoginKind,
+ isPendingOAuthCreateAccountRequired,
+ hasPendingOAuthSuggestedProfile,
+ completePendingOAuthBindLogin,
+ createPendingLinuxDoOAuthAccount,
+ createPendingOIDCOAuthAccount,
+ createPendingWeChatOAuthAccount,
+ exchangePendingOAuthCompletion,
+ completeLinuxDoOAuthRegistration,
+ completeOIDCOAuthRegistration,
+ completeWeChatOAuthRegistration
}
export default authAPI
diff --git a/frontend/src/api/channelMonitor.ts b/frontend/src/api/channelMonitor.ts
new file mode 100644
index 00000000..38dd0c99
--- /dev/null
+++ b/frontend/src/api/channelMonitor.ts
@@ -0,0 +1,83 @@
+/**
+ * User-facing Channel Monitor API endpoints
+ * Read-only views for end users to inspect channel availability/status.
+ */
+
+import { apiClient } from './client'
+import type { Provider, MonitorStatus } from './admin/channelMonitor'
+
+export type { Provider, MonitorStatus } from './admin/channelMonitor'
+
+export interface UserMonitorExtraModel {
+ model: string
+ status: MonitorStatus
+ latency_ms: number | null
+}
+
+export interface MonitorTimelinePoint {
+ status: MonitorStatus
+ latency_ms: number | null
+ ping_latency_ms: number | null
+ checked_at: string
+}
+
+export interface UserMonitorView {
+ id: number
+ name: string
+ provider: Provider
+ group_name: string
+ primary_model: string
+ primary_status: MonitorStatus
+ primary_latency_ms: number | null
+ primary_ping_latency_ms: number | null
+ availability_7d: number
+ extra_models: UserMonitorExtraModel[]
+ timeline: MonitorTimelinePoint[]
+}
+
+export interface UserMonitorListResponse {
+ items: UserMonitorView[]
+}
+
+export interface UserMonitorModelDetail {
+ model: string
+ latest_status: MonitorStatus
+ latest_latency_ms: number | null
+ availability_7d: number
+ availability_15d: number
+ availability_30d: number
+ avg_latency_7d_ms: number | null
+}
+
+export interface UserMonitorDetail {
+ id: number
+ name: string
+ provider: Provider
+ group_name: string
+ models: UserMonitorModelDetail[]
+}
+
+/**
+ * List all monitor views available to the current user.
+ */
+export async function list(options?: { signal?: AbortSignal }): Promise {
+ const { data } = await apiClient.get('/channel-monitors', {
+ signal: options?.signal,
+ })
+ return data
+}
+
+/**
+ * Get detailed status (multi-window availability + latency) for a single monitor.
+ */
+export async function status(id: number): Promise {
+ const { data } = await apiClient.get(`/channel-monitors/${id}/status`)
+ return data
+}
+
+export const channelMonitorUserAPI = {
+ list,
+ status,
+}
+
+export default channelMonitorUserAPI
diff --git a/frontend/src/api/channels.ts b/frontend/src/api/channels.ts
new file mode 100644
index 00000000..8962af2c
--- /dev/null
+++ b/frontend/src/api/channels.ts
@@ -0,0 +1,76 @@
+/**
+ * User Channels API endpoints (non-admin)
+ * 用户侧「可用渠道」聚合查询:渠道 + 用户可访问的分组 + 支持模型(含定价)。
+ */
+
+import { apiClient } from './client'
+import type { BillingMode } from '@/constants/channel'
+
+export interface UserAvailableGroup {
+ id: number
+ name: string
+ platform: string
+ /** 'standard' | 'subscription' — 订阅分组视觉加深,和 API 密钥页保持一致。 */
+ subscription_type: string
+ /** 分组默认倍率。用户专属倍率(若有)通过 /groups/rates 获取后在前端 join。 */
+ rate_multiplier: number
+ /** true = 专属分组(小范围授权);false = 公开分组。 */
+ is_exclusive: boolean
+}
+
+export interface UserPricingInterval {
+ min_tokens: number
+ max_tokens: number | null
+ tier_label?: string
+ input_price: number | null
+ output_price: number | null
+ cache_write_price: number | null
+ cache_read_price: number | null
+ per_request_price: number | null
+}
+
+export interface UserSupportedModelPricing {
+ billing_mode: BillingMode
+ input_price: number | null
+ output_price: number | null
+ cache_write_price: number | null
+ cache_read_price: number | null
+ image_output_price: number | null
+ per_request_price: number | null
+ intervals: UserPricingInterval[]
+}
+
+export interface UserSupportedModel {
+ name: string
+ platform: string
+ pricing: UserSupportedModelPricing | null
+}
+
+/**
+ * 渠道下单个平台的子视图:用户可访问的分组 + 该平台支持的模型。
+ * 后端把一个渠道按平台聚合成 sections,前端可以把渠道名作为 row-group
+ * 一次渲染,后面按 sections 顺序用 rowspan 铺开。
+ */
+export interface UserChannelPlatformSection {
+ platform: string
+ groups: UserAvailableGroup[]
+ supported_models: UserSupportedModel[]
+}
+
+export interface UserAvailableChannel {
+ name: string
+ description: string
+ platforms: UserChannelPlatformSection[]
+}
+
+/** 列出当前用户可见的「可用渠道」(与 /groups/available 保持一致,返回平数组)。 */
+export async function getAvailable(options?: { signal?: AbortSignal }): Promise {
+ const { data } = await apiClient.get('/channels/available', {
+ signal: options?.signal
+ })
+ return data
+}
+
+export const userChannelsAPI = { getAvailable }
+
+export default userChannelsAPI
diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts
index 95f9ff31..54ea4520 100644
--- a/frontend/src/api/client.ts
+++ b/frontend/src/api/client.ts
@@ -13,6 +13,7 @@ const API_BASE_URL = import.meta.env.VITE_API_BASE_URL || '/api/v1'
export const apiClient: AxiosInstance = axios.create({
baseURL: API_BASE_URL,
+ withCredentials: true,
timeout: 30000,
headers: {
'Content-Type': 'application/json'
@@ -92,10 +93,13 @@ apiClient.interceptors.response.use(
response.data = apiResponse.data
} else {
// API error
+ const resp = apiResponse as unknown as Record
return Promise.reject({
status: response.status,
code: apiResponse.code,
- message: apiResponse.message || 'Unknown error'
+ message: apiResponse.message || 'Unknown error',
+ reason: resp.reason,
+ metadata: resp.metadata,
})
}
}
@@ -267,8 +271,10 @@ apiClient.interceptors.response.use(
return Promise.reject({
status,
code: apiData.code,
+ reason: apiData.reason,
error: apiData.error,
- message: apiData.message || apiData.detail || error.message
+ message: apiData.message || apiData.detail || error.message,
+ metadata: apiData.metadata,
})
}
diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts
index 99e3cf32..02adb0d2 100644
--- a/frontend/src/api/index.ts
+++ b/frontend/src/api/index.ts
@@ -14,9 +14,12 @@ export { keysAPI } from './keys'
export { usageAPI } from './usage'
export { userAPI } from './user'
export { redeemAPI, type RedeemHistoryItem } from './redeem'
+export { paymentAPI } from './payment'
export { userGroupsAPI } from './groups'
+export { userChannelsAPI } from './channels'
export { totpAPI } from './totp'
export { default as announcementsAPI } from './announcements'
+export { channelMonitorUserAPI } from './channelMonitor'
// Admin APIs
export { adminAPI } from './admin'
diff --git a/frontend/src/api/keys.ts b/frontend/src/api/keys.ts
index 137e10ba..34dd5b4b 100644
--- a/frontend/src/api/keys.ts
+++ b/frontend/src/api/keys.ts
@@ -17,7 +17,13 @@ import type { ApiKey, CreateApiKeyRequest, UpdateApiKeyRequest, PaginatedRespons
export async function list(
page: number = 1,
pageSize: number = 10,
- filters?: { search?: string; status?: string; group_id?: number | string },
+ filters?: {
+ search?: string
+ status?: string
+ group_id?: number | string
+ sort_by?: string
+ sort_order?: 'asc' | 'desc'
+ },
options?: {
signal?: AbortSignal
}
diff --git a/frontend/src/api/payment.ts b/frontend/src/api/payment.ts
new file mode 100644
index 00000000..92b0ec90
--- /dev/null
+++ b/frontend/src/api/payment.ts
@@ -0,0 +1,89 @@
+/**
+ * User Payment API endpoints
+ * Handles payment operations for regular users
+ */
+
+import { apiClient } from './client'
+import type {
+ PaymentConfig,
+ SubscriptionPlan,
+ PaymentChannel,
+ MethodLimitsResponse,
+ CheckoutInfoResponse,
+ CreateOrderRequest,
+ CreateOrderResult,
+ PaymentOrder
+} from '@/types/payment'
+import type { BasePaginationResponse } from '@/types'
+
+export const paymentAPI = {
+ /** Get payment configuration (enabled types, limits, etc.) */
+ getConfig() {
+ return apiClient.get('/payment/config')
+ },
+
+ /** Get available subscription plans */
+ getPlans() {
+ return apiClient.get('/payment/plans')
+ },
+
+ /** Get available payment channels */
+ getChannels() {
+ return apiClient.get('/payment/channels')
+ },
+
+ /** Get all checkout page data in a single call */
+ getCheckoutInfo() {
+ return apiClient.get('/payment/checkout-info')
+ },
+
+ /** Get payment method limits and fee rates */
+ getLimits() {
+ return apiClient.get('/payment/limits')
+ },
+
+ /** Create a new payment order */
+ createOrder(data: CreateOrderRequest) {
+ return apiClient.post('/payment/orders', data)
+ },
+
+ /** Get current user's orders */
+ getMyOrders(params?: { page?: number; page_size?: number; status?: string }) {
+ return apiClient.get>('/payment/orders/my', { params })
+ },
+
+ /** Get a specific order by ID */
+ getOrder(id: number) {
+ return apiClient.get(`/payment/orders/${id}`)
+ },
+
+ /** Cancel a pending order */
+ cancelOrder(id: number) {
+ return apiClient.post(`/payment/orders/${id}/cancel`)
+ },
+
+ /** Verify order payment status with upstream provider */
+ verifyOrder(outTradeNo: string) {
+ return apiClient.post('/payment/orders/verify', { out_trade_no: outTradeNo })
+ },
+
+ /** Legacy-compatible public order lookup by out_trade_no */
+ verifyOrderPublic(outTradeNo: string) {
+ return apiClient.post('/payment/public/orders/verify', { out_trade_no: outTradeNo })
+ },
+
+ /** Resolve an order from a signed resume token without auth */
+ resolveOrderPublicByResumeToken(resumeToken: string) {
+ return apiClient.post('/payment/public/orders/resolve', { resume_token: resumeToken })
+ },
+
+ /** Request a refund for a completed order */
+ requestRefund(id: number, data: { reason: string }) {
+ return apiClient.post(`/payment/orders/${id}/refund-request`, data)
+ },
+
+ /** Get provider instance IDs that allow user refund */
+ getRefundEligibleProviders() {
+ return apiClient.get<{ provider_instance_ids: string[] }>('/payment/orders/refund-eligible-providers')
+ }
+}
diff --git a/frontend/src/api/sora.ts b/frontend/src/api/sora.ts
deleted file mode 100644
index 45108454..00000000
--- a/frontend/src/api/sora.ts
+++ /dev/null
@@ -1,307 +0,0 @@
-/**
- * Sora 客户端 API
- * 封装所有 Sora 生成、作品库、配额等接口调用
- */
-
-import { apiClient } from './client'
-
-// ==================== 类型定义 ====================
-
-export interface SoraGeneration {
- id: number
- user_id: number
- model: string
- prompt: string
- media_type: string
- status: string // pending | generating | completed | failed | cancelled
- storage_type: string // upstream | s3 | local
- media_url: string
- media_urls: string[]
- s3_object_keys: string[]
- file_size_bytes: number
- error_message: string
- created_at: string
- completed_at?: string
-}
-
-export interface GenerateRequest {
- model: string
- prompt: string
- video_count?: number
- media_type?: string
- image_input?: string
- api_key_id?: number
-}
-
-export interface GenerateResponse {
- generation_id: number
- status: string
-}
-
-export interface GenerationListResponse {
- data: SoraGeneration[]
- total: number
- page: number
-}
-
-export interface QuotaInfo {
- quota_bytes: number
- used_bytes: number
- available_bytes: number
- quota_source: string // user | group | system | unlimited
- source?: string // 兼容旧字段
-}
-
-export interface StorageStatus {
- s3_enabled: boolean
- s3_healthy: boolean
- local_enabled: boolean
-}
-
-/** 单个扁平模型(旧接口,保留兼容) */
-export interface SoraModel {
- id: string
- name: string
- type: string // video | image
- orientation?: string
- duration?: number
-}
-
-/** 模型家族(新接口 — 后端从 soraModelConfigs 自动聚合) */
-export interface SoraModelFamily {
- id: string // 家族 ID,如 "sora2"
- name: string // 显示名,如 "Sora 2"
- type: string // "video" | "image"
- orientations: string[] // ["landscape", "portrait"] 或 ["landscape", "portrait", "square"]
- durations?: number[] // [10, 15, 25](仅视频模型)
-}
-
-type LooseRecord = Record
-
-function asRecord(value: unknown): LooseRecord | null {
- return value !== null && typeof value === 'object' ? value as LooseRecord : null
-}
-
-function asArray(value: unknown): T[] {
- return Array.isArray(value) ? value as T[] : []
-}
-
-function asPositiveInt(value: unknown): number | null {
- const n = Number(value)
- if (!Number.isFinite(n) || n <= 0) return null
- return Math.round(n)
-}
-
-function dedupeStrings(values: string[]): string[] {
- return Array.from(new Set(values))
-}
-
-function extractOrientationFromModelID(modelID: string): string | null {
- const m = modelID.match(/-(landscape|portrait|square)(?:-\d+s)?$/i)
- return m ? m[1].toLowerCase() : null
-}
-
-function extractDurationFromModelID(modelID: string): number | null {
- const m = modelID.match(/-(\d+)s$/i)
- return m ? asPositiveInt(m[1]) : null
-}
-
-function normalizeLegacyFamilies(candidates: unknown[]): SoraModelFamily[] {
- const familyMap = new Map()
-
- for (const item of candidates) {
- const model = asRecord(item)
- if (!model || typeof model.id !== 'string' || model.id.trim() === '') continue
-
- const rawID = model.id.trim()
- const type = model.type === 'image' ? 'image' : 'video'
- const name = typeof model.name === 'string' && model.name.trim() ? model.name.trim() : rawID
- const baseID = rawID.replace(/-(landscape|portrait|square)(?:-\d+s)?$/i, '')
- const orientation =
- typeof model.orientation === 'string' && model.orientation
- ? model.orientation.toLowerCase()
- : extractOrientationFromModelID(rawID)
- const duration = asPositiveInt(model.duration) ?? extractDurationFromModelID(rawID)
- const familyKey = baseID || rawID
-
- const family = familyMap.get(familyKey) ?? {
- id: familyKey,
- name,
- type,
- orientations: [],
- durations: []
- }
-
- if (orientation) {
- family.orientations.push(orientation)
- }
- if (type === 'video' && duration) {
- family.durations = family.durations || []
- family.durations.push(duration)
- }
-
- familyMap.set(familyKey, family)
- }
-
- return Array.from(familyMap.values())
- .map((family) => ({
- ...family,
- orientations:
- family.orientations.length > 0
- ? dedupeStrings(family.orientations)
- : (family.type === 'image' ? ['square'] : ['landscape']),
- durations:
- family.type === 'video'
- ? Array.from(new Set((family.durations || []).filter((d): d is number => Number.isFinite(d)))).sort((a, b) => a - b)
- : []
- }))
- .filter((family) => family.id !== '')
-}
-
-function normalizeModelFamilyRecord(item: unknown): SoraModelFamily | null {
- const model = asRecord(item)
- if (!model || typeof model.id !== 'string' || model.id.trim() === '') return null
- // 仅把明确的“家族结构”识别为 family;老结构(单模型)走 legacy 聚合逻辑。
- if (!Array.isArray(model.orientations) && !Array.isArray(model.durations)) return null
-
- const orientations = asArray(model.orientations).filter((o): o is string => typeof o === 'string' && o.length > 0)
- const durations = asArray(model.durations)
- .map(asPositiveInt)
- .filter((d): d is number => d !== null)
-
- return {
- id: model.id.trim(),
- name: typeof model.name === 'string' && model.name.trim() ? model.name.trim() : model.id.trim(),
- type: model.type === 'image' ? 'image' : 'video',
- orientations: dedupeStrings(orientations),
- durations: Array.from(new Set(durations)).sort((a, b) => a - b)
- }
-}
-
-function extractCandidateArray(payload: unknown): unknown[] {
- if (Array.isArray(payload)) return payload
- const record = asRecord(payload)
- if (!record) return []
-
- const keys: Array = ['data', 'items', 'models', 'families']
- for (const key of keys) {
- if (Array.isArray(record[key])) {
- return record[key] as unknown[]
- }
- }
- return []
-}
-
-export function normalizeModelFamiliesResponse(payload: unknown): SoraModelFamily[] {
- const candidates = extractCandidateArray(payload)
- if (candidates.length === 0) return []
-
- const normalized = candidates
- .map(normalizeModelFamilyRecord)
- .filter((item): item is SoraModelFamily => item !== null)
-
- if (normalized.length > 0) return normalized
- return normalizeLegacyFamilies(candidates)
-}
-
-export function normalizeGenerationListResponse(payload: unknown): GenerationListResponse {
- const record = asRecord(payload)
- if (!record) {
- return { data: [], total: 0, page: 1 }
- }
-
- const data = Array.isArray(record.data)
- ? (record.data as SoraGeneration[])
- : Array.isArray(record.items)
- ? (record.items as SoraGeneration[])
- : []
-
- const total = Number(record.total)
- const page = Number(record.page)
-
- return {
- data,
- total: Number.isFinite(total) ? total : data.length,
- page: Number.isFinite(page) && page > 0 ? page : 1
- }
-}
-
-// ==================== API 方法 ====================
-
-/** 异步生成 — 创建 pending 记录后立即返回 */
-export async function generate(req: GenerateRequest): Promise {
- const { data } = await apiClient.post('/sora/generate', req)
- return data
-}
-
-/** 查询生成记录列表 */
-export async function listGenerations(params?: {
- page?: number
- page_size?: number
- status?: string
- storage_type?: string
- media_type?: string
-}): Promise {
- const { data } = await apiClient.get('/sora/generations', { params })
- return normalizeGenerationListResponse(data)
-}
-
-/** 查询生成记录详情 */
-export async function getGeneration(id: number): Promise {
- const { data } = await apiClient.get(`/sora/generations/${id}`)
- return data
-}
-
-/** 删除生成记录 */
-export async function deleteGeneration(id: number): Promise<{ message: string }> {
- const { data } = await apiClient.delete<{ message: string }>(`/sora/generations/${id}`)
- return data
-}
-
-/** 取消生成任务 */
-export async function cancelGeneration(id: number): Promise<{ message: string }> {
- const { data } = await apiClient.post<{ message: string }>(`/sora/generations/${id}/cancel`)
- return data
-}
-
-/** 手动保存到 S3 */
-export async function saveToStorage(
- id: number
-): Promise<{ message: string; object_key: string; object_keys?: string[] }> {
- const { data } = await apiClient.post<{ message: string; object_key: string; object_keys?: string[] }>(
- `/sora/generations/${id}/save`
- )
- return data
-}
-
-/** 查询配额信息 */
-export async function getQuota(): Promise {
- const { data } = await apiClient.get('/sora/quota')
- return data
-}
-
-/** 获取可用模型家族列表 */
-export async function getModels(): Promise {
- const { data } = await apiClient.get('/sora/models')
- return normalizeModelFamiliesResponse(data)
-}
-
-/** 获取存储状态 */
-export async function getStorageStatus(): Promise {
- const { data } = await apiClient.get('/sora/storage-status')
- return data
-}
-
-const soraAPI = {
- generate,
- listGenerations,
- getGeneration,
- deleteGeneration,
- cancelGeneration,
- saveToStorage,
- getQuota,
- getModels,
- getStorageStatus
-}
-
-export default soraAPI
diff --git a/frontend/src/api/usage.ts b/frontend/src/api/usage.ts
index 6efd7657..802c428f 100644
--- a/frontend/src/api/usage.ts
+++ b/frontend/src/api/usage.ts
@@ -91,7 +91,7 @@ export async function list(
* @returns Paginated list of usage logs
*/
export async function query(
- params: UsageQueryParams,
+ params: UsageQueryParams & { sort_by?: string; sort_order?: 'asc' | 'desc' },
config: { signal?: AbortSignal } = {}
): Promise> {
const { data } = await apiClient.get>('/usage', {
diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts
index bfc0e30b..da7a91eb 100644
--- a/frontend/src/api/user.ts
+++ b/frontend/src/api/user.ts
@@ -4,7 +4,19 @@
*/
import { apiClient } from './client'
-import type { User, ChangePasswordRequest } from '@/types'
+import {
+ resolveWeChatOAuthStartStrict,
+ prepareOAuthBindAccessTokenCookie,
+ type WeChatOAuthPublicSettings,
+} from './auth'
+import type {
+ User,
+ ChangePasswordRequest,
+ NotifyEmailEntry,
+ UserAuthProvider,
+ UserAffiliateDetail,
+ AffiliateTransferResponse
+} from '@/types'
/**
* Get current user profile
@@ -22,6 +34,10 @@ export async function getProfile(): Promise {
*/
export async function updateProfile(profile: {
username?: string
+ avatar_url?: string | null
+ balance_notify_enabled?: boolean
+ balance_notify_threshold?: number | null
+ balance_notify_extra_emails?: NotifyEmailEntry[]
}): Promise {
const { data } = await apiClient.put('/user', profile)
return data
@@ -45,10 +61,145 @@ export async function changePassword(
return data
}
+/**
+ * Send verification code for adding a notify email
+ * @param email - Email address to verify
+ */
+export async function sendNotifyEmailCode(email: string): Promise {
+ await apiClient.post('/user/notify-email/send-code', { email })
+}
+
+/**
+ * Verify and add a notify email
+ * @param email - Email address to add
+ * @param code - Verification code
+ */
+export async function verifyNotifyEmail(email: string, code: string): Promise {
+ await apiClient.post('/user/notify-email/verify', { email, code })
+}
+
+/**
+ * Remove a notify email
+ * @param email - Email address to remove
+ */
+export async function removeNotifyEmail(email: string): Promise {
+ await apiClient.delete('/user/notify-email', { data: { email } })
+}
+
+/**
+ * Toggle a notify email's disabled state
+ * @param email - Email address (empty string for primary email placeholder)
+ * @param disabled - Whether to disable the email
+ */
+export async function toggleNotifyEmail(email: string, disabled: boolean): Promise {
+ const { data } = await apiClient.put('/user/notify-email/toggle', { email, disabled })
+ return data
+}
+
+export async function sendEmailBindingCode(email: string): Promise {
+ await apiClient.post('/user/account-bindings/email/send-code', { email })
+}
+
+export async function bindEmailIdentity(payload: {
+ email: string
+ verify_code: string
+ password: string
+}): Promise {
+ const { data } = await apiClient.post('/user/account-bindings/email', payload)
+ return data
+}
+
+export async function unbindAuthIdentity(provider: BindableOAuthProvider): Promise {
+ const { data } = await apiClient.delete(`/user/account-bindings/${provider}`)
+ return data
+}
+
+export type BindableOAuthProvider = Exclude
+
+interface BuildOAuthBindingStartURLOptions {
+ redirectTo?: string
+ wechatOAuthSettings?: WeChatOAuthPublicSettings | null
+}
+
+export function resolveWeChatOAuthMode(): 'open' | 'mp' {
+ if (typeof navigator === 'undefined') {
+ return 'open'
+ }
+ return /MicroMessenger/i.test(navigator.userAgent) ? 'mp' : 'open'
+}
+
+function resolveWeChatOAuthBindingMode(
+ settings?: WeChatOAuthPublicSettings | null
+): 'open' | 'mp' | null {
+ if (settings) {
+ return resolveWeChatOAuthStartStrict(settings).mode
+ }
+ return resolveWeChatOAuthMode()
+}
+
+export function buildOAuthBindingStartURL(
+ provider: BindableOAuthProvider,
+ options: BuildOAuthBindingStartURLOptions = {}
+): string | null {
+ const redirectTo = options.redirectTo?.trim() || '/profile'
+ const apiBase = (import.meta.env.VITE_API_BASE_URL as string | undefined) || '/api/v1'
+ const normalized = apiBase.replace(/\/$/, '')
+ const params = new URLSearchParams({
+ redirect: redirectTo,
+ intent: 'bind_current_user'
+ })
+
+ if (provider === 'wechat') {
+ const mode = resolveWeChatOAuthBindingMode(options.wechatOAuthSettings)
+ if (!mode) {
+ return null
+ }
+ params.set('mode', mode)
+ }
+
+ return `${normalized}/auth/oauth/${provider}/bind/start?${params.toString()}`
+}
+
+export async function startOAuthBinding(
+ provider: BindableOAuthProvider,
+ options: BuildOAuthBindingStartURLOptions = {}
+): Promise