diff --git a/.dockerignore b/.dockerignore
index e4e8e72e..0670cd7d 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -4,4 +4,5 @@
.vscode
.gitignore
Makefile
-docs
\ No newline at end of file
+docs
+.eslintcache
\ No newline at end of file
diff --git a/.env.example b/.env.example
index bece06db..72645404 100644
--- a/.env.example
+++ b/.env.example
@@ -7,6 +7,8 @@
# 调试相关配置
# 启用pprof
# ENABLE_PPROF=true
+# 启用调试模式
+# DEBUG=true
# 数据库相关配置
# 数据库连接字符串
@@ -41,6 +43,14 @@
# 更新任务启用
# UPDATE_TASK=true
+# 对话超时设置
+# 所有请求超时时间,单位秒,默认为0,表示不限制
+# RELAY_TIMEOUT=0
+# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
+# STREAMING_TIMEOUT=300
+
+# Gemini 识别图片 最大图片数量
+# GEMINI_VISION_MAX_IMAGE_NUM=16
# 会话密钥
# SESSION_SECRET=random_string
@@ -58,8 +68,6 @@
# GET_MEDIA_TOKEN_NOT_STREAM=true
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
# DIFY_DEBUG=true
-# 设置流式一次回复的超时时间
-# STREAMING_TIMEOUT=90
# 节点类型
diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
new file mode 100644
index 00000000..4f6e41ac
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
@@ -0,0 +1,19 @@
+### PR 类型
+
+- [ ] Bug 修复
+- [ ] 新功能
+- [ ] 文档更新
+- [ ] 其他
+
+### PR 是否包含破坏性更新?
+
+- [ ] 是
+- [ ] 否
+
+### PR 描述
+
+**请在下方详细描述您的 PR,包括目的、实现细节等。**
+
+### **重要提示**
+
+**所有 PR 都必须提交到 `alpha` 分支。请确保您的 PR 目标分支是 `alpha`。**
diff --git a/.github/workflows/docker-image-amd64.yml b/.github/workflows/docker-image-alpha.yml
similarity index 72%
rename from .github/workflows/docker-image-amd64.yml
rename to .github/workflows/docker-image-alpha.yml
index a823151c..c02bd409 100644
--- a/.github/workflows/docker-image-amd64.yml
+++ b/.github/workflows/docker-image-alpha.yml
@@ -1,14 +1,15 @@
-name: Publish Docker image (amd64)
+name: Publish Docker image (alpha)
on:
push:
- tags:
- - '*'
+ branches:
+ - alpha
workflow_dispatch:
inputs:
name:
- description: 'reason'
+ description: "reason"
required: false
+
jobs:
push_to_registries:
name: Push Docker image to multiple registries
@@ -22,7 +23,7 @@ jobs:
- name: Save version info
run: |
- git describe --tags > VERSION
+ echo "alpha-$(date +'%Y%m%d')-$(git rev-parse --short HEAD)" > VERSION
- name: Log in to Docker Hub
uses: docker/login-action@v3
@@ -37,6 +38,9 @@ jobs:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v5
@@ -44,11 +48,15 @@ jobs:
images: |
calciumion/new-api
ghcr.io/${{ github.repository }}
+ tags: |
+ type=raw,value=alpha
+ type=raw,value=alpha-{{date 'YYYYMMDD'}}-{{sha}}
- name: Build and push Docker images
uses: docker/build-push-action@v5
with:
context: .
+ platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
- labels: ${{ steps.meta.outputs.labels }}
\ No newline at end of file
+ labels: ${{ steps.meta.outputs.labels }}
diff --git a/.github/workflows/docker-image-arm64.yml b/.github/workflows/docker-image-arm64.yml
index d7468c8e..8e4656aa 100644
--- a/.github/workflows/docker-image-arm64.yml
+++ b/.github/workflows/docker-image-arm64.yml
@@ -1,14 +1,9 @@
-name: Publish Docker image (arm64)
+name: Publish Docker image (Multi Registries)
on:
push:
tags:
- '*'
- workflow_dispatch:
- inputs:
- name:
- description: 'reason'
- required: false
jobs:
push_to_registries:
name: Push Docker image to multiple registries
diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml
index 3ddabc6d..c87fcfce 100644
--- a/.github/workflows/linux-release.yml
+++ b/.github/workflows/linux-release.yml
@@ -3,6 +3,11 @@ permissions:
contents: write
on:
+ workflow_dispatch:
+ inputs:
+ name:
+ description: 'reason'
+ required: false
push:
tags:
- '*'
@@ -15,16 +20,16 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- - uses: actions/setup-node@v3
+ - uses: oven-sh/setup-bun@v2
with:
- node-version: 18
+ bun-version: latest
- name: Build Frontend
env:
CI: ""
run: |
cd web
- npm install
- REACT_APP_VERSION=$(git describe --tags) npm run build
+ bun install
+ DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
cd ..
- name: Set up Go
uses: actions/setup-go@v3
diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml
index ccc480bf..1bc786ac 100644
--- a/.github/workflows/macos-release.yml
+++ b/.github/workflows/macos-release.yml
@@ -3,6 +3,11 @@ permissions:
contents: write
on:
+ workflow_dispatch:
+ inputs:
+ name:
+ description: 'reason'
+ required: false
push:
tags:
- '*'
@@ -15,16 +20,17 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- - uses: actions/setup-node@v3
+ - uses: oven-sh/setup-bun@v2
with:
- node-version: 18
+ bun-version: latest
- name: Build Frontend
env:
CI: ""
+ NODE_OPTIONS: "--max-old-space-size=4096"
run: |
cd web
- npm install
- REACT_APP_VERSION=$(git describe --tags) npm run build
+ bun install
+ DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
cd ..
- name: Set up Go
uses: actions/setup-go@v3
diff --git a/.github/workflows/pr-target-branch-check.yml b/.github/workflows/pr-target-branch-check.yml
new file mode 100644
index 00000000..e7bd4c81
--- /dev/null
+++ b/.github/workflows/pr-target-branch-check.yml
@@ -0,0 +1,21 @@
+name: Check PR Branching Strategy
+on:
+ pull_request:
+ types: [opened, synchronize, reopened, edited]
+
+jobs:
+ check-branching-strategy:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Enforce branching strategy
+ run: |
+ if [[ "${{ github.base_ref }}" == "main" ]]; then
+ if [[ "${{ github.head_ref }}" != "alpha" ]]; then
+ echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
+ exit 1
+ fi
+ elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
+ echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
+ exit 1
+ fi
+ echo "Branching strategy check passed."
\ No newline at end of file
diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml
index f9500718..de3d83d5 100644
--- a/.github/workflows/windows-release.yml
+++ b/.github/workflows/windows-release.yml
@@ -3,6 +3,11 @@ permissions:
contents: write
on:
+ workflow_dispatch:
+ inputs:
+ name:
+ description: 'reason'
+ required: false
push:
tags:
- '*'
@@ -18,16 +23,16 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- - uses: actions/setup-node@v3
+ - uses: oven-sh/setup-bun@v2
with:
- node-version: 18
+ bun-version: latest
- name: Build Frontend
env:
CI: ""
run: |
cd web
- npm install
- REACT_APP_VERSION=$(git describe --tags) npm run build
+ bun install
+ DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
cd ..
- name: Set up Go
uses: actions/setup-go@v3
diff --git a/.gitignore b/.gitignore
index 6a23f89e..1382829f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,4 +10,5 @@ web/dist
.env
one-api
.DS_Store
-tiktoken_cache
\ No newline at end of file
+tiktoken_cache
+.eslintcache
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
index 214ceaa3..08cc86f7 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -2,6 +2,7 @@ FROM oven/bun:latest AS builder
WORKDIR /build
COPY web/package.json .
+COPY web/bun.lock .
RUN bun install
COPY ./web .
COPY ./VERSION .
@@ -24,8 +25,7 @@ RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-
FROM alpine
-RUN apk update \
- && apk upgrade \
+RUN apk upgrade --no-cache \
&& apk add --no-cache ca-certificates tzdata ffmpeg \
&& update-ca-certificates
diff --git a/LICENSE b/LICENSE
index 261eeb9e..71284f6d 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,201 +1,103 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
+# **New API 许可协议 (Licensing)**
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+本项目采用**基于使用场景的双重许可 (Usage-Based Dual Licensing)** 模式。
- 1. Definitions.
+**核心原则:**
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
+- **默认许可:** 本项目默认在 **GNU Affero 通用公共许可证 v3.0 (AGPLv3)** 下提供。任何用户在遵守 AGPLv3 条款和下述附加限制的前提下,均可免费使用。
+- **商业许可:** 在特定商业场景下,或当您希望获得 AGPLv3 之外的权利时,**必须**获取**商业许可证 (Commercial License)**。
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
+---
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
+## **1. 开源许可证 (Open Source License): AGPLv3 - 适用于基础使用**
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
+- 在遵守 **AGPLv3** 条款的前提下,您可以自由地使用、修改和分发 New API。AGPLv3 的完整文本可以访问 [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html) 获取。
+- **核心义务:** AGPLv3 的一个关键要求是,如果您修改了 New API 并通过网络提供服务 (SaaS),或者分发了修改后的版本,您必须以 AGPLv3 许可证向所有用户提供相应的**完整源代码**。
+- **附加限制 (重要):** 在仅使用 AGPLv3 开源许可证的情况下,您**必须**完整保留项目代码中原有的品牌标识、LOGO 及版权声明信息。**禁止以任何形式修改、移除或遮盖**这些信息。如需移除,必须获取商业许可证。
+- 使用前请务必仔细阅读并理解 AGPLv3 的所有条款及上述附加限制。
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
+## **2. 商业许可证 (Commercial License) - 适用于高级场景及闭源需求**
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
+在以下任一情况下,您**必须**联系我们获取并签署一份商业许可证,才能合法使用 New API:
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
+- **场景一:移除品牌和版权信息**
+ 您希望在您的产品或服务中移除 New API 的 LOGO、UI界面中的版权声明或其他品牌标识。
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
+- **场景二:规避 AGPLv3 开源义务**
+ 您基于 New API 进行了修改,并希望:
+ - 通过网络提供服务(SaaS),但**不希望**向您的服务用户公开您修改后的源代码。
+ - 分发一个集成了 New API 的软件产品,但**不希望**以 AGPLv3 许可证发布您的产品或公开源代码。
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
+- **场景三:企业政策与集成需求**
+ - 您所在公司的政策、客户合同或项目要求不允许使用 AGPLv3 许可的软件。
+ - 您需要进行 OEM 集成,将 New API 作为您闭源商业产品的一部分进行再分发。
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
+- **场景四:需要商业支持与保障**
+ 您需要 AGPLv3 未提供的商业保障,如官方技术支持等。
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
+**获取商业许可:**
+请通过电子邮件 **support@quantumnous.com** 联系 New API 团队洽谈商业授权事宜。
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
+## **3. 贡献 (Contributions)**
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
+- 我们欢迎社区对 New API 的贡献。所有向本项目提交的贡献(例如通过 Pull Request)都将被视为在 **AGPLv3** 许可证下提供。
+- 通过向本项目提交贡献,即表示您同意您的代码以 AGPLv3 许可证授权给本项目及所有后续使用者(无论这些使用者最终遵循 AGPLv3 还是商业许可)。
+- 您也理解并同意,您的贡献可能会被包含在根据商业许可证分发的 New API 版本中。
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
+## **4. 其他条款 (Other Terms)**
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
+- 关于商业许可证的具体条款、条件和价格,以双方签署的正式商业许可协议为准。
+- 项目维护者保留根据需要更新本许可政策的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
+---
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
+# **New API Licensing**
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
+This project uses a **Usage-Based Dual Licensing** model.
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
+**Core Principles:**
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
+- **Default License:** This project is available by default under the **GNU Affero General Public License v3.0 (AGPLv3)**. Any user may use it free of charge, provided they comply with both the AGPLv3 terms and the additional restrictions listed below.
+- **Commercial License:** For specific commercial scenarios, or if you require rights beyond those granted by AGPLv3, you **must** obtain a **Commercial License**.
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
+---
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
+## **1. Open Source License: AGPLv3 – For Basic Usage**
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
+- Under the terms of the **AGPLv3**, you are free to use, modify, and distribute New API. The complete AGPLv3 license text can be viewed at [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html).
+- **Core Obligation:** A key AGPLv3 requirement is that if you modify New API and provide it as a network service (SaaS), or distribute a modified version, you must make the **complete corresponding source code** available to all users under the AGPLv3 license.
+- **Additional Restriction (Important):** When using only the AGPLv3 open-source license, you **must** retain all original branding, logos, and copyright statements within the project’s code. **You are strictly prohibited from modifying, removing, or concealing** any such information. If you wish to remove this, you must obtain a Commercial License.
+- Please read and ensure that you fully understand all AGPLv3 terms and the above additional restriction before use.
- END OF TERMS AND CONDITIONS
+## **2. Commercial License – For Advanced Scenarios & Closed Source Needs**
- APPENDIX: How to apply the Apache License to your work.
+You **must** contact us to obtain and sign a Commercial License in any of the following scenarios in order to legally use New API:
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
+- **Scenario 1: Removal of Branding and Copyright**
+ You wish to remove the New API logo, copyright statement, or other branding elements from your product or service.
- Copyright [yyyy] [name of copyright owner]
+- **Scenario 2: Avoidance of AGPLv3 Open Source Obligations**
+ You have modified New API and wish to:
+ - Offer it as a network service (SaaS) **without** disclosing your modifications' source code to your users.
+ - Distribute a software product integrated with New API **without** releasing your product under AGPLv3 or open-sourcing the code.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+- **Scenario 3: Enterprise Policy & Integration Needs**
+ - Your organization’s policies, client contracts, or project requirements prohibit the use of AGPLv3-licensed software.
+ - You require OEM integration and need to redistribute New API as part of your closed-source commercial product.
- http://www.apache.org/licenses/LICENSE-2.0
+- **Scenario 4: Commercial Support and Assurances**
+ You require commercial assurances not provided by AGPLv3, such as official technical support.
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
+**Obtaining a Commercial License:**
+Please contact the New API team via email at **support@quantumnous.com** to discuss commercial licensing.
+
+## **3. Contributions**
+
+- We welcome community contributions to New API. All contributions (e.g., via Pull Request) are deemed to be provided under the **AGPLv3** license.
+- By submitting a contribution, you agree that your code is licensed to this project and all downstream users under the AGPLv3 license (regardless of whether those users ultimately operate under AGPLv3 or a Commercial License).
+- You also acknowledge and agree that your contribution may be included in New API releases distributed under a Commercial License.
+
+## **4. Other Terms**
+
+- The specific terms, conditions, and pricing of the Commercial License are governed by the formal commercial license agreement executed by both parties.
+- Project maintainers reserve the right to update this licensing policy as needed. Updates will be communicated via official project channels (e.g., repository, official website).
diff --git a/README.en.md b/README.en.md
index 10a3cdb0..69fd32f8 100644
--- a/README.en.md
+++ b/README.en.md
@@ -40,6 +40,28 @@
> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**, and must not use it for illegal purposes.
> - According to the [《Interim Measures for the Management of Generative Artificial Intelligence Services》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), please do not provide any unregistered generative AI services to the public in China.
+
🤝 Trusted Partners
+
+No particular order
+
+
+
+
+
+
+
+
+
## 📚 Documentation
For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/)
@@ -100,7 +122,7 @@ This version supports multiple models, please refer to [API Documentation-Relay
For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
-- `STREAMING_TIMEOUT`: Streaming response timeout, default is 60 seconds
+- `STREAMING_TIMEOUT`: Streaming response timeout, default is 300 seconds
- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`
diff --git a/README.md b/README.md
index e9d1c154..45b04834 100644
--- a/README.md
+++ b/README.md
@@ -40,6 +40,28 @@
> - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
+🤝 我们信任的合作伙伴
+
+排名不分先后
+
+
+
+
+
+
+
+
+
## 📚 文档
详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/)
@@ -100,7 +122,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables):
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
-- `STREAMING_TIMEOUT`:流式回复超时时间,默认60秒
+- `STREAMING_TIMEOUT`:流式回复超时时间,默认300秒
- `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true`
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true`
- `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true`
@@ -180,7 +202,6 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
其他基于New API的项目:
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版
-- [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的前端美化版本
## 帮助支持
diff --git a/common/api_type.go b/common/api_type.go
new file mode 100644
index 00000000..5ac46c86
--- /dev/null
+++ b/common/api_type.go
@@ -0,0 +1,75 @@
+package common
+
+import "one-api/constant"
+
+func ChannelType2APIType(channelType int) (int, bool) {
+ apiType := -1
+ switch channelType {
+ case constant.ChannelTypeOpenAI:
+ apiType = constant.APITypeOpenAI
+ case constant.ChannelTypeAnthropic:
+ apiType = constant.APITypeAnthropic
+ case constant.ChannelTypeBaidu:
+ apiType = constant.APITypeBaidu
+ case constant.ChannelTypePaLM:
+ apiType = constant.APITypePaLM
+ case constant.ChannelTypeZhipu:
+ apiType = constant.APITypeZhipu
+ case constant.ChannelTypeAli:
+ apiType = constant.APITypeAli
+ case constant.ChannelTypeXunfei:
+ apiType = constant.APITypeXunfei
+ case constant.ChannelTypeAIProxyLibrary:
+ apiType = constant.APITypeAIProxyLibrary
+ case constant.ChannelTypeTencent:
+ apiType = constant.APITypeTencent
+ case constant.ChannelTypeGemini:
+ apiType = constant.APITypeGemini
+ case constant.ChannelTypeZhipu_v4:
+ apiType = constant.APITypeZhipuV4
+ case constant.ChannelTypeOllama:
+ apiType = constant.APITypeOllama
+ case constant.ChannelTypePerplexity:
+ apiType = constant.APITypePerplexity
+ case constant.ChannelTypeAws:
+ apiType = constant.APITypeAws
+ case constant.ChannelTypeCohere:
+ apiType = constant.APITypeCohere
+ case constant.ChannelTypeDify:
+ apiType = constant.APITypeDify
+ case constant.ChannelTypeJina:
+ apiType = constant.APITypeJina
+ case constant.ChannelCloudflare:
+ apiType = constant.APITypeCloudflare
+ case constant.ChannelTypeSiliconFlow:
+ apiType = constant.APITypeSiliconFlow
+ case constant.ChannelTypeVertexAi:
+ apiType = constant.APITypeVertexAi
+ case constant.ChannelTypeMistral:
+ apiType = constant.APITypeMistral
+ case constant.ChannelTypeDeepSeek:
+ apiType = constant.APITypeDeepSeek
+ case constant.ChannelTypeMokaAI:
+ apiType = constant.APITypeMokaAI
+ case constant.ChannelTypeVolcEngine:
+ apiType = constant.APITypeVolcEngine
+ case constant.ChannelTypeBaiduV2:
+ apiType = constant.APITypeBaiduV2
+ case constant.ChannelTypeOpenRouter:
+ apiType = constant.APITypeOpenRouter
+ case constant.ChannelTypeXinference:
+ apiType = constant.APITypeXinference
+ case constant.ChannelTypeXai:
+ apiType = constant.APITypeXai
+ case constant.ChannelTypeCoze:
+ apiType = constant.APITypeCoze
+ case constant.ChannelTypeJimeng:
+ apiType = constant.APITypeJimeng
+ case constant.ChannelTypeMoonshot:
+ apiType = constant.APITypeMoonshot
+ }
+ if apiType == -1 {
+ return constant.APITypeOpenAI, false
+ }
+ return apiType, true
+}
diff --git a/common/constants.go b/common/constants.go
index bee00506..e6d59d10 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -83,6 +83,7 @@ var GitHubClientId = ""
var GitHubClientSecret = ""
var LinuxDOClientId = ""
var LinuxDOClientSecret = ""
+var LinuxDOMinimumTrustLevel = 0
var WeChatServerAddress = ""
var WeChatServerToken = ""
@@ -195,105 +196,7 @@ const (
)
const (
- ChannelTypeUnknown = 0
- ChannelTypeOpenAI = 1
- ChannelTypeMidjourney = 2
- ChannelTypeAzure = 3
- ChannelTypeOllama = 4
- ChannelTypeMidjourneyPlus = 5
- ChannelTypeOpenAIMax = 6
- ChannelTypeOhMyGPT = 7
- ChannelTypeCustom = 8
- ChannelTypeAILS = 9
- ChannelTypeAIProxy = 10
- ChannelTypePaLM = 11
- ChannelTypeAPI2GPT = 12
- ChannelTypeAIGC2D = 13
- ChannelTypeAnthropic = 14
- ChannelTypeBaidu = 15
- ChannelTypeZhipu = 16
- ChannelTypeAli = 17
- ChannelTypeXunfei = 18
- ChannelType360 = 19
- ChannelTypeOpenRouter = 20
- ChannelTypeAIProxyLibrary = 21
- ChannelTypeFastGPT = 22
- ChannelTypeTencent = 23
- ChannelTypeGemini = 24
- ChannelTypeMoonshot = 25
- ChannelTypeZhipu_v4 = 26
- ChannelTypePerplexity = 27
- ChannelTypeLingYiWanWu = 31
- ChannelTypeAws = 33
- ChannelTypeCohere = 34
- ChannelTypeMiniMax = 35
- ChannelTypeSunoAPI = 36
- ChannelTypeDify = 37
- ChannelTypeJina = 38
- ChannelCloudflare = 39
- ChannelTypeSiliconFlow = 40
- ChannelTypeVertexAi = 41
- ChannelTypeMistral = 42
- ChannelTypeDeepSeek = 43
- ChannelTypeMokaAI = 44
- ChannelTypeVolcEngine = 45
- ChannelTypeBaiduV2 = 46
- ChannelTypeXinference = 47
- ChannelTypeXai = 48
- ChannelTypeCoze = 49
- ChannelTypeDummy // this one is only for count, do not add any channel after this
-
+ TopUpStatusPending = "pending"
+ TopUpStatusSuccess = "success"
+ TopUpStatusExpired = "expired"
)
-
-var ChannelBaseURLs = []string{
- "", // 0
- "https://api.openai.com", // 1
- "https://oa.api2d.net", // 2
- "", // 3
- "http://localhost:11434", // 4
- "https://api.openai-sb.com", // 5
- "https://api.openaimax.com", // 6
- "https://api.ohmygpt.com", // 7
- "", // 8
- "https://api.caipacity.com", // 9
- "https://api.aiproxy.io", // 10
- "", // 11
- "https://api.api2gpt.com", // 12
- "https://api.aigc2d.com", // 13
- "https://api.anthropic.com", // 14
- "https://aip.baidubce.com", // 15
- "https://open.bigmodel.cn", // 16
- "https://dashscope.aliyuncs.com", // 17
- "", // 18
- "https://api.360.cn", // 19
- "https://openrouter.ai/api", // 20
- "https://api.aiproxy.io", // 21
- "https://fastgpt.run/api/openapi", // 22
- "https://hunyuan.tencentcloudapi.com", //23
- "https://generativelanguage.googleapis.com", //24
- "https://api.moonshot.cn", //25
- "https://open.bigmodel.cn", //26
- "https://api.perplexity.ai", //27
- "", //28
- "", //29
- "", //30
- "https://api.lingyiwanwu.com", //31
- "", //32
- "", //33
- "https://api.cohere.ai", //34
- "https://api.minimax.chat", //35
- "", //36
- "https://api.dify.ai", //37
- "https://api.jina.ai", //38
- "https://api.cloudflare.com", //39
- "https://api.siliconflow.cn", //40
- "", //41
- "https://api.mistral.ai", //42
- "https://api.deepseek.com", //43
- "https://api.moka.ai", //44
- "https://ark.cn-beijing.volces.com", //45
- "https://qianfan.baidubce.com", //46
- "", //47
- "https://api.x.ai", //48
- "https://api.coze.cn", //49
-}
diff --git a/common/copy.go b/common/copy.go
new file mode 100644
index 00000000..8573d6e0
--- /dev/null
+++ b/common/copy.go
@@ -0,0 +1,21 @@
+package common
+
+import (
+ "fmt"
+ "github.com/antlabs/pcopy"
+)
+
+func DeepCopy[T any](src *T) (*T, error) {
+ if src == nil {
+ return nil, fmt.Errorf("copy source cannot be nil")
+ }
+ var dst T
+ err := pcopy.Copy(&dst, src)
+ if err != nil {
+ return nil, err
+ }
+ if &dst == nil {
+ return nil, fmt.Errorf("copy result cannot be nil")
+ }
+ return &dst, nil
+}
diff --git a/common/custom-event.go b/common/custom-event.go
index d8f9ec9f..256db546 100644
--- a/common/custom-event.go
+++ b/common/custom-event.go
@@ -9,6 +9,7 @@ import (
"io"
"net/http"
"strings"
+ "sync"
)
type stringWriter interface {
@@ -52,6 +53,8 @@ type CustomEvent struct {
Id string
Retry uint
Data interface{}
+
+ Mutex sync.Mutex
}
func encode(writer io.Writer, event CustomEvent) error {
@@ -73,6 +76,8 @@ func (r CustomEvent) Render(w http.ResponseWriter) error {
}
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
+ r.Mutex.Lock()
+ defer r.Mutex.Unlock()
header := w.Header()
header["Content-Type"] = contentType
diff --git a/common/database.go b/common/database.go
index 3c0a944b..71dbd94d 100644
--- a/common/database.go
+++ b/common/database.go
@@ -1,8 +1,15 @@
package common
+const (
+ DatabaseTypeMySQL = "mysql"
+ DatabaseTypeSQLite = "sqlite"
+ DatabaseTypePostgreSQL = "postgres"
+)
+
var UsingSQLite = false
var UsingPostgreSQL = false
+var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
var UsingMySQL = false
var UsingClickHouse = false
-var SQLitePath = "one-api.db?_busy_timeout=5000"
+var SQLitePath = "one-api.db?_busy_timeout=30000"
diff --git a/common/endpoint_defaults.go b/common/endpoint_defaults.go
new file mode 100644
index 00000000..ffc26350
--- /dev/null
+++ b/common/endpoint_defaults.go
@@ -0,0 +1,32 @@
+package common
+
+import "one-api/constant"
+
+// EndpointInfo 描述单个端点的默认请求信息
+// path: 上游路径
+// method: HTTP 请求方式,例如 POST/GET
+// 目前均为 POST,后续可扩展
+//
+// json 标签用于直接序列化到 API 输出
+// 例如:{"path":"/v1/chat/completions","method":"POST"}
+
+type EndpointInfo struct {
+ Path string `json:"path"`
+ Method string `json:"method"`
+}
+
+// defaultEndpointInfoMap 保存内置端点的默认 Path 与 Method
+var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{
+ constant.EndpointTypeOpenAI: {Path: "/v1/chat/completions", Method: "POST"},
+ constant.EndpointTypeOpenAIResponse: {Path: "/v1/responses", Method: "POST"},
+ constant.EndpointTypeAnthropic: {Path: "/v1/messages", Method: "POST"},
+ constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"},
+ constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"},
+ constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"},
+}
+
+// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在
+func GetDefaultEndpointInfo(et constant.EndpointType) (EndpointInfo, bool) {
+ info, ok := defaultEndpointInfoMap[et]
+ return info, ok
+}
diff --git a/common/endpoint_type.go b/common/endpoint_type.go
new file mode 100644
index 00000000..a0ca73ea
--- /dev/null
+++ b/common/endpoint_type.go
@@ -0,0 +1,41 @@
+package common
+
+import "one-api/constant"
+
+// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
+func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
+ var endpointTypes []constant.EndpointType
+ switch channelType {
+ case constant.ChannelTypeJina:
+ endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
+ //case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus:
+ // endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney}
+ //case constant.ChannelTypeSunoAPI:
+ // endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno}
+ //case constant.ChannelTypeKling:
+ // endpointTypes = []constant.EndpointType{constant.EndpointTypeKling}
+ //case constant.ChannelTypeJimeng:
+ // endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng}
+ case constant.ChannelTypeAws:
+ fallthrough
+ case constant.ChannelTypeAnthropic:
+ endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
+ case constant.ChannelTypeVertexAi:
+ fallthrough
+ case constant.ChannelTypeGemini:
+ endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
+ case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
+ endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
+ default:
+ if IsOpenAIResponseOnlyModel(modelName) {
+ endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
+ } else {
+ endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
+ }
+ }
+ if IsImageGenerationModel(modelName) {
+ // add to first
+ endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...)
+ }
+ return endpointTypes
+}
diff --git a/common/gin.go b/common/gin.go
index 4a909dfc..2cb35844 100644
--- a/common/gin.go
+++ b/common/gin.go
@@ -2,10 +2,13 @@ package common
import (
"bytes"
- "encoding/json"
- "github.com/gin-gonic/gin"
"io"
+ "net/http"
+ "one-api/constant"
"strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
)
const KeyRequestBody = "key_request_body"
@@ -29,9 +32,12 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
if err != nil {
return err
}
+ //if DebugEnabled {
+ // println("UnmarshalBodyReusable request body:", string(requestBody))
+ //}
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
- err = json.Unmarshal(requestBody, &v)
+ err = Unmarshal(requestBody, &v)
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
@@ -43,3 +49,67 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil
}
+
+func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
+ c.Set(string(key), value)
+}
+
+func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
+ return c.Get(string(key))
+}
+
+func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
+ return c.GetString(string(key))
+}
+
+func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
+ return c.GetInt(string(key))
+}
+
+func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
+ return c.GetBool(string(key))
+}
+
+func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
+ return c.GetStringSlice(string(key))
+}
+
+func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
+ return c.GetStringMap(string(key))
+}
+
+func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
+ return c.GetTime(string(key))
+}
+
+func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
+ if value, ok := c.Get(string(key)); ok {
+ if v, ok := value.(T); ok {
+ return v, true
+ }
+ }
+ var t T
+ return t, false
+}
+
+func ApiError(c *gin.Context, err error) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+}
+
+func ApiErrorMsg(c *gin.Context, msg string) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": msg,
+ })
+}
+
+func ApiSuccess(c *gin.Context, data any) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": data,
+ })
+}
diff --git a/common/hash.go b/common/hash.go
new file mode 100644
index 00000000..50191938
--- /dev/null
+++ b/common/hash.go
@@ -0,0 +1,34 @@
+package common
+
+import (
+ "crypto/hmac"
+ "crypto/sha1"
+ "crypto/sha256"
+ "encoding/hex"
+)
+
+func Sha256Raw(data []byte) []byte {
+ h := sha256.New()
+ h.Write(data)
+ return h.Sum(nil)
+}
+
+func Sha1Raw(data []byte) []byte {
+ h := sha1.New()
+ h.Write(data)
+ return h.Sum(nil)
+}
+
+func Sha1(data []byte) string {
+ return hex.EncodeToString(Sha1Raw(data))
+}
+
+func HmacSha256Raw(message, key []byte) []byte {
+ h := hmac.New(sha256.New, key)
+ h.Write(message)
+ return h.Sum(nil)
+}
+
+func HmacSha256(message, key string) string {
+ return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
+}
diff --git a/common/init.go b/common/init.go
index c0caf0a1..c4626f9a 100644
--- a/common/init.go
+++ b/common/init.go
@@ -4,6 +4,7 @@ import (
"flag"
"fmt"
"log"
+ "one-api/constant"
"os"
"path/filepath"
"strconv"
@@ -24,7 +25,7 @@ func printHelp() {
fmt.Println("Usage: one-api [--port ] [--log-dir ] [--version] [--help]")
}
-func LoadEnv() {
+func InitEnv() {
flag.Parse()
if *PrintVersion {
@@ -95,4 +96,25 @@ func LoadEnv() {
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
+
+ initConstantEnv()
+}
+
+func initConstantEnv() {
+ constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 300)
+ constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
+ constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
+ // ForceStreamOption 覆盖请求参数,强制返回usage信息
+ constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
+ constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
+ constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
+ constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
+ constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
+ constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
+ constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
+ constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
+ // GenerateDefaultToken 是否生成初始令牌,默认关闭。
+ constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
+ // 是否启用错误日志
+ constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
}
diff --git a/common/json.go b/common/json.go
index cec8f16b..69aa952e 100644
--- a/common/json.go
+++ b/common/json.go
@@ -5,14 +5,18 @@ import (
"encoding/json"
)
-func DecodeJson(data []byte, v any) error {
- return json.NewDecoder(bytes.NewReader(data)).Decode(v)
+func Unmarshal(data []byte, v any) error {
+ return json.Unmarshal(data, v)
}
-func DecodeJsonStr(data string, v any) error {
- return DecodeJson(StringToByteSlice(data), v)
+func UnmarshalJsonStr(data string, v any) error {
+ return json.Unmarshal(StringToByteSlice(data), v)
}
-func EncodeJson(v any) ([]byte, error) {
+func DecodeJson(reader *bytes.Reader, v any) error {
+ return json.NewDecoder(reader).Decode(v)
+}
+
+func Marshal(v any) ([]byte, error) {
return json.Marshal(v)
}
diff --git a/common/model.go b/common/model.go
new file mode 100644
index 00000000..14ca1911
--- /dev/null
+++ b/common/model.go
@@ -0,0 +1,42 @@
+package common
+
+import "strings"
+
+var (
+ // OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
+ OpenAIResponseOnlyModels = []string{
+ "o3-pro",
+ "o3-deep-research",
+ "o4-mini-deep-research",
+ }
+ ImageGenerationModels = []string{
+ "dall-e-3",
+ "dall-e-2",
+ "gpt-image-1",
+ "prefix:imagen-",
+ "flux-",
+ "flux.1-",
+ }
+)
+
+func IsOpenAIResponseOnlyModel(modelName string) bool {
+ for _, m := range OpenAIResponseOnlyModels {
+ if strings.Contains(modelName, m) {
+ return true
+ }
+ }
+ return false
+}
+
+func IsImageGenerationModel(modelName string) bool {
+ modelName = strings.ToLower(modelName)
+ for _, m := range ImageGenerationModels {
+ if strings.Contains(modelName, m) {
+ return true
+ }
+ if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) {
+ return true
+ }
+ }
+ return false
+}
diff --git a/common/page_info.go b/common/page_info.go
new file mode 100644
index 00000000..2378a5d8
--- /dev/null
+++ b/common/page_info.go
@@ -0,0 +1,82 @@
+package common
+
+import (
+ "strconv"
+
+ "github.com/gin-gonic/gin"
+)
+
+type PageInfo struct {
+ Page int `json:"page"` // page num 页码
+ PageSize int `json:"page_size"` // page size 页大小
+
+ Total int `json:"total"` // 总条数,后设置
+ Items any `json:"items"` // 数据,后设置
+}
+
+func (p *PageInfo) GetStartIdx() int {
+ return (p.Page - 1) * p.PageSize
+}
+
+func (p *PageInfo) GetEndIdx() int {
+ return p.Page * p.PageSize
+}
+
+func (p *PageInfo) GetPageSize() int {
+ return p.PageSize
+}
+
+func (p *PageInfo) GetPage() int {
+ return p.Page
+}
+
+func (p *PageInfo) SetTotal(total int) {
+ p.Total = total
+}
+
+func (p *PageInfo) SetItems(items any) {
+ p.Items = items
+}
+
+func GetPageQuery(c *gin.Context) *PageInfo {
+ pageInfo := &PageInfo{}
+ // 手动获取并处理每个参数
+ if page, err := strconv.Atoi(c.Query("p")); err == nil {
+ pageInfo.Page = page
+ }
+ if pageSize, err := strconv.Atoi(c.Query("page_size")); err == nil {
+ pageInfo.PageSize = pageSize
+ }
+ if pageInfo.Page < 1 {
+ // 兼容
+ page, _ := strconv.Atoi(c.Query("p"))
+ if page != 0 {
+ pageInfo.Page = page
+ } else {
+ pageInfo.Page = 1
+ }
+ }
+
+ if pageInfo.PageSize == 0 {
+ // 兼容
+ pageSize, _ := strconv.Atoi(c.Query("ps"))
+ if pageSize != 0 {
+ pageInfo.PageSize = pageSize
+ }
+ if pageInfo.PageSize == 0 {
+ pageSize, _ = strconv.Atoi(c.Query("size")) // token page
+ if pageSize != 0 {
+ pageInfo.PageSize = pageSize
+ }
+ }
+ if pageInfo.PageSize == 0 {
+ pageInfo.PageSize = ItemsPerPage
+ }
+ }
+
+ if pageInfo.PageSize > 100 {
+ pageInfo.PageSize = 100
+ }
+
+ return pageInfo
+}
diff --git a/common/quota.go b/common/quota.go
new file mode 100644
index 00000000..dfd65d27
--- /dev/null
+++ b/common/quota.go
@@ -0,0 +1,5 @@
+package common
+
+func GetTrustQuota() int {
+ return int(10 * QuotaPerUnit)
+}
diff --git a/common/redis.go b/common/redis.go
index 49d3ec78..c7287837 100644
--- a/common/redis.go
+++ b/common/redis.go
@@ -16,6 +16,10 @@ import (
var RDB *redis.Client
var RedisEnabled = true
+func RedisKeyCacheSeconds() int {
+ return SyncFrequency
+}
+
// InitRedisClient This function is called after init()
func InitRedisClient() (err error) {
if os.Getenv("REDIS_CONN_STRING") == "" {
@@ -92,12 +96,12 @@ func RedisDel(key string) error {
return RDB.Del(ctx, key).Err()
}
-func RedisHDelObj(key string) error {
+func RedisDelKey(key string) error {
if DebugEnabled {
- SysLog(fmt.Sprintf("Redis HDEL: key=%s", key))
+ SysLog(fmt.Sprintf("Redis DEL Key: key=%s", key))
}
ctx := context.Background()
- return RDB.HDel(ctx, key).Err()
+ return RDB.Del(ctx, key).Err()
}
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
@@ -141,7 +145,11 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
txn := RDB.TxPipeline()
txn.HSet(ctx, key, data)
- txn.Expire(ctx, key, expiration)
+
+ // 只有在 expiration 大于 0 时才设置过期时间
+ if expiration > 0 {
+ txn.Expire(ctx, key, expiration)
+ }
_, err := txn.Exec(ctx)
if err != nil {
diff --git a/common/str.go b/common/str.go
index d42fd837..6debce28 100644
--- a/common/str.go
+++ b/common/str.go
@@ -1,9 +1,13 @@
package common
import (
+ "encoding/base64"
"encoding/json"
"math/rand"
+ "net/url"
+ "regexp"
"strconv"
+ "strings"
"unsafe"
)
@@ -31,16 +35,30 @@ func MapToJsonStr(m map[string]interface{}) string {
return string(bytes)
}
-func StrToMap(str string) map[string]interface{} {
+func StrToMap(str string) (map[string]interface{}, error) {
m := make(map[string]interface{})
- err := json.Unmarshal([]byte(str), &m)
+ err := Unmarshal([]byte(str), &m)
if err != nil {
- return nil
+ return nil, err
}
- return m
+ return m, nil
}
-func IsJsonStr(str string) bool {
+func StrToJsonArray(str string) ([]interface{}, error) {
+ var js []interface{}
+ err := json.Unmarshal([]byte(str), &js)
+ if err != nil {
+ return nil, err
+ }
+ return js, nil
+}
+
+func IsJsonArray(str string) bool {
+ var js []interface{}
+ return json.Unmarshal([]byte(str), &js) == nil
+}
+
+func IsJsonObject(str string) bool {
var js map[string]interface{}
return json.Unmarshal([]byte(str), &js) == nil
}
@@ -68,3 +86,152 @@ func StringToByteSlice(s string) []byte {
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2))
}
+
+func EncodeBase64(str string) string {
+ return base64.StdEncoding.EncodeToString([]byte(str))
+}
+
+func GetJsonString(data any) string {
+ if data == nil {
+ return ""
+ }
+ b, _ := json.Marshal(data)
+ return string(b)
+}
+
+// MaskEmail masks a user email to prevent PII leakage in logs
+// Returns "***masked***" if email is empty, otherwise shows only the domain part
+func MaskEmail(email string) string {
+ if email == "" {
+ return "***masked***"
+ }
+
+ // Find the @ symbol
+ atIndex := strings.Index(email, "@")
+ if atIndex == -1 {
+ // No @ symbol found, return masked
+ return "***masked***"
+ }
+
+ // Return only the domain part with @ symbol
+ return "***@" + email[atIndex+1:]
+}
+
+// maskHostTail returns the tail parts of a domain/host that should be preserved.
+// It keeps 2 parts for likely country-code TLDs (e.g., co.uk, com.cn), otherwise keeps only the TLD.
+func maskHostTail(parts []string) []string {
+ if len(parts) < 2 {
+ return parts
+ }
+ lastPart := parts[len(parts)-1]
+ secondLastPart := parts[len(parts)-2]
+ if len(lastPart) == 2 && len(secondLastPart) <= 3 {
+ // Likely country code TLD like co.uk, com.cn
+ return []string{secondLastPart, lastPart}
+ }
+ return []string{lastPart}
+}
+
+// maskHostForURL collapses subdomains and keeps only masked prefix + preserved tail.
+// Example: api.openai.com -> ***.com, sub.domain.co.uk -> ***.co.uk
+func maskHostForURL(host string) string {
+ parts := strings.Split(host, ".")
+ if len(parts) < 2 {
+ return "***"
+ }
+ tail := maskHostTail(parts)
+ return "***." + strings.Join(tail, ".")
+}
+
+// maskHostForPlainDomain masks a plain domain and reflects subdomain depth with multiple ***.
+// Example: openai.com -> ***.com, api.openai.com -> ***.***.com, sub.domain.co.uk -> ***.***.co.uk
+func maskHostForPlainDomain(domain string) string {
+ parts := strings.Split(domain, ".")
+ if len(parts) < 2 {
+ return domain
+ }
+ tail := maskHostTail(parts)
+ numStars := len(parts) - len(tail)
+ if numStars < 1 {
+ numStars = 1
+ }
+ stars := strings.TrimSuffix(strings.Repeat("***.", numStars), ".")
+ return stars + "." + strings.Join(tail, ".")
+}
+
+// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
+// Example:
+// http://example.com -> http://***.com
+// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
+// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
+// 192.168.1.1 -> ***.***.***.***
+// openai.com -> ***.com
+// www.openai.com -> ***.***.com
+// api.openai.com -> ***.***.com
+func MaskSensitiveInfo(str string) string {
+ // Mask URLs
+ urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
+ str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
+ u, err := url.Parse(urlStr)
+ if err != nil {
+ return urlStr
+ }
+
+ host := u.Host
+ if host == "" {
+ return urlStr
+ }
+
+ // Mask host with unified logic
+ maskedHost := maskHostForURL(host)
+
+ result := u.Scheme + "://" + maskedHost
+
+ // Mask path
+ if u.Path != "" && u.Path != "/" {
+ pathParts := strings.Split(strings.Trim(u.Path, "/"), "/")
+ maskedPathParts := make([]string, len(pathParts))
+ for i := range pathParts {
+ if pathParts[i] != "" {
+ maskedPathParts[i] = "***"
+ }
+ }
+ if len(maskedPathParts) > 0 {
+ result += "/" + strings.Join(maskedPathParts, "/")
+ }
+ } else if u.Path == "/" {
+ result += "/"
+ }
+
+ // Mask query parameters
+ if u.RawQuery != "" {
+ values, err := url.ParseQuery(u.RawQuery)
+ if err != nil {
+ // If can't parse query, just mask the whole query string
+ result += "?***"
+ } else {
+ maskedParams := make([]string, 0, len(values))
+ for key := range values {
+ maskedParams = append(maskedParams, key+"=***")
+ }
+ if len(maskedParams) > 0 {
+ result += "?" + strings.Join(maskedParams, "&")
+ }
+ }
+ }
+
+ return result
+ })
+
+ // Mask domain names without protocol (like openai.com, www.openai.com)
+ domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
+ str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
+ return maskHostForPlainDomain(domain)
+ })
+
+ // Mask IP addresses
+ ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
+ str = ipPattern.ReplaceAllString(str, "***.***.***.***")
+
+ return str
+}
diff --git a/common/sys_log.go b/common/sys_log.go
new file mode 100644
index 00000000..478015f0
--- /dev/null
+++ b/common/sys_log.go
@@ -0,0 +1,24 @@
+package common
+
+import (
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "os"
+ "time"
+)
+
+func SysLog(s string) {
+ t := time.Now()
+ _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
+}
+
+func SysError(s string) {
+ t := time.Now()
+ _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
+}
+
+func FatalLog(v ...any) {
+ t := time.Now()
+ _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
+ os.Exit(1)
+}
diff --git a/common/totp.go b/common/totp.go
new file mode 100644
index 00000000..400f9d05
--- /dev/null
+++ b/common/totp.go
@@ -0,0 +1,150 @@
+package common
+
+import (
+ "crypto/rand"
+ "fmt"
+ "os"
+ "strconv"
+ "strings"
+
+ "github.com/pquerna/otp"
+ "github.com/pquerna/otp/totp"
+)
+
+const (
+ // 备用码配置
+ BackupCodeLength = 8 // 备用码长度
+ BackupCodeCount = 4 // 生成备用码数量
+
+ // 限制配置
+ MaxFailAttempts = 5 // 最大失败尝试次数
+ LockoutDuration = 300 // 锁定时间(秒)
+)
+
+// GenerateTOTPSecret 生成TOTP密钥和配置
+func GenerateTOTPSecret(accountName string) (*otp.Key, error) {
+ issuer := Get2FAIssuer()
+ return totp.Generate(totp.GenerateOpts{
+ Issuer: issuer,
+ AccountName: accountName,
+ Period: 30,
+ Digits: otp.DigitsSix,
+ Algorithm: otp.AlgorithmSHA1,
+ })
+}
+
+// ValidateTOTPCode 验证TOTP验证码
+func ValidateTOTPCode(secret, code string) bool {
+ // 清理验证码格式
+ cleanCode := strings.ReplaceAll(code, " ", "")
+ if len(cleanCode) != 6 {
+ return false
+ }
+
+ // 验证验证码
+ return totp.Validate(cleanCode, secret)
+}
+
+// GenerateBackupCodes 生成备用恢复码
+func GenerateBackupCodes() ([]string, error) {
+ codes := make([]string, BackupCodeCount)
+
+ for i := 0; i < BackupCodeCount; i++ {
+ code, err := generateRandomBackupCode()
+ if err != nil {
+ return nil, err
+ }
+ codes[i] = code
+ }
+
+ return codes, nil
+}
+
+// generateRandomBackupCode 生成单个备用码
+func generateRandomBackupCode() (string, error) {
+ const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
+ code := make([]byte, BackupCodeLength)
+
+ for i := range code {
+ randomBytes := make([]byte, 1)
+ _, err := rand.Read(randomBytes)
+ if err != nil {
+ return "", err
+ }
+ code[i] = charset[int(randomBytes[0])%len(charset)]
+ }
+
+ // 格式化为 XXXX-XXXX 格式
+ return fmt.Sprintf("%s-%s", string(code[:4]), string(code[4:])), nil
+}
+
+// ValidateBackupCode 验证备用码格式
+func ValidateBackupCode(code string) bool {
+ // 移除所有分隔符并转为大写
+ cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
+ if len(cleanCode) != BackupCodeLength {
+ return false
+ }
+
+ // 检查字符是否合法
+ for _, char := range cleanCode {
+ if !((char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9')) {
+ return false
+ }
+ }
+
+ return true
+}
+
+// NormalizeBackupCode 标准化备用码格式
+func NormalizeBackupCode(code string) string {
+ cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
+ if len(cleanCode) == BackupCodeLength {
+ return fmt.Sprintf("%s-%s", cleanCode[:4], cleanCode[4:])
+ }
+ return code
+}
+
+// HashBackupCode 对备用码进行哈希
+func HashBackupCode(code string) (string, error) {
+ normalizedCode := NormalizeBackupCode(code)
+ return Password2Hash(normalizedCode)
+}
+
+// Get2FAIssuer 获取2FA发行者名称
+func Get2FAIssuer() string {
+ return SystemName
+}
+
+// getEnvOrDefault 获取环境变量或默认值
+func getEnvOrDefault(key, defaultValue string) string {
+ if value, exists := os.LookupEnv(key); exists {
+ return value
+ }
+ return defaultValue
+}
+
+// ValidateNumericCode 验证数字验证码格式
+func ValidateNumericCode(code string) (string, error) {
+ // 移除空格
+ code = strings.ReplaceAll(code, " ", "")
+
+ if len(code) != 6 {
+ return "", fmt.Errorf("验证码必须是6位数字")
+ }
+
+ // 检查是否为纯数字
+ if _, err := strconv.Atoi(code); err != nil {
+ return "", fmt.Errorf("验证码只能包含数字")
+ }
+
+ return code, nil
+}
+
+// GenerateQRCodeData 生成二维码数据
+func GenerateQRCodeData(secret, username string) string {
+ issuer := Get2FAIssuer()
+ accountName := fmt.Sprintf("%s (%s)", username, issuer)
+ return fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&digits=6&period=30",
+ issuer, accountName, secret, issuer)
+}
diff --git a/common/utils.go b/common/utils.go
index 587de537..17aecd95 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -13,6 +13,7 @@ import (
"math/big"
"math/rand"
"net"
+ "net/url"
"os"
"os/exec"
"runtime"
@@ -249,13 +250,55 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
}
// GetAudioDuration returns the duration of an audio file in seconds.
-func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
+func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
output, err := c.Output()
if err != nil {
return 0, errors.Wrap(err, "failed to get audio duration")
}
+ durationStr := string(bytes.TrimSpace(output))
+ if durationStr == "N/A" {
+ // Create a temporary output file name
+ tmpFp, err := os.CreateTemp("", "audio-*"+ext)
+ if err != nil {
+ return 0, errors.Wrap(err, "failed to create temporary file")
+ }
+ tmpName := tmpFp.Name()
+ // Close immediately so ffmpeg can open the file on Windows.
+ _ = tmpFp.Close()
+ defer os.Remove(tmpName)
- return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
+ // ffmpeg -y -i filename -vcodec copy -acodec copy
+ ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
+ if err := ffmpegCmd.Run(); err != nil {
+ return 0, errors.Wrap(err, "failed to run ffmpeg")
+ }
+
+ // Recalculate the duration of the new file
+ c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
+ output, err := c.Output()
+ if err != nil {
+ return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
+ }
+ durationStr = string(bytes.TrimSpace(output))
+ }
+ return strconv.ParseFloat(durationStr, 64)
+}
+
+// BuildURL concatenates base and endpoint, returns the complete url string
+func BuildURL(base string, endpoint string) string {
+ u, err := url.Parse(base)
+ if err != nil {
+ return base + endpoint
+ }
+ end := endpoint
+ if end == "" {
+ end = "/"
+ }
+ ref, err := url.Parse(end)
+ if err != nil {
+ return base + endpoint
+ }
+ return u.ResolveReference(ref).String()
}
diff --git a/constant/README.md b/constant/README.md
new file mode 100644
index 00000000..12a9ffad
--- /dev/null
+++ b/constant/README.md
@@ -0,0 +1,26 @@
+# constant 包 (`/constant`)
+
+该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
+
+## 当前文件
+
+| 文件 | 说明 |
+|----------------------|---------------------------------------------------------------------|
+| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 |
+| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 |
+| `channel_setting.go` | Channel 级别的设置键,如 `proxy`、`force_format` 等。 |
+| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量(请求时间、Token/Channel/User 相关信息等)。 |
+| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 |
+| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 |
+| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 |
+| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 |
+| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 |
+| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 |
+
+## 使用约定
+
+1. `constant` 包**只能被其他包引用**(import),**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**。
+2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
+3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
+
+> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。
\ No newline at end of file
diff --git a/constant/api_type.go b/constant/api_type.go
new file mode 100644
index 00000000..f62d91d5
--- /dev/null
+++ b/constant/api_type.go
@@ -0,0 +1,36 @@
+package constant
+
+const (
+ APITypeOpenAI = iota
+ APITypeAnthropic
+ APITypePaLM
+ APITypeBaidu
+ APITypeZhipu
+ APITypeAli
+ APITypeXunfei
+ APITypeAIProxyLibrary
+ APITypeTencent
+ APITypeGemini
+ APITypeZhipuV4
+ APITypeOllama
+ APITypePerplexity
+ APITypeAws
+ APITypeCohere
+ APITypeDify
+ APITypeJina
+ APITypeCloudflare
+ APITypeSiliconFlow
+ APITypeVertexAi
+ APITypeMistral
+ APITypeDeepSeek
+ APITypeMokaAI
+ APITypeVolcEngine
+ APITypeBaiduV2
+ APITypeOpenRouter
+ APITypeXinference
+ APITypeXai
+ APITypeCoze
+ APITypeJimeng
+ APITypeMoonshot // this one is only for count, do not add any channel after this
+ APITypeDummy // this one is only for count, do not add any channel after this
+)
diff --git a/constant/cache_key.go b/constant/cache_key.go
index 27cb3b75..0601396a 100644
--- a/constant/cache_key.go
+++ b/constant/cache_key.go
@@ -1,14 +1,5 @@
package constant
-import "one-api/common"
-
-var (
- TokenCacheSeconds = common.SyncFrequency
- UserId2GroupCacheSeconds = common.SyncFrequency
- UserId2QuotaCacheSeconds = common.SyncFrequency
- UserId2StatusCacheSeconds = common.SyncFrequency
-)
-
// Cache keys
const (
UserGroupKeyFmt = "user_group:%d"
diff --git a/constant/channel.go b/constant/channel.go
new file mode 100644
index 00000000..2e1cc5b0
--- /dev/null
+++ b/constant/channel.go
@@ -0,0 +1,111 @@
+package constant
+
+const (
+ ChannelTypeUnknown = 0
+ ChannelTypeOpenAI = 1
+ ChannelTypeMidjourney = 2
+ ChannelTypeAzure = 3
+ ChannelTypeOllama = 4
+ ChannelTypeMidjourneyPlus = 5
+ ChannelTypeOpenAIMax = 6
+ ChannelTypeOhMyGPT = 7
+ ChannelTypeCustom = 8
+ ChannelTypeAILS = 9
+ ChannelTypeAIProxy = 10
+ ChannelTypePaLM = 11
+ ChannelTypeAPI2GPT = 12
+ ChannelTypeAIGC2D = 13
+ ChannelTypeAnthropic = 14
+ ChannelTypeBaidu = 15
+ ChannelTypeZhipu = 16
+ ChannelTypeAli = 17
+ ChannelTypeXunfei = 18
+ ChannelType360 = 19
+ ChannelTypeOpenRouter = 20
+ ChannelTypeAIProxyLibrary = 21
+ ChannelTypeFastGPT = 22
+ ChannelTypeTencent = 23
+ ChannelTypeGemini = 24
+ ChannelTypeMoonshot = 25
+ ChannelTypeZhipu_v4 = 26
+ ChannelTypePerplexity = 27
+ ChannelTypeLingYiWanWu = 31
+ ChannelTypeAws = 33
+ ChannelTypeCohere = 34
+ ChannelTypeMiniMax = 35
+ ChannelTypeSunoAPI = 36
+ ChannelTypeDify = 37
+ ChannelTypeJina = 38
+ ChannelCloudflare = 39
+ ChannelTypeSiliconFlow = 40
+ ChannelTypeVertexAi = 41
+ ChannelTypeMistral = 42
+ ChannelTypeDeepSeek = 43
+ ChannelTypeMokaAI = 44
+ ChannelTypeVolcEngine = 45
+ ChannelTypeBaiduV2 = 46
+ ChannelTypeXinference = 47
+ ChannelTypeXai = 48
+ ChannelTypeCoze = 49
+ ChannelTypeKling = 50
+ ChannelTypeJimeng = 51
+ ChannelTypeVidu = 52
+ ChannelTypeDummy // this one is only for count, do not add any channel after this
+
+)
+
+var ChannelBaseURLs = []string{
+ "", // 0
+ "https://api.openai.com", // 1
+ "https://oa.api2d.net", // 2
+ "", // 3
+ "http://localhost:11434", // 4
+ "https://api.openai-sb.com", // 5
+ "https://api.openaimax.com", // 6
+ "https://api.ohmygpt.com", // 7
+ "", // 8
+ "https://api.caipacity.com", // 9
+ "https://api.aiproxy.io", // 10
+ "", // 11
+ "https://api.api2gpt.com", // 12
+ "https://api.aigc2d.com", // 13
+ "https://api.anthropic.com", // 14
+ "https://aip.baidubce.com", // 15
+ "https://open.bigmodel.cn", // 16
+ "https://dashscope.aliyuncs.com", // 17
+ "", // 18
+ "https://api.360.cn", // 19
+ "https://openrouter.ai/api", // 20
+ "https://api.aiproxy.io", // 21
+ "https://fastgpt.run/api/openapi", // 22
+ "https://hunyuan.tencentcloudapi.com", //23
+ "https://generativelanguage.googleapis.com", //24
+ "https://api.moonshot.cn", //25
+ "https://open.bigmodel.cn", //26
+ "https://api.perplexity.ai", //27
+ "", //28
+ "", //29
+ "", //30
+ "https://api.lingyiwanwu.com", //31
+ "", //32
+ "", //33
+ "https://api.cohere.ai", //34
+ "https://api.minimax.chat", //35
+ "", //36
+ "https://api.dify.ai", //37
+ "https://api.jina.ai", //38
+ "https://api.cloudflare.com", //39
+ "https://api.siliconflow.cn", //40
+ "", //41
+ "https://api.mistral.ai", //42
+ "https://api.deepseek.com", //43
+ "https://api.moka.ai", //44
+ "https://ark.cn-beijing.volces.com", //45
+ "https://qianfan.baidubce.com", //46
+ "", //47
+ "https://api.x.ai", //48
+ "https://api.coze.cn", //49
+ "https://api.klingai.com", //50
+ "https://visual.volcengineapi.com", //51
+ "https://api.vidu.cn", //52
+}
diff --git a/constant/channel_setting.go b/constant/channel_setting.go
deleted file mode 100644
index e06e7eb1..00000000
--- a/constant/channel_setting.go
+++ /dev/null
@@ -1,7 +0,0 @@
-package constant
-
-var (
- ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
- ChanelSettingProxy = "proxy" // Proxy 代理
- ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
-)
diff --git a/constant/context_key.go b/constant/context_key.go
index 4b4d5cae..3945243c 100644
--- a/constant/context_key.go
+++ b/constant/context_key.go
@@ -1,10 +1,49 @@
package constant
+type ContextKey string
+
const (
- ContextKeyRequestStartTime = "request_start_time"
- ContextKeyUserSetting = "user_setting"
- ContextKeyUserQuota = "user_quota"
- ContextKeyUserStatus = "user_status"
- ContextKeyUserEmail = "user_email"
- ContextKeyUserGroup = "user_group"
+ ContextKeyTokenCountMeta ContextKey = "token_count_meta"
+ ContextKeyPromptTokens ContextKey = "prompt_tokens"
+
+ ContextKeyOriginalModel ContextKey = "original_model"
+ ContextKeyRequestStartTime ContextKey = "request_start_time"
+
+ /* token related keys */
+ ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota"
+ ContextKeyTokenKey ContextKey = "token_key"
+ ContextKeyTokenId ContextKey = "token_id"
+ ContextKeyTokenGroup ContextKey = "token_group"
+ ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
+ ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
+ ContextKeyTokenModelLimit ContextKey = "token_model_limit"
+
+ /* channel related keys */
+ ContextKeyChannelId ContextKey = "channel_id"
+ ContextKeyChannelName ContextKey = "channel_name"
+ ContextKeyChannelCreateTime ContextKey = "channel_create_time"
+ ContextKeyChannelBaseUrl ContextKey = "base_url"
+ ContextKeyChannelType ContextKey = "channel_type"
+ ContextKeyChannelSetting ContextKey = "channel_setting"
+ ContextKeyChannelOtherSetting ContextKey = "channel_other_setting"
+ ContextKeyChannelParamOverride ContextKey = "param_override"
+ ContextKeyChannelOrganization ContextKey = "channel_organization"
+ ContextKeyChannelAutoBan ContextKey = "auto_ban"
+ ContextKeyChannelModelMapping ContextKey = "model_mapping"
+ ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
+ ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
+ ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
+ ContextKeyChannelKey ContextKey = "channel_key"
+
+ /* user related keys */
+ ContextKeyUserId ContextKey = "id"
+ ContextKeyUserSetting ContextKey = "user_setting"
+ ContextKeyUserQuota ContextKey = "user_quota"
+ ContextKeyUserStatus ContextKey = "user_status"
+ ContextKeyUserEmail ContextKey = "user_email"
+ ContextKeyUserGroup ContextKey = "user_group"
+ ContextKeyUsingGroup ContextKey = "group"
+ ContextKeyUserName ContextKey = "username"
+
+ ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
)
diff --git a/constant/endpoint_type.go b/constant/endpoint_type.go
new file mode 100644
index 00000000..ef096b75
--- /dev/null
+++ b/constant/endpoint_type.go
@@ -0,0 +1,16 @@
+package constant
+
+type EndpointType string
+
+const (
+ EndpointTypeOpenAI EndpointType = "openai"
+ EndpointTypeOpenAIResponse EndpointType = "openai-response"
+ EndpointTypeAnthropic EndpointType = "anthropic"
+ EndpointTypeGemini EndpointType = "gemini"
+ EndpointTypeJinaRerank EndpointType = "jina-rerank"
+ EndpointTypeImageGeneration EndpointType = "image-generation"
+ //EndpointTypeMidjourney EndpointType = "midjourney-proxy"
+ //EndpointTypeSuno EndpointType = "suno-proxy"
+ //EndpointTypeKling EndpointType = "kling"
+ //EndpointTypeJimeng EndpointType = "jimeng"
+)
diff --git a/constant/env.go b/constant/env.go
index 612f3e8b..8bc2f131 100644
--- a/constant/env.go
+++ b/constant/env.go
@@ -1,9 +1,5 @@
package constant
-import (
- "one-api/common"
-)
-
var StreamingTimeout int
var DifyDebug bool
var MaxFileDownloadMB int
@@ -17,39 +13,3 @@ var NotifyLimitCount int
var NotificationLimitDurationMinute int
var GenerateDefaultToken bool
var ErrorLogEnabled bool
-
-//var GeminiModelMap = map[string]string{
-// "gemini-1.0-pro": "v1",
-//}
-
-func InitEnv() {
- StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
- DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
- MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
- // ForceStreamOption 覆盖请求参数,强制返回usage信息
- ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
- GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
- GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
- UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
- AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
- GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
- NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
- NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
- // GenerateDefaultToken 是否生成初始令牌,默认关闭。
- GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
- // 是否启用错误日志
- ErrorLogEnabled = common.GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
-
- //modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
- //if modelVersionMapStr == "" {
- // return
- //}
- //for _, pair := range strings.Split(modelVersionMapStr, ",") {
- // parts := strings.Split(pair, ":")
- // if len(parts) == 2 {
- // GeminiModelMap[parts[0]] = parts[1]
- // } else {
- // common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
- // }
- //}
-}
diff --git a/constant/midjourney.go b/constant/midjourney.go
index 1bf4d549..5934be2f 100644
--- a/constant/midjourney.go
+++ b/constant/midjourney.go
@@ -22,6 +22,8 @@ const (
MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE"
MjActionUpload = "UPLOAD"
+ MjActionVideo = "VIDEO"
+ MjActionEdits = "EDITS"
)
var MidjourneyModel2Action = map[string]string{
@@ -41,4 +43,6 @@ var MidjourneyModel2Action = map[string]string{
"mj_pan": MjActionPan,
"swap_face": MjActionSwapFace,
"mj_upload": MjActionUpload,
+ "mj_video": MjActionVideo,
+ "mj_edits": MjActionEdits,
}
diff --git a/constant/multi_key_mode.go b/constant/multi_key_mode.go
new file mode 100644
index 00000000..cd0cdbff
--- /dev/null
+++ b/constant/multi_key_mode.go
@@ -0,0 +1,8 @@
+package constant
+
+type MultiKeyMode string
+
+const (
+ MultiKeyModeRandom MultiKeyMode = "random" // 随机
+ MultiKeyModePolling MultiKeyMode = "polling" // 轮询
+)
diff --git a/constant/task.go b/constant/task.go
index 1a68b812..21790145 100644
--- a/constant/task.go
+++ b/constant/task.go
@@ -10,6 +10,9 @@ const (
const (
SunoActionMusic = "MUSIC"
SunoActionLyrics = "LYRICS"
+
+ TaskActionGenerate = "generate"
+ TaskActionTextGenerate = "textGenerate"
)
var SunoModel2Action = map[string]string{
diff --git a/constant/user_setting.go b/constant/user_setting.go
deleted file mode 100644
index 055884f7..00000000
--- a/constant/user_setting.go
+++ /dev/null
@@ -1,15 +0,0 @@
-package constant
-
-var (
- UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
- UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
- UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
- UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
- UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
- UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
-)
-
-var (
- NotifyTypeEmail = "email" // Email 邮件
- NotifyTypeWebhook = "webhook" // Webhook
-)
diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index 2bda0fd2..18acf231 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -7,11 +7,16 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/model"
"one-api/service"
+ "one-api/setting"
+ "one-api/types"
"strconv"
"time"
+ "github.com/shopspring/decimal"
+
"github.com/gin-gonic/gin"
)
@@ -130,7 +135,11 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers {
req.Header.Add(k, headers.Get(k))
}
- res, err := service.GetHttpClient().Do(req)
+ client, err := service.NewProxyHttpClient(channel.GetSetting().Proxy)
+ if err != nil {
+ return nil, err
+ }
+ res, err := client.Do(req)
if err != nil {
return nil, err
}
@@ -304,34 +313,70 @@ func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
return balance, nil
}
+func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
+ url := "https://api.moonshot.cn/v1/users/me/balance"
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+ if err != nil {
+ return 0, err
+ }
+
+ type MoonshotBalanceData struct {
+ AvailableBalance float64 `json:"available_balance"`
+ VoucherBalance float64 `json:"voucher_balance"`
+ CashBalance float64 `json:"cash_balance"`
+ }
+
+ type MoonshotBalanceResponse struct {
+ Code int `json:"code"`
+ Data MoonshotBalanceData `json:"data"`
+ Scode string `json:"scode"`
+ Status bool `json:"status"`
+ }
+
+ response := MoonshotBalanceResponse{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ return 0, err
+ }
+ if !response.Status || response.Code != 0 {
+ return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
+ }
+ availableBalanceCny := response.Data.AvailableBalance
+ availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
+ channel.UpdateBalance(availableBalanceUsd)
+ return availableBalanceUsd, nil
+}
+
func updateChannelBalance(channel *model.Channel) (float64, error) {
- baseURL := common.ChannelBaseURLs[channel.Type]
+ baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" {
channel.BaseURL = &baseURL
}
switch channel.Type {
- case common.ChannelTypeOpenAI:
+ case constant.ChannelTypeOpenAI:
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
- case common.ChannelTypeAzure:
+ case constant.ChannelTypeAzure:
return 0, errors.New("尚未实现")
- case common.ChannelTypeCustom:
+ case constant.ChannelTypeCustom:
baseURL = channel.GetBaseURL()
//case common.ChannelTypeOpenAISB:
// return updateChannelOpenAISBBalance(channel)
- case common.ChannelTypeAIProxy:
+ case constant.ChannelTypeAIProxy:
return updateChannelAIProxyBalance(channel)
- case common.ChannelTypeAPI2GPT:
+ case constant.ChannelTypeAPI2GPT:
return updateChannelAPI2GPTBalance(channel)
- case common.ChannelTypeAIGC2D:
+ case constant.ChannelTypeAIGC2D:
return updateChannelAIGC2DBalance(channel)
- case common.ChannelTypeSiliconFlow:
+ case constant.ChannelTypeSiliconFlow:
return updateChannelSiliconFlowBalance(channel)
- case common.ChannelTypeDeepSeek:
+ case constant.ChannelTypeDeepSeek:
return updateChannelDeepSeekBalance(channel)
- case common.ChannelTypeOpenRouter:
+ case constant.ChannelTypeOpenRouter:
return updateChannelOpenRouterBalance(channel)
+ case constant.ChannelTypeMoonshot:
+ return updateChannelMoonshotBalance(channel)
default:
return 0, errors.New("尚未实现")
}
@@ -370,26 +415,24 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
func UpdateChannelBalance(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
- channel, err := model.GetChannelById(id, true)
+ channel, err := model.CacheGetChannel(id)
if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if channel.ChannelInfo.IsMultiKey {
c.JSON(http.StatusOK, gin.H{
"success": false,
- "message": err.Error(),
+ "message": "多密钥渠道不支持余额查询",
})
return
}
balance, err := updateChannelBalance(channel)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -397,7 +440,6 @@ func UpdateChannelBalance(c *gin.Context) {
"message": "",
"balance": balance,
})
- return
}
func updateAllChannelsBalance() error {
@@ -409,6 +451,9 @@ func updateAllChannelsBalance() error {
if channel.Status != common.ChannelStatusEnabled {
continue
}
+ if channel.ChannelInfo.IsMultiKey {
+ continue // skip multi-key channels
+ }
// TODO: support Azure
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
// continue
@@ -419,7 +464,7 @@ func updateAllChannelsBalance() error {
} else {
// err is nil & balance <= 0 means quota is used up
if balance <= 0 {
- service.DisableChannel(channel.Id, channel.Name, "余额不足")
+ service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足")
}
}
time.Sleep(common.RequestInterval)
@@ -431,10 +476,7 @@ func UpdateAllChannelsBalance(c *gin.Context) {
// TODO: make it async
err := updateAllChannelsBalance()
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
diff --git a/controller/channel-test.go b/controller/channel-test.go
index d1cb4093..81f7e19a 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -11,14 +11,16 @@ import (
"net/http/httptest"
"net/url"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/relay"
relaycommon "one-api/relay/common"
- "one-api/relay/constant"
+ relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
+ "one-api/types"
"strconv"
"strings"
"sync"
@@ -29,16 +31,49 @@ import (
"github.com/gin-gonic/gin"
)
-func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
+type testResult struct {
+ context *gin.Context
+ localErr error
+ newAPIError *types.NewAPIError
+}
+
+func testChannel(channel *model.Channel, testModel string) testResult {
tik := time.Now()
- if channel.Type == common.ChannelTypeMidjourney {
- return errors.New("midjourney channel test is not supported"), nil
+ if channel.Type == constant.ChannelTypeMidjourney {
+ return testResult{
+ localErr: errors.New("midjourney channel test is not supported"),
+ newAPIError: nil,
+ }
}
- if channel.Type == common.ChannelTypeMidjourneyPlus {
- return errors.New("midjourney plus channel test is not supported!!!"), nil
+ if channel.Type == constant.ChannelTypeMidjourneyPlus {
+ return testResult{
+ localErr: errors.New("midjourney plus channel test is not supported"),
+ newAPIError: nil,
+ }
}
- if channel.Type == common.ChannelTypeSunoAPI {
- return errors.New("suno channel test is not supported"), nil
+ if channel.Type == constant.ChannelTypeSunoAPI {
+ return testResult{
+ localErr: errors.New("suno channel test is not supported"),
+ newAPIError: nil,
+ }
+ }
+ if channel.Type == constant.ChannelTypeKling {
+ return testResult{
+ localErr: errors.New("kling channel test is not supported"),
+ newAPIError: nil,
+ }
+ }
+ if channel.Type == constant.ChannelTypeJimeng {
+ return testResult{
+ localErr: errors.New("jimeng channel test is not supported"),
+ newAPIError: nil,
+ }
+ }
+ if channel.Type == constant.ChannelTypeVidu {
+ return testResult{
+ localErr: errors.New("vidu channel test is not supported"),
+ newAPIError: nil,
+ }
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -50,7 +85,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型
strings.Contains(testModel, "embed") ||
- channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
+ channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径
}
@@ -75,80 +110,162 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
cache, err := model.GetUserCache(1)
if err != nil {
- return err, nil
+ return testResult{
+ localErr: err,
+ newAPIError: nil,
+ }
}
cache.WriteContext(c)
- c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
+ //c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
group, _ := model.GetUserGroup(1, false)
c.Set("group", group)
- middleware.SetupContextForSelectedChannel(c, channel, testModel)
-
- info := relaycommon.GenRelayInfo(c)
-
- err = helper.ModelMappedHelper(c, info)
- if err != nil {
- return err, nil
+ newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
+ if newAPIError != nil {
+ return testResult{
+ context: c,
+ localErr: newAPIError,
+ newAPIError: newAPIError,
+ }
}
- testModel = info.UpstreamModelName
+ request := buildTestRequest(testModel)
- apiType, _ := constant.ChannelType2APIType(channel.Type)
+ // Determine relay format based on request path
+ relayFormat := types.RelayFormatOpenAI
+ if c.Request.URL.Path == "/v1/embeddings" {
+ relayFormat = types.RelayFormatEmbedding
+ }
+
+ info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
+
+ if err != nil {
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed),
+ }
+ }
+
+ info.InitChannelMeta(c)
+
+ err = helper.ModelMappedHelper(c, info, request)
+ if err != nil {
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
+ }
+ }
+
+ testModel = info.UpstreamModelName
+ request.Model = testModel
+
+ apiType, _ := common.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
- return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
+ return testResult{
+ context: c,
+ localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
+ newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
+ }
}
- request := buildTestRequest(testModel)
- // 创建一个用于日志的 info 副本,移除 ApiKey
- logInfo := *info
- logInfo.ApiKey = ""
- common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
+ //// 创建一个用于日志的 info 副本,移除 ApiKey
+ //logInfo := info
+ //logInfo.ApiKey = ""
+ common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString()))
- priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
+ priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta())
if err != nil {
- return err, nil
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
+ }
}
adaptor.Init(info)
- convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
+ var convertedRequest any
+ // 根据 RelayMode 选择正确的转换函数
+ if info.RelayMode == relayconstant.RelayModeEmbeddings {
+ // 创建一个 EmbeddingRequest
+ embeddingRequest := dto.EmbeddingRequest{
+ Input: request.Input,
+ Model: request.Model,
+ }
+ // 调用专门用于 Embedding 的转换函数
+ convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
+ } else {
+ // 对其他所有请求类型(如 Chat),保持原有逻辑
+ convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
+ }
+
if err != nil {
- return err, nil
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
+ }
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
- return err, nil
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
+ }
}
requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
- return err, nil
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
+ }
}
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(httpResp, true)
- return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
+ }
}
}
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
if respErr != nil {
- return fmt.Errorf("%s", respErr.Error.Message), respErr
+ return testResult{
+ context: c,
+ localErr: respErr,
+ newAPIError: respErr,
+ }
}
if usageA == nil {
- return errors.New("usage is nil"), nil
+ return testResult{
+ context: c,
+ localErr: errors.New("usage is nil"),
+ newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
+ }
}
usage := usageA.(*dto.Usage)
result := w.Result()
respBody, err := io.ReadAll(result.Body)
if err != nil {
- return err, nil
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
+ }
}
info.PromptTokens = usage.PromptTokens
@@ -165,12 +282,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
- other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
- usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice)
- model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
- quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
+ other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
+ usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+ model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
+ ChannelId: channel.Id,
+ PromptTokens: usage.PromptTokens,
+ CompletionTokens: usage.CompletionTokens,
+ ModelName: info.OriginModelName,
+ TokenName: "模型测试",
+ Quota: quota,
+ Content: "模型测试",
+ UseTimeSeconds: int(consumedTime),
+ IsStream: info.IsStream,
+ Group: info.UsingGroup,
+ Other: other,
+ })
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
- return nil, nil
+ return testResult{
+ context: c,
+ localErr: nil,
+ newAPIError: nil,
+ }
}
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
@@ -185,7 +317,7 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
strings.Contains(model, "bge-") {
testRequest.Model = model
// Embedding 请求
- testRequest.Input = []string{"hello world"}
+ testRequest.Input = []any{"hello world"} // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
return testRequest
}
// 并非Embedding 模型
@@ -196,14 +328,14 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
testRequest.MaxTokens = 50
}
} else if strings.Contains(model, "gemini") {
- testRequest.MaxTokens = 300
+ testRequest.MaxTokens = 3000
} else {
testRequest.MaxTokens = 10
}
- content, _ := json.Marshal("hi")
+
testMessage := dto.Message{
Role: "user",
- Content: content,
+ Content: "hi",
}
testRequest.Model = model
testRequest.Messages = append(testRequest.Messages, testMessage)
@@ -213,31 +345,41 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
func TestChannel(c *gin.Context) {
channelId, err := strconv.Atoi(c.Param("id"))
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
- channel, err := model.GetChannelById(channelId, true)
+ channel, err := model.CacheGetChannel(channelId)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
+ channel, err = model.GetChannelById(channelId, true)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
}
+ //defer func() {
+ // if channel.ChannelInfo.IsMultiKey {
+ // go func() { _ = channel.SaveChannelInfo() }()
+ // }
+ //}()
testModel := c.Query("model")
tik := time.Now()
- err, _ = testChannel(channel, testModel)
+ result := testChannel(channel, testModel)
+ if result.localErr != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": result.localErr.Error(),
+ "time": 0.0,
+ })
+ return
+ }
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0
- if err != nil {
+ if result.newAPIError != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
- "message": err.Error(),
+ "message": result.newAPIError.Error(),
"time": consumedTime,
})
return
@@ -262,52 +404,59 @@ func testAllChannels(notify bool) error {
}
testAllChannelsRunning = true
testAllChannelsLock.Unlock()
- channels, err := model.GetAllChannels(0, 0, true, false)
- if err != nil {
- return err
+ channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
+ if getChannelErr != nil {
+ return getChannelErr
}
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
gopool.Go(func() {
+ // 使用 defer 确保无论如何都会重置运行状态,防止死锁
+ defer func() {
+ testAllChannelsLock.Lock()
+ testAllChannelsRunning = false
+ testAllChannelsLock.Unlock()
+ }()
+
for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
- err, openaiWithStatusErr := testChannel(channel, "")
+ result := testChannel(channel, "")
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
shouldBanChannel := false
-
+ newAPIError := result.newAPIError
// request error disables the channel
- if openaiWithStatusErr != nil {
- oaiErr := openaiWithStatusErr.Error
- err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
- shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
+ if newAPIError != nil {
+ shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
}
- if milliseconds > disableThreshold {
- err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
- shouldBanChannel = true
+ // 当错误检查通过,才检查响应时间
+ if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
+ if milliseconds > disableThreshold {
+ err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
+ newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
+ shouldBanChannel = true
+ }
}
// disable channel
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
- service.DisableChannel(channel.Id, channel.Name, err.Error())
+ processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
}
// enable channel
- if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
- service.EnableChannel(channel.Id, channel.Name)
+ if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
+ service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}
- testAllChannelsLock.Lock()
- testAllChannelsRunning = false
- testAllChannelsLock.Unlock()
+
if notify {
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
}
@@ -318,10 +467,7 @@ func testAllChannels(notify bool) error {
func TestAllChannels(c *gin.Context) {
err := testAllChannels(true)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -332,6 +478,10 @@ func TestAllChannels(c *gin.Context) {
}
func AutomaticallyTestChannels(frequency int) {
+ if frequency <= 0 {
+ common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
+ return
+ }
for {
time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("testing all channels")
diff --git a/controller/channel.go b/controller/channel.go
index a31e1f47..020a3327 100644
--- a/controller/channel.go
+++ b/controller/channel.go
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/model"
"strconv"
"strings"
@@ -40,50 +41,123 @@ type OpenAIModelsResponse struct {
Success bool `json:"success"`
}
+func parseStatusFilter(statusParam string) int {
+ switch strings.ToLower(statusParam) {
+ case "enabled", "1":
+ return common.ChannelStatusEnabled
+ case "disabled", "0":
+ return 0
+ default:
+ return -1
+ }
+}
+
+func clearChannelInfo(channel *model.Channel) {
+ if channel.ChannelInfo.IsMultiKey {
+ channel.ChannelInfo.MultiKeyDisabledReason = nil
+ channel.ChannelInfo.MultiKeyDisabledTime = nil
+ }
+}
+
func GetAllChannels(c *gin.Context) {
- p, _ := strconv.Atoi(c.Query("p"))
- pageSize, _ := strconv.Atoi(c.Query("page_size"))
- if p < 0 {
- p = 0
- }
- if pageSize < 0 {
- pageSize = common.ItemsPerPage
- }
+ pageInfo := common.GetPageQuery(c)
channelData := make([]*model.Channel, 0)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
+ statusParam := c.Query("status")
+ // statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
+ statusFilter := parseStatusFilter(statusParam)
+ // type filter
+ typeStr := c.Query("type")
+ typeFilter := -1
+ if typeStr != "" {
+ if t, err := strconv.Atoi(typeStr); err == nil {
+ typeFilter = t
+ }
+ }
+
+ var total int64
+
if enableTagMode {
- tags, err := model.GetPaginatedTags(p*pageSize, pageSize)
+ tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
for _, tag := range tags {
- if tag != nil && *tag != "" {
- tagChannel, err := model.GetChannelsByTag(*tag, idSort)
- if err == nil {
- channelData = append(channelData, tagChannel...)
- }
+ if tag == nil || *tag == "" {
+ continue
}
+ tagChannels, err := model.GetChannelsByTag(*tag, idSort)
+ if err != nil {
+ continue
+ }
+ filtered := make([]*model.Channel, 0)
+ for _, ch := range tagChannels {
+ if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
+ continue
+ }
+ if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
+ continue
+ }
+ if typeFilter >= 0 && ch.Type != typeFilter {
+ continue
+ }
+ filtered = append(filtered, ch)
+ }
+ channelData = append(channelData, filtered...)
}
+ total, _ = model.CountAllTags()
} else {
- channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort)
+ baseQuery := model.DB.Model(&model.Channel{})
+ if typeFilter >= 0 {
+ baseQuery = baseQuery.Where("type = ?", typeFilter)
+ }
+ if statusFilter == common.ChannelStatusEnabled {
+ baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
+ } else if statusFilter == 0 {
+ baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
+ }
+
+ baseQuery.Count(&total)
+
+ order := "priority desc"
+ if idSort {
+ order = "id desc"
+ }
+
+ err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
- channelData = channels
}
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": channelData,
+
+ for _, datum := range channelData {
+ clearChannelInfo(datum)
+ }
+
+ countQuery := model.DB.Model(&model.Channel{})
+ if statusFilter == common.ChannelStatusEnabled {
+ countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
+ } else if statusFilter == 0 {
+ countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
+ }
+ var results []struct {
+ Type int64
+ Count int64
+ }
+ _ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
+ typeCounts := make(map[int64]int64)
+ for _, r := range results {
+ typeCounts[r.Type] = r.Count
+ }
+ common.ApiSuccess(c, gin.H{
+ "items": channelData,
+ "total": total,
+ "page": pageInfo.GetPage(),
+ "page_size": pageInfo.GetPageSize(),
+ "type_counts": typeCounts,
})
return
}
@@ -91,46 +165,42 @@ func GetAllChannels(c *gin.Context) {
func FetchUpstreamModels(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
- //if channel.Type != common.ChannelTypeOpenAI {
- // c.JSON(http.StatusOK, gin.H{
- // "success": false,
- // "message": "仅支持 OpenAI 类型渠道",
- // })
- // return
- //}
- baseURL := common.ChannelBaseURLs[channel.Type]
+ baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
- url := fmt.Sprintf("%s/v1/models", baseURL)
+
+ var url string
switch channel.Type {
- case common.ChannelTypeGemini:
- url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
- case common.ChannelTypeAli:
+ case constant.ChannelTypeGemini:
+ // curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
+ url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
+ case constant.ChannelTypeAli:
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
+ default:
+ url = fmt.Sprintf("%s/v1/models", baseURL)
+ }
+
+ // 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
+ var body []byte
+ key := strings.Split(channel.Key, "\n")[0]
+ if channel.Type == constant.ChannelTypeGemini {
+ body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key)) // Use AuthHeader since Gemini now forces it
+ } else {
+ body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key))
}
- body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
@@ -146,7 +216,7 @@ func FetchUpstreamModels(c *gin.Context) {
var ids []string
for _, model := range result.Data {
id := model.ID
- if channel.Type == common.ChannelTypeGemini {
+ if channel.Type == constant.ChannelTypeGemini {
id = strings.TrimPrefix(id, "models/")
}
ids = append(ids, id)
@@ -160,18 +230,18 @@ func FetchUpstreamModels(c *gin.Context) {
}
func FixChannelsAbilities(c *gin.Context) {
- count, err := model.FixAbility()
+ success, fails, err := model.FixAbility()
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": count,
+ "data": gin.H{
+ "success": success,
+ "fails": fails,
+ },
})
}
@@ -179,6 +249,8 @@ func SearchChannels(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
modelKeyword := c.Query("model")
+ statusParam := c.Query("status")
+ statusFilter := parseStatusFilter(statusParam)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
channelData := make([]*model.Channel, 0)
@@ -210,10 +282,78 @@ func SearchChannels(c *gin.Context) {
}
channelData = channels
}
+
+ if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
+ filtered := make([]*model.Channel, 0, len(channelData))
+ for _, ch := range channelData {
+ if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
+ continue
+ }
+ if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
+ continue
+ }
+ filtered = append(filtered, ch)
+ }
+ channelData = filtered
+ }
+
+ // calculate type counts for search results
+ typeCounts := make(map[int64]int64)
+ for _, channel := range channelData {
+ typeCounts[int64(channel.Type)]++
+ }
+
+ typeParam := c.Query("type")
+ typeFilter := -1
+ if typeParam != "" {
+ if tp, err := strconv.Atoi(typeParam); err == nil {
+ typeFilter = tp
+ }
+ }
+
+ if typeFilter >= 0 {
+ filtered := make([]*model.Channel, 0, len(channelData))
+ for _, ch := range channelData {
+ if ch.Type == typeFilter {
+ filtered = append(filtered, ch)
+ }
+ }
+ channelData = filtered
+ }
+
+ page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
+ pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
+ if page < 1 {
+ page = 1
+ }
+ if pageSize <= 0 {
+ pageSize = 20
+ }
+
+ total := len(channelData)
+ startIdx := (page - 1) * pageSize
+ if startIdx > total {
+ startIdx = total
+ }
+ endIdx := startIdx + pageSize
+ if endIdx > total {
+ endIdx = total
+ }
+
+ pagedData := channelData[startIdx:endIdx]
+
+ for _, datum := range pagedData {
+ clearChannelInfo(datum)
+ }
+
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": channelData,
+ "data": gin.H{
+ "items": pagedData,
+ "total": total,
+ "type_counts": typeCounts,
+ },
})
return
}
@@ -221,20 +361,17 @@ func SearchChannels(c *gin.Context) {
func GetChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
channel, err := model.GetChannelById(id, false)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
+ if channel != nil {
+ clearChannelInfo(channel)
+ }
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -243,66 +380,167 @@ func GetChannel(c *gin.Context) {
return
}
-func AddChannel(c *gin.Context) {
- channel := model.Channel{}
- err := c.ShouldBindJSON(&channel)
+// validateChannel 通用的渠道校验函数
+func validateChannel(channel *model.Channel, isAdd bool) error {
+ // 校验 channel settings
+ if err := channel.ValidateSettings(); err != nil {
+ return fmt.Errorf("渠道额外设置[channel setting] 格式错误:%s", err.Error())
+ }
+
+ // 如果是添加操作,检查 channel 和 key 是否为空
+ if isAdd {
+ if channel == nil || channel.Key == "" {
+ return fmt.Errorf("channel cannot be empty")
+ }
+
+ // 检查模型名称长度是否超过 255
+ for _, m := range channel.GetModels() {
+ if len(m) > 255 {
+ return fmt.Errorf("模型名称过长: %s", m)
+ }
+ }
+ }
+
+ // VertexAI 特殊校验
+ if channel.Type == constant.ChannelTypeVertexAi {
+ if channel.Other == "" {
+ return fmt.Errorf("部署地区不能为空")
+ }
+
+ regionMap, err := common.StrToMap(channel.Other)
+ if err != nil {
+ return fmt.Errorf("部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}")
+ }
+
+ if regionMap["default"] == nil {
+ return fmt.Errorf("部署地区必须包含default字段")
+ }
+ }
+
+ return nil
+}
+
+type AddChannelRequest struct {
+ Mode string `json:"mode"`
+ MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
+ Channel *model.Channel `json:"channel"`
+}
+
+func getVertexArrayKeys(keys string) ([]string, error) {
+ if keys == "" {
+ return nil, nil
+ }
+ var keyArray []interface{}
+ err := common.Unmarshal([]byte(keys), &keyArray)
if err != nil {
+ return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
+ }
+ cleanKeys := make([]string, 0, len(keyArray))
+ for _, key := range keyArray {
+ var keyStr string
+ switch v := key.(type) {
+ case string:
+ keyStr = strings.TrimSpace(v)
+ default:
+ bytes, err := json.Marshal(v)
+ if err != nil {
+ return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
+ }
+ keyStr = string(bytes)
+ }
+ if keyStr != "" {
+ cleanKeys = append(cleanKeys, keyStr)
+ }
+ }
+ if len(cleanKeys) == 0 {
+ return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
+ }
+ return cleanKeys, nil
+}
+
+func AddChannel(c *gin.Context) {
+ addChannelRequest := AddChannelRequest{}
+ err := c.ShouldBindJSON(&addChannelRequest)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ // 使用统一的校验函数
+ if err := validateChannel(addChannelRequest.Channel, true); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
- channel.CreatedTime = common.GetTimestamp()
- keys := strings.Split(channel.Key, "\n")
- if channel.Type == common.ChannelTypeVertexAi {
- if channel.Other == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "部署地区不能为空",
- })
- return
- } else {
- if common.IsJsonStr(channel.Other) {
- // must have default
- regionMap := common.StrToMap(channel.Other)
- if regionMap["default"] == nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "部署地区必须包含default字段",
- })
- return
- }
+
+ addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
+ keys := make([]string, 0)
+ switch addChannelRequest.Mode {
+ case "multi_to_single":
+ addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
+ addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
+ if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
+ array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
}
+ addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
+ addChannelRequest.Channel.Key = strings.Join(array, "\n")
+ } else {
+ cleanKeys := make([]string, 0)
+ for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
+ if key == "" {
+ continue
+ }
+ key = strings.TrimSpace(key)
+ cleanKeys = append(cleanKeys, key)
+ }
+ addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
+ addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
}
- keys = []string{channel.Key}
+ keys = []string{addChannelRequest.Channel.Key}
+ case "batch":
+ if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
+ // multi json
+ keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ } else {
+ keys = strings.Split(addChannelRequest.Channel.Key, "\n")
+ }
+ case "single":
+ keys = []string{addChannelRequest.Channel.Key}
+ default:
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "不支持的添加模式",
+ })
+ return
}
+
channels := make([]model.Channel, 0, len(keys))
for _, key := range keys {
if key == "" {
continue
}
- localChannel := channel
+ localChannel := addChannelRequest.Channel
localChannel.Key = key
- // Validate the length of the model name
- models := strings.Split(localChannel.Models, ",")
- for _, model := range models {
- if len(model) > 255 {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": fmt.Sprintf("模型名称过长: %s", model),
- })
- return
- }
- }
- channels = append(channels, localChannel)
+ channels = append(channels, *localChannel)
}
err = model.BatchInsertChannels(channels)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -317,12 +555,10 @@ func DeleteChannel(c *gin.Context) {
channel := model.Channel{Id: id}
err := channel.Delete()
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
+ model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -333,12 +569,10 @@ func DeleteChannel(c *gin.Context) {
func DeleteDisabledChannel(c *gin.Context) {
rows, err := model.DeleteDisabledChannel()
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
+ model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -369,12 +603,10 @@ func DisableTagChannels(c *gin.Context) {
}
err = model.DisableChannelByTag(channelTag.Tag)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
+ model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -394,12 +626,10 @@ func EnableTagChannels(c *gin.Context) {
}
err = model.EnableChannelByTag(channelTag.Tag)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
+ model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -426,12 +656,10 @@ func EditTagChannels(c *gin.Context) {
}
err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
+ model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -456,12 +684,10 @@ func DeleteChannelBatch(c *gin.Context) {
}
err = model.BatchDeleteChannels(channelBatch.Ids)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
+ model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -470,9 +696,30 @@ func DeleteChannelBatch(c *gin.Context) {
return
}
+type PatchChannel struct {
+ model.Channel
+ MultiKeyMode *string `json:"multi_key_mode"`
+ KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
+}
+
func UpdateChannel(c *gin.Context) {
- channel := model.Channel{}
+ channel := PatchChannel{}
err := c.ShouldBindJSON(&channel)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ // 使用统一的校验函数
+ if err := validateChannel(&channel.Channel, false); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ // Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
+ originChannel, err := model.GetChannelById(channel.Id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -480,35 +727,85 @@ func UpdateChannel(c *gin.Context) {
})
return
}
- if channel.Type == common.ChannelTypeVertexAi {
- if channel.Other == "" {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "部署地区不能为空",
- })
- return
- } else {
- if common.IsJsonStr(channel.Other) {
- // must have default
- regionMap := common.StrToMap(channel.Other)
- if regionMap["default"] == nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "部署地区必须包含default字段",
- })
- return
+
+ // Always copy the original ChannelInfo so that fields like IsMultiKey and MultiKeySize are retained.
+ channel.ChannelInfo = originChannel.ChannelInfo
+
+ // If the request explicitly specifies a new MultiKeyMode, apply it on top of the original info.
+ if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
+ channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
+ }
+
+ // 处理多key模式下的密钥追加/覆盖逻辑
+ if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
+ switch *channel.KeyMode {
+ case "append":
+ // 追加模式:将新密钥添加到现有密钥列表
+ if originChannel.Key != "" {
+ var newKeys []string
+ var existingKeys []string
+
+ // 解析现有密钥
+ if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
+ // JSON数组格式
+ var arr []json.RawMessage
+ if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
+ existingKeys = make([]string, len(arr))
+ for i, v := range arr {
+ existingKeys[i] = string(v)
+ }
+ }
+ } else {
+ // 换行分隔格式
+ existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
+ }
+
+ // 处理 Vertex AI 的特殊情况
+ if channel.Type == constant.ChannelTypeVertexAi {
+ // 尝试解析新密钥为JSON数组
+ if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
+ array, err := getVertexArrayKeys(channel.Key)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "追加密钥解析失败: " + err.Error(),
+ })
+ return
+ }
+ newKeys = array
+ } else {
+ // 单个JSON密钥
+ newKeys = []string{channel.Key}
+ }
+ // 合并密钥
+ allKeys := append(existingKeys, newKeys...)
+ channel.Key = strings.Join(allKeys, "\n")
+ } else {
+ // 普通渠道的处理
+ inputKeys := strings.Split(channel.Key, "\n")
+ for _, key := range inputKeys {
+ key = strings.TrimSpace(key)
+ if key != "" {
+ newKeys = append(newKeys, key)
+ }
+ }
+ // 合并密钥
+ allKeys := append(existingKeys, newKeys...)
+ channel.Key = strings.Join(allKeys, "\n")
}
}
+ case "replace":
+ // 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
}
}
err = channel.Update()
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
+ model.InitChannelCache()
+ channel.Key = ""
+ clearChannelInfo(&channel.Channel)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -534,7 +831,7 @@ func FetchModels(c *gin.Context) {
baseURL := req.BaseURL
if baseURL == "" {
- baseURL = common.ChannelBaseURLs[req.Type]
+ baseURL = constant.ChannelBaseURLs[req.Type]
}
client := &http.Client{}
@@ -610,12 +907,10 @@ func BatchSetChannelTag(c *gin.Context) {
}
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
+ model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -623,3 +918,504 @@ func BatchSetChannelTag(c *gin.Context) {
})
return
}
+
+func GetTagModels(c *gin.Context) {
+ tag := c.Query("tag")
+ if tag == "" {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "success": false,
+ "message": "tag不能为空",
+ })
+ return
+ }
+
+ channels, err := model.GetChannelsByTag(tag, false) // Assuming false for idSort is fine here
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ var longestModels string
+ maxLength := 0
+
+ // Find the longest models string among all channels with the given tag
+ for _, channel := range channels {
+ if channel.Models != "" {
+ currentModels := strings.Split(channel.Models, ",")
+ if len(currentModels) > maxLength {
+ maxLength = len(currentModels)
+ longestModels = channel.Models
+ }
+ }
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": longestModels,
+ })
+ return
+}
+
+// CopyChannel handles cloning an existing channel with its key.
+// POST /api/channel/copy/:id
+// Optional query params:
+//
+// suffix - string appended to the original name (default "_复制")
+// reset_balance - bool, when true will reset balance & used_quota to 0 (default true)
+func CopyChannel(c *gin.Context) {
+ id, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid id"})
+ return
+ }
+
+ suffix := c.DefaultQuery("suffix", "_复制")
+ resetBalance := true
+ if rbStr := c.DefaultQuery("reset_balance", "true"); rbStr != "" {
+ if v, err := strconv.ParseBool(rbStr); err == nil {
+ resetBalance = v
+ }
+ }
+
+ // fetch original channel with key
+ origin, err := model.GetChannelById(id, true)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+
+ // clone channel
+ clone := *origin // shallow copy is sufficient as we will overwrite primitives
+ clone.Id = 0 // let DB auto-generate
+ clone.CreatedTime = common.GetTimestamp()
+ clone.Name = origin.Name + suffix
+ clone.TestTime = 0
+ clone.ResponseTime = 0
+ if resetBalance {
+ clone.Balance = 0
+ clone.UsedQuota = 0
+ }
+
+ // insert
+ if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ model.InitChannelCache()
+ // success
+ c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
+}
+
+// MultiKeyManageRequest represents the request for multi-key management operations
+type MultiKeyManageRequest struct {
+ ChannelId int `json:"channel_id"`
+ Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status"
+ KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions
+ Page int `json:"page,omitempty"` // for get_key_status pagination
+ PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
+ Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
+}
+
+// MultiKeyStatusResponse represents the response for key status query
+type MultiKeyStatusResponse struct {
+ Keys []KeyStatus `json:"keys"`
+ Total int `json:"total"`
+ Page int `json:"page"`
+ PageSize int `json:"page_size"`
+ TotalPages int `json:"total_pages"`
+ // Statistics
+ EnabledCount int `json:"enabled_count"`
+ ManualDisabledCount int `json:"manual_disabled_count"`
+ AutoDisabledCount int `json:"auto_disabled_count"`
+}
+
+type KeyStatus struct {
+ Index int `json:"index"`
+ Status int `json:"status"` // 1: enabled, 2: disabled
+ DisabledTime int64 `json:"disabled_time,omitempty"`
+ Reason string `json:"reason,omitempty"`
+ KeyPreview string `json:"key_preview"` // first 10 chars of key for identification
+}
+
+// ManageMultiKeys handles multi-key management operations
+func ManageMultiKeys(c *gin.Context) {
+ request := MultiKeyManageRequest{}
+ err := c.ShouldBindJSON(&request)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ channel, err := model.GetChannelById(request.ChannelId, true)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "渠道不存在",
+ })
+ return
+ }
+
+ if !channel.ChannelInfo.IsMultiKey {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "该渠道不是多密钥模式",
+ })
+ return
+ }
+
+ lock := model.GetChannelPollingLock(channel.Id)
+ lock.Lock()
+ defer lock.Unlock()
+
+ switch request.Action {
+ case "get_key_status":
+ keys := channel.GetKeys()
+
+ // Default pagination parameters
+ page := request.Page
+ pageSize := request.PageSize
+ if page <= 0 {
+ page = 1
+ }
+ if pageSize <= 0 {
+ pageSize = 50 // Default page size
+ }
+
+ // Statistics for all keys (unchanged by filtering)
+ var enabledCount, manualDisabledCount, autoDisabledCount int
+
+ // Build all key status data first
+ var allKeyStatusList []KeyStatus
+ for i, key := range keys {
+ status := 1 // default enabled
+ var disabledTime int64
+ var reason string
+
+ if channel.ChannelInfo.MultiKeyStatusList != nil {
+ if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
+ status = s
+ }
+ }
+
+ // Count for statistics (all keys)
+ switch status {
+ case 1:
+ enabledCount++
+ case 2:
+ manualDisabledCount++
+ case 3:
+ autoDisabledCount++
+ }
+
+ if status != 1 {
+ if channel.ChannelInfo.MultiKeyDisabledTime != nil {
+ disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i]
+ }
+ if channel.ChannelInfo.MultiKeyDisabledReason != nil {
+ reason = channel.ChannelInfo.MultiKeyDisabledReason[i]
+ }
+ }
+
+ // Create key preview (first 10 chars)
+ keyPreview := key
+ if len(key) > 10 {
+ keyPreview = key[:10] + "..."
+ }
+
+ allKeyStatusList = append(allKeyStatusList, KeyStatus{
+ Index: i,
+ Status: status,
+ DisabledTime: disabledTime,
+ Reason: reason,
+ KeyPreview: keyPreview,
+ })
+ }
+
+ // Apply status filter if specified
+ var filteredKeyStatusList []KeyStatus
+ if request.Status != nil {
+ for _, keyStatus := range allKeyStatusList {
+ if keyStatus.Status == *request.Status {
+ filteredKeyStatusList = append(filteredKeyStatusList, keyStatus)
+ }
+ }
+ } else {
+ filteredKeyStatusList = allKeyStatusList
+ }
+
+ // Calculate pagination based on filtered results
+ filteredTotal := len(filteredKeyStatusList)
+ totalPages := (filteredTotal + pageSize - 1) / pageSize
+ if totalPages == 0 {
+ totalPages = 1
+ }
+ if page > totalPages {
+ page = totalPages
+ }
+
+ // Calculate range for current page
+ start := (page - 1) * pageSize
+ end := start + pageSize
+ if end > filteredTotal {
+ end = filteredTotal
+ }
+
+ // Get the page data
+ var pageKeyStatusList []KeyStatus
+ if start < filteredTotal {
+ pageKeyStatusList = filteredKeyStatusList[start:end]
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": MultiKeyStatusResponse{
+ Keys: pageKeyStatusList,
+ Total: filteredTotal, // Total of filtered results
+ Page: page,
+ PageSize: pageSize,
+ TotalPages: totalPages,
+ EnabledCount: enabledCount, // Overall statistics
+ ManualDisabledCount: manualDisabledCount, // Overall statistics
+ AutoDisabledCount: autoDisabledCount, // Overall statistics
+ },
+ })
+ return
+
+ case "disable_key":
+ if request.KeyIndex == nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "未指定要禁用的密钥索引",
+ })
+ return
+ }
+
+ keyIndex := *request.KeyIndex
+ if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "密钥索引超出范围",
+ })
+ return
+ }
+
+ if channel.ChannelInfo.MultiKeyStatusList == nil {
+ channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
+ }
+ if channel.ChannelInfo.MultiKeyDisabledTime == nil {
+ channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
+ }
+ if channel.ChannelInfo.MultiKeyDisabledReason == nil {
+ channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
+ }
+
+ channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled
+
+ err = channel.Update()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ model.InitChannelCache()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "密钥已禁用",
+ })
+ return
+
+ case "enable_key":
+ if request.KeyIndex == nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "未指定要启用的密钥索引",
+ })
+ return
+ }
+
+ keyIndex := *request.KeyIndex
+ if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "密钥索引超出范围",
+ })
+ return
+ }
+
+ // 从状态列表中删除该密钥的记录,使其回到默认启用状态
+ if channel.ChannelInfo.MultiKeyStatusList != nil {
+ delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
+ }
+ if channel.ChannelInfo.MultiKeyDisabledTime != nil {
+ delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex)
+ }
+ if channel.ChannelInfo.MultiKeyDisabledReason != nil {
+ delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex)
+ }
+
+ err = channel.Update()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ model.InitChannelCache()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "密钥已启用",
+ })
+ return
+
+ case "enable_all_keys":
+ // 清空所有禁用状态,使所有密钥回到默认启用状态
+ var enabledCount int
+ if channel.ChannelInfo.MultiKeyStatusList != nil {
+ enabledCount = len(channel.ChannelInfo.MultiKeyStatusList)
+ }
+
+ channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
+ channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
+ channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
+
+ err = channel.Update()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ model.InitChannelCache()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": fmt.Sprintf("已启用 %d 个密钥", enabledCount),
+ })
+ return
+
+ case "disable_all_keys":
+ // 禁用所有启用的密钥
+ if channel.ChannelInfo.MultiKeyStatusList == nil {
+ channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
+ }
+ if channel.ChannelInfo.MultiKeyDisabledTime == nil {
+ channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
+ }
+ if channel.ChannelInfo.MultiKeyDisabledReason == nil {
+ channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
+ }
+
+ var disabledCount int
+ for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ {
+ status := 1 // default enabled
+ if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
+ status = s
+ }
+
+ // 只禁用当前启用的密钥
+ if status == 1 {
+ channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled
+ disabledCount++
+ }
+ }
+
+ if disabledCount == 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "没有可禁用的密钥",
+ })
+ return
+ }
+
+ err = channel.Update()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ model.InitChannelCache()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount),
+ })
+ return
+
+ case "delete_disabled_keys":
+ keys := channel.GetKeys()
+ var remainingKeys []string
+ var deletedCount int
+ var newStatusList = make(map[int]int)
+ var newDisabledTime = make(map[int]int64)
+ var newDisabledReason = make(map[int]string)
+
+ newIndex := 0
+ for i, key := range keys {
+ status := 1 // default enabled
+ if channel.ChannelInfo.MultiKeyStatusList != nil {
+ if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
+ status = s
+ }
+ }
+
+ // 只删除自动禁用(status == 3)的密钥,保留启用(status == 1)和手动禁用(status == 2)的密钥
+ if status == 3 {
+ deletedCount++
+ } else {
+ remainingKeys = append(remainingKeys, key)
+ // 保留非自动禁用密钥的状态信息,重新索引
+ if status != 1 {
+ newStatusList[newIndex] = status
+ if channel.ChannelInfo.MultiKeyDisabledTime != nil {
+ if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
+ newDisabledTime[newIndex] = t
+ }
+ }
+ if channel.ChannelInfo.MultiKeyDisabledReason != nil {
+ if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
+ newDisabledReason[newIndex] = r
+ }
+ }
+ }
+ newIndex++
+ }
+ }
+
+ if deletedCount == 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "没有需要删除的自动禁用密钥",
+ })
+ return
+ }
+
+ // Update channel with remaining keys
+ channel.Key = strings.Join(remainingKeys, "\n")
+ channel.ChannelInfo.MultiKeySize = len(remainingKeys)
+ channel.ChannelInfo.MultiKeyStatusList = newStatusList
+ channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
+ channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
+
+ err = channel.Update()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ model.InitChannelCache()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount),
+ "data": deletedCount,
+ })
+ return
+
+ default:
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "不支持的操作",
+ })
+ return
+ }
+}
diff --git a/controller/console_migrate.go b/controller/console_migrate.go
new file mode 100644
index 00000000..f0812c3d
--- /dev/null
+++ b/controller/console_migrate.go
@@ -0,0 +1,104 @@
+// 用于迁移检测的旧键,该文件下个版本会删除
+
+package controller
+
+import (
+ "encoding/json"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+
+ "github.com/gin-gonic/gin"
+)
+
+// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
+func MigrateConsoleSetting(c *gin.Context) {
+ // 读取全部 option
+ opts, err := model.AllOption()
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ // 建立 map
+ valMap := map[string]string{}
+ for _, o := range opts {
+ valMap[o.Key] = o.Value
+ }
+
+ // 处理 APIInfo
+ if v := valMap["ApiInfo"]; v != "" {
+ var arr []map[string]interface{}
+ if err := json.Unmarshal([]byte(v), &arr); err == nil {
+ if len(arr) > 50 {
+ arr = arr[:50]
+ }
+ bytes, _ := json.Marshal(arr)
+ model.UpdateOption("console_setting.api_info", string(bytes))
+ }
+ model.UpdateOption("ApiInfo", "")
+ }
+ // Announcements 直接搬
+ if v := valMap["Announcements"]; v != "" {
+ model.UpdateOption("console_setting.announcements", v)
+ model.UpdateOption("Announcements", "")
+ }
+ // FAQ 转换
+ if v := valMap["FAQ"]; v != "" {
+ var arr []map[string]interface{}
+ if err := json.Unmarshal([]byte(v), &arr); err == nil {
+ out := []map[string]interface{}{}
+ for _, item := range arr {
+ q, _ := item["question"].(string)
+ if q == "" {
+ q, _ = item["title"].(string)
+ }
+ a, _ := item["answer"].(string)
+ if a == "" {
+ a, _ = item["content"].(string)
+ }
+ if q != "" && a != "" {
+ out = append(out, map[string]interface{}{"question": q, "answer": a})
+ }
+ }
+ if len(out) > 50 {
+ out = out[:50]
+ }
+ bytes, _ := json.Marshal(out)
+ model.UpdateOption("console_setting.faq", string(bytes))
+ }
+ model.UpdateOption("FAQ", "")
+ }
+ // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
+ url := valMap["UptimeKumaUrl"]
+ slug := valMap["UptimeKumaSlug"]
+ if url != "" && slug != "" {
+ // 仅当同时存在 URL 与 Slug 时才进行迁移
+ groups := []map[string]interface{}{
+ {
+ "id": 1,
+ "categoryName": "old",
+ "url": url,
+ "slug": slug,
+ "description": "",
+ },
+ }
+ bytes, _ := json.Marshal(groups)
+ model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
+ }
+ // 清空旧键内容
+ if url != "" {
+ model.UpdateOption("UptimeKumaUrl", "")
+ }
+ if slug != "" {
+ model.UpdateOption("UptimeKumaSlug", "")
+ }
+
+ // 删除旧键记录
+ oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
+ model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
+
+ // 重新加载 OptionMap
+ model.InitOptionMap()
+ common.SysLog("console setting migrated")
+ c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
+}
diff --git a/controller/github.go b/controller/github.go
index 79711841..881d6dc1 100644
--- a/controller/github.go
+++ b/controller/github.go
@@ -5,13 +5,14 @@ import (
"encoding/json"
"errors"
"fmt"
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
)
type GitHubOAuthResponse struct {
@@ -103,10 +104,7 @@ func GitHubOAuth(c *gin.Context) {
code := c.Query("code")
githubUser, err := getGitHubUserInfoByCode(code)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
user := model.User{
@@ -185,10 +183,7 @@ func GitHubBind(c *gin.Context) {
code := c.Query("code")
githubUser, err := getGitHubUserInfoByCode(code)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
user := model.User{
@@ -207,19 +202,13 @@ func GitHubBind(c *gin.Context) {
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
user.GitHubId = githubUser.Login
err = user.Update(false)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -239,10 +228,7 @@ func GenerateOAuthCode(c *gin.Context) {
session.Set("oauth_state", state)
err := session.Save()
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
diff --git a/controller/group.go b/controller/group.go
index 2c725a4d..2565b6ea 100644
--- a/controller/group.go
+++ b/controller/group.go
@@ -1,15 +1,17 @@
package controller
import (
- "github.com/gin-gonic/gin"
"net/http"
"one-api/model"
"one-api/setting"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
- for groupName, _ := range setting.GetGroupRatioCopy() {
+ for groupName := range ratio_setting.GetGroupRatioCopy() {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{
@@ -24,7 +26,7 @@ func GetUserGroups(c *gin.Context) {
userGroup := ""
userId := c.GetInt("id")
userGroup, _ = model.GetUserGroup(userId, false)
- for groupName, ratio := range setting.GetGroupRatioCopy() {
+ for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
// UserUsableGroups contains the groups that the user can use
userUsableGroups := setting.GetUserUsableGroups(userGroup)
if desc, ok := userUsableGroups[groupName]; ok {
@@ -34,6 +36,12 @@ func GetUserGroups(c *gin.Context) {
}
}
}
+ if setting.GroupInUserUsableGroups("auto") {
+ usableGroups["auto"] = map[string]interface{}{
+ "ratio": "自动",
+ "desc": setting.GetUsableGroupDescription("auto"),
+ }
+ }
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
diff --git a/controller/linuxdo.go b/controller/linuxdo.go
index 2cdb3517..9fa15615 100644
--- a/controller/linuxdo.go
+++ b/controller/linuxdo.go
@@ -38,10 +38,7 @@ func LinuxDoBind(c *gin.Context) {
code := c.Query("code")
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
@@ -63,20 +60,14 @@ func LinuxDoBind(c *gin.Context) {
err = user.FillUserById()
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
err = user.Update(false)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
@@ -202,10 +193,7 @@ func LinuxdoOAuth(c *gin.Context) {
code := c.Query("code")
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
@@ -232,21 +220,29 @@ func LinuxdoOAuth(c *gin.Context) {
}
} else {
if common.RegisterEnabled {
- user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
- user.DisplayName = linuxdoUser.Name
- user.Role = common.RoleCommonUser
- user.Status = common.UserStatusEnabled
+ if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel {
+ user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
+ user.DisplayName = linuxdoUser.Name
+ user.Role = common.RoleCommonUser
+ user.Status = common.UserStatusEnabled
- affCode := session.Get("aff")
- inviterId := 0
- if affCode != nil {
- inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
- }
+ affCode := session.Get("aff")
+ inviterId := 0
+ if affCode != nil {
+ inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
+ }
- if err := user.Insert(inviterId); err != nil {
+ if err := user.Insert(inviterId); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ } else {
c.JSON(http.StatusOK, gin.H{
"success": false,
- "message": err.Error(),
+ "message": "Linux DO 信任等级未达到管理员设置的最低信任等级",
})
return
}
diff --git a/controller/log.go b/controller/log.go
index 8d67c83e..042fa725 100644
--- a/controller/log.go
+++ b/controller/log.go
@@ -10,14 +10,7 @@ import (
)
func GetAllLogs(c *gin.Context) {
- p, _ := strconv.Atoi(c.Query("p"))
- pageSize, _ := strconv.Atoi(c.Query("page_size"))
- if p < 1 {
- p = 1
- }
- if pageSize < 0 {
- pageSize = common.ItemsPerPage
- }
+ pageInfo := common.GetPageQuery(c)
logType, _ := strconv.Atoi(c.Query("type"))
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
@@ -26,38 +19,19 @@ func GetAllLogs(c *gin.Context) {
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
group := c.Query("group")
- logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel, group)
+ logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), channel, group)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": map[string]any{
- "items": logs,
- "total": total,
- "page": p,
- "page_size": pageSize,
- },
- })
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(logs)
+ common.ApiSuccess(c, pageInfo)
+ return
}
func GetUserLogs(c *gin.Context) {
- p, _ := strconv.Atoi(c.Query("p"))
- pageSize, _ := strconv.Atoi(c.Query("page_size"))
- if p < 1 {
- p = 1
- }
- if pageSize < 0 {
- pageSize = common.ItemsPerPage
- }
- if pageSize > 100 {
- pageSize = 100
- }
+ pageInfo := common.GetPageQuery(c)
userId := c.GetInt("id")
logType, _ := strconv.Atoi(c.Query("type"))
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
@@ -65,24 +39,14 @@ func GetUserLogs(c *gin.Context) {
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
group := c.Query("group")
- logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize, group)
+ logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), group)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": map[string]any{
- "items": logs,
- "total": total,
- "page": p,
- "page_size": pageSize,
- },
- })
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(logs)
+ common.ApiSuccess(c, pageInfo)
return
}
@@ -90,10 +54,7 @@ func SearchAllLogs(c *gin.Context) {
keyword := c.Query("keyword")
logs, err := model.SearchAllLogs(keyword)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -109,10 +70,7 @@ func SearchUserLogs(c *gin.Context) {
userId := c.GetInt("id")
logs, err := model.SearchUserLogs(userId, keyword)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -198,10 +156,7 @@ func DeleteHistoryLogs(c *gin.Context) {
}
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
diff --git a/controller/midjourney.go b/controller/midjourney.go
index 21027d8f..a67d39c2 100644
--- a/controller/midjourney.go
+++ b/controller/midjourney.go
@@ -5,17 +5,17 @@ import (
"context"
"encoding/json"
"fmt"
- "github.com/gin-gonic/gin"
"io"
- "log"
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
"one-api/model"
"one-api/service"
"one-api/setting"
- "strconv"
"time"
+
+ "github.com/gin-gonic/gin"
)
func UpdateMidjourneyTaskBulk() {
@@ -29,7 +29,7 @@ func UpdateMidjourneyTaskBulk() {
continue
}
- common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
+ logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Midjourney)
nullTaskIds := make([]int, 0)
@@ -48,9 +48,9 @@ func UpdateMidjourneyTaskBulk() {
"progress": "100%",
})
if err != nil {
- common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
+ logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
} else {
- common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
+ logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
}
}
if len(taskChannelM) == 0 {
@@ -58,20 +58,20 @@ func UpdateMidjourneyTaskBulk() {
}
for channelId, taskIds := range taskChannelM {
- common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+ logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
continue
}
midjourneyChannel, err := model.CacheGetChannel(channelId)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
+ logger.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
err := model.MjBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
- common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
+ logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
}
continue
}
@@ -82,7 +82,7 @@ func UpdateMidjourneyTaskBulk() {
})
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
if err != nil {
- common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
+ logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
// 设置超时时间
@@ -94,22 +94,22 @@ func UpdateMidjourneyTaskBulk() {
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := service.GetHttpClient().Do(req)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
+ logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
if resp.StatusCode != http.StatusOK {
- common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+ logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
continue
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
+ logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
var responseItems []dto.MidjourneyDto
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
+ logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
continue
}
resp.Body.Close()
@@ -146,9 +146,25 @@ func UpdateMidjourneyTaskBulk() {
buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr)
}
+ // 映射 VideoUrl
+ task.VideoUrl = responseItem.VideoUrl
+
+ // 映射 VideoUrls - 将数组序列化为 JSON 字符串
+ if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 {
+ videoUrlsStr, err := json.Marshal(responseItem.VideoUrls)
+ if err != nil {
+ logger.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err))
+ task.VideoUrls = "[]" // 失败时设置为空数组
+ } else {
+ task.VideoUrls = string(videoUrlsStr)
+ }
+ } else {
+ task.VideoUrls = "" // 空值时清空字段
+ }
+
shouldReturnQuota := false
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
- common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
+ logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
if task.Quota != 0 {
shouldReturnQuota = true
@@ -156,14 +172,14 @@ func UpdateMidjourneyTaskBulk() {
}
err = task.Update()
if err != nil {
- common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
+ logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else {
if shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
if err != nil {
- common.LogError(ctx, "fail to increase user quota: "+err.Error())
+ logger.LogError(ctx, "fail to increase user quota: "+err.Error())
}
- logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota))
+ logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
@@ -209,15 +225,26 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
if oldTask.Progress != "100%" && newTask.FailReason != "" {
return true
}
+ // 检查 VideoUrl 是否需要更新
+ if oldTask.VideoUrl != newTask.VideoUrl {
+ return true
+ }
+ // 检查 VideoUrls 是否需要更新
+ if newTask.VideoUrls != nil && len(newTask.VideoUrls) > 0 {
+ newVideoUrlsStr, _ := json.Marshal(newTask.VideoUrls)
+ if oldTask.VideoUrls != string(newVideoUrlsStr) {
+ return true
+ }
+ } else if oldTask.VideoUrls != "" {
+ // 如果新数据没有 VideoUrls 但旧数据有,需要更新(清空)
+ return true
+ }
return false
}
func GetAllMidjourney(c *gin.Context) {
- p, _ := strconv.Atoi(c.Query("p"))
- if p < 0 {
- p = 0
- }
+ pageInfo := common.GetPageQuery(c)
// 解析其他查询参数
queryParams := model.TaskQueryParams{
@@ -227,31 +254,24 @@ func GetAllMidjourney(c *gin.Context) {
EndTimestamp: c.Query("end_timestamp"),
}
- logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
- if logs == nil {
- logs = make([]*model.Midjourney, 0)
- }
+ items := model.GetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
+ total := model.CountAllTasks(queryParams)
+
if setting.MjForwardUrlEnabled {
- for i, midjourney := range logs {
+ for i, midjourney := range items {
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
- logs[i] = midjourney
+ items[i] = midjourney
}
}
- c.JSON(200, gin.H{
- "success": true,
- "message": "",
- "data": logs,
- })
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(items)
+ common.ApiSuccess(c, pageInfo)
}
func GetUserMidjourney(c *gin.Context) {
- p, _ := strconv.Atoi(c.Query("p"))
- if p < 0 {
- p = 0
- }
+ pageInfo := common.GetPageQuery(c)
userId := c.GetInt("id")
- log.Printf("userId = %d \n", userId)
queryParams := model.TaskQueryParams{
MjID: c.Query("mj_id"),
@@ -259,19 +279,16 @@ func GetUserMidjourney(c *gin.Context) {
EndTimestamp: c.Query("end_timestamp"),
}
- logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
- if logs == nil {
- logs = make([]*model.Midjourney, 0)
- }
+ items := model.GetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
+ total := model.CountAllUserTask(userId, queryParams)
+
if setting.MjForwardUrlEnabled {
- for i, midjourney := range logs {
+ for i, midjourney := range items {
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
- logs[i] = midjourney
+ items[i] = midjourney
}
}
- c.JSON(200, gin.H{
- "success": true,
- "message": "",
- "data": logs,
- })
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(items)
+ common.ApiSuccess(c, pageInfo)
}
diff --git a/controller/misc.go b/controller/misc.go
index 4d265c3f..f30ab8c7 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -6,8 +6,10 @@ import (
"net/http"
"one-api/common"
"one-api/constant"
+ "one-api/middleware"
"one-api/model"
"one-api/setting"
+ "one-api/setting/console_setting"
"one-api/setting/operation_setting"
"one-api/setting/system_setting"
"strings"
@@ -24,57 +26,90 @@ func TestStatus(c *gin.Context) {
})
return
}
+ // 获取HTTP统计信息
+ httpStats := middleware.GetStats()
c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "Server is running",
+ "success": true,
+ "message": "Server is running",
+ "http_stats": httpStats,
})
return
}
func GetStatus(c *gin.Context) {
+
+ cs := console_setting.GetConsoleSetting()
+
+ data := gin.H{
+ "version": common.Version,
+ "start_time": common.StartTime,
+ "email_verification": common.EmailVerificationEnabled,
+ "github_oauth": common.GitHubOAuthEnabled,
+ "github_client_id": common.GitHubClientId,
+ "linuxdo_oauth": common.LinuxDOOAuthEnabled,
+ "linuxdo_client_id": common.LinuxDOClientId,
+ "linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel,
+ "telegram_oauth": common.TelegramOAuthEnabled,
+ "telegram_bot_name": common.TelegramBotName,
+ "system_name": common.SystemName,
+ "logo": common.Logo,
+ "footer_html": common.Footer,
+ "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
+ "wechat_login": common.WeChatAuthEnabled,
+ "server_address": setting.ServerAddress,
+ "price": setting.Price,
+ "stripe_unit_price": setting.StripeUnitPrice,
+ "min_topup": setting.MinTopUp,
+ "stripe_min_topup": setting.StripeMinTopUp,
+ "turnstile_check": common.TurnstileCheckEnabled,
+ "turnstile_site_key": common.TurnstileSiteKey,
+ "top_up_link": common.TopUpLink,
+ "docs_link": operation_setting.GetGeneralSetting().DocsLink,
+ "quota_per_unit": common.QuotaPerUnit,
+ "display_in_currency": common.DisplayInCurrencyEnabled,
+ "enable_batch_update": common.BatchUpdateEnabled,
+ "enable_drawing": common.DrawingEnabled,
+ "enable_task": common.TaskEnabled,
+ "enable_data_export": common.DataExportEnabled,
+ "data_export_default_time": common.DataExportDefaultTime,
+ "default_collapse_sidebar": common.DefaultCollapseSidebar,
+ "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
+ "enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
+ "mj_notify_enabled": setting.MjNotifyEnabled,
+ "chats": setting.Chats,
+ "demo_site_enabled": operation_setting.DemoSiteEnabled,
+ "self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
+ "default_use_auto_group": setting.DefaultUseAutoGroup,
+ "pay_methods": setting.PayMethods,
+ "usd_exchange_rate": setting.USDExchangeRate,
+
+ // 面板启用开关
+ "api_info_enabled": cs.ApiInfoEnabled,
+ "uptime_kuma_enabled": cs.UptimeKumaEnabled,
+ "announcements_enabled": cs.AnnouncementsEnabled,
+ "faq_enabled": cs.FAQEnabled,
+
+ "oidc_enabled": system_setting.GetOIDCSettings().Enabled,
+ "oidc_client_id": system_setting.GetOIDCSettings().ClientId,
+ "oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
+ "setup": constant.Setup,
+ }
+
+ // 根据启用状态注入可选内容
+ if cs.ApiInfoEnabled {
+ data["api_info"] = console_setting.GetApiInfo()
+ }
+ if cs.AnnouncementsEnabled {
+ data["announcements"] = console_setting.GetAnnouncements()
+ }
+ if cs.FAQEnabled {
+ data["faq"] = console_setting.GetFAQ()
+ }
+
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": gin.H{
- "version": common.Version,
- "start_time": common.StartTime,
- "email_verification": common.EmailVerificationEnabled,
- "github_oauth": common.GitHubOAuthEnabled,
- "github_client_id": common.GitHubClientId,
- "linuxdo_oauth": common.LinuxDOOAuthEnabled,
- "linuxdo_client_id": common.LinuxDOClientId,
- "telegram_oauth": common.TelegramOAuthEnabled,
- "telegram_bot_name": common.TelegramBotName,
- "system_name": common.SystemName,
- "logo": common.Logo,
- "footer_html": common.Footer,
- "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
- "wechat_login": common.WeChatAuthEnabled,
- "server_address": setting.ServerAddress,
- "price": setting.Price,
- "min_topup": setting.MinTopUp,
- "turnstile_check": common.TurnstileCheckEnabled,
- "turnstile_site_key": common.TurnstileSiteKey,
- "top_up_link": common.TopUpLink,
- "docs_link": operation_setting.GetGeneralSetting().DocsLink,
- "quota_per_unit": common.QuotaPerUnit,
- "display_in_currency": common.DisplayInCurrencyEnabled,
- "enable_batch_update": common.BatchUpdateEnabled,
- "enable_drawing": common.DrawingEnabled,
- "enable_task": common.TaskEnabled,
- "enable_data_export": common.DataExportEnabled,
- "data_export_default_time": common.DataExportDefaultTime,
- "default_collapse_sidebar": common.DefaultCollapseSidebar,
- "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
- "mj_notify_enabled": setting.MjNotifyEnabled,
- "chats": setting.Chats,
- "demo_site_enabled": operation_setting.DemoSiteEnabled,
- "self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
- "oidc_enabled": system_setting.GetOIDCSettings().Enabled,
- "oidc_client_id": system_setting.GetOIDCSettings().ClientId,
- "oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
- "setup": constant.Setup,
- },
+ "data": data,
})
return
}
@@ -184,10 +219,7 @@ func SendEmailVerification(c *gin.Context) {
"验证码 %d 分钟内有效,如果不是本人操作,请忽略。
", common.SystemName, code, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -223,10 +255,7 @@ func SendPasswordResetEmail(c *gin.Context) {
"重置链接 %d 分钟内有效,如果不是本人操作,请忽略。
", common.SystemName, link, link, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -261,10 +290,7 @@ func ResetPassword(c *gin.Context) {
password := common.GenerateVerificationCode(12)
err = model.ResetUserPasswordByEmail(req.Email, password)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
common.DeleteKey(req.Email, common.PasswordResetPurpose)
diff --git a/controller/missing_models.go b/controller/missing_models.go
new file mode 100644
index 00000000..425f9b25
--- /dev/null
+++ b/controller/missing_models.go
@@ -0,0 +1,27 @@
+package controller
+
+import (
+ "net/http"
+ "one-api/model"
+
+ "github.com/gin-gonic/gin"
+)
+
+// GetMissingModels returns the list of model names that are referenced by channels
+// but do not have corresponding records in the models meta table.
+// This helps administrators quickly discover models that need configuration.
+func GetMissingModels(c *gin.Context) {
+ missing, err := model.GetMissingModels()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "data": missing,
+ })
+}
diff --git a/controller/model.go b/controller/model.go
index df7e59a6..398503e8 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -3,6 +3,7 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
+ "github.com/samber/lo"
"net/http"
"one-api/common"
"one-api/constant"
@@ -14,7 +15,8 @@ import (
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
+ "one-api/setting"
+ "time"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -23,30 +25,10 @@ var openAIModels []dto.OpenAIModels
var openAIModelsMap map[string]dto.OpenAIModels
var channelId2Models map[int][]string
-func getPermission() []dto.OpenAIModelPermission {
- var permission []dto.OpenAIModelPermission
- permission = append(permission, dto.OpenAIModelPermission{
- Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
- Object: "model_permission",
- Created: 1626777600,
- AllowCreateEngine: true,
- AllowSampling: true,
- AllowLogprobs: true,
- AllowSearchIndices: false,
- AllowView: true,
- AllowFineTuning: false,
- Organization: "*",
- Group: nil,
- IsBlocking: false,
- })
- return permission
-}
-
func init() {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
- permission := getPermission()
- for i := 0; i < relayconstant.APITypeDummy; i++ {
- if i == relayconstant.APITypeAIProxyLibrary {
+ for i := 0; i < constant.APITypeDummy; i++ {
+ if i == constant.APITypeAIProxyLibrary {
continue
}
adaptor := relay.GetAdaptor(i)
@@ -54,69 +36,51 @@ func init() {
modelNames := adaptor.GetModelList()
for _, modelName := range modelNames {
openAIModels = append(openAIModels, dto.OpenAIModels{
- Id: modelName,
- Object: "model",
- Created: 1626777600,
- OwnedBy: channelName,
- Permission: permission,
- Root: modelName,
- Parent: nil,
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: channelName,
})
}
}
for _, modelName := range ai360.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
- Id: modelName,
- Object: "model",
- Created: 1626777600,
- OwnedBy: ai360.ChannelName,
- Permission: permission,
- Root: modelName,
- Parent: nil,
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: ai360.ChannelName,
})
}
for _, modelName := range moonshot.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
- Id: modelName,
- Object: "model",
- Created: 1626777600,
- OwnedBy: moonshot.ChannelName,
- Permission: permission,
- Root: modelName,
- Parent: nil,
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: moonshot.ChannelName,
})
}
for _, modelName := range lingyiwanwu.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
- Id: modelName,
- Object: "model",
- Created: 1626777600,
- OwnedBy: lingyiwanwu.ChannelName,
- Permission: permission,
- Root: modelName,
- Parent: nil,
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: lingyiwanwu.ChannelName,
})
}
for _, modelName := range minimax.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{
- Id: modelName,
- Object: "model",
- Created: 1626777600,
- OwnedBy: minimax.ChannelName,
- Permission: permission,
- Root: modelName,
- Parent: nil,
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: minimax.ChannelName,
})
}
for modelName, _ := range constant.MidjourneyModel2Action {
openAIModels = append(openAIModels, dto.OpenAIModels{
- Id: modelName,
- Object: "model",
- Created: 1626777600,
- OwnedBy: "midjourney",
- Permission: permission,
- Root: modelName,
- Parent: nil,
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: "midjourney",
})
}
openAIModelsMap = make(map[string]dto.OpenAIModels)
@@ -124,25 +88,29 @@ func init() {
openAIModelsMap[aiModel.Id] = aiModel
}
channelId2Models = make(map[int][]string)
- for i := 1; i <= common.ChannelTypeDummy; i++ {
- apiType, success := relayconstant.ChannelType2APIType(i)
- if !success || apiType == relayconstant.APITypeAIProxyLibrary {
+ for i := 1; i <= constant.ChannelTypeDummy; i++ {
+ apiType, success := common.ChannelType2APIType(i)
+ if !success || apiType == constant.APITypeAIProxyLibrary {
continue
}
- meta := &relaycommon.RelayInfo{ChannelType: i}
+ meta := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
+ ChannelType: i,
+ }}
adaptor := relay.GetAdaptor(apiType)
adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList()
}
+ openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
+ return m.Id
+ })
}
-func ListModels(c *gin.Context) {
+func ListModels(c *gin.Context, modelType int) {
userOpenAiModels := make([]dto.OpenAIModels, 0)
- permission := getPermission()
- modelLimitEnable := c.GetBool("token_model_limit_enabled")
+ modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
if modelLimitEnable {
- s, ok := c.Get("token_model_limit")
+ s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
@@ -150,23 +118,22 @@ func ListModels(c *gin.Context) {
tokenModelLimit = map[string]bool{}
}
for allowModel, _ := range tokenModelLimit {
- if _, ok := openAIModelsMap[allowModel]; ok {
- userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
+ if oaiModel, ok := openAIModelsMap[allowModel]; ok {
+ oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
+ userOpenAiModels = append(userOpenAiModels, oaiModel)
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
- Id: allowModel,
- Object: "model",
- Created: 1626777600,
- OwnedBy: "custom",
- Permission: permission,
- Root: allowModel,
- Parent: nil,
+ Id: allowModel,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: "custom",
+ SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
})
}
}
} else {
userId := c.GetInt("id")
- userGroup, err := model.GetUserGroup(userId, true)
+ userGroup, err := model.GetUserGroup(userId, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -175,31 +142,73 @@ func ListModels(c *gin.Context) {
return
}
group := userGroup
- tokenGroup := c.GetString("token_group")
+ tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
if tokenGroup != "" {
group = tokenGroup
}
- models := model.GetGroupModels(group)
- for _, s := range models {
- if _, ok := openAIModelsMap[s]; ok {
- userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
+ var models []string
+ if tokenGroup == "auto" {
+ for _, autoGroup := range setting.AutoGroups {
+ groupModels := model.GetGroupEnabledModels(autoGroup)
+ for _, g := range groupModels {
+ if !common.StringsContains(models, g) {
+ models = append(models, g)
+ }
+ }
+ }
+ } else {
+ models = model.GetGroupEnabledModels(group)
+ }
+ for _, modelName := range models {
+ if oaiModel, ok := openAIModelsMap[modelName]; ok {
+ oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
+ userOpenAiModels = append(userOpenAiModels, oaiModel)
} else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
- Id: s,
- Object: "model",
- Created: 1626777600,
- OwnedBy: "custom",
- Permission: permission,
- Root: s,
- Parent: nil,
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: "custom",
+ SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
})
}
}
}
- c.JSON(200, gin.H{
- "success": true,
- "data": userOpenAiModels,
- })
+ switch modelType {
+ case constant.ChannelTypeAnthropic:
+ useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
+ for i, model := range userOpenAiModels {
+ useranthropicModels[i] = dto.AnthropicModel{
+ ID: model.Id,
+ CreatedAt: time.Unix(int64(model.Created), 0).UTC().Format(time.RFC3339),
+ DisplayName: model.Id,
+ Type: "model",
+ }
+ }
+ c.JSON(200, gin.H{
+ "data": useranthropicModels,
+ "first_id": useranthropicModels[0].ID,
+ "has_more": false,
+ "last_id": useranthropicModels[len(useranthropicModels)-1].ID,
+ })
+ case constant.ChannelTypeGemini:
+ userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels))
+ for i, model := range userOpenAiModels {
+ userGeminiModels[i] = dto.GeminiModel{
+ Name: model.Id,
+ DisplayName: model.Id,
+ }
+ }
+ c.JSON(200, gin.H{
+ "models": userGeminiModels,
+ "nextPageToken": nil,
+ })
+ default:
+ c.JSON(200, gin.H{
+ "success": true,
+ "data": userOpenAiModels,
+ })
+ }
}
func ChannelListModels(c *gin.Context) {
@@ -223,10 +232,20 @@ func EnabledListModels(c *gin.Context) {
})
}
-func RetrieveModel(c *gin.Context) {
+func RetrieveModel(c *gin.Context, modelType int) {
modelId := c.Param("model")
if aiModel, ok := openAIModelsMap[modelId]; ok {
- c.JSON(200, aiModel)
+ switch modelType {
+ case constant.ChannelTypeAnthropic:
+ c.JSON(200, dto.AnthropicModel{
+ ID: aiModel.Id,
+ CreatedAt: time.Unix(int64(aiModel.Created), 0).UTC().Format(time.RFC3339),
+ DisplayName: aiModel.Id,
+ Type: "model",
+ })
+ default:
+ c.JSON(200, aiModel)
+ }
} else {
openAIError := dto.OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
diff --git a/controller/model_meta.go b/controller/model_meta.go
new file mode 100644
index 00000000..31ea64f3
--- /dev/null
+++ b/controller/model_meta.go
@@ -0,0 +1,330 @@
+package controller
+
+import (
+ "encoding/json"
+ "sort"
+ "strconv"
+ "strings"
+
+ "one-api/common"
+ "one-api/constant"
+ "one-api/model"
+
+ "github.com/gin-gonic/gin"
+)
+
+// GetAllModelsMeta 获取模型列表(分页)
+func GetAllModelsMeta(c *gin.Context) {
+
+ pageInfo := common.GetPageQuery(c)
+ modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ // 批量填充附加字段,提升列表接口性能
+ enrichModels(modelsMeta)
+ var total int64
+ model.DB.Model(&model.Model{}).Count(&total)
+
+ // 统计供应商计数(全部数据,不受分页影响)
+ vendorCounts, _ := model.GetVendorModelCounts()
+
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(modelsMeta)
+ common.ApiSuccess(c, gin.H{
+ "items": modelsMeta,
+ "total": total,
+ "page": pageInfo.GetPage(),
+ "page_size": pageInfo.GetPageSize(),
+ "vendor_counts": vendorCounts,
+ })
+}
+
+// SearchModelsMeta 搜索模型列表
+func SearchModelsMeta(c *gin.Context) {
+
+ keyword := c.Query("keyword")
+ vendor := c.Query("vendor")
+ pageInfo := common.GetPageQuery(c)
+
+ modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ // 批量填充附加字段,提升列表接口性能
+ enrichModels(modelsMeta)
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(modelsMeta)
+ common.ApiSuccess(c, pageInfo)
+}
+
+// GetModelMeta 根据 ID 获取单条模型信息
+func GetModelMeta(c *gin.Context) {
+ idStr := c.Param("id")
+ id, err := strconv.Atoi(idStr)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ var m model.Model
+ if err := model.DB.First(&m, id).Error; err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ enrichModels([]*model.Model{&m})
+ common.ApiSuccess(c, &m)
+}
+
+// CreateModelMeta 新建模型
+func CreateModelMeta(c *gin.Context) {
+ var m model.Model
+ if err := c.ShouldBindJSON(&m); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if m.ModelName == "" {
+ common.ApiErrorMsg(c, "模型名称不能为空")
+ return
+ }
+ // 名称冲突检查
+ if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
+ common.ApiError(c, err)
+ return
+ } else if dup {
+ common.ApiErrorMsg(c, "模型名称已存在")
+ return
+ }
+
+ if err := m.Insert(); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ model.RefreshPricing()
+ common.ApiSuccess(c, &m)
+}
+
+// UpdateModelMeta 更新模型
+func UpdateModelMeta(c *gin.Context) {
+ statusOnly := c.Query("status_only") == "true"
+
+ var m model.Model
+ if err := c.ShouldBindJSON(&m); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if m.Id == 0 {
+ common.ApiErrorMsg(c, "缺少模型 ID")
+ return
+ }
+
+ if statusOnly {
+ // 只更新状态,防止误清空其他字段
+ if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ } else {
+ // 名称冲突检查
+ if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
+ common.ApiError(c, err)
+ return
+ } else if dup {
+ common.ApiErrorMsg(c, "模型名称已存在")
+ return
+ }
+
+ if err := m.Update(); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ }
+ model.RefreshPricing()
+ common.ApiSuccess(c, &m)
+}
+
+// DeleteModelMeta 删除模型
+func DeleteModelMeta(c *gin.Context) {
+ idStr := c.Param("id")
+ id, err := strconv.Atoi(idStr)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ model.RefreshPricing()
+ common.ApiSuccess(c, nil)
+}
+
+// enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询
+func enrichModels(models []*model.Model) {
+ if len(models) == 0 {
+ return
+ }
+
+ // 1) 拆分精确与规则匹配
+ exactNames := make([]string, 0)
+ exactIdx := make(map[string][]int) // modelName -> indices in models
+ ruleIndices := make([]int, 0)
+ for i, m := range models {
+ if m == nil {
+ continue
+ }
+ if m.NameRule == model.NameRuleExact {
+ exactNames = append(exactNames, m.ModelName)
+ exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i)
+ } else {
+ ruleIndices = append(ruleIndices, i)
+ }
+ }
+
+ // 2) 批量查询精确模型的绑定渠道
+ channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames)
+
+ // 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存
+ for name, indices := range exactIdx {
+ chs := channelsByModel[name]
+ for _, idx := range indices {
+ mm := models[idx]
+ if mm.Endpoints == "" {
+ eps := model.GetModelSupportEndpointTypes(mm.ModelName)
+ if b, err := json.Marshal(eps); err == nil {
+ mm.Endpoints = string(b)
+ }
+ }
+ mm.BoundChannels = chs
+ mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName)
+ mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName)
+ }
+ }
+
+ if len(ruleIndices) == 0 {
+ return
+ }
+
+ // 4) 一次性读取定价缓存,内存匹配所有规则模型
+ pricings := model.GetPricing()
+
+ // 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合
+ matchedNamesByIdx := make(map[int][]string)
+ endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{})
+ groupSetByIdx := make(map[int]map[string]struct{})
+ quotaSetByIdx := make(map[int]map[int]struct{})
+
+ for _, p := range pricings {
+ for _, idx := range ruleIndices {
+ mm := models[idx]
+ var matched bool
+ switch mm.NameRule {
+ case model.NameRulePrefix:
+ matched = strings.HasPrefix(p.ModelName, mm.ModelName)
+ case model.NameRuleSuffix:
+ matched = strings.HasSuffix(p.ModelName, mm.ModelName)
+ case model.NameRuleContains:
+ matched = strings.Contains(p.ModelName, mm.ModelName)
+ }
+ if !matched {
+ continue
+ }
+ matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName)
+
+ es := endpointSetByIdx[idx]
+ if es == nil {
+ es = make(map[constant.EndpointType]struct{})
+ endpointSetByIdx[idx] = es
+ }
+ for _, et := range p.SupportedEndpointTypes {
+ es[et] = struct{}{}
+ }
+
+ gs := groupSetByIdx[idx]
+ if gs == nil {
+ gs = make(map[string]struct{})
+ groupSetByIdx[idx] = gs
+ }
+ for _, g := range p.EnableGroup {
+ gs[g] = struct{}{}
+ }
+
+ qs := quotaSetByIdx[idx]
+ if qs == nil {
+ qs = make(map[int]struct{})
+ quotaSetByIdx[idx] = qs
+ }
+ qs[p.QuotaType] = struct{}{}
+ }
+ }
+
+ // 5) 汇总所有匹配到的模型名称,批量查询一次渠道
+ allMatchedSet := make(map[string]struct{})
+ for _, names := range matchedNamesByIdx {
+ for _, n := range names {
+ allMatchedSet[n] = struct{}{}
+ }
+ }
+ allMatched := make([]string, 0, len(allMatchedSet))
+ for n := range allMatchedSet {
+ allMatched = append(allMatched, n)
+ }
+ matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched)
+
+ // 6) 回填每个规则模型的并集信息
+ for _, idx := range ruleIndices {
+ mm := models[idx]
+
+ // 端点并集 -> 序列化
+ if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" {
+ eps := make([]constant.EndpointType, 0, len(es))
+ for et := range es {
+ eps = append(eps, et)
+ }
+ if b, err := json.Marshal(eps); err == nil {
+ mm.Endpoints = string(b)
+ }
+ }
+
+ // 分组并集
+ if gs, ok := groupSetByIdx[idx]; ok {
+ groups := make([]string, 0, len(gs))
+ for g := range gs {
+ groups = append(groups, g)
+ }
+ mm.EnableGroups = groups
+ }
+
+ // 配额类型集合(保持去重并排序)
+ if qs, ok := quotaSetByIdx[idx]; ok {
+ arr := make([]int, 0, len(qs))
+ for k := range qs {
+ arr = append(arr, k)
+ }
+ sort.Ints(arr)
+ mm.QuotaTypes = arr
+ }
+
+ // 渠道并集
+ names := matchedNamesByIdx[idx]
+ channelSet := make(map[string]model.BoundChannel)
+ for _, n := range names {
+ for _, ch := range matchedChannelsByModel[n] {
+ key := ch.Name + "_" + strconv.Itoa(ch.Type)
+ channelSet[key] = ch
+ }
+ }
+ if len(channelSet) > 0 {
+ chs := make([]model.BoundChannel, 0, len(channelSet))
+ for _, ch := range channelSet {
+ chs = append(chs, ch)
+ }
+ mm.BoundChannels = chs
+ }
+
+ // 匹配信息
+ mm.MatchedModels = names
+ mm.MatchedCount = len(names)
+ }
+}
diff --git a/controller/oidc.go b/controller/oidc.go
index 440e0964..f3def0e3 100644
--- a/controller/oidc.go
+++ b/controller/oidc.go
@@ -69,7 +69,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
}
if oidcResponse.AccessToken == "" {
- common.SysError("OIDC 获取 Token 失败,请检查设置!")
+ common.SysLog("OIDC 获取 Token 失败,请检查设置!")
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
}
@@ -85,7 +85,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
}
defer res2.Body.Close()
if res2.StatusCode != http.StatusOK {
- common.SysError("OIDC 获取用户信息失败!请检查设置!")
+ common.SysLog("OIDC 获取用户信息失败!请检查设置!")
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
}
@@ -95,7 +95,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
return nil, err
}
if oidcUser.OpenID == "" || oidcUser.Email == "" {
- common.SysError("OIDC 获取用户信息为空!请检查设置!")
+ common.SysLog("OIDC 获取用户信息为空!请检查设置!")
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
}
return &oidcUser, nil
@@ -126,10 +126,7 @@ func OidcAuth(c *gin.Context) {
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
user := model.User{
@@ -195,10 +192,7 @@ func OidcBind(c *gin.Context) {
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
user := model.User{
@@ -217,19 +211,13 @@ func OidcBind(c *gin.Context) {
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
user.OidcId = oidcUser.OpenID
err = user.Update(false)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
diff --git a/controller/option.go b/controller/option.go
index 250f16bb..decdb0d4 100644
--- a/controller/option.go
+++ b/controller/option.go
@@ -6,6 +6,8 @@ import (
"one-api/common"
"one-api/model"
"one-api/setting"
+ "one-api/setting/console_setting"
+ "one-api/setting/ratio_setting"
"one-api/setting/system_setting"
"strings"
@@ -102,7 +104,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "GroupRatio":
- err = setting.CheckGroupRatio(option.Value)
+ err = ratio_setting.CheckGroupRatio(option.Value)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -119,14 +121,46 @@ func UpdateOption(c *gin.Context) {
})
return
}
-
+ case "console_setting.api_info":
+ err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ case "console_setting.announcements":
+ err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ case "console_setting.faq":
+ err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ case "console_setting.uptime_kuma_groups":
+ err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
}
err = model.UpdateOption(option.Key, option.Value)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
diff --git a/controller/playground.go b/controller/playground.go
index a2b54790..8a1cb2b6 100644
--- a/controller/playground.go
+++ b/controller/playground.go
@@ -3,67 +3,58 @@ package controller
import (
"errors"
"fmt"
- "github.com/gin-gonic/gin"
- "net/http"
"one-api/common"
"one-api/constant"
- "one-api/dto"
"one-api/middleware"
"one-api/model"
- "one-api/service"
- "one-api/setting"
+ "one-api/types"
"time"
+
+ "github.com/gin-gonic/gin"
)
func Playground(c *gin.Context) {
- var openaiErr *dto.OpenAIErrorWithStatusCode
+ var newAPIError *types.NewAPIError
defer func() {
- if openaiErr != nil {
- c.JSON(openaiErr.StatusCode, gin.H{
- "error": openaiErr.Error,
+ if newAPIError != nil {
+ c.JSON(newAPIError.StatusCode, gin.H{
+ "error": newAPIError.ToOpenAIError(),
})
}
}()
useAccessToken := c.GetBool("use_access_token")
if useAccessToken {
- openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
+ newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
return
}
- playgroundRequest := &dto.PlayGroundRequest{}
- err := common.UnmarshalBodyReusable(c, playgroundRequest)
+ group := c.GetString("group")
+ modelName := c.GetString("original_model")
+
+ userId := c.GetInt("id")
+
+ // Write user context to ensure acceptUnsetRatio is available
+ userCache, err := model.GetUserCache(userId)
if err != nil {
- openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
+ newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
return
}
+ userCache.WriteContext(c)
- if playgroundRequest.Model == "" {
- openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
+ tempToken := &model.Token{
+ UserId: userId,
+ Name: fmt.Sprintf("playground-%s", group),
+ Group: group,
+ }
+ _ = middleware.SetupContextForToken(c, tempToken)
+ _, newAPIError = getChannel(c, group, modelName, 0)
+ if newAPIError != nil {
return
}
- c.Set("original_model", playgroundRequest.Model)
- group := playgroundRequest.Group
- userGroup := c.GetString("group")
+ //middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
+ common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
- if group == "" {
- group = userGroup
- } else {
- if !setting.GroupInUserUsableGroups(group) && group != userGroup {
- openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
- return
- }
- c.Set("group", group)
- }
- c.Set("token_name", "playground-"+group)
- channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
- if err != nil {
- message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
- openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
- return
- }
- middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
- c.Set(constant.ContextKeyRequestStartTime, time.Now())
- Relay(c)
+ Relay(c, types.RelayFormatOpenAI)
}
diff --git a/controller/prefill_group.go b/controller/prefill_group.go
new file mode 100644
index 00000000..d912d609
--- /dev/null
+++ b/controller/prefill_group.go
@@ -0,0 +1,90 @@
+package controller
+
+import (
+ "strconv"
+
+ "one-api/common"
+ "one-api/model"
+
+ "github.com/gin-gonic/gin"
+)
+
+// GetPrefillGroups 获取预填组列表,可通过 ?type=xxx 过滤
+func GetPrefillGroups(c *gin.Context) {
+ groupType := c.Query("type")
+ groups, err := model.GetAllPrefillGroups(groupType)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ common.ApiSuccess(c, groups)
+}
+
+// CreatePrefillGroup 创建新的预填组
+func CreatePrefillGroup(c *gin.Context) {
+ var g model.PrefillGroup
+ if err := c.ShouldBindJSON(&g); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if g.Name == "" || g.Type == "" {
+ common.ApiErrorMsg(c, "组名称和类型不能为空")
+ return
+ }
+ // 创建前检查名称
+ if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil {
+ common.ApiError(c, err)
+ return
+ } else if dup {
+ common.ApiErrorMsg(c, "组名称已存在")
+ return
+ }
+
+ if err := g.Insert(); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ common.ApiSuccess(c, &g)
+}
+
+// UpdatePrefillGroup 更新预填组
+func UpdatePrefillGroup(c *gin.Context) {
+ var g model.PrefillGroup
+ if err := c.ShouldBindJSON(&g); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if g.Id == 0 {
+ common.ApiErrorMsg(c, "缺少组 ID")
+ return
+ }
+ // 名称冲突检查
+ if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil {
+ common.ApiError(c, err)
+ return
+ } else if dup {
+ common.ApiErrorMsg(c, "组名称已存在")
+ return
+ }
+
+ if err := g.Update(); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ common.ApiSuccess(c, &g)
+}
+
+// DeletePrefillGroup 删除预填组
+func DeletePrefillGroup(c *gin.Context) {
+ idStr := c.Param("id")
+ id, err := strconv.Atoi(idStr)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if err := model.DeletePrefillGroupByID(id); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ common.ApiSuccess(c, nil)
+}
diff --git a/controller/pricing.go b/controller/pricing.go
index 1cbfe731..4b7cc86d 100644
--- a/controller/pricing.go
+++ b/controller/pricing.go
@@ -1,10 +1,11 @@
package controller
import (
- "github.com/gin-gonic/gin"
"one-api/model"
"one-api/setting"
- "one-api/setting/operation_setting"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
)
func GetPricing(c *gin.Context) {
@@ -12,7 +13,7 @@ func GetPricing(c *gin.Context) {
userId, exists := c.Get("id")
usableGroup := map[string]string{}
groupRatio := map[string]float64{}
- for s, f := range setting.GetGroupRatioCopy() {
+ for s, f := range ratio_setting.GetGroupRatioCopy() {
groupRatio[s] = f
}
var group string
@@ -20,27 +21,36 @@ func GetPricing(c *gin.Context) {
user, err := model.GetUserCache(userId.(int))
if err == nil {
group = user.Group
+ for g := range groupRatio {
+ ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
+ if ok {
+ groupRatio[g] = ratio
+ }
+ }
}
}
usableGroup = setting.GetUserUsableGroups(group)
// check groupRatio contains usableGroup
- for group := range setting.GetGroupRatioCopy() {
+ for group := range ratio_setting.GetGroupRatioCopy() {
if _, ok := usableGroup[group]; !ok {
delete(groupRatio, group)
}
}
c.JSON(200, gin.H{
- "success": true,
- "data": pricing,
- "group_ratio": groupRatio,
- "usable_group": usableGroup,
+ "success": true,
+ "data": pricing,
+ "vendors": model.GetVendors(),
+ "group_ratio": groupRatio,
+ "usable_group": usableGroup,
+ "supported_endpoint": model.GetSupportedEndpointMap(),
+ "auto_groups": setting.AutoGroups,
})
}
func ResetModelRatio(c *gin.Context) {
- defaultStr := operation_setting.DefaultModelRatio2JSONString()
+ defaultStr := ratio_setting.DefaultModelRatio2JSONString()
err := model.UpdateOption("ModelRatio", defaultStr)
if err != nil {
c.JSON(200, gin.H{
@@ -49,7 +59,7 @@ func ResetModelRatio(c *gin.Context) {
})
return
}
- err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
+ err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
if err != nil {
c.JSON(200, gin.H{
"success": false,
diff --git a/controller/ratio_config.go b/controller/ratio_config.go
new file mode 100644
index 00000000..6ddc3d9e
--- /dev/null
+++ b/controller/ratio_config.go
@@ -0,0 +1,24 @@
+package controller
+
+import (
+ "net/http"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GetRatioConfig(c *gin.Context) {
+ if !ratio_setting.IsExposeRatioEnabled() {
+ c.JSON(http.StatusForbidden, gin.H{
+ "success": false,
+ "message": "倍率配置接口未启用",
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": ratio_setting.GetExposedData(),
+ })
+}
\ No newline at end of file
diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go
new file mode 100644
index 00000000..6fba0aac
--- /dev/null
+++ b/controller/ratio_sync.go
@@ -0,0 +1,474 @@
+package controller
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "one-api/logger"
+ "strings"
+ "sync"
+ "time"
+
+ "one-api/dto"
+ "one-api/model"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ defaultTimeoutSeconds = 10
+ defaultEndpoint = "/api/ratio_config"
+ maxConcurrentFetches = 8
+)
+
+var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
+
+type upstreamResult struct {
+ Name string `json:"name"`
+ Data map[string]any `json:"data,omitempty"`
+ Err string `json:"err,omitempty"`
+}
+
+func FetchUpstreamRatios(c *gin.Context) {
+ var req dto.UpstreamRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+
+ if req.Timeout <= 0 {
+ req.Timeout = defaultTimeoutSeconds
+ }
+
+ var upstreams []dto.UpstreamDTO
+
+ if len(req.Upstreams) > 0 {
+ for _, u := range req.Upstreams {
+ if strings.HasPrefix(u.BaseURL, "http") {
+ if u.Endpoint == "" {
+ u.Endpoint = defaultEndpoint
+ }
+ u.BaseURL = strings.TrimRight(u.BaseURL, "/")
+ upstreams = append(upstreams, u)
+ }
+ }
+ } else if len(req.ChannelIDs) > 0 {
+ intIds := make([]int, 0, len(req.ChannelIDs))
+ for _, id64 := range req.ChannelIDs {
+ intIds = append(intIds, int(id64))
+ }
+ dbChannels, err := model.GetChannelsByIds(intIds)
+ if err != nil {
+ logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
+ c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
+ return
+ }
+ for _, ch := range dbChannels {
+ if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
+ upstreams = append(upstreams, dto.UpstreamDTO{
+ ID: ch.Id,
+ Name: ch.Name,
+ BaseURL: strings.TrimRight(base, "/"),
+ Endpoint: "",
+ })
+ }
+ }
+ }
+
+ if len(upstreams) == 0 {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
+ return
+ }
+
+ var wg sync.WaitGroup
+ ch := make(chan upstreamResult, len(upstreams))
+
+ sem := make(chan struct{}, maxConcurrentFetches)
+
+ client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
+
+ for _, chn := range upstreams {
+ wg.Add(1)
+ go func(chItem dto.UpstreamDTO) {
+ defer wg.Done()
+
+ sem <- struct{}{}
+ defer func() { <-sem }()
+
+ endpoint := chItem.Endpoint
+ if endpoint == "" {
+ endpoint = defaultEndpoint
+ } else if !strings.HasPrefix(endpoint, "/") {
+ endpoint = "/" + endpoint
+ }
+ fullURL := chItem.BaseURL + endpoint
+
+ uniqueName := chItem.Name
+ if chItem.ID != 0 {
+ uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
+ }
+
+ ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
+ defer cancel()
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
+ if err != nil {
+ logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
+ ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+ return
+ }
+
+ resp, err := client.Do(httpReq)
+ if err != nil {
+ logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
+ ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+ return
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
+ ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
+ return
+ }
+ // 兼容两种上游接口格式:
+ // type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
+ // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
+ var body struct {
+ Success bool `json:"success"`
+ Data json.RawMessage `json:"data"`
+ Message string `json:"message"`
+ }
+
+ if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
+ logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
+ ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+ return
+ }
+
+ if !body.Success {
+ ch <- upstreamResult{Name: uniqueName, Err: body.Message}
+ return
+ }
+
+ // 尝试按 type1 解析
+ var type1Data map[string]any
+ if err := json.Unmarshal(body.Data, &type1Data); err == nil {
+ // 如果包含至少一个 ratioTypes 字段,则认为是 type1
+ isType1 := false
+ for _, rt := range ratioTypes {
+ if _, ok := type1Data[rt]; ok {
+ isType1 = true
+ break
+ }
+ }
+ if isType1 {
+ ch <- upstreamResult{Name: uniqueName, Data: type1Data}
+ return
+ }
+ }
+
+ // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
+ var pricingItems []struct {
+ ModelName string `json:"model_name"`
+ QuotaType int `json:"quota_type"`
+ ModelRatio float64 `json:"model_ratio"`
+ ModelPrice float64 `json:"model_price"`
+ CompletionRatio float64 `json:"completion_ratio"`
+ }
+ if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
+ logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
+ ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
+ return
+ }
+
+ modelRatioMap := make(map[string]float64)
+ completionRatioMap := make(map[string]float64)
+ modelPriceMap := make(map[string]float64)
+
+ for _, item := range pricingItems {
+ if item.QuotaType == 1 {
+ modelPriceMap[item.ModelName] = item.ModelPrice
+ } else {
+ modelRatioMap[item.ModelName] = item.ModelRatio
+ // completionRatio 可能为 0,此时也直接赋值,保持与上游一致
+ completionRatioMap[item.ModelName] = item.CompletionRatio
+ }
+ }
+
+ converted := make(map[string]any)
+
+ if len(modelRatioMap) > 0 {
+ ratioAny := make(map[string]any, len(modelRatioMap))
+ for k, v := range modelRatioMap {
+ ratioAny[k] = v
+ }
+ converted["model_ratio"] = ratioAny
+ }
+
+ if len(completionRatioMap) > 0 {
+ compAny := make(map[string]any, len(completionRatioMap))
+ for k, v := range completionRatioMap {
+ compAny[k] = v
+ }
+ converted["completion_ratio"] = compAny
+ }
+
+ if len(modelPriceMap) > 0 {
+ priceAny := make(map[string]any, len(modelPriceMap))
+ for k, v := range modelPriceMap {
+ priceAny[k] = v
+ }
+ converted["model_price"] = priceAny
+ }
+
+ ch <- upstreamResult{Name: uniqueName, Data: converted}
+ }(chn)
+ }
+
+ wg.Wait()
+ close(ch)
+
+ localData := ratio_setting.GetExposedData()
+
+ var testResults []dto.TestResult
+ var successfulChannels []struct {
+ name string
+ data map[string]any
+ }
+
+ for r := range ch {
+ if r.Err != "" {
+ testResults = append(testResults, dto.TestResult{
+ Name: r.Name,
+ Status: "error",
+ Error: r.Err,
+ })
+ } else {
+ testResults = append(testResults, dto.TestResult{
+ Name: r.Name,
+ Status: "success",
+ })
+ successfulChannels = append(successfulChannels, struct {
+ name string
+ data map[string]any
+ }{name: r.Name, data: r.Data})
+ }
+ }
+
+ differences := buildDifferences(localData, successfulChannels)
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "data": gin.H{
+ "differences": differences,
+ "test_results": testResults,
+ },
+ })
+}
+
+func buildDifferences(localData map[string]any, successfulChannels []struct {
+ name string
+ data map[string]any
+}) map[string]map[string]dto.DifferenceItem {
+ differences := make(map[string]map[string]dto.DifferenceItem)
+
+ allModels := make(map[string]struct{})
+
+ for _, ratioType := range ratioTypes {
+ if localRatioAny, ok := localData[ratioType]; ok {
+ if localRatio, ok := localRatioAny.(map[string]float64); ok {
+ for modelName := range localRatio {
+ allModels[modelName] = struct{}{}
+ }
+ }
+ }
+ }
+
+ for _, channel := range successfulChannels {
+ for _, ratioType := range ratioTypes {
+ if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
+ for modelName := range upstreamRatio {
+ allModels[modelName] = struct{}{}
+ }
+ }
+ }
+ }
+
+ confidenceMap := make(map[string]map[string]bool)
+
+ // 预处理阶段:检查pricing接口的可信度
+ for _, channel := range successfulChannels {
+ confidenceMap[channel.name] = make(map[string]bool)
+
+ modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
+ completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
+
+ if hasModelRatio && hasCompletionRatio {
+ // 遍历所有模型,检查是否满足不可信条件
+ for modelName := range allModels {
+ // 默认为可信
+ confidenceMap[channel.name][modelName] = true
+
+ // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
+ if modelRatioVal, ok := modelRatios[modelName]; ok {
+ if completionRatioVal, ok := completionRatios[modelName]; ok {
+ // 转换为float64进行比较
+ if modelRatioFloat, ok := modelRatioVal.(float64); ok {
+ if completionRatioFloat, ok := completionRatioVal.(float64); ok {
+ if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
+ confidenceMap[channel.name][modelName] = false
+ }
+ }
+ }
+ }
+ }
+ }
+ } else {
+ // 如果不是从pricing接口获取的数据,则全部标记为可信
+ for modelName := range allModels {
+ confidenceMap[channel.name][modelName] = true
+ }
+ }
+ }
+
+ for modelName := range allModels {
+ for _, ratioType := range ratioTypes {
+ var localValue interface{} = nil
+ if localRatioAny, ok := localData[ratioType]; ok {
+ if localRatio, ok := localRatioAny.(map[string]float64); ok {
+ if val, exists := localRatio[modelName]; exists {
+ localValue = val
+ }
+ }
+ }
+
+ upstreamValues := make(map[string]interface{})
+ confidenceValues := make(map[string]bool)
+ hasUpstreamValue := false
+ hasDifference := false
+
+ for _, channel := range successfulChannels {
+ var upstreamValue interface{} = nil
+
+ if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
+ if val, exists := upstreamRatio[modelName]; exists {
+ upstreamValue = val
+ hasUpstreamValue = true
+
+ if localValue != nil && localValue != val {
+ hasDifference = true
+ } else if localValue == val {
+ upstreamValue = "same"
+ }
+ }
+ }
+ if upstreamValue == nil && localValue == nil {
+ upstreamValue = "same"
+ }
+
+ if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
+ hasDifference = true
+ }
+
+ upstreamValues[channel.name] = upstreamValue
+
+ confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
+ }
+
+ shouldInclude := false
+
+ if localValue != nil {
+ if hasDifference {
+ shouldInclude = true
+ }
+ } else {
+ if hasUpstreamValue {
+ shouldInclude = true
+ }
+ }
+
+ if shouldInclude {
+ if differences[modelName] == nil {
+ differences[modelName] = make(map[string]dto.DifferenceItem)
+ }
+ differences[modelName][ratioType] = dto.DifferenceItem{
+ Current: localValue,
+ Upstreams: upstreamValues,
+ Confidence: confidenceValues,
+ }
+ }
+ }
+ }
+
+ channelHasDiff := make(map[string]bool)
+ for _, ratioMap := range differences {
+ for _, item := range ratioMap {
+ for chName, val := range item.Upstreams {
+ if val != nil && val != "same" {
+ channelHasDiff[chName] = true
+ }
+ }
+ }
+ }
+
+ for modelName, ratioMap := range differences {
+ for ratioType, item := range ratioMap {
+ for chName := range item.Upstreams {
+ if !channelHasDiff[chName] {
+ delete(item.Upstreams, chName)
+ delete(item.Confidence, chName)
+ }
+ }
+
+ allSame := true
+ for _, v := range item.Upstreams {
+ if v != "same" {
+ allSame = false
+ break
+ }
+ }
+ if len(item.Upstreams) == 0 || allSame {
+ delete(ratioMap, ratioType)
+ } else {
+ differences[modelName][ratioType] = item
+ }
+ }
+
+ if len(ratioMap) == 0 {
+ delete(differences, modelName)
+ }
+ }
+
+ return differences
+}
+
+func GetSyncableChannels(c *gin.Context) {
+ channels, err := model.GetAllChannels(0, 0, true, false)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ var syncableChannels []dto.SyncableChannel
+ for _, channel := range channels {
+ if channel.GetBaseURL() != "" {
+ syncableChannels = append(syncableChannels, dto.SyncableChannel{
+ ID: channel.Id,
+ Name: channel.Name,
+ BaseURL: channel.GetBaseURL(),
+ Status: channel.Status,
+ })
+ }
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": syncableChannels,
+ })
+}
diff --git a/controller/redemption.go b/controller/redemption.go
index a7e09a8a..1e305e3d 100644
--- a/controller/redemption.go
+++ b/controller/redemption.go
@@ -1,90 +1,52 @@
package controller
import (
+ "errors"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
+ "unicode/utf8"
"github.com/gin-gonic/gin"
)
func GetAllRedemptions(c *gin.Context) {
- p, _ := strconv.Atoi(c.Query("p"))
- pageSize, _ := strconv.Atoi(c.Query("page_size"))
- if p < 0 {
- p = 0
- }
- if pageSize < 1 {
- pageSize = common.ItemsPerPage
- }
- redemptions, total, err := model.GetAllRedemptions((p-1)*pageSize, pageSize)
+ pageInfo := common.GetPageQuery(c)
+ redemptions, total, err := model.GetAllRedemptions(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": gin.H{
- "items": redemptions,
- "total": total,
- "page": p,
- "page_size": pageSize,
- },
- })
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(redemptions)
+ common.ApiSuccess(c, pageInfo)
return
}
func SearchRedemptions(c *gin.Context) {
keyword := c.Query("keyword")
- p, _ := strconv.Atoi(c.Query("p"))
- pageSize, _ := strconv.Atoi(c.Query("page_size"))
- if p < 0 {
- p = 0
- }
- if pageSize < 1 {
- pageSize = common.ItemsPerPage
- }
- redemptions, total, err := model.SearchRedemptions(keyword, (p-1)*pageSize, pageSize)
+ pageInfo := common.GetPageQuery(c)
+ redemptions, total, err := model.SearchRedemptions(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": gin.H{
- "items": redemptions,
- "total": total,
- "page": p,
- "page_size": pageSize,
- },
- })
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(redemptions)
+ common.ApiSuccess(c, pageInfo)
return
}
func GetRedemption(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
redemption, err := model.GetRedemptionById(id)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -99,13 +61,10 @@ func AddRedemption(c *gin.Context) {
redemption := model.Redemption{}
err := c.ShouldBindJSON(&redemption)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
- if len(redemption.Name) == 0 || len(redemption.Name) > 20 {
+ if utf8.RuneCountInString(redemption.Name) == 0 || utf8.RuneCountInString(redemption.Name) > 20 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "兑换码名称长度必须在1-20之间",
@@ -126,6 +85,10 @@ func AddRedemption(c *gin.Context) {
})
return
}
+ if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
var keys []string
for i := 0; i < redemption.Count; i++ {
key := common.GetUUID()
@@ -135,6 +98,7 @@ func AddRedemption(c *gin.Context) {
Key: key,
CreatedTime: common.GetTimestamp(),
Quota: redemption.Quota,
+ ExpiredTime: redemption.ExpiredTime,
}
err = cleanRedemption.Insert()
if err != nil {
@@ -159,10 +123,7 @@ func DeleteRedemption(c *gin.Context) {
id, _ := strconv.Atoi(c.Param("id"))
err := model.DeleteRedemptionById(id)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -177,33 +138,30 @@ func UpdateRedemption(c *gin.Context) {
redemption := model.Redemption{}
err := c.ShouldBindJSON(&redemption)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
cleanRedemption, err := model.GetRedemptionById(redemption.Id)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
- if statusOnly != "" {
- cleanRedemption.Status = redemption.Status
- } else {
+ if statusOnly == "" {
+ if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
// If you add more fields, please also update redemption.Update()
cleanRedemption.Name = redemption.Name
cleanRedemption.Quota = redemption.Quota
+ cleanRedemption.ExpiredTime = redemption.ExpiredTime
+ }
+ if statusOnly != "" {
+ cleanRedemption.Status = redemption.Status
}
err = cleanRedemption.Update()
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -213,3 +171,24 @@ func UpdateRedemption(c *gin.Context) {
})
return
}
+
+func DeleteInvalidRedemption(c *gin.Context) {
+ rows, err := model.DeleteInvalidRedemptions()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": rows,
+ })
+ return
+}
+
+func validateExpiredTime(expired int64) error {
+ if expired != 0 && expired < common.GetTimestamp() {
+ return errors.New("过期时间不能早于当前时间")
+ }
+ return nil
+}
diff --git a/controller/relay.go b/controller/relay.go
index 1a875dbc..c055ef71 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -2,114 +2,192 @@ package controller
import (
"bytes"
- "errors"
"fmt"
+ "github.com/bytedance/gopkg/util/gopool"
"io"
"log"
"net/http"
"one-api/common"
- constant2 "one-api/constant"
+ "one-api/constant"
"one-api/dto"
+ "one-api/logger"
"one-api/middleware"
"one-api/model"
"one-api/relay"
- "one-api/relay/constant"
+ relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
+ "one-api/setting"
+ "one-api/types"
"strings"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
-func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
- var err *dto.OpenAIErrorWithStatusCode
- switch relayMode {
+func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
+ var err *types.NewAPIError
+ switch info.RelayMode {
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
- err = relay.ImageHelper(c)
+ err = relay.ImageHelper(c, info)
case relayconstant.RelayModeAudioSpeech:
fallthrough
case relayconstant.RelayModeAudioTranslation:
fallthrough
case relayconstant.RelayModeAudioTranscription:
- err = relay.AudioHelper(c)
+ err = relay.AudioHelper(c, info)
case relayconstant.RelayModeRerank:
- err = relay.RerankHelper(c, relayMode)
+ err = relay.RerankHelper(c, info)
case relayconstant.RelayModeEmbeddings:
- err = relay.EmbeddingHelper(c)
+ err = relay.EmbeddingHelper(c, info)
case relayconstant.RelayModeResponses:
- err = relay.ResponsesHelper(c)
- case relayconstant.RelayModeGemini:
- err = relay.GeminiHelper(c)
+ err = relay.ResponsesHelper(c, info)
default:
- err = relay.TextHelper(c)
+ err = relay.TextHelper(c, info)
}
-
- if constant2.ErrorLogEnabled && err != nil {
- // 保存错误日志到mysql中
- userId := c.GetInt("id")
- tokenName := c.GetString("token_name")
- modelName := c.GetString("original_model")
- tokenId := c.GetInt("token_id")
- userGroup := c.GetString("group")
- channelId := c.GetInt("channel_id")
- other := make(map[string]interface{})
- other["error_type"] = err.Error.Type
- other["error_code"] = err.Error.Code
- other["status_code"] = err.StatusCode
- other["channel_id"] = channelId
- other["channel_name"] = c.GetString("channel_name")
- other["channel_type"] = c.GetInt("channel_type")
-
- model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error.Message, tokenId, 0, false, userGroup, other)
- }
-
return err
}
-func Relay(c *gin.Context) {
- relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
+ var err *types.NewAPIError
+ if strings.Contains(c.Request.URL.Path, "embed") {
+ err = relay.GeminiEmbeddingHandler(c, info)
+ } else {
+ err = relay.GeminiHelper(c, info)
+ }
+ return err
+}
+
+func Relay(c *gin.Context, relayFormat types.RelayFormat) {
+
requestId := c.GetString(common.RequestIdKey)
- group := c.GetString("group")
- originalModel := c.GetString("original_model")
- var openaiErr *dto.OpenAIErrorWithStatusCode
+ group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
+ originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
+
+ var (
+ newAPIError *types.NewAPIError
+ ws *websocket.Conn
+ )
+
+ if relayFormat == types.RelayFormatOpenAIRealtime {
+ var err error
+ ws, err = upgrader.Upgrade(c.Writer, c.Request, nil)
+ if err != nil {
+ helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
+ return
+ }
+ defer ws.Close()
+ }
+
+ defer func() {
+ if newAPIError != nil {
+ newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
+ switch relayFormat {
+ case types.RelayFormatOpenAIRealtime:
+ helper.WssError(c, ws, newAPIError.ToOpenAIError())
+ case types.RelayFormatClaude:
+ c.JSON(newAPIError.StatusCode, gin.H{
+ "type": "error",
+ "error": newAPIError.ToClaudeError(),
+ })
+ default:
+ c.JSON(newAPIError.StatusCode, gin.H{
+ "error": newAPIError.ToOpenAIError(),
+ })
+ }
+ }
+ }()
+
+ request, err := helper.GetAndValidateRequest(c, relayFormat)
+ if err != nil {
+ newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
+ return
+ }
+
+ relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws)
+ if err != nil {
+ newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed)
+ return
+ }
+
+ meta := request.GetTokenCountMeta()
+
+ if setting.ShouldCheckPromptSensitive() {
+ contains, words := service.CheckSensitiveText(meta.CombineText)
+ if contains {
+ logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
+ newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
+ return
+ }
+ }
+
+ tokens, err := service.CountRequestToken(c, meta, relayInfo)
+ if err != nil {
+ newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed)
+ return
+ }
+
+ relayInfo.SetPromptTokens(tokens)
+
+ priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
+ if err != nil {
+ newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)
+ return
+ }
+
+ // common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
+
+ preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+ if newAPIError != nil {
+ return
+ }
+
+ defer func() {
+ // Only return quota if downstream failed and quota was actually pre-consumed
+ if newAPIError != nil && preConsumedQuota != 0 {
+ service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
+ }
+ }()
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
- common.LogError(c, err.Error())
- openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
+ logger.LogError(c, err.Error())
+ newAPIError = err
break
}
- openaiErr = relayRequest(c, relayMode, channel)
+ addUsedChannel(c, channel.Id)
+ requestBody, _ := common.GetRequestBody(c)
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
- if openaiErr == nil {
- return // 成功处理请求,直接返回
+ switch relayFormat {
+ case types.RelayFormatOpenAIRealtime:
+ newAPIError = relay.WssHelper(c, relayInfo)
+ case types.RelayFormatClaude:
+ newAPIError = relay.ClaudeHelper(c, relayInfo)
+ case types.RelayFormatGemini:
+ newAPIError = geminiRelayHandler(c, relayInfo)
+ default:
+ newAPIError = relayHandler(c, relayInfo)
}
- go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
+ if newAPIError == nil {
+ return
+ }
- if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
+ processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
+
+ if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
break
}
}
+
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
- common.LogInfo(c, retryLogStr)
- }
-
- if openaiErr != nil {
- if openaiErr.StatusCode == http.StatusTooManyRequests {
- common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message))
- openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
- }
- openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
- c.JSON(openaiErr.StatusCode, gin.H{
- "error": openaiErr.Error,
- })
+ logger.LogInfo(c, retryLogStr)
}
}
@@ -120,132 +198,13 @@ var upgrader = websocket.Upgrader{
},
}
-func WssRelay(c *gin.Context) {
- // 将 HTTP 连接升级为 WebSocket 连接
-
- ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
- defer ws.Close()
-
- if err != nil {
- openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
- helper.WssError(c, ws, openaiErr.Error)
- return
- }
-
- relayMode := constant.Path2RelayMode(c.Request.URL.Path)
- requestId := c.GetString(common.RequestIdKey)
- group := c.GetString("group")
- //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
- originalModel := c.GetString("original_model")
- var openaiErr *dto.OpenAIErrorWithStatusCode
-
- for i := 0; i <= common.RetryTimes; i++ {
- channel, err := getChannel(c, group, originalModel, i)
- if err != nil {
- common.LogError(c, err.Error())
- openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
- break
- }
-
- openaiErr = wssRequest(c, ws, relayMode, channel)
-
- if openaiErr == nil {
- return // 成功处理请求,直接返回
- }
-
- go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
-
- if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
- break
- }
- }
- useChannel := c.GetStringSlice("use_channel")
- if len(useChannel) > 1 {
- retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
- common.LogInfo(c, retryLogStr)
- }
-
- if openaiErr != nil {
- if openaiErr.StatusCode == http.StatusTooManyRequests {
- openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
- }
- openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
- helper.WssError(c, ws, openaiErr.Error)
- }
-}
-
-func RelayClaude(c *gin.Context) {
- //relayMode := constant.Path2RelayMode(c.Request.URL.Path)
- requestId := c.GetString(common.RequestIdKey)
- group := c.GetString("group")
- originalModel := c.GetString("original_model")
- var claudeErr *dto.ClaudeErrorWithStatusCode
-
- for i := 0; i <= common.RetryTimes; i++ {
- channel, err := getChannel(c, group, originalModel, i)
- if err != nil {
- common.LogError(c, err.Error())
- claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
- break
- }
-
- claudeErr = claudeRequest(c, channel)
-
- if claudeErr == nil {
- return // 成功处理请求,直接返回
- }
-
- openaiErr := service.ClaudeErrorToOpenAIError(claudeErr)
-
- go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
-
- if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
- break
- }
- }
- useChannel := c.GetStringSlice("use_channel")
- if len(useChannel) > 1 {
- retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
- common.LogInfo(c, retryLogStr)
- }
-
- if claudeErr != nil {
- claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId)
- c.JSON(claudeErr.StatusCode, gin.H{
- "type": "error",
- "error": claudeErr.Error,
- })
- }
-}
-
-func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
- addUsedChannel(c, channel.Id)
- requestBody, _ := common.GetRequestBody(c)
- c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
- return relayHandler(c, relayMode)
-}
-
-func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
- addUsedChannel(c, channel.Id)
- requestBody, _ := common.GetRequestBody(c)
- c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
- return relay.WssHelper(c, ws)
-}
-
-func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode {
- addUsedChannel(c, channel.Id)
- requestBody, _ := common.GetRequestBody(c)
- c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
- return relay.ClaudeHelper(c)
-}
-
func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
}
-func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
+func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
if retryCount == 0 {
autoBan := c.GetBool("auto_ban")
autoBanInt := 1
@@ -259,19 +218,28 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
AutoBan: &autoBanInt,
}, nil
}
- channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
+ channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil {
- return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
+ return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
+ }
+ if channel == nil {
+ return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
+ }
+ newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+ if newAPIError != nil {
+ return nil, newAPIError
}
- middleware.SetupContextForSelectedChannel(c, channel, originalModel)
return channel, nil
}
-func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
+func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
if openaiErr == nil {
return false
}
- if openaiErr.LocalError {
+ if types.IsChannelError(openaiErr) {
+ return true
+ }
+ if types.IsSkipRetryError(openaiErr) {
return false
}
if retryTimes <= 0 {
@@ -294,10 +262,6 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
return true
}
if openaiErr.StatusCode == http.StatusBadRequest {
- channelType := c.GetInt("channel_type")
- if channelType == common.ChannelTypeAnthropic {
- return true
- }
return false
}
if openaiErr.StatusCode == 408 {
@@ -310,45 +274,85 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
return true
}
-func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
- // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
- // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
- common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
- if service.ShouldDisableChannel(channelType, err) && autoBan {
- service.DisableChannel(channelId, channelName, err.Error.Message)
+func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
+ logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
+
+ gopool.Go(func() {
+ // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
+ // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
+ if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
+ service.DisableChannel(channelError, err.Error())
+ }
+ })
+
+ if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
+ // 保存错误日志到mysql中
+ userId := c.GetInt("id")
+ tokenName := c.GetString("token_name")
+ modelName := c.GetString("original_model")
+ tokenId := c.GetInt("token_id")
+ userGroup := c.GetString("group")
+ channelId := c.GetInt("channel_id")
+ other := make(map[string]interface{})
+ other["error_type"] = err.GetErrorType()
+ other["error_code"] = err.GetErrorCode()
+ other["status_code"] = err.StatusCode
+ other["channel_id"] = channelId
+ other["channel_name"] = c.GetString("channel_name")
+ other["channel_type"] = c.GetInt("channel_type")
+ adminInfo := make(map[string]interface{})
+ adminInfo["use_channel"] = c.GetStringSlice("use_channel")
+ isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
+ if isMultiKey {
+ adminInfo["is_multi_key"] = true
+ adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
+ }
+ other["admin_info"] = adminInfo
+ model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
}
+
}
func RelayMidjourney(c *gin.Context) {
- relayMode := c.GetInt("relay_mode")
- var err *dto.MidjourneyResponse
- switch relayMode {
+ relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil)
+
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "description": fmt.Sprintf("failed to generate relay info: %s", err.Error()),
+ "type": "upstream_error",
+ "code": 4,
+ })
+ return
+ }
+
+ var mjErr *dto.MidjourneyResponse
+ switch relayInfo.RelayMode {
case relayconstant.RelayModeMidjourneyNotify:
- err = relay.RelayMidjourneyNotify(c)
+ mjErr = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
- err = relay.RelayMidjourneyTask(c, relayMode)
+ mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode)
case relayconstant.RelayModeMidjourneyTaskImageSeed:
- err = relay.RelayMidjourneyTaskImageSeed(c)
+ mjErr = relay.RelayMidjourneyTaskImageSeed(c)
case relayconstant.RelayModeSwapFace:
- err = relay.RelaySwapFace(c)
+ mjErr = relay.RelaySwapFace(c, relayInfo)
default:
- err = relay.RelayMidjourneySubmit(c, relayMode)
+ mjErr = relay.RelayMidjourneySubmit(c, relayInfo)
}
//err = relayMidjourneySubmit(c, relayMode)
- log.Println(err)
- if err != nil {
+ log.Println(mjErr)
+ if mjErr != nil {
statusCode := http.StatusBadRequest
- if err.Code == 30 {
- err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
+ if mjErr.Code == 30 {
+ mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
statusCode = http.StatusTooManyRequests
}
c.JSON(statusCode, gin.H{
- "description": fmt.Sprintf("%s %s", err.Description, err.Result),
+ "description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result),
"type": "upstream_error",
- "code": err.Code,
+ "code": mjErr.Code,
})
channelId := c.GetInt("channel_id")
- common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
+ logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result)))
}
}
@@ -388,26 +392,27 @@ func RelayTask(c *gin.Context) {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
- channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
- if err != nil {
- common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
+ channel, newAPIError := getChannel(c, group, originalModel, i)
+ if newAPIError != nil {
+ logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
+ taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
break
}
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
- common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
- middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+ logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
+ //middleware.SetupContextForSelectedChannel(c, channel, originalModel)
- requestBody, err := common.GetRequestBody(c)
+ requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
taskErr = taskRelayHandler(c, relayMode)
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
- common.LogInfo(c, retryLogStr)
+ logger.LogInfo(c, retryLogStr)
}
if taskErr != nil {
if taskErr.StatusCode == http.StatusTooManyRequests {
@@ -420,7 +425,7 @@ func RelayTask(c *gin.Context) {
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
var err *dto.TaskError
switch relayMode {
- case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
+ case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
err = relay.RelayTaskFetch(c, relayMode)
default:
err = relay.RelayTaskSubmit(c, relayMode)
diff --git a/controller/setup.go b/controller/setup.go
index 0a13bcf9..8943a1a0 100644
--- a/controller/setup.go
+++ b/controller/setup.go
@@ -75,6 +75,14 @@ func PostSetup(c *gin.Context) {
// If root doesn't exist, validate and create admin account
if !rootExists {
+ // Validate username length: max 12 characters to align with model.User validation
+ if len(req.Username) > 12 {
+ c.JSON(400, gin.H{
+ "success": false,
+ "message": "用户名长度不能超过12个字符",
+ })
+ return
+ }
// Validate password
if req.Password != req.ConfirmPassword {
c.JSON(400, gin.H{
diff --git a/controller/swag_video.go b/controller/swag_video.go
new file mode 100644
index 00000000..68dd6345
--- /dev/null
+++ b/controller/swag_video.go
@@ -0,0 +1,136 @@
+package controller
+
+import (
+ "github.com/gin-gonic/gin"
+)
+
+// VideoGenerations
+// @Summary 生成视频
+// @Description 调用视频生成接口生成视频
+// @Description 支持多种视频生成服务:
+// @Description - 可灵AI (Kling): https://app.klingai.com/cn/dev/document-api/apiReference/commonInfo
+// @Description - 即梦 (Jimeng): https://www.volcengine.com/docs/85621/1538636
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
+// @Param request body dto.VideoRequest true "视频生成请求参数"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /v1/video/generations [post]
+func VideoGenerations(c *gin.Context) {
+}
+
+// VideoGenerationsTaskId
+// @Summary 查询视频
+// @Description 根据任务ID查询视频生成任务的状态和结果
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Security BearerAuth
+// @Param task_id path string true "Task ID"
+// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /v1/video/generations/{task_id} [get]
+func VideoGenerationsTaskId(c *gin.Context) {
+}
+
+// KlingText2VideoGenerations
+// @Summary 可灵文生视频
+// @Description 调用可灵AI文生视频接口,生成视频内容
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
+// @Param request body KlingText2VideoRequest true "视频生成请求参数"
+// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /kling/v1/videos/text2video [post]
+func KlingText2VideoGenerations(c *gin.Context) {
+}
+
+type KlingText2VideoRequest struct {
+ ModelName string `json:"model_name,omitempty" example:"kling-v1"`
+ Prompt string `json:"prompt" binding:"required" example:"A cat playing piano in the garden"`
+ NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
+ CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
+ Mode string `json:"mode,omitempty" example:"std"`
+ CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
+ AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
+ Duration string `json:"duration,omitempty" example:"5"`
+ CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
+ ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-001"`
+}
+
+type KlingCameraControl struct {
+ Type string `json:"type,omitempty" example:"simple"`
+ Config *KlingCameraConfig `json:"config,omitempty"`
+}
+
+type KlingCameraConfig struct {
+ Horizontal float64 `json:"horizontal,omitempty" example:"2.5"`
+ Vertical float64 `json:"vertical,omitempty" example:"0"`
+ Pan float64 `json:"pan,omitempty" example:"0"`
+ Tilt float64 `json:"tilt,omitempty" example:"0"`
+ Roll float64 `json:"roll,omitempty" example:"0"`
+ Zoom float64 `json:"zoom,omitempty" example:"0"`
+}
+
+// KlingImage2VideoGenerations
+// @Summary 可灵官方-图生视频
+// @Description 调用可灵AI图生视频接口,生成视频内容
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
+// @Param request body KlingImage2VideoRequest true "图生视频请求参数"
+// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /kling/v1/videos/image2video [post]
+func KlingImage2VideoGenerations(c *gin.Context) {
+}
+
+type KlingImage2VideoRequest struct {
+ ModelName string `json:"model_name,omitempty" example:"kling-v2-master"`
+ Image string `json:"image" binding:"required" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"`
+ Prompt string `json:"prompt,omitempty" example:"A cat playing piano in the garden"`
+ NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
+ CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
+ Mode string `json:"mode,omitempty" example:"std"`
+ CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
+ AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
+ Duration string `json:"duration,omitempty" example:"5"`
+ CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
+ ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-002"`
+}
+
+// KlingImage2videoTaskId godoc
+// @Summary 可灵任务查询--图生视频
+// @Description Query the status and result of a Kling video generation task by task ID
+// @Tags Origin
+// @Accept json
+// @Produce json
+// @Param task_id path string true "Task ID"
+// @Router /kling/v1/videos/image2video/{task_id} [get]
+func KlingImage2videoTaskId(c *gin.Context) {}
+
+// KlingText2videoTaskId godoc
+// @Summary 可灵任务查询--文生视频
+// @Description Query the status and result of a Kling text-to-video generation task by task ID
+// @Tags Origin
+// @Accept json
+// @Produce json
+// @Param task_id path string true "Task ID"
+// @Router /kling/v1/videos/text2video/{task_id} [get]
+func KlingText2videoTaskId(c *gin.Context) {}
diff --git a/controller/task.go b/controller/task.go
index 65f79ead..1082d7a1 100644
--- a/controller/task.go
+++ b/controller/task.go
@@ -5,18 +5,20 @@ import (
"encoding/json"
"errors"
"fmt"
- "github.com/gin-gonic/gin"
- "github.com/samber/lo"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
"one-api/model"
"one-api/relay"
"sort"
"strconv"
"time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/samber/lo"
)
func UpdateTaskBulk() {
@@ -53,9 +55,9 @@ func UpdateTaskBulk() {
"progress": "100%",
})
if err != nil {
- common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
+ logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
} else {
- common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
+ logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
}
}
if len(taskChannelM) == 0 {
@@ -75,7 +77,9 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
case constant.TaskPlatformSuno:
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
default:
- common.SysLog("未知平台")
+ if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
+ common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
+ }
}
}
@@ -83,14 +87,14 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM
for channelId, taskIds := range taskChannelM {
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
+ logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
}
}
return nil
}
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
- common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+ logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
}
@@ -103,7 +107,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
"progress": "100%",
})
if err != nil {
- common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
+ common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
}
return err
}
@@ -115,23 +119,23 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
"ids": taskIds,
})
if err != nil {
- common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
+ common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
return err
}
if resp.StatusCode != http.StatusOK {
- common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+ logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
+ common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
return err
}
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
- common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
+ logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
return err
}
if !responseItems.IsSuccess() {
@@ -151,19 +155,19 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
- common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
+ logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
task.Progress = "100%"
//err = model.CacheUpdateUserQuota(task.UserId) ?
if err != nil {
- common.LogError(ctx, "error update user quota cache: "+err.Error())
+ logger.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota, false)
if err != nil {
- common.LogError(ctx, "fail to increase user quota: "+err.Error())
+ logger.LogError(ctx, "fail to increase user quota: "+err.Error())
}
- logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
+ logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
@@ -175,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
err = task.Update()
if err != nil {
- common.SysError("UpdateMidjourneyTask task error: " + err.Error())
+ common.SysLog("UpdateMidjourneyTask task error: " + err.Error())
}
}
return nil
@@ -223,10 +227,8 @@ func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool
}
func GetAllTask(c *gin.Context) {
- p, _ := strconv.Atoi(c.Query("p"))
- if p < 0 {
- p = 0
- }
+ pageInfo := common.GetPageQuery(c)
+
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
// 解析其他查询参数
@@ -237,25 +239,18 @@ func GetAllTask(c *gin.Context) {
Action: c.Query("action"),
StartTimestamp: startTimestamp,
EndTimestamp: endTimestamp,
+ ChannelID: c.Query("channel_id"),
}
- logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
- if logs == nil {
- logs = make([]*model.Task, 0)
- }
-
- c.JSON(200, gin.H{
- "success": true,
- "message": "",
- "data": logs,
- })
+ items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
+ total := model.TaskCountAllTasks(queryParams)
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(items)
+ common.ApiSuccess(c, pageInfo)
}
func GetUserTask(c *gin.Context) {
- p, _ := strconv.Atoi(c.Query("p"))
- if p < 0 {
- p = 0
- }
+ pageInfo := common.GetPageQuery(c)
userId := c.GetInt("id")
@@ -271,14 +266,9 @@ func GetUserTask(c *gin.Context) {
EndTimestamp: endTimestamp,
}
- logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
- if logs == nil {
- logs = make([]*model.Task, 0)
- }
-
- c.JSON(200, gin.H{
- "success": true,
- "message": "",
- "data": logs,
- })
+ items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
+ total := model.TaskCountAllUserTask(userId, queryParams)
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(items)
+ common.ApiSuccess(c, pageInfo)
}
diff --git a/controller/task_video.go b/controller/task_video.go
new file mode 100644
index 00000000..ffb6728b
--- /dev/null
+++ b/controller/task_video.go
@@ -0,0 +1,148 @@
+package controller
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/logger"
+ "one-api/model"
+ "one-api/relay"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "time"
+)
+
+func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+ for channelId, taskIds := range taskChannelM {
+ if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
+ logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
+ }
+ }
+ return nil
+}
+
+func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+ logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
+ if len(taskIds) == 0 {
+ return nil
+ }
+ cacheGetChannel, err := model.CacheGetChannel(channelId)
+ if err != nil {
+ errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
+ "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
+ "status": "FAILURE",
+ "progress": "100%",
+ })
+ if errUpdate != nil {
+ common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
+ }
+ return fmt.Errorf("CacheGetChannel failed: %w", err)
+ }
+ adaptor := relay.GetTaskAdaptor(platform)
+ if adaptor == nil {
+ return fmt.Errorf("video adaptor not found")
+ }
+ for _, taskId := range taskIds {
+ if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
+ logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
+ }
+ }
+ return nil
+}
+
+func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
+ baseURL := constant.ChannelBaseURLs[channel.Type]
+ if channel.GetBaseURL() != "" {
+ baseURL = channel.GetBaseURL()
+ }
+
+ task := taskM[taskId]
+ if task == nil {
+ logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
+ return fmt.Errorf("task %s not found", taskId)
+ }
+ resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
+ "task_id": taskId,
+ "action": task.Action,
+ })
+ if err != nil {
+ return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
+ }
+ //if resp.StatusCode != http.StatusOK {
+ //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
+ //}
+ defer resp.Body.Close()
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
+ }
+
+ taskResult := &relaycommon.TaskInfo{}
+ // try parse as New API response format
+ var responseItems dto.TaskResponse[model.Task]
+ if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
+ t := responseItems.Data
+ taskResult.TaskID = t.TaskID
+ taskResult.Status = string(t.Status)
+ taskResult.Url = t.FailReason
+ taskResult.Progress = t.Progress
+ taskResult.Reason = t.FailReason
+ } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
+ return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
+ } else {
+ task.Data = responseBody
+ }
+
+ now := time.Now().Unix()
+ if taskResult.Status == "" {
+ return fmt.Errorf("task %s status is empty", taskId)
+ }
+ task.Status = model.TaskStatus(taskResult.Status)
+ switch taskResult.Status {
+ case model.TaskStatusSubmitted:
+ task.Progress = "10%"
+ case model.TaskStatusQueued:
+ task.Progress = "20%"
+ case model.TaskStatusInProgress:
+ task.Progress = "30%"
+ if task.StartTime == 0 {
+ task.StartTime = now
+ }
+ case model.TaskStatusSuccess:
+ task.Progress = "100%"
+ if task.FinishTime == 0 {
+ task.FinishTime = now
+ }
+ task.FailReason = taskResult.Url
+ case model.TaskStatusFailure:
+ task.Status = model.TaskStatusFailure
+ task.Progress = "100%"
+ if task.FinishTime == 0 {
+ task.FinishTime = now
+ }
+ task.FailReason = taskResult.Reason
+ logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
+ quota := task.Quota
+ if quota != 0 {
+ if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
+ logger.LogError(ctx, "Failed to increase user quota: "+err.Error())
+ }
+ logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
+ model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+ }
+ default:
+ return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
+ }
+ if taskResult.Progress != "" {
+ task.Progress = taskResult.Progress
+ }
+ if err := task.Update(); err != nil {
+ common.SysLog("UpdateVideoTask task error: " + err.Error())
+ }
+
+ return nil
+}
diff --git a/controller/token.go b/controller/token.go
index 0afb1391..5b96a2b7 100644
--- a/controller/token.go
+++ b/controller/token.go
@@ -12,29 +12,16 @@ import (
func GetAllTokens(c *gin.Context) {
userId := c.GetInt("id")
- p, _ := strconv.Atoi(c.Query("p"))
- size, _ := strconv.Atoi(c.Query("size"))
- if p < 0 {
- p = 0
- }
- if size <= 0 {
- size = common.ItemsPerPage
- } else if size > 100 {
- size = 100
- }
- tokens, err := model.GetAllUserTokens(userId, p*size, size)
+ pageInfo := common.GetPageQuery(c)
+ tokens, err := model.GetAllUserTokens(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": tokens,
- })
+ total, _ := model.CountUserTokens(userId)
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(tokens)
+ common.ApiSuccess(c, pageInfo)
return
}
@@ -44,10 +31,7 @@ func SearchTokens(c *gin.Context) {
token := c.Query("token")
tokens, err := model.SearchUserTokens(userId, keyword, token)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -62,18 +46,12 @@ func GetToken(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
userId := c.GetInt("id")
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
token, err := model.GetTokenByIds(id, userId)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -89,10 +67,7 @@ func GetTokenStatus(c *gin.Context) {
userId := c.GetInt("id")
token, err := model.GetTokenByIds(tokenId, userId)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
expiredAt := token.ExpiredTime
@@ -162,10 +137,7 @@ func AddToken(c *gin.Context) {
token := model.Token{}
err := c.ShouldBindJSON(&token)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
if len(token.Name) > 30 {
@@ -181,7 +153,7 @@ func AddToken(c *gin.Context) {
"success": false,
"message": "生成令牌失败",
})
- common.SysError("failed to generate token key: " + err.Error())
+ common.SysLog("failed to generate token key: " + err.Error())
return
}
cleanToken := model.Token{
@@ -200,10 +172,7 @@ func AddToken(c *gin.Context) {
}
err = cleanToken.Insert()
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -218,10 +187,7 @@ func DeleteToken(c *gin.Context) {
userId := c.GetInt("id")
err := model.DeleteTokenById(id, userId)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -237,10 +203,7 @@ func UpdateToken(c *gin.Context) {
token := model.Token{}
err := c.ShouldBindJSON(&token)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
if len(token.Name) > 30 {
@@ -252,10 +215,7 @@ func UpdateToken(c *gin.Context) {
}
cleanToken, err := model.GetTokenByIds(token.Id, userId)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
if token.Status == common.TokenStatusEnabled {
@@ -289,10 +249,7 @@ func UpdateToken(c *gin.Context) {
}
err = cleanToken.Update()
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -302,3 +259,29 @@ func UpdateToken(c *gin.Context) {
})
return
}
+
+type TokenBatch struct {
+ Ids []int `json:"ids"`
+}
+
+func DeleteTokenBatch(c *gin.Context) {
+ tokenBatch := TokenBatch{}
+ if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "参数错误",
+ })
+ return
+ }
+ userId := c.GetInt("id")
+ count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": count,
+ })
+}
diff --git a/controller/topup.go b/controller/topup.go
index 4654b6ea..3f3c8623 100644
--- a/controller/topup.go
+++ b/controller/topup.go
@@ -5,6 +5,7 @@ import (
"log"
"net/url"
"one-api/common"
+ "one-api/logger"
"one-api/model"
"one-api/service"
"one-api/setting"
@@ -97,16 +98,14 @@ func RequestEpay(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return
}
- payType := "wxpay"
- if req.PaymentMethod == "zfb" {
- payType = "alipay"
- }
- if req.PaymentMethod == "wx" {
- req.PaymentMethod = "wxpay"
- payType = "wxpay"
+
+ if !setting.ContainsPayMethod(req.PaymentMethod) {
+ c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
+ return
}
+
callBackAddress := service.GetCallbackAddress()
- returnUrl, _ := url.Parse(setting.ServerAddress + "/log")
+ returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
@@ -116,7 +115,7 @@ func RequestEpay(c *gin.Context) {
return
}
uri, params, err := client.Purchase(&epay.PurchaseArgs{
- Type: payType,
+ Type: req.PaymentMethod,
ServiceTradeNo: tradeNo,
Name: fmt.Sprintf("TUC%d", req.Amount),
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
@@ -233,7 +232,7 @@ func EpayNotify(c *gin.Context) {
return
}
log.Printf("易支付回调更新用户成功 %v", topUp)
- model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money))
+ model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money))
}
} else {
log.Printf("易支付异常回调: %v", verifyInfo)
diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go
new file mode 100644
index 00000000..eb320809
--- /dev/null
+++ b/controller/topup_stripe.go
@@ -0,0 +1,275 @@
+package controller
+
+import (
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "one-api/setting"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stripe/stripe-go/v81"
+ "github.com/stripe/stripe-go/v81/checkout/session"
+ "github.com/stripe/stripe-go/v81/webhook"
+ "github.com/thanhpk/randstr"
+)
+
+const (
+ PaymentMethodStripe = "stripe"
+)
+
+var stripeAdaptor = &StripeAdaptor{}
+
+type StripePayRequest struct {
+ Amount int64 `json:"amount"`
+ PaymentMethod string `json:"payment_method"`
+}
+
+type StripeAdaptor struct {
+}
+
+func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
+ if req.Amount < getStripeMinTopup() {
+ c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
+ return
+ }
+ id := c.GetInt("id")
+ group, err := model.GetUserGroup(id, true)
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
+ return
+ }
+ payMoney := getStripePayMoney(float64(req.Amount), group)
+ if payMoney <= 0.01 {
+ c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
+ return
+ }
+ c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
+}
+
+func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
+ if req.PaymentMethod != PaymentMethodStripe {
+ c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
+ return
+ }
+ if req.Amount < getStripeMinTopup() {
+ c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
+ return
+ }
+ if req.Amount > 10000 {
+ c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
+ return
+ }
+
+ id := c.GetInt("id")
+ user, _ := model.GetUserById(id, false)
+ chargedMoney := GetChargedAmount(float64(req.Amount), *user)
+
+ reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), randstr.String(4))
+ referenceId := "ref_" + common.Sha1([]byte(reference))
+
+ payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount)
+ if err != nil {
+ log.Println("获取Stripe Checkout支付链接失败", err)
+ c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
+ return
+ }
+
+ topUp := &model.TopUp{
+ UserId: id,
+ Amount: req.Amount,
+ Money: chargedMoney,
+ TradeNo: referenceId,
+ CreateTime: time.Now().Unix(),
+ Status: common.TopUpStatusPending,
+ }
+ err = topUp.Insert()
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
+ return
+ }
+ c.JSON(200, gin.H{
+ "message": "success",
+ "data": gin.H{
+ "pay_link": payLink,
+ },
+ })
+}
+
+func RequestStripeAmount(c *gin.Context) {
+ var req StripePayRequest
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
+ return
+ }
+ stripeAdaptor.RequestAmount(c, &req)
+}
+
+func RequestStripePay(c *gin.Context) {
+ var req StripePayRequest
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
+ return
+ }
+ stripeAdaptor.RequestPay(c, &req)
+}
+
+func StripeWebhook(c *gin.Context) {
+ payload, err := io.ReadAll(c.Request.Body)
+ if err != nil {
+ log.Printf("解析Stripe Webhook参数失败: %v\n", err)
+ c.AbortWithStatus(http.StatusServiceUnavailable)
+ return
+ }
+
+ signature := c.GetHeader("Stripe-Signature")
+ endpointSecret := setting.StripeWebhookSecret
+ event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
+ IgnoreAPIVersionMismatch: true,
+ })
+
+ if err != nil {
+ log.Printf("Stripe Webhook验签失败: %v\n", err)
+ c.AbortWithStatus(http.StatusBadRequest)
+ return
+ }
+
+ switch event.Type {
+ case stripe.EventTypeCheckoutSessionCompleted:
+ sessionCompleted(event)
+ case stripe.EventTypeCheckoutSessionExpired:
+ sessionExpired(event)
+ default:
+ log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
+ }
+
+ c.Status(http.StatusOK)
+}
+
+func sessionCompleted(event stripe.Event) {
+ customerId := event.GetObjectValue("customer")
+ referenceId := event.GetObjectValue("client_reference_id")
+ status := event.GetObjectValue("status")
+ if "complete" != status {
+ log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
+ return
+ }
+
+ err := model.Recharge(referenceId, customerId)
+ if err != nil {
+ log.Println(err.Error(), referenceId)
+ return
+ }
+
+ total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
+ currency := strings.ToUpper(event.GetObjectValue("currency"))
+ log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
+}
+
+func sessionExpired(event stripe.Event) {
+ referenceId := event.GetObjectValue("client_reference_id")
+ status := event.GetObjectValue("status")
+ if "expired" != status {
+ log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
+ return
+ }
+
+ if len(referenceId) == 0 {
+ log.Println("未提供支付单号")
+ return
+ }
+
+ topUp := model.GetTopUpByTradeNo(referenceId)
+ if topUp == nil {
+ log.Println("充值订单不存在", referenceId)
+ return
+ }
+
+ if topUp.Status != common.TopUpStatusPending {
+ log.Println("充值订单状态错误", referenceId)
+ }
+
+ topUp.Status = common.TopUpStatusExpired
+ err := topUp.Update()
+ if err != nil {
+ log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
+ return
+ }
+
+ log.Println("充值订单已过期", referenceId)
+}
+
+func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
+ if !strings.HasPrefix(setting.StripeApiSecret, "sk_") && !strings.HasPrefix(setting.StripeApiSecret, "rk_") {
+ return "", fmt.Errorf("无效的Stripe API密钥")
+ }
+
+ stripe.Key = setting.StripeApiSecret
+
+ params := &stripe.CheckoutSessionParams{
+ ClientReferenceID: stripe.String(referenceId),
+ SuccessURL: stripe.String(setting.ServerAddress + "/log"),
+ CancelURL: stripe.String(setting.ServerAddress + "/topup"),
+ LineItems: []*stripe.CheckoutSessionLineItemParams{
+ {
+ Price: stripe.String(setting.StripePriceId),
+ Quantity: stripe.Int64(amount),
+ },
+ },
+ Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
+ }
+
+ if "" == customerId {
+ if "" != email {
+ params.CustomerEmail = stripe.String(email)
+ }
+
+ params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
+ } else {
+ params.Customer = stripe.String(customerId)
+ }
+
+ result, err := session.New(params)
+ if err != nil {
+ return "", err
+ }
+
+ return result.URL, nil
+}
+
+func GetChargedAmount(count float64, user model.User) float64 {
+ topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
+ if topUpGroupRatio == 0 {
+ topUpGroupRatio = 1
+ }
+
+ return count * topUpGroupRatio
+}
+
+func getStripePayMoney(amount float64, group string) float64 {
+ if !common.DisplayInCurrencyEnabled {
+ amount = amount / common.QuotaPerUnit
+ }
+ // Using float64 for monetary calculations is acceptable here due to the small amounts involved
+ topupGroupRatio := common.GetTopupGroupRatio(group)
+ if topupGroupRatio == 0 {
+ topupGroupRatio = 1
+ }
+ payMoney := amount * setting.StripeUnitPrice * topupGroupRatio
+ return payMoney
+}
+
+func getStripeMinTopup() int64 {
+ minTopup := setting.StripeMinTopUp
+ if !common.DisplayInCurrencyEnabled {
+ minTopup = minTopup * int(common.QuotaPerUnit)
+ }
+ return int64(minTopup)
+}
diff --git a/controller/twofa.go b/controller/twofa.go
new file mode 100644
index 00000000..1859a128
--- /dev/null
+++ b/controller/twofa.go
@@ -0,0 +1,553 @@
+package controller
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "strconv"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+)
+
+// Setup2FARequest 设置2FA请求结构
+type Setup2FARequest struct {
+ Code string `json:"code" binding:"required"`
+}
+
+// Verify2FARequest 验证2FA请求结构
+type Verify2FARequest struct {
+ Code string `json:"code" binding:"required"`
+}
+
+// Setup2FAResponse 设置2FA响应结构
+type Setup2FAResponse struct {
+ Secret string `json:"secret"`
+ QRCodeData string `json:"qr_code_data"`
+ BackupCodes []string `json:"backup_codes"`
+}
+
+// Setup2FA 初始化2FA设置
+func Setup2FA(c *gin.Context) {
+ userId := c.GetInt("id")
+
+ // 检查用户是否已经启用2FA
+ existing, err := model.GetTwoFAByUserId(userId)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if existing != nil && existing.IsEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "用户已启用2FA,请先禁用后重新设置",
+ })
+ return
+ }
+
+ // 如果存在已禁用的2FA记录,先删除它
+ if existing != nil && !existing.IsEnabled {
+ if err := existing.Delete(); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ existing = nil // 重置为nil,后续将创建新记录
+ }
+
+ // 获取用户信息
+ user, err := model.GetUserById(userId, false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ // 生成TOTP密钥
+ key, err := common.GenerateTOTPSecret(user.Username)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "生成2FA密钥失败",
+ })
+ common.SysLog("生成TOTP密钥失败: " + err.Error())
+ return
+ }
+
+ // 生成备用码
+ backupCodes, err := common.GenerateBackupCodes()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "生成备用码失败",
+ })
+ common.SysLog("生成备用码失败: " + err.Error())
+ return
+ }
+
+ // 生成二维码数据
+ qrCodeData := common.GenerateQRCodeData(key.Secret(), user.Username)
+
+ // 创建或更新2FA记录(暂未启用)
+ twoFA := &model.TwoFA{
+ UserId: userId,
+ Secret: key.Secret(),
+ IsEnabled: false,
+ }
+
+ if existing != nil {
+ // 更新现有记录
+ twoFA.Id = existing.Id
+ err = twoFA.Update()
+ } else {
+ // 创建新记录
+ err = twoFA.Create()
+ }
+
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ // 创建备用码记录
+ if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "保存备用码失败",
+ })
+ common.SysLog("保存备用码失败: " + err.Error())
+ return
+ }
+
+ // 记录操作日志
+ model.RecordLog(userId, model.LogTypeSystem, "开始设置两步验证")
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "2FA设置初始化成功,请使用认证器扫描二维码并输入验证码完成设置",
+ "data": Setup2FAResponse{
+ Secret: key.Secret(),
+ QRCodeData: qrCodeData,
+ BackupCodes: backupCodes,
+ },
+ })
+}
+
+// Enable2FA 启用2FA
+func Enable2FA(c *gin.Context) {
+ var req Setup2FARequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "参数错误",
+ })
+ return
+ }
+
+ userId := c.GetInt("id")
+
+ // 获取2FA记录
+ twoFA, err := model.GetTwoFAByUserId(userId)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if twoFA == nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "请先完成2FA初始化设置",
+ })
+ return
+ }
+ if twoFA.IsEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "2FA已经启用",
+ })
+ return
+ }
+
+ // 验证TOTP验证码
+ cleanCode, err := common.ValidateNumericCode(req.Code)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ if !common.ValidateTOTPCode(twoFA.Secret, cleanCode) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "验证码或备用码错误,请重试",
+ })
+ return
+ }
+
+ // 启用2FA
+ if err := twoFA.Enable(); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ // 记录操作日志
+ model.RecordLog(userId, model.LogTypeSystem, "成功启用两步验证")
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "两步验证启用成功",
+ })
+}
+
+// Disable2FA 禁用2FA
+func Disable2FA(c *gin.Context) {
+ var req Verify2FARequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "参数错误",
+ })
+ return
+ }
+
+ userId := c.GetInt("id")
+
+ // 获取2FA记录
+ twoFA, err := model.GetTwoFAByUserId(userId)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if twoFA == nil || !twoFA.IsEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "用户未启用2FA",
+ })
+ return
+ }
+
+ // 验证TOTP验证码或备用码
+ cleanCode, err := common.ValidateNumericCode(req.Code)
+ isValidTOTP := false
+ isValidBackup := false
+
+ if err == nil {
+ // 尝试验证TOTP
+ isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
+ }
+
+ if !isValidTOTP {
+ // 尝试验证备用码
+ isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ }
+
+ if !isValidTOTP && !isValidBackup {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "验证码或备用码错误,请重试",
+ })
+ return
+ }
+
+ // 禁用2FA
+ if err := model.DisableTwoFA(userId); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ // 记录操作日志
+ model.RecordLog(userId, model.LogTypeSystem, "禁用两步验证")
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "两步验证已禁用",
+ })
+}
+
+// Get2FAStatus 获取用户2FA状态
+func Get2FAStatus(c *gin.Context) {
+ userId := c.GetInt("id")
+
+ twoFA, err := model.GetTwoFAByUserId(userId)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ status := map[string]interface{}{
+ "enabled": false,
+ "locked": false,
+ }
+
+ if twoFA != nil {
+ status["enabled"] = twoFA.IsEnabled
+ status["locked"] = twoFA.IsLocked()
+ if twoFA.IsEnabled {
+ // 获取剩余备用码数量
+ backupCount, err := model.GetUnusedBackupCodeCount(userId)
+ if err != nil {
+ common.SysLog("获取备用码数量失败: " + err.Error())
+ } else {
+ status["backup_codes_remaining"] = backupCount
+ }
+ }
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": status,
+ })
+}
+
+// RegenerateBackupCodes 重新生成备用码
+func RegenerateBackupCodes(c *gin.Context) {
+ var req Verify2FARequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "参数错误",
+ })
+ return
+ }
+
+ userId := c.GetInt("id")
+
+ // 获取2FA记录
+ twoFA, err := model.GetTwoFAByUserId(userId)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if twoFA == nil || !twoFA.IsEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "用户未启用2FA",
+ })
+ return
+ }
+
+ // 验证TOTP验证码
+ cleanCode, err := common.ValidateNumericCode(req.Code)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ valid, err := twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ if !valid {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "验证码或备用码错误,请重试",
+ })
+ return
+ }
+
+ // 生成新的备用码
+ backupCodes, err := common.GenerateBackupCodes()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "生成备用码失败",
+ })
+ common.SysLog("生成备用码失败: " + err.Error())
+ return
+ }
+
+ // 保存新的备用码
+ if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "保存备用码失败",
+ })
+ common.SysLog("保存备用码失败: " + err.Error())
+ return
+ }
+
+ // 记录操作日志
+ model.RecordLog(userId, model.LogTypeSystem, "重新生成两步验证备用码")
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "备用码重新生成成功",
+ "data": map[string]interface{}{
+ "backup_codes": backupCodes,
+ },
+ })
+}
+
+// Verify2FALogin 登录时验证2FA
+func Verify2FALogin(c *gin.Context) {
+ var req Verify2FARequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "参数错误",
+ })
+ return
+ }
+
+ // 从会话中获取pending用户信息
+ session := sessions.Default(c)
+ pendingUserId := session.Get("pending_user_id")
+ if pendingUserId == nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "会话已过期,请重新登录",
+ })
+ return
+ }
+ userId, ok := pendingUserId.(int)
+ if !ok {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "会话数据无效,请重新登录",
+ })
+ return
+ }
+ // 获取用户信息
+ user, err := model.GetUserById(userId, false)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "用户不存在",
+ })
+ return
+ }
+
+ // 获取2FA记录
+ twoFA, err := model.GetTwoFAByUserId(user.Id)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if twoFA == nil || !twoFA.IsEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "用户未启用2FA",
+ })
+ return
+ }
+
+ // 验证TOTP验证码或备用码
+ cleanCode, err := common.ValidateNumericCode(req.Code)
+ isValidTOTP := false
+ isValidBackup := false
+
+ if err == nil {
+ // 尝试验证TOTP
+ isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
+ }
+
+ if !isValidTOTP {
+ // 尝试验证备用码
+ isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ }
+
+ if !isValidTOTP && !isValidBackup {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "验证码或备用码错误,请重试",
+ })
+ return
+ }
+
+ // 2FA验证成功,清理pending会话信息并完成登录
+ session.Delete("pending_username")
+ session.Delete("pending_user_id")
+ session.Save()
+
+ setupLogin(user, c)
+}
+
+// Admin2FAStats 管理员获取2FA统计信息
+func Admin2FAStats(c *gin.Context) {
+ stats, err := model.GetTwoFAStats()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": stats,
+ })
+}
+
+// AdminDisable2FA 管理员强制禁用用户2FA
+func AdminDisable2FA(c *gin.Context) {
+ userIdStr := c.Param("id")
+ userId, err := strconv.Atoi(userIdStr)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "用户ID格式错误",
+ })
+ return
+ }
+
+ // 检查目标用户权限
+ targetUser, err := model.GetUserById(userId, false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ myRole := c.GetInt("role")
+ if myRole <= targetUser.Role && myRole != common.RoleRootUser {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无权操作同级或更高级用户的2FA设置",
+ })
+ return
+ }
+
+ // 禁用2FA
+ if err := model.DisableTwoFA(userId); err != nil {
+ if errors.Is(err, model.ErrTwoFANotEnabled) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "用户未启用2FA",
+ })
+ return
+ }
+ common.ApiError(c, err)
+ return
+ }
+
+ // 记录操作日志
+ adminId := c.GetInt("id")
+ model.RecordLog(userId, model.LogTypeManage,
+ fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId))
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "用户2FA已被强制禁用",
+ })
+}
diff --git a/controller/uptime_kuma.go b/controller/uptime_kuma.go
new file mode 100644
index 00000000..05d6297e
--- /dev/null
+++ b/controller/uptime_kuma.go
@@ -0,0 +1,154 @@
+package controller
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "one-api/setting/console_setting"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "golang.org/x/sync/errgroup"
+)
+
+const (
+ requestTimeout = 30 * time.Second
+ httpTimeout = 10 * time.Second
+ uptimeKeySuffix = "_24"
+ apiStatusPath = "/api/status-page/"
+ apiHeartbeatPath = "/api/status-page/heartbeat/"
+)
+
+type Monitor struct {
+ Name string `json:"name"`
+ Uptime float64 `json:"uptime"`
+ Status int `json:"status"`
+ Group string `json:"group,omitempty"`
+}
+
+type UptimeGroupResult struct {
+ CategoryName string `json:"categoryName"`
+ Monitors []Monitor `json:"monitors"`
+}
+
+func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return err
+ }
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return errors.New("non-200 status")
+ }
+
+ return json.NewDecoder(resp.Body).Decode(dest)
+}
+
+func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[string]interface{}) UptimeGroupResult {
+ url, _ := groupConfig["url"].(string)
+ slug, _ := groupConfig["slug"].(string)
+ categoryName, _ := groupConfig["categoryName"].(string)
+
+ result := UptimeGroupResult{
+ CategoryName: categoryName,
+ Monitors: []Monitor{},
+ }
+
+ if url == "" || slug == "" {
+ return result
+ }
+
+ baseURL := strings.TrimSuffix(url, "/")
+
+ var statusData struct {
+ PublicGroupList []struct {
+ ID int `json:"id"`
+ Name string `json:"name"`
+ MonitorList []struct {
+ ID int `json:"id"`
+ Name string `json:"name"`
+ } `json:"monitorList"`
+ } `json:"publicGroupList"`
+ }
+
+ var heartbeatData struct {
+ HeartbeatList map[string][]struct {
+ Status int `json:"status"`
+ } `json:"heartbeatList"`
+ UptimeList map[string]float64 `json:"uptimeList"`
+ }
+
+ g, gCtx := errgroup.WithContext(ctx)
+ g.Go(func() error {
+ return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
+ })
+ g.Go(func() error {
+ return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
+ })
+
+ if g.Wait() != nil {
+ return result
+ }
+
+ for _, pg := range statusData.PublicGroupList {
+ if len(pg.MonitorList) == 0 {
+ continue
+ }
+
+ for _, m := range pg.MonitorList {
+ monitor := Monitor{
+ Name: m.Name,
+ Group: pg.Name,
+ }
+
+ monitorID := strconv.Itoa(m.ID)
+
+ if uptime, exists := heartbeatData.UptimeList[monitorID+uptimeKeySuffix]; exists {
+ monitor.Uptime = uptime
+ }
+
+ if heartbeats, exists := heartbeatData.HeartbeatList[monitorID]; exists && len(heartbeats) > 0 {
+ monitor.Status = heartbeats[0].Status
+ }
+
+ result.Monitors = append(result.Monitors, monitor)
+ }
+ }
+
+ return result
+}
+
+func GetUptimeKumaStatus(c *gin.Context) {
+ groups := console_setting.GetUptimeKumaGroups()
+ if len(groups) == 0 {
+ c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": []UptimeGroupResult{}})
+ return
+ }
+
+ ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
+ defer cancel()
+
+ client := &http.Client{Timeout: httpTimeout}
+ results := make([]UptimeGroupResult, len(groups))
+
+ g, gCtx := errgroup.WithContext(ctx)
+ for i, group := range groups {
+ i, group := i, group
+ g.Go(func() error {
+ results[i] = fetchGroupData(gCtx, client, group)
+ return nil
+ })
+ }
+
+ g.Wait()
+ c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
+}
\ No newline at end of file
diff --git a/controller/usedata.go b/controller/usedata.go
index 270eadf3..4adee50f 100644
--- a/controller/usedata.go
+++ b/controller/usedata.go
@@ -1,10 +1,12 @@
package controller
import (
- "github.com/gin-gonic/gin"
"net/http"
+ "one-api/common"
"one-api/model"
"strconv"
+
+ "github.com/gin-gonic/gin"
)
func GetAllQuotaDates(c *gin.Context) {
@@ -13,10 +15,7 @@ func GetAllQuotaDates(c *gin.Context) {
username := c.Query("username")
dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -41,10 +40,7 @@ func GetUserQuotaDates(c *gin.Context) {
}
dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
diff --git a/controller/user.go b/controller/user.go
index fd53e743..c9795c0c 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -6,6 +6,8 @@ import (
"net/http"
"net/url"
"one-api/common"
+ "one-api/dto"
+ "one-api/logger"
"one-api/model"
"one-api/setting"
"strconv"
@@ -61,6 +63,32 @@ func Login(c *gin.Context) {
})
return
}
+
+ // 检查是否启用2FA
+ if model.IsTwoFAEnabled(user.Id) {
+ // 设置pending session,等待2FA验证
+ session := sessions.Default(c)
+ session.Set("pending_username", user.Username)
+ session.Set("pending_user_id", user.Id)
+ err := session.Save()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "无法保存会话信息,请重试",
+ "success": false,
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "message": "请输入两步验证码",
+ "success": true,
+ "data": map[string]interface{}{
+ "require_2fa": true,
+ },
+ })
+ return
+ }
+
setupLogin(&user, c)
}
@@ -165,7 +193,7 @@ func Register(c *gin.Context) {
"success": false,
"message": "数据库错误,请稍后重试",
})
- common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
+ common.SysLog(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
return
}
if exist {
@@ -187,10 +215,7 @@ func Register(c *gin.Context) {
cleanUser.Email = user.Email
}
if err := cleanUser.Insert(inviterId); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
@@ -211,7 +236,7 @@ func Register(c *gin.Context) {
"success": false,
"message": "生成默认令牌失败",
})
- common.SysError("failed to generate token key: " + err.Error())
+ common.SysLog("failed to generate token key: " + err.Error())
return
}
// 生成默认令牌
@@ -226,6 +251,9 @@ func Register(c *gin.Context) {
UnlimitedQuota: true,
ModelLimitsEnabled: false,
}
+ if setting.DefaultUseAutoGroup {
+ token.Group = "auto"
+ }
if err := token.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -243,83 +271,45 @@ func Register(c *gin.Context) {
}
func GetAllUsers(c *gin.Context) {
- p, _ := strconv.Atoi(c.Query("p"))
- pageSize, _ := strconv.Atoi(c.Query("page_size"))
- if p < 1 {
- p = 1
- }
- if pageSize < 0 {
- pageSize = common.ItemsPerPage
- }
- users, total, err := model.GetAllUsers((p-1)*pageSize, pageSize)
+ pageInfo := common.GetPageQuery(c)
+ users, total, err := model.GetAllUsers(pageInfo)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": gin.H{
- "items": users,
- "total": total,
- "page": p,
- "page_size": pageSize,
- },
- })
+
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(users)
+
+ common.ApiSuccess(c, pageInfo)
return
}
func SearchUsers(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
- p, _ := strconv.Atoi(c.Query("p"))
- pageSize, _ := strconv.Atoi(c.Query("page_size"))
- if p < 1 {
- p = 1
- }
- if pageSize < 0 {
- pageSize = common.ItemsPerPage
- }
- startIdx := (p - 1) * pageSize
- users, total, err := model.SearchUsers(keyword, group, startIdx, pageSize)
+ pageInfo := common.GetPageQuery(c)
+ users, total, err := model.SearchUsers(keyword, group, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": gin.H{
- "items": users,
- "total": total,
- "page": p,
- "page_size": pageSize,
- },
- })
+
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(users)
+ common.ApiSuccess(c, pageInfo)
return
}
func GetUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
user, err := model.GetUserById(id, false)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
myRole := c.GetInt("role")
@@ -342,10 +332,7 @@ func GenerateAccessToken(c *gin.Context) {
id := c.GetInt("id")
user, err := model.GetUserById(id, true)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
// get rand int 28-32
@@ -356,7 +343,7 @@ func GenerateAccessToken(c *gin.Context) {
"success": false,
"message": "生成失败",
})
- common.SysError("failed to generate key: " + err.Error())
+ common.SysLog("failed to generate key: " + err.Error())
return
}
user.SetAccessToken(key)
@@ -370,10 +357,7 @@ func GenerateAccessToken(c *gin.Context) {
}
if err := user.Update(false); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
@@ -393,18 +377,12 @@ func TransferAffQuota(c *gin.Context) {
id := c.GetInt("id")
user, err := model.GetUserById(id, true)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
tran := TransferAffQuotaRequest{}
if err := c.ShouldBindJSON(&tran); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
err = user.TransferAffQuotaToQuota(tran.Quota)
@@ -425,10 +403,7 @@ func GetAffCode(c *gin.Context) {
id := c.GetInt("id")
user, err := model.GetUserById(id, true)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
if user.AffCode == "" {
@@ -453,12 +428,12 @@ func GetSelf(c *gin.Context) {
id := c.GetInt("id")
user, err := model.GetUserById(id, false)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
+ // Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
+ user.Remark = ""
+
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -474,16 +449,13 @@ func GetUserModels(c *gin.Context) {
}
user, err := model.GetUserCache(id)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
groups := setting.GetUserUsableGroups(user.Group)
var models []string
for group := range groups {
- for _, g := range model.GetGroupModels(group) {
+ for _, g := range model.GetGroupEnabledModels(group) {
if !common.StringsContains(models, g) {
models = append(models, g)
}
@@ -519,10 +491,7 @@ func UpdateUser(c *gin.Context) {
}
originUser, err := model.GetUserById(updatedUser.Id, false)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
myRole := c.GetInt("role")
@@ -545,14 +514,11 @@ func UpdateUser(c *gin.Context) {
}
updatePassword := updatedUser.Password != ""
if err := updatedUser.Edit(updatePassword); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
if originUser.Quota != updatedUser.Quota {
- model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
+ model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota)))
}
c.JSON(http.StatusOK, gin.H{
"success": true,
@@ -594,17 +560,11 @@ func UpdateSelf(c *gin.Context) {
}
updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
if err := cleanUser.Update(updatePassword); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
@@ -635,18 +595,12 @@ func checkUpdatePassword(originalPassword string, newPassword string, userId int
func DeleteUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
originUser, err := model.GetUserById(id, false)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
myRole := c.GetInt("role")
@@ -681,10 +635,7 @@ func DeleteSelf(c *gin.Context) {
err := model.DeleteUserById(id)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -730,10 +681,7 @@ func CreateUser(c *gin.Context) {
DisplayName: user.DisplayName,
}
if err := cleanUser.Insert(0); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
@@ -843,10 +791,7 @@ func ManageUser(c *gin.Context) {
}
if err := user.Update(false); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
clearUser := model.User{
@@ -878,20 +823,14 @@ func EmailBind(c *gin.Context) {
}
err := user.FillUserById()
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
user.Email = email
// no need to check if this email already taken, because we have used verification code to check it
err = user.Update(false)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -905,27 +844,67 @@ type topUpRequest struct {
Key string `json:"key"`
}
-var topUpLock = sync.Mutex{}
+var topUpLocks sync.Map
+var topUpCreateLock sync.Mutex
+
+type topUpTryLock struct {
+ ch chan struct{}
+}
+
+func newTopUpTryLock() *topUpTryLock {
+ return &topUpTryLock{ch: make(chan struct{}, 1)}
+}
+
+func (l *topUpTryLock) TryLock() bool {
+ select {
+ case l.ch <- struct{}{}:
+ return true
+ default:
+ return false
+ }
+}
+
+func (l *topUpTryLock) Unlock() {
+ select {
+ case <-l.ch:
+ default:
+ }
+}
+
+func getTopUpLock(userID int) *topUpTryLock {
+ if v, ok := topUpLocks.Load(userID); ok {
+ return v.(*topUpTryLock)
+ }
+ topUpCreateLock.Lock()
+ defer topUpCreateLock.Unlock()
+ if v, ok := topUpLocks.Load(userID); ok {
+ return v.(*topUpTryLock)
+ }
+ l := newTopUpTryLock()
+ topUpLocks.Store(userID, l)
+ return l
+}
func TopUp(c *gin.Context) {
- topUpLock.Lock()
- defer topUpLock.Unlock()
- req := topUpRequest{}
- err := c.ShouldBindJSON(&req)
- if err != nil {
+ id := c.GetInt("id")
+ lock := getTopUpLock(id)
+ if !lock.TryLock() {
c.JSON(http.StatusOK, gin.H{
"success": false,
- "message": err.Error(),
+ "message": "充值处理中,请稍后重试",
})
return
}
- id := c.GetInt("id")
+ defer lock.Unlock()
+ req := topUpRequest{}
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
quota, err := model.Redeem(req.Key, id)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
@@ -933,7 +912,6 @@ func TopUp(c *gin.Context) {
"message": "",
"data": quota,
})
- return
}
type UpdateUserSettingRequest struct {
@@ -943,6 +921,7 @@ type UpdateUserSettingRequest struct {
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
+ RecordIpLog bool `json:"record_ip_log"`
}
func UpdateUserSetting(c *gin.Context) {
@@ -956,7 +935,7 @@ func UpdateUserSetting(c *gin.Context) {
}
// 验证预警类型
- if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
+ if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的预警类型",
@@ -974,7 +953,7 @@ func UpdateUserSetting(c *gin.Context) {
}
// 如果是webhook类型,验证webhook地址
- if req.QuotaWarningType == constant.NotifyTypeWebhook {
+ if req.QuotaWarningType == dto.NotifyTypeWebhook {
if req.WebhookUrl == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -993,7 +972,7 @@ func UpdateUserSetting(c *gin.Context) {
}
// 如果是邮件类型,验证邮箱地址
- if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
+ if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
// 验证邮箱格式
if !strings.Contains(req.NotificationEmail, "@") {
c.JSON(http.StatusOK, gin.H{
@@ -1007,31 +986,29 @@ func UpdateUserSetting(c *gin.Context) {
userId := c.GetInt("id")
user, err := model.GetUserById(userId, true)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
// 构建设置
- settings := map[string]interface{}{
- constant.UserSettingNotifyType: req.QuotaWarningType,
- constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
- "accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
+ settings := dto.UserSetting{
+ NotifyType: req.QuotaWarningType,
+ QuotaWarningThreshold: req.QuotaWarningThreshold,
+ AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel,
+ RecordIpLog: req.RecordIpLog,
}
// 如果是webhook类型,添加webhook相关设置
- if req.QuotaWarningType == constant.NotifyTypeWebhook {
- settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
+ if req.QuotaWarningType == dto.NotifyTypeWebhook {
+ settings.WebhookUrl = req.WebhookUrl
if req.WebhookSecret != "" {
- settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
+ settings.WebhookSecret = req.WebhookSecret
}
}
// 如果提供了通知邮箱,添加到设置中
- if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
- settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
+ if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
+ settings.NotificationEmail = req.NotificationEmail
}
// 更新用户设置
diff --git a/controller/vendor_meta.go b/controller/vendor_meta.go
new file mode 100644
index 00000000..21d5a21d
--- /dev/null
+++ b/controller/vendor_meta.go
@@ -0,0 +1,124 @@
+package controller
+
+import (
+ "strconv"
+
+ "one-api/common"
+ "one-api/model"
+
+ "github.com/gin-gonic/gin"
+)
+
+// GetAllVendors 获取供应商列表(分页)
+func GetAllVendors(c *gin.Context) {
+ pageInfo := common.GetPageQuery(c)
+ vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ var total int64
+ model.DB.Model(&model.Vendor{}).Count(&total)
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(vendors)
+ common.ApiSuccess(c, pageInfo)
+}
+
+// SearchVendors 搜索供应商
+func SearchVendors(c *gin.Context) {
+ keyword := c.Query("keyword")
+ pageInfo := common.GetPageQuery(c)
+ vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(vendors)
+ common.ApiSuccess(c, pageInfo)
+}
+
+// GetVendorMeta 根据 ID 获取供应商
+func GetVendorMeta(c *gin.Context) {
+ idStr := c.Param("id")
+ id, err := strconv.Atoi(idStr)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ v, err := model.GetVendorByID(id)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ common.ApiSuccess(c, v)
+}
+
+// CreateVendorMeta 新建供应商
+func CreateVendorMeta(c *gin.Context) {
+ var v model.Vendor
+ if err := c.ShouldBindJSON(&v); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if v.Name == "" {
+ common.ApiErrorMsg(c, "供应商名称不能为空")
+ return
+ }
+ // 创建前先检查名称
+ if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil {
+ common.ApiError(c, err)
+ return
+ } else if dup {
+ common.ApiErrorMsg(c, "供应商名称已存在")
+ return
+ }
+
+ if err := v.Insert(); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ common.ApiSuccess(c, &v)
+}
+
+// UpdateVendorMeta 更新供应商
+func UpdateVendorMeta(c *gin.Context) {
+ var v model.Vendor
+ if err := c.ShouldBindJSON(&v); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if v.Id == 0 {
+ common.ApiErrorMsg(c, "缺少供应商 ID")
+ return
+ }
+ // 名称冲突检查
+ if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil {
+ common.ApiError(c, err)
+ return
+ } else if dup {
+ common.ApiErrorMsg(c, "供应商名称已存在")
+ return
+ }
+
+ if err := v.Update(); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ common.ApiSuccess(c, &v)
+}
+
+// DeleteVendorMeta 删除供应商
+func DeleteVendorMeta(c *gin.Context) {
+ idStr := c.Param("id")
+ id, err := strconv.Atoi(idStr)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ common.ApiSuccess(c, nil)
+}
diff --git a/controller/wechat.go b/controller/wechat.go
index 9b5f2070..9a4bdfed 100644
--- a/controller/wechat.go
+++ b/controller/wechat.go
@@ -4,13 +4,14 @@ import (
"encoding/json"
"errors"
"fmt"
- "github.com/gin-contrib/sessions"
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
)
type wechatLoginResponse struct {
@@ -150,19 +151,13 @@ func WeChatBind(c *gin.Context) {
}
err = user.FillUserById()
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
user.WeChatId = wechatId
err = user.Update(false)
if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
+ common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
diff --git a/docker-compose.yml b/docker-compose.yml
index 3d707ed0..d98fd706 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -16,7 +16,7 @@ services:
- REDIS_CONN_STRING=redis://redis
- TZ=Asia/Shanghai
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
- # - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache,请取消注释
+ # - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
diff --git a/docs/api/user.md b/docs/api/user.md
deleted file mode 100644
index e69de29b..00000000
diff --git a/docs/api/web_api.md b/docs/api/web_api.md
new file mode 100644
index 00000000..e64fd359
--- /dev/null
+++ b/docs/api/web_api.md
@@ -0,0 +1,197 @@
+# New API – Web 界面后端接口文档
+
+> 本文档汇总了 **New API** 后端提供给前端 Web 界面的全部 REST 接口(不含 *Relay* 相关接口)。
+>
+> 接口前缀统一为 `https://`,以下仅列出 **路径**、**HTTP 方法**、**鉴权要求** 与 **功能简介**。
+>
+> 鉴权级别说明:
+> * **公开** – 不需要登录即可调用
+> * **用户** – 需携带用户 Token(`middleware.UserAuth`)
+> * **管理员** – 需管理员 Token(`middleware.AdminAuth`)
+> * **Root** – 仅限最高权限 Root 用户(`middleware.RootAuth`)
+
+---
+
+## 1. 初始化 / 系统状态
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/setup | 公开 | 获取系统初始化状态 |
+| POST | /api/setup | 公开 | 完成首次安装向导 |
+| GET | /api/status | 公开 | 获取运行状态摘要 |
+| GET | /api/uptime/status | 公开 | Uptime-Kuma 兼容状态探针 |
+| GET | /api/status/test | 管理员 | 测试后端与依赖组件是否正常 |
+
+## 2. 公共信息
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/models | 用户 | 获取前端可用模型列表 |
+| GET | /api/notice | 公开 | 获取公告栏内容 |
+| GET | /api/about | 公开 | 关于页面信息 |
+| GET | /api/home_page_content | 公开 | 首页自定义内容 |
+| GET | /api/pricing | 可匿名/用户 | 价格与套餐信息 |
+| GET | /api/ratio_config | 公开 | 模型倍率配置(仅公开字段) |
+
+## 3. 邮件 / 身份验证
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/verification | 公开 (限流) | 发送邮箱验证邮件 |
+| GET | /api/reset_password | 公开 (限流) | 发送重置密码邮件 |
+| POST | /api/user/reset | 公开 | 提交重置密码请求 |
+
+## 4. OAuth / 第三方登录
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/oauth/github | 公开 | GitHub OAuth 跳转 |
+| GET | /api/oauth/oidc | 公开 | OIDC 通用 OAuth 跳转 |
+| GET | /api/oauth/linuxdo | 公开 | LinuxDo OAuth 跳转 |
+| GET | /api/oauth/wechat | 公开 | 微信扫码登录跳转 |
+| GET | /api/oauth/wechat/bind | 公开 | 微信账户绑定 |
+| GET | /api/oauth/email/bind | 公开 | 邮箱绑定 |
+| GET | /api/oauth/telegram/login | 公开 | Telegram 登录 |
+| GET | /api/oauth/telegram/bind | 公开 | Telegram 账户绑定 |
+| GET | /api/oauth/state | 公开 | 获取随机 state(防 CSRF) |
+
+## 5. 用户模块
+### 5.1 账号注册/登录
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| POST | /api/user/register | 公开 | 注册新账号 |
+| POST | /api/user/login | 公开 | 用户登录 |
+| GET | /api/user/logout | 用户 | 退出登录 |
+| GET | /api/user/epay/notify | 公开 | Epay 支付回调 |
+| GET | /api/user/groups | 公开 | 列出所有分组(无鉴权版) |
+
+### 5.2 用户自身操作 (需登录)
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/user/self/groups | 用户 | 获取自己所在分组 |
+| GET | /api/user/self | 用户 | 获取个人资料 |
+| GET | /api/user/models | 用户 | 获取模型可见性 |
+| PUT | /api/user/self | 用户 | 修改个人资料 |
+| DELETE | /api/user/self | 用户 | 注销账号 |
+| GET | /api/user/token | 用户 | 生成用户级别 Access Token |
+| GET | /api/user/aff | 用户 | 获取推广码信息 |
+| POST | /api/user/topup | 用户 | 余额直充 |
+| POST | /api/user/pay | 用户 | 提交支付订单 |
+| POST | /api/user/amount | 用户 | 余额支付 |
+| POST | /api/user/aff_transfer | 用户 | 推广额度转账 |
+| PUT | /api/user/setting | 用户 | 更新用户设置 |
+
+### 5.3 管理员用户管理
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/user/ | 管理员 | 获取全部用户列表 |
+| GET | /api/user/search | 管理员 | 搜索用户 |
+| GET | /api/user/:id | 管理员 | 获取单个用户信息 |
+| POST | /api/user/ | 管理员 | 创建用户 |
+| POST | /api/user/manage | 管理员 | 冻结/重置等管理操作 |
+| PUT | /api/user/ | 管理员 | 更新用户 |
+| DELETE | /api/user/:id | 管理员 | 删除用户 |
+
+## 6. 站点选项 (Root)
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/option/ | Root | 获取全局配置 |
+| PUT | /api/option/ | Root | 更新全局配置 |
+| POST | /api/option/rest_model_ratio | Root | 重置模型倍率 |
+| POST | /api/option/migrate_console_setting | Root | 迁移旧版控制台配置 |
+
+## 7. 模型倍率同步 (Root)
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/ratio_sync/channels | Root | 获取可同步渠道列表 |
+| POST | /api/ratio_sync/fetch | Root | 从上游拉取倍率 |
+
+## 8. 渠道管理 (管理员)
+| 方法 | 路径 | 说明 |
+|------|------|------|
+| GET | /api/channel/ | 获取渠道列表 |
+| GET | /api/channel/search | 搜索渠道 |
+| GET | /api/channel/models | 查询渠道模型能力 |
+| GET | /api/channel/models_enabled | 查询启用模型能力 |
+| GET | /api/channel/:id | 获取单个渠道 |
+| GET | /api/channel/test | 批量测试渠道连通性 |
+| GET | /api/channel/test/:id | 单个渠道测试 |
+| GET | /api/channel/update_balance | 批量刷新余额 |
+| GET | /api/channel/update_balance/:id | 单个刷新余额 |
+| POST | /api/channel/ | 新增渠道 |
+| PUT | /api/channel/ | 更新渠道 |
+| DELETE | /api/channel/disabled | 删除已禁用渠道 |
+| POST | /api/channel/tag/disabled | 批量禁用标签渠道 |
+| POST | /api/channel/tag/enabled | 批量启用标签渠道 |
+| PUT | /api/channel/tag | 编辑渠道标签 |
+| DELETE | /api/channel/:id | 删除渠道 |
+| POST | /api/channel/batch | 批量删除渠道 |
+| POST | /api/channel/fix | 修复渠道能力表 |
+| GET | /api/channel/fetch_models/:id | 拉取单渠道模型 |
+| POST | /api/channel/fetch_models | 拉取全部渠道模型 |
+| POST | /api/channel/batch/tag | 批量设置渠道标签 |
+| GET | /api/channel/tag/models | 根据标签获取模型 |
+| POST | /api/channel/copy/:id | 复制渠道 |
+
+## 9. Token 管理
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/token/ | 用户 | 获取全部 Token |
+| GET | /api/token/search | 用户 | 搜索 Token |
+| GET | /api/token/:id | 用户 | 获取单个 Token |
+| POST | /api/token/ | 用户 | 创建 Token |
+| PUT | /api/token/ | 用户 | 更新 Token |
+| DELETE | /api/token/:id | 用户 | 删除 Token |
+| POST | /api/token/batch | 用户 | 批量删除 Token |
+
+## 10. 兑换码管理 (管理员)
+| 方法 | 路径 | 说明 |
+|------|------|------|
+| GET | /api/redemption/ | 获取兑换码列表 |
+| GET | /api/redemption/search | 搜索兑换码 |
+| GET | /api/redemption/:id | 获取单个兑换码 |
+| POST | /api/redemption/ | 创建兑换码 |
+| PUT | /api/redemption/ | 更新兑换码 |
+| DELETE | /api/redemption/invalid | 删除无效兑换码 |
+| DELETE | /api/redemption/:id | 删除兑换码 |
+
+## 11. 日志
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/log/ | 管理员 | 获取全部日志 |
+| DELETE | /api/log/ | 管理员 | 删除历史日志 |
+| GET | /api/log/stat | 管理员 | 日志统计 |
+| GET | /api/log/self/stat | 用户 | 我的日志统计 |
+| GET | /api/log/search | 管理员 | 搜索全部日志 |
+| GET | /api/log/self | 用户 | 获取我的日志 |
+| GET | /api/log/self/search | 用户 | 搜索我的日志 |
+| GET | /api/log/token | 公开 | 根据 Token 查询日志(支持 CORS) |
+
+## 12. 数据统计
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/data/ | 管理员 | 全站用量按日期统计 |
+| GET | /api/data/self | 用户 | 我的用量按日期统计 |
+
+## 13. 分组
+| GET | /api/group/ | 管理员 | 获取全部分组列表 |
+
+## 14. Midjourney 任务
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/mj/self | 用户 | 获取自己的 MJ 任务 |
+| GET | /api/mj/ | 管理员 | 获取全部 MJ 任务 |
+
+## 15. 任务中心
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/task/self | 用户 | 获取我的任务 |
+| GET | /api/task/ | 管理员 | 获取全部任务 |
+
+## 16. 账户计费面板 (Dashboard)
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /dashboard/billing/subscription | 用户 Token | 获取订阅额度信息 |
+| GET | /v1/dashboard/billing/subscription | 同上 | 兼容 OpenAI SDK 路径 |
+| GET | /dashboard/billing/usage | 用户 Token | 获取使用量信息 |
+| GET | /v1/dashboard/billing/usage | 同上 | 兼容 OpenAI SDK 路径 |
+
+---
+
+> **更新日期**:2025.07.17
diff --git a/docs/images/aliyun.png b/docs/images/aliyun.png
new file mode 100644
index 00000000..6266bfbf
Binary files /dev/null and b/docs/images/aliyun.png differ
diff --git a/docs/images/cherry-studio.png b/docs/images/cherry-studio.png
new file mode 100644
index 00000000..a58a7713
Binary files /dev/null and b/docs/images/cherry-studio.png differ
diff --git a/docs/images/io-net.png b/docs/images/io-net.png
new file mode 100644
index 00000000..fb47534d
Binary files /dev/null and b/docs/images/io-net.png differ
diff --git a/docs/images/pku.png b/docs/images/pku.png
new file mode 100644
index 00000000..a058c3ce
Binary files /dev/null and b/docs/images/pku.png differ
diff --git a/docs/images/ucloud.png b/docs/images/ucloud.png
new file mode 100644
index 00000000..16cca764
Binary files /dev/null and b/docs/images/ucloud.png differ
diff --git a/dto/audio.go b/dto/audio.go
index c36b3da5..9d71f6f7 100644
--- a/dto/audio.go
+++ b/dto/audio.go
@@ -1,5 +1,11 @@
package dto
+import (
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
type AudioRequest struct {
Model string `json:"model"`
Input string `json:"input"`
@@ -8,6 +14,24 @@ type AudioRequest struct {
ResponseFormat string `json:"response_format,omitempty"`
}
+func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ meta := &types.TokenCountMeta{
+ CombineText: r.Input,
+ TokenType: types.TokenTypeTextNumber,
+ }
+ return meta
+}
+
+func (r *AudioRequest) IsStream(c *gin.Context) bool {
+ return false
+}
+
+func (r *AudioRequest) SetModelName(modelName string) {
+ if modelName != "" {
+ r.Model = modelName
+ }
+}
+
type AudioResponse struct {
Text string `json:"text"`
}
diff --git a/dto/channel_settings.go b/dto/channel_settings.go
new file mode 100644
index 00000000..2c58795c
--- /dev/null
+++ b/dto/channel_settings.go
@@ -0,0 +1,14 @@
+package dto
+
+type ChannelSettings struct {
+ ForceFormat bool `json:"force_format,omitempty"`
+ ThinkingToContent bool `json:"thinking_to_content,omitempty"`
+ Proxy string `json:"proxy"`
+ PassThroughBodyEnabled bool `json:"pass_through_body_enabled,omitempty"`
+ SystemPrompt string `json:"system_prompt,omitempty"`
+ SystemPromptOverride bool `json:"system_prompt_override,omitempty"`
+}
+
+type ChannelOtherSettings struct {
+ AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
+}
diff --git a/dto/claude.go b/dto/claude.go
index 36dfc02e..5c4396f2 100644
--- a/dto/claude.go
+++ b/dto/claude.go
@@ -1,6 +1,14 @@
package dto
-import "encoding/json"
+import (
+ "encoding/json"
+ "fmt"
+ "one-api/common"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
type ClaudeMetadata struct {
UserId string `json:"user_id"`
@@ -20,11 +28,11 @@ type ClaudeMediaMessage struct {
Delta string `json:"delta,omitempty"`
CacheControl json.RawMessage `json:"cache_control,omitempty"`
// tool_calls
- Id string `json:"id,omitempty"`
- Name string `json:"name,omitempty"`
- Input any `json:"input,omitempty"`
- Content json.RawMessage `json:"content,omitempty"`
- ToolUseId string `json:"tool_use_id,omitempty"`
+ Id string `json:"id,omitempty"`
+ Name string `json:"name,omitempty"`
+ Input any `json:"input,omitempty"`
+ Content any `json:"content,omitempty"`
+ ToolUseId string `json:"tool_use_id,omitempty"`
}
func (c *ClaudeMediaMessage) SetText(s string) {
@@ -39,34 +47,54 @@ func (c *ClaudeMediaMessage) GetText() string {
}
func (c *ClaudeMediaMessage) IsStringContent() bool {
- var content string
- return json.Unmarshal(c.Content, &content) == nil
+ if c.Content == nil {
+ return false
+ }
+ _, ok := c.Content.(string)
+ if ok {
+ return true
+ }
+ return false
}
func (c *ClaudeMediaMessage) GetStringContent() string {
- var content string
- if err := json.Unmarshal(c.Content, &content); err == nil {
- return content
+ if c.Content == nil {
+ return ""
}
+ switch c.Content.(type) {
+ case string:
+ return c.Content.(string)
+ case []any:
+ var contentStr string
+ for _, contentItem := range c.Content.([]any) {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ if contentMap["type"] == ContentTypeText {
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentStr += subStr
+ }
+ }
+ }
+ return contentStr
+ }
+
return ""
}
func (c *ClaudeMediaMessage) GetJsonRowString() string {
- jsonContent, _ := json.Marshal(c)
+ jsonContent, _ := common.Marshal(c)
return string(jsonContent)
}
func (c *ClaudeMediaMessage) SetContent(content any) {
- jsonContent, _ := json.Marshal(content)
- c.Content = jsonContent
+ c.Content = content
}
func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
- var mediaContent []ClaudeMediaMessage
- if err := json.Unmarshal(c.Content, &mediaContent); err == nil {
- return mediaContent
- }
- return make([]ClaudeMediaMessage, 0)
+ mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.Content)
+ return mediaContent
}
type ClaudeMessageSource struct {
@@ -82,14 +110,36 @@ type ClaudeMessage struct {
}
func (c *ClaudeMessage) IsStringContent() bool {
+ if c.Content == nil {
+ return false
+ }
_, ok := c.Content.(string)
return ok
}
func (c *ClaudeMessage) GetStringContent() string {
- if c.IsStringContent() {
- return c.Content.(string)
+ if c.Content == nil {
+ return ""
}
+ switch c.Content.(type) {
+ case string:
+ return c.Content.(string)
+ case []any:
+ var contentStr string
+ for _, contentItem := range c.Content.([]any) {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ if contentMap["type"] == ContentTypeText {
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentStr += subStr
+ }
+ }
+ }
+ return contentStr
+ }
+
return ""
}
@@ -98,15 +148,7 @@ func (c *ClaudeMessage) SetStringContent(content string) {
}
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
- // map content to []ClaudeMediaMessage
- // parse to json
- jsonContent, _ := json.Marshal(c.Content)
- var contentList []ClaudeMediaMessage
- err := json.Unmarshal(jsonContent, &contentList)
- if err != nil {
- return make([]ClaudeMediaMessage, 0), err
- }
- return contentList, nil
+ return common.Any2Type[[]ClaudeMediaMessage](c.Content)
}
type Tool struct {
@@ -121,6 +163,27 @@ type InputSchema struct {
Required any `json:"required,omitempty"`
}
+type ClaudeWebSearchTool struct {
+ Type string `json:"type"`
+ Name string `json:"name"`
+ MaxUses int `json:"max_uses,omitempty"`
+ UserLocation *ClaudeWebSearchUserLocation `json:"user_location,omitempty"`
+}
+
+type ClaudeWebSearchUserLocation struct {
+ Type string `json:"type"`
+ Timezone string `json:"timezone,omitempty"`
+ Country string `json:"country,omitempty"`
+ Region string `json:"region,omitempty"`
+ City string `json:"city,omitempty"`
+}
+
+type ClaudeToolChoice struct {
+ Type string `json:"type"`
+ Name string `json:"name,omitempty"`
+ DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
+}
+
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
@@ -139,9 +202,210 @@ type ClaudeRequest struct {
Thinking *Thinking `json:"thinking,omitempty"`
}
+func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var tokenCountMeta = types.TokenCountMeta{
+ TokenType: types.TokenTypeTokenizer,
+ MaxTokens: int(c.MaxTokens),
+ }
+
+ var texts = make([]string, 0)
+ var fileMeta = make([]*types.FileMeta, 0)
+
+ // system
+ if c.System != nil {
+ if c.IsStringSystem() {
+ sys := c.GetStringSystem()
+ if sys != "" {
+ texts = append(texts, sys)
+ }
+ } else {
+ systemMedia := c.ParseSystem()
+ for _, media := range systemMedia {
+ switch media.Type {
+ case "text":
+ texts = append(texts, media.GetText())
+ case "image":
+ if media.Source != nil {
+ data := media.Source.Url
+ if data == "" {
+ data = common.Interface2String(media.Source.Data)
+ }
+ if data != "" {
+ fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // messages
+ for _, message := range c.Messages {
+ tokenCountMeta.MessagesCount++
+ texts = append(texts, message.Role)
+ if message.IsStringContent() {
+ content := message.GetStringContent()
+ if content != "" {
+ texts = append(texts, content)
+ }
+ continue
+ }
+
+ content, _ := message.ParseContent()
+ for _, media := range content {
+ switch media.Type {
+ case "text":
+ texts = append(texts, media.GetText())
+ case "image":
+ if media.Source != nil {
+ data := media.Source.Url
+ if data == "" {
+ data = common.Interface2String(media.Source.Data)
+ }
+ if data != "" {
+ fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
+ }
+ }
+ case "tool_use":
+ if media.Name != "" {
+ texts = append(texts, media.Name)
+ }
+ if media.Input != nil {
+ b, _ := common.Marshal(media.Input)
+ texts = append(texts, string(b))
+ }
+ case "tool_result":
+ if media.Content != nil {
+ b, _ := common.Marshal(media.Content)
+ texts = append(texts, string(b))
+ }
+ }
+ }
+ }
+
+ // tools
+ if c.Tools != nil {
+ tools := c.GetTools()
+ normalTools, webSearchTools := ProcessTools(tools)
+ if normalTools != nil {
+ for _, t := range normalTools {
+ tokenCountMeta.ToolsCount++
+ if t.Name != "" {
+ texts = append(texts, t.Name)
+ }
+ if t.Description != "" {
+ texts = append(texts, t.Description)
+ }
+ if t.InputSchema != nil {
+ b, _ := common.Marshal(t.InputSchema)
+ texts = append(texts, string(b))
+ }
+ }
+ }
+ if webSearchTools != nil {
+ for _, t := range webSearchTools {
+ tokenCountMeta.ToolsCount++
+ if t.Name != "" {
+ texts = append(texts, t.Name)
+ }
+ if t.UserLocation != nil {
+ b, _ := common.Marshal(t.UserLocation)
+ texts = append(texts, string(b))
+ }
+ }
+ }
+ }
+
+ tokenCountMeta.CombineText = strings.Join(texts, "\n")
+ tokenCountMeta.Files = fileMeta
+ return &tokenCountMeta
+}
+
+func (c *ClaudeRequest) IsStream(ctx *gin.Context) bool {
+ return c.Stream
+}
+
+func (c *ClaudeRequest) SetModelName(modelName string) {
+ if modelName != "" {
+ c.Model = modelName
+ }
+}
+
+func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string {
+ for _, message := range c.Messages {
+ content, _ := message.ParseContent()
+ for _, mediaMessage := range content {
+ if mediaMessage.Id == toolCallId {
+ return mediaMessage.Name
+ }
+ }
+ }
+ return ""
+}
+
+// AddTool 添加工具到请求中
+func (c *ClaudeRequest) AddTool(tool any) {
+ if c.Tools == nil {
+ c.Tools = make([]any, 0)
+ }
+
+ switch tools := c.Tools.(type) {
+ case []any:
+ c.Tools = append(tools, tool)
+ default:
+ // 如果Tools不是[]any类型,重新初始化为[]any
+ c.Tools = []any{tool}
+ }
+}
+
+// GetTools 获取工具列表
+func (c *ClaudeRequest) GetTools() []any {
+ if c.Tools == nil {
+ return nil
+ }
+
+ switch tools := c.Tools.(type) {
+ case []any:
+ return tools
+ default:
+ return nil
+ }
+}
+
+// ProcessTools 处理工具列表,支持类型断言
+func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) {
+ var normalTools []*Tool
+ var webSearchTools []*ClaudeWebSearchTool
+
+ for _, tool := range tools {
+ switch t := tool.(type) {
+ case *Tool:
+ normalTools = append(normalTools, t)
+ case *ClaudeWebSearchTool:
+ webSearchTools = append(webSearchTools, t)
+ case Tool:
+ normalTools = append(normalTools, &t)
+ case ClaudeWebSearchTool:
+ webSearchTools = append(webSearchTools, &t)
+ default:
+ // 未知类型,跳过
+ continue
+ }
+ }
+
+ return normalTools, webSearchTools
+}
+
type Thinking struct {
Type string `json:"type"`
- BudgetTokens int `json:"budget_tokens"`
+ BudgetTokens *int `json:"budget_tokens,omitempty"`
+}
+
+func (c *Thinking) GetBudgetTokens() int {
+ if c.BudgetTokens == nil {
+ return 0
+ }
+ return *c.BudgetTokens
}
func (c *ClaudeRequest) IsStringSystem() bool {
@@ -161,24 +425,13 @@ func (c *ClaudeRequest) SetStringSystem(system string) {
}
func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
- // map content to []ClaudeMediaMessage
- // parse to json
- jsonContent, _ := json.Marshal(c.System)
- var contentList []ClaudeMediaMessage
- if err := json.Unmarshal(jsonContent, &contentList); err == nil {
- return contentList
- }
- return make([]ClaudeMediaMessage, 0)
-}
-
-type ClaudeError struct {
- Type string `json:"type,omitempty"`
- Message string `json:"message,omitempty"`
+ mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.System)
+ return mediaContent
}
type ClaudeErrorWithStatusCode struct {
- Error ClaudeError `json:"error"`
- StatusCode int `json:"status_code"`
+ Error types.ClaudeError `json:"error"`
+ StatusCode int `json:"status_code"`
LocalError bool
}
@@ -190,7 +443,7 @@ type ClaudeResponse struct {
Completion string `json:"completion,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
Model string `json:"model,omitempty"`
- Error *ClaudeError `json:"error,omitempty"`
+ Error any `json:"error,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
Index *int `json:"index,omitempty"`
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
@@ -211,9 +464,50 @@ func (c *ClaudeResponse) GetIndex() int {
return *c.Index
}
-type ClaudeUsage struct {
- InputTokens int `json:"input_tokens"`
- CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
- CacheReadInputTokens int `json:"cache_read_input_tokens"`
- OutputTokens int `json:"output_tokens"`
+// GetClaudeError 从动态错误类型中提取ClaudeError结构
+func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError {
+ if c.Error == nil {
+ return nil
+ }
+
+ switch err := c.Error.(type) {
+ case types.ClaudeError:
+ return &err
+ case *types.ClaudeError:
+ return err
+ case map[string]interface{}:
+ // 处理从JSON解析来的map结构
+ claudeErr := &types.ClaudeError{}
+ if errType, ok := err["type"].(string); ok {
+ claudeErr.Type = errType
+ }
+ if errMsg, ok := err["message"].(string); ok {
+ claudeErr.Message = errMsg
+ }
+ return claudeErr
+ case string:
+ // 处理简单字符串错误
+ return &types.ClaudeError{
+ Type: "error",
+ Message: err,
+ }
+ default:
+ // 未知类型,尝试转换为字符串
+ return &types.ClaudeError{
+ Type: "unknown_error",
+ Message: fmt.Sprintf("%v", err),
+ }
+ }
+}
+
+type ClaudeUsage struct {
+ InputTokens int `json:"input_tokens"`
+ CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
+ CacheReadInputTokens int `json:"cache_read_input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+ ServerToolUse *ClaudeServerToolUse `json:"server_tool_use,omitempty"`
+}
+
+type ClaudeServerToolUse struct {
+ WebSearchRequests int `json:"web_search_requests"`
}
diff --git a/dto/dalle.go b/dto/dalle.go
deleted file mode 100644
index a1309b6c..00000000
--- a/dto/dalle.go
+++ /dev/null
@@ -1,28 +0,0 @@
-package dto
-
-import "encoding/json"
-
-type ImageRequest struct {
- Model string `json:"model"`
- Prompt string `json:"prompt" binding:"required"`
- N int `json:"n,omitempty"`
- Size string `json:"size,omitempty"`
- Quality string `json:"quality,omitempty"`
- ResponseFormat string `json:"response_format,omitempty"`
- Style string `json:"style,omitempty"`
- User string `json:"user,omitempty"`
- ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
- Background string `json:"background,omitempty"`
- Moderation string `json:"moderation,omitempty"`
- OutputFormat string `json:"output_format,omitempty"`
-}
-
-type ImageResponse struct {
- Data []ImageData `json:"data"`
- Created int64 `json:"created"`
-}
-type ImageData struct {
- Url string `json:"url"`
- B64Json string `json:"b64_json"`
- RevisedPrompt string `json:"revised_prompt"`
-}
diff --git a/dto/embedding.go b/dto/embedding.go
index 9d722292..b473b722 100644
--- a/dto/embedding.go
+++ b/dto/embedding.go
@@ -1,5 +1,12 @@
package dto
+import (
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
type EmbeddingOptions struct {
Seed int `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
@@ -24,9 +31,32 @@ type EmbeddingRequest struct {
PresencePenalty float64 `json:"presence_penalty,omitempty"`
}
-func (r EmbeddingRequest) ParseInput() []string {
+func (r *EmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var texts = make([]string, 0)
+
+ inputs := r.ParseInput()
+ for _, input := range inputs {
+ texts = append(texts, input)
+ }
+
+ return &types.TokenCountMeta{
+ CombineText: strings.Join(texts, "\n"),
+ }
+}
+
+func (r *EmbeddingRequest) IsStream(c *gin.Context) bool {
+ return false
+}
+
+func (r *EmbeddingRequest) SetModelName(modelName string) {
+ if modelName != "" {
+ r.Model = modelName
+ }
+}
+
+func (r *EmbeddingRequest) ParseInput() []string {
if r.Input == nil {
- return nil
+ return make([]string, 0)
}
var input []string
switch r.Input.(type) {
diff --git a/dto/error.go b/dto/error.go
index b347f6a1..d7f6824d 100644
--- a/dto/error.go
+++ b/dto/error.go
@@ -1,5 +1,7 @@
package dto
+import "one-api/types"
+
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
@@ -14,11 +16,11 @@ type OpenAIErrorWithStatusCode struct {
}
type GeneralErrorResponse struct {
- Error OpenAIError `json:"error"`
- Message string `json:"message"`
- Msg string `json:"msg"`
- Err string `json:"err"`
- ErrorMsg string `json:"error_msg"`
+ Error types.OpenAIError `json:"error"`
+ Message string `json:"message"`
+ Msg string `json:"msg"`
+ Err string `json:"err"`
+ ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
diff --git a/dto/gemini.go b/dto/gemini.go
new file mode 100644
index 00000000..5df67ba0
--- /dev/null
+++ b/dto/gemini.go
@@ -0,0 +1,384 @@
+package dto
+
+import (
+ "encoding/json"
+ "github.com/gin-gonic/gin"
+ "one-api/common"
+ "one-api/logger"
+ "one-api/types"
+ "strings"
+)
+
+type GeminiChatRequest struct {
+ Contents []GeminiChatContent `json:"contents"`
+ SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
+ GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
+ Tools json.RawMessage `json:"tools,omitempty"`
+ SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
+}
+
+func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var files []*types.FileMeta = make([]*types.FileMeta, 0)
+
+ var maxTokens int
+
+ if r.GenerationConfig.MaxOutputTokens > 0 {
+ maxTokens = int(r.GenerationConfig.MaxOutputTokens)
+ }
+
+ var inputTexts []string
+ for _, content := range r.Contents {
+ for _, part := range content.Parts {
+ if part.Text != "" {
+ inputTexts = append(inputTexts, part.Text)
+ }
+ if part.InlineData != nil && part.InlineData.Data != "" {
+ if strings.HasPrefix(part.InlineData.MimeType, "image/") {
+ files = append(files, &types.FileMeta{
+ FileType: types.FileTypeImage,
+ OriginData: part.InlineData.Data,
+ })
+ } else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
+ files = append(files, &types.FileMeta{
+ FileType: types.FileTypeAudio,
+ OriginData: part.InlineData.Data,
+ })
+ } else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
+ files = append(files, &types.FileMeta{
+ FileType: types.FileTypeVideo,
+ OriginData: part.InlineData.Data,
+ })
+ } else {
+ files = append(files, &types.FileMeta{
+ FileType: types.FileTypeFile,
+ OriginData: part.InlineData.Data,
+ })
+ }
+ }
+ }
+ }
+
+ inputText := strings.Join(inputTexts, "\n")
+ return &types.TokenCountMeta{
+ CombineText: inputText,
+ Files: files,
+ MaxTokens: maxTokens,
+ }
+}
+
+func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
+ if c.Query("alt") == "sse" {
+ return true
+ }
+ return false
+}
+
+func (r *GeminiChatRequest) SetModelName(modelName string) {
+ // GeminiChatRequest does not have a model field, so this method does nothing.
+}
+
+func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
+ var tools []GeminiChatTool
+ if strings.HasSuffix(string(r.Tools), "[") {
+ // is array
+ if err := common.Unmarshal(r.Tools, &tools); err != nil {
+ logger.LogError(nil, "error_unmarshalling_tools: "+err.Error())
+ return nil
+ }
+ } else if strings.HasPrefix(string(r.Tools), "{") {
+ // is object
+ singleTool := GeminiChatTool{}
+ if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
+ logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
+ return nil
+ }
+ tools = []GeminiChatTool{singleTool}
+ }
+ return tools
+}
+
+func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
+ if len(tools) == 0 {
+ r.Tools = json.RawMessage("[]")
+ return
+ }
+
+ // Marshal the tools to JSON
+ data, err := common.Marshal(tools)
+ if err != nil {
+ logger.LogError(nil, "error_marshalling_tools: "+err.Error())
+ return
+ }
+ r.Tools = data
+}
+
+type GeminiThinkingConfig struct {
+ IncludeThoughts bool `json:"includeThoughts,omitempty"`
+ ThinkingBudget *int `json:"thinkingBudget,omitempty"`
+}
+
+func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
+ c.ThinkingBudget = &budget
+}
+
+type GeminiInlineData struct {
+ MimeType string `json:"mimeType"`
+ Data string `json:"data"`
+}
+
+// UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
+func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
+ type Alias GeminiInlineData // Use type alias to avoid recursion
+ var aux struct {
+ Alias
+ MimeTypeSnake string `json:"mime_type"`
+ }
+
+ if err := common.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ *g = GeminiInlineData(aux.Alias) // Copy other fields if any in future
+
+ // Prioritize snake_case if present
+ if aux.MimeTypeSnake != "" {
+ g.MimeType = aux.MimeTypeSnake
+ } else if aux.MimeType != "" { // Fallback to camelCase from Alias
+ g.MimeType = aux.MimeType
+ }
+ // g.Data would be populated by aux.Alias.Data
+ return nil
+}
+
+type FunctionCall struct {
+ FunctionName string `json:"name"`
+ Arguments any `json:"args"`
+}
+
+type GeminiFunctionResponse struct {
+ Name string `json:"name"`
+ Response map[string]interface{} `json:"response"`
+}
+
+type GeminiPartExecutableCode struct {
+ Language string `json:"language,omitempty"`
+ Code string `json:"code,omitempty"`
+}
+
+type GeminiPartCodeExecutionResult struct {
+ Outcome string `json:"outcome,omitempty"`
+ Output string `json:"output,omitempty"`
+}
+
+type GeminiFileData struct {
+ MimeType string `json:"mimeType,omitempty"`
+ FileUri string `json:"fileUri,omitempty"`
+}
+
+type GeminiPart struct {
+ Text string `json:"text,omitempty"`
+ Thought bool `json:"thought,omitempty"`
+ InlineData *GeminiInlineData `json:"inlineData,omitempty"`
+ FunctionCall *FunctionCall `json:"functionCall,omitempty"`
+ FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
+ FileData *GeminiFileData `json:"fileData,omitempty"`
+ ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
+ CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
+}
+
+// UnmarshalJSON custom unmarshaler for GeminiPart to support snake_case and camelCase for InlineData
+func (p *GeminiPart) UnmarshalJSON(data []byte) error {
+ // Alias to avoid recursion during unmarshalling
+ type Alias GeminiPart
+ var aux struct {
+ Alias
+ InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
+ }
+
+ if err := common.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ // Assign fields from alias
+ *p = GeminiPart(aux.Alias)
+
+ // Prioritize snake_case for InlineData if present
+ if aux.InlineDataSnake != nil {
+ p.InlineData = aux.InlineDataSnake
+ } else if aux.InlineData != nil { // Fallback to camelCase from Alias
+ p.InlineData = aux.InlineData
+ }
+ // Other fields like Text, FunctionCall etc. are already populated via aux.Alias
+
+ return nil
+}
+
+type GeminiChatContent struct {
+ Role string `json:"role,omitempty"`
+ Parts []GeminiPart `json:"parts"`
+}
+
+type GeminiChatSafetySettings struct {
+ Category string `json:"category"`
+ Threshold string `json:"threshold"`
+}
+
+type GeminiChatTool struct {
+ GoogleSearch any `json:"googleSearch,omitempty"`
+ GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
+ CodeExecution any `json:"codeExecution,omitempty"`
+ FunctionDeclarations any `json:"functionDeclarations,omitempty"`
+}
+
+type GeminiChatGenerationConfig struct {
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"topP,omitempty"`
+ TopK float64 `json:"topK,omitempty"`
+ MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
+ CandidateCount int `json:"candidateCount,omitempty"`
+ StopSequences []string `json:"stopSequences,omitempty"`
+ ResponseMimeType string `json:"responseMimeType,omitempty"`
+ ResponseSchema any `json:"responseSchema,omitempty"`
+ Seed int64 `json:"seed,omitempty"`
+ ResponseModalities []string `json:"responseModalities,omitempty"`
+ ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
+ SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
+}
+
+type GeminiChatCandidate struct {
+ Content GeminiChatContent `json:"content"`
+ FinishReason *string `json:"finishReason"`
+ Index int64 `json:"index"`
+ SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
+}
+
+type GeminiChatSafetyRating struct {
+ Category string `json:"category"`
+ Probability string `json:"probability"`
+}
+
+type GeminiChatPromptFeedback struct {
+ SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
+}
+
+type GeminiChatResponse struct {
+ Candidates []GeminiChatCandidate `json:"candidates"`
+ PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
+ UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
+}
+
+type GeminiUsageMetadata struct {
+ PromptTokenCount int `json:"promptTokenCount"`
+ CandidatesTokenCount int `json:"candidatesTokenCount"`
+ TotalTokenCount int `json:"totalTokenCount"`
+ ThoughtsTokenCount int `json:"thoughtsTokenCount"`
+ PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
+}
+
+type GeminiPromptTokensDetails struct {
+ Modality string `json:"modality"`
+ TokenCount int `json:"tokenCount"`
+}
+
+// Imagen related structs
+type GeminiImageRequest struct {
+ Instances []GeminiImageInstance `json:"instances"`
+ Parameters GeminiImageParameters `json:"parameters"`
+}
+
+type GeminiImageInstance struct {
+ Prompt string `json:"prompt"`
+}
+
+type GeminiImageParameters struct {
+ SampleCount int `json:"sampleCount,omitempty"`
+ AspectRatio string `json:"aspectRatio,omitempty"`
+ PersonGeneration string `json:"personGeneration,omitempty"`
+}
+
+type GeminiImageResponse struct {
+ Predictions []GeminiImagePrediction `json:"predictions"`
+}
+
+type GeminiImagePrediction struct {
+ MimeType string `json:"mimeType"`
+ BytesBase64Encoded string `json:"bytesBase64Encoded"`
+ RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
+ SafetyAttributes any `json:"safetyAttributes,omitempty"`
+}
+
+// Embedding related structs
+type GeminiEmbeddingRequest struct {
+ Model string `json:"model,omitempty"`
+ Content GeminiChatContent `json:"content"`
+ TaskType string `json:"taskType,omitempty"`
+ Title string `json:"title,omitempty"`
+ OutputDimensionality int `json:"outputDimensionality,omitempty"`
+}
+
+func (r *GeminiEmbeddingRequest) IsStream(c *gin.Context) bool {
+ // Gemini embedding requests are not streamed
+ return false
+}
+
+func (r *GeminiEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var inputTexts []string
+ for _, part := range r.Content.Parts {
+ if part.Text != "" {
+ inputTexts = append(inputTexts, part.Text)
+ }
+ }
+ inputText := strings.Join(inputTexts, "\n")
+ return &types.TokenCountMeta{
+ CombineText: inputText,
+ }
+}
+
+func (r *GeminiEmbeddingRequest) SetModelName(modelName string) {
+ if modelName != "" {
+ r.Model = modelName
+ }
+}
+
+type GeminiBatchEmbeddingRequest struct {
+ Requests []*GeminiEmbeddingRequest `json:"requests"`
+}
+
+func (r *GeminiBatchEmbeddingRequest) IsStream(c *gin.Context) bool {
+ // Gemini batch embedding requests are not streamed
+ return false
+}
+
+func (r *GeminiBatchEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var inputTexts []string
+ for _, request := range r.Requests {
+ meta := request.GetTokenCountMeta()
+ if meta != nil && meta.CombineText != "" {
+ inputTexts = append(inputTexts, meta.CombineText)
+ }
+ }
+ inputText := strings.Join(inputTexts, "\n")
+ return &types.TokenCountMeta{
+ CombineText: inputText,
+ }
+}
+
+func (r *GeminiBatchEmbeddingRequest) SetModelName(modelName string) {
+ if modelName != "" {
+ for _, req := range r.Requests {
+ req.SetModelName(modelName)
+ }
+ }
+}
+
+type GeminiEmbeddingResponse struct {
+ Embedding ContentEmbedding `json:"embedding"`
+}
+
+type GeminiBatchEmbeddingResponse struct {
+ Embeddings []*ContentEmbedding `json:"embeddings"`
+}
+
+type ContentEmbedding struct {
+ Values []float64 `json:"values"`
+}
diff --git a/dto/midjourney.go b/dto/midjourney.go
index 40251ee9..6fbcb357 100644
--- a/dto/midjourney.go
+++ b/dto/midjourney.go
@@ -57,6 +57,8 @@ type MidjourneyDto struct {
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"`
+ VideoUrl string `json:"videoUrl"`
+ VideoUrls []ImgUrls `json:"videoUrls"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"failReason"`
@@ -65,6 +67,10 @@ type MidjourneyDto struct {
Properties *Properties `json:"properties"`
}
+type ImgUrls struct {
+ Url string `json:"url"`
+}
+
type MidjourneyStatus struct {
Status int `json:"status"`
}
diff --git a/dto/openai_image.go b/dto/openai_image.go
new file mode 100644
index 00000000..c26c4200
--- /dev/null
+++ b/dto/openai_image.go
@@ -0,0 +1,80 @@
+package dto
+
+import (
+ "encoding/json"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+type ImageRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt" binding:"required"`
+ N uint `json:"n,omitempty"`
+ Size string `json:"size,omitempty"`
+ Quality string `json:"quality,omitempty"`
+ ResponseFormat string `json:"response_format,omitempty"`
+ Style json.RawMessage `json:"style,omitempty"`
+ User json.RawMessage `json:"user,omitempty"`
+ ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
+ Background json.RawMessage `json:"background,omitempty"`
+ Moderation json.RawMessage `json:"moderation,omitempty"`
+ OutputFormat json.RawMessage `json:"output_format,omitempty"`
+ OutputCompression json.RawMessage `json:"output_compression,omitempty"`
+ PartialImages json.RawMessage `json:"partial_images,omitempty"`
+ // Stream bool `json:"stream,omitempty"`
+ Watermark *bool `json:"watermark,omitempty"`
+}
+
+func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var sizeRatio = 1.0
+ var qualityRatio = 1.0
+
+ if strings.HasPrefix(i.Model, "dall-e") {
+ // Size
+ if i.Size == "256x256" {
+ sizeRatio = 0.4
+ } else if i.Size == "512x512" {
+ sizeRatio = 0.45
+ } else if i.Size == "1024x1024" {
+ sizeRatio = 1
+ } else if i.Size == "1024x1792" || i.Size == "1792x1024" {
+ sizeRatio = 2
+ }
+
+ if i.Model == "dall-e-3" && i.Quality == "hd" {
+ qualityRatio = 2.0
+ if i.Size == "1024x1792" || i.Size == "1792x1024" {
+ qualityRatio = 1.5
+ }
+ }
+ }
+
+ // not support token count for dalle
+ return &types.TokenCountMeta{
+ CombineText: i.Prompt,
+ MaxTokens: 1584,
+ ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N),
+ }
+}
+
+func (i *ImageRequest) IsStream(c *gin.Context) bool {
+ return false
+}
+
+func (i *ImageRequest) SetModelName(modelName string) {
+ if modelName != "" {
+ i.Model = modelName
+ }
+}
+
+type ImageResponse struct {
+ Data []ImageData `json:"data"`
+ Created int64 `json:"created"`
+}
+type ImageData struct {
+ Url string `json:"url"`
+ B64Json string `json:"b64_json"`
+ RevisedPrompt string `json:"revised_prompt"`
+}
diff --git a/dto/openai_request.go b/dto/openai_request.go
index a7325fe8..02f969a7 100644
--- a/dto/openai_request.go
+++ b/dto/openai_request.go
@@ -2,70 +2,211 @@ package dto
import (
"encoding/json"
+ "fmt"
"one-api/common"
+ "one-api/types"
"strings"
+
+ "github.com/gin-gonic/gin"
)
type ResponseFormat struct {
- Type string `json:"type,omitempty"`
- JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
+ Type string `json:"type,omitempty"`
+ JsonSchema json.RawMessage `json:"json_schema,omitempty"`
}
type FormatJsonSchema struct {
- Description string `json:"description,omitempty"`
- Name string `json:"name"`
- Schema any `json:"schema,omitempty"`
- Strict any `json:"strict,omitempty"`
+ Description string `json:"description,omitempty"`
+ Name string `json:"name"`
+ Schema any `json:"schema,omitempty"`
+ Strict json.RawMessage `json:"strict,omitempty"`
}
type GeneralOpenAIRequest struct {
- Model string `json:"model,omitempty"`
- Messages []Message `json:"messages,omitempty"`
- Prompt any `json:"prompt,omitempty"`
- Prefix any `json:"prefix,omitempty"`
- Suffix any `json:"suffix,omitempty"`
- Stream bool `json:"stream,omitempty"`
- StreamOptions *StreamOptions `json:"stream_options,omitempty"`
- MaxTokens uint `json:"max_tokens,omitempty"`
- MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
- ReasoningEffort string `json:"reasoning_effort,omitempty"`
- Temperature *float64 `json:"temperature,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- TopK int `json:"top_k,omitempty"`
- Stop any `json:"stop,omitempty"`
- N int `json:"n,omitempty"`
- Input any `json:"input,omitempty"`
- Instruction string `json:"instruction,omitempty"`
- Size string `json:"size,omitempty"`
- Functions any `json:"functions,omitempty"`
- FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
- PresencePenalty float64 `json:"presence_penalty,omitempty"`
- ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
- EncodingFormat any `json:"encoding_format,omitempty"`
- Seed float64 `json:"seed,omitempty"`
- ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
- Tools []ToolCallRequest `json:"tools,omitempty"`
- ToolChoice any `json:"tool_choice,omitempty"`
- User string `json:"user,omitempty"`
- LogProbs bool `json:"logprobs,omitempty"`
- TopLogProbs int `json:"top_logprobs,omitempty"`
- Dimensions int `json:"dimensions,omitempty"`
- Modalities any `json:"modalities,omitempty"`
- Audio any `json:"audio,omitempty"`
- EnableThinking any `json:"enable_thinking,omitempty"` // ali
- ExtraBody any `json:"extra_body,omitempty"`
- WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
- // OpenRouter Params
+ Model string `json:"model,omitempty"`
+ Messages []Message `json:"messages,omitempty"`
+ Prompt any `json:"prompt,omitempty"`
+ Prefix any `json:"prefix,omitempty"`
+ Suffix any `json:"suffix,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ StreamOptions *StreamOptions `json:"stream_options,omitempty"`
+ MaxTokens uint `json:"max_tokens,omitempty"`
+ MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
+ ReasoningEffort string `json:"reasoning_effort,omitempty"`
+ Verbosity json.RawMessage `json:"verbosity,omitempty"` // gpt-5
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ Stop any `json:"stop,omitempty"`
+ N int `json:"n,omitempty"`
+ Input any `json:"input,omitempty"`
+ Instruction string `json:"instruction,omitempty"`
+ Size string `json:"size,omitempty"`
+ Functions json.RawMessage `json:"functions,omitempty"`
+ FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty float64 `json:"presence_penalty,omitempty"`
+ ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
+ EncodingFormat json.RawMessage `json:"encoding_format,omitempty"`
+ Seed float64 `json:"seed,omitempty"`
+ ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
+ Tools []ToolCallRequest `json:"tools,omitempty"`
+ ToolChoice any `json:"tool_choice,omitempty"`
+ User string `json:"user,omitempty"`
+ LogProbs bool `json:"logprobs,omitempty"`
+ TopLogProbs int `json:"top_logprobs,omitempty"`
+ Dimensions int `json:"dimensions,omitempty"`
+ Modalities json.RawMessage `json:"modalities,omitempty"`
+ Audio json.RawMessage `json:"audio,omitempty"`
+ EnableThinking any `json:"enable_thinking,omitempty"` // ali
+ THINKING json.RawMessage `json:"thinking,omitempty"` // doubao,zhipu_v4
+ ExtraBody json.RawMessage `json:"extra_body,omitempty"`
+ SearchParameters any `json:"search_parameters,omitempty"` //xai
+ WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
+ // OpenRouter Params
+ Usage json.RawMessage `json:"usage,omitempty"`
Reasoning json.RawMessage `json:"reasoning,omitempty"`
+ // Ali Qwen Params
+ VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
+ // 用匿名参数接收额外参数,例如ollama的think参数在此接收
+ Extra map[string]json.RawMessage `json:"-"`
+}
+
+func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var tokenCountMeta types.TokenCountMeta
+ var texts = make([]string, 0)
+ var fileMeta = make([]*types.FileMeta, 0)
+
+ if r.Prompt != nil {
+ switch v := r.Prompt.(type) {
+ case string:
+ texts = append(texts, v)
+ case []any:
+ for _, item := range v {
+ if str, ok := item.(string); ok {
+ texts = append(texts, str)
+ }
+ }
+ default:
+ texts = append(texts, fmt.Sprintf("%v", r.Prompt))
+ }
+ }
+
+ if r.Input != nil {
+ inputs := r.ParseInput()
+ texts = append(texts, inputs...)
+ }
+
+ if r.MaxCompletionTokens > r.MaxTokens {
+ tokenCountMeta.MaxTokens = int(r.MaxCompletionTokens)
+ } else {
+ tokenCountMeta.MaxTokens = int(r.MaxTokens)
+ }
+
+ for _, message := range r.Messages {
+ tokenCountMeta.MessagesCount++
+ texts = append(texts, message.Role)
+ if message.Content != nil {
+ if message.Name != nil {
+ tokenCountMeta.NameCount++
+ texts = append(texts, *message.Name)
+ }
+ arrayContent := message.ParseContent()
+ for _, m := range arrayContent {
+ if m.Type == ContentTypeImageURL {
+ imageUrl := m.GetImageMedia()
+ if imageUrl != nil {
+ if imageUrl.Url != "" {
+ meta := &types.FileMeta{
+ FileType: types.FileTypeImage,
+ }
+ meta.OriginData = imageUrl.Url
+ meta.Detail = imageUrl.Detail
+ fileMeta = append(fileMeta, meta)
+ }
+ }
+ } else if m.Type == ContentTypeInputAudio {
+ inputAudio := m.GetInputAudio()
+ if inputAudio != nil {
+ meta := &types.FileMeta{
+ FileType: types.FileTypeAudio,
+ }
+ meta.OriginData = inputAudio.Data
+ fileMeta = append(fileMeta, meta)
+ }
+ } else if m.Type == ContentTypeFile {
+ file := m.GetFile()
+ if file != nil {
+ meta := &types.FileMeta{
+ FileType: types.FileTypeFile,
+ }
+ meta.OriginData = file.FileData
+ fileMeta = append(fileMeta, meta)
+ }
+ } else if m.Type == ContentTypeVideoUrl {
+ videoUrl := m.GetVideoUrl()
+ if videoUrl != nil && videoUrl.Url != "" {
+ meta := &types.FileMeta{
+ FileType: types.FileTypeVideo,
+ }
+ meta.OriginData = videoUrl.Url
+ fileMeta = append(fileMeta, meta)
+ }
+ } else {
+ texts = append(texts, m.Text)
+ }
+ }
+ }
+ }
+
+ if r.Tools != nil {
+ openaiTools := r.Tools
+ for _, tool := range openaiTools {
+ tokenCountMeta.ToolsCount++
+ texts = append(texts, tool.Function.Name)
+ if tool.Function.Description != "" {
+ texts = append(texts, tool.Function.Description)
+ }
+ if tool.Function.Parameters != nil {
+ texts = append(texts, fmt.Sprintf("%v", tool.Function.Parameters))
+ }
+ }
+ //toolTokens := CountTokenInput(countStr, request.Model)
+ //tkm += 8
+ //tkm += toolTokens
+ }
+ tokenCountMeta.CombineText = strings.Join(texts, "\n")
+ tokenCountMeta.Files = fileMeta
+ return &tokenCountMeta
+}
+
+func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool {
+ return r.Stream
+}
+
+func (r *GeneralOpenAIRequest) SetModelName(modelName string) {
+ if modelName != "" {
+ r.Model = modelName
+ }
}
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
result := make(map[string]any)
- data, _ := common.EncodeJson(r)
- _ = common.DecodeJson(data, &result)
+ data, _ := common.Marshal(r)
+ _ = common.Unmarshal(data, &result)
return result
}
+func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
+ if strings.HasPrefix(r.Model, "o") {
+ if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
+ return "developer"
+ }
+ } else if strings.HasPrefix(r.Model, "gpt-5") {
+ return "developer"
+ }
+ return "system"
+}
+
type ToolCallRequest struct {
ID string `json:"id,omitempty"`
Type string `json:"type"`
@@ -83,8 +224,11 @@ type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
-func (r *GeneralOpenAIRequest) GetMaxTokens() int {
- return int(r.MaxTokens)
+func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
+ if r.MaxCompletionTokens != 0 {
+ return r.MaxCompletionTokens
+ }
+ return r.MaxTokens
}
func (r *GeneralOpenAIRequest) ParseInput() []string {
@@ -107,16 +251,16 @@ func (r *GeneralOpenAIRequest) ParseInput() []string {
}
type Message struct {
- Role string `json:"role"`
- Content json.RawMessage `json:"content"`
- Name *string `json:"name,omitempty"`
- Prefix *bool `json:"prefix,omitempty"`
- ReasoningContent string `json:"reasoning_content,omitempty"`
- Reasoning string `json:"reasoning,omitempty"`
- ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
- ToolCallId string `json:"tool_call_id,omitempty"`
- parsedContent []MediaContent
- parsedStringContent *string
+ Role string `json:"role"`
+ Content any `json:"content"`
+ Name *string `json:"name,omitempty"`
+ Prefix *bool `json:"prefix,omitempty"`
+ ReasoningContent string `json:"reasoning_content,omitempty"`
+ Reasoning string `json:"reasoning,omitempty"`
+ ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
+ ToolCallId string `json:"tool_call_id,omitempty"`
+ parsedContent []MediaContent
+ //parsedStringContent *string
}
type MediaContent struct {
@@ -132,21 +276,65 @@ type MediaContent struct {
func (m *MediaContent) GetImageMedia() *MessageImageUrl {
if m.ImageUrl != nil {
- return m.ImageUrl.(*MessageImageUrl)
+ if _, ok := m.ImageUrl.(*MessageImageUrl); ok {
+ return m.ImageUrl.(*MessageImageUrl)
+ }
+ if itemMap, ok := m.ImageUrl.(map[string]any); ok {
+ out := &MessageImageUrl{
+ Url: common.Interface2String(itemMap["url"]),
+ Detail: common.Interface2String(itemMap["detail"]),
+ MimeType: common.Interface2String(itemMap["mime_type"]),
+ }
+ return out
+ }
}
return nil
}
func (m *MediaContent) GetInputAudio() *MessageInputAudio {
if m.InputAudio != nil {
- return m.InputAudio.(*MessageInputAudio)
+ if _, ok := m.InputAudio.(*MessageInputAudio); ok {
+ return m.InputAudio.(*MessageInputAudio)
+ }
+ if itemMap, ok := m.InputAudio.(map[string]any); ok {
+ out := &MessageInputAudio{
+ Data: common.Interface2String(itemMap["data"]),
+ Format: common.Interface2String(itemMap["format"]),
+ }
+ return out
+ }
}
return nil
}
func (m *MediaContent) GetFile() *MessageFile {
if m.File != nil {
- return m.File.(*MessageFile)
+ if _, ok := m.File.(*MessageFile); ok {
+ return m.File.(*MessageFile)
+ }
+ if itemMap, ok := m.File.(map[string]any); ok {
+ out := &MessageFile{
+ FileName: common.Interface2String(itemMap["file_name"]),
+ FileData: common.Interface2String(itemMap["file_data"]),
+ FileId: common.Interface2String(itemMap["file_id"]),
+ }
+ return out
+ }
+ }
+ return nil
+}
+
+func (m *MediaContent) GetVideoUrl() *MessageVideoUrl {
+ if m.VideoUrl != nil {
+ if _, ok := m.VideoUrl.(*MessageVideoUrl); ok {
+ return m.VideoUrl.(*MessageVideoUrl)
+ }
+ if itemMap, ok := m.VideoUrl.(map[string]any); ok {
+ out := &MessageVideoUrl{
+ Url: common.Interface2String(itemMap["url"]),
+ }
+ return out
+ }
}
return nil
}
@@ -182,6 +370,7 @@ const (
ContentTypeInputAudio = "input_audio"
ContentTypeFile = "file"
ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别
+ //ContentTypeAudioUrl = "audio_url"
)
func (m *Message) GetPrefix() bool {
@@ -212,6 +401,186 @@ func (m *Message) SetToolCalls(toolCalls any) {
}
func (m *Message) StringContent() string {
+ switch m.Content.(type) {
+ case string:
+ return m.Content.(string)
+ case []any:
+ var contentStr string
+ for _, contentItem := range m.Content.([]any) {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ if contentMap["type"] == ContentTypeText {
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentStr += subStr
+ }
+ }
+ }
+ return contentStr
+ }
+
+ return ""
+}
+
+func (m *Message) SetNullContent() {
+ m.Content = nil
+ m.parsedContent = nil
+}
+
+func (m *Message) SetStringContent(content string) {
+ m.Content = content
+ m.parsedContent = nil
+}
+
+func (m *Message) SetMediaContent(content []MediaContent) {
+ m.Content = content
+ m.parsedContent = content
+}
+
+func (m *Message) IsStringContent() bool {
+ _, ok := m.Content.(string)
+ if ok {
+ return true
+ }
+ return false
+}
+
+func (m *Message) ParseContent() []MediaContent {
+ if m.Content == nil {
+ return nil
+ }
+ if len(m.parsedContent) > 0 {
+ return m.parsedContent
+ }
+
+ var contentList []MediaContent
+ // 先尝试解析为字符串
+ content, ok := m.Content.(string)
+ if ok {
+ contentList = []MediaContent{{
+ Type: ContentTypeText,
+ Text: content,
+ }}
+ m.parsedContent = contentList
+ return contentList
+ }
+
+ // 尝试解析为数组
+ //var arrayContent []map[string]interface{}
+
+ arrayContent, ok := m.Content.([]any)
+ if !ok {
+ return contentList
+ }
+
+ for _, contentItemAny := range arrayContent {
+ mediaItem, ok := contentItemAny.(MediaContent)
+ if ok {
+ contentList = append(contentList, mediaItem)
+ continue
+ }
+
+ contentItem, ok := contentItemAny.(map[string]any)
+ if !ok {
+ continue
+ }
+ contentType, ok := contentItem["type"].(string)
+ if !ok {
+ continue
+ }
+
+ switch contentType {
+ case ContentTypeText:
+ if text, ok := contentItem["text"].(string); ok {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeText,
+ Text: text,
+ })
+ }
+
+ case ContentTypeImageURL:
+ imageUrl := contentItem["image_url"]
+ temp := &MessageImageUrl{
+ Detail: "high",
+ }
+ switch v := imageUrl.(type) {
+ case string:
+ temp.Url = v
+ case map[string]interface{}:
+ url, ok1 := v["url"].(string)
+ detail, ok2 := v["detail"].(string)
+ if ok2 {
+ temp.Detail = detail
+ }
+ if ok1 {
+ temp.Url = url
+ }
+ }
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeImageURL,
+ ImageUrl: temp,
+ })
+
+ case ContentTypeInputAudio:
+ if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
+ data, ok1 := audioData["data"].(string)
+ format, ok2 := audioData["format"].(string)
+ if ok1 && ok2 {
+ temp := &MessageInputAudio{
+ Data: data,
+ Format: format,
+ }
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeInputAudio,
+ InputAudio: temp,
+ })
+ }
+ }
+ case ContentTypeFile:
+ if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
+ fileId, ok3 := fileData["file_id"].(string)
+ if ok3 {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeFile,
+ File: &MessageFile{
+ FileId: fileId,
+ },
+ })
+ } else {
+ fileName, ok1 := fileData["filename"].(string)
+ fileDataStr, ok2 := fileData["file_data"].(string)
+ if ok1 && ok2 {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeFile,
+ File: &MessageFile{
+ FileName: fileName,
+ FileData: fileDataStr,
+ },
+ })
+ }
+ }
+ }
+ case ContentTypeVideoUrl:
+ if videoUrl, ok := contentItem["video_url"].(string); ok {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeVideoUrl,
+ VideoUrl: &MessageVideoUrl{
+ Url: videoUrl,
+ },
+ })
+ }
+ }
+ }
+
+ if len(contentList) > 0 {
+ m.parsedContent = contentList
+ }
+ return contentList
+}
+
+// old code
+/*func (m *Message) StringContent() string {
if m.parsedStringContent != nil {
return *m.parsedStringContent
}
@@ -382,33 +751,106 @@ func (m *Message) ParseContent() []MediaContent {
m.parsedContent = contentList
}
return contentList
-}
+}*/
type WebSearchOptions struct {
SearchContextSize string `json:"search_context_size,omitempty"`
UserLocation json.RawMessage `json:"user_location,omitempty"`
}
+// https://platform.openai.com/docs/api-reference/responses/create
type OpenAIResponsesRequest struct {
- Model string `json:"model"`
- Input json.RawMessage `json:"input,omitempty"`
- Include json.RawMessage `json:"include,omitempty"`
- Instructions json.RawMessage `json:"instructions,omitempty"`
- MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
- Metadata json.RawMessage `json:"metadata,omitempty"`
- ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
- PreviousResponseID string `json:"previous_response_id,omitempty"`
- Reasoning *Reasoning `json:"reasoning,omitempty"`
- ServiceTier string `json:"service_tier,omitempty"`
- Store bool `json:"store,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
- Text json.RawMessage `json:"text,omitempty"`
- ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
- Tools []ResponsesToolsCall `json:"tools,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- Truncation string `json:"truncation,omitempty"`
- User string `json:"user,omitempty"`
+ Model string `json:"model"`
+ Input any `json:"input,omitempty"`
+ Include json.RawMessage `json:"include,omitempty"`
+ Instructions json.RawMessage `json:"instructions,omitempty"`
+ MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
+ Metadata json.RawMessage `json:"metadata,omitempty"`
+ ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
+ PreviousResponseID string `json:"previous_response_id,omitempty"`
+ Reasoning *Reasoning `json:"reasoning,omitempty"`
+ ServiceTier string `json:"service_tier,omitempty"`
+ Store bool `json:"store,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ Text json.RawMessage `json:"text,omitempty"`
+ ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
+ Tools []map[string]any `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
+ TopP float64 `json:"top_p,omitempty"`
+ Truncation string `json:"truncation,omitempty"`
+ User string `json:"user,omitempty"`
+ MaxToolCalls uint `json:"max_tool_calls,omitempty"`
+ Prompt json.RawMessage `json:"prompt,omitempty"`
+}
+
+func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var fileMeta = make([]*types.FileMeta, 0)
+ var texts = make([]string, 0)
+
+ if r.Input != nil {
+ inputs := r.ParseInput()
+ for _, input := range inputs {
+ if input.Type == "input_image" {
+ if input.ImageUrl != "" {
+ fileMeta = append(fileMeta, &types.FileMeta{
+ FileType: types.FileTypeImage,
+ OriginData: input.ImageUrl,
+ Detail: input.Detail,
+ })
+ }
+ } else if input.Type == "input_file" {
+ if input.FileUrl != "" {
+ fileMeta = append(fileMeta, &types.FileMeta{
+ FileType: types.FileTypeFile,
+ OriginData: input.FileUrl,
+ })
+ }
+ } else {
+ texts = append(texts, input.Text)
+ }
+ }
+ }
+
+ if len(r.Instructions) > 0 {
+ texts = append(texts, string(r.Instructions))
+ }
+
+ if len(r.Metadata) > 0 {
+ texts = append(texts, string(r.Metadata))
+ }
+
+ if len(r.Text) > 0 {
+ texts = append(texts, string(r.Text))
+ }
+
+ if len(r.ToolChoice) > 0 {
+ texts = append(texts, string(r.ToolChoice))
+ }
+
+ if len(r.Prompt) > 0 {
+ texts = append(texts, string(r.Prompt))
+ }
+
+ if len(r.Tools) > 0 {
+ toolStr, _ := common.Marshal(r.Tools)
+ texts = append(texts, string(toolStr))
+ }
+
+ return &types.TokenCountMeta{
+ CombineText: strings.Join(texts, "\n"),
+ Files: fileMeta,
+ MaxTokens: int(r.MaxOutputTokens),
+ }
+}
+
+func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool {
+ return r.Stream
+}
+
+func (r *OpenAIResponsesRequest) SetModelName(modelName string) {
+ if modelName != "" {
+ r.Model = modelName
+ }
}
type Reasoning struct {
@@ -416,21 +858,80 @@ type Reasoning struct {
Summary string `json:"summary,omitempty"`
}
-type ResponsesToolsCall struct {
- Type string `json:"type"`
- // Web Search
- UserLocation json.RawMessage `json:"user_location,omitempty"`
- SearchContextSize string `json:"search_context_size,omitempty"`
- // File Search
- VectorStoreIds []string `json:"vector_store_ids,omitempty"`
- MaxNumResults uint `json:"max_num_results,omitempty"`
- Filters json.RawMessage `json:"filters,omitempty"`
- // Computer Use
- DisplayWidth uint `json:"display_width,omitempty"`
- DisplayHeight uint `json:"display_height,omitempty"`
- Environment string `json:"environment,omitempty"`
- // Function
- Name string `json:"name,omitempty"`
- Description string `json:"description,omitempty"`
- Parameters json.RawMessage `json:"parameters,omitempty"`
+type MediaInput struct {
+ Type string `json:"type"`
+ Text string `json:"text,omitempty"`
+ FileUrl string `json:"file_url,omitempty"`
+ ImageUrl string `json:"image_url,omitempty"`
+ Detail string `json:"detail,omitempty"` // 仅 input_image 有效
+}
+
+// ParseInput parses the Responses API `input` field into a normalized slice of MediaInput.
+// Reference implementation mirrors Message.ParseContent:
+// - input can be a string, treated as an input_text item
+// - input can be an array of objects with a `type` field
+// supported types: input_text, input_image, input_file
+func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
+ if r.Input == nil {
+ return nil
+ }
+
+ var inputs []MediaInput
+
+ // Try string first
+ if str, ok := r.Input.(string); ok {
+ inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
+ return inputs
+ }
+
+ // Try array of parts
+ if array, ok := r.Input.([]any); ok {
+ for _, itemAny := range array {
+ // Already parsed MediaInput
+ if media, ok := itemAny.(MediaInput); ok {
+ inputs = append(inputs, media)
+ continue
+ }
+ // Generic map
+ item, ok := itemAny.(map[string]any)
+ if !ok {
+ continue
+ }
+ typeVal, ok := item["type"].(string)
+ if !ok {
+ continue
+ }
+ switch typeVal {
+ case "input_text":
+ text, _ := item["text"].(string)
+ inputs = append(inputs, MediaInput{Type: "input_text", Text: text})
+ case "input_image":
+ // image_url may be string or object with url field
+ var imageUrl string
+ switch v := item["image_url"].(type) {
+ case string:
+ imageUrl = v
+ case map[string]any:
+ if url, ok := v["url"].(string); ok {
+ imageUrl = url
+ }
+ }
+ inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
+ case "input_file":
+ // file_url may be string or object with url field
+ var fileUrl string
+ switch v := item["file_url"].(type) {
+ case string:
+ fileUrl = v
+ case map[string]any:
+ if url, ok := v["url"].(string); ok {
+ fileUrl = url
+ }
+ }
+ inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
+ }
+ }
+ }
+
+ return inputs
}
diff --git a/dto/openai_response.go b/dto/openai_response.go
index 790d4df8..966748cb 100644
--- a/dto/openai_response.go
+++ b/dto/openai_response.go
@@ -1,10 +1,19 @@
package dto
-import "encoding/json"
+import (
+ "encoding/json"
+ "fmt"
+ "one-api/types"
+)
type SimpleResponse struct {
Usage `json:"usage"`
- Error *OpenAIError `json:"error"`
+ Error any `json:"error"`
+}
+
+// GetOpenAIError 从动态错误类型中提取OpenAIError结构
+func (s *SimpleResponse) GetOpenAIError() *types.OpenAIError {
+ return GetOpenAIError(s.Error)
}
type TextResponse struct {
@@ -26,12 +35,17 @@ type OpenAITextResponse struct {
Id string `json:"id"`
Model string `json:"model"`
Object string `json:"object"`
- Created int64 `json:"created"`
+ Created any `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
- Error *OpenAIError `json:"error,omitempty"`
+ Error any `json:"error,omitempty"`
Usage `json:"usage"`
}
+// GetOpenAIError 从动态错误类型中提取OpenAIError结构
+func (o *OpenAITextResponse) GetOpenAIError() *types.OpenAIError {
+ return GetOpenAIError(o.Error)
+}
+
type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
@@ -45,6 +59,19 @@ type OpenAIEmbeddingResponse struct {
Usage `json:"usage"`
}
+type FlexibleEmbeddingResponseItem struct {
+ Object string `json:"object"`
+ Index int `json:"index"`
+ Embedding any `json:"embedding"`
+}
+
+type FlexibleEmbeddingResponse struct {
+ Object string `json:"object"`
+ Data []FlexibleEmbeddingResponseItem `json:"data"`
+ Model string `json:"model"`
+ Usage `json:"usage"`
+}
+
type ChatCompletionsStreamResponseChoice struct {
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"`
Logprobs *any `json:"logprobs"`
@@ -83,7 +110,7 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string
func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
c.ReasoningContent = &s
- c.Reasoning = &s
+ //c.Reasoning = &s
}
type ToolCallResponse struct {
@@ -116,6 +143,13 @@ type ChatCompletionsStreamResponse struct {
Usage *Usage `json:"usage"`
}
+func (c *ChatCompletionsStreamResponse) IsFinished() bool {
+ if len(c.Choices) == 0 {
+ return false
+ }
+ return c.Choices[0].FinishReason != nil && *c.Choices[0].FinishReason != ""
+}
+
func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
if len(c.Choices) == 0 {
return false
@@ -130,6 +164,19 @@ func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse {
return nil
}
+func (c *ChatCompletionsStreamResponse) ClearToolCalls() {
+ if !c.IsToolCall() {
+ return
+ }
+ for choiceIdx := range c.Choices {
+ for callIdx := range c.Choices[choiceIdx].Delta.ToolCalls {
+ c.Choices[choiceIdx].Delta.ToolCalls[callIdx].ID = ""
+ c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Type = nil
+ c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Function.Name = ""
+ }
+ }
+}
+
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
copy(choices, c.Choices)
@@ -178,6 +225,8 @@ type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
+ // OpenRouter Params
+ Cost any `json:"cost,omitempty"`
}
type InputTokenDetails struct {
@@ -195,28 +244,33 @@ type OutputTokenDetails struct {
}
type OpenAIResponsesResponse struct {
- ID string `json:"id"`
- Object string `json:"object"`
- CreatedAt int `json:"created_at"`
- Status string `json:"status"`
- Error *OpenAIError `json:"error,omitempty"`
- IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
- Instructions string `json:"instructions"`
- MaxOutputTokens int `json:"max_output_tokens"`
- Model string `json:"model"`
- Output []ResponsesOutput `json:"output"`
- ParallelToolCalls bool `json:"parallel_tool_calls"`
- PreviousResponseID string `json:"previous_response_id"`
- Reasoning *Reasoning `json:"reasoning"`
- Store bool `json:"store"`
- Temperature float64 `json:"temperature"`
- ToolChoice string `json:"tool_choice"`
- Tools []ResponsesToolsCall `json:"tools"`
- TopP float64 `json:"top_p"`
- Truncation string `json:"truncation"`
- Usage *Usage `json:"usage"`
- User json.RawMessage `json:"user"`
- Metadata json.RawMessage `json:"metadata"`
+ ID string `json:"id"`
+ Object string `json:"object"`
+ CreatedAt int `json:"created_at"`
+ Status string `json:"status"`
+ Error any `json:"error,omitempty"`
+ IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
+ Instructions string `json:"instructions"`
+ MaxOutputTokens int `json:"max_output_tokens"`
+ Model string `json:"model"`
+ Output []ResponsesOutput `json:"output"`
+ ParallelToolCalls bool `json:"parallel_tool_calls"`
+ PreviousResponseID string `json:"previous_response_id"`
+ Reasoning *Reasoning `json:"reasoning"`
+ Store bool `json:"store"`
+ Temperature float64 `json:"temperature"`
+ ToolChoice string `json:"tool_choice"`
+ Tools []map[string]any `json:"tools"`
+ TopP float64 `json:"top_p"`
+ Truncation string `json:"truncation"`
+ Usage *Usage `json:"usage"`
+ User json.RawMessage `json:"user"`
+ Metadata json.RawMessage `json:"metadata"`
+}
+
+// GetOpenAIError 从动态错误类型中提取OpenAIError结构
+func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
+ return GetOpenAIError(o.Error)
}
type IncompleteDetails struct {
@@ -258,3 +312,45 @@ type ResponsesStreamResponse struct {
Delta string `json:"delta,omitempty"`
Item *ResponsesOutput `json:"item,omitempty"`
}
+
+// GetOpenAIError 从动态错误类型中提取OpenAIError结构
+func GetOpenAIError(errorField any) *types.OpenAIError {
+ if errorField == nil {
+ return nil
+ }
+
+ switch err := errorField.(type) {
+ case types.OpenAIError:
+ return &err
+ case *types.OpenAIError:
+ return err
+ case map[string]interface{}:
+ // 处理从JSON解析来的map结构
+ openaiErr := &types.OpenAIError{}
+ if errType, ok := err["type"].(string); ok {
+ openaiErr.Type = errType
+ }
+ if errMsg, ok := err["message"].(string); ok {
+ openaiErr.Message = errMsg
+ }
+ if errParam, ok := err["param"].(string); ok {
+ openaiErr.Param = errParam
+ }
+ if errCode, ok := err["code"]; ok {
+ openaiErr.Code = errCode
+ }
+ return openaiErr
+ case string:
+ // 处理简单字符串错误
+ return &types.OpenAIError{
+ Type: "error",
+ Message: err,
+ }
+ default:
+ // 未知类型,尝试转换为字符串
+ return &types.OpenAIError{
+ Type: "unknown_error",
+ Message: fmt.Sprintf("%v", err),
+ }
+ }
+}
diff --git a/dto/pricing.go b/dto/pricing.go
index ee77c098..bc024de3 100644
--- a/dto/pricing.go
+++ b/dto/pricing.go
@@ -1,26 +1,35 @@
package dto
-type OpenAIModelPermission struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int `json:"created"`
- AllowCreateEngine bool `json:"allow_create_engine"`
- AllowSampling bool `json:"allow_sampling"`
- AllowLogprobs bool `json:"allow_logprobs"`
- AllowSearchIndices bool `json:"allow_search_indices"`
- AllowView bool `json:"allow_view"`
- AllowFineTuning bool `json:"allow_fine_tuning"`
- Organization string `json:"organization"`
- Group *string `json:"group"`
- IsBlocking bool `json:"is_blocking"`
+import "one-api/constant"
+
+// 这里不好动就不动了,本来想独立出来的(
+type OpenAIModels struct {
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int `json:"created"`
+ OwnedBy string `json:"owned_by"`
+ SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
}
-type OpenAIModels struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int `json:"created"`
- OwnedBy string `json:"owned_by"`
- Permission []OpenAIModelPermission `json:"permission"`
- Root string `json:"root"`
- Parent *string `json:"parent"`
+type AnthropicModel struct {
+ ID string `json:"id"`
+ CreatedAt string `json:"created_at"`
+ DisplayName string `json:"display_name"`
+ Type string `json:"type"`
+}
+
+type GeminiModel struct {
+ Name interface{} `json:"name"`
+ BaseModelId interface{} `json:"baseModelId"`
+ Version interface{} `json:"version"`
+ DisplayName interface{} `json:"displayName"`
+ Description interface{} `json:"description"`
+ InputTokenLimit interface{} `json:"inputTokenLimit"`
+ OutputTokenLimit interface{} `json:"outputTokenLimit"`
+ SupportedGenerationMethods []interface{} `json:"supportedGenerationMethods"`
+ Thinking interface{} `json:"thinking"`
+ Temperature interface{} `json:"temperature"`
+ MaxTemperature interface{} `json:"maxTemperature"`
+ TopP interface{} `json:"topP"`
+ TopK interface{} `json:"topK"`
}
diff --git a/dto/ratio_sync.go b/dto/ratio_sync.go
new file mode 100644
index 00000000..6315f31a
--- /dev/null
+++ b/dto/ratio_sync.go
@@ -0,0 +1,38 @@
+package dto
+
+type UpstreamDTO struct {
+ ID int `json:"id,omitempty"`
+ Name string `json:"name" binding:"required"`
+ BaseURL string `json:"base_url" binding:"required"`
+ Endpoint string `json:"endpoint"`
+}
+
+type UpstreamRequest struct {
+ ChannelIDs []int64 `json:"channel_ids"`
+ Upstreams []UpstreamDTO `json:"upstreams"`
+ Timeout int `json:"timeout"`
+}
+
+// TestResult 上游测试连通性结果
+type TestResult struct {
+ Name string `json:"name"`
+ Status string `json:"status"`
+ Error string `json:"error,omitempty"`
+}
+
+// DifferenceItem 差异项
+// Current 为本地值,可能为 nil
+// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
+
+type DifferenceItem struct {
+ Current interface{} `json:"current"`
+ Upstreams map[string]interface{} `json:"upstreams"`
+ Confidence map[string]bool `json:"confidence"`
+}
+
+type SyncableChannel struct {
+ ID int `json:"id"`
+ Name string `json:"name"`
+ BaseURL string `json:"base_url"`
+ Status int `json:"status"`
+}
\ No newline at end of file
diff --git a/dto/realtime.go b/dto/realtime.go
index 86ae352d..32a69056 100644
--- a/dto/realtime.go
+++ b/dto/realtime.go
@@ -1,5 +1,7 @@
package dto
+import "one-api/types"
+
const (
RealtimeEventTypeError = "error"
RealtimeEventTypeSessionUpdate = "session.update"
@@ -23,12 +25,12 @@ type RealtimeEvent struct {
EventId string `json:"event_id"`
Type string `json:"type"`
//PreviousItemId string `json:"previous_item_id"`
- Session *RealtimeSession `json:"session,omitempty"`
- Item *RealtimeItem `json:"item,omitempty"`
- Error *OpenAIError `json:"error,omitempty"`
- Response *RealtimeResponse `json:"response,omitempty"`
- Delta string `json:"delta,omitempty"`
- Audio string `json:"audio,omitempty"`
+ Session *RealtimeSession `json:"session,omitempty"`
+ Item *RealtimeItem `json:"item,omitempty"`
+ Error *types.OpenAIError `json:"error,omitempty"`
+ Response *RealtimeResponse `json:"response,omitempty"`
+ Delta string `json:"delta,omitempty"`
+ Audio string `json:"audio,omitempty"`
}
type RealtimeResponse struct {
diff --git a/dto/request_common.go b/dto/request_common.go
new file mode 100644
index 00000000..da3ac3c5
--- /dev/null
+++ b/dto/request_common.go
@@ -0,0 +1,25 @@
+package dto
+
+import (
+ "github.com/gin-gonic/gin"
+ "one-api/types"
+)
+
+type Request interface {
+ GetTokenCountMeta() *types.TokenCountMeta
+ IsStream(c *gin.Context) bool
+ SetModelName(modelName string)
+}
+
+type BaseRequest struct {
+}
+
+func (b *BaseRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ return &types.TokenCountMeta{
+ TokenType: types.TokenTypeTokenizer,
+ }
+}
+func (b *BaseRequest) IsStream(c *gin.Context) bool {
+ return false
+}
+func (b *BaseRequest) SetModelName(modelName string) {}
diff --git a/dto/rerank.go b/dto/rerank.go
index 21f6437c..46f4bce6 100644
--- a/dto/rerank.go
+++ b/dto/rerank.go
@@ -1,15 +1,48 @@
package dto
+import (
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "one-api/types"
+ "strings"
+)
+
type RerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
- TopN int `json:"top_n"`
+ TopN int `json:"top_n,omitempty"`
ReturnDocuments *bool `json:"return_documents,omitempty"`
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
OverLapTokens int `json:"overlap_tokens,omitempty"`
}
+func (r *RerankRequest) IsStream(c *gin.Context) bool {
+ return false
+}
+
+func (r *RerankRequest) GetTokenCountMeta() *types.TokenCountMeta {
+ var texts = make([]string, 0)
+
+ for _, document := range r.Documents {
+ texts = append(texts, fmt.Sprintf("%v", document))
+ }
+
+ if r.Query != "" {
+ texts = append(texts, r.Query)
+ }
+
+ return &types.TokenCountMeta{
+ CombineText: strings.Join(texts, "\n"),
+ }
+}
+
+func (r *RerankRequest) SetModelName(modelName string) {
+ if modelName != "" {
+ r.Model = modelName
+ }
+}
+
func (r *RerankRequest) GetReturnDocuments() bool {
if r.ReturnDocuments == nil {
return false
diff --git a/dto/user_settings.go b/dto/user_settings.go
new file mode 100644
index 00000000..2e1a1541
--- /dev/null
+++ b/dto/user_settings.go
@@ -0,0 +1,16 @@
+package dto
+
+type UserSetting struct {
+ NotifyType string `json:"notify_type,omitempty"` // QuotaWarningType 额度预警类型
+ QuotaWarningThreshold float64 `json:"quota_warning_threshold,omitempty"` // QuotaWarningThreshold 额度预警阈值
+ WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
+ WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
+ NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
+ AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
+ RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
+}
+
+var (
+ NotifyTypeEmail = "email" // Email 邮件
+ NotifyTypeWebhook = "webhook" // Webhook
+)
diff --git a/dto/video.go b/dto/video.go
new file mode 100644
index 00000000..5b48146a
--- /dev/null
+++ b/dto/video.go
@@ -0,0 +1,47 @@
+package dto
+
+type VideoRequest struct {
+ Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
+ Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
+ Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
+ Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
+ Width int `json:"width" example:"512"` // Video width
+ Height int `json:"height" example:"512"` // Video height
+ Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
+ Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
+ N int `json:"n,omitempty" example:"1"` // Number of videos to generate
+ ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
+ User string `json:"user,omitempty" example:"user-1234"` // User identifier
+ Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
+}
+
+// VideoResponse 视频生成提交任务后的响应
+type VideoResponse struct {
+ TaskId string `json:"task_id"`
+ Status string `json:"status"`
+}
+
+// VideoTaskResponse 查询视频生成任务状态的响应
+type VideoTaskResponse struct {
+ TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
+ Status string `json:"status" example:"succeeded"` // 任务状态
+ Url string `json:"url,omitempty"` // 视频资源URL(成功时)
+ Format string `json:"format,omitempty" example:"mp4"` // 视频格式
+ Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
+ Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
+}
+
+// VideoTaskMetadata 视频任务元数据
+type VideoTaskMetadata struct {
+ Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
+ Fps int `json:"fps" example:"30"` // 实际帧率
+ Width int `json:"width" example:"512"` // 实际宽度
+ Height int `json:"height" example:"512"` // 实际高度
+ Seed int `json:"seed" example:"20231234"` // 使用的随机种子
+}
+
+// VideoTaskError 视频任务错误信息
+type VideoTaskError struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+}
diff --git a/go.mod b/go.mod
index ce768bf3..1a92947e 100644
--- a/go.mod
+++ b/go.mod
@@ -7,11 +7,11 @@ require (
github.com/Calcium-Ion/go-epay v0.0.4
github.com/andybalholm/brotli v1.1.1
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
- github.com/aws/aws-sdk-go-v2 v1.26.1
+ github.com/aws/aws-sdk-go-v2 v1.37.2
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
- github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
+ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0
+ github.com/aws/smithy-go v1.22.5
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
- github.com/bytedance/sonic v1.11.6
github.com/gin-contrib/cors v1.7.2
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
@@ -25,30 +25,41 @@ require (
github.com/gorilla/websocket v1.5.0
github.com/joho/godotenv v1.5.1
github.com/pkg/errors v0.9.1
- github.com/pkoukk/tiktoken-go v0.1.7
+ github.com/pquerna/otp v1.5.0
github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.4.0
+ github.com/stripe/stripe-go/v81 v81.4.0
+ github.com/thanhpk/randstr v1.0.6
+ github.com/tidwall/gjson v1.18.0
+ github.com/tidwall/sjson v1.2.5
+ github.com/tiktoken-go/tokenizer v0.6.2
golang.org/x/crypto v0.35.0
golang.org/x/image v0.23.0
golang.org/x/net v0.35.0
+ golang.org/x/sync v0.11.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/gorm v1.25.2
)
require (
+ github.com/Masterminds/goutils v1.1.1 // indirect
+ github.com/Masterminds/semver/v3 v3.2.0 // indirect
+ github.com/Masterminds/sprig/v3 v3.2.3 // indirect
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
- github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
- github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
- github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
- github.com/aws/smithy-go v1.20.2 // indirect
+ github.com/antlabs/pcopy v0.1.5 // indirect
+ github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect
+ github.com/boombuler/barcode v1.1.0 // indirect
+ github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
- github.com/dlclark/regexp2 v1.11.0 // indirect
+ github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
@@ -62,6 +73,8 @@ require (
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
+ github.com/huandu/xstrings v1.3.3 // indirect
+ github.com/imdario/mergo v0.3.11 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.7.1 // indirect
@@ -72,11 +85,16 @@ require (
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
+ github.com/mitchellh/copystructure v1.0.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
+ github.com/mitchellh/reflectwalk v1.0.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
+ github.com/spf13/cast v1.3.1 // indirect
+ github.com/tidwall/match v1.1.1 // indirect
+ github.com/tidwall/pretty v1.2.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
@@ -84,7 +102,6 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.12.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
- golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.22.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
diff --git a/go.sum b/go.sum
index 2bd81fa3..7b8104b9 100644
--- a/go.sum
+++ b/go.sum
@@ -1,25 +1,36 @@
github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
+github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
+github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
+github.com/Masterminds/semver/v3 v3.2.0 h1:3MEsd0SM6jqZojhjLWWeBY+Kcjy9i6MQAeY7YgDP83g=
+github.com/Masterminds/semver/v3 v3.2.0/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
+github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA=
+github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
-github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
-github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
-github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
-github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
+github.com/antlabs/pcopy v0.1.5 h1:5Fa1ExY9T6ar3ysAi4rzB5jiYg72Innm+/ESEIOSHvQ=
+github.com/antlabs/pcopy v0.1.5/go.mod h1:2FvdkPD3cFiM1CjGuXFCDQZqhKVcLI7IzeSJ2xUIOOI=
+github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo=
+github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg=
+github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg=
+github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
-github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
-github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
-github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
-github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
-github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
-github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
-github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
-github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 h1:sPiRHLVUIIQcoVZTNwqQcdtjkqkPopyYmIX0M5ElRf4=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2/go.mod h1:ik86P3sgV+Bk7c1tBFCwI3VxMoSEwl4YkRB9xn1s340=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 h1:ZdzDAg075H6stMZtbD2o+PyB933M/f20e9WmCBC17wA=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2/go.mod h1:eE1IIzXG9sdZCB0pNNpMpsYTLl4YdOQD3njiVN1e/E4=
+github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0 h1:JzidOz4Hcn2RbP5fvIS1iAP+DcRv5VJtgixbEYDsI5g=
+github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.33.0/go.mod h1:9A4/PJYlWjvjEzzoOLGQjkLt4bYK9fRWi7uz1GSsAcA=
+github.com/aws/smithy-go v1.22.5 h1:P9ATCXPMb2mPjYBgueqJNCA5S9UfktsW0tTxi+a7eqw=
+github.com/aws/smithy-go v1.22.5/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
+github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
+github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo=
+github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
@@ -38,8 +49,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
-github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
-github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
+github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
+github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
@@ -99,6 +110,7 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
+github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
@@ -109,6 +121,10 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
+github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4=
+github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
+github.com/imdario/mergo v0.3.11 h1:3tnifQM4i+fbajXKBHXWEH+KvNHqojZ778UH75j3bGA=
+github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -147,8 +163,12 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ=
+github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
+github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/IfikLNY=
+github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -167,10 +187,10 @@ github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
-github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
-github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
+github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
@@ -181,14 +201,19 @@ github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
+github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
+github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng=
+github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
+github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@@ -197,6 +222,21 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
+github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
+github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
+github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
+github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
+github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
+github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
+github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
+github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
+github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
+github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
+github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
+github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
+github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
+github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
@@ -211,43 +251,67 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
+github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg=
golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4=
golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
+golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
+golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
+golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
+golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
+golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
+golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
+golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
+golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
@@ -262,6 +326,7 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkep
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/i18n/zh-cn.json b/i18n/zh-cn.json
new file mode 100644
index 00000000..dc7a1e4c
--- /dev/null
+++ b/i18n/zh-cn.json
@@ -0,0 +1,1054 @@
+{
+ "未登录或登录已过期,请重新登录": "未登录或登录已过期,请重新登录",
+ "登 录": "登 录",
+ "使用 微信 继续": "使用 微信 继续",
+ "使用 GitHub 继续": "使用 GitHub 继续",
+ "使用 LinuxDO 继续": "使用 LinuxDO 继续",
+ "使用 邮箱或用户名 登录": "使用 邮箱或用户名 登录",
+ "没有账户?": "没有账户?",
+ "用户名或邮箱": "用户名或邮箱",
+ "请输入您的用户名或邮箱地址": "请输入您的用户名或邮箱地址",
+ "请输入您的密码": "请输入您的密码",
+ "继续": "继续",
+ "忘记密码?": "忘记密码?",
+ "其他登录选项": "其他登录选项",
+ "微信扫码登录": "微信扫码登录",
+ "登录": "登录",
+ "微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)": "微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)",
+ "验证码": "验证码",
+ "处理中...": "处理中...",
+ "绑定成功!": "绑定成功!",
+ "登录成功!": "登录成功!",
+ "操作失败,重定向至登录界面中...": "操作失败,重定向至登录界面中...",
+ "出现错误,第 ${count} 次重试中...": "出现错误,第 ${count} 次重试中...",
+ "无效的重置链接,请重新发起密码重置请求": "无效的重置链接,请重新发起密码重置请求",
+ "密码已重置并已复制到剪贴板:": "密码已重置并已复制到剪贴板:",
+ "密码重置确认": "密码重置确认",
+ "等待获取邮箱信息...": "等待获取邮箱信息...",
+ "新密码": "新密码",
+ "密码已复制到剪贴板:": "密码已复制到剪贴板:",
+ "密码重置完成": "密码重置完成",
+ "确认重置密码": "确认重置密码",
+ "返回登录": "返回登录",
+ "请输入邮箱地址": "请输入邮箱地址",
+ "请稍后几秒重试,Turnstile 正在检查用户环境!": "请稍后几秒重试,Turnstile 正在检查用户环境!",
+ "重置邮件发送成功,请检查邮箱!": "重置邮件发送成功,请检查邮箱!",
+ "密码重置": "密码重置",
+ "请输入您的邮箱地址": "请输入您的邮箱地址",
+ "重试": "重试",
+ "想起来了?": "想起来了?",
+ "注 册": "注 册",
+ "使用 用户名 注册": "使用 用户名 注册",
+ "已有账户?": "已有账户?",
+ "用户名": "用户名",
+ "请输入用户名": "请输入用户名",
+ "输入密码,最短 8 位,最长 20 位": "输入密码,最短 8 位,最长 20 位",
+ "确认密码": "确认密码",
+ "输入邮箱地址": "输入邮箱地址",
+ "获取验证码": "获取验证码",
+ "输入验证码": "输入验证码",
+ "或": "或",
+ "其他注册选项": "其他注册选项",
+ "加载中...": "加载中...",
+ "复制代码": "复制代码",
+ "代码已复制到剪贴板": "代码已复制到剪贴板",
+ "复制失败,请手动复制": "复制失败,请手动复制",
+ "显示更多": "显示更多",
+ "关于我们": "关于我们",
+ "关于项目": "关于项目",
+ "联系我们": "联系我们",
+ "功能特性": "功能特性",
+ "快速开始": "快速开始",
+ "安装指南": "安装指南",
+ "API 文档": "API 文档",
+ "基于New API的项目": "基于New API的项目",
+ "版权所有": "版权所有",
+ "设计与开发由": "设计与开发由",
+ "首页": "首页",
+ "控制台": "控制台",
+ "文档": "文档",
+ "关于": "关于",
+ "注销成功!": "注销成功!",
+ "个人设置": "个人设置",
+ "令牌管理": "令牌管理",
+ "退出": "退出",
+ "关闭侧边栏": "关闭侧边栏",
+ "打开侧边栏": "打开侧边栏",
+ "关闭菜单": "关闭菜单",
+ "打开菜单": "打开菜单",
+ "演示站点": "演示站点",
+ "自用模式": "自用模式",
+ "系统公告": "系统公告",
+ "切换主题": "切换主题",
+ "切换语言": "切换语言",
+ "暂无公告": "暂无公告",
+ "暂无系统公告": "暂无系统公告",
+ "今日关闭": "今日关闭",
+ "关闭公告": "关闭公告",
+ "数据看板": "数据看板",
+ "绘图日志": "绘图日志",
+ "任务日志": "任务日志",
+ "渠道": "渠道",
+ "兑换码": "兑换码",
+ "用户管理": "用户管理",
+ "操练场": "操练场",
+ "聊天": "聊天",
+ "管理员": "管理员",
+ "个人中心": "个人中心",
+ "展开侧边栏": "展开侧边栏",
+ "AI 对话": "AI 对话",
+ "选择模型开始对话": "选择模型开始对话",
+ "显示调试": "显示调试",
+ "请输入您的问题...": "请输入您的问题...",
+ "已复制到剪贴板": "已复制到剪贴板",
+ "复制失败": "复制失败",
+ "正在构造请求体预览...": "正在构造请求体预览...",
+ "暂无请求数据": "暂无请求数据",
+ "暂无响应数据": "暂无响应数据",
+ "内容较大,已启用性能优化模式": "内容较大,已启用性能优化模式",
+ "内容较大,部分功能可能受限": "内容较大,部分功能可能受限",
+ "已复制": "已复制",
+ "正在处理大内容...": "正在处理大内容...",
+ "显示完整内容": "显示完整内容",
+ "收起": "收起",
+ "配置已导出到下载文件夹": "配置已导出到下载文件夹",
+ "导出配置失败: ": "导出配置失败: ",
+ "确认导入配置": "确认导入配置",
+ "导入的配置将覆盖当前设置,是否继续?": "导入的配置将覆盖当前设置,是否继续?",
+ "取消": "取消",
+ "配置导入成功": "配置导入成功",
+ "导入配置失败: ": "导入配置失败: ",
+ "重置配置": "重置配置",
+ "将清除所有保存的配置并恢复默认设置,此操作不可撤销。是否继续?": "将清除所有保存的配置并恢复默认设置,此操作不可撤销。是否继续?",
+ "重置选项": "重置选项",
+ "是否同时重置对话消息?选择\"是\"将清空所有对话记录并恢复默认示例;选择\"否\"将保留当前对话记录。": "是否同时重置对话消息?选择\"是\"将清空所有对话记录并恢复默认示例;选择\"否\"将保留当前对话记录。",
+ "同时重置消息": "同时重置消息",
+ "仅重置配置": "仅重置配置",
+ "配置和消息已全部重置": "配置和消息已全部重置",
+ "配置已重置,对话消息已保留": "配置已重置,对话消息已保留",
+ "已有保存的配置": "已有保存的配置",
+ "暂无保存的配置": "暂无保存的配置",
+ "导出配置": "导出配置",
+ "导入配置": "导入配置",
+ "导出": "导出",
+ "导入": "导入",
+ "调试信息": "调试信息",
+ "预览请求体": "预览请求体",
+ "实际请求体": "实际请求体",
+ "预览更新": "预览更新",
+ "最后请求": "最后请求",
+ "操作暂时被禁用": "操作暂时被禁用",
+ "复制": "复制",
+ "编辑": "编辑",
+ "切换为System角色": "切换为System角色",
+ "切换为Assistant角色": "切换为Assistant角色",
+ "删除": "删除",
+ "请求发生错误": "请求发生错误",
+ "系统消息": "系统消息",
+ "请输入消息内容...": "请输入消息内容...",
+ "保存": "保存",
+ "模型配置": "模型配置",
+ "分组": "分组",
+ "请选择分组": "请选择分组",
+ "请选择模型": "请选择模型",
+ "思考中...": "思考中...",
+ "思考过程": "思考过程",
+ "选择同步渠道": "选择同步渠道",
+ "搜索渠道名称或地址": "搜索渠道名称或地址",
+ "暂无渠道": "暂无渠道",
+ "暂无选择": "暂无选择",
+ "无搜索结果": "无搜索结果",
+ "公告已更新": "公告已更新",
+ "公告更新失败": "公告更新失败",
+ "系统名称已更新": "系统名称已更新",
+ "系统名称更新失败": "系统名称更新失败",
+ "系统信息": "系统信息",
+ "当前版本": "当前版本",
+ "检查更新": "检查更新",
+ "启动时间": "启动时间",
+ "通用设置": "通用设置",
+ "设置公告": "设置公告",
+ "个性化设置": "个性化设置",
+ "系统名称": "系统名称",
+ "在此输入系统名称": "在此输入系统名称",
+ "设置系统名称": "设置系统名称",
+ "Logo 图片地址": "Logo 图片地址",
+ "在此输入 Logo 图片地址": "在此输入 Logo 图片地址",
+ "首页内容": "首页内容",
+ "设置首页内容": "设置首页内容",
+ "设置关于": "设置关于",
+ "页脚": "页脚",
+ "设置页脚": "设置页脚",
+ "详情": "详情",
+ "刷新失败": "刷新失败",
+ "令牌已重置并已复制到剪贴板": "令牌已重置并已复制到剪贴板",
+ "加载模型列表失败": "加载模型列表失败",
+ "系统令牌已复制到剪切板": "系统令牌已复制到剪切板",
+ "请输入你的账户名以确认删除!": "请输入你的账户名以确认删除!",
+ "账户已删除!": "账户已删除!",
+ "微信账户绑定成功!": "微信账户绑定成功!",
+ "请输入原密码!": "请输入原密码!",
+ "请输入新密码!": "请输入新密码!",
+ "新密码需要和原密码不一致!": "新密码需要和原密码不一致!",
+ "两次输入的密码不一致!": "两次输入的密码不一致!",
+ "密码修改成功!": "密码修改成功!",
+ "验证码发送成功,请检查邮箱!": "验证码发送成功,请检查邮箱!",
+ "请输入邮箱验证码!": "请输入邮箱验证码!",
+ "邮箱账户绑定成功!": "邮箱账户绑定成功!",
+ "无法复制到剪贴板,请手动复制": "无法复制到剪贴板,请手动复制",
+ "设置保存成功": "设置保存成功",
+ "设置保存失败": "设置保存失败",
+ "超级管理员": "超级管理员",
+ "普通用户": "普通用户",
+ "当前余额": "当前余额",
+ "历史消耗": "历史消耗",
+ "请求次数": "请求次数",
+ "默认": "默认",
+ "可用模型": "可用模型",
+ "模型列表": "模型列表",
+ "点击模型名称可复制": "点击模型名称可复制",
+ "没有可用模型": "没有可用模型",
+ "该分类下没有可用模型": "该分类下没有可用模型",
+ "更多": "更多",
+ "个模型": "个模型",
+ "账户绑定": "账户绑定",
+ "未绑定": "未绑定",
+ "修改绑定": "修改绑定",
+ "微信": "微信",
+ "已绑定": "已绑定",
+ "未启用": "未启用",
+ "绑定": "绑定",
+ "安全设置": "安全设置",
+ "系统访问令牌": "系统访问令牌",
+ "用于API调用的身份验证令牌,请妥善保管": "用于API调用的身份验证令牌,请妥善保管",
+ "生成令牌": "生成令牌",
+ "密码管理": "密码管理",
+ "定期更改密码可以提高账户安全性": "定期更改密码可以提高账户安全性",
+ "修改密码": "修改密码",
+ "此操作不可逆,所有数据将被永久删除": "此操作不可逆,所有数据将被永久删除",
+ "删除账户": "删除账户",
+ "其他设置": "其他设置",
+ "通知设置": "通知设置",
+ "邮件通知": "邮件通知",
+ "通过邮件接收通知": "通过邮件接收通知",
+ "Webhook通知": "Webhook通知",
+ "通过HTTP请求接收通知": "通过HTTP请求接收通知",
+ "请输入Webhook地址,例如: https://example.com/webhook": "请输入Webhook地址,例如: https://example.com/webhook",
+ "只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求": "只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求",
+ "接口凭证(可选)": "接口凭证(可选)",
+ "请输入密钥": "请输入密钥",
+ "密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性": "密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性",
+ "通知邮箱": "通知邮箱",
+ "留空则使用账号绑定的邮箱": "留空则使用账号绑定的邮箱",
+ "设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱": "设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱",
+ "额度预警阈值": "额度预警阈值",
+ "请输入预警额度": "请输入预警额度",
+ "当剩余额度低于此数值时,系统将通过选择的方式发送通知": "当剩余额度低于此数值时,系统将通过选择的方式发送通知",
+ "接受未设置价格模型": "接受未设置价格模型",
+ "当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用": "当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用",
+ "IP记录": "IP记录",
+ "记录请求与错误日志 IP": "记录请求与错误日志 IP",
+ "开启后,仅“消费”和“错误”日志将记录您的客户端 IP 地址": "开启后,仅“消费”和“错误”日志将记录您的客户端 IP 地址",
+ "绑定邮箱地址": "绑定邮箱地址",
+ "重新发送": "重新发送",
+ "绑定微信账户": "绑定微信账户",
+ "删除账户确认": "删除账户确认",
+ "您正在删除自己的帐户,将清空所有数据且不可恢复": "您正在删除自己的帐户,将清空所有数据且不可恢复",
+ "请输入您的用户名以确认删除": "请输入您的用户名以确认删除",
+ "输入你的账户名{{username}}以确认删除": "输入你的账户名{{username}}以确认删除",
+ "原密码": "原密码",
+ "请输入原密码": "请输入原密码",
+ "请输入新密码": "请输入新密码",
+ "确认新密码": "确认新密码",
+ "请再次输入新密码": "请再次输入新密码",
+ "模型倍率设置": "模型倍率设置",
+ "可视化倍率设置": "可视化倍率设置",
+ "未设置倍率模型": "未设置倍率模型",
+ "上游倍率同步": "上游倍率同步",
+ "未知类型": "未知类型",
+ "标签聚合": "标签聚合",
+ "已启用": "已启用",
+ "自动禁用": "自动禁用",
+ "未知状态": "未知状态",
+ "未测试": "未测试",
+ "名称": "名称",
+ "类型": "类型",
+ "状态": "状态",
+ ",时间:": ",时间:",
+ "响应时间": "响应时间",
+ "已用/剩余": "已用/剩余",
+ "剩余额度$": "剩余额度$",
+ ",点击更新": ",点击更新",
+ "已用额度": "已用额度",
+ "修改子渠道优先级": "修改子渠道优先级",
+ "确定要修改所有子渠道优先级为 ": "确定要修改所有子渠道优先级为 ",
+ "权重": "权重",
+ "修改子渠道权重": "修改子渠道权重",
+ "确定要修改所有子渠道权重为 ": "确定要修改所有子渠道权重为 ",
+ "确定是否要删除此渠道?": "确定是否要删除此渠道?",
+ "此修改将不可逆": "此修改将不可逆",
+ "确定是否要复制此渠道?": "确定是否要复制此渠道?",
+ "复制渠道的所有信息": "复制渠道的所有信息",
+ "测试单个渠道操作项目组": "测试单个渠道操作项目组",
+ "禁用": "禁用",
+ "启用": "启用",
+ "启用全部": "启用全部",
+ "禁用全部": "禁用全部",
+ "重置": "重置",
+ "全选": "全选",
+ "_复制": "_复制",
+ "渠道未找到,请刷新页面后重试。": "渠道未找到,请刷新页面后重试。",
+ "渠道复制成功": "渠道复制成功",
+ "渠道复制失败: ": "渠道复制失败: ",
+ "操作成功完成!": "操作成功完成!",
+ "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。",
+ "已停止测试": "已停止测试",
+ "全部": "全部",
+ "请先选择要设置标签的渠道!": "请先选择要设置标签的渠道!",
+ "标签不能为空!": "标签不能为空!",
+ "已为 ${count} 个渠道设置标签!": "已为 ${count} 个渠道设置标签!",
+ "已成功开始测试所有已启用通道,请刷新页面查看结果。": "已成功开始测试所有已启用通道,请刷新页面查看结果。",
+ "已删除所有禁用渠道,共计 ${data} 个": "已删除所有禁用渠道,共计 ${data} 个",
+ "已更新完毕所有已启用通道余额!": "已更新完毕所有已启用通道余额!",
+ "通道 ${name} 余额更新成功!": "通道 ${name} 余额更新成功!",
+ "已删除 ${data} 个通道!": "已删除 ${data} 个通道!",
+ "已修复 ${data} 个通道!": "已修复 ${data} 个通道!",
+ "确定是否要删除所选通道?": "确定是否要删除所选通道?",
+ "删除所选通道": "删除所选通道",
+ "批量设置标签": "批量设置标签",
+ "确定要测试所有通道吗?": "确定要测试所有通道吗?",
+ "测试所有通道": "测试所有通道",
+ "确定要更新所有已启用通道余额吗?": "确定要更新所有已启用通道余额吗?",
+ "更新所有已启用通道余额": "更新所有已启用通道余额",
+ "确定是否要删除禁用通道?": "确定是否要删除禁用通道?",
+ "删除禁用通道": "删除禁用通道",
+ "确定是否要修复数据库一致性?": "确定是否要修复数据库一致性?",
+ "进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用",
+ "批量操作": "批量操作",
+ "使用ID排序": "使用ID排序",
+ "开启批量操作": "开启批量操作",
+ "标签聚合模式": "标签聚合模式",
+ "刷新": "刷新",
+ "列设置": "列设置",
+ "搜索渠道的 ID,名称,密钥和API地址 ...": "搜索渠道的 ID,名称,密钥和API地址 ...",
+ "模型关键字": "模型关键字",
+ "选择分组": "选择分组",
+ "查询": "查询",
+ "第 {{start}} - {{end}} 条,共 {{total}} 条": "第 {{start}} - {{end}} 条,共 {{total}} 条",
+ "搜索无结果": "搜索无结果",
+ "请输入要设置的标签名称": "请输入要设置的标签名称",
+ "请输入标签名称": "请输入标签名称",
+ "已选择 ${count} 个渠道": "已选择 ${count} 个渠道",
+ "共": "共",
+ "停止测试": "停止测试",
+ "测试中...": "测试中...",
+ "批量测试${count}个模型": "批量测试${count}个模型",
+ "搜索模型...": "搜索模型...",
+ "模型名称": "模型名称",
+ "测试中": "测试中",
+ "未开始": "未开始",
+ "失败": "失败",
+ "请求时长: ${time}s": "请求时长: ${time}s",
+ "充值": "充值",
+ "消费": "消费",
+ "系统": "系统",
+ "错误": "错误",
+ "流": "流",
+ "非流": "非流",
+ "请求并计费模型": "请求并计费模型",
+ "实际模型": "实际模型",
+ "用户": "用户",
+ "用时/首字": "用时/首字",
+ "提示": "提示",
+ "花费": "花费",
+ "只有当用户设置开启IP记录时,才会进行请求和错误类型日志的IP记录": "只有当用户设置开启IP记录时,才会进行请求和错误类型日志的IP记录",
+ "确定": "确定",
+ "用户信息": "用户信息",
+ "渠道信息": "渠道信息",
+ "语音输入": "语音输入",
+ "文字输入": "文字输入",
+ "文字输出": "文字输出",
+ "缓存创建 Tokens": "缓存创建 Tokens",
+ "日志详情": "日志详情",
+ "消耗额度": "消耗额度",
+ "开始时间": "开始时间",
+ "结束时间": "结束时间",
+ "用户名称": "用户名称",
+ "日志类型": "日志类型",
+ "绘图": "绘图",
+ "放大": "放大",
+ "变换": "变换",
+ "强变换": "强变换",
+ "平移": "平移",
+ "图生文": "图生文",
+ "图混合": "图混合",
+ "重绘": "重绘",
+ "局部重绘-提交": "局部重绘-提交",
+ "自定义变焦-提交": "自定义变焦-提交",
+ "窗口处理": "窗口处理",
+ "未知": "未知",
+ "已提交": "已提交",
+ "等待中": "等待中",
+ "重复提交": "重复提交",
+ "成功": "成功",
+ "未启动": "未启动",
+ "执行中": "执行中",
+ "窗口等待": "窗口等待",
+ "秒": "秒",
+ "提交时间": "提交时间",
+ "花费时间": "花费时间",
+ "任务ID": "任务ID",
+ "提交结果": "提交结果",
+ "任务状态": "任务状态",
+ "结果图片": "结果图片",
+ "查看图片": "查看图片",
+ "无": "无",
+ "失败原因": "失败原因",
+ "已复制:": "已复制:",
+ "当前未开启Midjourney回调,部分项目可能无法获得绘图结果,可在运营设置中开启。": "当前未开启Midjourney回调,部分项目可能无法获得绘图结果,可在运营设置中开启。",
+ "Midjourney 任务记录": "Midjourney 任务记录",
+ "任务 ID": "任务 ID",
+ "按次计费": "按次计费",
+ "按量计费": "按量计费",
+ "您的分组可以使用该模型": "您的分组可以使用该模型",
+ "可用性": "可用性",
+ "计费类型": "计费类型",
+ "当前查看的分组为:{{group}},倍率为:{{ratio}}": "当前查看的分组为:{{group}},倍率为:{{ratio}}",
+ "倍率": "倍率",
+ "倍率是为了方便换算不同价格的模型": "倍率是为了方便换算不同价格的模型",
+ "模型倍率": "模型倍率",
+ "补全倍率": "补全倍率",
+ "分组倍率": "分组倍率",
+ "模型价格": "模型价格",
+ "补全": "补全",
+ "模糊搜索模型名称": "模糊搜索模型名称",
+ "复制选中模型": "复制选中模型",
+ "模型定价": "模型定价",
+ "当前分组": "当前分组",
+ "未登录,使用默认分组倍率": "未登录,使用默认分组倍率",
+ "按量计费费用 = 分组倍率 × 模型倍率 × (提示token数 + 补全token数 × 补全倍率)/ 500000 (单位:美元)": "按量计费费用 = 分组倍率 × 模型倍率 × (提示token数 + 补全token数 × 补全倍率)/ 500000 (单位:美元)",
+ "已过期": "已过期",
+ "未使用": "未使用",
+ "已禁用": "已禁用",
+ "创建时间": "创建时间",
+ "过期时间": "过期时间",
+ "永不过期": "永不过期",
+ "确定是否要删除此兑换码?": "确定是否要删除此兑换码?",
+ "查看": "查看",
+ "已复制到剪贴板!": "已复制到剪贴板!",
+ "兑换码可以批量生成和分发,适合用于推广活动或批量充值。": "兑换码可以批量生成和分发,适合用于推广活动或批量充值。",
+ "添加兑换码": "添加兑换码",
+ "请至少选择一个兑换码!": "请至少选择一个兑换码!",
+ "复制所选兑换码到剪贴板": "复制所选兑换码到剪贴板",
+ "确定清除所有失效兑换码?": "确定清除所有失效兑换码?",
+ "将删除已使用、已禁用及过期的兑换码,此操作不可撤销。": "将删除已使用、已禁用及过期的兑换码,此操作不可撤销。",
+ "已删除 {{count}} 条失效兑换码": "已删除 {{count}} 条失效兑换码",
+ "关键字(id或者名称)": "关键字(id或者名称)",
+ "生成音乐": "生成音乐",
+ "生成歌词": "生成歌词",
+ "生成视频": "生成视频",
+ "排队中": "排队中",
+ "正在提交": "正在提交",
+ "平台": "平台",
+ "点击预览视频": "点击预览视频",
+ "任务记录": "任务记录",
+ "渠道 ID": "渠道 ID",
+ "已启用:限制模型": "已启用:限制模型",
+ "已耗尽": "已耗尽",
+ "剩余额度": "剩余额度",
+ "聊天链接配置错误,请联系管理员": "聊天链接配置错误,请联系管理员",
+ "令牌详情": "令牌详情",
+ "确定是否要删除此令牌?": "确定是否要删除此令牌?",
+ "项目操作按钮组": "项目操作按钮组",
+ "请联系管理员配置聊天链接": "请联系管理员配置聊天链接",
+ "令牌用于API访问认证,可以设置额度限制和模型权限。": "令牌用于API访问认证,可以设置额度限制和模型权限。",
+ "添加令牌": "添加令牌",
+ "请至少选择一个令牌!": "请至少选择一个令牌!",
+ "复制所选令牌到剪贴板": "复制所选令牌到剪贴板",
+ "搜索关键字": "搜索关键字",
+ "未知身份": "未知身份",
+ "已封禁": "已封禁",
+ "统计信息": "统计信息",
+ "剩余": "剩余",
+ "调用": "调用",
+ "邀请信息": "邀请信息",
+ "收益": "收益",
+ "无邀请人": "无邀请人",
+ "已注销": "已注销",
+ "确定要提升此用户吗?": "确定要提升此用户吗?",
+ "此操作将提升用户的权限级别": "此操作将提升用户的权限级别",
+ "确定要降级此用户吗?": "确定要降级此用户吗?",
+ "此操作将降低用户的权限级别": "此操作将降低用户的权限级别",
+ "确定是否要注销此用户?": "确定是否要注销此用户?",
+ "相当于删除用户,此修改将不可逆": "相当于删除用户,此修改将不可逆",
+ "用户管理页面,可以查看和管理所有注册用户的信息、权限和状态。": "用户管理页面,可以查看和管理所有注册用户的信息、权限和状态。",
+ "添加用户": "添加用户",
+ "支持搜索用户的 ID、用户名、显示名称和邮箱地址": "支持搜索用户的 ID、用户名、显示名称和邮箱地址",
+ "全部模型": "全部模型",
+ "智谱": "智谱",
+ "通义千问": "通义千问",
+ "文心一言": "文心一言",
+ "腾讯混元": "腾讯混元",
+ "360智脑": "360智脑",
+ "豆包": "豆包",
+ "用户分组": "用户分组",
+ "专属倍率": "专属倍率",
+ "输入价格:${{price}} / 1M tokens{{audioPrice}}": "输入价格:${{price}} / 1M tokens{{audioPrice}}",
+ "Web搜索价格:${{price}} / 1K 次": "Web搜索价格:${{price}} / 1K 次",
+ "文件搜索价格:${{price}} / 1K 次": "文件搜索价格:${{price}} / 1K 次",
+ "仅供参考,以实际扣费为准": "仅供参考,以实际扣费为准",
+ "价格:${{price}} * {{ratioType}}:{{ratio}}": "价格:${{price}} * {{ratioType}}:{{ratio}}",
+ "模型: {{ratio}} * {{ratioType}}:{{groupRatio}}": "模型: {{ratio}} * {{ratioType}}:{{groupRatio}}",
+ "提示价格:${{price}} / 1M tokens": "提示价格:${{price}} / 1M tokens",
+ "模型价格 ${{price}},{{ratioType}} {{ratio}}": "模型价格 ${{price}},{{ratioType}} {{ratio}}",
+ "模型: {{ratio}} * {{ratioType}}: {{groupRatio}}": "模型: {{ratio}} * {{ratioType}}: {{groupRatio}}",
+ "不是合法的 JSON 字符串": "不是合法的 JSON 字符串",
+ "请求发生错误: ": "请求发生错误: ",
+ "解析响应数据时发生错误": "解析响应数据时发生错误",
+ "连接已断开": "连接已断开",
+ "建立连接时发生错误": "建立连接时发生错误",
+ "加载模型失败": "加载模型失败",
+ "加载分组失败": "加载分组失败",
+ "消息已复制到剪贴板": "消息已复制到剪贴板",
+ "确认删除": "确认删除",
+ "确定要删除这条消息吗?": "确定要删除这条消息吗?",
+ "已删除消息及其回复": "已删除消息及其回复",
+ "消息已删除": "消息已删除",
+ "消息已编辑": "消息已编辑",
+ "检测到该消息后有AI回复,是否删除后续回复并重新生成?": "检测到该消息后有AI回复,是否删除后续回复并重新生成?",
+ "重新生成": "重新生成",
+ "消息已更新": "消息已更新",
+ "加载关于内容失败...": "加载关于内容失败...",
+ "可在设置页面设置关于内容,支持 HTML & Markdown": "可在设置页面设置关于内容,支持 HTML & Markdown",
+ "New API项目仓库地址:": "New API项目仓库地址:",
+ "| 基于": "| 基于",
+ "本项目根据": "本项目根据",
+ "MIT许可证": "MIT许可证",
+ "授权,需在遵守": "授权,需在遵守",
+ "Apache-2.0协议": "Apache-2.0协议",
+ "管理员暂时未设置任何关于内容": "管理员暂时未设置任何关于内容",
+ "仅支持 OpenAI 接口格式": "仅支持 OpenAI 接口格式",
+ "请填写密钥": "请填写密钥",
+ "获取模型列表成功": "获取模型列表成功",
+ "获取模型列表失败": "获取模型列表失败",
+ "请填写渠道名称和渠道密钥!": "请填写渠道名称和渠道密钥!",
+ "请至少选择一个模型!": "请至少选择一个模型!",
+ "模型映射必须是合法的 JSON 格式!": "模型映射必须是合法的 JSON 格式!",
+ "提交失败,请勿重复提交!": "提交失败,请勿重复提交!",
+ "渠道创建成功!": "渠道创建成功!",
+ "已新增 {{count}} 个模型:{{list}}": "已新增 {{count}} 个模型:{{list}}",
+ "未发现新增模型": "未发现新增模型",
+ "新建": "新建",
+ "更新渠道信息": "更新渠道信息",
+ "创建新的渠道": "创建新的渠道",
+ "基本信息": "基本信息",
+ "渠道的基本配置信息": "渠道的基本配置信息",
+ "请选择渠道类型": "请选择渠道类型",
+ "请为渠道命名": "请为渠道命名",
+ "请输入密钥,一行一个": "请输入密钥,一行一个",
+ "批量创建": "批量创建",
+ "API 配置": "API 配置",
+ "API 地址和相关配置": "API 地址和相关配置",
+ "2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的\".\"": "2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的\".\"",
+ "请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com": "请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com",
+ "请输入默认 API 版本,例如:2025-04-01-preview": "请输入默认 API 版本,例如:2025-04-01-preview",
+ "如果你对接的是上游One API或者New API等转发项目,请使用OpenAI类型,不要使用此类型,除非你知道你在做什么。": "如果你对接的是上游One API或者New API等转发项目,请使用OpenAI类型,不要使用此类型,除非你知道你在做什么。",
+ "完整的 Base URL,支持变量{model}": "完整的 Base URL,支持变量{model}",
+ "请输入完整的URL,例如:https://api.openai.com/v1/chat/completions": "请输入完整的URL,例如:https://api.openai.com/v1/chat/completions",
+ "Dify渠道只适配chatflow和agent,并且agent不支持图片!": "Dify渠道只适配chatflow和agent,并且agent不支持图片!",
+ "此项可选,用于通过自定义API地址来进行 API 调用,末尾不要带/v1和/": "此项可选,用于通过自定义API地址来进行 API 调用,末尾不要带/v1和/",
+ "对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写": "对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写",
+ "私有部署地址": "私有部署地址",
+ "请输入私有部署地址,格式为:https://fastgpt.run/api/openapi": "请输入私有部署地址,格式为:https://fastgpt.run/api/openapi",
+ "注意非Chat API,请务必填写正确的API地址,否则可能导致无法使用": "注意非Chat API,请务必填写正确的API地址,否则可能导致无法使用",
+ "请输入到 /suno 前的路径,通常就是域名,例如:https://api.example.com": "请输入到 /suno 前的路径,通常就是域名,例如:https://api.example.com",
+ "模型选择和映射设置": "模型选择和映射设置",
+ "模型": "模型",
+ "请选择该渠道所支持的模型": "请选择该渠道所支持的模型",
+ "填入相关模型": "填入相关模型",
+ "填入所有模型": "填入所有模型",
+ "获取模型列表": "获取模型列表",
+ "清除所有模型": "清除所有模型",
+ "输入自定义模型名称": "输入自定义模型名称",
+ "模型重定向": "模型重定向",
+ "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:",
+ "填入模板": "填入模板",
+ "默认测试模型": "默认测试模型",
+ "不填则为模型列表第一个": "不填则为模型列表第一个",
+ "渠道的高级配置选项": "渠道的高级配置选项",
+ "请选择可以使用该渠道的分组": "请选择可以使用该渠道的分组",
+ "请在系统设置页面编辑分组倍率以添加新的分组:": "请在系统设置页面编辑分组倍率以添加新的分组:",
+ "部署地区": "部署地区",
+ "知识库 ID": "知识库 ID",
+ "渠道标签": "渠道标签",
+ "渠道优先级": "渠道优先级",
+ "渠道权重": "渠道权重",
+ "渠道额外设置": "渠道额外设置",
+ "此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:": "此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:",
+ "强制格式化": "强制格式化",
+ "强制将响应格式化为 OpenAI 标准格式(只适用于OpenAI渠道类型)": "强制将响应格式化为 OpenAI 标准格式(只适用于OpenAI渠道类型)",
+ "思考内容转换": "思考内容转换",
+ "将 reasoning_content 转换为 标签拼接到内容中": "将 reasoning_content 转换为 标签拼接到内容中",
+ "透传请求体": "透传请求体",
+ "启用请求体透传功能": "启用请求体透传功能",
+ "代理地址": "代理地址",
+ "例如: socks5://user:pass@host:port": "例如: socks5://user:pass@host:port",
+ "用于配置网络代理": "用于配置网络代理",
+ "用于配置网络代理,支持 socks5 协议": "用于配置网络代理,支持 socks5 协议",
+ "系统提示词": "系统提示词",
+ "输入系统提示词,用户的系统提示词将优先于此设置": "输入系统提示词,用户的系统提示词将优先于此设置",
+ "用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置": "用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置",
+ "参数覆盖": "参数覆盖",
+ "此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:": "此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:",
+ "请输入组织org-xxx": "请输入组织org-xxx",
+ "组织,可选,不填则为默认组织": "组织,可选,不填则为默认组织",
+ "是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道": "是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道",
+ "状态码复写(仅影响本地判断,不修改返回到上游的状态码)": "状态码复写(仅影响本地判断,不修改返回到上游的状态码)",
+ "此项可选,用于复写返回的状态码,比如将claude渠道的400错误复写为500(用于重试),请勿滥用该功能,例如:": "此项可选,用于复写返回的状态码,比如将claude渠道的400错误复写为500(用于重试),请勿滥用该功能,例如:",
+ "编辑标签": "编辑标签",
+ "标签信息": "标签信息",
+ "标签的基本配置": "标签的基本配置",
+ "所有编辑均为覆盖操作,留空则不更改": "所有编辑均为覆盖操作,留空则不更改",
+ "标签名称": "标签名称",
+ "请输入新标签,留空则解散标签": "请输入新标签,留空则解散标签",
+ "当前模型列表为该标签下所有渠道模型列表最长的一个,并非所有渠道的并集,请注意可能导致某些渠道模型丢失。": "当前模型列表为该标签下所有渠道模型列表最长的一个,并非所有渠道的并集,请注意可能导致某些渠道模型丢失。",
+ "请选择该渠道所支持的模型,留空则不更改": "请选择该渠道所支持的模型,留空则不更改",
+ "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,留空则不更改": "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,留空则不更改",
+ "清空重定向": "清空重定向",
+ "分组设置": "分组设置",
+ "用户分组配置": "用户分组配置",
+ "请选择可以使用该渠道的分组,留空则不更改": "请选择可以使用该渠道的分组,留空则不更改",
+ "正在跳转...": "正在跳转...",
+ "小时": "小时",
+ "周": "周",
+ "模型调用次数占比": "模型调用次数占比",
+ "模型消耗分布": "模型消耗分布",
+ "总计": "总计",
+ "早上好": "早上好",
+ "中午好": "中午好",
+ "下午好": "下午好",
+ "账户数据": "账户数据",
+ "使用统计": "使用统计",
+ "统计次数": "统计次数",
+ "资源消耗": "资源消耗",
+ "统计额度": "统计额度",
+ "性能指标": "性能指标",
+ "平均RPM": "平均RPM",
+ "复制成功": "复制成功",
+ "进行中": "进行中",
+ "异常": "异常",
+ "正常": "正常",
+ "可用率": "可用率",
+ "有异常": "有异常",
+ "高延迟": "高延迟",
+ "维护中": "维护中",
+ "暂无监控数据": "暂无监控数据",
+ "搜索条件": "搜索条件",
+ "时间粒度": "时间粒度",
+ "模型数据分析": "模型数据分析",
+ "消耗分布": "消耗分布",
+ "调用次数分布": "调用次数分布",
+ "API信息": "API信息",
+ "暂无API信息": "暂无API信息",
+ "请联系管理员在系统设置中配置API信息": "请联系管理员在系统设置中配置API信息",
+ "显示最新20条": "显示最新20条",
+ "请联系管理员在系统设置中配置公告信息": "请联系管理员在系统设置中配置公告信息",
+ "暂无常见问答": "暂无常见问答",
+ "请联系管理员在系统设置中配置常见问答": "请联系管理员在系统设置中配置常见问答",
+ "服务可用性": "服务可用性",
+ "请联系管理员在系统设置中配置Uptime": "请联系管理员在系统设置中配置Uptime",
+ "加载首页内容失败...": "加载首页内容失败...",
+ "统一的大模型接口网关": "统一的大模型接口网关",
+ "更好的价格,更好的稳定性,无需订阅": "更好的价格,更好的稳定性,无需订阅",
+ "开始使用": "开始使用",
+ "支持众多的大模型供应商": "支持众多的大模型供应商",
+ "页面未找到,请检查您的浏览器地址是否正确": "页面未找到,请检查您的浏览器地址是否正确",
+ "登录过期,请重新登录!": "登录过期,请重新登录!",
+ "兑换码更新成功!": "兑换码更新成功!",
+ "兑换码创建成功!": "兑换码创建成功!",
+ "兑换码创建成功": "兑换码创建成功",
+ "兑换码创建成功,是否下载兑换码?": "兑换码创建成功,是否下载兑换码?",
+ "兑换码将以文本文件的形式下载,文件名为兑换码的名称。": "兑换码将以文本文件的形式下载,文件名为兑换码的名称。",
+ "更新兑换码信息": "更新兑换码信息",
+ "创建新的兑换码": "创建新的兑换码",
+ "设置兑换码的基本信息": "设置兑换码的基本信息",
+ "请输入名称": "请输入名称",
+ "选择过期时间(可选,留空为永久)": "选择过期时间(可选,留空为永久)",
+ "额度设置": "额度设置",
+ "设置兑换码的额度和数量": "设置兑换码的额度和数量",
+ "请输入额度": "请输入额度",
+ "生成数量": "生成数量",
+ "请输入生成数量": "请输入生成数量",
+ "你似乎并没有修改什么": "你似乎并没有修改什么",
+ "部分保存失败,请重试": "部分保存失败,请重试",
+ "保存成功": "保存成功",
+ "保存失败,请重试": "保存失败,请重试",
+ "请检查输入": "请检查输入",
+ "聊天配置": "聊天配置",
+ "为一个 JSON 文本": "为一个 JSON 文本",
+ "保存聊天设置": "保存聊天设置",
+ "设置已保存": "设置已保存",
+ "API地址": "API地址",
+ "说明": "说明",
+ "颜色": "颜色",
+ "API信息管理,可以配置多个API地址用于状态展示和负载均衡(最多50个)": "API信息管理,可以配置多个API地址用于状态展示和负载均衡(最多50个)",
+ "批量删除": "批量删除",
+ "保存设置": "保存设置",
+ "添加API": "添加API",
+ "请输入API地址": "请输入API地址",
+ "如:香港线路": "如:香港线路",
+ "请输入线路描述": "请输入线路描述",
+ "如:大带宽批量分析图片推荐": "如:大带宽批量分析图片推荐",
+ "请输入说明": "请输入说明",
+ "标识颜色": "标识颜色",
+ "确定要删除此API信息吗?": "确定要删除此API信息吗?",
+ "警告": "警告",
+ "发布时间": "发布时间",
+ "操作": "操作",
+ "系统公告管理,可以发布系统通知和重要消息(最多100个,前端显示最新20条)": "系统公告管理,可以发布系统通知和重要消息(最多100个,前端显示最新20条)",
+ "添加公告": "添加公告",
+ "编辑公告": "编辑公告",
+ "公告内容": "公告内容",
+ "请输入公告内容": "请输入公告内容",
+ "请选择发布日期": "请选择发布日期",
+ "公告类型": "公告类型",
+ "说明信息": "说明信息",
+ "可选,公告的补充说明": "可选,公告的补充说明",
+ "确定要删除此公告吗?": "确定要删除此公告吗?",
+ "数据看板设置": "数据看板设置",
+ "启用数据看板(实验性)": "启用数据看板(实验性)",
+ "数据看板更新间隔": "数据看板更新间隔",
+ "设置过短会影响数据库性能": "设置过短会影响数据库性能",
+ "数据看板默认时间粒度": "数据看板默认时间粒度",
+ "仅修改展示粒度,统计精确到小时": "仅修改展示粒度,统计精确到小时",
+ "保存数据看板设置": "保存数据看板设置",
+ "问题标题": "问题标题",
+ "回答内容": "回答内容",
+ "常见问答管理,为用户提供常见问题的答案(最多50个,前端显示最新20条)": "常见问答管理,为用户提供常见问题的答案(最多50个,前端显示最新20条)",
+ "添加问答": "添加问答",
+ "编辑问答": "编辑问答",
+ "请输入问题标题": "请输入问题标题",
+ "请输入回答内容": "请输入回答内容",
+ "确定要删除此问答吗?": "确定要删除此问答吗?",
+ "分类名称": "分类名称",
+ "Uptime Kuma地址": "Uptime Kuma地址",
+ "Uptime Kuma监控分类管理,可以配置多个监控分类用于服务状态展示(最多20个)": "Uptime Kuma监控分类管理,可以配置多个监控分类用于服务状态展示(最多20个)",
+ "编辑分类": "编辑分类",
+ "添加分类": "添加分类",
+ "请输入分类名称,如:OpenAI、Claude等": "请输入分类名称,如:OpenAI、Claude等",
+ "请输入分类名称": "请输入分类名称",
+ "请输入Uptime Kuma服务地址,如:https://status.example.com": "请输入Uptime Kuma服务地址,如:https://status.example.com",
+ "请输入Uptime Kuma地址": "请输入Uptime Kuma地址",
+ "请输入状态页面的Slug,如:my-status": "请输入状态页面的Slug,如:my-status",
+ "请输入状态页面Slug": "请输入状态页面Slug",
+ "确定要删除此分类吗?": "确定要删除此分类吗?",
+ "绘图设置": "绘图设置",
+ "启用绘图功能": "启用绘图功能",
+ "允许回调(会泄露服务器 IP 地址)": "允许回调(会泄露服务器 IP 地址)",
+ "允许 AccountFilter 参数": "允许 AccountFilter 参数",
+ "开启之后会清除用户提示词中的": "开启之后会清除用户提示词中的",
+ "以及": "以及",
+ "检测必须等待绘图成功才能进行放大等操作": "检测必须等待绘图成功才能进行放大等操作",
+ "保存绘图设置": "保存绘图设置",
+ "Claude设置": "Claude设置",
+ "Claude请求头覆盖": "Claude请求头覆盖",
+ "为一个 JSON 文本,例如:": "为一个 JSON 文本,例如:",
+ "缺省 MaxTokens": "缺省 MaxTokens",
+ "启用Claude思考适配(-thinking后缀)": "启用Claude思考适配(-thinking后缀)",
+ "思考适配 BudgetTokens 百分比": "思考适配 BudgetTokens 百分比",
+ "0.1-1之间的小数": "0.1-1之间的小数",
+ "Gemini设置": "Gemini设置",
+ "Gemini安全设置": "Gemini安全设置",
+ "default为默认设置,可单独设置每个模型的版本": "default为默认设置,可单独设置每个模型的版本",
+ "例如:": "例如:",
+ "Gemini思考适配设置": "Gemini思考适配设置",
+ "启用Gemini思考后缀适配": "启用Gemini思考后缀适配",
+ "适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "适配 -thinking、-thinking-预算数字 和 -nothinking 后缀",
+ "0.002-1之间的小数": "0.002-1之间的小数",
+ "全局设置": "全局设置",
+ "启用请求透传": "启用请求透传",
+ "连接保活设置": "连接保活设置",
+ "启用Ping间隔": "启用Ping间隔",
+ "Ping间隔(秒)": "Ping间隔(秒)",
+ "新用户初始额度": "新用户初始额度",
+ "请求预扣费额度": "请求预扣费额度",
+ "请求结束后多退少补": "请求结束后多退少补",
+ "邀请新用户奖励额度": "邀请新用户奖励额度",
+ "新用户使用邀请码奖励额度": "新用户使用邀请码奖励额度",
+ "例如:1000": "例如:1000",
+ "保存额度设置": "保存额度设置",
+ "例如发卡网站的购买链接": "例如发卡网站的购买链接",
+ "文档地址": "文档地址",
+ "单位美元额度": "单位美元额度",
+ "一单位货币能兑换的额度": "一单位货币能兑换的额度",
+ "失败重试次数": "失败重试次数",
+ "以货币形式显示额度": "以货币形式显示额度",
+ "额度查询接口返回令牌额度而非用户额度": "额度查询接口返回令牌额度而非用户额度",
+ "默认折叠侧边栏": "默认折叠侧边栏",
+ "开启后不限制:必须设置模型倍率": "开启后不限制:必须设置模型倍率",
+ "保存通用设置": "保存通用设置",
+ "请选择日志记录时间": "请选择日志记录时间",
+ "条日志已清理!": "条日志已清理!",
+ "日志清理失败:": "日志清理失败:",
+ "启用额度消费日志记录": "启用额度消费日志记录",
+ "日志记录时间": "日志记录时间",
+ "清除历史日志": "清除历史日志",
+ "保存日志设置": "保存日志设置",
+ "监控设置": "监控设置",
+ "测试所有渠道的最长响应时间": "测试所有渠道的最长响应时间",
+ "额度提醒阈值": "额度提醒阈值",
+ "低于此额度时将发送邮件提醒用户": "低于此额度时将发送邮件提醒用户",
+ "失败时自动禁用通道": "失败时自动禁用通道",
+ "成功时自动启用通道": "成功时自动启用通道",
+ "自动禁用关键词": "自动禁用关键词",
+ "一行一个,不区分大小写": "一行一个,不区分大小写",
+ "屏蔽词过滤设置": "屏蔽词过滤设置",
+ "启用屏蔽词过滤功能": "启用屏蔽词过滤功能",
+ "启用 Prompt 检查": "启用 Prompt 检查",
+ "一行一个屏蔽词,不需要符号分割": "一行一个屏蔽词,不需要符号分割",
+ "保存屏蔽词过滤设置": "保存屏蔽词过滤设置",
+ "更新成功": "更新成功",
+ "更新失败": "更新失败",
+ "服务器地址": "服务器地址",
+ "更新服务器地址": "更新服务器地址",
+ "请先填写服务器地址": "请先填写服务器地址",
+ "充值分组倍率不是合法的 JSON 字符串": "充值分组倍率不是合法的 JSON 字符串",
+ "充值方式设置不是合法的 JSON 字符串": "充值方式设置不是合法的 JSON 字符串",
+ "支付设置": "支付设置",
+ "(当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!)": "(当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!)",
+ "例如:https://yourdomain.com": "例如:https://yourdomain.com",
+ "易支付商户ID": "易支付商户ID",
+ "易支付商户密钥": "易支付商户密钥",
+ "敏感信息不会发送到前端显示": "敏感信息不会发送到前端显示",
+ "回调地址": "回调地址",
+ "充值价格(x元/美金)": "充值价格(x元/美金)",
+ "例如:7,就是7元/美金": "例如:7,就是7元/美金",
+ "最低充值美元数量": "最低充值美元数量",
+ "例如:2,就是最低充值2$": "例如:2,就是最低充值2$",
+ "为一个 JSON 文本,键为组名称,值为倍率": "为一个 JSON 文本,键为组名称,值为倍率",
+ "充值方式设置": "充值方式设置",
+ "更新支付设置": "更新支付设置",
+ "模型请求速率限制": "模型请求速率限制",
+ "启用用户模型请求速率限制(可能会影响高并发性能)": "启用用户模型请求速率限制(可能会影响高并发性能)",
+ "分钟": "分钟",
+ "频率限制的周期(分钟)": "频率限制的周期(分钟)",
+ "用户每周期最多请求次数": "用户每周期最多请求次数",
+ "包括失败请求的次数,0代表不限制": "包括失败请求的次数,0代表不限制",
+ "用户每周期最多请求完成次数": "用户每周期最多请求完成次数",
+ "只包括请求成功的次数": "只包括请求成功的次数",
+ "分组速率限制": "分组速率限制",
+ "使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}": "使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}",
+ "示例:{\"default\": [200, 100], \"vip\": [0, 1000]}。": "示例:{\"default\": [200, 100], \"vip\": [0, 1000]}。",
+ "[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。": "[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。",
+ "分组速率配置优先级高于全局速率限制。": "分组速率配置优先级高于全局速率限制。",
+ "限制周期统一使用上方配置的“限制周期”值。": "限制周期统一使用上方配置的“限制周期”值。",
+ "保存模型速率限制": "保存模型速率限制",
+ "保存失败": "保存失败",
+ "为一个 JSON 文本,键为分组名称,值为倍率": "为一个 JSON 文本,键为分组名称,值为倍率",
+ "用户可选分组": "用户可选分组",
+ "为一个 JSON 文本,键为分组名称,值为分组描述": "为一个 JSON 文本,键为分组名称,值为分组描述",
+ "自动分组auto,从第一个开始选择": "自动分组auto,从第一个开始选择",
+ "必须是有效的 JSON 字符串数组,例如:[\"g1\",\"g2\"]": "必须是有效的 JSON 字符串数组,例如:[\"g1\",\"g2\"]",
+ "模型固定价格": "模型固定价格",
+ "一次调用消耗多少刀,优先级大于模型倍率": "一次调用消耗多少刀,优先级大于模型倍率",
+ "为一个 JSON 文本,键为模型名称,值为倍率": "为一个 JSON 文本,键为模型名称,值为倍率",
+ "模型补全倍率(仅对自定义模型有效)": "模型补全倍率(仅对自定义模型有效)",
+ "仅对自定义模型有效": "仅对自定义模型有效",
+ "保存模型倍率设置": "保存模型倍率设置",
+ "确定重置模型倍率吗?": "确定重置模型倍率吗?",
+ "重置模型倍率": "重置模型倍率",
+ "获取启用模型失败:": "获取启用模型失败:",
+ "获取启用模型失败": "获取启用模型失败",
+ "JSON解析错误:": "JSON解析错误:",
+ "保存失败:": "保存失败:",
+ "输入模型倍率": "输入模型倍率",
+ "输入补全倍率": "输入补全倍率",
+ "请输入数字": "请输入数字",
+ "模型名称已存在": "模型名称已存在",
+ "请先选择需要批量设置的模型": "请先选择需要批量设置的模型",
+ "请输入模型倍率和补全倍率": "请输入模型倍率和补全倍率",
+ "请输入有效的数字": "请输入有效的数字",
+ "请输入填充值": "请输入填充值",
+ "批量设置成功": "批量设置成功",
+ "已为 {{count}} 个模型设置{{type}}": "已为 {{count}} 个模型设置{{type}}",
+ "模型倍率和补全倍率": "模型倍率和补全倍率",
+ "添加模型": "添加模型",
+ "批量设置": "批量设置",
+ "应用更改": "应用更改",
+ "搜索模型名称": "搜索模型名称",
+ "此页面仅显示未设置价格或倍率的模型,设置后将自动从列表中移除": "此页面仅显示未设置价格或倍率的模型,设置后将自动从列表中移除",
+ "定价模式": "定价模式",
+ "固定价格": "固定价格",
+ "固定价格(每次)": "固定价格(每次)",
+ "输入每次价格": "输入每次价格",
+ "输入补全价格": "输入补全价格",
+ "批量设置模型参数": "批量设置模型参数",
+ "设置类型": "设置类型",
+ "模型倍率和补全倍率同时设置": "模型倍率和补全倍率同时设置",
+ "模型倍率值": "模型倍率值",
+ "请输入模型倍率": "请输入模型倍率",
+ "补全倍率值": "补全倍率值",
+ "请输入补全倍率": "请输入补全倍率",
+ "请输入数值": "请输入数值",
+ "将为选中的 ": "将为选中的 ",
+ " 个模型设置相同的值": " 个模型设置相同的值",
+ "当前设置类型: ": "当前设置类型: ",
+ "默认补全倍率": "默认补全倍率",
+ "添加成功": "添加成功",
+ "价格设置方式": "价格设置方式",
+ "按倍率设置": "按倍率设置",
+ "按价格设置": "按价格设置",
+ "输入价格": "输入价格",
+ "输出价格": "输出价格",
+ "获取渠道失败:": "获取渠道失败:",
+ "请至少选择一个渠道": "请至少选择一个渠道",
+ "后端请求失败": "后端请求失败",
+ "部分渠道测试失败:": "部分渠道测试失败:",
+ "未找到差异化倍率,无需同步": "未找到差异化倍率,无需同步",
+ "请求后端接口失败:": "请求后端接口失败:",
+ "同步成功": "同步成功",
+ "部分保存失败": "部分保存失败",
+ "未找到匹配的模型": "未找到匹配的模型",
+ "暂无差异化倍率显示": "暂无差异化倍率显示",
+ "请先选择同步渠道": "请先选择同步渠道",
+ "倍率类型": "倍率类型",
+ "缓存倍率": "缓存倍率",
+ "当前值": "当前值",
+ "未设置": "未设置",
+ "与本地相同": "与本地相同",
+ "运营设置": "运营设置",
+ "聊天设置": "聊天设置",
+ "速率限制设置": "速率限制设置",
+ "模型相关设置": "模型相关设置",
+ "系统设置": "系统设置",
+ "仪表盘设置": "仪表盘设置",
+ "获取初始化状态失败": "获取初始化状态失败",
+ "表单引用错误,请刷新页面重试": "表单引用错误,请刷新页面重试",
+ "请输入管理员用户名": "请输入管理员用户名",
+ "密码长度至少为8个字符": "密码长度至少为8个字符",
+ "两次输入的密码不一致": "两次输入的密码不一致",
+ "系统初始化成功,正在跳转...": "系统初始化成功,正在跳转...",
+ "初始化失败,请重试": "初始化失败,请重试",
+ "系统初始化失败,请重试": "系统初始化失败,请重试",
+ "系统初始化": "系统初始化",
+ "欢迎使用,请完成以下设置以开始使用系统": "欢迎使用,请完成以下设置以开始使用系统",
+ "数据库信息": "数据库信息",
+ "管理员账号": "管理员账号",
+ "设置系统管理员的登录信息": "设置系统管理员的登录信息",
+ "管理员账号已经初始化过,请继续设置其他参数": "管理员账号已经初始化过,请继续设置其他参数",
+ "密码": "密码",
+ "请输入管理员密码": "请输入管理员密码",
+ "请确认管理员密码": "请确认管理员密码",
+ "选择适合您使用场景的模式": "选择适合您使用场景的模式",
+ "对外运营模式": "对外运营模式",
+ "适用于为多个用户提供服务的场景": "适用于为多个用户提供服务的场景",
+ "默认模式": "默认模式",
+ "适用于个人使用的场景,不需要设置模型价格": "适用于个人使用的场景,不需要设置模型价格",
+ "无需计费": "无需计费",
+ "演示站点模式": "演示站点模式",
+ "适用于展示系统功能的场景,提供基础功能演示": "适用于展示系统功能的场景,提供基础功能演示",
+ "初始化系统": "初始化系统",
+ "使用模式说明": "使用模式说明",
+ "我已了解": "我已了解",
+ "默认模式,适用于为多个用户提供服务的场景。": "默认模式,适用于为多个用户提供服务的场景。",
+ "此模式下,系统将计算每次调用的用量,您需要对每个模型都设置价格,如果没有设置价格,用户将无法使用该模型。": "此模式下,系统将计算每次调用的用量,您需要对每个模型都设置价格,如果没有设置价格,用户将无法使用该模型。",
+ "多用户支持": "多用户支持",
+ "适用于个人使用的场景。": "适用于个人使用的场景。",
+ "不需要设置模型价格,系统将弱化用量计算,您可专注于使用模型。": "不需要设置模型价格,系统将弱化用量计算,您可专注于使用模型。",
+ "个人使用": "个人使用",
+ "适用于展示系统功能的场景。": "适用于展示系统功能的场景。",
+ "提供基础功能演示,方便用户了解系统特性。": "提供基础功能演示,方便用户了解系统特性。",
+ "体验试用": "体验试用",
+ "自动选择": "自动选择",
+ "过期时间格式错误!": "过期时间格式错误!",
+ "令牌更新成功!": "令牌更新成功!",
+ "令牌创建成功,请在列表页面点击复制获取令牌!": "令牌创建成功,请在列表页面点击复制获取令牌!",
+ "更新令牌信息": "更新令牌信息",
+ "创建新的令牌": "创建新的令牌",
+ "设置令牌的基本信息": "设置令牌的基本信息",
+ "请选择过期时间": "请选择过期时间",
+ "一天": "一天",
+ "一个月": "一个月",
+ "设置令牌可用额度和数量": "设置令牌可用额度和数量",
+ "新建数量": "新建数量",
+ "请选择或输入创建令牌的数量": "请选择或输入创建令牌的数量",
+ "20个": "20个",
+ "100个": "100个",
+ "取消无限额度": "取消无限额度",
+ "设为无限额度": "设为无限额度",
+ "设置令牌的访问限制": "设置令牌的访问限制",
+ "IP白名单": "IP白名单",
+ "允许的IP,一行一个,不填写则不限制": "允许的IP,一行一个,不填写则不限制",
+ "请勿过度信任此功能,IP可能被伪造": "请勿过度信任此功能,IP可能被伪造",
+ "勾选启用模型限制后可选择": "勾选启用模型限制后可选择",
+ "非必要,不建议启用模型限制": "非必要,不建议启用模型限制",
+ "分组信息": "分组信息",
+ "设置令牌的分组": "设置令牌的分组",
+ "令牌分组,默认为用户的分组": "令牌分组,默认为用户的分组",
+ "管理员未设置用户可选分组": "管理员未设置用户可选分组",
+ "请输入兑换码!": "请输入兑换码!",
+ "兑换成功!": "兑换成功!",
+ "成功兑换额度:": "成功兑换额度:",
+ "请求失败": "请求失败",
+ "超级管理员未设置充值链接!": "超级管理员未设置充值链接!",
+ "管理员未开启在线充值!": "管理员未开启在线充值!",
+ "充值数量不能小于": "充值数量不能小于",
+ "支付请求失败": "支付请求失败",
+ "划转金额最低为": "划转金额最低为",
+ "邀请链接已复制到剪切板": "邀请链接已复制到剪切板",
+ "支付方式配置错误, 请联系管理员": "支付方式配置错误, 请联系管理员",
+ "划转邀请额度": "划转邀请额度",
+ "可用邀请额度": "可用邀请额度",
+ "划转额度": "划转额度",
+ "充值确认": "充值确认",
+ "充值数量": "充值数量",
+ "实付金额": "实付金额",
+ "支付方式": "支付方式",
+ "在线充值": "在线充值",
+ "快速方便的充值方式": "快速方便的充值方式",
+ "选择充值额度": "选择充值额度",
+ "实付": "实付",
+ "或输入自定义金额": "或输入自定义金额",
+ "充值数量,最低 ": "充值数量,最低 ",
+ "选择支付方式": "选择支付方式",
+ "处理中": "处理中",
+ "兑换码充值": "兑换码充值",
+ "使用兑换码快速充值": "使用兑换码快速充值",
+ "请输入兑换码": "请输入兑换码",
+ "兑换中...": "兑换中...",
+ "兑换": "兑换",
+ "邀请奖励": "邀请奖励",
+ "邀请好友获得额外奖励": "邀请好友获得额外奖励",
+ "待使用收益": "待使用收益",
+ "总收益": "总收益",
+ "邀请人数": "邀请人数",
+ "邀请链接": "邀请链接",
+ "邀请好友注册,好友充值后您可获得相应奖励": "邀请好友注册,好友充值后您可获得相应奖励",
+ "通过划转功能将奖励额度转入到您的账户余额中": "通过划转功能将奖励额度转入到您的账户余额中",
+ "邀请的好友越多,获得的奖励越多": "邀请的好友越多,获得的奖励越多",
+ "用户名和密码不能为空!": "用户名和密码不能为空!",
+ "用户账户创建成功!": "用户账户创建成功!",
+ "提交": "提交",
+ "创建新用户账户": "创建新用户账户",
+ "请输入显示名称": "请输入显示名称",
+ "请输入密码": "请输入密码",
+ "请输入备注(仅管理员可见)": "请输入备注(仅管理员可见)",
+ "编辑用户": "编辑用户",
+ "用户的基本账户信息": "用户的基本账户信息",
+ "请输入新的用户名": "请输入新的用户名",
+ "请输入新的密码,最短 8 位": "请输入新的密码,最短 8 位",
+ "显示名称": "显示名称",
+ "请输入新的显示名称": "请输入新的显示名称",
+ "权限设置": "权限设置",
+ "用户分组和额度管理": "用户分组和额度管理",
+ "请输入新的剩余额度": "请输入新的剩余额度",
+ "添加额度": "添加额度",
+ "第三方账户绑定状态(只读)": "第三方账户绑定状态(只读)",
+ "已绑定的 GitHub 账户": "已绑定的 GitHub 账户",
+ "已绑定的 OIDC 账户": "已绑定的 OIDC 账户",
+ "已绑定的微信账户": "已绑定的微信账户",
+ "已绑定的邮箱账户": "已绑定的邮箱账户",
+ "已绑定的 Telegram 账户": "已绑定的 Telegram 账户",
+ "新额度": "新额度",
+ "需要添加的额度(支持负数)": "需要添加的额度(支持负数)"
+}
\ No newline at end of file
diff --git a/common/logger.go b/logger/logger.go
similarity index 70%
rename from common/logger.go
rename to logger/logger.go
index 86d15fa4..d59e51cb 100644
--- a/common/logger.go
+++ b/logger/logger.go
@@ -1,23 +1,26 @@
-package common
+package logger
import (
"context"
"encoding/json"
"fmt"
- "github.com/bytedance/gopkg/util/gopool"
- "github.com/gin-gonic/gin"
"io"
"log"
+ "one-api/common"
"os"
"path/filepath"
"sync"
"time"
+
+ "github.com/bytedance/gopkg/util/gopool"
+ "github.com/gin-gonic/gin"
)
const (
loggerINFO = "INFO"
loggerWarn = "WARN"
loggerError = "ERR"
+ loggerDebug = "DEBUG"
)
const maxLogCount = 1000000
@@ -27,7 +30,10 @@ var setupLogLock sync.Mutex
var setupLogWorking bool
func SetupLogger() {
- if *LogDir != "" {
+ defer func() {
+ setupLogWorking = false
+ }()
+ if *common.LogDir != "" {
ok := setupLogLock.TryLock()
if !ok {
log.Println("setup log is already working")
@@ -35,9 +41,8 @@ func SetupLogger() {
}
defer func() {
setupLogLock.Unlock()
- setupLogWorking = false
}()
- logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
+ logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
@@ -47,16 +52,6 @@ func SetupLogger() {
}
}
-func SysLog(s string) {
- t := time.Now()
- _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
-}
-
-func SysError(s string) {
- t := time.Now()
- _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
-}
-
func LogInfo(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg)
}
@@ -69,12 +64,21 @@ func LogError(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg)
}
+func LogDebug(ctx context.Context, msg string) {
+ if common.DebugEnabled {
+ logHelper(ctx, loggerDebug, msg)
+ }
+}
+
func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter
if level == loggerINFO {
writer = gin.DefaultWriter
}
- id := ctx.Value(RequestIdKey)
+ id := ctx.Value(common.RequestIdKey)
+ if id == nil {
+ id = "SYSTEM"
+ }
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
logCount++ // we don't need accurate count, so no lock here
@@ -87,23 +91,17 @@ func logHelper(ctx context.Context, level string, msg string) {
}
}
-func FatalLog(v ...any) {
- t := time.Now()
- _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
- os.Exit(1)
-}
-
func LogQuota(quota int) string {
- if DisplayInCurrencyEnabled {
- return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
+ if common.DisplayInCurrencyEnabled {
+ return fmt.Sprintf("$%.6f 额度", float64(quota)/common.QuotaPerUnit)
} else {
return fmt.Sprintf("%d 点额度", quota)
}
}
func FormatQuota(quota int) string {
- if DisplayInCurrencyEnabled {
- return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit)
+ if common.DisplayInCurrencyEnabled {
+ return fmt.Sprintf("$%.6f", float64(quota)/common.QuotaPerUnit)
} else {
return fmt.Sprintf("%d", quota)
}
diff --git a/main.go b/main.go
index c286650f..2dfddacc 100644
--- a/main.go
+++ b/main.go
@@ -8,11 +8,12 @@ import (
"one-api/common"
"one-api/constant"
"one-api/controller"
+ "one-api/logger"
"one-api/middleware"
"one-api/model"
"one-api/router"
"one-api/service"
- "one-api/setting/operation_setting"
+ "one-api/setting/ratio_setting"
"os"
"strconv"
@@ -32,14 +33,13 @@ var buildFS embed.FS
var indexPage []byte
func main() {
- err := godotenv.Load(".env")
+
+ err := InitResources()
if err != nil {
- common.SysLog("Support for .env file is disabled: " + err.Error())
+ common.FatalLog("failed to initialize resources: " + err.Error())
+ return
}
- common.LoadEnv()
-
- common.SetupLogger()
common.SysLog("New API " + common.Version + " started")
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
@@ -47,19 +47,7 @@ func main() {
if common.DebugEnabled {
common.SysLog("running in debug mode")
}
- // Initialize SQL Database
- err = model.InitDB()
- if err != nil {
- common.FatalLog("failed to initialize database: " + err.Error())
- }
- model.CheckSetup()
-
- // Initialize SQL Database
- err = model.InitLogDB()
- if err != nil {
- common.FatalLog("failed to initialize database: " + err.Error())
- }
defer func() {
err := model.CloseDB()
if err != nil {
@@ -67,48 +55,35 @@ func main() {
}
}()
- // Initialize Redis
- err = common.InitRedisClient()
- if err != nil {
- common.FatalLog("failed to initialize Redis: " + err.Error())
- }
-
- // Initialize model settings
- operation_setting.InitRatioSettings()
- // Initialize constants
- constant.InitEnv()
- // Initialize options
- model.InitOptionMap()
-
- service.InitTokenEncoders()
-
if common.RedisEnabled {
// for compatibility with old versions
common.MemoryCacheEnabled = true
}
if common.MemoryCacheEnabled {
common.SysLog("memory cache enabled")
- common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
+ common.SysLog(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
// Add panic recovery and retry for InitChannelCache
func() {
defer func() {
if r := recover(); r != nil {
- common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
+ common.SysLog(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
// Retry once
- _, fixErr := model.FixAbility()
+ _, _, fixErr := model.FixAbility()
if fixErr != nil {
- common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
+ common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
}
}
}()
model.InitChannelCache()
}()
- go model.SyncOptions(common.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency)
}
+ // 热更新配置
+ go model.SyncOptions(common.SyncFrequency)
+
// 数据看板
go model.UpdateQuotaData()
@@ -151,7 +126,7 @@ func main() {
// Initialize HTTP server
server := gin.New()
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
- common.SysError(fmt.Sprintf("panic detected: %v", err))
+ common.SysLog(fmt.Sprintf("panic detected: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
@@ -184,3 +159,53 @@ func main() {
common.FatalLog("failed to start HTTP server: " + err.Error())
}
}
+
+func InitResources() error {
+ // Initialize resources here if needed
+ // This is a placeholder function for future resource initialization
+ err := godotenv.Load(".env")
+ if err != nil {
+ common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
+ common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
+ }
+
+ // 加载环境变量
+ common.InitEnv()
+
+ logger.SetupLogger()
+
+ // Initialize model settings
+ ratio_setting.InitRatioSettings()
+
+ service.InitHttpClient()
+
+ service.InitTokenEncoders()
+
+ // Initialize SQL Database
+ err = model.InitDB()
+ if err != nil {
+ common.FatalLog("failed to initialize database: " + err.Error())
+ return err
+ }
+
+ model.CheckSetup()
+
+ // Initialize options, should after model.InitDB()
+ model.InitOptionMap()
+
+ // 初始化模型
+ model.GetPricing()
+
+ // Initialize SQL Database
+ err = model.InitLogDB()
+ if err != nil {
+ return err
+ }
+
+ // Initialize Redis
+ err = common.InitRedisClient()
+ if err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/makefile b/makefile
index 5042723c..cbc4ea6a 100644
--- a/makefile
+++ b/makefile
@@ -7,7 +7,7 @@ all: build-frontend start-backend
build-frontend:
@echo "Building frontend..."
- @cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) npm run build
+ @cd $(FRONTEND_DIR) && bun install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
start-backend:
@echo "Starting backend dev server..."
diff --git a/middleware/auth.go b/middleware/auth.go
index ce86bb36..25caf50d 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -1,9 +1,13 @@
package middleware
import (
+ "fmt"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/model"
+ "one-api/setting"
+ "one-api/setting/ratio_setting"
"strconv"
"strings"
@@ -121,7 +125,20 @@ func authHelper(c *gin.Context, minRole int) {
c.Set("role", role)
c.Set("id", id)
c.Set("group", session.Get("group"))
+ c.Set("user_group", session.Get("group"))
c.Set("use_access_token", useAccessToken)
+
+ //userCache, err := model.GetUserCache(id.(int))
+ //if err != nil {
+ // c.JSON(http.StatusOK, gin.H{
+ // "success": false,
+ // "message": err.Error(),
+ // })
+ // c.Abort()
+ // return
+ //}
+ //userCache.WriteContext(c)
+
c.Next()
}
@@ -177,18 +194,24 @@ func TokenAuth() func(c *gin.Context) {
}
// 检查path包含/v1/messages
if strings.Contains(c.Request.URL.Path, "/v1/messages") {
- // 从x-api-key中获取key
- key := c.Request.Header.Get("x-api-key")
- if key != "" {
- c.Request.Header.Set("Authorization", "Bearer "+key)
+ anthropicKey := c.Request.Header.Get("x-api-key")
+ if anthropicKey != "" {
+ c.Request.Header.Set("Authorization", "Bearer "+anthropicKey)
}
}
// gemini api 从query中获取key
- if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
+ if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models") ||
+ strings.HasPrefix(c.Request.URL.Path, "/v1beta/openai/models") ||
+ strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
skKey := c.Query("key")
if skKey != "" {
c.Request.Header.Set("Authorization", "Bearer "+skKey)
}
+ // 从x-goog-api-key header中获取key
+ xGoogKey := c.Request.Header.Get("x-goog-api-key")
+ if xGoogKey != "" {
+ c.Request.Header.Set("Authorization", "Bearer "+xGoogKey)
+ }
}
key := c.Request.Header.Get("Authorization")
parts := make([]string, 0)
@@ -215,6 +238,16 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
+
+ allowIpsMap := token.GetIpLimitsMap()
+ if len(allowIpsMap) != 0 {
+ clientIp := c.ClientIP()
+ if _, ok := allowIpsMap[clientIp]; !ok {
+ abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
+ return
+ }
+ }
+
userCache, err := model.GetUserCache(token.UserId)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
@@ -228,30 +261,59 @@ func TokenAuth() func(c *gin.Context) {
userCache.WriteContext(c)
- c.Set("id", token.UserId)
- c.Set("token_id", token.Id)
- c.Set("token_key", token.Key)
- c.Set("token_name", token.Name)
- c.Set("token_unlimited_quota", token.UnlimitedQuota)
- if !token.UnlimitedQuota {
- c.Set("token_quota", token.RemainQuota)
- }
- if token.ModelLimitsEnabled {
- c.Set("token_model_limit_enabled", true)
- c.Set("token_model_limit", token.GetModelLimitsMap())
- } else {
- c.Set("token_model_limit_enabled", false)
- }
- c.Set("allow_ips", token.GetIpLimitsMap())
- c.Set("token_group", token.Group)
- if len(parts) > 1 {
- if model.IsAdmin(token.UserId) {
- c.Set("specific_channel_id", parts[1])
- } else {
- abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
+ userGroup := userCache.Group
+ tokenGroup := token.Group
+ if tokenGroup != "" {
+ // check common.UserUsableGroups[userGroup]
+ if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
+ abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
+ // check group in common.GroupRatio
+ if !ratio_setting.ContainsGroupRatio(tokenGroup) {
+ if tokenGroup != "auto" {
+ abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
+ return
+ }
+ }
+ userGroup = tokenGroup
+ }
+ common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
+
+ err = SetupContextForToken(c, token, parts...)
+ if err != nil {
+ return
}
c.Next()
}
}
+
+func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
+ if token == nil {
+ return fmt.Errorf("token is nil")
+ }
+ c.Set("id", token.UserId)
+ c.Set("token_id", token.Id)
+ c.Set("token_key", token.Key)
+ c.Set("token_name", token.Name)
+ c.Set("token_unlimited_quota", token.UnlimitedQuota)
+ if !token.UnlimitedQuota {
+ c.Set("token_quota", token.RemainQuota)
+ }
+ if token.ModelLimitsEnabled {
+ c.Set("token_model_limit_enabled", true)
+ c.Set("token_model_limit", token.GetModelLimitsMap())
+ } else {
+ c.Set("token_model_limit_enabled", false)
+ }
+ c.Set("token_group", token.Group)
+ if len(parts) > 1 {
+ if model.IsAdmin(token.UserId) {
+ c.Set("specific_channel_id", parts[1])
+ } else {
+ abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
+ return fmt.Errorf("普通用户不支持指定渠道")
+ }
+ }
+ return nil
+}
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 1bfe1821..28b66a3a 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -11,6 +11,8 @@ import (
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
+ "one-api/setting/ratio_setting"
+ "one-api/types"
"strconv"
"strings"
"time"
@@ -20,41 +22,18 @@ import (
type ModelRequest struct {
Model string `json:"model"`
+ Group string `json:"group,omitempty"`
}
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
- allowIpsMap := c.GetStringMap("allow_ips")
- if len(allowIpsMap) != 0 {
- clientIp := c.ClientIP()
- if _, ok := allowIpsMap[clientIp]; !ok {
- abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
- return
- }
- }
var channel *model.Channel
- channelId, ok := c.Get("specific_channel_id")
+ channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
modelRequest, shouldSelectChannel, err := getModelRequest(c)
if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
- userGroup := c.GetString(constant.ContextKeyUserGroup)
- tokenGroup := c.GetString("token_group")
- if tokenGroup != "" {
- // check common.UserUsableGroups[userGroup]
- if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
- abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
- return
- }
- // check group in common.GroupRatio
- if !setting.ContainsGroupRatio(tokenGroup) {
- abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
- return
- }
- userGroup = tokenGroup
- }
- c.Set("group", userGroup)
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
@@ -73,47 +52,71 @@ func Distribute() func(c *gin.Context) {
} else {
// Select a channel for the user
// check token model mapping
- modelLimitEnable := c.GetBool("token_model_limit_enabled")
+ modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
if modelLimitEnable {
- s, ok := c.Get("token_model_limit")
- var tokenModelLimit map[string]bool
- if ok {
- tokenModelLimit = s.(map[string]bool)
- } else {
- tokenModelLimit = map[string]bool{}
- }
- if tokenModelLimit != nil {
- if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
- abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
- return
- }
- } else {
+ s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
+ if !ok {
// token model limit is empty, all models are not allowed
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
return
}
+ var tokenModelLimit map[string]bool
+ tokenModelLimit, ok = s.(map[string]bool)
+ if !ok {
+ tokenModelLimit = map[string]bool{}
+ }
+ matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) // match gpts & thinking-*
+ if _, ok := tokenModelLimit[matchName]; !ok {
+ abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
+ return
+ }
}
if shouldSelectChannel {
- channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
- if err != nil {
- message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
- // 如果错误,但是渠道不为空,说明是数据库一致性问题
- if channel != nil {
- common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
- message = "数据库一致性已被破坏,请联系管理员"
+ if modelRequest.Model == "" {
+ abortWithOpenAiMessage(c, http.StatusBadRequest, "未指定模型名称,模型名称不能为空")
+ return
+ }
+ var selectGroup string
+ userGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
+ // check path is /pg/chat/completions
+ if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
+ playgroundRequest := &dto.PlayGroundRequest{}
+ err = common.UnmarshalBodyReusable(c, playgroundRequest)
+ if err != nil {
+ abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
+ return
}
- // 如果错误,而且渠道为空,说明是没有可用渠道
- abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
+ if playgroundRequest.Group != "" {
+ if !setting.GroupInUserUsableGroups(playgroundRequest.Group) && playgroundRequest.Group != userGroup {
+ abortWithOpenAiMessage(c, http.StatusForbidden, "无权访问该分组")
+ return
+ }
+ userGroup = playgroundRequest.Group
+ }
+ }
+ channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
+ if err != nil {
+ showGroup := userGroup
+ if userGroup == "auto" {
+ showGroup = fmt.Sprintf("auto(%s)", selectGroup)
+ }
+ message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(数据库一致性已被破坏,distributor): %s", showGroup, modelRequest.Model, err.Error())
+ // 如果错误,但是渠道不为空,说明是数据库一致性问题
+ //if channel != nil {
+ // common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
+ // message = "数据库一致性已被破坏,请联系管理员"
+ //}
+ abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, string(types.ErrorCodeModelNotFound))
return
}
if channel == nil {
- abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
+ abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道(distributor)", userGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound))
return
}
}
}
- c.Set(constant.ContextKeyRequestStartTime, time.Now())
+ common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
c.Next()
}
@@ -162,7 +165,19 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
c.Set("platform", string(constant.TaskPlatformSuno))
c.Set("relay_mode", relayMode)
- } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
+ } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
+ err = common.UnmarshalBodyReusable(c, &modelRequest)
+ relayMode := relayconstant.RelayModeUnknown
+ if c.Request.Method == http.MethodPost {
+ relayMode = relayconstant.RelayModeVideoSubmit
+ } else if c.Request.Method == http.MethodGet {
+ relayMode = relayconstant.RelayModeVideoFetchByID
+ shouldSelectChannel = false
+ }
+ if _, ok := c.Get("relay_mode"); !ok {
+ c.Set("relay_mode", relayMode)
+ }
+ } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
relayMode := relayconstant.RelayModeGemini
modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
@@ -210,47 +225,73 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
}
c.Set("relay_mode", relayMode)
}
+ if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
+ // playground chat completions
+ err = common.UnmarshalBodyReusable(c, &modelRequest)
+ if err != nil {
+ return nil, false, errors.New("无效的请求, " + err.Error())
+ }
+ common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
+ }
return &modelRequest, shouldSelectChannel, nil
}
-func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
+func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
c.Set("original_model", modelName) // for retry
if channel == nil {
- return
+ return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
- c.Set("channel_id", channel.Id)
- c.Set("channel_name", channel.Name)
- c.Set("channel_type", channel.Type)
- c.Set("channel_create_time", channel.CreatedTime)
- c.Set("channel_setting", channel.GetSetting())
- c.Set("param_override", channel.GetParamOverride())
- if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
- c.Set("channel_organization", *channel.OpenAIOrganization)
+ common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
+ common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
+ common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
+ common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
+ common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
+ common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings())
+ common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
+ if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
+ common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
}
- c.Set("auto_ban", channel.GetAutoBan())
- c.Set("model_mapping", channel.GetModelMapping())
- c.Set("status_code_mapping", channel.GetStatusCodeMapping())
- c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
- c.Set("base_url", channel.GetBaseURL())
+ common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
+ common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
+ common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
+
+ key, index, newAPIError := channel.GetNextEnabledKey()
+ if newAPIError != nil {
+ return newAPIError
+ }
+ if channel.ChannelInfo.IsMultiKey {
+ common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
+ common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
+ } else {
+ // 必须设置为 false,否则在重试到单个 key 的时候会导致日志显示错误
+ common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, false)
+ }
+ // c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
+ common.SetContextKey(c, constant.ContextKeyChannelKey, key)
+ common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
+
+ common.SetContextKey(c, constant.ContextKeySystemPromptOverride, false)
+
// TODO: api_version统一
switch channel.Type {
- case common.ChannelTypeAzure:
+ case constant.ChannelTypeAzure:
c.Set("api_version", channel.Other)
- case common.ChannelTypeVertexAi:
+ case constant.ChannelTypeVertexAi:
c.Set("region", channel.Other)
- case common.ChannelTypeXunfei:
+ case constant.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
- case common.ChannelTypeGemini:
+ case constant.ChannelTypeGemini:
c.Set("api_version", channel.Other)
- case common.ChannelTypeAli:
+ case constant.ChannelTypeAli:
c.Set("plugin", channel.Other)
- case common.ChannelCloudflare:
+ case constant.ChannelCloudflare:
c.Set("api_version", channel.Other)
- case common.ChannelTypeMokaAI:
+ case constant.ChannelTypeMokaAI:
c.Set("api_version", channel.Other)
- case common.ChannelTypeCoze:
+ case constant.ChannelTypeCoze:
c.Set("bot_id", channel.Other)
}
+ return nil
}
// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
diff --git a/middleware/email-verification-rate-limit.go b/middleware/email-verification-rate-limit.go
new file mode 100644
index 00000000..a7d828d9
--- /dev/null
+++ b/middleware/email-verification-rate-limit.go
@@ -0,0 +1,80 @@
+package middleware
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ EmailVerificationRateLimitMark = "EV"
+ EmailVerificationMaxRequests = 2 // 30秒内最多2次
+ EmailVerificationDuration = 30 // 30秒时间窗口
+)
+
+func redisEmailVerificationRateLimiter(c *gin.Context) {
+ ctx := context.Background()
+ rdb := common.RDB
+ key := "emailVerification:" + EmailVerificationRateLimitMark + ":" + c.ClientIP()
+
+ count, err := rdb.Incr(ctx, key).Result()
+ if err != nil {
+ // fallback
+ memoryEmailVerificationRateLimiter(c)
+ return
+ }
+
+ // 第一次设置键时设置过期时间
+ if count == 1 {
+ _ = rdb.Expire(ctx, key, time.Duration(EmailVerificationDuration)*time.Second).Err()
+ }
+
+ // 检查是否超出限制
+ if count <= int64(EmailVerificationMaxRequests) {
+ c.Next()
+ return
+ }
+
+ // 获取剩余等待时间
+ ttl, err := rdb.TTL(ctx, key).Result()
+ waitSeconds := int64(EmailVerificationDuration)
+ if err == nil && ttl > 0 {
+ waitSeconds = int64(ttl.Seconds())
+ }
+
+ c.JSON(http.StatusTooManyRequests, gin.H{
+ "success": false,
+ "message": fmt.Sprintf("发送过于频繁,请等待 %d 秒后再试", waitSeconds),
+ })
+ c.Abort()
+}
+
+func memoryEmailVerificationRateLimiter(c *gin.Context) {
+ key := EmailVerificationRateLimitMark + ":" + c.ClientIP()
+
+ if !inMemoryRateLimiter.Request(key, EmailVerificationMaxRequests, EmailVerificationDuration) {
+ c.JSON(http.StatusTooManyRequests, gin.H{
+ "success": false,
+ "message": "发送过于频繁,请稍后再试",
+ })
+ c.Abort()
+ return
+ }
+
+ c.Next()
+}
+
+func EmailVerificationRateLimit() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ if common.RedisEnabled {
+ redisEmailVerificationRateLimiter(c)
+ } else {
+ inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
+ memoryEmailVerificationRateLimiter(c)
+ }
+ }
+}
diff --git a/middleware/jimeng_adapter.go b/middleware/jimeng_adapter.go
new file mode 100644
index 00000000..ce5e1467
--- /dev/null
+++ b/middleware/jimeng_adapter.go
@@ -0,0 +1,66 @@
+package middleware
+
+import (
+ "bytes"
+ "encoding/json"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ relayconstant "one-api/relay/constant"
+)
+
+func JimengRequestConvert() func(c *gin.Context) {
+ return func(c *gin.Context) {
+ action := c.Query("Action")
+ if action == "" {
+ abortWithOpenAiMessage(c, http.StatusBadRequest, "Action query parameter is required")
+ return
+ }
+
+ // Handle Jimeng official API request
+ var originalReq map[string]interface{}
+ if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
+ abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request body")
+ return
+ }
+ model, _ := originalReq["req_key"].(string)
+ prompt, _ := originalReq["prompt"].(string)
+
+ unifiedReq := map[string]interface{}{
+ "model": model,
+ "prompt": prompt,
+ "metadata": originalReq,
+ }
+
+ jsonData, err := json.Marshal(unifiedReq)
+ if err != nil {
+ abortWithOpenAiMessage(c, http.StatusInternalServerError, "Failed to marshal request body")
+ return
+ }
+
+ // Update request body
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
+ c.Set(common.KeyRequestBody, jsonData)
+
+ if image, ok := originalReq["image"]; !ok || image == "" {
+ c.Set("action", constant.TaskActionTextGenerate)
+ }
+
+ c.Request.URL.Path = "/v1/video/generations"
+
+ if action == "CVSync2AsyncGetResult" {
+ taskId, ok := originalReq["task_id"].(string)
+ if !ok || taskId == "" {
+ abortWithOpenAiMessage(c, http.StatusBadRequest, "task_id is required for CVSync2AsyncGetResult")
+ return
+ }
+ c.Request.URL.Path = "/v1/video/generations/" + taskId
+ c.Request.Method = http.MethodGet
+ c.Set("task_id", taskId)
+ c.Set("relay_mode", relayconstant.RelayModeVideoFetchByID)
+ }
+ c.Next()
+ }
+}
diff --git a/middleware/kling_adapter.go b/middleware/kling_adapter.go
new file mode 100644
index 00000000..20973c9f
--- /dev/null
+++ b/middleware/kling_adapter.go
@@ -0,0 +1,51 @@
+package middleware
+
+import (
+ "bytes"
+ "encoding/json"
+ "io"
+ "one-api/common"
+ "one-api/constant"
+
+ "github.com/gin-gonic/gin"
+)
+
+func KlingRequestConvert() func(c *gin.Context) {
+ return func(c *gin.Context) {
+ var originalReq map[string]interface{}
+ if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
+ c.Next()
+ return
+ }
+
+ // Support both model_name and model fields
+ model, _ := originalReq["model_name"].(string)
+ if model == "" {
+ model, _ = originalReq["model"].(string)
+ }
+ prompt, _ := originalReq["prompt"].(string)
+
+ unifiedReq := map[string]interface{}{
+ "model": model,
+ "prompt": prompt,
+ "metadata": originalReq,
+ }
+
+ jsonData, err := json.Marshal(unifiedReq)
+ if err != nil {
+ c.Next()
+ return
+ }
+
+ // Rewrite request body and path
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
+ c.Request.URL.Path = "/v1/video/generations"
+ if image, ok := originalReq["image"]; !ok || image == "" {
+ c.Set("action", constant.TaskActionTextGenerate)
+ }
+
+ // We have to reset the request body for the next handlers
+ c.Set(common.KeyRequestBody, jsonData)
+ c.Next()
+ }
+}
diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go
index 34caa59b..14d9a737 100644
--- a/middleware/model-rate-limit.go
+++ b/middleware/model-rate-limit.go
@@ -177,9 +177,9 @@ func ModelRequestRateLimit() func(c *gin.Context) {
successMaxCount := setting.ModelRequestRateLimitSuccessCount
// 获取分组
- group := c.GetString("token_group")
+ group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
if group == "" {
- group = c.GetString(constant.ContextKeyUserGroup)
+ group = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
}
//获取分组的限流配置
diff --git a/middleware/recover.go b/middleware/recover.go
index 51fc7190..d78c8137 100644
--- a/middleware/recover.go
+++ b/middleware/recover.go
@@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
- common.SysError(fmt.Sprintf("panic detected: %v", err))
- common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
+ common.SysLog(fmt.Sprintf("panic detected: %v", err))
+ common.SysLog(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
diff --git a/middleware/stats.go b/middleware/stats.go
new file mode 100644
index 00000000..1c97983f
--- /dev/null
+++ b/middleware/stats.go
@@ -0,0 +1,41 @@
+package middleware
+
+import (
+ "sync/atomic"
+
+ "github.com/gin-gonic/gin"
+)
+
+// HTTPStats 存储HTTP统计信息
+type HTTPStats struct {
+ activeConnections int64
+}
+
+var globalStats = &HTTPStats{}
+
+// StatsMiddleware 统计中间件
+func StatsMiddleware() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ // 增加活跃连接数
+ atomic.AddInt64(&globalStats.activeConnections, 1)
+
+ // 确保在请求结束时减少连接数
+ defer func() {
+ atomic.AddInt64(&globalStats.activeConnections, -1)
+ }()
+
+ c.Next()
+ }
+}
+
+// StatsInfo 统计信息结构
+type StatsInfo struct {
+ ActiveConnections int64 `json:"active_connections"`
+}
+
+// GetStats 获取统计信息
+func GetStats() StatsInfo {
+ return StatsInfo{
+ ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
+ }
+}
\ No newline at end of file
diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go
index 26688810..106a7278 100644
--- a/middleware/turnstile-check.go
+++ b/middleware/turnstile-check.go
@@ -37,7 +37,7 @@ func TurnstileCheck() gin.HandlerFunc {
"remoteip": {c.ClientIP()},
})
if err != nil {
- common.SysError(err.Error())
+ common.SysLog(err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@@ -49,7 +49,7 @@ func TurnstileCheck() gin.HandlerFunc {
var res turnstileCheckResponse
err = json.NewDecoder(rawRes.Body).Decode(&res)
if err != nil {
- common.SysError(err.Error())
+ common.SysLog(err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
diff --git a/middleware/utils.go b/middleware/utils.go
index 082f5657..77d1eb80 100644
--- a/middleware/utils.go
+++ b/middleware/utils.go
@@ -4,18 +4,24 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
+ "one-api/logger"
)
-func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
+func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string, code ...string) {
+ codeStr := ""
+ if len(code) > 0 {
+ codeStr = code[0]
+ }
userId := c.GetInt("id")
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
"type": "new_api_error",
+ "code": codeStr,
},
})
c.Abort()
- common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
+ logger.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
}
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
@@ -25,5 +31,5 @@ func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, descri
"code": code,
})
c.Abort()
- common.LogError(c.Request.Context(), description)
+ logger.LogError(c.Request.Context(), description)
}
diff --git a/model/ability.go b/model/ability.go
index 38b0bd73..123fc7be 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -5,9 +5,11 @@ import (
"fmt"
"one-api/common"
"strings"
+ "sync"
"github.com/samber/lo"
"gorm.io/gorm"
+ "gorm.io/gorm/clause"
)
type Ability struct {
@@ -20,10 +22,25 @@ type Ability struct {
Tag *string `json:"tag" gorm:"index"`
}
-func GetGroupModels(group string) []string {
+type AbilityWithChannel struct {
+ Ability
+ ChannelType int `json:"channel_type"`
+}
+
+func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
+ var abilities []AbilityWithChannel
+ err := DB.Table("abilities").
+ Select("abilities.*, channels.type as channel_type").
+ Joins("left join channels on abilities.channel_id = channels.id").
+ Where("abilities.enabled = ?", true).
+ Scan(&abilities).Error
+ return abilities, err
+}
+
+func GetGroupEnabledModels(group string) []string {
var models []string
// Find distinct models
- DB.Table("abilities").Where(groupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
+ DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
return models
}
@@ -41,16 +58,12 @@ func GetAllEnableAbilities() []Ability {
}
func getPriority(group string, model string, retry int) (int, error) {
- trueVal := "1"
- if common.UsingPostgreSQL {
- trueVal = "true"
- }
var priorities []int
err := DB.Model(&Ability{}).
Select("DISTINCT(priority)").
- Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
- Order("priority DESC"). // 按优先级降序排序
+ Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
+ Order("priority DESC"). // 按优先级降序排序
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
if err != nil {
@@ -74,30 +87,29 @@ func getPriority(group string, model string, retry int) (int, error) {
return priorityToUse, nil
}
-func getChannelQuery(group string, model string, retry int) *gorm.DB {
- trueVal := "1"
- if common.UsingPostgreSQL {
- trueVal = "true"
- }
- maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
- channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
+func getChannelQuery(group string, model string, retry int) (*gorm.DB, error) {
+ maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
+ channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
if retry != 0 {
priority, err := getPriority(group, model, retry)
if err != nil {
- common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
+ return nil, err
} else {
- channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
+ channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
}
}
- return channelQuery
+ return channelQuery, nil
}
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
var abilities []Ability
var err error = nil
- channelQuery := getChannelQuery(group, model, retry)
+ channelQuery, err := getChannelQuery(group, model, retry)
+ if err != nil {
+ return nil, err
+ }
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("weight DESC").Find(&abilities).Error
} else {
@@ -124,18 +136,24 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
}
}
} else {
- return nil, errors.New("channel not found")
+ return nil, nil
}
err = DB.First(&channel, "id = ?", channel.Id).Error
return &channel, err
}
-func (channel *Channel) AddAbilities() error {
+func (channel *Channel) AddAbilities(tx *gorm.DB) error {
models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
+ abilitySet := make(map[string]struct{})
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {
+ key := group + "|" + model
+ if _, exists := abilitySet[key]; exists {
+ continue
+ }
+ abilitySet[key] = struct{}{}
ability := Ability{
Group: group,
Model: model,
@@ -151,8 +169,13 @@ func (channel *Channel) AddAbilities() error {
if len(abilities) == 0 {
return nil
}
+ // choose DB or provided tx
+ useDB := DB
+ if tx != nil {
+ useDB = tx
+ }
for _, chunk := range lo.Chunk(abilities, 50) {
- err := DB.Create(&chunk).Error
+ err := useDB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
if err != nil {
return err
}
@@ -194,9 +217,15 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
// Then add new abilities
models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
+ abilitySet := make(map[string]struct{})
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {
+ key := group + "|" + model
+ if _, exists := abilitySet[key]; exists {
+ continue
+ }
+ abilitySet[key] = struct{}{}
ability := Ability{
Group: group,
Model: model,
@@ -212,7 +241,7 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
if len(abilities) > 0 {
for _, chunk := range lo.Chunk(abilities, 50) {
- err = tx.Create(&chunk).Error
+ err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
if err != nil {
if isNewTx {
tx.Rollback()
@@ -252,74 +281,60 @@ func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uin
return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
}
-func FixAbility() (int, error) {
- var channelIds []int
- count := 0
- // Find all channel ids from channel table
- err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error
- if err != nil {
- common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
- return 0, err
- }
+var fixLock = sync.Mutex{}
- // Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
- if len(channelIds) > 0 {
- // Process deletion in chunks to avoid "too many placeholders" error
- for _, chunk := range lo.Chunk(channelIds, 100) {
- err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
- if err != nil {
- common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
- return 0, err
- }
+func FixAbility() (int, int, error) {
+ lock := fixLock.TryLock()
+ if !lock {
+ return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
+ }
+ defer fixLock.Unlock()
+
+ // truncate abilities table
+ if common.UsingSQLite {
+ err := DB.Exec("DELETE FROM abilities").Error
+ if err != nil {
+ common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
+ return 0, 0, err
}
} else {
- // If no channels exist, delete all abilities
- err = DB.Delete(&Ability{}).Error
+ err := DB.Exec("TRUNCATE TABLE abilities").Error
if err != nil {
- common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
- return 0, err
+ common.SysLog(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
+ return 0, 0, err
}
- common.SysLog("Delete all abilities successfully")
- return 0, nil
}
-
- common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
- count += len(channelIds)
-
- // Use channelIds to find channel not in abilities table
- var abilityChannelIds []int
- err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
+ var channels []*Channel
+ // Find all channels
+ err := DB.Model(&Channel{}).Find(&channels).Error
if err != nil {
- common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
- return count, err
+ return 0, 0, err
}
-
- var channels []Channel
- if len(abilityChannelIds) == 0 {
- err = DB.Find(&channels).Error
- } else {
- // Process query in chunks to avoid "too many placeholders" error
- err = nil
- for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
- var channelsChunk []Channel
- err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
- if err != nil {
- common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
- return count, err
- }
- channels = append(channels, channelsChunk...)
- }
+ if len(channels) == 0 {
+ return 0, 0, nil
}
-
- for _, channel := range channels {
- err := channel.UpdateAbilities(nil)
+ successCount := 0
+ failCount := 0
+ for _, chunk := range lo.Chunk(channels, 50) {
+ ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
+ // Delete all abilities of this channel
+ err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
if err != nil {
- common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
- } else {
- common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id))
- count++
+ common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
+ failCount += len(chunk)
+ continue
+ }
+ // Then add new abilities
+ for _, channel := range chunk {
+ err = channel.AddAbilities(nil)
+ if err != nil {
+ common.SysLog(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
+ failCount++
+ } else {
+ successCount++
+ }
}
}
InitChannelCache()
- return count, nil
+ return successCount, failCount, nil
}
diff --git a/model/cache.go b/model/cache.go
deleted file mode 100644
index e2f83e22..00000000
--- a/model/cache.go
+++ /dev/null
@@ -1,166 +0,0 @@
-package model
-
-import (
- "errors"
- "fmt"
- "math/rand"
- "one-api/common"
- "sort"
- "strings"
- "sync"
- "time"
-)
-
-var group2model2channels map[string]map[string][]*Channel
-var channelsIDM map[int]*Channel
-var channelSyncLock sync.RWMutex
-
-func InitChannelCache() {
- if !common.MemoryCacheEnabled {
- return
- }
- newChannelId2channel := make(map[int]*Channel)
- var channels []*Channel
- DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
- for _, channel := range channels {
- newChannelId2channel[channel.Id] = channel
- }
- var abilities []*Ability
- DB.Find(&abilities)
- groups := make(map[string]bool)
- for _, ability := range abilities {
- groups[ability.Group] = true
- }
- newGroup2model2channels := make(map[string]map[string][]*Channel)
- newChannelsIDM := make(map[int]*Channel)
- for group := range groups {
- newGroup2model2channels[group] = make(map[string][]*Channel)
- }
- for _, channel := range channels {
- newChannelsIDM[channel.Id] = channel
- groups := strings.Split(channel.Group, ",")
- for _, group := range groups {
- models := strings.Split(channel.Models, ",")
- for _, model := range models {
- if _, ok := newGroup2model2channels[group][model]; !ok {
- newGroup2model2channels[group][model] = make([]*Channel, 0)
- }
- newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
- }
- }
- }
-
- // sort by priority
- for group, model2channels := range newGroup2model2channels {
- for model, channels := range model2channels {
- sort.Slice(channels, func(i, j int) bool {
- return channels[i].GetPriority() > channels[j].GetPriority()
- })
- newGroup2model2channels[group][model] = channels
- }
- }
-
- channelSyncLock.Lock()
- group2model2channels = newGroup2model2channels
- channelsIDM = newChannelsIDM
- channelSyncLock.Unlock()
- common.SysLog("channels synced from database")
-}
-
-func SyncChannelCache(frequency int) {
- for {
- time.Sleep(time.Duration(frequency) * time.Second)
- common.SysLog("syncing channels from database")
- InitChannelCache()
- }
-}
-
-func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
- if strings.HasPrefix(model, "gpt-4-gizmo") {
- model = "gpt-4-gizmo-*"
- }
- if strings.HasPrefix(model, "gpt-4o-gizmo") {
- model = "gpt-4o-gizmo-*"
- }
-
- // if memory cache is disabled, get channel directly from database
- if !common.MemoryCacheEnabled {
- return GetRandomSatisfiedChannel(group, model, retry)
- }
-
- channelSyncLock.RLock()
- channels := group2model2channels[group][model]
- channelSyncLock.RUnlock()
-
- if len(channels) == 0 {
- return nil, errors.New("channel not found")
- }
-
- uniquePriorities := make(map[int]bool)
- for _, channel := range channels {
- uniquePriorities[int(channel.GetPriority())] = true
- }
- var sortedUniquePriorities []int
- for priority := range uniquePriorities {
- sortedUniquePriorities = append(sortedUniquePriorities, priority)
- }
- sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
-
- if retry >= len(uniquePriorities) {
- retry = len(uniquePriorities) - 1
- }
- targetPriority := int64(sortedUniquePriorities[retry])
-
- // get the priority for the given retry number
- var targetChannels []*Channel
- for _, channel := range channels {
- if channel.GetPriority() == targetPriority {
- targetChannels = append(targetChannels, channel)
- }
- }
-
- // 平滑系数
- smoothingFactor := 10
- // Calculate the total weight of all channels up to endIdx
- totalWeight := 0
- for _, channel := range targetChannels {
- totalWeight += channel.GetWeight() + smoothingFactor
- }
- // Generate a random value in the range [0, totalWeight)
- randomWeight := rand.Intn(totalWeight)
-
- // Find a channel based on its weight
- for _, channel := range targetChannels {
- randomWeight -= channel.GetWeight() + smoothingFactor
- if randomWeight < 0 {
- return channel, nil
- }
- }
- // return null if no channel is not found
- return nil, errors.New("channel not found")
-}
-
-func CacheGetChannel(id int) (*Channel, error) {
- if !common.MemoryCacheEnabled {
- return GetChannelById(id, true)
- }
- channelSyncLock.RLock()
- defer channelSyncLock.RUnlock()
-
- c, ok := channelsIDM[id]
- if !ok {
- return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
- }
- return c, nil
-}
-
-func CacheUpdateChannelStatus(id int, status int) {
- if !common.MemoryCacheEnabled {
- return
- }
- channelSyncLock.Lock()
- defer channelSyncLock.Unlock()
- if channel, ok := channelsIDM[id]; ok {
- channel.Status = status
- }
-}
diff --git a/model/channel.go b/model/channel.go
index ed7a0a7e..a9a23481 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -1,11 +1,19 @@
package model
import (
+ "database/sql/driver"
"encoding/json"
+ "errors"
+ "fmt"
+ "math/rand"
"one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/types"
"strings"
"sync"
+ "github.com/samber/lo"
"gorm.io/gorm"
)
@@ -34,9 +42,148 @@ type Channel struct {
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"`
OtherInfo string `json:"other_info"`
+ OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置
Tag *string `json:"tag" gorm:"index"`
- Setting *string `json:"setting" gorm:"type:text"`
+ Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
ParamOverride *string `json:"param_override" gorm:"type:text"`
+ // add after v0.8.5
+ ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
+
+ // cache info
+ Keys []string `json:"-" gorm:"-"`
+}
+
+type ChannelInfo struct {
+ IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
+ MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
+ MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
+ MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason,omitempty"` // key禁用原因列表,key index -> reason
+ MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time,omitempty"` // key禁用时间列表,key index -> time
+ MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
+ MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
+}
+
+// Value implements driver.Valuer interface
+func (c ChannelInfo) Value() (driver.Value, error) {
+ return common.Marshal(&c)
+}
+
+// Scan implements sql.Scanner interface
+func (c *ChannelInfo) Scan(value interface{}) error {
+ bytesValue, _ := value.([]byte)
+ return common.Unmarshal(bytesValue, c)
+}
+
+func (channel *Channel) GetKeys() []string {
+ if channel.Key == "" {
+ return []string{}
+ }
+ if len(channel.Keys) > 0 {
+ return channel.Keys
+ }
+ trimmed := strings.TrimSpace(channel.Key)
+ // If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
+ if strings.HasPrefix(trimmed, "[") {
+ var arr []json.RawMessage
+ if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
+ res := make([]string, len(arr))
+ for i, v := range arr {
+ res[i] = string(v)
+ }
+ return res
+ }
+ }
+ // Otherwise, fall back to splitting by newline
+ keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
+ return keys
+}
+
+func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
+ // If not in multi-key mode, return the original key string directly.
+ if !channel.ChannelInfo.IsMultiKey {
+ return channel.Key, 0, nil
+ }
+
+ // Obtain all keys (split by \n)
+ keys := channel.GetKeys()
+ if len(keys) == 0 {
+ // No keys available, return error, should disable the channel
+ return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
+ }
+
+ statusList := channel.ChannelInfo.MultiKeyStatusList
+ // helper to get key status, default to enabled when missing
+ getStatus := func(idx int) int {
+ if statusList == nil {
+ return common.ChannelStatusEnabled
+ }
+ if status, ok := statusList[idx]; ok {
+ return status
+ }
+ return common.ChannelStatusEnabled
+ }
+
+ // Collect indexes of enabled keys
+ enabledIdx := make([]int, 0, len(keys))
+ for i := range keys {
+ if getStatus(i) == common.ChannelStatusEnabled {
+ enabledIdx = append(enabledIdx, i)
+ }
+ }
+ // If no specific status list or none enabled, fall back to first key
+ if len(enabledIdx) == 0 {
+ return keys[0], 0, nil
+ }
+
+ switch channel.ChannelInfo.MultiKeyMode {
+ case constant.MultiKeyModeRandom:
+ // Randomly pick one enabled key
+ selectedIdx := enabledIdx[rand.Intn(len(enabledIdx))]
+ return keys[selectedIdx], selectedIdx, nil
+ case constant.MultiKeyModePolling:
+ // Use channel-specific lock to ensure thread-safe polling
+ lock := GetChannelPollingLock(channel.Id)
+ lock.Lock()
+ defer lock.Unlock()
+
+ channelInfo, err := CacheGetChannelInfo(channel.Id)
+ if err != nil {
+ return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
+ }
+ //println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
+ defer func() {
+ if common.DebugEnabled {
+ println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex))
+ }
+ if !common.MemoryCacheEnabled {
+ _ = channel.SaveChannelInfo()
+ } else {
+ // CacheUpdateChannel(channel)
+ }
+ }()
+ // Start from the saved polling index and look for the next enabled key
+ start := channelInfo.MultiKeyPollingIndex
+ if start < 0 || start >= len(keys) {
+ start = 0
+ }
+ for i := 0; i < len(keys); i++ {
+ idx := (start + i) % len(keys)
+ if getStatus(idx) == common.ChannelStatusEnabled {
+ // update polling index for next call (point to the next position)
+ channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
+ return keys[idx], idx, nil
+ }
+ }
+ // Fallback – should not happen, but return first enabled key
+ return keys[enabledIdx[0]], enabledIdx[0], nil
+ default:
+ // Unknown mode, default to first enabled key (or original key string)
+ return keys[enabledIdx[0]], enabledIdx[0], nil
+ }
+}
+
+func (channel *Channel) SaveChannelInfo() error {
+ return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error
}
func (channel *Channel) GetModels() []string {
@@ -60,9 +207,9 @@ func (channel *Channel) GetGroups() []string {
func (channel *Channel) GetOtherInfo() map[string]interface{} {
otherInfo := make(map[string]interface{})
if channel.OtherInfo != "" {
- err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
+ err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
if err != nil {
- common.SysError("failed to unmarshal other info: " + err.Error())
+ common.SysLog(fmt.Sprintf("failed to unmarshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err))
}
}
return otherInfo
@@ -71,7 +218,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} {
func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
otherInfoBytes, err := json.Marshal(otherInfo)
if err != nil {
- common.SysError("failed to marshal other info: " + err.Error())
+ common.SysLog(fmt.Sprintf("failed to marshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err))
return
}
channel.OtherInfo = string(otherInfoBytes)
@@ -145,7 +292,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
}
// 构造基础查询
- baseQuery := DB.Model(&Channel{}).Omit(keyCol)
+ baseQuery := DB.Model(&Channel{}).Omit("key")
// 构造WHERE子句
var whereClause string
@@ -153,15 +300,15 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
if group != "" && group != "null" {
var groupCondition string
if common.UsingMySQL {
- groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
+ groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
- groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
+ groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
}
- whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else {
- whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
}
@@ -174,49 +321,71 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
}
func GetChannelById(id int, selectAll bool) (*Channel, error) {
- channel := Channel{Id: id}
+ channel := &Channel{Id: id}
var err error = nil
if selectAll {
- err = DB.First(&channel, "id = ?", id).Error
+ err = DB.First(channel, "id = ?", id).Error
} else {
- err = DB.Omit("key").First(&channel, "id = ?", id).Error
+ err = DB.Omit("key").First(channel, "id = ?", id).Error
}
- return &channel, err
+ if err != nil {
+ return nil, err
+ }
+ if channel == nil {
+ return nil, errors.New("channel not found")
+ }
+ return channel, nil
}
func BatchInsertChannels(channels []Channel) error {
- var err error
- err = DB.Create(&channels).Error
- if err != nil {
- return err
+ if len(channels) == 0 {
+ return nil
}
- for _, channel_ := range channels {
- err = channel_.AddAbilities()
- if err != nil {
+ tx := DB.Begin()
+ if tx.Error != nil {
+ return tx.Error
+ }
+ defer func() {
+ if r := recover(); r != nil {
+ tx.Rollback()
+ }
+ }()
+
+ for _, chunk := range lo.Chunk(channels, 50) {
+ if err := tx.Create(&chunk).Error; err != nil {
+ tx.Rollback()
return err
}
+ for _, channel_ := range chunk {
+ if err := channel_.AddAbilities(tx); err != nil {
+ tx.Rollback()
+ return err
+ }
+ }
}
- return nil
+ return tx.Commit().Error
}
func BatchDeleteChannels(ids []int) error {
- //使用事务 删除channel表和channel_ability表
+ if len(ids) == 0 {
+ return nil
+ }
+ // 使用事务 分批删除channel表和abilities表
tx := DB.Begin()
- err := tx.Where("id in (?)", ids).Delete(&Channel{}).Error
- if err != nil {
- // 回滚事务
- tx.Rollback()
- return err
+ if tx.Error != nil {
+ return tx.Error
}
- err = tx.Where("channel_id in (?)", ids).Delete(&Ability{}).Error
- if err != nil {
- // 回滚事务
- tx.Rollback()
- return err
+ for _, chunk := range lo.Chunk(ids, 200) {
+ if err := tx.Where("id in (?)", chunk).Delete(&Channel{}).Error; err != nil {
+ tx.Rollback()
+ return err
+ }
+ if err := tx.Where("channel_id in (?)", chunk).Delete(&Ability{}).Error; err != nil {
+ tx.Rollback()
+ return err
+ }
}
- // 提交事务
- tx.Commit()
- return err
+ return tx.Commit().Error
}
func (channel *Channel) GetPriority() int64 {
@@ -237,7 +406,11 @@ func (channel *Channel) GetBaseURL() string {
if channel.BaseURL == nil {
return ""
}
- return *channel.BaseURL
+ url := *channel.BaseURL
+ if url == "" {
+ url = constant.ChannelBaseURLs[channel.Type]
+ }
+ return url
}
func (channel *Channel) GetModelMapping() string {
@@ -260,11 +433,49 @@ func (channel *Channel) Insert() error {
if err != nil {
return err
}
- err = channel.AddAbilities()
+ err = channel.AddAbilities(nil)
return err
}
func (channel *Channel) Update() error {
+ // If this is a multi-key channel, recalculate MultiKeySize based on the current key list to avoid inconsistency after editing keys
+ if channel.ChannelInfo.IsMultiKey {
+ var keyStr string
+ if channel.Key != "" {
+ keyStr = channel.Key
+ } else {
+ // If key is not provided, read the existing key from the database
+ if existing, err := GetChannelById(channel.Id, true); err == nil {
+ keyStr = existing.Key
+ }
+ }
+ // Parse the key list (supports newline separation or JSON array)
+ keys := []string{}
+ if keyStr != "" {
+ trimmed := strings.TrimSpace(keyStr)
+ if strings.HasPrefix(trimmed, "[") {
+ var arr []json.RawMessage
+ if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
+ keys = make([]string, len(arr))
+ for i, v := range arr {
+ keys[i] = string(v)
+ }
+ }
+ }
+ if len(keys) == 0 { // fallback to newline split
+ keys = strings.Split(strings.Trim(keyStr, "\n"), "\n")
+ }
+ }
+ channel.ChannelInfo.MultiKeySize = len(keys)
+ // Clean up status data that exceeds the new key count to prevent index out of range
+ if channel.ChannelInfo.MultiKeyStatusList != nil {
+ for idx := range channel.ChannelInfo.MultiKeyStatusList {
+ if idx >= channel.ChannelInfo.MultiKeySize {
+ delete(channel.ChannelInfo.MultiKeyStatusList, idx)
+ }
+ }
+ }
+ }
var err error
err = DB.Model(channel).Updates(channel).Error
if err != nil {
@@ -281,7 +492,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
ResponseTime: int(responseTime),
}).Error
if err != nil {
- common.SysError("failed to update response time: " + err.Error())
+ common.SysLog(fmt.Sprintf("failed to update response time: channel_id=%d, error=%v", channel.Id, err))
}
}
@@ -291,7 +502,7 @@ func (channel *Channel) UpdateBalance(balance float64) {
Balance: balance,
}).Error
if err != nil {
- common.SysError("failed to update balance: " + err.Error())
+ common.SysLog(fmt.Sprintf("failed to update balance: channel_id=%d, error=%v", channel.Id, err))
}
}
@@ -307,51 +518,135 @@ func (channel *Channel) Delete() error {
var channelStatusLock sync.Mutex
-func UpdateChannelStatusById(id int, status int, reason string) bool {
+// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
+var channelPollingLocks sync.Map
+
+// GetChannelPollingLock returns or creates a mutex for the given channel ID
+func GetChannelPollingLock(channelId int) *sync.Mutex {
+ if lock, exists := channelPollingLocks.Load(channelId); exists {
+ return lock.(*sync.Mutex)
+ }
+ // Create new lock for this channel
+ newLock := &sync.Mutex{}
+ actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock)
+ return actual.(*sync.Mutex)
+}
+
+// CleanupChannelPollingLocks removes locks for channels that no longer exist
+// This is optional and can be called periodically to prevent memory leaks
+func CleanupChannelPollingLocks() {
+ var activeChannelIds []int
+ DB.Model(&Channel{}).Pluck("id", &activeChannelIds)
+
+ activeChannelSet := make(map[int]bool)
+ for _, id := range activeChannelIds {
+ activeChannelSet[id] = true
+ }
+
+ channelPollingLocks.Range(func(key, value interface{}) bool {
+ channelId := key.(int)
+ if !activeChannelSet[channelId] {
+ channelPollingLocks.Delete(channelId)
+ }
+ return true
+ })
+}
+
+func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason string) {
+ keys := channel.GetKeys()
+ if len(keys) == 0 {
+ channel.Status = status
+ } else {
+ var keyIndex int
+ for i, key := range keys {
+ if key == usingKey {
+ keyIndex = i
+ break
+ }
+ }
+ if channel.ChannelInfo.MultiKeyStatusList == nil {
+ channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
+ }
+ if status == common.ChannelStatusEnabled {
+ delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
+ } else {
+ channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
+ if channel.ChannelInfo.MultiKeyDisabledReason == nil {
+ channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
+ }
+ if channel.ChannelInfo.MultiKeyDisabledTime == nil {
+ channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
+ }
+ channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason
+ channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp()
+ }
+ if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
+ channel.Status = common.ChannelStatusAutoDisabled
+ info := channel.GetOtherInfo()
+ info["status_reason"] = "All keys are disabled"
+ info["status_time"] = common.GetTimestamp()
+ channel.SetOtherInfo(info)
+ }
+ }
+}
+
+func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
if common.MemoryCacheEnabled {
channelStatusLock.Lock()
defer channelStatusLock.Unlock()
- channelCache, _ := CacheGetChannel(id)
- // 如果缓存渠道存在,且状态已是目标状态,直接返回
- if channelCache != nil && channelCache.Status == status {
+ channelCache, _ := CacheGetChannel(channelId)
+ if channelCache == nil {
return false
}
- // 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
- if channelCache == nil && status != common.ChannelStatusEnabled {
- return false
+ if channelCache.ChannelInfo.IsMultiKey {
+ // 如果是多Key模式,更新缓存中的状态
+ handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
+ //CacheUpdateChannel(channelCache)
+ //return true
+ } else {
+ // 如果缓存渠道存在,且状态已是目标状态,直接返回
+ if channelCache.Status == status {
+ return false
+ }
+ CacheUpdateChannelStatus(channelId, status)
}
- CacheUpdateChannelStatus(id, status)
}
- err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
+
+ shouldUpdateAbilities := false
+ defer func() {
+ if shouldUpdateAbilities {
+ err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
+ if err != nil {
+ common.SysLog(fmt.Sprintf("failed to update ability status: channel_id=%d, error=%v", channelId, err))
+ }
+ }
+ }()
+ channel, err := GetChannelById(channelId, true)
if err != nil {
- common.SysError("failed to update ability status: " + err.Error())
return false
- }
- channel, err := GetChannelById(id, true)
- if err != nil {
- // find channel by id error, directly update status
- result := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status)
- if result.Error != nil {
- common.SysError("failed to update channel status: " + result.Error.Error())
- return false
- }
- if result.RowsAffected == 0 {
- return false
- }
} else {
if channel.Status == status {
return false
}
- // find channel by id success, update status and other info
- info := channel.GetOtherInfo()
- info["status_reason"] = reason
- info["status_time"] = common.GetTimestamp()
- channel.SetOtherInfo(info)
- channel.Status = status
+
+ if channel.ChannelInfo.IsMultiKey {
+ beforeStatus := channel.Status
+ handlerMultiKeyUpdate(channel, usingKey, status, reason)
+ if beforeStatus != channel.Status {
+ shouldUpdateAbilities = true
+ }
+ } else {
+ info := channel.GetOtherInfo()
+ info["status_reason"] = reason
+ info["status_time"] = common.GetTimestamp()
+ channel.SetOtherInfo(info)
+ channel.Status = status
+ shouldUpdateAbilities = true
+ }
err = channel.Save()
if err != nil {
- common.SysError("failed to update channel status: " + err.Error())
+ common.SysLog(fmt.Sprintf("failed to update channel status: channel_id=%d, status=%d, error=%v", channel.Id, status, err))
return false
}
}
@@ -413,7 +708,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
for _, channel := range channels {
err = channel.UpdateAbilities(nil)
if err != nil {
- common.SysError("failed to update abilities: " + err.Error())
+ common.SysLog(fmt.Sprintf("failed to update abilities: channel_id=%d, tag=%s, error=%v", channel.Id, channel.GetTag(), err))
}
}
}
@@ -437,7 +732,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
func updateChannelUsedQuota(id int, quota int) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil {
- common.SysError("failed to update channel used quota: " + err.Error())
+ common.SysLog(fmt.Sprintf("failed to update channel used quota: channel_id=%d, delta_quota=%d, error=%v", id, quota, err))
}
}
@@ -478,7 +773,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
}
// 构造基础查询
- baseQuery := DB.Model(&Channel{}).Omit(keyCol)
+ baseQuery := DB.Model(&Channel{}).Omit("key")
// 构造WHERE子句
var whereClause string
@@ -486,15 +781,15 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
if group != "" && group != "null" {
var groupCondition string
if common.UsingMySQL {
- groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
+ groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
} else {
// sqlite, PostgreSQL
- groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
+ groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
}
- whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
} else {
- whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
}
@@ -514,32 +809,67 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
return tags, nil
}
-func (channel *Channel) GetSetting() map[string]interface{} {
- setting := make(map[string]interface{})
+func (channel *Channel) ValidateSettings() error {
+ channelParams := &dto.ChannelSettings{}
if channel.Setting != nil && *channel.Setting != "" {
- err := json.Unmarshal([]byte(*channel.Setting), &setting)
+ err := common.Unmarshal([]byte(*channel.Setting), channelParams)
if err != nil {
- common.SysError("failed to unmarshal setting: " + err.Error())
+ return err
+ }
+ }
+ return nil
+}
+
+func (channel *Channel) GetSetting() dto.ChannelSettings {
+ setting := dto.ChannelSettings{}
+ if channel.Setting != nil && *channel.Setting != "" {
+ err := common.Unmarshal([]byte(*channel.Setting), &setting)
+ if err != nil {
+ common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err))
+ channel.Setting = nil // 清空设置以避免后续错误
+ _ = channel.Save() // 保存修改
}
}
return setting
}
-func (channel *Channel) SetSetting(setting map[string]interface{}) {
- settingBytes, err := json.Marshal(setting)
+func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
+ settingBytes, err := common.Marshal(setting)
if err != nil {
- common.SysError("failed to marshal setting: " + err.Error())
+ common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err))
return
}
channel.Setting = common.GetPointer[string](string(settingBytes))
}
+func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
+ setting := dto.ChannelOtherSettings{}
+ if channel.OtherSettings != "" {
+ err := common.UnmarshalJsonStr(channel.OtherSettings, &setting)
+ if err != nil {
+ common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err))
+ channel.OtherSettings = "{}" // 清空设置以避免后续错误
+ _ = channel.Save() // 保存修改
+ }
+ }
+ return setting
+}
+
+func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) {
+ settingBytes, err := common.Marshal(setting)
+ if err != nil {
+ common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err))
+ return
+ }
+ channel.OtherSettings = string(settingBytes)
+}
+
func (channel *Channel) GetParamOverride() map[string]interface{} {
paramOverride := make(map[string]interface{})
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
- err := json.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
+ err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
if err != nil {
- common.SysError("failed to unmarshal param override: " + err.Error())
+ common.SysLog(fmt.Sprintf("failed to unmarshal param override: channel_id=%d, error=%v", channel.Id, err))
}
}
return paramOverride
@@ -583,3 +913,53 @@ func BatchSetChannelTag(ids []int, tag *string) error {
// 提交事务
return tx.Commit().Error
}
+
+// CountAllChannels returns total channels in DB
+func CountAllChannels() (int64, error) {
+ var total int64
+ err := DB.Model(&Channel{}).Count(&total).Error
+ return total, err
+}
+
+// CountAllTags returns number of non-empty distinct tags
+func CountAllTags() (int64, error) {
+ var total int64
+ err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
+ return total, err
+}
+
+// Get channels of specified type with pagination
+func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
+ var channels []*Channel
+ order := "priority desc"
+ if idSort {
+ order = "id desc"
+ }
+ err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
+ return channels, err
+}
+
+// Count channels of specific type
+func CountChannelsByType(channelType int) (int64, error) {
+ var count int64
+ err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
+ return count, err
+}
+
+// Return map[type]count for all channels
+func CountChannelsGroupByType() (map[int64]int64, error) {
+ type result struct {
+ Type int64 `gorm:"column:type"`
+ Count int64 `gorm:"column:count"`
+ }
+ var results []result
+ err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
+ if err != nil {
+ return nil, err
+ }
+ counts := make(map[int64]int64)
+ for _, r := range results {
+ counts[r.Type] = r.Count
+ }
+ return counts, nil
+}
diff --git a/model/channel_cache.go b/model/channel_cache.go
new file mode 100644
index 00000000..86866e40
--- /dev/null
+++ b/model/channel_cache.go
@@ -0,0 +1,284 @@
+package model
+
+import (
+ "errors"
+ "fmt"
+ "math/rand"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/setting"
+ "one-api/setting/ratio_setting"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+var group2model2channels map[string]map[string][]int // enabled channel
+var channelsIDM map[int]*Channel // all channels include disabled
+var channelSyncLock sync.RWMutex
+
+func InitChannelCache() {
+ if !common.MemoryCacheEnabled {
+ return
+ }
+ newChannelId2channel := make(map[int]*Channel)
+ var channels []*Channel
+ DB.Find(&channels)
+ for _, channel := range channels {
+ newChannelId2channel[channel.Id] = channel
+ }
+ var abilities []*Ability
+ DB.Find(&abilities)
+ groups := make(map[string]bool)
+ for _, ability := range abilities {
+ groups[ability.Group] = true
+ }
+ newGroup2model2channels := make(map[string]map[string][]int)
+ for group := range groups {
+ newGroup2model2channels[group] = make(map[string][]int)
+ }
+ for _, channel := range channels {
+ if channel.Status != common.ChannelStatusEnabled {
+ continue // skip disabled channels
+ }
+ groups := strings.Split(channel.Group, ",")
+ for _, group := range groups {
+ models := strings.Split(channel.Models, ",")
+ for _, model := range models {
+ if _, ok := newGroup2model2channels[group][model]; !ok {
+ newGroup2model2channels[group][model] = make([]int, 0)
+ }
+ newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id)
+ }
+ }
+ }
+
+ // sort by priority
+ for group, model2channels := range newGroup2model2channels {
+ for model, channels := range model2channels {
+ sort.Slice(channels, func(i, j int) bool {
+ return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
+ })
+ newGroup2model2channels[group][model] = channels
+ }
+ }
+
+ channelSyncLock.Lock()
+ group2model2channels = newGroup2model2channels
+ //channelsIDM = newChannelId2channel
+ for i, channel := range newChannelId2channel {
+ if channel.ChannelInfo.IsMultiKey {
+ channel.Keys = channel.GetKeys()
+ if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
+ if oldChannel, ok := channelsIDM[i]; ok {
+ // 存在旧的渠道,如果是多key且轮询,保留轮询索引信息
+ if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
+ channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex
+ }
+ }
+ }
+ }
+ }
+ channelsIDM = newChannelId2channel
+ channelSyncLock.Unlock()
+ common.SysLog("channels synced from database")
+}
+
+func SyncChannelCache(frequency int) {
+ for {
+ time.Sleep(time.Duration(frequency) * time.Second)
+ common.SysLog("syncing channels from database")
+ InitChannelCache()
+ }
+}
+
+func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
+ var channel *Channel
+ var err error
+ selectGroup := group
+ if group == "auto" {
+ if len(setting.AutoGroups) == 0 {
+ return nil, selectGroup, errors.New("auto groups is not enabled")
+ }
+ for _, autoGroup := range setting.AutoGroups {
+ if common.DebugEnabled {
+ println("autoGroup:", autoGroup)
+ }
+ channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
+ if channel == nil {
+ continue
+ } else {
+ c.Set("auto_group", autoGroup)
+ selectGroup = autoGroup
+ if common.DebugEnabled {
+ println("selectGroup:", selectGroup)
+ }
+ break
+ }
+ }
+ } else {
+ channel, err = getRandomSatisfiedChannel(group, model, retry)
+ if err != nil {
+ return nil, group, err
+ }
+ }
+ return channel, selectGroup, nil
+}
+
+func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
+ // if memory cache is disabled, get channel directly from database
+ if !common.MemoryCacheEnabled {
+ return GetRandomSatisfiedChannel(group, model, retry)
+ }
+
+ channelSyncLock.RLock()
+ defer channelSyncLock.RUnlock()
+
+ // First, try to find channels with the exact model name.
+ channels := group2model2channels[group][model]
+
+ // If no channels found, try to find channels with the normalized model name.
+ if len(channels) == 0 {
+ normalizedModel := ratio_setting.FormatMatchingModelName(model)
+ channels = group2model2channels[group][normalizedModel]
+ }
+
+ if len(channels) == 0 {
+ return nil, nil
+ }
+
+ if len(channels) == 1 {
+ if channel, ok := channelsIDM[channels[0]]; ok {
+ return channel, nil
+ }
+ return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
+ }
+
+ uniquePriorities := make(map[int]bool)
+ for _, channelId := range channels {
+ if channel, ok := channelsIDM[channelId]; ok {
+ uniquePriorities[int(channel.GetPriority())] = true
+ } else {
+ return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
+ }
+ }
+ var sortedUniquePriorities []int
+ for priority := range uniquePriorities {
+ sortedUniquePriorities = append(sortedUniquePriorities, priority)
+ }
+ sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
+
+ if retry >= len(uniquePriorities) {
+ retry = len(uniquePriorities) - 1
+ }
+ targetPriority := int64(sortedUniquePriorities[retry])
+
+ // get the priority for the given retry number
+ var targetChannels []*Channel
+ for _, channelId := range channels {
+ if channel, ok := channelsIDM[channelId]; ok {
+ if channel.GetPriority() == targetPriority {
+ targetChannels = append(targetChannels, channel)
+ }
+ } else {
+ return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
+ }
+ }
+
+ // 平滑系数
+ smoothingFactor := 10
+ // Calculate the total weight of all channels up to endIdx
+ totalWeight := 0
+ for _, channel := range targetChannels {
+ totalWeight += channel.GetWeight() + smoothingFactor
+ }
+ // Generate a random value in the range [0, totalWeight)
+ randomWeight := rand.Intn(totalWeight)
+
+ // Find a channel based on its weight
+ for _, channel := range targetChannels {
+ randomWeight -= channel.GetWeight() + smoothingFactor
+ if randomWeight < 0 {
+ return channel, nil
+ }
+ }
+ // return null if no channel is not found
+ return nil, errors.New("channel not found")
+}
+
+func CacheGetChannel(id int) (*Channel, error) {
+ if !common.MemoryCacheEnabled {
+ return GetChannelById(id, true)
+ }
+ channelSyncLock.RLock()
+ defer channelSyncLock.RUnlock()
+
+ c, ok := channelsIDM[id]
+ if !ok {
+ return nil, fmt.Errorf("渠道# %d,已不存在", id)
+ }
+ return c, nil
+}
+
+func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
+ if !common.MemoryCacheEnabled {
+ channel, err := GetChannelById(id, true)
+ if err != nil {
+ return nil, err
+ }
+ return &channel.ChannelInfo, nil
+ }
+ channelSyncLock.RLock()
+ defer channelSyncLock.RUnlock()
+
+ c, ok := channelsIDM[id]
+ if !ok {
+ return nil, fmt.Errorf("渠道# %d,已不存在", id)
+ }
+ return &c.ChannelInfo, nil
+}
+
+func CacheUpdateChannelStatus(id int, status int) {
+ if !common.MemoryCacheEnabled {
+ return
+ }
+ channelSyncLock.Lock()
+ defer channelSyncLock.Unlock()
+ if channel, ok := channelsIDM[id]; ok {
+ channel.Status = status
+ }
+ if status != common.ChannelStatusEnabled {
+ // delete the channel from group2model2channels
+ for group, model2channels := range group2model2channels {
+ for model, channels := range model2channels {
+ for i, channelId := range channels {
+ if channelId == id {
+ // remove the channel from the slice
+ group2model2channels[group][model] = append(channels[:i], channels[i+1:]...)
+ break
+ }
+ }
+ }
+ }
+ }
+}
+
+func CacheUpdateChannel(channel *Channel) {
+ if !common.MemoryCacheEnabled {
+ return
+ }
+ channelSyncLock.Lock()
+ defer channelSyncLock.Unlock()
+ if channel == nil {
+ return
+ }
+
+ println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
+
+ println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
+ channelsIDM[channel.Id] = channel
+ println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
+}
diff --git a/model/log.go b/model/log.go
index 0a891fcd..979cbe7b 100644
--- a/model/log.go
+++ b/model/log.go
@@ -4,6 +4,8 @@ import (
"context"
"fmt"
"one-api/common"
+ "one-api/logger"
+ "one-api/types"
"os"
"strings"
"time"
@@ -27,11 +29,12 @@ type Log struct {
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
UseTime int `json:"use_time" gorm:"default:0"`
- IsStream bool `json:"is_stream" gorm:"default:false"`
+ IsStream bool `json:"is_stream"`
ChannelId int `json:"channel" gorm:"index"`
ChannelName string `json:"channel_name" gorm:"->"`
TokenId int `json:"token_id" gorm:"default:0;index"`
Group string `json:"group" gorm:"index"`
+ Ip string `json:"ip" gorm:"index;default:''"`
Other string `json:"other"`
}
@@ -48,7 +51,7 @@ func formatUserLogs(logs []*Log) {
for i := range logs {
logs[i].ChannelName = ""
var otherMap map[string]interface{}
- otherMap = common.StrToMap(logs[i].Other)
+ otherMap, _ = common.StrToMap(logs[i].Other)
if otherMap != nil {
// delete admin
delete(otherMap, "admin_info")
@@ -61,7 +64,7 @@ func formatUserLogs(logs []*Log) {
func GetLogByKey(key string) (logs []*Log, err error) {
if os.Getenv("LOG_SQL_DSN") != "" {
var tk Token
- if err = DB.Model(&Token{}).Where(keyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
+ if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
return nil, err
}
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
@@ -86,15 +89,22 @@ func RecordLog(userId int, logType int, content string) {
}
err := LOG_DB.Create(log).Error
if err != nil {
- common.SysError("failed to record log: " + err.Error())
+ common.SysLog("failed to record log: " + err.Error())
}
}
func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
- common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
+ logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
username := c.GetString("username")
otherStr := common.MapToJsonStr(other)
+ // 判断是否需要记录 IP
+ needRecordIp := false
+ if settingMap, err := GetUserSetting(userId, false); err == nil {
+ if settingMap.RecordIpLog {
+ needRecordIp = true
+ }
+ }
log := &Log{
UserId: userId,
Username: username,
@@ -111,48 +121,80 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
UseTime: useTimeSeconds,
IsStream: isStream,
Group: group,
- Other: otherStr,
+ Ip: func() string {
+ if needRecordIp {
+ return c.ClientIP()
+ }
+ return ""
+ }(),
+ Other: otherStr,
}
err := LOG_DB.Create(log).Error
if err != nil {
- common.LogError(c, "failed to record log: "+err.Error())
+ logger.LogError(c, "failed to record log: "+err.Error())
}
}
-func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
- modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
- isStream bool, group string, other map[string]interface{}) {
- common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
+type RecordConsumeLogParams struct {
+ ChannelId int `json:"channel_id"`
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ ModelName string `json:"model_name"`
+ TokenName string `json:"token_name"`
+ Quota int `json:"quota"`
+ Content string `json:"content"`
+ TokenId int `json:"token_id"`
+ UseTimeSeconds int `json:"use_time_seconds"`
+ IsStream bool `json:"is_stream"`
+ Group string `json:"group"`
+ Other map[string]interface{} `json:"other"`
+}
+
+func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
if !common.LogConsumeEnabled {
return
}
+ logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
username := c.GetString("username")
- otherStr := common.MapToJsonStr(other)
+ otherStr := common.MapToJsonStr(params.Other)
+ // 判断是否需要记录 IP
+ needRecordIp := false
+ if settingMap, err := GetUserSetting(userId, false); err == nil {
+ if settingMap.RecordIpLog {
+ needRecordIp = true
+ }
+ }
log := &Log{
UserId: userId,
Username: username,
CreatedAt: common.GetTimestamp(),
Type: LogTypeConsume,
- Content: content,
- PromptTokens: promptTokens,
- CompletionTokens: completionTokens,
- TokenName: tokenName,
- ModelName: modelName,
- Quota: quota,
- ChannelId: channelId,
- TokenId: tokenId,
- UseTime: useTimeSeconds,
- IsStream: isStream,
- Group: group,
- Other: otherStr,
+ Content: params.Content,
+ PromptTokens: params.PromptTokens,
+ CompletionTokens: params.CompletionTokens,
+ TokenName: params.TokenName,
+ ModelName: params.ModelName,
+ Quota: params.Quota,
+ ChannelId: params.ChannelId,
+ TokenId: params.TokenId,
+ UseTime: params.UseTimeSeconds,
+ IsStream: params.IsStream,
+ Group: params.Group,
+ Ip: func() string {
+ if needRecordIp {
+ return c.ClientIP()
+ }
+ return ""
+ }(),
+ Other: otherStr,
}
err := LOG_DB.Create(log).Error
if err != nil {
- common.LogError(c, "failed to record log: "+err.Error())
+ logger.LogError(c, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
gopool.Go(func() {
- LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
+ LogQuotaData(userId, username, params.ModelName, params.Quota, common.GetTimestamp(), params.PromptTokens+params.CompletionTokens)
})
}
}
@@ -184,7 +226,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = tx.Where("logs.channel_id = ?", channel)
}
if group != "" {
- tx = tx.Where("logs."+groupCol+" = ?", group)
+ tx = tx.Where("logs."+logGroupCol+" = ?", group)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
@@ -195,21 +237,22 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
return nil, 0, err
}
- channelIds := make([]int, 0)
- channelMap := make(map[int]string)
+ channelIds := types.NewSet[int]()
for _, log := range logs {
if log.ChannelId != 0 {
- channelIds = append(channelIds, log.ChannelId)
+ channelIds.Add(log.ChannelId)
}
}
- if len(channelIds) > 0 {
+
+ if channelIds.Len() > 0 {
var channels []struct {
Id int `gorm:"column:id"`
Name string `gorm:"column:name"`
}
- if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds).Find(&channels).Error; err != nil {
+ if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil {
return logs, total, err
}
+ channelMap := make(map[int]string, len(channels))
for _, channel := range channels {
channelMap[channel.Id] = channel.Name
}
@@ -242,7 +285,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
tx = tx.Where("logs.created_at <= ?", endTimestamp)
}
if group != "" {
- tx = tx.Where("logs."+groupCol+" = ?", group)
+ tx = tx.Where("logs."+logGroupCol+" = ?", group)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
@@ -303,8 +346,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
}
if group != "" {
- tx = tx.Where(groupCol+" = ?", group)
- rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group)
+ tx = tx.Where(logGroupCol+" = ?", group)
+ rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group)
}
tx = tx.Where("type = ?", LogTypeConsume)
diff --git a/model/main.go b/model/main.go
index 61d6bb10..dbf27152 100644
--- a/model/main.go
+++ b/model/main.go
@@ -1,6 +1,7 @@
package model
import (
+ "fmt"
"log"
"one-api/common"
"one-api/constant"
@@ -15,24 +16,70 @@ import (
"gorm.io/gorm"
)
-var groupCol string
-var keyCol string
+var commonGroupCol string
+var commonKeyCol string
+var commonTrueVal string
+var commonFalseVal string
+
+var logKeyCol string
+var logGroupCol string
func initCol() {
+ // init common column names
if common.UsingPostgreSQL {
- groupCol = `"group"`
- keyCol = `"key"`
-
+ commonGroupCol = `"group"`
+ commonKeyCol = `"key"`
+ commonTrueVal = "true"
+ commonFalseVal = "false"
} else {
- groupCol = "`group`"
- keyCol = "`key`"
+ commonGroupCol = "`group`"
+ commonKeyCol = "`key`"
+ commonTrueVal = "1"
+ commonFalseVal = "0"
}
+ if os.Getenv("LOG_SQL_DSN") != "" {
+ switch common.LogSqlType {
+ case common.DatabaseTypePostgreSQL:
+ logGroupCol = `"group"`
+ logKeyCol = `"key"`
+ default:
+ logGroupCol = commonGroupCol
+ logKeyCol = commonKeyCol
+ }
+ } else {
+ // LOG_SQL_DSN 为空时,日志数据库与主数据库相同
+ if common.UsingPostgreSQL {
+ logGroupCol = `"group"`
+ logKeyCol = `"key"`
+ } else {
+ logGroupCol = commonGroupCol
+ logKeyCol = commonKeyCol
+ }
+ }
+ // log sql type and database type
+ //common.SysLog("Using Log SQL Type: " + common.LogSqlType)
}
var DB *gorm.DB
var LOG_DB *gorm.DB
+// dropIndexIfExists drops a MySQL index only if it exists to avoid noisy 1091 errors
+func dropIndexIfExists(tableName string, indexName string) {
+ if !common.UsingMySQL {
+ return
+ }
+ var count int64
+ // Check index existence via information_schema
+ err := DB.Raw(
+ "SELECT COUNT(1) FROM information_schema.statistics WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?",
+ tableName, indexName,
+ ).Scan(&count).Error
+ if err == nil && count > 0 {
+ _ = DB.Exec("ALTER TABLE " + tableName + " DROP INDEX " + indexName + ";").Error
+ }
+}
+
func createRootAccountIfNeed() error {
var user User
//if user.Status != common.UserStatusEnabled {
@@ -83,7 +130,7 @@ func CheckSetup() {
}
}
-func chooseDB(envName string) (*gorm.DB, error) {
+func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
defer func() {
initCol()
}()
@@ -92,7 +139,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
- common.UsingPostgreSQL = true
+ if !isLog {
+ common.UsingPostgreSQL = true
+ } else {
+ common.LogSqlType = common.DatabaseTypePostgreSQL
+ }
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
@@ -102,7 +153,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
}
if strings.HasPrefix(dsn, "local") {
common.SysLog("SQL_DSN not set, using SQLite as database")
- common.UsingSQLite = true
+ if !isLog {
+ common.UsingSQLite = true
+ } else {
+ common.LogSqlType = common.DatabaseTypeSQLite
+ }
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
@@ -117,7 +172,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
dsn += "?parseTime=true"
}
}
- common.UsingMySQL = true
+ if !isLog {
+ common.UsingMySQL = true
+ } else {
+ common.LogSqlType = common.DatabaseTypeMySQL
+ }
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
@@ -131,12 +190,18 @@ func chooseDB(envName string) (*gorm.DB, error) {
}
func InitDB() (err error) {
- db, err := chooseDB("SQL_DSN")
+ db, err := chooseDB("SQL_DSN", false)
if err == nil {
if common.DebugEnabled {
db = db.Debug()
}
DB = db
+ // MySQL charset/collation startup check: ensure Chinese-capable charset
+ if common.UsingMySQL {
+ if err := checkMySQLChineseSupport(DB); err != nil {
+ panic(err)
+ }
+ }
sqlDB, err := DB.DB()
if err != nil {
return err
@@ -149,7 +214,7 @@ func InitDB() (err error) {
return nil
}
if common.UsingMySQL {
- _, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
+ //_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
}
common.SysLog("database migration started")
err = migrateDB()
@@ -165,12 +230,18 @@ func InitLogDB() (err error) {
LOG_DB = DB
return
}
- db, err := chooseDB("LOG_SQL_DSN")
+ db, err := chooseDB("LOG_SQL_DSN", true)
if err == nil {
if common.DebugEnabled {
db = db.Debug()
}
LOG_DB = db
+ // If log DB is MySQL, also ensure Chinese-capable charset
+ if common.LogSqlType == common.DatabaseTypeMySQL {
+ if err := checkMySQLChineseSupport(LOG_DB); err != nil {
+ panic(err)
+ }
+ }
sqlDB, err := LOG_DB.DB()
if err != nil {
return err
@@ -182,12 +253,6 @@ func InitLogDB() (err error) {
if !common.IsMasterNode {
return nil
}
- //if common.UsingMySQL {
- // _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
- // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
- // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
- // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
- //}
common.SysLog("database migration started")
err = migrateLOGDB()
return err
@@ -198,54 +263,99 @@ func InitLogDB() (err error) {
}
func migrateDB() error {
- err := DB.AutoMigrate(&Channel{})
+ // 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录
+ // 删除单列唯一索引(列级 UNIQUE)及早期命名方式,防止与新复合唯一索引 (model_name, deleted_at) 冲突
+ dropIndexIfExists("models", "uk_model_name") // 新版复合索引名称(若已存在)
+ dropIndexIfExists("models", "model_name") // 旧版列级唯一索引名称
+
+ dropIndexIfExists("vendors", "uk_vendor_name") // 新版复合索引名称(若已存在)
+ dropIndexIfExists("vendors", "name") // 旧版列级唯一索引名称
+ //if !common.UsingPostgreSQL {
+ // return migrateDBFast()
+ //}
+ err := DB.AutoMigrate(
+ &Channel{},
+ &Token{},
+ &User{},
+ &Option{},
+ &Redemption{},
+ &Ability{},
+ &Log{},
+ &Midjourney{},
+ &TopUp{},
+ &QuotaData{},
+ &Task{},
+ &Model{},
+ &Vendor{},
+ &PrefillGroup{},
+ &Setup{},
+ &TwoFA{},
+ &TwoFABackupCode{},
+ )
if err != nil {
return err
}
- err = DB.AutoMigrate(&Token{})
- if err != nil {
- return err
+ return nil
+}
+
+func migrateDBFast() error {
+ // 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录
+ // 删除单列唯一索引(列级 UNIQUE)及早期命名方式,防止与新复合唯一索引冲突
+ dropIndexIfExists("models", "uk_model_name")
+ dropIndexIfExists("models", "model_name")
+
+ dropIndexIfExists("vendors", "uk_vendor_name")
+ dropIndexIfExists("vendors", "name")
+
+ var wg sync.WaitGroup
+
+ migrations := []struct {
+ model interface{}
+ name string
+ }{
+ {&Channel{}, "Channel"},
+ {&Token{}, "Token"},
+ {&User{}, "User"},
+ {&Option{}, "Option"},
+ {&Redemption{}, "Redemption"},
+ {&Ability{}, "Ability"},
+ {&Log{}, "Log"},
+ {&Midjourney{}, "Midjourney"},
+ {&TopUp{}, "TopUp"},
+ {&QuotaData{}, "QuotaData"},
+ {&Task{}, "Task"},
+ {&Model{}, "Model"},
+ {&Vendor{}, "Vendor"},
+ {&PrefillGroup{}, "PrefillGroup"},
+ {&Setup{}, "Setup"},
+ {&TwoFA{}, "TwoFA"},
+ {&TwoFABackupCode{}, "TwoFABackupCode"},
}
- err = DB.AutoMigrate(&User{})
- if err != nil {
- return err
+ // 动态计算migration数量,确保errChan缓冲区足够大
+ errChan := make(chan error, len(migrations))
+
+ for _, m := range migrations {
+ wg.Add(1)
+ go func(model interface{}, name string) {
+ defer wg.Done()
+ if err := DB.AutoMigrate(model); err != nil {
+ errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
+ }
+ }(m.model, m.name)
}
- err = DB.AutoMigrate(&Option{})
- if err != nil {
- return err
+
+ // Wait for all migrations to complete
+ wg.Wait()
+ close(errChan)
+
+ // Check for any errors
+ for err := range errChan {
+ if err != nil {
+ return err
+ }
}
- err = DB.AutoMigrate(&Redemption{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&Ability{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&Log{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&Midjourney{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&TopUp{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&QuotaData{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&Task{})
- if err != nil {
- return err
- }
- err = DB.AutoMigrate(&Setup{})
common.SysLog("database migrated")
- //err = createRootAccountIfNeed()
- return err
+ return nil
}
func migrateLOGDB() error {
@@ -275,6 +385,98 @@ func CloseDB() error {
return closeDB(DB)
}
+// checkMySQLChineseSupport ensures the MySQL connection and current schema
+// default charset/collation can store Chinese characters. It allows common
+// Chinese-capable charsets (utf8mb4, utf8, gbk, big5, gb18030) and panics otherwise.
+func checkMySQLChineseSupport(db *gorm.DB) error {
+ // 仅检测:当前库默认字符集/排序规则 + 各表的排序规则(隐含字符集)
+
+ // Read current schema defaults
+ var schemaCharset, schemaCollation string
+ err := db.Raw("SELECT DEFAULT_CHARACTER_SET_NAME, DEFAULT_COLLATION_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = DATABASE()").Row().Scan(&schemaCharset, &schemaCollation)
+ if err != nil {
+ return fmt.Errorf("读取当前库默认字符集/排序规则失败 / Failed to read schema default charset/collation: %v", err)
+ }
+
+ toLower := func(s string) string { return strings.ToLower(s) }
+ // Allowed charsets that can store Chinese text
+ allowedCharsets := map[string]string{
+ "utf8mb4": "utf8mb4_",
+ "utf8": "utf8_",
+ "gbk": "gbk_",
+ "big5": "big5_",
+ "gb18030": "gb18030_",
+ }
+ isChineseCapable := func(cs, cl string) bool {
+ csLower := toLower(cs)
+ clLower := toLower(cl)
+ if prefix, ok := allowedCharsets[csLower]; ok {
+ if clLower == "" {
+ return true
+ }
+ return strings.HasPrefix(clLower, prefix)
+ }
+ // 如果仅提供了排序规则,尝试按排序规则前缀判断
+ for _, prefix := range allowedCharsets {
+ if strings.HasPrefix(clLower, prefix) {
+ return true
+ }
+ }
+ return false
+ }
+
+ // 1) 当前库默认值必须支持中文
+ if !isChineseCapable(schemaCharset, schemaCollation) {
+ return fmt.Errorf("当前库默认字符集/排序规则不支持中文:schema(%s/%s)。请将库设置为 utf8mb4/utf8/gbk/big5/gb18030 / Schema default charset/collation is not Chinese-capable: schema(%s/%s). Please set to utf8mb4/utf8/gbk/big5/gb18030",
+ schemaCharset, schemaCollation, schemaCharset, schemaCollation)
+ }
+
+ // 2) 所有物理表的排序规则(隐含字符集)必须支持中文
+ type tableInfo struct {
+ Name string
+ Collation *string
+ }
+ var tables []tableInfo
+ if err := db.Raw("SELECT TABLE_NAME, TABLE_COLLATION FROM information_schema.TABLES WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE'").Scan(&tables).Error; err != nil {
+ return fmt.Errorf("读取表排序规则失败 / Failed to read table collations: %v", err)
+ }
+
+ var badTables []string
+ for _, t := range tables {
+ // NULL 或空表示继承库默认设置,已在上面校验库默认,视为通过
+ if t.Collation == nil || *t.Collation == "" {
+ continue
+ }
+ cl := *t.Collation
+ // 仅凭排序规则判断是否中文可用
+ ok := false
+ lower := strings.ToLower(cl)
+ for _, prefix := range allowedCharsets {
+ if strings.HasPrefix(lower, prefix) {
+ ok = true
+ break
+ }
+ }
+ if !ok {
+ badTables = append(badTables, fmt.Sprintf("%s(%s)", t.Name, cl))
+ }
+ }
+
+ if len(badTables) > 0 {
+ // 限制输出数量以避免日志过长
+ maxShow := 20
+ shown := badTables
+ if len(shown) > maxShow {
+ shown = shown[:maxShow]
+ }
+ return fmt.Errorf(
+ "存在不支持中文的表,请修复其排序规则/字符集。示例(最多展示 %d 项):%v / Found tables not Chinese-capable. Please fix their collation/charset. Examples (showing up to %d): %v",
+ maxShow, shown, maxShow, shown,
+ )
+ }
+ return nil
+}
+
var (
lastPingTime time.Time
pingMutex sync.Mutex
diff --git a/model/midjourney.go b/model/midjourney.go
index 5f85abfd..c6ef5de5 100644
--- a/model/midjourney.go
+++ b/model/midjourney.go
@@ -14,6 +14,8 @@ type Midjourney struct {
StartTime int64 `json:"start_time" gorm:"index"`
FinishTime int64 `json:"finish_time" gorm:"index"`
ImageUrl string `json:"image_url"`
+ VideoUrl string `json:"video_url"`
+ VideoUrls string `json:"video_urls"`
Status string `json:"status" gorm:"type:varchar(20);index"`
Progress string `json:"progress" gorm:"type:varchar(30);index"`
FailReason string `json:"fail_reason"`
@@ -166,3 +168,40 @@ func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
Where("id in (?)", taskIDs).
Updates(params).Error
}
+
+// CountAllTasks returns total midjourney tasks for admin query
+func CountAllTasks(queryParams TaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Midjourney{})
+ if queryParams.ChannelID != "" {
+ query = query.Where("channel_id = ?", queryParams.ChannelID)
+ }
+ if queryParams.MjID != "" {
+ query = query.Where("mj_id = ?", queryParams.MjID)
+ }
+ if queryParams.StartTimestamp != "" {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != "" {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
+
+// CountAllUserTask returns total midjourney tasks for user
+func CountAllUserTask(userId int, queryParams TaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Midjourney{}).Where("user_id = ?", userId)
+ if queryParams.MjID != "" {
+ query = query.Where("mj_id = ?", queryParams.MjID)
+ }
+ if queryParams.StartTimestamp != "" {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != "" {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
diff --git a/model/missing_models.go b/model/missing_models.go
new file mode 100644
index 00000000..18191ba6
--- /dev/null
+++ b/model/missing_models.go
@@ -0,0 +1,30 @@
+package model
+
+// GetMissingModels returns model names that are referenced in the system
+func GetMissingModels() ([]string, error) {
+ // 1. 获取所有已启用模型(去重)
+ models := GetEnabledModels()
+ if len(models) == 0 {
+ return []string{}, nil
+ }
+
+ // 2. 查询已有的元数据模型名
+ var existing []string
+ if err := DB.Model(&Model{}).Where("model_name IN ?", models).Pluck("model_name", &existing).Error; err != nil {
+ return nil, err
+ }
+
+ existingSet := make(map[string]struct{}, len(existing))
+ for _, e := range existing {
+ existingSet[e] = struct{}{}
+ }
+
+ // 3. 收集缺失模型
+ var missing []string
+ for _, name := range models {
+ if _, ok := existingSet[name]; !ok {
+ missing = append(missing, name)
+ }
+ }
+ return missing, nil
+}
diff --git a/model/model_extra.go b/model/model_extra.go
new file mode 100644
index 00000000..71fd84e7
--- /dev/null
+++ b/model/model_extra.go
@@ -0,0 +1,31 @@
+package model
+
+func GetModelEnableGroups(modelName string) []string {
+ // 确保缓存最新
+ GetPricing()
+
+ if modelName == "" {
+ return make([]string, 0)
+ }
+
+ modelEnableGroupsLock.RLock()
+ groups, ok := modelEnableGroups[modelName]
+ modelEnableGroupsLock.RUnlock()
+ if !ok {
+ return make([]string, 0)
+ }
+ return groups
+}
+
+// GetModelQuotaTypes 返回指定模型的计费类型集合(来自缓存)
+func GetModelQuotaTypes(modelName string) []int {
+ GetPricing()
+
+ modelEnableGroupsLock.RLock()
+ quota, ok := modelQuotaTypeMap[modelName]
+ modelEnableGroupsLock.RUnlock()
+ if !ok {
+ return []int{}
+ }
+ return []int{quota}
+}
diff --git a/model/model_meta.go b/model/model_meta.go
new file mode 100644
index 00000000..b7602b0e
--- /dev/null
+++ b/model/model_meta.go
@@ -0,0 +1,146 @@
+package model
+
+import (
+ "one-api/common"
+ "strconv"
+
+ "gorm.io/gorm"
+)
+
+const (
+ NameRuleExact = iota
+ NameRulePrefix
+ NameRuleContains
+ NameRuleSuffix
+)
+
+type BoundChannel struct {
+ Name string `json:"name"`
+ Type int `json:"type"`
+}
+
+type Model struct {
+ Id int `json:"id"`
+ ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name,priority:1"`
+ Description string `json:"description,omitempty" gorm:"type:text"`
+ Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
+ Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
+ VendorID int `json:"vendor_id,omitempty" gorm:"index"`
+ Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
+ Status int `json:"status" gorm:"default:1"`
+ CreatedTime int64 `json:"created_time" gorm:"bigint"`
+ UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
+ DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name,priority:2"`
+
+ BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
+ EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
+ QuotaTypes []int `json:"quota_types,omitempty" gorm:"-"`
+ NameRule int `json:"name_rule" gorm:"default:0"`
+
+ MatchedModels []string `json:"matched_models,omitempty" gorm:"-"`
+ MatchedCount int `json:"matched_count,omitempty" gorm:"-"`
+}
+
+func (mi *Model) Insert() error {
+ now := common.GetTimestamp()
+ mi.CreatedTime = now
+ mi.UpdatedTime = now
+ return DB.Create(mi).Error
+}
+
+func IsModelNameDuplicated(id int, name string) (bool, error) {
+ if name == "" {
+ return false, nil
+ }
+ var cnt int64
+ err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error
+ return cnt > 0, err
+}
+
+func (mi *Model) Update() error {
+ mi.UpdatedTime = common.GetTimestamp()
+ return DB.Session(&gorm.Session{AllowGlobalUpdate: false, FullSaveAssociations: false}).
+ Model(&Model{}).
+ Where("id = ?", mi.Id).
+ Omit("created_time").
+ Select("*").
+ Updates(mi).Error
+}
+
+func (mi *Model) Delete() error {
+ return DB.Delete(mi).Error
+}
+
+func GetVendorModelCounts() (map[int64]int64, error) {
+ var stats []struct {
+ VendorID int64
+ Count int64
+ }
+ if err := DB.Model(&Model{}).
+ Select("vendor_id as vendor_id, count(*) as count").
+ Group("vendor_id").
+ Scan(&stats).Error; err != nil {
+ return nil, err
+ }
+ m := make(map[int64]int64, len(stats))
+ for _, s := range stats {
+ m[s.VendorID] = s.Count
+ }
+ return m, nil
+}
+
+func GetAllModels(offset int, limit int) ([]*Model, error) {
+ var models []*Model
+ err := DB.Order("id DESC").Offset(offset).Limit(limit).Find(&models).Error
+ return models, err
+}
+
+func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel, error) {
+ result := make(map[string][]BoundChannel)
+ if len(modelNames) == 0 {
+ return result, nil
+ }
+ type row struct {
+ Model string
+ Name string
+ Type int
+ }
+ var rows []row
+ err := DB.Table("channels").
+ Select("abilities.model as model, channels.name as name, channels.type as type").
+ Joins("JOIN abilities ON abilities.channel_id = channels.id").
+ Where("abilities.model IN ? AND abilities.enabled = ?", modelNames, true).
+ Distinct().
+ Scan(&rows).Error
+ if err != nil {
+ return nil, err
+ }
+ for _, r := range rows {
+ result[r.Model] = append(result[r.Model], BoundChannel{Name: r.Name, Type: r.Type})
+ }
+ return result, nil
+}
+
+func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
+ var models []*Model
+ db := DB.Model(&Model{})
+ if keyword != "" {
+ like := "%" + keyword + "%"
+ db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like)
+ }
+ if vendor != "" {
+ if vid, err := strconv.Atoi(vendor); err == nil {
+ db = db.Where("models.vendor_id = ?", vid)
+ } else {
+ db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%")
+ }
+ }
+ var total int64
+ if err := db.Count(&total).Error; err != nil {
+ return nil, 0, err
+ }
+ if err := db.Order("models.id DESC").Offset(offset).Limit(limit).Find(&models).Error; err != nil {
+ return nil, 0, err
+ }
+ return models, total, nil
+}
diff --git a/model/option.go b/model/option.go
index d892b120..2121710c 100644
--- a/model/option.go
+++ b/model/option.go
@@ -5,6 +5,7 @@ import (
"one-api/setting"
"one-api/setting/config"
"one-api/setting/operation_setting"
+ "one-api/setting/ratio_setting"
"strconv"
"strings"
"time"
@@ -73,9 +74,18 @@ func InitOptionMap() {
common.OptionMap["EpayId"] = ""
common.OptionMap["EpayKey"] = ""
common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64)
+ common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(setting.USDExchangeRate, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
+ common.OptionMap["StripeMinTopUp"] = strconv.Itoa(setting.StripeMinTopUp)
+ common.OptionMap["StripeApiSecret"] = setting.StripeApiSecret
+ common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret
+ common.OptionMap["StripePriceId"] = setting.StripePriceId
+ common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(setting.StripeUnitPrice, 'f', -1, 64)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
+ common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
+ common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
+ common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = ""
@@ -94,12 +104,13 @@ func InitOptionMap() {
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
- common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
- common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
- common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
- common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
+ common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
+ common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
+ common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
+ common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
+ common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
- common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
+ common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
//common.OptionMap["ChatLink"] = common.ChatLink
//common.OptionMap["ChatLink2"] = common.ChatLink2
@@ -122,6 +133,7 @@ func InitOptionMap() {
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
+ common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
// 自动添加所有注册的模型配置
modelConfigs := config.GlobalConfig.ExportAllConfigs()
@@ -138,7 +150,7 @@ func loadOptionsFromDatabase() {
for _, option := range options {
err := updateOptionMap(option.Key, option.Value)
if err != nil {
- common.SysError("failed to update option map: " + err.Error())
+ common.SysLog("failed to update option map: " + err.Error())
}
}
}
@@ -191,7 +203,7 @@ func updateOptionMap(key string, value string) (err error) {
common.ImageDownloadPermission = intValue
}
}
- if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
+ if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
boolValue := value == "true"
switch key {
case "PasswordRegisterEnabled":
@@ -260,6 +272,10 @@ func updateOptionMap(key string, value string) (err error) {
common.SMTPSSLEnabled = boolValue
case "WorkerAllowHttpImageRequestEnabled":
setting.WorkerAllowHttpImageRequestEnabled = boolValue
+ case "DefaultUseAutoGroup":
+ setting.DefaultUseAutoGroup = boolValue
+ case "ExposeRatioEnabled":
+ ratio_setting.SetExposeRatioEnabled(boolValue)
}
}
switch key {
@@ -286,6 +302,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.PayAddress = value
case "Chats":
err = setting.UpdateChatsByJsonString(value)
+ case "AutoGroups":
+ err = setting.UpdateAutoGroupsByJsonString(value)
case "CustomCallbackAddress":
setting.CustomCallbackAddress = value
case "EpayId":
@@ -294,8 +312,20 @@ func updateOptionMap(key string, value string) (err error) {
setting.EpayKey = value
case "Price":
setting.Price, _ = strconv.ParseFloat(value, 64)
+ case "USDExchangeRate":
+ setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
case "MinTopUp":
setting.MinTopUp, _ = strconv.Atoi(value)
+ case "StripeApiSecret":
+ setting.StripeApiSecret = value
+ case "StripeWebhookSecret":
+ setting.StripeWebhookSecret = value
+ case "StripePriceId":
+ setting.StripePriceId = value
+ case "StripeUnitPrice":
+ setting.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
+ case "StripeMinTopUp":
+ setting.StripeMinTopUp, _ = strconv.Atoi(value)
case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId":
@@ -306,6 +336,8 @@ func updateOptionMap(key string, value string) (err error) {
common.LinuxDOClientId = value
case "LinuxDOClientSecret":
common.LinuxDOClientSecret = value
+ case "LinuxDOMinimumTrustLevel":
+ common.LinuxDOMinimumTrustLevel, _ = strconv.Atoi(value)
case "Footer":
common.Footer = value
case "SystemName":
@@ -351,17 +383,19 @@ func updateOptionMap(key string, value string) (err error) {
case "DataExportDefaultTime":
common.DataExportDefaultTime = value
case "ModelRatio":
- err = operation_setting.UpdateModelRatioByJSONString(value)
+ err = ratio_setting.UpdateModelRatioByJSONString(value)
case "GroupRatio":
- err = setting.UpdateGroupRatioByJSONString(value)
+ err = ratio_setting.UpdateGroupRatioByJSONString(value)
+ case "GroupGroupRatio":
+ err = ratio_setting.UpdateGroupGroupRatioByJSONString(value)
case "UserUsableGroups":
err = setting.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
- err = operation_setting.UpdateCompletionRatioByJSONString(value)
+ err = ratio_setting.UpdateCompletionRatioByJSONString(value)
case "ModelPrice":
- err = operation_setting.UpdateModelPriceByJSONString(value)
+ err = ratio_setting.UpdateModelPriceByJSONString(value)
case "CacheRatio":
- err = operation_setting.UpdateCacheRatioByJSONString(value)
+ err = ratio_setting.UpdateCacheRatioByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
//case "ChatLink":
@@ -378,6 +412,8 @@ func updateOptionMap(key string, value string) (err error) {
operation_setting.AutomaticDisableKeywordsFromString(value)
case "StreamCacheQueueLength":
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
+ case "PayMethods":
+ err = setting.UpdatePayMethodsByJsonString(value)
}
return err
}
diff --git a/model/prefill_group.go b/model/prefill_group.go
new file mode 100644
index 00000000..a21b76fe
--- /dev/null
+++ b/model/prefill_group.go
@@ -0,0 +1,126 @@
+package model
+
+import (
+ "database/sql/driver"
+ "encoding/json"
+ "one-api/common"
+
+ "gorm.io/gorm"
+)
+
+// PrefillGroup 用于存储可复用的“组”信息,例如模型组、标签组、端点组等。
+// Name 字段保持唯一,用于在前端下拉框中展示。
+// Type 字段用于区分组的类别,可选值如:model、tag、endpoint。
+// Items 字段使用 JSON 数组保存对应类型的字符串集合,示例:
+// ["gpt-4o", "gpt-3.5-turbo"]
+// 设计遵循 3NF,避免冗余,提供灵活扩展能力。
+
+// JSONValue 基于 json.RawMessage 实现,支持从数据库的 []byte 和 string 两种类型读取
+type JSONValue json.RawMessage
+
+// Value 实现 driver.Valuer 接口,用于数据库写入
+func (j JSONValue) Value() (driver.Value, error) {
+ if j == nil {
+ return nil, nil
+ }
+ return []byte(j), nil
+}
+
+// Scan 实现 sql.Scanner 接口,兼容不同驱动返回的类型
+func (j *JSONValue) Scan(value interface{}) error {
+ switch v := value.(type) {
+ case nil:
+ *j = nil
+ return nil
+ case []byte:
+ // 拷贝底层字节,避免保留底层缓冲区
+ b := make([]byte, len(v))
+ copy(b, v)
+ *j = JSONValue(b)
+ return nil
+ case string:
+ *j = JSONValue([]byte(v))
+ return nil
+ default:
+ // 其他类型尝试序列化为 JSON
+ b, err := json.Marshal(v)
+ if err != nil {
+ return err
+ }
+ *j = JSONValue(b)
+ return nil
+ }
+}
+
+// MarshalJSON 确保在对外编码时与 json.RawMessage 行为一致
+func (j JSONValue) MarshalJSON() ([]byte, error) {
+ if j == nil {
+ return []byte("null"), nil
+ }
+ return j, nil
+}
+
+// UnmarshalJSON 确保在对外解码时与 json.RawMessage 行为一致
+func (j *JSONValue) UnmarshalJSON(data []byte) error {
+ if data == nil {
+ *j = nil
+ return nil
+ }
+ b := make([]byte, len(data))
+ copy(b, data)
+ *j = JSONValue(b)
+ return nil
+}
+
+type PrefillGroup struct {
+ Id int `json:"id"`
+ Name string `json:"name" gorm:"size:64;not null;uniqueIndex:uk_prefill_name,where:deleted_at IS NULL"`
+ Type string `json:"type" gorm:"size:32;index;not null"`
+ Items JSONValue `json:"items" gorm:"type:json"`
+ Description string `json:"description,omitempty" gorm:"type:varchar(255)"`
+ CreatedTime int64 `json:"created_time" gorm:"bigint"`
+ UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
+ DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
+}
+
+// Insert 新建组
+func (g *PrefillGroup) Insert() error {
+ now := common.GetTimestamp()
+ g.CreatedTime = now
+ g.UpdatedTime = now
+ return DB.Create(g).Error
+}
+
+// IsPrefillGroupNameDuplicated 检查组名称是否重复(排除自身 ID)
+func IsPrefillGroupNameDuplicated(id int, name string) (bool, error) {
+ if name == "" {
+ return false, nil
+ }
+ var cnt int64
+ err := DB.Model(&PrefillGroup{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error
+ return cnt > 0, err
+}
+
+// Update 更新组
+func (g *PrefillGroup) Update() error {
+ g.UpdatedTime = common.GetTimestamp()
+ return DB.Save(g).Error
+}
+
+// DeleteByID 根据 ID 删除组
+func DeletePrefillGroupByID(id int) error {
+ return DB.Delete(&PrefillGroup{}, id).Error
+}
+
+// GetAllPrefillGroups 获取全部组,可按类型过滤(为空则返回全部)
+func GetAllPrefillGroups(groupType string) ([]*PrefillGroup, error) {
+ var groups []*PrefillGroup
+ query := DB.Model(&PrefillGroup{})
+ if groupType != "" {
+ query = query.Where("type = ?", groupType)
+ }
+ if err := query.Order("updated_time DESC").Find(&groups).Error; err != nil {
+ return nil, err
+ }
+ return groups, nil
+}
diff --git a/model/pricing.go b/model/pricing.go
index ba1815e2..3c9349de 100644
--- a/model/pricing.go
+++ b/model/pricing.go
@@ -1,81 +1,309 @@
package model
import (
+ "encoding/json"
+ "fmt"
+ "strings"
+
"one-api/common"
- "one-api/setting/operation_setting"
+ "one-api/constant"
+ "one-api/setting/ratio_setting"
+ "one-api/types"
"sync"
"time"
)
type Pricing struct {
- ModelName string `json:"model_name"`
- QuotaType int `json:"quota_type"`
- ModelRatio float64 `json:"model_ratio"`
- ModelPrice float64 `json:"model_price"`
- OwnerBy string `json:"owner_by"`
- CompletionRatio float64 `json:"completion_ratio"`
- EnableGroup []string `json:"enable_groups,omitempty"`
+ ModelName string `json:"model_name"`
+ Description string `json:"description,omitempty"`
+ Icon string `json:"icon,omitempty"`
+ Tags string `json:"tags,omitempty"`
+ VendorID int `json:"vendor_id,omitempty"`
+ QuotaType int `json:"quota_type"`
+ ModelRatio float64 `json:"model_ratio"`
+ ModelPrice float64 `json:"model_price"`
+ OwnerBy string `json:"owner_by"`
+ CompletionRatio float64 `json:"completion_ratio"`
+ EnableGroup []string `json:"enable_groups"`
+ SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
+}
+
+type PricingVendor struct {
+ ID int `json:"id"`
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
+ Icon string `json:"icon,omitempty"`
}
var (
- pricingMap []Pricing
- lastGetPricingTime time.Time
- updatePricingLock sync.Mutex
+ pricingMap []Pricing
+ vendorsList []PricingVendor
+ supportedEndpointMap map[string]common.EndpointInfo
+ lastGetPricingTime time.Time
+ updatePricingLock sync.Mutex
+
+ // 缓存映射:模型名 -> 启用分组 / 计费类型
+ modelEnableGroups = make(map[string][]string)
+ modelQuotaTypeMap = make(map[string]int)
+ modelEnableGroupsLock = sync.RWMutex{}
+)
+
+var (
+ modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
+ modelSupportEndpointsLock = sync.RWMutex{}
)
func GetPricing() []Pricing {
- updatePricingLock.Lock()
- defer updatePricingLock.Unlock()
-
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
- updatePricing()
+ updatePricingLock.Lock()
+ defer updatePricingLock.Unlock()
+ // Double check after acquiring the lock
+ if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
+ modelSupportEndpointsLock.Lock()
+ defer modelSupportEndpointsLock.Unlock()
+ updatePricing()
+ }
}
- //if group != "" {
- // userPricingMap := make([]Pricing, 0)
- // models := GetGroupModels(group)
- // for _, pricing := range pricingMap {
- // if !common.StringsContains(models, pricing.ModelName) {
- // pricing.Available = false
- // }
- // userPricingMap = append(userPricingMap, pricing)
- // }
- // return userPricingMap
- //}
return pricingMap
}
+// GetVendors 返回当前定价接口使用到的供应商信息
+func GetVendors() []PricingVendor {
+ if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
+ // 保证先刷新一次
+ GetPricing()
+ }
+ return vendorsList
+}
+
+func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
+ if model == "" {
+ return make([]constant.EndpointType, 0)
+ }
+ modelSupportEndpointsLock.RLock()
+ defer modelSupportEndpointsLock.RUnlock()
+ if endpoints, ok := modelSupportEndpointTypes[model]; ok {
+ return endpoints
+ }
+ return make([]constant.EndpointType, 0)
+}
+
func updatePricing() {
//modelRatios := common.GetModelRatios()
- enableAbilities := GetAllEnableAbilities()
- modelGroupsMap := make(map[string][]string)
+ enableAbilities, err := GetAllEnableAbilityWithChannels()
+ if err != nil {
+ common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
+ return
+ }
+ // 预加载模型元数据与供应商一次,避免循环查询
+ var allMeta []Model
+ _ = DB.Find(&allMeta).Error
+ metaMap := make(map[string]*Model)
+ prefixList := make([]*Model, 0)
+ suffixList := make([]*Model, 0)
+ containsList := make([]*Model, 0)
+ for i := range allMeta {
+ m := &allMeta[i]
+ if m.NameRule == NameRuleExact {
+ metaMap[m.ModelName] = m
+ } else {
+ switch m.NameRule {
+ case NameRulePrefix:
+ prefixList = append(prefixList, m)
+ case NameRuleSuffix:
+ suffixList = append(suffixList, m)
+ case NameRuleContains:
+ containsList = append(containsList, m)
+ }
+ }
+ }
+
+ // 将非精确规则模型匹配到 metaMap
+ for _, m := range prefixList {
+ for _, pricingModel := range enableAbilities {
+ if strings.HasPrefix(pricingModel.Model, m.ModelName) {
+ if _, exists := metaMap[pricingModel.Model]; !exists {
+ metaMap[pricingModel.Model] = m
+ }
+ }
+ }
+ }
+ for _, m := range suffixList {
+ for _, pricingModel := range enableAbilities {
+ if strings.HasSuffix(pricingModel.Model, m.ModelName) {
+ if _, exists := metaMap[pricingModel.Model]; !exists {
+ metaMap[pricingModel.Model] = m
+ }
+ }
+ }
+ }
+ for _, m := range containsList {
+ for _, pricingModel := range enableAbilities {
+ if strings.Contains(pricingModel.Model, m.ModelName) {
+ if _, exists := metaMap[pricingModel.Model]; !exists {
+ metaMap[pricingModel.Model] = m
+ }
+ }
+ }
+ }
+
+ // 预加载供应商
+ var vendors []Vendor
+ _ = DB.Find(&vendors).Error
+ vendorMap := make(map[int]*Vendor)
+ for i := range vendors {
+ vendorMap[vendors[i].Id] = &vendors[i]
+ }
+
+ // 构建对前端友好的供应商列表
+ vendorsList = make([]PricingVendor, 0, len(vendors))
+ for _, v := range vendors {
+ vendorsList = append(vendorsList, PricingVendor{
+ ID: v.Id,
+ Name: v.Name,
+ Description: v.Description,
+ Icon: v.Icon,
+ })
+ }
+
+ modelGroupsMap := make(map[string]*types.Set[string])
+
for _, ability := range enableAbilities {
- groups := modelGroupsMap[ability.Model]
- if groups == nil {
- groups = make([]string, 0)
+ groups, ok := modelGroupsMap[ability.Model]
+ if !ok {
+ groups = types.NewSet[string]()
+ modelGroupsMap[ability.Model] = groups
}
- if !common.StringsContains(groups, ability.Group) {
- groups = append(groups, ability.Group)
+ groups.Add(ability.Group)
+ }
+
+ //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
+ modelSupportEndpointsStr := make(map[string][]string)
+
+ // 先根据已有能力填充原生端点
+ for _, ability := range enableAbilities {
+ endpoints := modelSupportEndpointsStr[ability.Model]
+ channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
+ for _, channelType := range channelTypes {
+ if !common.StringsContains(endpoints, string(channelType)) {
+ endpoints = append(endpoints, string(channelType))
+ }
+ }
+ modelSupportEndpointsStr[ability.Model] = endpoints
+ }
+
+ // 再补充模型自定义端点
+ for modelName, meta := range metaMap {
+ if strings.TrimSpace(meta.Endpoints) == "" {
+ continue
+ }
+ var raw map[string]interface{}
+ if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
+ endpoints := modelSupportEndpointsStr[modelName]
+ for k := range raw {
+ if !common.StringsContains(endpoints, k) {
+ endpoints = append(endpoints, k)
+ }
+ }
+ modelSupportEndpointsStr[modelName] = endpoints
+ }
+ }
+
+ modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
+ for model, endpoints := range modelSupportEndpointsStr {
+ supportedEndpoints := make([]constant.EndpointType, 0)
+ for _, endpointStr := range endpoints {
+ endpointType := constant.EndpointType(endpointStr)
+ supportedEndpoints = append(supportedEndpoints, endpointType)
+ }
+ modelSupportEndpointTypes[model] = supportedEndpoints
+ }
+
+ // 构建全局 supportedEndpointMap(默认 + 自定义覆盖)
+ supportedEndpointMap = make(map[string]common.EndpointInfo)
+ // 1. 默认端点
+ for _, endpoints := range modelSupportEndpointTypes {
+ for _, et := range endpoints {
+ if info, ok := common.GetDefaultEndpointInfo(et); ok {
+ if _, exists := supportedEndpointMap[string(et)]; !exists {
+ supportedEndpointMap[string(et)] = info
+ }
+ }
+ }
+ }
+ // 2. 自定义端点(models 表)覆盖默认
+ for _, meta := range metaMap {
+ if strings.TrimSpace(meta.Endpoints) == "" {
+ continue
+ }
+ var raw map[string]interface{}
+ if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
+ for k, v := range raw {
+ switch val := v.(type) {
+ case string:
+ supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
+ case map[string]interface{}:
+ ep := common.EndpointInfo{Method: "POST"}
+ if p, ok := val["path"].(string); ok {
+ ep.Path = p
+ }
+ if m, ok := val["method"].(string); ok {
+ ep.Method = strings.ToUpper(m)
+ }
+ supportedEndpointMap[k] = ep
+ default:
+ // ignore unsupported types
+ }
+ }
}
- modelGroupsMap[ability.Model] = groups
}
pricingMap = make([]Pricing, 0)
for model, groups := range modelGroupsMap {
pricing := Pricing{
- ModelName: model,
- EnableGroup: groups,
+ ModelName: model,
+ EnableGroup: groups.Items(),
+ SupportedEndpointTypes: modelSupportEndpointTypes[model],
}
- modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
+
+ // 补充模型元数据(描述、标签、供应商、状态)
+ if meta, ok := metaMap[model]; ok {
+ // 若模型被禁用(status!=1),则直接跳过,不返回给前端
+ if meta.Status != 1 {
+ continue
+ }
+ pricing.Description = meta.Description
+ pricing.Icon = meta.Icon
+ pricing.Tags = meta.Tags
+ pricing.VendorID = meta.VendorID
+ }
+ modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
if findPrice {
pricing.ModelPrice = modelPrice
pricing.QuotaType = 1
} else {
- modelRatio, _ := operation_setting.GetModelRatio(model)
+ modelRatio, _, _ := ratio_setting.GetModelRatio(model)
pricing.ModelRatio = modelRatio
- pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
+ pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
pricing.QuotaType = 0
}
pricingMap = append(pricingMap, pricing)
}
+
+ // 刷新缓存映射,供高并发快速查询
+ modelEnableGroupsLock.Lock()
+ modelEnableGroups = make(map[string][]string)
+ modelQuotaTypeMap = make(map[string]int)
+ for _, p := range pricingMap {
+ modelEnableGroups[p.ModelName] = p.EnableGroup
+ modelQuotaTypeMap[p.ModelName] = p.QuotaType
+ }
+ modelEnableGroupsLock.Unlock()
+
lastGetPricingTime = time.Now()
}
+
+// GetSupportedEndpointMap 返回全局端点到路径的映射
+func GetSupportedEndpointMap() map[string]common.EndpointInfo {
+ return supportedEndpointMap
+}
diff --git a/model/pricing_refresh.go b/model/pricing_refresh.go
new file mode 100644
index 00000000..cd0d7559
--- /dev/null
+++ b/model/pricing_refresh.go
@@ -0,0 +1,14 @@
+package model
+
+// RefreshPricing 强制立即重新计算与定价相关的缓存。
+// 该方法用于需要最新数据的内部管理 API,
+// 因此会绕过默认的 1 分钟延迟刷新。
+func RefreshPricing() {
+ updatePricingLock.Lock()
+ defer updatePricingLock.Unlock()
+
+ modelSupportEndpointsLock.Lock()
+ defer modelSupportEndpointsLock.Unlock()
+
+ updatePricing()
+}
diff --git a/model/redemption.go b/model/redemption.go
index 89c4ac8c..1ab84f45 100644
--- a/model/redemption.go
+++ b/model/redemption.go
@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"one-api/common"
+ "one-api/logger"
"strconv"
"gorm.io/gorm"
@@ -21,6 +22,7 @@ type Redemption struct {
Count int `json:"count" gorm:"-:all"` // only for api request
UsedUserId int `json:"used_user_id"`
DeletedAt gorm.DeletedAt `gorm:"index"`
+ ExpiredTime int64 `json:"expired_time" gorm:"bigint"` // 过期时间,0 表示不过期
}
func GetAllRedemptions(startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
@@ -131,6 +133,9 @@ func Redeem(key string, userId int) (quota int, err error) {
if redemption.Status != common.RedemptionCodeStatusEnabled {
return errors.New("该兑换码已被使用")
}
+ if redemption.ExpiredTime != 0 && redemption.ExpiredTime < common.GetTimestamp() {
+ return errors.New("该兑换码已过期")
+ }
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
if err != nil {
return err
@@ -144,7 +149,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil {
return 0, errors.New("兑换失败," + err.Error())
}
- RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id))
+ RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id))
return redemption.Quota, nil
}
@@ -162,7 +167,7 @@ func (redemption *Redemption) SelectUpdate() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (redemption *Redemption) Update() error {
var err error
- err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time").Updates(redemption).Error
+ err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time", "expired_time").Updates(redemption).Error
return err
}
@@ -183,3 +188,9 @@ func DeleteRedemptionById(id int) (err error) {
}
return redemption.Delete()
}
+
+func DeleteInvalidRedemptions() (int64, error) {
+ now := common.GetTimestamp()
+ result := DB.Where("status IN ? OR (status = ? AND expired_time != 0 AND expired_time < ?)", []int{common.RedemptionCodeStatusUsed, common.RedemptionCodeStatusDisabled}, common.RedemptionCodeStatusEnabled, now).Delete(&Redemption{})
+ return result.RowsAffected, result.Error
+}
diff --git a/model/task.go b/model/task.go
index df221edf..9e4177ba 100644
--- a/model/task.go
+++ b/model/task.go
@@ -302,3 +302,64 @@ func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, e
err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
return stat, err
}
+
+// TaskCountAllTasks returns total tasks that match the given query params (admin usage)
+func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Task{})
+ if queryParams.ChannelID != "" {
+ query = query.Where("channel_id = ?", queryParams.ChannelID)
+ }
+ if queryParams.Platform != "" {
+ query = query.Where("platform = ?", queryParams.Platform)
+ }
+ if queryParams.UserID != "" {
+ query = query.Where("user_id = ?", queryParams.UserID)
+ }
+ if len(queryParams.UserIDs) != 0 {
+ query = query.Where("user_id in (?)", queryParams.UserIDs)
+ }
+ if queryParams.TaskID != "" {
+ query = query.Where("task_id = ?", queryParams.TaskID)
+ }
+ if queryParams.Action != "" {
+ query = query.Where("action = ?", queryParams.Action)
+ }
+ if queryParams.Status != "" {
+ query = query.Where("status = ?", queryParams.Status)
+ }
+ if queryParams.StartTimestamp != 0 {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != 0 {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
+
+// TaskCountAllUserTask returns total tasks for given user
+func TaskCountAllUserTask(userId int, queryParams SyncTaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Task{}).Where("user_id = ?", userId)
+ if queryParams.TaskID != "" {
+ query = query.Where("task_id = ?", queryParams.TaskID)
+ }
+ if queryParams.Action != "" {
+ query = query.Where("action = ?", queryParams.Action)
+ }
+ if queryParams.Status != "" {
+ query = query.Where("status = ?", queryParams.Status)
+ }
+ if queryParams.Platform != "" {
+ query = query.Where("platform = ?", queryParams.Platform)
+ }
+ if queryParams.StartTimestamp != 0 {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != 0 {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
diff --git a/model/token.go b/model/token.go
index 8587ea62..320b5cf0 100644
--- a/model/token.go
+++ b/model/token.go
@@ -20,8 +20,8 @@ type Token struct {
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int `json:"remain_quota" gorm:"default:0"`
- UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
- ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
+ UnlimitedQuota bool `json:"unlimited_quota"`
+ ModelLimitsEnabled bool `json:"model_limits_enabled"`
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
AllowIps *string `json:"allow_ips" gorm:"default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
@@ -66,7 +66,7 @@ func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token
if token != "" {
token = strings.Trim(token, "sk-")
}
- err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(keyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
+ err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
return tokens, err
}
@@ -91,7 +91,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
token.Status = common.TokenStatusExpired
err := token.SelectUpdate()
if err != nil {
- common.SysError("failed to update token status" + err.Error())
+ common.SysLog("failed to update token status" + err.Error())
}
}
return token, errors.New("该令牌已过期")
@@ -102,7 +102,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
token.Status = common.TokenStatusExhausted
err := token.SelectUpdate()
if err != nil {
- common.SysError("failed to update token status" + err.Error())
+ common.SysLog("failed to update token status" + err.Error())
}
}
keyPrefix := key[:3]
@@ -134,7 +134,7 @@ func GetTokenById(id int) (*Token, error) {
if shouldUpdateRedis(true, err) {
gopool.Go(func() {
if err := cacheSetToken(token); err != nil {
- common.SysError("failed to update user status cache: " + err.Error())
+ common.SysLog("failed to update user status cache: " + err.Error())
}
})
}
@@ -147,7 +147,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
if shouldUpdateRedis(fromDB, err) && token != nil {
gopool.Go(func() {
if err := cacheSetToken(*token); err != nil {
- common.SysError("failed to update user status cache: " + err.Error())
+ common.SysLog("failed to update user status cache: " + err.Error())
}
})
}
@@ -161,7 +161,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
// Don't return error - fall through to DB
}
fromDB = true
- err = DB.Where(keyCol+" = ?", key).First(&token).Error
+ err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error
return token, err
}
@@ -178,7 +178,7 @@ func (token *Token) Update() (err error) {
gopool.Go(func() {
err := cacheSetToken(*token)
if err != nil {
- common.SysError("failed to update token cache: " + err.Error())
+ common.SysLog("failed to update token cache: " + err.Error())
}
})
}
@@ -194,7 +194,7 @@ func (token *Token) SelectUpdate() (err error) {
gopool.Go(func() {
err := cacheSetToken(*token)
if err != nil {
- common.SysError("failed to update token cache: " + err.Error())
+ common.SysLog("failed to update token cache: " + err.Error())
}
})
}
@@ -209,7 +209,7 @@ func (token *Token) Delete() (err error) {
gopool.Go(func() {
err := cacheDeleteToken(token.Key)
if err != nil {
- common.SysError("failed to delete token cache: " + err.Error())
+ common.SysLog("failed to delete token cache: " + err.Error())
}
})
}
@@ -269,7 +269,7 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) {
gopool.Go(func() {
err := cacheIncrTokenQuota(key, int64(quota))
if err != nil {
- common.SysError("failed to increase token quota: " + err.Error())
+ common.SysLog("failed to increase token quota: " + err.Error())
}
})
}
@@ -299,7 +299,7 @@ func DecreaseTokenQuota(id int, key string, quota int) (err error) {
gopool.Go(func() {
err := cacheDecrTokenQuota(key, int64(quota))
if err != nil {
- common.SysError("failed to decrease token quota: " + err.Error())
+ common.SysLog("failed to decrease token quota: " + err.Error())
}
})
}
@@ -320,3 +320,44 @@ func decreaseTokenQuota(id int, quota int) (err error) {
).Error
return err
}
+
+// CountUserTokens returns total number of tokens for the given user, used for pagination
+func CountUserTokens(userId int) (int64, error) {
+ var total int64
+ err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
+ return total, err
+}
+
+// BatchDeleteTokens 删除指定用户的一组令牌,返回成功删除数量
+func BatchDeleteTokens(ids []int, userId int) (int, error) {
+ if len(ids) == 0 {
+ return 0, errors.New("ids 不能为空!")
+ }
+
+ tx := DB.Begin()
+
+ var tokens []Token
+ if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Find(&tokens).Error; err != nil {
+ tx.Rollback()
+ return 0, err
+ }
+
+ if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Delete(&Token{}).Error; err != nil {
+ tx.Rollback()
+ return 0, err
+ }
+
+ if err := tx.Commit().Error; err != nil {
+ return 0, err
+ }
+
+ if common.RedisEnabled {
+ gopool.Go(func() {
+ for _, t := range tokens {
+ _ = cacheDeleteToken(t.Key)
+ }
+ })
+ }
+
+ return len(tokens), nil
+}
diff --git a/model/token_cache.go b/model/token_cache.go
index 0fe02fea..5399dbc8 100644
--- a/model/token_cache.go
+++ b/model/token_cache.go
@@ -10,7 +10,7 @@ import (
func cacheSetToken(token Token) error {
key := common.GenerateHMAC(token.Key)
token.Clean()
- err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
+ err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(common.RedisKeyCacheSeconds())*time.Second)
if err != nil {
return err
}
@@ -19,7 +19,7 @@ func cacheSetToken(token Token) error {
func cacheDeleteToken(key string) error {
key = common.GenerateHMAC(key)
- err := common.RedisHDelObj(fmt.Sprintf("token:%s", key))
+ err := common.RedisDelKey(fmt.Sprintf("token:%s", key))
if err != nil {
return err
}
diff --git a/model/topup.go b/model/topup.go
index 507b8518..802c866f 100644
--- a/model/topup.go
+++ b/model/topup.go
@@ -1,13 +1,23 @@
package model
+import (
+ "errors"
+ "fmt"
+ "one-api/common"
+ "one-api/logger"
+
+ "gorm.io/gorm"
+)
+
type TopUp struct {
- Id int `json:"id"`
- UserId int `json:"user_id" gorm:"index"`
- Amount int64 `json:"amount"`
- Money float64 `json:"money"`
- TradeNo string `json:"trade_no"`
- CreateTime int64 `json:"create_time"`
- Status string `json:"status"`
+ Id int `json:"id"`
+ UserId int `json:"user_id" gorm:"index"`
+ Amount int64 `json:"amount"`
+ Money float64 `json:"money"`
+ TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
+ CreateTime int64 `json:"create_time"`
+ CompleteTime int64 `json:"complete_time"`
+ Status string `json:"status"`
}
func (topUp *TopUp) Insert() error {
@@ -41,3 +51,51 @@ func GetTopUpByTradeNo(tradeNo string) *TopUp {
}
return topUp
}
+
+func Recharge(referenceId string, customerId string) (err error) {
+ if referenceId == "" {
+ return errors.New("未提供支付单号")
+ }
+
+ var quota float64
+ topUp := &TopUp{}
+
+ refCol := "`trade_no`"
+ if common.UsingPostgreSQL {
+ refCol = `"trade_no"`
+ }
+
+ err = DB.Transaction(func(tx *gorm.DB) error {
+ err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", referenceId).First(topUp).Error
+ if err != nil {
+ return errors.New("充值订单不存在")
+ }
+
+ if topUp.Status != common.TopUpStatusPending {
+ return errors.New("充值订单状态错误")
+ }
+
+ topUp.CompleteTime = common.GetTimestamp()
+ topUp.Status = common.TopUpStatusSuccess
+ err = tx.Save(topUp).Error
+ if err != nil {
+ return err
+ }
+
+ quota = topUp.Money * common.QuotaPerUnit
+ err = tx.Model(&User{}).Where("id = ?", topUp.UserId).Updates(map[string]interface{}{"stripe_customer": customerId, "quota": gorm.Expr("quota + ?", quota)}).Error
+ if err != nil {
+ return err
+ }
+
+ return nil
+ })
+
+ if err != nil {
+ return errors.New("充值失败," + err.Error())
+ }
+
+ RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", logger.FormatQuota(int(quota)), topUp.Amount))
+
+ return nil
+}
diff --git a/model/twofa.go b/model/twofa.go
new file mode 100644
index 00000000..8e97289f
--- /dev/null
+++ b/model/twofa.go
@@ -0,0 +1,322 @@
+package model
+
+import (
+ "errors"
+ "fmt"
+ "one-api/common"
+ "time"
+
+ "gorm.io/gorm"
+)
+
+var ErrTwoFANotEnabled = errors.New("用户未启用2FA")
+
+// TwoFA 用户2FA设置表
+type TwoFA struct {
+ Id int `json:"id" gorm:"primaryKey"`
+ UserId int `json:"user_id" gorm:"unique;not null;index"`
+ Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥,不返回给前端
+ IsEnabled bool `json:"is_enabled" gorm:"default:false"`
+ FailedAttempts int `json:"failed_attempts" gorm:"default:0"`
+ LockedUntil *time.Time `json:"locked_until,omitempty"`
+ LastUsedAt *time.Time `json:"last_used_at,omitempty"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+ DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
+}
+
+// TwoFABackupCode 备用码使用记录表
+type TwoFABackupCode struct {
+ Id int `json:"id" gorm:"primaryKey"`
+ UserId int `json:"user_id" gorm:"not null;index"`
+ CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希
+ IsUsed bool `json:"is_used" gorm:"default:false"`
+ UsedAt *time.Time `json:"used_at,omitempty"`
+ CreatedAt time.Time `json:"created_at"`
+ DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
+}
+
+// GetTwoFAByUserId 根据用户ID获取2FA设置
+func GetTwoFAByUserId(userId int) (*TwoFA, error) {
+ if userId == 0 {
+ return nil, errors.New("用户ID不能为空")
+ }
+
+ var twoFA TwoFA
+ err := DB.Where("user_id = ?", userId).First(&twoFA).Error
+ if err != nil {
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ return nil, nil // 返回nil表示未设置2FA
+ }
+ return nil, err
+ }
+
+ return &twoFA, nil
+}
+
+// IsTwoFAEnabled 检查用户是否启用了2FA
+func IsTwoFAEnabled(userId int) bool {
+ twoFA, err := GetTwoFAByUserId(userId)
+ if err != nil || twoFA == nil {
+ return false
+ }
+ return twoFA.IsEnabled
+}
+
+// CreateTwoFA 创建2FA设置
+func (t *TwoFA) Create() error {
+ // 检查用户是否已存在2FA设置
+ existing, err := GetTwoFAByUserId(t.UserId)
+ if err != nil {
+ return err
+ }
+ if existing != nil {
+ return errors.New("用户已存在2FA设置")
+ }
+
+ // 验证用户存在
+ var user User
+ if err := DB.First(&user, t.UserId).Error; err != nil {
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ return errors.New("用户不存在")
+ }
+ return err
+ }
+
+ return DB.Create(t).Error
+}
+
+// Update 更新2FA设置
+func (t *TwoFA) Update() error {
+ if t.Id == 0 {
+ return errors.New("2FA记录ID不能为空")
+ }
+ return DB.Save(t).Error
+}
+
+// Delete 删除2FA设置
+func (t *TwoFA) Delete() error {
+ if t.Id == 0 {
+ return errors.New("2FA记录ID不能为空")
+ }
+
+ // 使用事务确保原子性
+ return DB.Transaction(func(tx *gorm.DB) error {
+ // 同时删除相关的备用码记录(硬删除)
+ if err := tx.Unscoped().Where("user_id = ?", t.UserId).Delete(&TwoFABackupCode{}).Error; err != nil {
+ return err
+ }
+
+ // 硬删除2FA记录
+ return tx.Unscoped().Delete(t).Error
+ })
+}
+
+// ResetFailedAttempts 重置失败尝试次数
+func (t *TwoFA) ResetFailedAttempts() error {
+ t.FailedAttempts = 0
+ t.LockedUntil = nil
+ return t.Update()
+}
+
+// IncrementFailedAttempts 增加失败尝试次数
+func (t *TwoFA) IncrementFailedAttempts() error {
+ t.FailedAttempts++
+
+ // 检查是否需要锁定
+ if t.FailedAttempts >= common.MaxFailAttempts {
+ lockUntil := time.Now().Add(time.Duration(common.LockoutDuration) * time.Second)
+ t.LockedUntil = &lockUntil
+ }
+
+ return t.Update()
+}
+
+// IsLocked 检查账户是否被锁定
+func (t *TwoFA) IsLocked() bool {
+ if t.LockedUntil == nil {
+ return false
+ }
+ return time.Now().Before(*t.LockedUntil)
+}
+
+// CreateBackupCodes 创建备用码
+func CreateBackupCodes(userId int, codes []string) error {
+ return DB.Transaction(func(tx *gorm.DB) error {
+ // 先删除现有的备用码
+ if err := tx.Where("user_id = ?", userId).Delete(&TwoFABackupCode{}).Error; err != nil {
+ return err
+ }
+
+ // 创建新的备用码记录
+ for _, code := range codes {
+ hashedCode, err := common.HashBackupCode(code)
+ if err != nil {
+ return err
+ }
+
+ backupCode := TwoFABackupCode{
+ UserId: userId,
+ CodeHash: hashedCode,
+ IsUsed: false,
+ }
+
+ if err := tx.Create(&backupCode).Error; err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+}
+
+// ValidateBackupCode 验证并使用备用码
+func ValidateBackupCode(userId int, code string) (bool, error) {
+ if !common.ValidateBackupCode(code) {
+ return false, errors.New("验证码或备用码不正确")
+ }
+
+ normalizedCode := common.NormalizeBackupCode(code)
+
+ // 查找未使用的备用码
+ var backupCodes []TwoFABackupCode
+ if err := DB.Where("user_id = ? AND is_used = false", userId).Find(&backupCodes).Error; err != nil {
+ return false, err
+ }
+
+ // 验证备用码
+ for _, bc := range backupCodes {
+ if common.ValidatePasswordAndHash(normalizedCode, bc.CodeHash) {
+ // 标记为已使用
+ now := time.Now()
+ bc.IsUsed = true
+ bc.UsedAt = &now
+
+ if err := DB.Save(&bc).Error; err != nil {
+ return false, err
+ }
+
+ return true, nil
+ }
+ }
+
+ return false, nil
+}
+
+// GetUnusedBackupCodeCount 获取未使用的备用码数量
+func GetUnusedBackupCodeCount(userId int) (int, error) {
+ var count int64
+ err := DB.Model(&TwoFABackupCode{}).Where("user_id = ? AND is_used = false", userId).Count(&count).Error
+ return int(count), err
+}
+
+// DisableTwoFA 禁用用户的2FA
+func DisableTwoFA(userId int) error {
+ twoFA, err := GetTwoFAByUserId(userId)
+ if err != nil {
+ return err
+ }
+ if twoFA == nil {
+ return ErrTwoFANotEnabled
+ }
+
+ // 删除2FA设置和备用码
+ return twoFA.Delete()
+}
+
+// EnableTwoFA 启用2FA
+func (t *TwoFA) Enable() error {
+ t.IsEnabled = true
+ t.FailedAttempts = 0
+ t.LockedUntil = nil
+ return t.Update()
+}
+
+// ValidateTOTPAndUpdateUsage 验证TOTP并更新使用记录
+func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
+ // 检查是否被锁定
+ if t.IsLocked() {
+ return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
+ }
+
+ // 验证TOTP码
+ if !common.ValidateTOTPCode(t.Secret, code) {
+ // 增加失败次数
+ if err := t.IncrementFailedAttempts(); err != nil {
+ common.SysLog("更新2FA失败次数失败: " + err.Error())
+ }
+ return false, nil
+ }
+
+ // 验证成功,重置失败次数并更新最后使用时间
+ now := time.Now()
+ t.FailedAttempts = 0
+ t.LockedUntil = nil
+ t.LastUsedAt = &now
+
+ if err := t.Update(); err != nil {
+ common.SysLog("更新2FA使用记录失败: " + err.Error())
+ }
+
+ return true, nil
+}
+
+// ValidateBackupCodeAndUpdateUsage 验证备用码并更新使用记录
+func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
+ // 检查是否被锁定
+ if t.IsLocked() {
+ return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
+ }
+
+ // 验证备用码
+ valid, err := ValidateBackupCode(t.UserId, code)
+ if err != nil {
+ return false, err
+ }
+
+ if !valid {
+ // 增加失败次数
+ if err := t.IncrementFailedAttempts(); err != nil {
+ common.SysLog("更新2FA失败次数失败: " + err.Error())
+ }
+ return false, nil
+ }
+
+ // 验证成功,重置失败次数并更新最后使用时间
+ now := time.Now()
+ t.FailedAttempts = 0
+ t.LockedUntil = nil
+ t.LastUsedAt = &now
+
+ if err := t.Update(); err != nil {
+ common.SysLog("更新2FA使用记录失败: " + err.Error())
+ }
+
+ return true, nil
+}
+
+// GetTwoFAStats 获取2FA统计信息(管理员使用)
+func GetTwoFAStats() (map[string]interface{}, error) {
+ var totalUsers, enabledUsers int64
+
+ // 总用户数
+ if err := DB.Model(&User{}).Count(&totalUsers).Error; err != nil {
+ return nil, err
+ }
+
+ // 启用2FA的用户数
+ if err := DB.Model(&TwoFA{}).Where("is_enabled = true").Count(&enabledUsers).Error; err != nil {
+ return nil, err
+ }
+
+ enabledRate := float64(0)
+ if totalUsers > 0 {
+ enabledRate = float64(enabledUsers) / float64(totalUsers) * 100
+ }
+
+ return map[string]interface{}{
+ "total_users": totalUsers,
+ "enabled_users": enabledUsers,
+ "enabled_rate": fmt.Sprintf("%.1f%%", enabledRate),
+ }, nil
+}
diff --git a/model/usedata.go b/model/usedata.go
index 1255b0be..7e525d2e 100644
--- a/model/usedata.go
+++ b/model/usedata.go
@@ -21,12 +21,6 @@ type QuotaData struct {
}
func UpdateQuotaData() {
- // recover
- defer func() {
- if r := recover(); r != nil {
- common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
- }
- }()
for {
if common.DataExportEnabled {
common.SysLog("正在更新数据看板数据...")
diff --git a/model/user.go b/model/user.go
index 1a3372aa..29d7a446 100644
--- a/model/user.go
+++ b/model/user.go
@@ -5,6 +5,8 @@ import (
"errors"
"fmt"
"one-api/common"
+ "one-api/dto"
+ "one-api/logger"
"strconv"
"strings"
@@ -41,6 +43,8 @@ type User struct {
DeletedAt gorm.DeletedAt `gorm:"index"`
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
Setting string `json:"setting" gorm:"type:text;column:setting"`
+ Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
+ StripeCustomer string `json:"stripe_customer" gorm:"type:varchar(64);column:stripe_customer;index"`
}
func (user *User) ToBaseUser() *UserBase {
@@ -67,17 +71,21 @@ func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
-func (user *User) GetSetting() map[string]interface{} {
- if user.Setting == "" {
- return nil
+func (user *User) GetSetting() dto.UserSetting {
+ setting := dto.UserSetting{}
+ if user.Setting != "" {
+ err := json.Unmarshal([]byte(user.Setting), &setting)
+ if err != nil {
+ common.SysLog("failed to unmarshal setting: " + err.Error())
+ }
}
- return common.StrToMap(user.Setting)
+ return setting
}
-func (user *User) SetSetting(setting map[string]interface{}) {
+func (user *User) SetSetting(setting dto.UserSetting) {
settingBytes, err := json.Marshal(setting)
if err != nil {
- common.SysError("failed to marshal setting: " + err.Error())
+ common.SysLog("failed to marshal setting: " + err.Error())
return
}
user.Setting = string(settingBytes)
@@ -113,7 +121,7 @@ func GetMaxUserId() int {
return user.Id
}
-func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error) {
+func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err error) {
// Start transaction
tx := DB.Begin()
if tx.Error != nil {
@@ -133,7 +141,7 @@ func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error)
}
// Get paginated users within same transaction
- err = tx.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
+ err = tx.Unscoped().Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("password").Find(&users).Error
if err != nil {
tx.Rollback()
return nil, 0, err
@@ -175,7 +183,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
// 如果是数字,同时搜索ID和其他字段
likeCondition = "id = ? OR " + likeCondition
if group != "" {
- query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
+ query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
@@ -184,7 +192,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
} else {
// 非数字关键字,只搜索字符串字段
if group != "" {
- query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
+ query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
@@ -267,7 +275,7 @@ func inviteUser(inviterId int) (err error) {
func (user *User) TransferAffQuotaToQuota(quota int) error {
// 检查quota是否小于最小额度
if float64(quota) < common.QuotaPerUnit {
- return fmt.Errorf("转移额度最小为%s!", common.LogQuota(int(common.QuotaPerUnit)))
+ return fmt.Errorf("转移额度最小为%s!", logger.LogQuota(int(common.QuotaPerUnit)))
}
// 开始数据库事务
@@ -317,16 +325,16 @@ func (user *User) Insert(inviterId int) error {
return result.Error
}
if common.QuotaForNewUser > 0 {
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
- RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
//_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
- RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
+ RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
_ = inviteUser(inviterId)
}
}
@@ -366,6 +374,7 @@ func (user *User) Edit(updatePassword bool) error {
"display_name": newUser.DisplayName,
"group": newUser.Group,
"quota": newUser.Quota,
+ "remark": newUser.Remark,
}
if updatePassword {
updates["password"] = newUser.Password
@@ -509,7 +518,7 @@ func IsAdmin(userId int) bool {
var user User
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
if err != nil {
- common.SysError("no such user " + err.Error())
+ common.SysLog("no such user " + err.Error())
return false
}
return user.Role >= common.RoleAdminUser
@@ -564,7 +573,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserQuotaCache(id, quota); err != nil {
- common.SysError("failed to update user quota cache: " + err.Error())
+ common.SysLog("failed to update user quota cache: " + err.Error())
}
})
}
@@ -602,7 +611,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserGroupCache(id, group); err != nil {
- common.SysError("failed to update user group cache: " + err.Error())
+ common.SysLog("failed to update user group cache: " + err.Error())
}
})
}
@@ -615,7 +624,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
// Don't return error - fall through to DB
}
fromDB = true
- err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
+ err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error
if err != nil {
return "", err
}
@@ -624,14 +633,14 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
}
// GetUserSetting gets setting from Redis first, falls back to DB if needed
-func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
+func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) {
var setting string
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserSettingCache(id, setting); err != nil {
- common.SysError("failed to update user setting cache: " + err.Error())
+ common.SysLog("failed to update user setting cache: " + err.Error())
}
})
}
@@ -646,10 +655,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
if err != nil {
- return map[string]interface{}{}, err
+ return settingMap, err
}
-
- return common.StrToMap(setting), nil
+ userBase := &UserBase{
+ Setting: setting,
+ }
+ return userBase.GetSetting(), nil
}
func IncreaseUserQuota(id int, quota int, db bool) (err error) {
@@ -659,7 +670,7 @@ func IncreaseUserQuota(id int, quota int, db bool) (err error) {
gopool.Go(func() {
err := cacheIncrUserQuota(id, int64(quota))
if err != nil {
- common.SysError("failed to increase user quota: " + err.Error())
+ common.SysLog("failed to increase user quota: " + err.Error())
}
})
if !db && common.BatchUpdateEnabled {
@@ -684,7 +695,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
gopool.Go(func() {
err := cacheDecrUserQuota(id, int64(quota))
if err != nil {
- common.SysError("failed to decrease user quota: " + err.Error())
+ common.SysLog("failed to decrease user quota: " + err.Error())
}
})
if common.BatchUpdateEnabled {
@@ -740,7 +751,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
},
).Error
if err != nil {
- common.SysError("failed to update user used quota and request count: " + err.Error())
+ common.SysLog("failed to update user used quota and request count: " + err.Error())
return
}
@@ -757,14 +768,14 @@ func updateUserUsedQuota(id int, quota int) {
},
).Error
if err != nil {
- common.SysError("failed to update user used quota: " + err.Error())
+ common.SysLog("failed to update user used quota: " + err.Error())
}
}
func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil {
- common.SysError("failed to update user request count: " + err.Error())
+ common.SysLog("failed to update user request count: " + err.Error())
}
}
@@ -775,7 +786,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) {
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserNameCache(id, username); err != nil {
- common.SysError("failed to update user name cache: " + err.Error())
+ common.SysLog("failed to update user name cache: " + err.Error())
}
})
}
diff --git a/model/user_cache.go b/model/user_cache.go
index bc412e77..936e1a43 100644
--- a/model/user_cache.go
+++ b/model/user_cache.go
@@ -1,13 +1,14 @@
package model
import (
- "encoding/json"
"fmt"
- "github.com/gin-gonic/gin"
"one-api/common"
"one-api/constant"
+ "one-api/dto"
"time"
+ "github.com/gin-gonic/gin"
+
"github.com/bytedance/gopkg/util/gopool"
)
@@ -23,28 +24,23 @@ type UserBase struct {
}
func (user *UserBase) WriteContext(c *gin.Context) {
- c.Set(constant.ContextKeyUserGroup, user.Group)
- c.Set(constant.ContextKeyUserQuota, user.Quota)
- c.Set(constant.ContextKeyUserStatus, user.Status)
- c.Set(constant.ContextKeyUserEmail, user.Email)
- c.Set("username", user.Username)
- c.Set(constant.ContextKeyUserSetting, user.GetSetting())
+ common.SetContextKey(c, constant.ContextKeyUserGroup, user.Group)
+ common.SetContextKey(c, constant.ContextKeyUserQuota, user.Quota)
+ common.SetContextKey(c, constant.ContextKeyUserStatus, user.Status)
+ common.SetContextKey(c, constant.ContextKeyUserEmail, user.Email)
+ common.SetContextKey(c, constant.ContextKeyUserName, user.Username)
+ common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
}
-func (user *UserBase) GetSetting() map[string]interface{} {
- if user.Setting == "" {
- return nil
+func (user *UserBase) GetSetting() dto.UserSetting {
+ setting := dto.UserSetting{}
+ if user.Setting != "" {
+ err := common.Unmarshal([]byte(user.Setting), &setting)
+ if err != nil {
+ common.SysLog("failed to unmarshal setting: " + err.Error())
+ }
}
- return common.StrToMap(user.Setting)
-}
-
-func (user *UserBase) SetSetting(setting map[string]interface{}) {
- settingBytes, err := json.Marshal(setting)
- if err != nil {
- common.SysError("failed to marshal setting: " + err.Error())
- return
- }
- user.Setting = string(settingBytes)
+ return setting
}
// getUserCacheKey returns the key for user cache
@@ -57,7 +53,7 @@ func invalidateUserCache(userId int) error {
if !common.RedisEnabled {
return nil
}
- return common.RedisHDelObj(getUserCacheKey(userId))
+ return common.RedisDelKey(getUserCacheKey(userId))
}
// updateUserCache updates all user cache fields using hash
@@ -69,7 +65,7 @@ func updateUserCache(user User) error {
return common.RedisHSetObj(
getUserCacheKey(user.Id),
user.ToBaseUser(),
- time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
+ time.Duration(common.RedisKeyCacheSeconds())*time.Second,
)
}
@@ -82,7 +78,7 @@ func GetUserCache(userId int) (userCache *UserBase, err error) {
if shouldUpdateRedis(fromDB, err) && user != nil {
gopool.Go(func() {
if err := updateUserCache(*user); err != nil {
- common.SysError("failed to update user status cache: " + err.Error())
+ common.SysLog("failed to update user status cache: " + err.Error())
}
})
}
@@ -173,11 +169,10 @@ func getUserNameCache(userId int) (string, error) {
return cache.Username, nil
}
-func getUserSettingCache(userId int) (map[string]interface{}, error) {
- setting := make(map[string]interface{})
+func getUserSettingCache(userId int) (dto.UserSetting, error) {
cache, err := GetUserCache(userId)
if err != nil {
- return setting, err
+ return dto.UserSetting{}, err
}
return cache.GetSetting(), nil
}
diff --git a/model/utils.go b/model/utils.go
index e6b09aa5..dced2bc6 100644
--- a/model/utils.go
+++ b/model/utils.go
@@ -2,11 +2,12 @@ package model
import (
"errors"
- "github.com/bytedance/gopkg/util/gopool"
- "gorm.io/gorm"
"one-api/common"
"sync"
"time"
+
+ "github.com/bytedance/gopkg/util/gopool"
+ "gorm.io/gorm"
)
const (
@@ -48,6 +49,22 @@ func addNewRecord(type_ int, id int, value int) {
}
func batchUpdate() {
+ // check if there's any data to update
+ hasData := false
+ for i := 0; i < BatchUpdateTypeCount; i++ {
+ batchUpdateLocks[i].Lock()
+ if len(batchUpdateStores[i]) > 0 {
+ hasData = true
+ batchUpdateLocks[i].Unlock()
+ break
+ }
+ batchUpdateLocks[i].Unlock()
+ }
+
+ if !hasData {
+ return
+ }
+
common.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
@@ -60,12 +77,12 @@ func batchUpdate() {
case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value)
if err != nil {
- common.SysError("failed to batch update user quota: " + err.Error())
+ common.SysLog("failed to batch update user quota: " + err.Error())
}
case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value)
if err != nil {
- common.SysError("failed to batch update token quota: " + err.Error())
+ common.SysLog("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value)
diff --git a/model/vendor_meta.go b/model/vendor_meta.go
new file mode 100644
index 00000000..88439f24
--- /dev/null
+++ b/model/vendor_meta.go
@@ -0,0 +1,88 @@
+package model
+
+import (
+ "one-api/common"
+
+ "gorm.io/gorm"
+)
+
+// Vendor 用于存储供应商信息,供模型引用
+// Name 唯一,用于在模型中关联
+// Icon 采用 @lobehub/icons 的图标名,前端可直接渲染
+// Status 预留字段,1 表示启用
+// 本表同样遵循 3NF 设计范式
+
+type Vendor struct {
+ Id int `json:"id"`
+ Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name,priority:1"`
+ Description string `json:"description,omitempty" gorm:"type:text"`
+ Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
+ Status int `json:"status" gorm:"default:1"`
+ CreatedTime int64 `json:"created_time" gorm:"bigint"`
+ UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
+ DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name,priority:2"`
+}
+
+// Insert 创建新的供应商记录
+func (v *Vendor) Insert() error {
+ now := common.GetTimestamp()
+ v.CreatedTime = now
+ v.UpdatedTime = now
+ return DB.Create(v).Error
+}
+
+// IsVendorNameDuplicated 检查供应商名称是否重复(排除自身 ID)
+func IsVendorNameDuplicated(id int, name string) (bool, error) {
+ if name == "" {
+ return false, nil
+ }
+ var cnt int64
+ err := DB.Model(&Vendor{}).Where("name = ? AND id <> ?", name, id).Count(&cnt).Error
+ return cnt > 0, err
+}
+
+// Update 更新供应商记录
+func (v *Vendor) Update() error {
+ v.UpdatedTime = common.GetTimestamp()
+ return DB.Save(v).Error
+}
+
+// Delete 软删除供应商
+func (v *Vendor) Delete() error {
+ return DB.Delete(v).Error
+}
+
+// GetVendorByID 根据 ID 获取供应商
+func GetVendorByID(id int) (*Vendor, error) {
+ var v Vendor
+ err := DB.First(&v, id).Error
+ if err != nil {
+ return nil, err
+ }
+ return &v, nil
+}
+
+// GetAllVendors 获取全部供应商(分页)
+func GetAllVendors(offset int, limit int) ([]*Vendor, error) {
+ var vendors []*Vendor
+ err := DB.Offset(offset).Limit(limit).Find(&vendors).Error
+ return vendors, err
+}
+
+// SearchVendors 按关键字搜索供应商
+func SearchVendors(keyword string, offset int, limit int) ([]*Vendor, int64, error) {
+ db := DB.Model(&Vendor{})
+ if keyword != "" {
+ like := "%" + keyword + "%"
+ db = db.Where("name LIKE ? OR description LIKE ?", like, like)
+ }
+ var total int64
+ if err := db.Count(&total).Error; err != nil {
+ return nil, 0, err
+ }
+ var vendors []*Vendor
+ if err := db.Offset(offset).Limit(limit).Order("id DESC").Find(&vendors).Error; err != nil {
+ return nil, 0, err
+ }
+ return vendors, total, nil
+}
diff --git a/relay/audio_handler.go b/relay/audio_handler.go
new file mode 100644
index 00000000..711cc7a9
--- /dev/null
+++ b/relay/audio_handler.go
@@ -0,0 +1,73 @@
+package relay
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+ info.InitChannelMeta(c)
+
+ audioReq, ok := info.Request.(*dto.AudioRequest)
+ if !ok {
+ return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ request, err := common.DeepCopy(audioReq)
+ if err != nil {
+ return types.NewError(fmt.Errorf("failed to copy request to AudioRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ err = helper.ModelMappedHelper(c, info, request)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
+ }
+
+ adaptor := GetAdaptor(info.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ }
+ adaptor.Init(info)
+
+ ioReader, err := adaptor.ConvertAudioRequest(c, info, *request)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+
+ resp, err := adaptor.DoRequest(c, info, ioReader)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeDoRequestFailed)
+ }
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+
+ var httpResp *http.Response
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ if httpResp.StatusCode != http.StatusOK {
+ newAPIError = service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ }
+
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
+ if newAPIError != nil {
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+
+ postConsumeQuota(c, info, usage.(*dto.Usage), "")
+
+ return nil
+}
diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go
index 50255d0a..ec749133 100644
--- a/relay/channel/adapter.go
+++ b/relay/channel/adapter.go
@@ -5,6 +5,7 @@ import (
"net/http"
"one-api/dto"
relaycommon "one-api/relay/common"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
@@ -21,10 +22,11 @@ type Adaptor interface {
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
- DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
+ DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError)
GetModelList() []string
GetChannelName() string
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
+ ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error)
}
type TaskAdaptor interface {
@@ -44,4 +46,6 @@ type TaskAdaptor interface {
// FetchTask
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
+
+ ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
}
diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go
index 31e926d6..5e31c753 100644
--- a/relay/channel/ali/adaptor.go
+++ b/relay/channel/ali/adaptor.go
@@ -3,24 +3,29 @@ package ali
import (
"errors"
"fmt"
+ "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
+ "one-api/relay/channel/claude"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
-
- "github.com/gin-gonic/gin"
+ "one-api/types"
+ "strings"
)
type Adaptor struct {
}
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
- panic("implement me")
- return nil, nil
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
+ return req, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -28,16 +33,24 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var fullRequestURL string
- switch info.RelayMode {
- case constant.RelayModeEmbeddings:
- fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
- case constant.RelayModeImagesGenerations:
- fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
- case constant.RelayModeCompletions:
- fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
+ switch info.RelayFormat {
+ case types.RelayFormatClaude:
+ fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.ChannelBaseUrl)
default:
- fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
+ switch info.RelayMode {
+ case constant.RelayModeEmbeddings:
+ fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.ChannelBaseUrl)
+ case constant.RelayModeRerank:
+ fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
+ case constant.RelayModeImagesGenerations:
+ fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
+ case constant.RelayModeCompletions:
+ fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl)
+ default:
+ fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.ChannelBaseUrl)
+ }
}
+
return fullRequestURL, nil
}
@@ -57,7 +70,13 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
-
+ // docs: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712216
+ // fix: InternalError.Algo.InvalidParameter: The value of the enable_thinking parameter is restricted to True.
+ if strings.Contains(request.Model, "thinking") {
+ request.EnableThinking = true
+ request.Stream = true
+ info.IsStream = true
+ }
// fix: ali parameter.enable_thinking must be set to false for non-streaming calls
if !info.IsStream {
request.EnableThinking = false
@@ -76,11 +95,11 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
- return nil, errors.New("not implemented")
+ return ConvertRerankRequest(request), nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
- return embeddingRequestOpenAI2Ali(request), nil
+ return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -97,20 +116,18 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
- switch info.RelayMode {
- case constant.RelayModeImagesGenerations:
- err, usage = aliImageHandler(c, resp, info)
- case constant.RelayModeEmbeddings:
- err, usage = aliEmbeddingHandler(c, resp)
- default:
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ switch info.RelayFormat {
+ case types.RelayFormatClaude:
if info.IsStream {
- err, usage = openai.OaiStreamHandler(c, resp, info)
+ return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
} else {
- err, usage = openai.OpenaiHandler(c, resp, info)
+ return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
}
+ default:
+ adaptor := openai.Adaptor{}
+ return adaptor.DoResponse(c, resp, info)
}
- return
}
func (a *Adaptor) GetModelList() []string {
diff --git a/relay/channel/ali/constants.go b/relay/channel/ali/constants.go
index 46de5e40..df64439b 100644
--- a/relay/channel/ali/constants.go
+++ b/relay/channel/ali/constants.go
@@ -8,6 +8,7 @@ var ModelList = []string{
"qwq-32b",
"qwen3-235b-a22b",
"text-embedding-v1",
+ "gte-rerank-v2",
}
var ChannelName = "ali"
diff --git a/relay/channel/ali/dto.go b/relay/channel/ali/dto.go
index f51286ad..dbd18968 100644
--- a/relay/channel/ali/dto.go
+++ b/relay/channel/ali/dto.go
@@ -1,5 +1,7 @@
package ali
+import "one-api/dto"
+
type AliMessage struct {
Content string `json:"content"`
Role string `json:"role"`
@@ -97,3 +99,28 @@ type AliImageRequest struct {
} `json:"parameters,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}
+
+type AliRerankParameters struct {
+ TopN *int `json:"top_n,omitempty"`
+ ReturnDocuments *bool `json:"return_documents,omitempty"`
+}
+
+type AliRerankInput struct {
+ Query string `json:"query"`
+ Documents []any `json:"documents"`
+}
+
+type AliRerankRequest struct {
+ Model string `json:"model"`
+ Input AliRerankInput `json:"input"`
+ Parameters AliRerankParameters `json:"parameters,omitempty"`
+}
+
+type AliRerankResponse struct {
+ Output struct {
+ Results []dto.RerankResponseResult `json:"results"`
+ } `json:"output"`
+ Usage AliUsage `json:"usage"`
+ RequestId string `json:"request_id"`
+ AliError
+}
diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go
index 44203583..645882bc 100644
--- a/relay/channel/ali/image.go
+++ b/relay/channel/ali/image.go
@@ -4,15 +4,18 @@ import (
"encoding/json"
"errors"
"fmt"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/service"
+ "one-api/types"
"strings"
"time"
+
+ "github.com/gin-gonic/gin"
)
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
@@ -20,14 +23,14 @@ func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
imageRequest.Input.Prompt = request.Prompt
imageRequest.Model = request.Model
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
- imageRequest.Parameters.N = request.N
+ imageRequest.Parameters.N = int(request.N)
imageRequest.ResponseFormat = request.ResponseFormat
return &imageRequest
}
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
- url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID)
+ url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
var aliResponse AliResponse
@@ -41,7 +44,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
- common.SysError("updateTask client.Do err: " + err.Error())
+ common.SysLog("updateTask client.Do err: " + err.Error())
return &aliResponse, err, nil
}
defer resp.Body.Close()
@@ -51,7 +54,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
var response AliResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
- common.SysError("updateTask NewDecoder err: " + err.Error())
+ common.SysLog("updateTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil
}
@@ -107,7 +110,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
if responseFormat == "b64_json" {
_, b64, err := service.GetImageFromUrl(data.Url)
if err != nil {
- common.LogError(c, "get_image_data_failed: "+err.Error())
+ logger.LogError(c, "get_image_data_failed: "+err.Error())
continue
}
b64Json = b64
@@ -124,52 +127,46 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
return &imageResponse
}
-func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
responseFormat := c.GetString("response_format")
var aliTaskResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliTaskResponse.Message != "" {
- common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
- return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
+ logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
+ return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
}
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
if err != nil {
- return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponse), nil
}
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
- return &dto.OpenAIErrorWithStatusCode{
- Error: dto.OpenAIError{
- Message: aliResponse.Output.Message,
- Type: "ali_error",
- Param: "",
- Code: aliResponse.Output.Code,
- },
- StatusCode: resp.StatusCode,
- }, nil
+ return types.WithOpenAIError(types.OpenAIError{
+ Message: aliResponse.Output.Message,
+ Type: "ali_error",
+ Param: "",
+ Code: aliResponse.Output.Code,
+ }, resp.StatusCode), nil
}
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return nil, nil
+ c.Writer.Write(jsonResponse)
+ return nil, &dto.Usage{}
}
diff --git a/relay/channel/ali/rerank.go b/relay/channel/ali/rerank.go
new file mode 100644
index 00000000..e7d6b514
--- /dev/null
+++ b/relay/channel/ali/rerank.go
@@ -0,0 +1,74 @@
+package ali
+
+import (
+ "encoding/json"
+ "io"
+ "net/http"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
+ returnDocuments := request.ReturnDocuments
+ if returnDocuments == nil {
+ t := true
+ returnDocuments = &t
+ }
+ return &AliRerankRequest{
+ Model: request.Model,
+ Input: AliRerankInput{
+ Query: request.Query,
+ Documents: request.Documents,
+ },
+ Parameters: AliRerankParameters{
+ TopN: &request.TopN,
+ ReturnDocuments: returnDocuments,
+ },
+ }
+}
+
+func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
+ }
+ service.CloseResponseBodyGracefully(resp)
+
+ var aliResponse AliRerankResponse
+ err = json.Unmarshal(responseBody, &aliResponse)
+ if err != nil {
+ return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
+ }
+
+ if aliResponse.Code != "" {
+ return types.WithOpenAIError(types.OpenAIError{
+ Message: aliResponse.Message,
+ Type: aliResponse.Code,
+ Param: aliResponse.RequestId,
+ Code: aliResponse.Code,
+ }, resp.StatusCode), nil
+ }
+
+ usage := dto.Usage{
+ PromptTokens: aliResponse.Usage.TotalTokens,
+ CompletionTokens: 0,
+ TotalTokens: aliResponse.Usage.TotalTokens,
+ }
+ rerankResponse := dto.RerankResponse{
+ Results: aliResponse.Output.Results,
+ Usage: usage,
+ }
+
+ jsonResponse, err := json.Marshal(rerankResponse)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ c.Writer.Write(jsonResponse)
+ return nil, &usage
+}
diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go
index 3fe893b3..67b63286 100644
--- a/relay/channel/ali/text.go
+++ b/relay/channel/ali/text.go
@@ -3,7 +3,6 @@ package ali
import (
"bufio"
"encoding/json"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -11,6 +10,10 @@ import (
"one-api/relay/helper"
"one-api/service"
"strings"
+
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
@@ -27,9 +30,6 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReque
}
func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
- if request.Model == "" {
- request.Model = "text-embedding-v1"
- }
return &AliEmbeddingRequest{
Model: request.Model,
Input: struct {
@@ -40,46 +40,34 @@ func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingReque
}
}
-func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
- var aliResponse AliEmbeddingResponse
- err := json.NewDecoder(resp.Body).Decode(&aliResponse)
+func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
+ var fullTextResponse dto.FlexibleEmbeddingResponse
+ err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
+ service.CloseResponseBodyGracefully(resp)
- if aliResponse.Code != "" {
- return &dto.OpenAIErrorWithStatusCode{
- Error: dto.OpenAIError{
- Message: aliResponse.Message,
- Type: aliResponse.Code,
- Param: aliResponse.RequestId,
- Code: aliResponse.Code,
- },
- StatusCode: resp.StatusCode,
- }, nil
+ model := c.GetString("model")
+ if model == "" {
+ model = "text-embedding-v4"
}
-
- fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
+ c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
-func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
+func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse, model string) *dto.OpenAIEmbeddingResponse {
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
- Model: "text-embedding-v1",
+ Model: model,
Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
}
@@ -94,12 +82,11 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbe
}
func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
- content, _ := json.Marshal(response.Output.Text)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
@@ -134,7 +121,7 @@ func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStre
return &response
}
-func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var usage dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
@@ -162,7 +149,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
var aliResponse AliResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ common.SysLog("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
@@ -175,7 +162,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ common.SysLog("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -185,42 +172,33 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
return false
}
})
- err := resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
+ service.CloseResponseBodyGracefully(resp)
return nil, &usage
}
-func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var aliResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
- return &dto.OpenAIErrorWithStatusCode{
- Error: dto.OpenAIError{
- Message: aliResponse.Message,
- Type: aliResponse.Code,
- Param: aliResponse.RequestId,
- Code: aliResponse.Code,
- },
- StatusCode: resp.StatusCode,
- }, nil
+ return types.WithOpenAIError(types.OpenAIError{
+ Message: aliResponse.Message,
+ Type: "ali_error",
+ Param: aliResponse.RequestId,
+ Code: aliResponse.Code,
+ }, resp.StatusCode), nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
- jsonResponse, err := json.Marshal(fullTextResponse)
+ jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go
index 1d733bd4..fd745cf7 100644
--- a/relay/channel/api_request.go
+++ b/relay/channel/api_request.go
@@ -7,6 +7,7 @@ import (
"io"
"net/http"
common2 "one-api/common"
+ "one-api/logger"
"one-api/relay/common"
"one-api/relay/constant"
"one-api/relay/helper"
@@ -109,6 +110,12 @@ func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.Canc
gopool.Go(func() {
defer func() {
+ // 增加panic恢复处理
+ if r := recover(); r != nil {
+ if common2.DebugEnabled {
+ println("SSE ping goroutine panic recovered:", fmt.Sprintf("%v", r))
+ }
+ }
if common2.DebugEnabled {
println("SSE ping goroutine stopped.")
}
@@ -119,19 +126,32 @@ func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.Canc
}
ticker := time.NewTicker(pingInterval)
- // 退出时清理 ticker
- defer ticker.Stop()
+ // 确保在任何情况下都清理ticker
+ defer func() {
+ ticker.Stop()
+ if common2.DebugEnabled {
+ println("SSE ping ticker stopped")
+ }
+ }()
var pingMutex sync.Mutex
if common2.DebugEnabled {
println("SSE ping goroutine started")
}
+ // 增加超时控制,防止goroutine长时间运行
+ maxPingDuration := 120 * time.Minute // 最大ping持续时间
+ pingTimeout := time.NewTimer(maxPingDuration)
+ defer pingTimeout.Stop()
+
for {
select {
// 发送 ping 数据
case <-ticker.C:
if err := sendPingData(c, &pingMutex); err != nil {
+ if common2.DebugEnabled {
+ println("SSE ping error, stopping goroutine:", err.Error())
+ }
return
}
// 收到退出信号
@@ -140,6 +160,12 @@ func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.Canc
// request 结束
case <-c.Request.Context().Done():
return
+ // 超时保护,防止goroutine无限运行
+ case <-pingTimeout.C:
+ if common2.DebugEnabled {
+ println("SSE ping goroutine timeout, stopping")
+ }
+ return
}
}
})
@@ -148,26 +174,44 @@ func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.Canc
}
func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
- mutex.Lock()
- defer mutex.Unlock()
+ // 增加超时控制,防止锁死等待
+ done := make(chan error, 1)
+ go func() {
+ mutex.Lock()
+ defer mutex.Unlock()
- err := helper.PingData(c)
- if err != nil {
- common2.LogError(c, "SSE ping error: "+err.Error())
+ err := helper.PingData(c)
+ if err != nil {
+ logger.LogError(c, "SSE ping error: "+err.Error())
+ done <- err
+ return
+ }
+
+ if common2.DebugEnabled {
+ println("SSE ping data sent.")
+ }
+ done <- nil
+ }()
+
+ // 设置发送ping数据的超时时间
+ select {
+ case err := <-done:
return err
+ case <-time.After(10 * time.Second):
+ return errors.New("SSE ping data send timeout")
+ case <-c.Request.Context().Done():
+ return errors.New("request context cancelled during ping")
}
-
- if common2.DebugEnabled {
- println("SSE ping data sent.")
- }
- return nil
}
+func DoRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
+ return doRequest(c, req, info)
+}
func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
var client *http.Client
var err error
- if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
- client, err = service.NewProxyHttpClient(proxyURL.(string))
+ if info.ChannelSetting.Proxy != "" {
+ client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
@@ -175,15 +219,23 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
client = service.GetHttpClient()
}
+ var stopPinger context.CancelFunc
if info.IsStream {
helper.SetEventStreamHeaders(c)
-
// 处理流式请求的 ping 保活
generalSettings := operation_setting.GetGeneralSetting()
- if generalSettings.PingIntervalEnabled {
+ if generalSettings.PingIntervalEnabled && !info.DisablePing {
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
- stopPinger := startPingKeepAlive(c, pingInterval)
- defer stopPinger()
+ stopPinger = startPingKeepAlive(c, pingInterval)
+ // 使用defer确保在任何情况下都能停止ping goroutine
+ defer func() {
+ if stopPinger != nil {
+ stopPinger()
+ if common2.DebugEnabled {
+ println("SSE ping goroutine stopped by defer")
+ }
+ }
+ }()
}
}
diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go
index 9c879399..1526a7f7 100644
--- a/relay/channel/aws/adaptor.go
+++ b/relay/channel/aws/adaptor.go
@@ -8,6 +8,7 @@ import (
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
@@ -21,6 +22,11 @@ type Adaptor struct {
RequestMode int
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
c.Set("request_model", request.Model)
c.Set("converted_request", request)
@@ -57,7 +63,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
var claudeReq *dto.ClaudeRequest
var err error
- claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
+ claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
if err != nil {
return nil, err
}
@@ -84,7 +90,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return nil, nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
} else {
diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go
index 64c7b747..3f8800b1 100644
--- a/relay/channel/aws/constants.go
+++ b/relay/channel/aws/constants.go
@@ -13,6 +13,7 @@ var awsModelIDMap = map[string]string{
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
+ "claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
}
var awsModelCanCrossRegionMap = map[string]map[string]bool{
@@ -54,6 +55,9 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
"anthropic.claude-opus-4-20250514-v1:0": {
"us": true,
},
+ "anthropic.claude-opus-4-1-20250805-v1:0": {
+ "us": true,
+ },
}
var awsRegionCrossModelPrefixMap = map[string]string{
diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go
index 3c9542c6..5822e363 100644
--- a/relay/channel/aws/relay-aws.go
+++ b/relay/channel/aws/relay-aws.go
@@ -1,35 +1,48 @@
package aws
import (
- "encoding/json"
"fmt"
- "github.com/gin-gonic/gin"
- "github.com/pkg/errors"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/types"
"strings"
+ "github.com/gin-gonic/gin"
+ "github.com/pkg/errors"
+
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
- "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
+ bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
+ "github.com/aws/smithy-go/auth/bearer"
)
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
awsSecret := strings.Split(info.ApiKey, "|")
- if len(awsSecret) != 3 {
+ var client *bedrockruntime.Client
+ switch len(awsSecret) {
+ case 2:
+ apiKey := awsSecret[0]
+ region := awsSecret[1]
+ client = bedrockruntime.New(bedrockruntime.Options{
+ Region: region,
+ BearerAuthTokenProvider: bearer.StaticTokenProvider{Token: bearer.Token{Value: apiKey}},
+ })
+ case 3:
+ ak := awsSecret[0]
+ sk := awsSecret[1]
+ region := awsSecret[2]
+ client = bedrockruntime.New(bedrockruntime.Options{
+ Region: region,
+ Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
+ })
+ default:
return nil, errors.New("invalid aws secret key")
}
- ak := awsSecret[0]
- sk := awsSecret[1]
- region := awsSecret[2]
- client := bedrockruntime.New(bedrockruntime.Options{
- Region: region,
- Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
- })
return client, nil
}
@@ -65,24 +78,21 @@ func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
return modelPrefix + "." + awsModelId
}
-func awsModelID(requestModel string) (string, error) {
+func awsModelID(requestModel string) string {
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
- return awsModelID, nil
+ return awsModelID
}
- return requestModel, nil
+ return requestModel
}
-func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
- return wrapErr(errors.Wrap(err, "newAwsClient")), nil
+ return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
}
- awsModelId, err := awsModelID(c.GetString("request_model"))
- if err != nil {
- return wrapErr(errors.Wrap(err, "awsModelID")), nil
- }
+ awsModelId := awsModelID(c.GetString("request_model"))
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
@@ -98,42 +108,42 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
claudeReq_, ok := c.Get("converted_request")
if !ok {
- return wrapErr(errors.New("request not found")), nil
+ return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
}
claudeReq := claudeReq_.(*dto.ClaudeRequest)
awsClaudeReq := copyRequest(claudeReq)
- awsReq.Body, err = json.Marshal(awsClaudeReq)
+ awsReq.Body, err = common.Marshal(awsClaudeReq)
if err != nil {
- return wrapErr(errors.Wrap(err, "marshal request")), nil
+ return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
- return wrapErr(errors.Wrap(err, "InvokeModel")), nil
+ return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
}
claudeInfo := &claude.ClaudeResponseInfo{
- ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+ ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
- claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
+ handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
+ if handlerErr != nil {
+ return handlerErr, nil
+ }
return nil, claudeInfo.Usage
}
-func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
- return wrapErr(errors.Wrap(err, "newAwsClient")), nil
+ return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
}
- awsModelId, err := awsModelID(c.GetString("request_model"))
- if err != nil {
- return wrapErr(errors.Wrap(err, "awsModelID")), nil
- }
+ awsModelId := awsModelID(c.GetString("request_model"))
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
@@ -149,25 +159,25 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
claudeReq_, ok := c.Get("converted_request")
if !ok {
- return wrapErr(errors.New("request not found")), nil
+ return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
}
claudeReq := claudeReq_.(*dto.ClaudeRequest)
awsClaudeReq := copyRequest(claudeReq)
- awsReq.Body, err = json.Marshal(awsClaudeReq)
+ awsReq.Body, err = common.Marshal(awsClaudeReq)
if err != nil {
- return wrapErr(errors.Wrap(err, "marshal request")), nil
+ return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
- return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
+ return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, http.StatusInternalServerError), nil
}
stream := awsResp.GetStream()
defer stream.Close()
claudeInfo := &claude.ClaudeResponseInfo{
- ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+ ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
@@ -176,18 +186,18 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
for event := range stream.Events() {
switch v := event.(type) {
- case *types.ResponseStreamMemberChunk:
+ case *bedrockruntimeTypes.ResponseStreamMemberChunk:
info.SetFirstResponseTime()
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
if respErr != nil {
return respErr, nil
}
- case *types.UnknownUnionMember:
+ case *bedrockruntimeTypes.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
- return wrapErr(errors.New("unknown response type")), nil
+ return types.NewError(errors.New("unknown response type"), types.ErrorCodeInvalidRequest), nil
default:
fmt.Println("union is nil or unknown type")
- return wrapErr(errors.New("nil or unknown response type")), nil
+ return types.NewError(errors.New("nil or unknown response type"), types.ErrorCodeInvalidRequest), nil
}
}
diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go
index 396c31ab..32e301ee 100644
--- a/relay/channel/baidu/adaptor.go
+++ b/relay/channel/baidu/adaptor.go
@@ -9,6 +9,7 @@ import (
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
+ "one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -17,6 +18,11 @@ import (
type Adaptor struct {
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
@@ -95,7 +101,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
default:
suffix += strings.ToLower(info.UpstreamModelName)
}
- fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix)
+ fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.ChannelBaseUrl, suffix)
var accessToken string
var err error
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
@@ -140,15 +146,15 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
- err, usage = baiduStreamHandler(c, resp)
+ err, usage = baiduStreamHandler(c, info, resp)
} else {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
- err, usage = baiduEmbeddingHandler(c, resp)
+ err, usage = baiduEmbeddingHandler(c, info, resp)
default:
- err, usage = baiduHandler(c, resp)
+ err, usage = baiduHandler(c, info, resp)
}
}
return
diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go
index 62b06413..31e8319e 100644
--- a/relay/channel/baidu/relay-baidu.go
+++ b/relay/channel/baidu/relay-baidu.go
@@ -1,21 +1,23 @@
package baidu
import (
- "bufio"
"encoding/json"
"errors"
"fmt"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
+ relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
+ "one-api/types"
"strings"
"sync"
"time"
+
+ "github.com/gin-gonic/gin"
)
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
@@ -32,9 +34,9 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
EnableCitation: false,
UserId: request.User,
}
- if request.MaxTokens != 0 {
- maxTokens := int(request.MaxTokens)
- if request.MaxTokens == 1 {
+ if request.GetMaxTokens() != 0 {
+ maxTokens := int(request.GetMaxTokens())
+ if request.GetMaxTokens() == 1 {
maxTokens = 2
}
baiduRequest.MaxOutputTokens = &maxTokens
@@ -53,12 +55,11 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
}
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
- content, _ := json.Marshal(response.Result)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: response.Result,
},
FinishReason: "stop",
}
@@ -111,98 +112,49 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAI
return &openAIEmbeddingResponse
}
-func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
- var usage dto.Usage
- scanner := bufio.NewScanner(resp.Body)
- scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
- if atEOF && len(data) == 0 {
- return 0, nil, nil
- }
- if i := strings.Index(string(data), "\n"); i >= 0 {
- return i + 1, data[0:i], nil
- }
- if atEOF {
- return len(data), data, nil
- }
- return 0, nil, nil
- })
- dataChan := make(chan string)
- stopChan := make(chan bool)
- go func() {
- for scanner.Scan() {
- data := scanner.Text()
- if len(data) < 6 { // ignore blank line or wrong format
- continue
- }
- data = data[6:]
- dataChan <- data
- }
- stopChan <- true
- }()
- helper.SetEventStreamHeaders(c)
- c.Stream(func(w io.Writer) bool {
- select {
- case data := <-dataChan:
- var baiduResponse BaiduChatStreamResponse
- err := json.Unmarshal([]byte(data), &baiduResponse)
- if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
- return true
- }
- if baiduResponse.Usage.TotalTokens != 0 {
- usage.TotalTokens = baiduResponse.Usage.TotalTokens
- usage.PromptTokens = baiduResponse.Usage.PromptTokens
- usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
- }
- response := streamResponseBaidu2OpenAI(&baiduResponse)
- jsonResponse, err := json.Marshal(response)
- if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
- return true
- }
- c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
+ usage := &dto.Usage{}
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+ var baiduResponse BaiduChatStreamResponse
+ err := common.Unmarshal([]byte(data), &baiduResponse)
+ if err != nil {
+ common.SysLog("error unmarshalling stream response: " + err.Error())
return true
- case <-stopChan:
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- return false
}
+ if baiduResponse.Usage.TotalTokens != 0 {
+ usage.TotalTokens = baiduResponse.Usage.TotalTokens
+ usage.PromptTokens = baiduResponse.Usage.PromptTokens
+ usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
+ }
+ response := streamResponseBaidu2OpenAI(&baiduResponse)
+ err = helper.ObjectData(c, response)
+ if err != nil {
+ common.SysLog("error sending stream response: " + err.Error())
+ }
+ return true
})
- err := resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
- return nil, &usage
+ service.CloseResponseBodyGracefully(resp)
+ return nil, usage
}
-func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var baiduResponse BaiduChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if baiduResponse.ErrorMsg != "" {
- return &dto.OpenAIErrorWithStatusCode{
- Error: dto.OpenAIError{
- Message: baiduResponse.ErrorMsg,
- Type: "baidu_error",
- Param: "",
- Code: baiduResponse.ErrorCode,
- },
- StatusCode: resp.StatusCode,
- }, nil
+ return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
}
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -210,35 +162,24 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
return nil, &fullTextResponse.Usage
}
-func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var baiduResponse BaiduEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if baiduResponse.ErrorMsg != "" {
- return &dto.OpenAIErrorWithStatusCode{
- Error: dto.OpenAIError{
- Message: baiduResponse.ErrorMsg,
- Type: "baidu_error",
- Param: "",
- Code: baiduResponse.ErrorCode,
- },
- StatusCode: resp.StatusCode,
- }, nil
+ return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
}
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -281,7 +222,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
- res, err := service.GetImpatientHttpClient().Do(req)
+ res, err := service.GetHttpClient().Do(req)
if err != nil {
return nil, err
}
diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go
index 2b8a52a2..6744f8ba 100644
--- a/relay/channel/baidu_v2/adaptor.go
+++ b/relay/channel/baidu_v2/adaptor.go
@@ -9,6 +9,8 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
+ "one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -17,10 +19,14 @@ import (
type Adaptor struct {
}
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
- panic("implement me")
- return nil, nil
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
+ adaptor := openai.Adaptor{}
+ return adaptor.ConvertClaudeRequest(c, info, req)
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -37,12 +43,34 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
+ switch info.RelayMode {
+ case constant.RelayModeChatCompletions:
+ return fmt.Sprintf("%s/v2/chat/completions", info.ChannelBaseUrl), nil
+ case constant.RelayModeEmbeddings:
+ return fmt.Sprintf("%s/v2/embeddings", info.ChannelBaseUrl), nil
+ case constant.RelayModeImagesGenerations:
+ return fmt.Sprintf("%s/v2/images/generations", info.ChannelBaseUrl), nil
+ case constant.RelayModeImagesEdits:
+ return fmt.Sprintf("%s/v2/images/edits", info.ChannelBaseUrl), nil
+ case constant.RelayModeRerank:
+ return fmt.Sprintf("%s/v2/rerank", info.ChannelBaseUrl), nil
+ default:
+ }
+ return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", "Bearer "+info.ApiKey)
+ keyParts := strings.Split(info.ApiKey, "|")
+ if len(keyParts) == 0 || keyParts[0] == "" {
+ return errors.New("invalid API key: authorization token is required")
+ }
+ if len(keyParts) > 1 {
+ if keyParts[1] != "" {
+ req.Set("appid", keyParts[1])
+ }
+ }
+ req.Set("Authorization", "Bearer "+keyParts[0])
return nil
}
@@ -83,12 +111,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
- if info.IsStream {
- err, usage = openai.OaiStreamHandler(c, resp, info)
- } else {
- err, usage = openai.OpenaiHandler(c, resp, info)
- }
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ adaptor := openai.Adaptor{}
+ usage, err = adaptor.DoResponse(c, resp, info)
return
}
diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go
index 8389b9f1..959327e1 100644
--- a/relay/channel/claude/adaptor.go
+++ b/relay/channel/claude/adaptor.go
@@ -9,6 +9,7 @@ import (
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
+ "one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -23,6 +24,11 @@ type Adaptor struct {
RequestMode int
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
return request, nil
}
@@ -47,9 +53,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if a.RequestMode == RequestModeMessage {
- return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl), nil
} else {
- return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1/complete", info.ChannelBaseUrl), nil
}
}
@@ -72,7 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if a.RequestMode == RequestModeCompletion {
return RequestOpenAI2ClaudeComplete(*request), nil
} else {
- return RequestOpenAI2ClaudeMessage(*request)
+ return RequestOpenAI2ClaudeMessage(c, *request)
}
}
@@ -94,11 +100,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
- err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
+ return ClaudeStreamHandler(c, resp, info, a.RequestMode)
} else {
- err, usage = ClaudeHandler(c, resp, a.RequestMode, info)
+ return ClaudeHandler(c, resp, info, a.RequestMode)
}
return
}
diff --git a/relay/channel/claude/constants.go b/relay/channel/claude/constants.go
index e0e3c421..a23543d2 100644
--- a/relay/channel/claude/constants.go
+++ b/relay/channel/claude/constants.go
@@ -17,6 +17,8 @@ var ModelList = []string{
"claude-sonnet-4-20250514-thinking",
"claude-opus-4-20250514",
"claude-opus-4-20250514-thinking",
+ "claude-opus-4-1-20250805",
+ "claude-opus-4-1-20250805-thinking",
}
var ChannelName = "claude"
diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go
index 95e7c4be..0c445bb9 100644
--- a/relay/channel/claude/relay-claude.go
+++ b/relay/channel/claude/relay-claude.go
@@ -7,15 +7,24 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
+ "one-api/relay/channel/openrouter"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
+ "one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
+const (
+ WebSearchMaxUsesLow = 1
+ WebSearchMaxUsesMedium = 5
+ WebSearchMaxUsesHigh = 10
+)
+
func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
@@ -48,9 +57,9 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla
prompt := ""
for _, message := range textRequest.Messages {
if message.Role == "user" {
- prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
+ prompt += fmt.Sprintf("\n\nHuman: %s", message.StringContent())
} else if message.Role == "assistant" {
- prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
+ prompt += fmt.Sprintf("\n\nAssistant: %s", message.StringContent())
} else if message.Role == "system" {
if prompt == "" {
prompt = message.StringContent()
@@ -62,8 +71,8 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla
return &claudeRequest
}
-func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
- claudeTools := make([]dto.Tool, 0, len(textRequest.Tools))
+func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
+ claudeTools := make([]any, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok {
@@ -83,13 +92,65 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
}
claudeTool.InputSchema[s] = a
}
- claudeTools = append(claudeTools, claudeTool)
+ claudeTools = append(claudeTools, &claudeTool)
}
}
+ // Web search tool
+ // https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/web-search-tool
+ if textRequest.WebSearchOptions != nil {
+ webSearchTool := dto.ClaudeWebSearchTool{
+ Type: "web_search_20250305",
+ Name: "web_search",
+ }
+
+ // 处理 user_location
+ if textRequest.WebSearchOptions.UserLocation != nil {
+ anthropicUserLocation := &dto.ClaudeWebSearchUserLocation{
+ Type: "approximate", // 固定为 "approximate"
+ }
+
+ // 解析 UserLocation JSON
+ var userLocationMap map[string]interface{}
+ if err := json.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil {
+ // 检查是否有 approximate 字段
+ if approximateData, ok := userLocationMap["approximate"].(map[string]interface{}); ok {
+ if timezone, ok := approximateData["timezone"].(string); ok && timezone != "" {
+ anthropicUserLocation.Timezone = timezone
+ }
+ if country, ok := approximateData["country"].(string); ok && country != "" {
+ anthropicUserLocation.Country = country
+ }
+ if region, ok := approximateData["region"].(string); ok && region != "" {
+ anthropicUserLocation.Region = region
+ }
+ if city, ok := approximateData["city"].(string); ok && city != "" {
+ anthropicUserLocation.City = city
+ }
+ }
+ }
+
+ webSearchTool.UserLocation = anthropicUserLocation
+ }
+
+ // 处理 search_context_size 转换为 max_uses
+ if textRequest.WebSearchOptions.SearchContextSize != "" {
+ switch textRequest.WebSearchOptions.SearchContextSize {
+ case "low":
+ webSearchTool.MaxUses = WebSearchMaxUsesLow
+ case "medium":
+ webSearchTool.MaxUses = WebSearchMaxUsesMedium
+ case "high":
+ webSearchTool.MaxUses = WebSearchMaxUsesHigh
+ }
+ }
+
+ claudeTools = append(claudeTools, &webSearchTool)
+ }
+
claudeRequest := dto.ClaudeRequest{
Model: textRequest.Model,
- MaxTokens: textRequest.MaxTokens,
+ MaxTokens: textRequest.GetMaxTokens(),
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
@@ -98,6 +159,14 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
Tools: claudeTools,
}
+ // 处理 tool_choice 和 parallel_tool_calls
+ if textRequest.ToolChoice != nil || textRequest.ParallelTooCalls != nil {
+ claudeToolChoice := mapToolChoice(textRequest.ToolChoice, textRequest.ParallelTooCalls)
+ if claudeToolChoice != nil {
+ claudeRequest.ToolChoice = claudeToolChoice
+ }
+ }
+
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
}
@@ -113,7 +182,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
// BudgetTokens 为 max_tokens 的 80%
claudeRequest.Thinking = &dto.Thinking{
Type: "enabled",
- BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
+ BudgetTokens: common.GetPointer[int](int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
@@ -122,6 +191,42 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
}
+ if textRequest.ReasoningEffort != "" {
+ switch textRequest.ReasoningEffort {
+ case "low":
+ claudeRequest.Thinking = &dto.Thinking{
+ Type: "enabled",
+ BudgetTokens: common.GetPointer[int](1280),
+ }
+ case "medium":
+ claudeRequest.Thinking = &dto.Thinking{
+ Type: "enabled",
+ BudgetTokens: common.GetPointer[int](2048),
+ }
+ case "high":
+ claudeRequest.Thinking = &dto.Thinking{
+ Type: "enabled",
+ BudgetTokens: common.GetPointer[int](4096),
+ }
+ }
+ }
+
+ // 指定了 reasoning 参数,覆盖 budgetTokens
+ if textRequest.Reasoning != nil {
+ var reasoning openrouter.RequestReasoning
+ if err := common.Unmarshal(textRequest.Reasoning, &reasoning); err != nil {
+ return nil, err
+ }
+
+ budgetTokens := reasoning.MaxTokens
+ if budgetTokens > 0 {
+ claudeRequest.Thinking = &dto.Thinking{
+ Type: "enabled",
+ BudgetTokens: &budgetTokens,
+ }
+ }
+ }
+
if textRequest.Stop != nil {
// stop maybe string/array string, convert to array string
switch textRequest.Stop.(type) {
@@ -155,15 +260,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
}
if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
if lastMessage.IsStringContent() && message.IsStringContent() {
- content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
- fmtMessage.Content = content
+ fmtMessage.SetStringContent(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
// delete last message
formatMessages = formatMessages[:len(formatMessages)-1]
}
}
if fmtMessage.Content == nil {
- content, _ := json.Marshal("...")
- fmtMessage.Content = content
+ fmtMessage.SetStringContent("...")
}
formatMessages = append(formatMessages, fmtMessage)
lastMessage = fmtMessage
@@ -252,7 +355,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
// 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") {
// 是url,获取图片的类型和base64编码的数据
- fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
+ fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Claude")
if err != nil {
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
}
@@ -273,7 +376,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
for _, toolCall := range message.ParseToolCalls() {
inputObj := make(map[string]any)
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
- common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
+ common.SysLog("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
continue
}
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
@@ -397,12 +500,11 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto
thinkingContent := ""
if reqMode == RequestModeCompletion {
- content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
@@ -457,6 +559,7 @@ type ClaudeResponseInfo struct {
Model string
ResponseText strings.Builder
Usage *dto.Usage
+ Done bool
}
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
@@ -464,20 +567,32 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
- // message_start, 获取usage
claudeInfo.ResponseId = claudeResponse.Message.Id
claudeInfo.Model = claudeResponse.Message.Model
+
+ // message_start, 获取usage
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
+ claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
+ claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
+ claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta.Text != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
}
+ if claudeResponse.Delta.Thinking != "" {
+ claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
+ }
} else if claudeResponse.Type == "message_delta" {
- claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+ // 最终的usage获取
if claudeResponse.Usage.InputTokens > 0 {
+ // 不叠加,只取最新的
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
- claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
+ claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+ claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
+
+ // 判断是否完整
+ claudeInfo.Done = true
} else if claudeResponse.Type == "content_block_start" {
} else {
return false
@@ -491,47 +606,30 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
return true
}
-func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
+func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *types.NewAPIError {
var claudeResponse dto.ClaudeResponse
- err := common.DecodeJsonStr(data, &claudeResponse)
+ err := common.UnmarshalJsonStr(data, &claudeResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
- return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
+ common.SysLog("error unmarshalling stream response: " + err.Error())
+ return types.NewError(err, types.ErrorCodeBadResponseBody)
}
- if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
- return &dto.OpenAIErrorWithStatusCode{
- Error: dto.OpenAIError{
- Code: "stream_response_error",
- Type: claudeResponse.Error.Type,
- Message: claudeResponse.Error.Message,
- },
- StatusCode: http.StatusInternalServerError,
- }
+ if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
+ return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
}
- if info.RelayFormat == relaycommon.RelayFormatClaude {
+ if info.RelayFormat == types.RelayFormatClaude {
+ FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
+
if requestMode == RequestModeCompletion {
- claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
info.UpstreamModelName = claudeResponse.Message.Model
- claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
- claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
- claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
- claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_delta" {
- claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
} else if claudeResponse.Type == "message_delta" {
- if claudeResponse.Usage.InputTokens > 0 {
- // 不叠加,只取最新的
- claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
- }
- claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
- claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
}
}
helper.ClaudeChunkData(c, claudeResponse, data)
- } else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
+ } else if info.RelayFormat == types.RelayFormatOpenAI {
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
@@ -540,56 +638,51 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
err = helper.ObjectData(c, response)
if err != nil {
- common.LogError(c, "send_stream_response_failed: "+err.Error())
+ logger.LogError(c, "send_stream_response_failed: "+err.Error())
}
}
return nil
}
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
- if info.RelayFormat == relaycommon.RelayFormatClaude {
- if requestMode == RequestModeCompletion {
- claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
- } else {
- // 说明流模式建立失败,可能为官方出错
- if claudeInfo.Usage.PromptTokens == 0 {
- //usage.PromptTokens = info.PromptTokens
- }
- if claudeInfo.Usage.CompletionTokens == 0 {
- claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
- }
+
+ if requestMode == RequestModeCompletion {
+ claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
+ } else {
+ if claudeInfo.Usage.PromptTokens == 0 {
+ //上游出错
}
- } else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
- if requestMode == RequestModeCompletion {
- claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
- } else {
- if claudeInfo.Usage.PromptTokens == 0 {
- //上游出错
- }
- if claudeInfo.Usage.CompletionTokens == 0 {
- claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
+ if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
+ if common.DebugEnabled {
+ common.SysLog("claude response usage is not complete, maybe upstream error")
}
+ claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
+ }
+
+ if info.RelayFormat == types.RelayFormatClaude {
+ //
+ } else if info.RelayFormat == types.RelayFormatOpenAI {
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response)
if err != nil {
- common.SysError("send final response failed: " + err.Error())
+ common.SysLog("send final response failed: " + err.Error())
}
}
helper.Done(c)
}
}
-func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) {
claudeInfo := &ClaudeResponseInfo{
- ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+ ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
- var err *dto.OpenAIErrorWithStatusCode
+ var err *types.NewAPIError
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
if err != nil {
@@ -598,34 +691,24 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
return true
})
if err != nil {
- return err, nil
+ return nil, err
}
HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
- return nil, claudeInfo.Usage
+ return claudeInfo.Usage, nil
}
-func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
+func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError {
var claudeResponse dto.ClaudeResponse
- err := common.DecodeJson(data, &claudeResponse)
+ err := common.Unmarshal(data, &claudeResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
+ return types.NewError(err, types.ErrorCodeBadResponseBody)
}
- if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
- return &dto.OpenAIErrorWithStatusCode{
- Error: dto.OpenAIError{
- Message: claudeResponse.Error.Message,
- Type: claudeResponse.Error.Type,
- Code: claudeResponse.Error.Type,
- },
- StatusCode: http.StatusInternalServerError,
- }
+ if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
+ return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
}
if requestMode == RequestModeCompletion {
- completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
- }
+ completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
claudeInfo.Usage.PromptTokens = info.PromptTokens
claudeInfo.Usage.CompletionTokens = completionTokens
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
@@ -638,25 +721,30 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
}
var responseData []byte
switch info.RelayFormat {
- case relaycommon.RelayFormatOpenAI:
+ case types.RelayFormatOpenAI:
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
openaiResponse.Usage = *claudeInfo.Usage
responseData, err = json.Marshal(openaiResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
+ return types.NewError(err, types.ErrorCodeBadResponseBody)
}
- case relaycommon.RelayFormatClaude:
+ case types.RelayFormatClaude:
responseData = data
}
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(http.StatusOK)
- _, err = c.Writer.Write(responseData)
+
+ if claudeResponse.Usage.ServerToolUse != nil && claudeResponse.Usage.ServerToolUse.WebSearchRequests > 0 {
+ c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
+ }
+
+ service.IOCopyBytesGracefully(c, nil, responseData)
return nil
}
-func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) {
+ defer service.CloseResponseBodyGracefully(resp)
+
claudeInfo := &ClaudeResponseInfo{
- ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+ ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
@@ -664,15 +752,62 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
- resp.Body.Close()
if common.DebugEnabled {
println("responseBody: ", string(responseBody))
}
handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
if handleErr != nil {
- return handleErr, nil
+ return nil, handleErr
}
- return nil, claudeInfo.Usage
+ return claudeInfo.Usage, nil
+}
+
+func mapToolChoice(toolChoice any, parallelToolCalls *bool) *dto.ClaudeToolChoice {
+ var claudeToolChoice *dto.ClaudeToolChoice
+
+ // 处理 tool_choice 字符串值
+ if toolChoiceStr, ok := toolChoice.(string); ok {
+ switch toolChoiceStr {
+ case "auto":
+ claudeToolChoice = &dto.ClaudeToolChoice{
+ Type: "auto",
+ }
+ case "required":
+ claudeToolChoice = &dto.ClaudeToolChoice{
+ Type: "any",
+ }
+ case "none":
+ claudeToolChoice = &dto.ClaudeToolChoice{
+ Type: "none",
+ }
+ }
+ } else if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
+ // 处理 tool_choice 对象值
+ if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
+ if toolName, ok := function["name"].(string); ok {
+ claudeToolChoice = &dto.ClaudeToolChoice{
+ Type: "tool",
+ Name: toolName,
+ }
+ }
+ }
+ }
+
+ // 处理 parallel_tool_calls
+ if parallelToolCalls != nil {
+ if claudeToolChoice == nil {
+ // 如果没有 tool_choice,但有 parallel_tool_calls,创建默认的 auto 类型
+ claudeToolChoice = &dto.ClaudeToolChoice{
+ Type: "auto",
+ }
+ }
+
+ // 设置 disable_parallel_tool_use
+ // 如果 parallel_tool_calls 为 true,则 disable_parallel_tool_use 为 false
+ claudeToolChoice.DisableParallelToolUse = !*parallelToolCalls
+ }
+
+ return claudeToolChoice
}
diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go
index 06f4ca34..bdea72f0 100644
--- a/relay/channel/cloudflare/adaptor.go
+++ b/relay/channel/cloudflare/adaptor.go
@@ -8,8 +8,10 @@ import (
"net/http"
"one-api/dto"
"one-api/relay/channel"
+ "one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
@@ -17,6 +19,11 @@ import (
type Adaptor struct {
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
@@ -29,11 +36,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
- return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
+ return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.ChannelBaseUrl, info.ApiVersion), nil
case constant.RelayModeEmbeddings:
- return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
+ return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.ChannelBaseUrl, info.ApiVersion), nil
+ case constant.RelayModeResponses:
+ return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.ChannelBaseUrl, info.ApiVersion), nil
default:
- return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
+ return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.ChannelBaseUrl, info.ApiVersion, info.UpstreamModelName), nil
}
}
@@ -56,8 +65,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // TODO implement me
- return nil, errors.New("not implemented")
+ return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
@@ -94,20 +102,26 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
return nil, errors.New("not implemented")
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
fallthrough
case constant.RelayModeChatCompletions:
if info.IsStream {
- err, usage = cfStreamHandler(c, resp, info)
+ err, usage = cfStreamHandler(c, info, resp)
} else {
- err, usage = cfHandler(c, resp, info)
+ err, usage = cfHandler(c, info, resp)
+ }
+ case constant.RelayModeResponses:
+ if info.IsStream {
+ usage, err = openai.OaiResponsesStreamHandler(c, info, resp)
+ } else {
+ usage, err = openai.OaiResponsesHandler(c, info, resp)
}
case constant.RelayModeAudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
- err, usage = cfSTTHandler(c, resp, info)
+ err, usage = cfSTTHandler(c, info, resp)
}
return
}
diff --git a/relay/channel/cloudflare/dto.go b/relay/channel/cloudflare/dto.go
index 62a45c40..72b40615 100644
--- a/relay/channel/cloudflare/dto.go
+++ b/relay/channel/cloudflare/dto.go
@@ -5,7 +5,7 @@ import "one-api/dto"
type CfRequest struct {
Messages []dto.Message `json:"messages,omitempty"`
Lora string `json:"lora,omitempty"`
- MaxTokens int `json:"max_tokens,omitempty"`
+ MaxTokens uint `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go
index a487429c..00f6b6c5 100644
--- a/relay/channel/cloudflare/relay_cloudflare.go
+++ b/relay/channel/cloudflare/relay_cloudflare.go
@@ -3,16 +3,18 @@ package cloudflare
import (
"bufio"
"encoding/json"
- "github.com/gin-gonic/gin"
"io"
"net/http"
- "one-api/common"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
+ "one-api/types"
"strings"
"time"
+
+ "github.com/gin-gonic/gin"
)
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
@@ -25,7 +27,7 @@ func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfReque
}
}
-func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
@@ -49,7 +51,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
var response dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)
if err != nil {
- common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
+ logger.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
continue
}
for _, choice := range response.Choices {
@@ -64,56 +66,50 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
info.FirstResponseTime = time.Now()
}
if err != nil {
- common.LogError(c, "error_rendering_stream_response: "+err.Error())
+ logger.LogError(c, "error_rendering_stream_response: "+err.Error())
}
}
if err := scanner.Err(); err != nil {
- common.LogError(c, "error_scanning_stream_response: "+err.Error())
+ logger.LogError(c, "error_scanning_stream_response: "+err.Error())
}
- usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
if err != nil {
- common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
+ logger.LogError(c, "error_rendering_final_usage_response: "+err.Error())
}
}
helper.Done(c)
- err := resp.Body.Close()
- if err != nil {
- common.LogError(c, "close_response_body_failed: "+err.Error())
- }
+ service.CloseResponseBodyGracefully(resp)
return nil, usage
}
-func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
+ service.CloseResponseBodyGracefully(resp)
var response dto.TextResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
response.Model = info.UpstreamModelName
var responseText string
for _, choice := range response.Choices {
responseText += choice.Message.StringContent()
}
- usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
response.Usage = *usage
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -121,19 +117,16 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
return nil, usage
}
-func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var cfResp CfAudioResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &cfResp)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
audioResp := &dto.AudioResponse{
@@ -142,7 +135,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
jsonResponse, err := json.Marshal(audioResp)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
@@ -150,7 +143,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
- usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
+ usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go
index a93b10f6..c8a38d46 100644
--- a/relay/channel/cohere/adaptor.go
+++ b/relay/channel/cohere/adaptor.go
@@ -9,6 +9,7 @@ import (
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
@@ -16,6 +17,11 @@ import (
type Adaptor struct {
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
@@ -37,9 +43,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
- return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
} else {
- return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1/chat", info.ChannelBaseUrl), nil
}
}
@@ -71,14 +77,14 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeRerank {
- err, usage = cohereRerankHandler(c, resp, info)
+ usage, err = cohereRerankHandler(c, resp, info)
} else {
if info.IsStream {
- err, usage = cohereStreamHandler(c, resp, info)
+ usage, err = cohereStreamHandler(c, info, resp) // TODO: fix this
} else {
- err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
+ usage, err = cohereHandler(c, info, resp)
}
}
return
diff --git a/relay/channel/cohere/dto.go b/relay/channel/cohere/dto.go
index 410540c0..d5127963 100644
--- a/relay/channel/cohere/dto.go
+++ b/relay/channel/cohere/dto.go
@@ -7,7 +7,7 @@ type CohereRequest struct {
ChatHistory []ChatHistory `json:"chat_history"`
Message string `json:"message"`
Stream bool `json:"stream"`
- MaxTokens int `json:"max_tokens"`
+ MaxTokens uint `json:"max_tokens"`
SafetyMode string `json:"safety_mode,omitempty"`
}
diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go
index 17b58dbc..af357348 100644
--- a/relay/channel/cohere/relay-cohere.go
+++ b/relay/channel/cohere/relay-cohere.go
@@ -3,8 +3,6 @@ package cohere
import (
"bufio"
"encoding/json"
- "fmt"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -12,8 +10,11 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
+ "one-api/types"
"strings"
"time"
+
+ "github.com/gin-gonic/gin"
)
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
@@ -77,8 +78,8 @@ func stopReasonCohere2OpenAI(reason string) string {
}
}
-func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
- responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ responseId := helper.GetResponseID(c)
createdTime := common.GetTimestamp()
usage := &dto.Usage{}
responseText := ""
@@ -117,7 +118,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
var cohereResp CohereResponse
err := json.Unmarshal([]byte(data), &cohereResp)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ common.SysLog("error unmarshalling stream response: " + err.Error())
return true
}
var openaiResp dto.ChatCompletionsStreamResponse
@@ -152,7 +153,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
jsonStr, err := json.Marshal(openaiResp)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ common.SysLog("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
@@ -163,25 +164,22 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
})
if usage.PromptTokens == 0 {
- usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
- return nil, usage
+ return usage, nil
}
-func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
createdTime := common.GetTimestamp()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
+ service.CloseResponseBodyGracefully(resp)
var cohereResp CohereResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
usage := dto.Usage{}
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
@@ -192,41 +190,37 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
openaiResp.Id = cohereResp.ResponseId
openaiResp.Created = createdTime
openaiResp.Object = "chat.completion"
- openaiResp.Model = modelName
+ openaiResp.Model = info.UpstreamModelName
openaiResp.Usage = usage
- content, _ := json.Marshal(cohereResp.Text)
openaiResp.Choices = []dto.OpenAITextResponseChoice{
{
Index: 0,
- Message: dto.Message{Content: content, Role: "assistant"},
+ Message: dto.Message{Content: cohereResp.Text, Role: "assistant"},
FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
},
}
jsonResponse, err := json.Marshal(openaiResp)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return nil, &usage
+ _, _ = c.Writer.Write(jsonResponse)
+ return &usage, nil
}
-func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
+ service.CloseResponseBodyGracefully(resp)
var cohereResp CohereRerankResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
usage := dto.Usage{}
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
@@ -245,10 +239,10 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
jsonResponse, err := json.Marshal(rerankResp)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
- return nil, &usage
+ return &usage, nil
}
diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go
index 80441a51..0f2a6fd3 100644
--- a/relay/channel/coze/adaptor.go
+++ b/relay/channel/coze/adaptor.go
@@ -9,6 +9,7 @@ import (
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/common"
+ "one-api/types"
"time"
"github.com/gin-gonic/gin"
@@ -17,6 +18,11 @@ import (
type Adaptor struct {
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *common.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
// ConvertAudioRequest implements channel.Adaptor.
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
return nil, errors.New("not implemented")
@@ -95,11 +101,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody
}
// DoResponse implements channel.Adaptor.
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
- err, usage = cozeChatStreamHandler(c, resp, info)
+ usage, err = cozeChatStreamHandler(c, info, resp)
} else {
- err, usage = cozeChatHandler(c, resp, info)
+ usage, err = cozeChatHandler(c, info, resp)
}
return
}
@@ -116,7 +122,7 @@ func (a *Adaptor) GetModelList() []string {
// GetRequestURL implements channel.Adaptor.
func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) {
- return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v3/chat", info.ChannelBaseUrl), nil
}
// Init implements channel.Adaptor.
diff --git a/relay/channel/coze/dto.go b/relay/channel/coze/dto.go
index 4e9afa23..d5dc9a81 100644
--- a/relay/channel/coze/dto.go
+++ b/relay/channel/coze/dto.go
@@ -10,7 +10,7 @@ type CozeError struct {
type CozeEnterMessage struct {
Role string `json:"role"`
Type string `json:"type,omitempty"`
- Content json.RawMessage `json:"content,omitempty"`
+ Content any `json:"content,omitempty"`
MetaData json.RawMessage `json:"meta_data,omitempty"`
ContentType string `json:"content_type,omitempty"`
}
diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go
index 6db40213..c480045f 100644
--- a/relay/channel/coze/relay-coze.go
+++ b/relay/channel/coze/relay-coze.go
@@ -12,6 +12,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
+ "one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -43,25 +44,22 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C
return cozeRequest
}
-func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
+ service.CloseResponseBodyGracefully(resp)
// convert coze response to openai response
var response dto.TextResponse
var cozeResponse CozeChatDetailResponse
response.Model = info.UpstreamModelName
err = json.Unmarshal(responseBody, &cozeResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if cozeResponse.Code != 0 {
- return service.OpenAIErrorWrapper(errors.New(cozeResponse.Msg), fmt.Sprintf("%d", cozeResponse.Code), http.StatusInternalServerError), nil
+ return nil, types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody)
}
// 从上下文获取 usage
var usage dto.Usage
@@ -88,16 +86,16 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
}
jsonResponse, err := json.Marshal(response)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
- return nil, &usage
+ return &usage, nil
}
-func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c)
@@ -106,7 +104,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
var currentEvent string
var currentData string
- var usage dto.Usage
+ var usage = &dto.Usage{}
for scanner.Scan() {
line := scanner.Text()
@@ -114,7 +112,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
if line == "" {
if currentEvent != "" && currentData != "" {
// handle last event
- handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
+ handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
currentEvent = ""
currentData = ""
}
@@ -134,21 +132,19 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
// Last event
if currentEvent != "" && currentData != "" {
- handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
+ handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
}
if err := scanner.Err(); err != nil {
- return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
helper.Done(c)
if usage.TotalTokens == 0 {
- usage.PromptTokens = info.PromptTokens
- usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
}
- return nil, &usage
+ return usage, nil
}
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
@@ -158,7 +154,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
var chatData CozeChatResponseData
err := json.Unmarshal([]byte(data), &chatData)
if err != nil {
- common.SysError("error_unmarshalling_stream_response: " + err.Error())
+ common.SysLog("error_unmarshalling_stream_response: " + err.Error())
return
}
@@ -175,14 +171,14 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
var messageData CozeChatV3MessageDetail
err := json.Unmarshal([]byte(data), &messageData)
if err != nil {
- common.SysError("error_unmarshalling_stream_response: " + err.Error())
+ common.SysLog("error_unmarshalling_stream_response: " + err.Error())
return
}
var content string
err = json.Unmarshal(messageData.Content, &content)
if err != nil {
- common.SysError("error_unmarshalling_stream_response: " + err.Error())
+ common.SysLog("error_unmarshalling_stream_response: " + err.Error())
return
}
@@ -207,16 +203,16 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
var errorData CozeError
err := json.Unmarshal([]byte(data), &errorData)
if err != nil {
- common.SysError("error_unmarshalling_stream_response: " + err.Error())
+ common.SysLog("error_unmarshalling_stream_response: " + err.Error())
return
}
- common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
+ common.SysLog(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
}
}
func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
- requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
+ requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.ChannelBaseUrl)
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
// 将 conversationId和chatId作为参数发送get请求
@@ -262,7 +258,7 @@ func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo
}
func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) {
- requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl)
+ requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.ChannelBaseUrl)
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
req, err := http.NewRequest("GET", requestURL, nil)
@@ -283,8 +279,8 @@ func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*ht
func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
var client *http.Client
var err error // 声明 err 变量
- if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
- client, err = service.NewProxyHttpClient(proxyURL.(string))
+ if info.ChannelSetting.Proxy != "" {
+ client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go
index 76e7fa8d..17d732ab 100644
--- a/relay/channel/deepseek/adaptor.go
+++ b/relay/channel/deepseek/adaptor.go
@@ -10,6 +10,7 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
+ "one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -18,10 +19,14 @@ import (
type Adaptor struct {
}
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
- panic("implement me")
- return nil, nil
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
+ adaptor := openai.Adaptor{}
+ return adaptor.ConvertClaudeRequest(c, info, req)
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -38,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- fimBaseUrl := info.BaseUrl
- if !strings.HasSuffix(info.BaseUrl, "/beta") {
+ fimBaseUrl := info.ChannelBaseUrl
+ if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
fimBaseUrl += "/beta"
}
switch info.RelayMode {
case constant.RelayModeCompletions:
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
default:
- return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
}
}
@@ -81,11 +86,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
- err, usage = openai.OaiStreamHandler(c, resp, info)
+ usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
- err, usage = openai.OpenaiHandler(c, resp, info)
+ usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}
diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go
index 51dbee71..0a08d035 100644
--- a/relay/channel/dify/adaptor.go
+++ b/relay/channel/dify/adaptor.go
@@ -8,6 +8,7 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
@@ -23,6 +24,11 @@ type Adaptor struct {
BotType int
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
@@ -55,13 +61,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch a.BotType {
case BotTypeWorkFlow:
- return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1/workflows/run", info.ChannelBaseUrl), nil
case BotTypeCompletion:
- return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1/completion-messages", info.ChannelBaseUrl), nil
case BotTypeAgent:
fallthrough
default:
- return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1/chat-messages", info.ChannelBaseUrl), nil
}
}
@@ -96,11 +102,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
- err, usage = difyStreamHandler(c, resp, info)
+ return difyStreamHandler(c, info, resp)
} else {
- err, usage = difyHandler(c, resp, info)
+ return difyHandler(c, info, resp)
}
return
}
diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go
index b58fbe53..2336fd4c 100644
--- a/relay/channel/dify/relay-dify.go
+++ b/relay/channel/dify/relay-dify.go
@@ -14,6 +14,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
+ "one-api/types"
"os"
"strings"
@@ -21,7 +22,7 @@ import (
)
func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile {
- uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl)
+ uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.ChannelBaseUrl)
switch media.Type {
case dto.ContentTypeImageURL:
// Decode base64 data
@@ -35,14 +36,14 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Decode base64 string
decodedData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
- common.SysError("failed to decode base64: " + err.Error())
+ common.SysLog("failed to decode base64: " + err.Error())
return nil
}
// Create temporary file
tempFile, err := os.CreateTemp("", "dify-upload-*")
if err != nil {
- common.SysError("failed to create temp file: " + err.Error())
+ common.SysLog("failed to create temp file: " + err.Error())
return nil
}
defer tempFile.Close()
@@ -50,7 +51,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Write decoded data to temp file
if _, err := tempFile.Write(decodedData); err != nil {
- common.SysError("failed to write to temp file: " + err.Error())
+ common.SysLog("failed to write to temp file: " + err.Error())
return nil
}
@@ -60,7 +61,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Add user field
if err := writer.WriteField("user", user); err != nil {
- common.SysError("failed to add user field: " + err.Error())
+ common.SysLog("failed to add user field: " + err.Error())
return nil
}
@@ -73,13 +74,13 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Create form file
part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
if err != nil {
- common.SysError("failed to create form file: " + err.Error())
+ common.SysLog("failed to create form file: " + err.Error())
return nil
}
// Copy file content to form
if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
- common.SysError("failed to copy file content: " + err.Error())
+ common.SysLog("failed to copy file content: " + err.Error())
return nil
}
writer.Close()
@@ -87,7 +88,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Create HTTP request
req, err := http.NewRequest("POST", uploadUrl, body)
if err != nil {
- common.SysError("failed to create request: " + err.Error())
+ common.SysLog("failed to create request: " + err.Error())
return nil
}
@@ -95,10 +96,10 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
// Send request
- client := service.GetImpatientHttpClient()
+ client := service.GetHttpClient()
resp, err := client.Do(req)
if err != nil {
- common.SysError("failed to send request: " + err.Error())
+ common.SysLog("failed to send request: " + err.Error())
return nil
}
defer resp.Body.Close()
@@ -108,7 +109,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
Id string `json:"id"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
- common.SysError("failed to decode response: " + err.Error())
+ common.SysLog("failed to decode response: " + err.Error())
return nil
}
@@ -209,7 +210,7 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
return &response
}
-func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var responseText string
usage := &dto.Usage{}
var nodeToken int
@@ -218,7 +219,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
var difyResponse DifyChunkChatCompletionResponse
err := json.Unmarshal([]byte(data), &difyResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ common.SysLog("error unmarshalling stream response: " + err.Error())
return true
}
var openaiResponse dto.ChatCompletionsStreamResponse
@@ -238,39 +239,29 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
}
err = helper.ObjectData(c, openaiResponse)
if err != nil {
- common.SysError(err.Error())
+ common.SysLog(err.Error())
}
return true
})
helper.Done(c)
- err := resp.Body.Close()
- if err != nil {
- // return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- common.SysError("close_response_body_failed: " + err.Error())
- }
if usage.TotalTokens == 0 {
- usage.PromptTokens = info.PromptTokens
- usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
usage.CompletionTokens += nodeToken
- return nil, usage
+ return usage, nil
}
-func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var difyResponse DifyChatCompletionResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &difyResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
fullTextResponse := dto.OpenAITextResponse{
Id: difyResponse.ConversationId,
@@ -278,22 +269,21 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf
Created: common.GetTimestamp(),
Usage: difyResponse.MetaData.Usage,
}
- content, _ := json.Marshal(difyResponse.Answer)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: difyResponse.Answer,
},
FinishReason: "stop",
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return nil, &difyResponse.MetaData.Usage
+ c.Writer.Write(jsonResponse)
+ return &difyResponse.MetaData.Usage, nil
}
diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
index e6f66d5f..4968f78f 100644
--- a/relay/channel/gemini/adaptor.go
+++ b/relay/channel/gemini/adaptor.go
@@ -1,18 +1,17 @@
package gemini
import (
- "encoding/json"
"errors"
"fmt"
"io"
"net/http"
- "one-api/common"
"one-api/dto"
"one-api/relay/channel"
+ "one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
- "one-api/service"
"one-api/setting/model_setting"
+ "one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -21,10 +20,33 @@ import (
type Adaptor struct {
}
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- //TODO implement me
- panic("implement me")
- return nil, nil
+func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
+ if len(request.Contents) > 0 {
+ for i, content := range request.Contents {
+ if i == 0 {
+ if request.Contents[0].Role == "" {
+ request.Contents[0].Role = "user"
+ }
+ }
+ for _, part := range content.Parts {
+ if part.FileData != nil {
+ if part.FileData.MimeType == "" && strings.Contains(part.FileData.FileUri, "www.youtube.com") {
+ part.FileData.MimeType = "video/webm"
+ }
+ }
+ }
+ }
+ }
+ return request, nil
+}
+
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
+ adaptor := openai.Adaptor{}
+ oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req)
+ if err != nil {
+ return nil, err
+ }
+ return a.ConvertOpenAIRequest(c, info, oaiReq.(*dto.GeneralOpenAIRequest))
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -37,26 +59,33 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
return nil, errors.New("not supported model for image generation")
}
- // convert size to aspect ratio
+ // convert size to aspect ratio but allow user to specify aspect ratio
aspectRatio := "1:1" // default aspect ratio
- switch request.Size {
- case "1024x1024":
- aspectRatio = "1:1"
- case "1024x1792":
- aspectRatio = "9:16"
- case "1792x1024":
- aspectRatio = "16:9"
+ size := strings.TrimSpace(request.Size)
+ if size != "" {
+ if strings.Contains(size, ":") {
+ aspectRatio = size
+ } else {
+ switch size {
+ case "1024x1024":
+ aspectRatio = "1:1"
+ case "1024x1792":
+ aspectRatio = "9:16"
+ case "1792x1024":
+ aspectRatio = "16:9"
+ }
+ }
}
// build gemini imagen request
- geminiRequest := GeminiImageRequest{
- Instances: []GeminiImageInstance{
+ geminiRequest := dto.GeminiImageRequest{
+ Instances: []dto.GeminiImageInstance{
{
Prompt: request.Prompt,
},
},
- Parameters: GeminiImageParameters{
- SampleCount: request.N,
+ Parameters: dto.GeminiImageParameters{
+ SampleCount: int(request.N),
AspectRatio: aspectRatio,
PersonGeneration: "allow_adult", // default allow adult
},
@@ -72,10 +101,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
- // suffix -thinking and -nothinking
- if strings.HasSuffix(info.OriginModelName, "-thinking") {
+ // 新增逻辑:处理 -thinking- 格式
+ if strings.Contains(info.UpstreamModelName, "-thinking-") {
+ parts := strings.Split(info.UpstreamModelName, "-thinking-")
+ info.UpstreamModelName = parts[0]
+ } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
- } else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
+ } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
}
}
@@ -83,20 +115,27 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
- return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
+ return fmt.Sprintf("%s/%s/models/%s:predict", info.ChannelBaseUrl, version, info.UpstreamModelName), nil
}
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
- return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil
+ action := "embedContent"
+ if info.IsGeminiBatchEmbedding {
+ action = "batchEmbedContents"
+ }
+ return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil
}
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent?alt=sse"
+ if info.RelayMode == constant.RelayModeGemini {
+ info.DisablePing = true
+ }
}
- return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
+ return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -110,7 +149,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
return nil, errors.New("request is nil")
}
- geminiRequest, err := CovertGemini2OpenAI(*request, info)
+ geminiRequest, err := CovertGemini2OpenAI(c, *request, info)
if err != nil {
return nil, err
}
@@ -131,29 +170,38 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
if len(inputs) == 0 {
return nil, errors.New("input is empty")
}
-
- // only process the first input
- geminiRequest := GeminiEmbeddingRequest{
- Content: GeminiChatContent{
- Parts: []GeminiPart{
- {
- Text: inputs[0],
+ // We always build a batch-style payload with `requests`, so ensure we call the
+ // batch endpoint upstream to avoid payload/endpoint mismatches.
+ info.IsGeminiBatchEmbedding = true
+ // process all inputs
+ geminiRequests := make([]map[string]interface{}, 0, len(inputs))
+ for _, input := range inputs {
+ geminiRequest := map[string]interface{}{
+ "model": fmt.Sprintf("models/%s", info.UpstreamModelName),
+ "content": dto.GeminiChatContent{
+ Parts: []dto.GeminiPart{
+ {
+ Text: input,
+ },
},
},
- },
- }
-
- // set specific parameters for different models
- // https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
- switch info.UpstreamModelName {
- case "text-embedding-004":
- // except embedding-001 supports setting `OutputDimensionality`
- if request.Dimensions > 0 {
- geminiRequest.OutputDimensionality = request.Dimensions
}
+
+ // set specific parameters for different models
+ // https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
+ switch info.UpstreamModelName {
+ case "text-embedding-004", "gemini-embedding-exp-03-07", "gemini-embedding-001":
+ // Only newer models introduced after 2024 support OutputDimensionality
+ if request.Dimensions > 0 {
+ geminiRequest["outputDimensionality"] = request.Dimensions
+ }
+ }
+ geminiRequests = append(geminiRequests, geminiRequest)
}
- return geminiRequest, nil
+ return map[string]interface{}{
+ "requests": geminiRequests,
+ }, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
@@ -165,98 +213,36 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeGemini {
+ if strings.HasSuffix(info.RequestURLPath, ":embedContent") ||
+ strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") {
+ return NativeGeminiEmbeddingHandler(c, resp, info)
+ }
if info.IsStream {
- return GeminiTextGenerationStreamHandler(c, resp, info)
+ return GeminiTextGenerationStreamHandler(c, info, resp)
} else {
- return GeminiTextGenerationHandler(c, resp, info)
+ return GeminiTextGenerationHandler(c, info, resp)
}
}
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
- return GeminiImageHandler(c, resp, info)
+ return GeminiImageHandler(c, info, resp)
}
// check if the model is an embedding model
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
- return GeminiEmbeddingHandler(c, resp, info)
+ return GeminiEmbeddingHandler(c, info, resp)
}
if info.IsStream {
- err, usage = GeminiChatStreamHandler(c, resp, info)
+ return GeminiChatStreamHandler(c, info, resp)
} else {
- err, usage = GeminiChatHandler(c, resp, info)
+ return GeminiChatHandler(c, info, resp)
}
- //if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
- // // 没有请求-thinking的情况下,产生思考token,则按照思考模型计费
- // if !strings.HasSuffix(info.OriginModelName, "-thinking") &&
- // !strings.HasSuffix(info.OriginModelName, "-nothinking") {
- // thinkingModelName := info.OriginModelName + "-thinking"
- // if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) {
- // info.OriginModelName = thinkingModelName
- // }
- // }
- //}
-
- return
-}
-
-func GeminiImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
- responseBody, readErr := io.ReadAll(resp.Body)
- if readErr != nil {
- return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
- }
- _ = resp.Body.Close()
-
- var geminiResponse GeminiImageResponse
- if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
- return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
- }
-
- if len(geminiResponse.Predictions) == 0 {
- return nil, service.OpenAIErrorWrapper(errors.New("no images generated"), "no_images", http.StatusBadRequest)
- }
-
- // convert to openai format response
- openAIResponse := dto.ImageResponse{
- Created: common.GetTimestamp(),
- Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
- }
-
- for _, prediction := range geminiResponse.Predictions {
- if prediction.RaiFilteredReason != "" {
- continue // skip filtered image
- }
- openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
- B64Json: prediction.BytesBase64Encoded,
- })
- }
-
- jsonResponse, jsonErr := json.Marshal(openAIResponse)
- if jsonErr != nil {
- return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
- }
-
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, _ = c.Writer.Write(jsonResponse)
-
- // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
- // each image has fixed 258 tokens
- const imageTokens = 258
- generatedImages := len(openAIResponse.Data)
-
- usage = &dto.Usage{
- PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
- CompletionTokens: 0, // image generation does not calculate completion tokens
- TotalTokens: imageTokens * generatedImages,
- }
-
- return usage, nil
}
func (a *Adaptor) GetModelList() []string {
diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go
deleted file mode 100644
index a0e38cb4..00000000
--- a/relay/channel/gemini/dto.go
+++ /dev/null
@@ -1,167 +0,0 @@
-package gemini
-
-type GeminiChatRequest struct {
- Contents []GeminiChatContent `json:"contents"`
- SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
- GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
- Tools []GeminiChatTool `json:"tools,omitempty"`
- SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
-}
-
-type GeminiThinkingConfig struct {
- IncludeThoughts bool `json:"includeThoughts,omitempty"`
- ThinkingBudget *int `json:"thinkingBudget,omitempty"`
-}
-
-func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
- c.ThinkingBudget = &budget
-}
-
-type GeminiInlineData struct {
- MimeType string `json:"mimeType"`
- Data string `json:"data"`
-}
-
-type FunctionCall struct {
- FunctionName string `json:"name"`
- Arguments any `json:"args"`
-}
-
-type GeminiFunctionResponseContent struct {
- Name string `json:"name"`
- Content any `json:"content"`
-}
-
-type FunctionResponse struct {
- Name string `json:"name"`
- Response GeminiFunctionResponseContent `json:"response"`
-}
-
-type GeminiPartExecutableCode struct {
- Language string `json:"language,omitempty"`
- Code string `json:"code,omitempty"`
-}
-
-type GeminiPartCodeExecutionResult struct {
- Outcome string `json:"outcome,omitempty"`
- Output string `json:"output,omitempty"`
-}
-
-type GeminiFileData struct {
- MimeType string `json:"mimeType,omitempty"`
- FileUri string `json:"fileUri,omitempty"`
-}
-
-type GeminiPart struct {
- Text string `json:"text,omitempty"`
- Thought bool `json:"thought,omitempty"`
- InlineData *GeminiInlineData `json:"inlineData,omitempty"`
- FunctionCall *FunctionCall `json:"functionCall,omitempty"`
- FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
- FileData *GeminiFileData `json:"fileData,omitempty"`
- ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
- CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
-}
-
-type GeminiChatContent struct {
- Role string `json:"role,omitempty"`
- Parts []GeminiPart `json:"parts"`
-}
-
-type GeminiChatSafetySettings struct {
- Category string `json:"category"`
- Threshold string `json:"threshold"`
-}
-
-type GeminiChatTool struct {
- GoogleSearch any `json:"googleSearch,omitempty"`
- GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
- CodeExecution any `json:"codeExecution,omitempty"`
- FunctionDeclarations any `json:"functionDeclarations,omitempty"`
-}
-
-type GeminiChatGenerationConfig struct {
- Temperature *float64 `json:"temperature,omitempty"`
- TopP float64 `json:"topP,omitempty"`
- TopK float64 `json:"topK,omitempty"`
- MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
- CandidateCount int `json:"candidateCount,omitempty"`
- StopSequences []string `json:"stopSequences,omitempty"`
- ResponseMimeType string `json:"responseMimeType,omitempty"`
- ResponseSchema any `json:"responseSchema,omitempty"`
- Seed int64 `json:"seed,omitempty"`
- ResponseModalities []string `json:"responseModalities,omitempty"`
- ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
-}
-
-type GeminiChatCandidate struct {
- Content GeminiChatContent `json:"content"`
- FinishReason *string `json:"finishReason"`
- Index int64 `json:"index"`
- SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
-}
-
-type GeminiChatSafetyRating struct {
- Category string `json:"category"`
- Probability string `json:"probability"`
-}
-
-type GeminiChatPromptFeedback struct {
- SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
-}
-
-type GeminiChatResponse struct {
- Candidates []GeminiChatCandidate `json:"candidates"`
- PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
- UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
-}
-
-type GeminiUsageMetadata struct {
- PromptTokenCount int `json:"promptTokenCount"`
- CandidatesTokenCount int `json:"candidatesTokenCount"`
- TotalTokenCount int `json:"totalTokenCount"`
- ThoughtsTokenCount int `json:"thoughtsTokenCount"`
-}
-
-// Imagen related structs
-type GeminiImageRequest struct {
- Instances []GeminiImageInstance `json:"instances"`
- Parameters GeminiImageParameters `json:"parameters"`
-}
-
-type GeminiImageInstance struct {
- Prompt string `json:"prompt"`
-}
-
-type GeminiImageParameters struct {
- SampleCount int `json:"sampleCount,omitempty"`
- AspectRatio string `json:"aspectRatio,omitempty"`
- PersonGeneration string `json:"personGeneration,omitempty"`
-}
-
-type GeminiImageResponse struct {
- Predictions []GeminiImagePrediction `json:"predictions"`
-}
-
-type GeminiImagePrediction struct {
- MimeType string `json:"mimeType"`
- BytesBase64Encoded string `json:"bytesBase64Encoded"`
- RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
- SafetyAttributes any `json:"safetyAttributes,omitempty"`
-}
-
-// Embedding related structs
-type GeminiEmbeddingRequest struct {
- Content GeminiChatContent `json:"content"`
- TaskType string `json:"taskType,omitempty"`
- Title string `json:"title,omitempty"`
- OutputDimensionality int `json:"outputDimensionality,omitempty"`
-}
-
-type GeminiEmbeddingResponse struct {
- Embedding ContentEmbedding `json:"embedding"`
-}
-
-type ContentEmbedding struct {
- Values []float64 `json:"values"`
-}
diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go
index c055e299..974a22f5 100644
--- a/relay/channel/gemini/relay-gemini-native.go
+++ b/relay/channel/gemini/relay-gemini-native.go
@@ -1,27 +1,29 @@
package gemini
import (
- "encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
+ "one-api/types"
+ "strings"
+
+ "github.com/pkg/errors"
"github.com/gin-gonic/gin"
)
-func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
+func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ defer service.CloseResponseBodyGracefully(resp)
+
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
- }
- err = resp.Body.Close()
- if err != nil {
- return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if common.DebugEnabled {
@@ -29,60 +31,83 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
}
// 解析为 Gemini 原生响应格式
- var geminiResponse GeminiChatResponse
- err = common.DecodeJson(responseBody, &geminiResponse)
+ var geminiResponse dto.GeminiChatResponse
+ err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
- return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
- }
-
- // 检查是否有候选响应
- if len(geminiResponse.Candidates) == 0 {
- return nil, &dto.OpenAIErrorWithStatusCode{
- Error: dto.OpenAIError{
- Message: "No candidates returned",
- Type: "server_error",
- Param: "",
- Code: 500,
- },
- StatusCode: resp.StatusCode,
- }
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// 计算使用量(基于 UsageMetadata)
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
- CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
+ CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
}
- // 直接返回 Gemini 原生格式的 JSON 响应
- jsonResponse, err := json.Marshal(geminiResponse)
- if err != nil {
- return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
+ usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
+
+ for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+ if detail.Modality == "AUDIO" {
+ usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+ } else if detail.Modality == "TEXT" {
+ usage.PromptTokensDetails.TextTokens = detail.TokenCount
+ }
}
- // 设置响应头并写入响应
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- if err != nil {
- return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
- }
+ service.IOCopyBytesGracefully(c, resp, responseBody)
return &usage, nil
}
-func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
+func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
+ defer service.CloseResponseBodyGracefully(resp)
+
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+ }
+
+ if common.DebugEnabled {
+ println(string(responseBody))
+ }
+
+ usage := &dto.Usage{
+ PromptTokens: info.PromptTokens,
+ TotalTokens: info.PromptTokens,
+ }
+
+ if info.IsGeminiBatchEmbedding {
+ var geminiResponse dto.GeminiBatchEmbeddingResponse
+ err = common.Unmarshal(responseBody, &geminiResponse)
+ if err != nil {
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+ }
+ } else {
+ var geminiResponse dto.GeminiEmbeddingResponse
+ err = common.Unmarshal(responseBody, &geminiResponse)
+ if err != nil {
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+ }
+ }
+
+ service.IOCopyBytesGracefully(c, resp, responseBody)
+
+ return usage, nil
+}
+
+func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var usage = &dto.Usage{}
var imageCount int
helper.SetEventStreamHeaders(c)
+ responseText := strings.Builder{}
+
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
- var geminiResponse GeminiChatResponse
- err := common.DecodeJsonStr(data, &geminiResponse)
+ var geminiResponse dto.GeminiChatResponse
+ err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil {
- common.LogError(c, "error unmarshalling stream response: "+err.Error())
+ logger.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
}
@@ -92,37 +117,59 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
if part.InlineData != nil && part.InlineData.MimeType != "" {
imageCount++
}
+ if part.Text != "" {
+ responseText.WriteString(part.Text)
+ }
}
}
// 更新使用量统计
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
- usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
+ usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
+ usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
+ for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+ if detail.Modality == "AUDIO" {
+ usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+ } else if detail.Modality == "TEXT" {
+ usage.PromptTokensDetails.TextTokens = detail.TokenCount
+ }
+ }
}
// 直接发送 GeminiChatResponse 响应
- err = helper.ObjectData(c, geminiResponse)
+ err = helper.StringData(c, data)
if err != nil {
- common.LogError(c, err.Error())
+ logger.LogError(c, err.Error())
}
-
+ info.SendResponseCount++
return true
})
+ if info.SendResponseCount == 0 {
+ return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
+ }
+
if imageCount != 0 {
if usage.CompletionTokens == 0 {
usage.CompletionTokens = imageCount * 258
}
}
- // 计算最终使用量
- usage.PromptTokensDetails.TextTokens = usage.PromptTokens
- usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
+ // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
+ if usage.CompletionTokens == 0 {
+ str := responseText.String()
+ if len(str) > 0 {
+ usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
+ } else {
+ // 空补全,不需要使用量
+ usage = &dto.Usage{}
+ }
+ }
- // 结束流式响应
- helper.Done(c)
+ // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
+ //helper.Done(c)
return usage, nil
}
diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go
index bf1ece57..c54eb5b6 100644
--- a/relay/channel/gemini/relay-gemini.go
+++ b/relay/channel/gemini/relay-gemini.go
@@ -2,16 +2,21 @@ package gemini
import (
"encoding/json"
+ "errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
+ "one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
+ "one-api/types"
+ "strconv"
"strings"
"unicode/utf8"
@@ -36,15 +41,151 @@ var geminiSupportedMimeTypes = map[string]bool{
"video/flv": true,
}
-// Setting safety to the lowest possible values since Gemini is already powerless enough
-func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
+// Gemini 允许的思考预算范围
+const (
+ pro25MinBudget = 128
+ pro25MaxBudget = 32768
+ flash25MaxBudget = 24576
+ flash25LiteMinBudget = 512
+ flash25LiteMaxBudget = 24576
+)
- geminiRequest := GeminiChatRequest{
- Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
- GenerationConfig: GeminiChatGenerationConfig{
+func isNew25ProModel(modelName string) bool {
+ return strings.HasPrefix(modelName, "gemini-2.5-pro") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
+}
+
+func is25FlashLiteModel(modelName string) bool {
+ return strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
+}
+
+// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
+func clampThinkingBudget(modelName string, budget int) int {
+ isNew25Pro := isNew25ProModel(modelName)
+ is25FlashLite := is25FlashLiteModel(modelName)
+
+ if is25FlashLite {
+ if budget < flash25LiteMinBudget {
+ return flash25LiteMinBudget
+ }
+ if budget > flash25LiteMaxBudget {
+ return flash25LiteMaxBudget
+ }
+ } else if isNew25Pro {
+ if budget < pro25MinBudget {
+ return pro25MinBudget
+ }
+ if budget > pro25MaxBudget {
+ return pro25MaxBudget
+ }
+ } else { // 其他模型
+ if budget < 0 {
+ return 0
+ }
+ if budget > flash25MaxBudget {
+ return flash25MaxBudget
+ }
+ }
+ return budget
+}
+
+// "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
+// "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
+// "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
+func clampThinkingBudgetByEffort(modelName string, effort string) int {
+ isNew25Pro := isNew25ProModel(modelName)
+ is25FlashLite := is25FlashLiteModel(modelName)
+
+ maxBudget := 0
+ if is25FlashLite {
+ maxBudget = flash25LiteMaxBudget
+ }
+ if isNew25Pro {
+ maxBudget = pro25MaxBudget
+ } else {
+ maxBudget = flash25MaxBudget
+ }
+ switch effort {
+ case "high":
+ maxBudget = maxBudget * 80 / 100
+ case "medium":
+ maxBudget = maxBudget * 50 / 100
+ case "low":
+ maxBudget = maxBudget * 20 / 100
+ }
+ return clampThinkingBudget(modelName, maxBudget)
+}
+
+func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
+ if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
+ modelName := info.UpstreamModelName
+ isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
+
+ if strings.Contains(modelName, "-thinking-") {
+ parts := strings.SplitN(modelName, "-thinking-", 2)
+ if len(parts) == 2 && parts[1] != "" {
+ if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
+ clampedBudget := clampThinkingBudget(modelName, budgetTokens)
+ geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
+ ThinkingBudget: common.GetPointer(clampedBudget),
+ IncludeThoughts: true,
+ }
+ }
+ }
+ } else if strings.HasSuffix(modelName, "-thinking") {
+ unsupportedModels := []string{
+ "gemini-2.5-pro-preview-05-06",
+ "gemini-2.5-pro-preview-03-25",
+ }
+ isUnsupported := false
+ for _, unsupportedModel := range unsupportedModels {
+ if strings.HasPrefix(modelName, unsupportedModel) {
+ isUnsupported = true
+ break
+ }
+ }
+
+ if isUnsupported {
+ geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
+ IncludeThoughts: true,
+ }
+ } else {
+ geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
+ IncludeThoughts: true,
+ }
+ if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
+ budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
+ clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
+ geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
+ } else {
+ if len(oaiRequest) > 0 {
+ // 如果有reasoningEffort参数,则根据其值设置思考预算
+ geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampThinkingBudgetByEffort(modelName, oaiRequest[0].ReasoningEffort))
+ }
+ }
+ }
+ } else if strings.HasSuffix(modelName, "-nothinking") {
+ if !isNew25Pro {
+ geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
+ ThinkingBudget: common.GetPointer(0),
+ }
+ }
+ }
+ }
+}
+
+// Setting safety to the lowest possible values since Gemini is already powerless enough
+func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
+
+ geminiRequest := dto.GeminiChatRequest{
+ Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
+ GenerationConfig: dto.GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
- MaxOutputTokens: textRequest.MaxTokens,
+ MaxOutputTokens: textRequest.GetMaxTokens(),
Seed: int64(textRequest.Seed),
},
}
@@ -56,33 +197,41 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
}
}
- if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
- if strings.HasSuffix(info.OriginModelName, "-thinking") {
- // 如果模型名以 gemini-2.5-pro 开头,不设置 ThinkingBudget
- if strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") {
- geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
- IncludeThoughts: true,
- }
- } else {
- budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
- if budgetTokens == 0 || budgetTokens > 24576 {
- budgetTokens = 24576
- }
- geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
- ThinkingBudget: common.GetPointer(int(budgetTokens)),
- IncludeThoughts: true,
- }
- }
- } else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
- geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
- ThinkingBudget: common.GetPointer(0),
+ adaptorWithExtraBody := false
+
+ if len(textRequest.ExtraBody) > 0 {
+ if !strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
+ var extraBody map[string]interface{}
+ if err := common.Unmarshal(textRequest.ExtraBody, &extraBody); err != nil {
+ return nil, fmt.Errorf("invalid extra body: %w", err)
+ }
+ // eg. {"google":{"thinking_config":{"thinking_budget":5324,"include_thoughts":true}}}
+ if googleBody, ok := extraBody["google"].(map[string]interface{}); ok {
+ adaptorWithExtraBody = true
+ if thinkingConfig, ok := googleBody["thinking_config"].(map[string]interface{}); ok {
+ if budget, ok := thinkingConfig["thinking_budget"].(float64); ok {
+ budgetInt := int(budget)
+ geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
+ ThinkingBudget: common.GetPointer(budgetInt),
+ IncludeThoughts: true,
+ }
+ } else {
+ geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
+ IncludeThoughts: true,
+ }
+ }
+ }
}
}
}
- safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
+ if !adaptorWithExtraBody {
+ ThinkingAdaptor(&geminiRequest, info, textRequest)
+ }
+
+ safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
for _, category := range SafetySettingList {
- safetySettings = append(safetySettings, GeminiChatSafetySettings{
+ safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
Category: category,
Threshold: model_setting.GetGeminiSafetySetting(category),
})
@@ -119,38 +268,35 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
tool.Function.Parameters = cleanedParams
functions = append(functions, tool.Function)
}
+ geminiTools := geminiRequest.GetTools()
if codeExecution {
- geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
+ geminiTools = append(geminiTools, dto.GeminiChatTool{
CodeExecution: make(map[string]string),
})
}
if googleSearch {
- geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
+ geminiTools = append(geminiTools, dto.GeminiChatTool{
GoogleSearch: make(map[string]string),
})
}
if len(functions) > 0 {
- geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
+ geminiTools = append(geminiTools, dto.GeminiChatTool{
FunctionDeclarations: functions,
})
}
- // common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
- // json_data, _ := json.Marshal(geminiRequest.Tools)
- // common.SysLog("tools_json: " + string(json_data))
- } else if textRequest.Functions != nil {
- //geminiRequest.Tools = []GeminiChatTool{
- // {
- // FunctionDeclarations: textRequest.Functions,
- // },
- //}
+ geminiRequest.SetTools(geminiTools)
}
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
- if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil {
- cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0)
- geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
+ if len(textRequest.ResponseFormat.JsonSchema) > 0 {
+ // 先将json.RawMessage解析
+ var jsonSchema dto.FormatJsonSchema
+ if err := common.Unmarshal(textRequest.ResponseFormat.JsonSchema, &jsonSchema); err == nil {
+ cleanedSchema := removeAdditionalPropertiesWithDepth(jsonSchema.Schema, 0)
+ geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
+ }
}
}
tool_call_ids := make(map[string]string)
@@ -162,7 +308,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
continue
} else if message.Role == "tool" || message.Role == "function" {
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
- geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
+ geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
Role: "user",
})
}
@@ -173,24 +319,34 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
name = val
}
- content := common.StrToMap(message.StringContent())
- functionResp := &FunctionResponse{
- Name: name,
- Response: GeminiFunctionResponseContent{
- Name: name,
- Content: content,
- },
+ var contentMap map[string]interface{}
+ contentStr := message.StringContent()
+
+ // 1. 尝试解析为 JSON 对象
+ if err := json.Unmarshal([]byte(contentStr), &contentMap); err != nil {
+ // 2. 如果失败,尝试解析为 JSON 数组
+ var contentSlice []interface{}
+ if err := json.Unmarshal([]byte(contentStr), &contentSlice); err == nil {
+ // 如果是数组,包装成对象
+ contentMap = map[string]interface{}{"result": contentSlice}
+ } else {
+ // 3. 如果再次失败,作为纯文本处理
+ contentMap = map[string]interface{}{"content": contentStr}
+ }
}
- if content == nil {
- functionResp.Response.Content = message.StringContent()
+
+ functionResp := &dto.GeminiFunctionResponse{
+ Name: name,
+ Response: contentMap,
}
- *parts = append(*parts, GeminiPart{
+
+ *parts = append(*parts, dto.GeminiPart{
FunctionResponse: functionResp,
})
continue
}
- var parts []GeminiPart
- content := GeminiChatContent{
+ var parts []dto.GeminiPart
+ content := dto.GeminiChatContent{
Role: message.Role,
}
// isToolCall := false
@@ -204,8 +360,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
}
}
- toolCall := GeminiPart{
- FunctionCall: &FunctionCall{
+ toolCall := dto.GeminiPart{
+ FunctionCall: &dto.FunctionCall{
FunctionName: call.Function.Name,
Arguments: args,
},
@@ -222,7 +378,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if part.Text == "" {
continue
}
- parts = append(parts, GeminiPart{
+ parts = append(parts, dto.GeminiPart{
Text: part.Text,
})
} else if part.Type == dto.ContentTypeImageURL {
@@ -234,18 +390,19 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
// 判断是否是url
if strings.HasPrefix(part.GetImageMedia().Url, "http") {
// 是url,获取文件的类型和base64编码的数据
- fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
+ fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini")
if err != nil {
return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
}
// 校验 MimeType 是否在 Gemini 支持的白名单中
if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
- return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, part.GetImageMedia().Url, getSupportedMimeTypesList())
+ url := part.GetImageMedia().Url
+ return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
}
- parts = append(parts, GeminiPart{
- InlineData: &GeminiInlineData{
+ parts = append(parts, dto.GeminiPart{
+ InlineData: &dto.GeminiInlineData{
MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
Data: fileData.Base64Data,
},
@@ -255,8 +412,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
}
- parts = append(parts, GeminiPart{
- InlineData: &GeminiInlineData{
+ parts = append(parts, dto.GeminiPart{
+ InlineData: &dto.GeminiInlineData{
MimeType: format,
Data: base64String,
},
@@ -270,8 +427,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if err != nil {
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
}
- parts = append(parts, GeminiPart{
- InlineData: &GeminiInlineData{
+ parts = append(parts, dto.GeminiPart{
+ InlineData: &dto.GeminiInlineData{
MimeType: format,
Data: base64String,
},
@@ -280,13 +437,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if part.GetInputAudio().Data == "" {
return nil, fmt.Errorf("only base64 audio is supported in gemini")
}
- format, base64String, err := service.DecodeBase64FileData(part.GetInputAudio().Data)
+ base64String, err := service.DecodeBase64AudioData(part.GetInputAudio().Data)
if err != nil {
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
}
- parts = append(parts, GeminiPart{
- InlineData: &GeminiInlineData{
- MimeType: format,
+ parts = append(parts, dto.GeminiPart{
+ InlineData: &dto.GeminiInlineData{
+ MimeType: "audio/" + part.GetInputAudio().Format,
Data: base64String,
},
})
@@ -299,12 +456,14 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if content.Role == "assistant" {
content.Role = "model"
}
- geminiRequest.Contents = append(geminiRequest.Contents, content)
+ if len(content.Parts) > 0 {
+ geminiRequest.Contents = append(geminiRequest.Contents, content)
+ }
}
if len(system_content) > 0 {
- geminiRequest.SystemInstructions = &GeminiChatContent{
- Parts: []GeminiPart{
+ geminiRequest.SystemInstructions = &dto.GeminiChatContent{
+ Parts: []dto.GeminiPart{
{
Text: strings.Join(system_content, "\n"),
},
@@ -547,7 +706,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
return data
}
-func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
+func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
var argsBytes []byte
var err error
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
@@ -569,21 +728,20 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
}
}
-func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
+func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
- Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+ Id: helper.GetResponseID(c),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
}
- content, _ := json.Marshal("")
isToolCall := false
for _, candidate := range response.Candidates {
choice := dto.OpenAITextResponseChoice{
Index: int(candidate.Index),
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: "",
},
FinishReason: constant.FinishReasonStop,
}
@@ -637,10 +795,9 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
return &fullTextResponse
}
-func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
+func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
isStop := false
- hasImage := false
for _, candidate := range geminiResponse.Candidates {
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
isStop = true
@@ -649,7 +806,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
choice := dto.ChatCompletionsStreamResponseChoice{
Index: int(candidate.Index),
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
- Role: "assistant",
+ //Role: "assistant",
},
}
var texts []string
@@ -671,7 +828,6 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
if strings.HasPrefix(part.InlineData.MimeType, "image") {
imgText := ""
texts = append(texts, imgText)
- hasImage = true
}
} else if part.FunctionCall != nil {
isTools = true
@@ -679,6 +835,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
call.SetIndex(len(choice.Delta.ToolCalls))
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
}
+
} else if part.Thought {
isThought = true
texts = append(texts, part.Text)
@@ -708,28 +865,60 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Choices = choices
- return &response, isStop, hasImage
+ return &response, isStop
}
-func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
+ streamData, err := common.Marshal(resp)
+ if err != nil {
+ return fmt.Errorf("failed to marshal stream response: %w", err)
+ }
+ err = openai.HandleStreamFormat(c, info, string(streamData), info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
+ if err != nil {
+ return fmt.Errorf("failed to handle stream format: %w", err)
+ }
+ return nil
+}
+
+func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
+ streamData, err := common.Marshal(resp)
+ if err != nil {
+ return fmt.Errorf("failed to marshal stream response: %w", err)
+ }
+ openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, false)
+ return nil
+}
+
+func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
// responseText := ""
- id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+ id := helper.GetResponseID(c)
createAt := common.GetTimestamp()
+ responseText := strings.Builder{}
var usage = &dto.Usage{}
var imageCount int
+ finishReason := constant.FinishReasonStop
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
- var geminiResponse GeminiChatResponse
- err := common.DecodeJsonStr(data, &geminiResponse)
+ var geminiResponse dto.GeminiChatResponse
+ err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil {
- common.LogError(c, "error unmarshalling stream response: "+err.Error())
+ logger.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
}
- response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
- if hasImage {
- imageCount++
+ for _, candidate := range geminiResponse.Candidates {
+ for _, part := range candidate.Content.Parts {
+ if part.InlineData != nil && part.InlineData.MimeType != "" {
+ imageCount++
+ }
+ if part.Text != "" {
+ responseText.WriteString(part.Text)
+ }
+ }
}
+
+ response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
+
response.Id = id
response.Created = createAt
response.Model = info.UpstreamModelName
@@ -738,19 +927,55 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
+ for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+ if detail.Modality == "AUDIO" {
+ usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+ } else if detail.Modality == "TEXT" {
+ usage.PromptTokensDetails.TextTokens = detail.TokenCount
+ }
+ }
}
- err = helper.ObjectData(c, response)
+ logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
+ if info.SendResponseCount == 0 {
+ // send first response
+ emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)
+ if response.IsToolCall() {
+ emptyResponse.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse, 1)
+ emptyResponse.Choices[0].Delta.ToolCalls[0] = *response.GetFirstToolCall()
+ emptyResponse.Choices[0].Delta.ToolCalls[0].Function.Arguments = ""
+ finishReason = constant.FinishReasonToolCalls
+ err = handleStream(c, info, emptyResponse)
+ if err != nil {
+ logger.LogError(c, err.Error())
+ }
+
+ response.ClearToolCalls()
+ if response.IsFinished() {
+ response.Choices[0].FinishReason = nil
+ }
+ } else {
+ err = handleStream(c, info, emptyResponse)
+ if err != nil {
+ logger.LogError(c, err.Error())
+ }
+ }
+ }
+
+ err = handleStream(c, info, response)
if err != nil {
- common.LogError(c, err.Error())
+ logger.LogError(c, err.Error())
}
if isStop {
- response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
- helper.ObjectData(c, response)
+ _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
}
return true
})
- var response *dto.ChatCompletionsStreamResponse
+ if info.SendResponseCount == 0 {
+ // 空补全,报错不计费
+ // empty response, throw an error
+ return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
+ }
if imageCount != 0 {
if usage.CompletionTokens == 0 {
@@ -761,47 +986,46 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
- if info.ShouldIncludeUsage {
- response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
- err := helper.ObjectData(c, response)
- if err != nil {
- common.SysError("send final response failed: " + err.Error())
+ if usage.CompletionTokens == 0 {
+ str := responseText.String()
+ if len(str) > 0 {
+ usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
+ } else {
+ // 空补全,不需要使用量
+ usage = &dto.Usage{}
}
}
- helper.Done(c)
+
+ response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
+ err := handleFinalStream(c, info, response)
+ if err != nil {
+ common.SysLog("send final response failed: " + err.Error())
+ }
+ //if info.RelayFormat == relaycommon.RelayFormatOpenAI {
+ // helper.Done(c)
+ //}
//resp.Body.Close()
- return nil, usage
+ return usage, nil
}
-func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
+ service.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
println(string(responseBody))
}
- var geminiResponse GeminiChatResponse
- err = common.DecodeJson(responseBody, &geminiResponse)
+ var geminiResponse dto.GeminiChatResponse
+ err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if len(geminiResponse.Candidates) == 0 {
- return &dto.OpenAIErrorWithStatusCode{
- Error: dto.OpenAIError{
- Message: "No candidates returned",
- Type: "server_error",
- Param: "",
- Code: 500,
- },
- StatusCode: resp.StatusCode,
- }, nil
+ return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
- fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
+ fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
fullTextResponse.Model = info.UpstreamModelName
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
@@ -812,40 +1036,64 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
- fullTextResponse.Usage = usage
- jsonResponse, err := json.Marshal(fullTextResponse)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+ if detail.Modality == "AUDIO" {
+ usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+ } else if detail.Modality == "TEXT" {
+ usage.PromptTokensDetails.TextTokens = detail.TokenCount
+ }
}
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return nil, &usage
+
+ fullTextResponse.Usage = usage
+
+ switch info.RelayFormat {
+ case types.RelayFormatOpenAI:
+ responseBody, err = common.Marshal(fullTextResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ case types.RelayFormatClaude:
+ claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
+ claudeRespStr, err := common.Marshal(claudeResp)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ responseBody = claudeRespStr
+ case types.RelayFormatGemini:
+ break
+ }
+
+ service.IOCopyBytesGracefully(c, resp, responseBody)
+
+ return &usage, nil
}
-func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ defer service.CloseResponseBodyGracefully(resp)
+
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
- return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
+ return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
- _ = resp.Body.Close()
- var geminiResponse GeminiEmbeddingResponse
- if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
- return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
+ var geminiResponse dto.GeminiBatchEmbeddingResponse
+ if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
+ return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// convert to openai format response
openAIResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
- Data: []dto.OpenAIEmbeddingResponseItem{
- {
- Object: "embedding",
- Embedding: geminiResponse.Embedding.Values,
- Index: 0,
- },
- },
- Model: info.UpstreamModelName,
+ Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)),
+ Model: info.UpstreamModelName,
+ }
+
+ for i, embedding := range geminiResponse.Embeddings {
+ openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{
+ Object: "embedding",
+ Embedding: embedding.Values,
+ Index: i,
+ })
}
// calculate usage
@@ -853,21 +1101,72 @@ func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycomm
// Google has not yet clarified how embedding models will be billed
// refer to openai billing method to use input tokens billing
// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
- usage = &dto.Usage{
+ usage := &dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: 0,
TotalTokens: info.PromptTokens,
}
- openAIResponse.Usage = *usage.(*dto.Usage)
+ openAIResponse.Usage = *usage
+
+ jsonResponse, jsonErr := common.Marshal(openAIResponse)
+ if jsonErr != nil {
+ return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+ }
+
+ service.IOCopyBytesGracefully(c, resp, jsonResponse)
+ return usage, nil
+}
+
+func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ responseBody, readErr := io.ReadAll(resp.Body)
+ if readErr != nil {
+ return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+ }
+ _ = resp.Body.Close()
+
+ var geminiResponse dto.GeminiImageResponse
+ if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
+ return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+ }
+
+ if len(geminiResponse.Predictions) == 0 {
+ return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+ }
+
+ // convert to openai format response
+ openAIResponse := dto.ImageResponse{
+ Created: common.GetTimestamp(),
+ Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
+ }
+
+ for _, prediction := range geminiResponse.Predictions {
+ if prediction.RaiFilteredReason != "" {
+ continue // skip filtered image
+ }
+ openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
+ B64Json: prediction.BytesBase64Encoded,
+ })
+ }
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
- return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
+ return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
+ // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
+ // each image has fixed 258 tokens
+ const imageTokens = 258
+ generatedImages := len(openAIResponse.Data)
+
+ usage := &dto.Usage{
+ PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
+ CompletionTokens: 0, // image generation does not calculate completion tokens
+ TotalTokens: imageTokens * generatedImages,
+ }
+
return usage, nil
}
diff --git a/relay/channel/jimeng/adaptor.go b/relay/channel/jimeng/adaptor.go
new file mode 100644
index 00000000..885a1427
--- /dev/null
+++ b/relay/channel/jimeng/adaptor.go
@@ -0,0 +1,142 @@
+package jimeng
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ relayconstant "one-api/relay/constant"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.ChannelBaseUrl), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
+ return errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
+}
+
+type LogoInfo struct {
+ AddLogo bool `json:"add_logo,omitempty"`
+ Position int `json:"position,omitempty"`
+ Language int `json:"language,omitempty"`
+ Opacity float64 `json:"opacity,omitempty"`
+ LogoTextContent string `json:"logo_text_content,omitempty"`
+}
+
+type imageRequestPayload struct {
+ ReqKey string `json:"req_key"` // Service identifier, fixed value: jimeng_high_aes_general_v21_L
+ Prompt string `json:"prompt"` // Prompt for image generation, supports both Chinese and English
+ Seed int64 `json:"seed,omitempty"` // Random seed, default -1 (random)
+ Width int `json:"width,omitempty"` // Image width, default 512, range [256, 768]
+ Height int `json:"height,omitempty"` // Image height, default 512, range [256, 768]
+ UsePreLLM bool `json:"use_pre_llm,omitempty"` // Enable text expansion, default true
+ UseSR bool `json:"use_sr,omitempty"` // Enable super resolution, default true
+ ReturnURL bool `json:"return_url,omitempty"` // Whether to return image URL (valid for 24 hours)
+ LogoInfo LogoInfo `json:"logo_info,omitempty"` // Watermark information
+ ImageUrls []string `json:"image_urls,omitempty"` // Image URLs for input
+ BinaryData []string `json:"binary_data_base64,omitempty"` // Base64 encoded binary data
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ payload := imageRequestPayload{
+ ReqKey: request.Model,
+ Prompt: request.Prompt,
+ }
+ if request.ResponseFormat == "" || request.ResponseFormat == "url" {
+ payload.ReturnURL = true // Default to returning image URLs
+ }
+
+ if len(request.ExtraFields) > 0 {
+ if err := json.Unmarshal(request.ExtraFields, &payload); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal extra fields: %w", err)
+ }
+ }
+
+ return payload, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ fullRequestURL, err := a.GetRequestURL(info)
+ if err != nil {
+ return nil, fmt.Errorf("get request url failed: %w", err)
+ }
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ if err != nil {
+ return nil, fmt.Errorf("new request failed: %w", err)
+ }
+ err = Sign(c, req, info.ApiKey)
+ if err != nil {
+ return nil, fmt.Errorf("setup request header failed: %w", err)
+ }
+ resp, err := channel.DoRequest(c, req, info)
+ if err != nil {
+ return nil, fmt.Errorf("do request failed: %w", err)
+ }
+ return resp, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.RelayMode == relayconstant.RelayModeImagesGenerations {
+ usage, err = jimengImageHandler(c, resp, info)
+ } else if info.IsStream {
+ usage, err = openai.OaiStreamHandler(c, info, resp)
+ } else {
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/jimeng/constants.go b/relay/channel/jimeng/constants.go
new file mode 100644
index 00000000..0d1764e5
--- /dev/null
+++ b/relay/channel/jimeng/constants.go
@@ -0,0 +1,9 @@
+package jimeng
+
+const (
+ ChannelName = "jimeng"
+)
+
+var ModelList = []string{
+ "jimeng_high_aes_general_v21_L",
+}
diff --git a/relay/channel/jimeng/image.go b/relay/channel/jimeng/image.go
new file mode 100644
index 00000000..11a0117b
--- /dev/null
+++ b/relay/channel/jimeng/image.go
@@ -0,0 +1,89 @@
+package jimeng
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type ImageResponse struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data struct {
+ BinaryDataBase64 []string `json:"binary_data_base64"`
+ ImageUrls []string `json:"image_urls"`
+ RephraseResult string `json:"rephraser_result"`
+ RequestID string `json:"request_id"`
+ // Other fields are omitted for brevity
+ } `json:"data"`
+ RequestID string `json:"request_id"`
+ Status int `json:"status"`
+ TimeElapsed string `json:"time_elapsed"`
+}
+
+func responseJimeng2OpenAIImage(_ *gin.Context, response *ImageResponse, info *relaycommon.RelayInfo) *dto.ImageResponse {
+ imageResponse := dto.ImageResponse{
+ Created: info.StartTime.Unix(),
+ }
+
+ for _, base64Data := range response.Data.BinaryDataBase64 {
+ imageResponse.Data = append(imageResponse.Data, dto.ImageData{
+ B64Json: base64Data,
+ })
+ }
+ for _, imageUrl := range response.Data.ImageUrls {
+ imageResponse.Data = append(imageResponse.Data, dto.ImageData{
+ Url: imageUrl,
+ })
+ }
+
+ return &imageResponse
+}
+
+// jimengImageHandler handles the Jimeng image generation response
+func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
+ var jimengResponse ImageResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
+ }
+ service.CloseResponseBodyGracefully(resp)
+
+ err = json.Unmarshal(responseBody, &jimengResponse)
+ if err != nil {
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+ }
+
+ // Check if the response indicates an error
+ if jimengResponse.Code != 10000 {
+ return nil, types.WithOpenAIError(types.OpenAIError{
+ Message: jimengResponse.Message,
+ Type: "jimeng_error",
+ Param: "",
+ Code: fmt.Sprintf("%d", jimengResponse.Code),
+ }, resp.StatusCode)
+ }
+
+ // Convert Jimeng response to OpenAI format
+ fullTextResponse := responseJimeng2OpenAIImage(c, &jimengResponse, info)
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = c.Writer.Write(jsonResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+
+ return &dto.Usage{}, nil
+}
diff --git a/relay/channel/jimeng/sign.go b/relay/channel/jimeng/sign.go
new file mode 100644
index 00000000..d8b598dc
--- /dev/null
+++ b/relay/channel/jimeng/sign.go
@@ -0,0 +1,176 @@
+package jimeng
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+ "net/url"
+ "one-api/logger"
+ "sort"
+ "strings"
+ "time"
+)
+
+// SignRequestForJimeng 对即梦 API 请求进行签名,支持 http.Request 或 header+url+body 方式
+//func SignRequestForJimeng(req *http.Request, accessKey, secretKey string) error {
+// var bodyBytes []byte
+// var err error
+//
+// if req.Body != nil {
+// bodyBytes, err = io.ReadAll(req.Body)
+// if err != nil {
+// return fmt.Errorf("read request body failed: %w", err)
+// }
+// _ = req.Body.Close()
+// req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // rewind
+// } else {
+// bodyBytes = []byte{}
+// }
+//
+// return signJimengHeaders(&req.Header, req.Method, req.URL, bodyBytes, accessKey, secretKey)
+//}
+
+const HexPayloadHashKey = "HexPayloadHash"
+
+func SetPayloadHash(c *gin.Context, req any) error {
+ body, err := json.Marshal(req)
+ if err != nil {
+ return err
+ }
+ logger.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
+ payloadHash := sha256.Sum256(body)
+ hexPayloadHash := hex.EncodeToString(payloadHash[:])
+ c.Set(HexPayloadHashKey, hexPayloadHash)
+ return nil
+}
+func getPayloadHash(c *gin.Context) string {
+ return c.GetString(HexPayloadHashKey)
+}
+
+func Sign(c *gin.Context, req *http.Request, apiKey string) error {
+ header := req.Header
+
+ var bodyBytes []byte
+ var err error
+
+ if req.Body != nil {
+ bodyBytes, err = io.ReadAll(req.Body)
+ if err != nil {
+ return err
+ }
+ _ = req.Body.Close()
+ req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
+ }
+
+ payloadHash := sha256.Sum256(bodyBytes)
+ hexPayloadHash := hex.EncodeToString(payloadHash[:])
+
+ method := c.Request.Method
+ u := req.URL
+ keyParts := strings.Split(apiKey, "|")
+ if len(keyParts) != 2 {
+ return errors.New("invalid api key format for jimeng: expected 'ak|sk'")
+ }
+ accessKey := strings.TrimSpace(keyParts[0])
+ secretKey := strings.TrimSpace(keyParts[1])
+ t := time.Now().UTC()
+ xDate := t.Format("20060102T150405Z")
+ shortDate := t.Format("20060102")
+
+ host := u.Host
+ header.Set("Host", host)
+ header.Set("X-Date", xDate)
+ header.Set("X-Content-Sha256", hexPayloadHash)
+
+ // Sort and encode query parameters to create canonical query string
+ queryParams := u.Query()
+ sortedKeys := make([]string, 0, len(queryParams))
+ for k := range queryParams {
+ sortedKeys = append(sortedKeys, k)
+ }
+ sort.Strings(sortedKeys)
+ var queryParts []string
+ for _, k := range sortedKeys {
+ values := queryParams[k]
+ sort.Strings(values)
+ for _, v := range values {
+ queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
+ }
+ }
+ canonicalQueryString := strings.Join(queryParts, "&")
+
+ headersToSign := map[string]string{
+ "host": host,
+ "x-date": xDate,
+ "x-content-sha256": hexPayloadHash,
+ }
+ if header.Get("Content-Type") == "" {
+ header.Set("Content-Type", "application/json")
+ }
+ headersToSign["content-type"] = header.Get("Content-Type")
+
+ var signedHeaderKeys []string
+ for k := range headersToSign {
+ signedHeaderKeys = append(signedHeaderKeys, k)
+ }
+ sort.Strings(signedHeaderKeys)
+
+ var canonicalHeaders strings.Builder
+ for _, k := range signedHeaderKeys {
+ canonicalHeaders.WriteString(k)
+ canonicalHeaders.WriteString(":")
+ canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
+ canonicalHeaders.WriteString("\n")
+ }
+ signedHeaders := strings.Join(signedHeaderKeys, ";")
+
+ canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
+ method,
+ u.Path,
+ canonicalQueryString,
+ canonicalHeaders.String(),
+ signedHeaders,
+ hexPayloadHash,
+ )
+
+ hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
+ hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
+
+ region := "cn-north-1"
+ serviceName := "cv"
+ credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
+ stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
+ xDate,
+ credentialScope,
+ hexHashedCanonicalRequest,
+ )
+
+ kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
+ kRegion := hmacSHA256(kDate, []byte(region))
+ kService := hmacSHA256(kRegion, []byte(serviceName))
+ kSigning := hmacSHA256(kService, []byte("request"))
+ signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
+
+ authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
+ accessKey,
+ credentialScope,
+ signedHeaders,
+ signature,
+ )
+ header.Set("Authorization", authorization)
+ return nil
+}
+
+// hmacSHA256 计算 HMAC-SHA256
+func hmacSHA256(key []byte, data []byte) []byte {
+ h := hmac.New(sha256.New, key)
+ h.Write(data)
+ return h.Sum(nil)
+}
diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go
index 85b6a83f..a383728f 100644
--- a/relay/channel/jina/adaptor.go
+++ b/relay/channel/jina/adaptor.go
@@ -11,6 +11,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/common_handler"
"one-api/relay/constant"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
@@ -18,6 +19,11 @@ import (
type Adaptor struct {
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
@@ -39,9 +45,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
- return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
- return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
@@ -73,11 +79,11 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return request, nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeRerank {
- err, usage = common_handler.RerankHandler(c, info, resp)
+ usage, err = common_handler.RerankHandler(c, info, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings {
- err, usage = openai.OpenaiHandler(c, resp, info)
+ usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}
diff --git a/relay/channel/jina/constant.go b/relay/channel/jina/constant.go
index 45fc44c9..be290fb6 100644
--- a/relay/channel/jina/constant.go
+++ b/relay/channel/jina/constant.go
@@ -3,6 +3,7 @@ package jina
var ModelList = []string{
"jina-clip-v1",
"jina-reranker-v2-base-multilingual",
+ "jina-reranker-m0",
}
var ChannelName = "jina"
diff --git a/relay/channel/minimax/relay-minimax.go b/relay/channel/minimax/relay-minimax.go
index d0a15b0d..ff9b72ea 100644
--- a/relay/channel/minimax/relay-minimax.go
+++ b/relay/channel/minimax/relay-minimax.go
@@ -6,5 +6,5 @@ import (
)
func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.ChannelBaseUrl), nil
}
diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go
index 44f57e61..f98ff869 100644
--- a/relay/channel/mistral/adaptor.go
+++ b/relay/channel/mistral/adaptor.go
@@ -8,6 +8,7 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
@@ -15,6 +16,11 @@ import (
type Adaptor struct {
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
@@ -35,7 +41,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
+ return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -69,11 +75,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
- err, usage = openai.OaiStreamHandler(c, resp, info)
+ usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
- err, usage = openai.OpenaiHandler(c, resp, info)
+ usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}
diff --git a/relay/channel/mistral/text.go b/relay/channel/mistral/text.go
index 75272e34..aa925781 100644
--- a/relay/channel/mistral/text.go
+++ b/relay/channel/mistral/text.go
@@ -1,13 +1,55 @@
package mistral
import (
+ "one-api/common"
"one-api/dto"
+ "regexp"
)
+var mistralToolCallIdRegexp = regexp.MustCompile("^[a-zA-Z0-9]{9}$")
+
func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages))
+ idMap := make(map[string]string)
for _, message := range request.Messages {
+ // 1. tool_calls.id
+ toolCalls := message.ParseToolCalls()
+ if toolCalls != nil {
+ for i := range toolCalls {
+ if !mistralToolCallIdRegexp.MatchString(toolCalls[i].ID) {
+ if newId, ok := idMap[toolCalls[i].ID]; ok {
+ toolCalls[i].ID = newId
+ } else {
+ newId, err := common.GenerateRandomCharsKey(9)
+ if err == nil {
+ idMap[toolCalls[i].ID] = newId
+ toolCalls[i].ID = newId
+ }
+ }
+ }
+ }
+ message.SetToolCalls(toolCalls)
+ }
+
+ // 2. tool_call_id
+ if message.ToolCallId != "" {
+ if newId, ok := idMap[message.ToolCallId]; ok {
+ message.ToolCallId = newId
+ } else {
+ if !mistralToolCallIdRegexp.MatchString(message.ToolCallId) {
+ newId, err := common.GenerateRandomCharsKey(9)
+ if err == nil {
+ idMap[message.ToolCallId] = newId
+ message.ToolCallId = newId
+ }
+ }
+ }
+ }
+
mediaMessages := message.ParseContent()
+ if message.Role == "assistant" && message.ToolCalls != nil && message.Content == "" {
+ mediaMessages = []dto.MediaContent{}
+ }
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.GetImageMedia()
@@ -29,7 +71,7 @@ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAI
Messages: messages,
Temperature: request.Temperature,
TopP: request.TopP,
- MaxTokens: request.MaxTokens,
+ MaxTokens: request.GetMaxTokens(),
Tools: request.Tools,
ToolChoice: request.ToolChoice,
}
diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go
index b889f225..f9da685f 100644
--- a/relay/channel/mokaai/adaptor.go
+++ b/relay/channel/mokaai/adaptor.go
@@ -9,6 +9,7 @@ import (
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
+ "one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -17,6 +18,11 @@ import (
type Adaptor struct {
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
@@ -48,7 +54,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if strings.HasPrefix(info.UpstreamModelName, "m3e") {
suffix = "embeddings"
}
- fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix)
+ fullRequestURL := fmt.Sprintf("%s/%s", info.ChannelBaseUrl, suffix)
return fullRequestURL, nil
}
@@ -84,11 +90,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
- err, usage = mokaEmbeddingHandler(c, resp)
+ return mokaEmbeddingHandler(c, info, resp)
default:
// err, usage = mokaHandler(c, resp)
diff --git a/relay/channel/mokaai/relay-mokaai.go b/relay/channel/mokaai/relay-mokaai.go
index d7580d7a..d91aceb3 100644
--- a/relay/channel/mokaai/relay-mokaai.go
+++ b/relay/channel/mokaai/relay-mokaai.go
@@ -2,11 +2,15 @@ package mokaai
import (
"encoding/json"
- "github.com/gin-gonic/gin"
"io"
"net/http"
+ "one-api/common"
"one-api/dto"
+ relaycommon "one-api/relay/common"
"one-api/service"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
)
func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest {
@@ -26,7 +30,7 @@ func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.Embeddin
}
return &dto.EmbeddingRequest{
Input: input,
- Model: request.Model,
+ Model: request.Model,
}
}
@@ -47,19 +51,16 @@ func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEm
return &openAIEmbeddingResponse
}
-func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var baiduResponse dto.EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
// if baiduResponse.ErrorMsg != "" {
// return &dto.OpenAIErrorWithStatusCode{
@@ -71,13 +72,12 @@ func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
// }, nil
// }
fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse)
- jsonResponse, err := json.Marshal(fullTextResponse)
+ jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return nil, &fullTextResponse.Usage
+ service.IOCopyBytesGracefully(c, resp, jsonResponse)
+ return &fullTextResponse.Usage, nil
}
-
diff --git a/relay/channel/moonshot/adaptor.go b/relay/channel/moonshot/adaptor.go
new file mode 100644
index 00000000..e290c239
--- /dev/null
+++ b/relay/channel/moonshot/adaptor.go
@@ -0,0 +1,110 @@
+package moonshot
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/claude"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
+ adaptor := openai.Adaptor{}
+ return adaptor.ConvertClaudeRequest(c, info, req)
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not supported")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ adaptor := openai.Adaptor{}
+ return adaptor.ConvertImageRequest(c, info, request)
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ switch info.RelayFormat {
+ case types.RelayFormatClaude:
+ return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
+ default:
+ if info.RelayMode == constant.RelayModeRerank {
+ return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
+ } else if info.RelayMode == constant.RelayModeEmbeddings {
+ return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
+ } else if info.RelayMode == constant.RelayModeChatCompletions {
+ return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
+ } else if info.RelayMode == constant.RelayModeCompletions {
+ return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
+ }
+ return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
+ }
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ switch info.RelayFormat {
+ case types.RelayFormatClaude:
+ if info.IsStream {
+ return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
+ } else {
+ return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
+ }
+ default:
+ adaptor := openai.Adaptor{}
+ return adaptor.DoResponse(c, resp, info)
+ }
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go
index 18069311..d6b5b697 100644
--- a/relay/channel/ollama/adaptor.go
+++ b/relay/channel/ollama/adaptor.go
@@ -9,6 +9,7 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
@@ -16,10 +17,21 @@ import (
type Adaptor struct {
}
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
- panic("implement me")
- return nil, nil
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
+ openaiAdaptor := openai.Adaptor{}
+ openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request)
+ if err != nil {
+ return nil, err
+ }
+ openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
+ IncludeUsage: true,
+ }
+ return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest))
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -36,11 +48,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ if info.RelayFormat == types.RelayFormatClaude {
+ return info.ChannelBaseUrl + "/v1/chat/completions", nil
+ }
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
- return info.BaseUrl + "/api/embed", nil
+ return info.ChannelBaseUrl + "/api/embed", nil
default:
- return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
+ return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
}
}
@@ -54,7 +69,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
- return requestOpenAI2Ollama(*request)
+ return requestOpenAI2Ollama(c, request)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -74,14 +89,15 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
- if info.IsStream {
- err, usage = openai.OaiStreamHandler(c, resp, info)
- } else {
- if info.RelayMode == relayconstant.RelayModeEmbeddings {
- err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ switch info.RelayMode {
+ case relayconstant.RelayModeEmbeddings:
+ usage, err = ollamaEmbeddingHandler(c, info, resp)
+ default:
+ if info.IsStream {
+ usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
- err, usage = openai.OpenaiHandler(c, resp, info)
+ usage, err = openai.OpenaiHandler(c, info, resp)
}
}
return
diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go
index 15c64cdc..317c2a4a 100644
--- a/relay/channel/ollama/dto.go
+++ b/relay/channel/ollama/dto.go
@@ -1,6 +1,9 @@
package ollama
-import "one-api/dto"
+import (
+ "encoding/json"
+ "one-api/dto"
+)
type OllamaRequest struct {
Model string `json:"model,omitempty"`
@@ -19,6 +22,7 @@ type OllamaRequest struct {
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
+ Think json.RawMessage `json:"think,omitempty"`
}
type Options struct {
diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go
index 89a04646..be2029f5 100644
--- a/relay/channel/ollama/relay-ollama.go
+++ b/relay/channel/ollama/relay-ollama.go
@@ -1,18 +1,20 @@
package ollama
import (
- "bytes"
- "encoding/json"
"fmt"
- "github.com/gin-gonic/gin"
"io"
"net/http"
+ "one-api/common"
"one-api/dto"
+ relaycommon "one-api/relay/common"
"one-api/service"
+ "one-api/types"
"strings"
+
+ "github.com/gin-gonic/gin"
)
-func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
+func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
if !message.IsStringContent() {
@@ -22,7 +24,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err
imageUrl := mediaMessage.GetImageMedia()
// check if not base64
if strings.HasPrefix(imageUrl.Url, "http") {
- fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
+ fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama")
if err != nil {
return nil, err
}
@@ -48,7 +50,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err
} else {
Stop, _ = request.Stop.([]string)
}
- return &OllamaRequest{
+ ollamaRequest := &OllamaRequest{
Model: request.Model,
Messages: messages,
Stream: request.Stream,
@@ -58,14 +60,18 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err
TopK: request.TopK,
Stop: Stop,
Tools: request.Tools,
- MaxTokens: request.MaxTokens,
+ MaxTokens: request.GetMaxTokens(),
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
Prompt: request.Prompt,
StreamOptions: request.StreamOptions,
Suffix: request.Suffix,
- }, nil
+ }
+ if think, ok := request.Extra["think"]; ok {
+ ollamaRequest.Think = think
+ }
+ return ollamaRequest, nil
}
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
@@ -82,22 +88,19 @@ func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequ
}
}
-func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var ollamaEmbeddingResponse OllamaEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
- err = resp.Body.Close()
+ service.CloseResponseBodyGracefully(resp)
+ err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
- err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if ollamaEmbeddingResponse.Error != "" {
- return service.OpenAIErrorWrapper(err, "ollama_error", resp.StatusCode), nil
+ return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
@@ -106,46 +109,22 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
Object: "embedding",
})
usage := &dto.Usage{
- TotalTokens: promptTokens,
+ TotalTokens: info.PromptTokens,
CompletionTokens: 0,
- PromptTokens: promptTokens,
+ PromptTokens: info.PromptTokens,
}
embeddingResponse := &dto.OpenAIEmbeddingResponse{
Object: "list",
Data: data,
- Model: model,
+ Model: info.UpstreamModelName,
Usage: *usage,
}
- doResponseBody, err := json.Marshal(embeddingResponse)
+ doResponseBody, err := common.Marshal(embeddingResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
- resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
- // We shouldn't set the header before we parse the response body, because the parse part may fail.
- // And then we will have to send an error response, but in this case, the header has already been set.
- // So the httpClient will be confused by the response.
- // For example, Postman will report error, and we cannot check the response at all.
- // Copy headers
- for k, v := range resp.Header {
- // 删除任何现有的相同头部,以防止重复添加头部
- c.Writer.Header().Del(k)
- for _, vv := range v {
- c.Writer.Header().Add(k, vv)
- }
- }
- // reset content length
- c.Writer.Header().Del("Content-Length")
- c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
- return nil, usage
+ service.IOCopyBytesGracefully(c, resp, doResponseBody)
+ return usage, nil
}
func flattenEmbeddings(embeddings [][]float64) []float64 {
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
index f0cf073f..939c0223 100644
--- a/relay/channel/openai/adaptor.go
+++ b/relay/channel/openai/adaptor.go
@@ -10,19 +10,19 @@ import (
"net/http"
"net/textproto"
"one-api/common"
- constant2 "one-api/constant"
+ "one-api/constant"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/ai360"
"one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/minimax"
- "one-api/relay/channel/moonshot"
"one-api/relay/channel/openrouter"
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
"one-api/relay/common_handler"
- "one-api/relay/constant"
+ relayconstant "one-api/relay/constant"
"one-api/service"
+ "one-api/types"
"path/filepath"
"strings"
@@ -34,15 +34,55 @@ type Adaptor struct {
ResponseFormat string
}
-func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
- if !strings.Contains(request.Model, "claude") {
- return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
+// parseReasoningEffortFromModelSuffix 从模型名称中解析推理级别
+// support OAI models: o1-mini/o3-mini/o4-mini/o1/o3 etc...
+// minimal effort only available in gpt-5
+func parseReasoningEffortFromModelSuffix(model string) (string, string) {
+ effortSuffixes := []string{"-high", "-minimal", "-low", "-medium"}
+ for _, suffix := range effortSuffixes {
+ if strings.HasSuffix(model, suffix) {
+ effort := strings.TrimPrefix(suffix, "-")
+ originModel := strings.TrimSuffix(model, suffix)
+ return effort, originModel
+ }
}
+ return "", model
+}
+
+func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
+ // 使用 service.GeminiToOpenAIRequest 转换请求格式
+ openaiRequest, err := service.GeminiToOpenAIRequest(request, info)
+ if err != nil {
+ return nil, err
+ }
+ return a.ConvertOpenAIRequest(c, info, openaiRequest)
+}
+
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
+ //if !strings.Contains(request.Model, "claude") {
+ // return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
+ //}
+ //if common.DebugEnabled {
+ // bodyBytes := []byte(common.GetJsonString(request))
+ // err := os.WriteFile(fmt.Sprintf("claude_request_%s.txt", c.GetString(common.RequestIdKey)), bodyBytes, 0644)
+ // if err != nil {
+ // println(fmt.Sprintf("failed to save request body to file: %v", err))
+ // }
+ //}
aiRequest, err := service.ClaudeToOpenAIRequest(*request, info)
if err != nil {
return nil, err
}
- if info.SupportStreamOptions {
+ //if common.DebugEnabled {
+ // println(fmt.Sprintf("convert claude to openai request result: %s", common.GetJsonString(aiRequest)))
+ // // Save request body to file for debugging
+ // bodyBytes := []byte(common.GetJsonString(aiRequest))
+ // err = os.WriteFile(fmt.Sprintf("claude_to_openai_request_%s.txt", c.GetString(common.RequestIdKey)), bodyBytes, 0644)
+ // if err != nil {
+ // println(fmt.Sprintf("failed to save request body to file: %v", err))
+ // }
+ //}
+ if info.SupportStreamOptions && info.IsStream {
aiRequest.StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
@@ -54,7 +94,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType
// initialize ThinkingContentInfo when thinking_to_content is enabled
- if think2Content, ok := info.ChannelSetting[constant2.ChannelSettingThinkingToContent].(bool); ok && think2Content {
+ if info.ChannelSetting.ThinkingToContent {
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
@@ -64,62 +104,86 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- if info.RelayFormat == relaycommon.RelayFormatClaude {
- return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
- }
- if info.RelayMode == constant.RelayModeRealtime {
- if strings.HasPrefix(info.BaseUrl, "https://") {
- baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
+ if info.RelayMode == relayconstant.RelayModeRealtime {
+ if strings.HasPrefix(info.ChannelBaseUrl, "https://") {
+ baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "https://")
baseUrl = "wss://" + baseUrl
- info.BaseUrl = baseUrl
- } else if strings.HasPrefix(info.BaseUrl, "http://") {
- baseUrl := strings.TrimPrefix(info.BaseUrl, "http://")
+ info.ChannelBaseUrl = baseUrl
+ } else if strings.HasPrefix(info.ChannelBaseUrl, "http://") {
+ baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "http://")
baseUrl = "ws://" + baseUrl
- info.BaseUrl = baseUrl
+ info.ChannelBaseUrl = baseUrl
}
}
switch info.ChannelType {
- case common.ChannelTypeAzure:
+ case constant.ChannelTypeAzure:
apiVersion := info.ApiVersion
if apiVersion == "" {
- apiVersion = constant2.AzureDefaultAPIVersion
+ apiVersion = constant.AzureDefaultAPIVersion
}
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(info.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
task := strings.TrimPrefix(requestURL, "/v1/")
+
+ if info.RelayFormat == types.RelayFormatClaude {
+ task = strings.TrimPrefix(task, "messages")
+ task = "chat/completions" + task
+ }
+
+ // 特殊处理 responses API
+ if info.RelayMode == relayconstant.RelayModeResponses {
+ responsesApiVersion := "preview"
+
+ subUrl := "/openai/v1/responses"
+ if strings.Contains(info.ChannelBaseUrl, "cognitiveservices.azure.com") {
+ subUrl = "/openai/responses"
+ responsesApiVersion = apiVersion
+ }
+
+ if info.ChannelOtherSettings.AzureResponsesVersion != "" {
+ responsesApiVersion = info.ChannelOtherSettings.AzureResponsesVersion
+ }
+
+ requestURL = fmt.Sprintf("%s?api-version=%s", subUrl, responsesApiVersion)
+ return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
+ }
+
model_ := info.UpstreamModelName
// 2025年5月10日后创建的渠道不移除.
- if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
+ if info.ChannelCreateTime < constant.AzureNoRemoveDotTime {
model_ = strings.Replace(model_, ".", "", -1)
}
// https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
- if info.RelayMode == constant.RelayModeRealtime {
+ if info.RelayMode == relayconstant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
}
- return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
- case common.ChannelTypeMiniMax:
+ return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
+ case constant.ChannelTypeMiniMax:
return minimax.GetRequestURL(info)
- case common.ChannelTypeCustom:
- url := info.BaseUrl
+ case constant.ChannelTypeCustom:
+ url := info.ChannelBaseUrl
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
return url, nil
default:
- return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
+ if info.RelayFormat == types.RelayFormatClaude || info.RelayFormat == types.RelayFormatGemini {
+ return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
+ }
+ return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, header)
- if info.ChannelType == common.ChannelTypeAzure {
+ if info.ChannelType == constant.ChannelTypeAzure {
header.Set("api-key", info.ApiKey)
return nil
}
- if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
+ if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization {
header.Set("OpenAI-Organization", info.Organization)
}
- if info.RelayMode == constant.RelayModeRealtime {
+ if info.RelayMode == relayconstant.RelayModeRealtime {
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
if swp != "" {
items := []string{
@@ -138,8 +202,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
} else {
header.Set("Authorization", "Bearer "+info.ApiKey)
}
- if info.ChannelType == common.ChannelTypeOpenRouter {
- header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
+ if info.ChannelType == constant.ChannelTypeOpenRouter {
+ header.Set("HTTP-Referer", "https://www.newapi.ai")
header.Set("X-Title", "New API")
}
return nil
@@ -149,30 +213,112 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
- if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
+ if info.ChannelType != constant.ChannelTypeOpenAI && info.ChannelType != constant.ChannelTypeAzure {
request.StreamOptions = nil
}
- if strings.HasPrefix(request.Model, "o") {
+ if info.ChannelType == constant.ChannelTypeOpenRouter {
+ if len(request.Usage) == 0 {
+ request.Usage = json.RawMessage(`{"include":true}`)
+ }
+ // 适配 OpenRouter 的 thinking 后缀
+ if strings.HasSuffix(info.UpstreamModelName, "-thinking") {
+ info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
+ request.Model = info.UpstreamModelName
+ if len(request.Reasoning) == 0 {
+ reasoning := map[string]any{
+ "enabled": true,
+ }
+ if request.ReasoningEffort != "" && request.ReasoningEffort != "none" {
+ reasoning["effort"] = request.ReasoningEffort
+ }
+ marshal, err := common.Marshal(reasoning)
+ if err != nil {
+ return nil, fmt.Errorf("error marshalling reasoning: %w", err)
+ }
+ request.Reasoning = marshal
+ }
+ // 清空多余的ReasoningEffort
+ request.ReasoningEffort = ""
+ } else {
+ if len(request.Reasoning) == 0 {
+ // 适配 OpenAI 的 ReasoningEffort 格式
+ if request.ReasoningEffort != "" {
+ reasoning := map[string]any{
+ "enabled": true,
+ }
+ if request.ReasoningEffort != "none" {
+ reasoning["effort"] = request.ReasoningEffort
+ marshal, err := common.Marshal(reasoning)
+ if err != nil {
+ return nil, fmt.Errorf("error marshalling reasoning: %w", err)
+ }
+ request.Reasoning = marshal
+ }
+ }
+ }
+ request.ReasoningEffort = ""
+ }
+
+ // https://docs.anthropic.com/en/api/openai-sdk#extended-thinking-support
+ // 没有做排除3.5Haiku等,要出问题再加吧,最佳兼容性(不是
+ if request.THINKING != nil && strings.HasPrefix(info.UpstreamModelName, "anthropic") {
+ var thinking dto.Thinking // Claude标准Thinking格式
+ if err := json.Unmarshal(request.THINKING, &thinking); err != nil {
+ return nil, fmt.Errorf("error Unmarshal thinking: %w", err)
+ }
+
+ // 只有当 thinking.Type 是 "enabled" 时才处理
+ if thinking.Type == "enabled" {
+ // 检查 BudgetTokens 是否为 nil
+ if thinking.BudgetTokens == nil {
+ return nil, fmt.Errorf("BudgetTokens is nil when thinking is enabled")
+ }
+
+ reasoning := openrouter.RequestReasoning{
+ MaxTokens: *thinking.BudgetTokens,
+ }
+
+ marshal, err := common.Marshal(reasoning)
+ if err != nil {
+ return nil, fmt.Errorf("error marshalling reasoning: %w", err)
+ }
+
+ request.Reasoning = marshal
+ }
+
+ // 清空 THINKING
+ request.THINKING = nil
+ }
+
+ }
+ if strings.HasPrefix(info.UpstreamModelName, "o") || strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
- request.Temperature = nil
- if strings.HasSuffix(request.Model, "-high") {
- request.ReasoningEffort = "high"
- request.Model = strings.TrimSuffix(request.Model, "-high")
- } else if strings.HasSuffix(request.Model, "-low") {
- request.ReasoningEffort = "low"
- request.Model = strings.TrimSuffix(request.Model, "-low")
- } else if strings.HasSuffix(request.Model, "-medium") {
- request.ReasoningEffort = "medium"
- request.Model = strings.TrimSuffix(request.Model, "-medium")
+
+ if strings.HasPrefix(info.UpstreamModelName, "o") {
+ request.Temperature = nil
}
+
+ if strings.HasPrefix(info.UpstreamModelName, "gpt-5") {
+ if info.UpstreamModelName != "gpt-5-chat-latest" {
+ request.Temperature = nil
+ }
+ }
+
+ // 转换模型推理力度后缀
+ effort, originModel := parseReasoningEffortFromModelSuffix(info.UpstreamModelName)
+ if effort != "" {
+ request.ReasoningEffort = effort
+ info.UpstreamModelName = originModel
+ request.Model = originModel
+ }
+
info.ReasoningEffort = request.ReasoningEffort
- info.UpstreamModelName = request.Model
// o系列模型developer适配(o1-mini除外)
- if !strings.HasPrefix(request.Model, "o1-mini") && !strings.HasPrefix(request.Model, "o1-preview") {
+ if !strings.HasPrefix(info.UpstreamModelName, "o1-mini") && !strings.HasPrefix(info.UpstreamModelName, "o1-preview") {
//修改第一个Message的内容,将system改为developer
if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
request.Messages[0].Role = "developer"
@@ -193,7 +339,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
a.ResponseFormat = request.ResponseFormat
- if info.RelayMode == constant.RelayModeAudioSpeech {
+ if info.RelayMode == relayconstant.RelayModeAudioSpeech {
jsonData, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("error marshalling object: %w", err)
@@ -242,46 +388,48 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
switch info.RelayMode {
- case constant.RelayModeImagesEdits:
+ case relayconstant.RelayModeImagesEdits:
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model)
- // 获取所有表单字段
- formData := c.Request.PostForm
- // 遍历表单字段并打印输出
- for key, values := range formData {
- if key == "model" {
- continue
+ // 使用已解析的 multipart 表单,避免重复解析
+ mf := c.Request.MultipartForm
+ if mf == nil {
+ if _, err := c.MultipartForm(); err != nil {
+ return nil, errors.New("failed to parse multipart form")
}
- for _, value := range values {
- writer.WriteField(key, value)
+ mf = c.Request.MultipartForm
+ }
+
+ // 写入所有非文件字段
+ if mf != nil {
+ for key, values := range mf.Value {
+ if key == "model" {
+ continue
+ }
+ for _, value := range values {
+ writer.WriteField(key, value)
+ }
}
}
- // Parse the multipart form to handle both single image and multiple images
- if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
- return nil, errors.New("failed to parse multipart form")
- }
-
- if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
+ if mf != nil && mf.File != nil {
// Check if "image" field exists in any form, including array notation
var imageFiles []*multipart.FileHeader
var exists bool
// First check for standard "image" field
- if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
+ if imageFiles, exists = mf.File["image"]; !exists || len(imageFiles) == 0 {
// If not found, check for "image[]" field
- if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
+ if imageFiles, exists = mf.File["image[]"]; !exists || len(imageFiles) == 0 {
// If still not found, iterate through all fields to find any that start with "image["
foundArrayImages := false
- for fieldName, files := range c.Request.MultipartForm.File {
+ for fieldName, files := range mf.File {
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
foundArrayImages = true
- for _, file := range files {
- imageFiles = append(imageFiles, file)
- }
+ imageFiles = append(imageFiles, files...)
}
}
@@ -298,7 +446,6 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
if err != nil {
return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
}
- defer file.Close()
// If multiple images, use image[] as the field name
fieldName := "image"
@@ -322,15 +469,18 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
if _, err := io.Copy(part, file); err != nil {
return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
}
+
+ // 复制完立即关闭,避免在循环内使用 defer 占用资源
+ _ = file.Close()
}
// Handle mask file if present
- if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
+ if maskFiles, exists := mf.File["mask"]; exists && len(maskFiles) > 0 {
maskFile, err := maskFiles[0].Open()
if err != nil {
return nil, errors.New("failed to open mask file")
}
- defer maskFile.Close()
+ // 复制完立即关闭,避免在循环内使用 defer 占用资源
// Determine MIME type for mask file
mimeType := detectImageMimeType(maskFiles[0].Filename)
@@ -348,6 +498,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
if _, err := io.Copy(maskPart, maskFile); err != nil {
return nil, errors.New("copy mask file failed")
}
+ _ = maskFile.Close()
}
} else {
return nil, errors.New("no multipart form data found")
@@ -356,7 +507,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
- return bytes.NewReader(requestBody.Bytes()), nil
+ return &requestBody, nil
default:
return request, nil
@@ -384,57 +535,52 @@ func detectImageMimeType(filename string) string {
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
- // 模型后缀转换 reasoning effort
- if strings.HasSuffix(request.Model, "-high") {
- request.Reasoning.Effort = "high"
- request.Model = strings.TrimSuffix(request.Model, "-high")
- } else if strings.HasSuffix(request.Model, "-low") {
- request.Reasoning.Effort = "low"
- request.Model = strings.TrimSuffix(request.Model, "-low")
- } else if strings.HasSuffix(request.Model, "-medium") {
- request.Reasoning.Effort = "medium"
- request.Model = strings.TrimSuffix(request.Model, "-medium")
+ // 转换模型推理力度后缀
+ effort, originModel := parseReasoningEffortFromModelSuffix(request.Model)
+ if effort != "" {
+ request.Reasoning.Effort = effort
+ request.Model = originModel
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- if info.RelayMode == constant.RelayModeAudioTranscription ||
- info.RelayMode == constant.RelayModeAudioTranslation ||
- info.RelayMode == constant.RelayModeImagesEdits {
+ if info.RelayMode == relayconstant.RelayModeAudioTranscription ||
+ info.RelayMode == relayconstant.RelayModeAudioTranslation ||
+ info.RelayMode == relayconstant.RelayModeImagesEdits {
return channel.DoFormRequest(a, c, info, requestBody)
- } else if info.RelayMode == constant.RelayModeRealtime {
+ } else if info.RelayMode == relayconstant.RelayModeRealtime {
return channel.DoWssRequest(a, c, info, requestBody)
} else {
return channel.DoApiRequest(a, c, info, requestBody)
}
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
- case constant.RelayModeRealtime:
+ case relayconstant.RelayModeRealtime:
err, usage = OpenaiRealtimeHandler(c, info)
- case constant.RelayModeAudioSpeech:
- err, usage = OpenaiTTSHandler(c, resp, info)
- case constant.RelayModeAudioTranslation:
+ case relayconstant.RelayModeAudioSpeech:
+ usage = OpenaiTTSHandler(c, resp, info)
+ case relayconstant.RelayModeAudioTranslation:
fallthrough
- case constant.RelayModeAudioTranscription:
+ case relayconstant.RelayModeAudioTranscription:
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
- case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
- err, usage = OpenaiHandlerWithUsage(c, resp, info)
- case constant.RelayModeRerank:
- err, usage = common_handler.RerankHandler(c, info, resp)
- case constant.RelayModeResponses:
+ case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
+ usage, err = OpenaiHandlerWithUsage(c, info, resp)
+ case relayconstant.RelayModeRerank:
+ usage, err = common_handler.RerankHandler(c, info, resp)
+ case relayconstant.RelayModeResponses:
if info.IsStream {
- err, usage = OaiResponsesStreamHandler(c, resp, info)
+ usage, err = OaiResponsesStreamHandler(c, info, resp)
} else {
- err, usage = OaiResponsesHandler(c, resp, info)
+ usage, err = OaiResponsesHandler(c, info, resp)
}
default:
if info.IsStream {
- err, usage = OaiStreamHandler(c, resp, info)
+ usage, err = OaiStreamHandler(c, info, resp)
} else {
- err, usage = OpenaiHandler(c, resp, info)
+ usage, err = OpenaiHandler(c, info, resp)
}
}
return
@@ -442,17 +588,15 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
func (a *Adaptor) GetModelList() []string {
switch a.ChannelType {
- case common.ChannelType360:
+ case constant.ChannelType360:
return ai360.ModelList
- case common.ChannelTypeMoonshot:
- return moonshot.ModelList
- case common.ChannelTypeLingYiWanWu:
+ case constant.ChannelTypeLingYiWanWu:
return lingyiwanwu.ModelList
- case common.ChannelTypeMiniMax:
+ case constant.ChannelTypeMiniMax:
return minimax.ModelList
- case common.ChannelTypeXinference:
+ case constant.ChannelTypeXinference:
return xinference.ModelList
- case common.ChannelTypeOpenRouter:
+ case constant.ChannelTypeOpenRouter:
return openrouter.ModelList
default:
return ModelList
@@ -461,17 +605,15 @@ func (a *Adaptor) GetModelList() []string {
func (a *Adaptor) GetChannelName() string {
switch a.ChannelType {
- case common.ChannelType360:
+ case constant.ChannelType360:
return ai360.ChannelName
- case common.ChannelTypeMoonshot:
- return moonshot.ChannelName
- case common.ChannelTypeLingYiWanWu:
+ case constant.ChannelTypeLingYiWanWu:
return lingyiwanwu.ChannelName
- case common.ChannelTypeMiniMax:
+ case constant.ChannelTypeMiniMax:
return minimax.ChannelName
- case common.ChannelTypeXinference:
+ case constant.ChannelTypeXinference:
return xinference.ChannelName
- case common.ChannelTypeOpenRouter:
+ case constant.ChannelTypeOpenRouter:
return openrouter.ChannelName
default:
return ChannelName
diff --git a/relay/channel/openai/constant.go b/relay/channel/openai/constant.go
index c703e414..af5b6724 100644
--- a/relay/channel/openai/constant.go
+++ b/relay/channel/openai/constant.go
@@ -12,13 +12,25 @@ var ModelList = []string{
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"gpt-4.5-preview", "gpt-4.5-preview-2025-02-27",
+ "gpt-4.1", "gpt-4.1-2025-04-14",
+ "gpt-4.1-mini", "gpt-4.1-mini-2025-04-14",
+ "gpt-4.1-nano", "gpt-4.1-nano-2025-04-14",
+ "o1", "o1-2024-12-17",
"o1-preview", "o1-preview-2024-09-12",
"o1-mini", "o1-mini-2024-09-12",
+ "o1-pro", "o1-pro-2025-03-19",
"o3-mini", "o3-mini-2025-01-31",
"o3-mini-high", "o3-mini-2025-01-31-high",
"o3-mini-low", "o3-mini-2025-01-31-low",
"o3-mini-medium", "o3-mini-2025-01-31-medium",
- "o1", "o1-2024-12-17",
+ "o3", "o3-2025-04-16",
+ "o3-pro", "o3-pro-2025-06-10",
+ "o3-deep-research", "o3-deep-research-2025-06-26",
+ "o4-mini", "o4-mini-2025-04-16",
+ "o4-mini-deep-research", "o4-mini-deep-research-2025-06-26",
+ "gpt-5", "gpt-5-2025-08-07", "gpt-5-chat-latest",
+ "gpt-5-mini", "gpt-5-mini-2025-08-07",
+ "gpt-5-nano", "gpt-5-nano-2025-08-07",
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01",
"gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01", "gpt-4o-realtime-preview-2024-12-17",
"gpt-4o-mini-realtime-preview", "gpt-4o-mini-realtime-preview-2024-12-17",
@@ -27,7 +39,7 @@ var ModelList = []string{
"text-moderation-latest", "text-moderation-stable",
"text-davinci-edit-001",
"davinci-002", "babbage-002",
- "dall-e-3",
+ "dall-e-3", "gpt-image-1",
"whisper-1",
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
}
diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go
index a068c544..e84f6cc4 100644
--- a/relay/channel/openai/helper.go
+++ b/relay/channel/openai/helper.go
@@ -4,30 +4,37 @@ import (
"encoding/json"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
+ "one-api/types"
"strings"
+ "github.com/samber/lo"
+
"github.com/gin-gonic/gin"
)
// 辅助函数
-func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
+func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
info.SendResponseCount++
+
switch info.RelayFormat {
- case relaycommon.RelayFormatOpenAI:
+ case types.RelayFormatOpenAI:
return sendStreamData(c, info, data, forceFormat, thinkToContent)
- case relaycommon.RelayFormatClaude:
+ case types.RelayFormatClaude:
return handleClaudeFormat(c, data, info)
+ case types.RelayFormatGemini:
+ return handleGeminiFormat(c, data, info)
}
return nil
}
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
var streamResponse dto.ChatCompletionsStreamResponse
- if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
+ if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
return err
}
@@ -41,6 +48,32 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
return nil
}
+func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
+ var streamResponse dto.ChatCompletionsStreamResponse
+ if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
+ logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
+ return err
+ }
+
+ geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
+
+ // 如果返回 nil,表示没有实际内容,跳过发送
+ if geminiResponse == nil {
+ return nil
+ }
+
+ geminiResponseStr, err := common.Marshal(geminiResponse)
+ if err != nil {
+ logger.LogError(c, "failed to marshal gemini response: "+err.Error())
+ return err
+ }
+
+ // send gemini format response
+ c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
+ _ = helper.FlushWriter(c)
+ return nil
+}
+
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
@@ -74,14 +107,14 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex
var streamResponses []dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
- common.SysError("error unmarshalling stream response: " + err.Error())
+ common.SysLog("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
- common.SysError("error processing stream response: " + err.Error())
+ common.SysLog("error processing stream response: " + err.Error())
}
}
return nil
@@ -110,7 +143,7 @@ func processCompletions(streamResp string, streamItems []string, responseTextBui
var streamResponses []dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
- common.SysError("error unmarshalling stream response: " + err.Error())
+ common.SysLog("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
@@ -151,19 +184,21 @@ func handleLastResponse(lastStreamData string, responseId *string, createAt *int
*containStreamUsage = true
*usage = lastStreamResponse.Usage
if !info.ShouldIncludeUsage {
- *shouldSendLastResp = false
+ *shouldSendLastResp = lo.SomeBy(lastStreamResponse.Choices, func(choice dto.ChatCompletionsStreamResponseChoice) bool {
+ return choice.Delta.GetContentString() != "" || choice.Delta.GetReasoningContent() != ""
+ })
}
}
return nil
}
-func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
+func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
responseId string, createAt int64, model string, systemFingerprint string,
usage *dto.Usage, containStreamUsage bool) {
switch info.RelayFormat {
- case relaycommon.RelayFormatOpenAI:
+ case types.RelayFormatOpenAI:
if info.ShouldIncludeUsage && !containStreamUsage {
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
@@ -171,11 +206,11 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
}
helper.Done(c)
- case relaycommon.RelayFormatClaude:
+ case types.RelayFormatClaude:
info.ClaudeConvertInfo.Done = true
var streamResponse dto.ChatCompletionsStreamResponse
- if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
+ common.SysLog("error unmarshalling stream response: " + err.Error())
return
}
@@ -183,8 +218,37 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
- helper.ClaudeData(c, *resp)
+ _ = helper.ClaudeData(c, *resp)
}
+
+ case types.RelayFormatGemini:
+ var streamResponse dto.ChatCompletionsStreamResponse
+ if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
+ common.SysLog("error unmarshalling stream response: " + err.Error())
+ return
+ }
+
+ // 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
+ // 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应
+ // 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
+ // 暂不知是否有程序会不兼容。
+
+ geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
+
+ // openai 流响应开头的空数据
+ if geminiResponse == nil {
+ return
+ }
+
+ geminiResponseStr, err := common.Marshal(geminiResponse)
+ if err != nil {
+ common.SysLog("error marshalling gemini response: " + err.Error())
+ return
+ }
+
+ // 发送最终的 Gemini 响应
+ c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
+ _ = helper.FlushWriter(c)
}
}
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
index 2e3d8df1..cce9235b 100644
--- a/relay/channel/openai/relay-openai.go
+++ b/relay/channel/openai/relay-openai.go
@@ -2,7 +2,6 @@ package openai
import (
"bytes"
- "encoding/json"
"fmt"
"io"
"math"
@@ -11,12 +10,16 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"os"
+ "path/filepath"
"strings"
+ "one-api/types"
+
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
@@ -33,7 +36,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
}
var lastStreamResponse dto.ChatCompletionsStreamResponse
- if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil {
+ if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
return err
}
@@ -104,182 +107,161 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
return helper.ObjectData(c, lastStreamResponse)
}
-func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
- common.LogError(c, "invalid response or response body")
- return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
+ logger.LogError(c, "invalid response or response body")
+ return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
}
- containStreamUsage := false
+ defer service.CloseResponseBodyGracefully(resp)
+
+ model := info.UpstreamModelName
var responseId string
var createAt int64 = 0
var systemFingerprint string
- model := info.UpstreamModelName
-
+ var containStreamUsage bool
var responseTextBuilder strings.Builder
var toolCount int
var usage = &dto.Usage{}
var streamItems []string // store stream items
- var forceFormat bool
- var thinkToContent bool
-
- if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
- forceFormat = forceFmt
- }
-
- if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
- thinkToContent = think2Content
- }
-
- var (
- lastStreamData string
- )
+ var lastStreamData string
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
if lastStreamData != "" {
- err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
+ err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
if err != nil {
- common.SysError("error handling stream format: " + err.Error())
+ common.SysLog("error handling stream format: " + err.Error())
}
}
- lastStreamData = data
- streamItems = append(streamItems, data)
+ if len(data) > 0 {
+ lastStreamData = data
+ streamItems = append(streamItems, data)
+ }
return true
})
+ // 处理最后的响应
shouldSendLastResp := true
- var lastStreamResponse dto.ChatCompletionsStreamResponse
- err := common.DecodeJsonStr(lastStreamData, &lastStreamResponse)
- if err == nil {
- responseId = lastStreamResponse.Id
- createAt = lastStreamResponse.Created
- systemFingerprint = lastStreamResponse.GetSystemFingerprint()
- model = lastStreamResponse.Model
- if service.ValidUsage(lastStreamResponse.Usage) {
- containStreamUsage = true
- usage = lastStreamResponse.Usage
- if !info.ShouldIncludeUsage {
- shouldSendLastResp = false
- }
- }
- for _, choice := range lastStreamResponse.Choices {
- if choice.FinishReason != nil {
- shouldSendLastResp = true
- }
- }
+ if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
+ &containStreamUsage, info, &shouldSendLastResp); err != nil {
+ logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
}
- if shouldSendLastResp {
- sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
- //err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
+ if info.RelayFormat == types.RelayFormatOpenAI {
+ if shouldSendLastResp {
+ _ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
+ }
}
// 处理token计算
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
- common.SysError("error processing tokens: " + err.Error())
+ logger.LogError(c, "error processing tokens: "+err.Error())
}
if !containStreamUsage {
- usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+ usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else {
- if info.ChannelType == common.ChannelTypeDeepSeek {
+ if info.ChannelType == constant.ChannelTypeDeepSeek {
if usage.PromptCacheHitTokens != 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
}
}
+ HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
- handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
-
- return nil, usage
+ return usage, nil
}
-func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ defer service.CloseResponseBodyGracefully(resp)
+
var simpleResponse dto.OpenAITextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
- err = resp.Body.Close()
+ if common.DebugEnabled {
+ println("upstream response body:", string(responseBody))
+ }
+ err = common.Unmarshal(responseBody, &simpleResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
- err = common.DecodeJson(responseBody, &simpleResponse)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
- }
- if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
- return &dto.OpenAIErrorWithStatusCode{
- Error: *simpleResponse.Error,
- StatusCode: resp.StatusCode,
- }, nil
- }
-
- forceFormat := false
- if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
- forceFormat = forceFmt
+ if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
+ return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
}
- if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
- completionTokens := 0
- for _, choice := range simpleResponse.Choices {
- ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
- completionTokens += ctkm
+ forceFormat := false
+ if info.ChannelSetting.ForceFormat {
+ forceFormat = true
+ }
+
+ usageModified := false
+ if simpleResponse.Usage.PromptTokens == 0 {
+ completionTokens := simpleResponse.Usage.CompletionTokens
+ if completionTokens == 0 {
+ for _, choice := range simpleResponse.Choices {
+ ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
+ completionTokens += ctkm
+ }
}
simpleResponse.Usage = dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: completionTokens,
TotalTokens: info.PromptTokens + completionTokens,
}
+ usageModified = true
}
switch info.RelayFormat {
- case relaycommon.RelayFormatOpenAI:
- if forceFormat {
- responseBody, err = json.Marshal(simpleResponse)
+ case types.RelayFormatOpenAI:
+ if usageModified {
+ var bodyMap map[string]interface{}
+ err = common.Unmarshal(responseBody, &bodyMap)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+ }
+ bodyMap["usage"] = simpleResponse.Usage
+ responseBody, _ = common.Marshal(bodyMap)
+ }
+ if forceFormat {
+ responseBody, err = common.Marshal(simpleResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
} else {
break
}
- case relaycommon.RelayFormatClaude:
+ case types.RelayFormatClaude:
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
- claudeRespStr, err := json.Marshal(claudeResp)
+ claudeRespStr, err := common.Marshal(claudeResp)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
responseBody = claudeRespStr
+ case types.RelayFormatGemini:
+ geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
+ geminiRespStr, err := common.Marshal(geminiResp)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ responseBody = geminiRespStr
}
- // Reset response body
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
- // We shouldn't set the header before we parse the response body, because the parse part may fail.
- // And then we will have to send an error response, but in this case, the header has already been set.
- // So the httpClient will be confused by the response.
- // For example, Postman will report error, and we cannot check the response at all.
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
- //return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
- common.SysError("error copying response body: " + err.Error())
- }
- resp.Body.Close()
- return nil, &simpleResponse.Usage
+ service.IOCopyBytesGracefully(c, resp, responseBody)
+
+ return &simpleResponse.Usage, nil
}
-func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
// the status code has been judged before, if there is a body reading failure,
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
- // Analogous to nginx's load balancing, it will only retry if it can't be requested or
- // if the upstream returns a specific status code, once the upstream has already written the header,
- // the subsequent failure of the response body should be regarded as a non-recoverable error,
+ // Analogous to nginx's load balancing, it will only retry if it can't be requested or
+ // if the upstream returns a specific status code, once the upstream has already written the header,
+ // the subsequent failure of the response body should be regarded as a non-recoverable error,
// and can be terminated directly.
- defer resp.Body.Close()
+ defer service.CloseResponseBodyGracefully(resp)
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.TotalTokens = info.PromptTokens
@@ -290,40 +272,25 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
c.Writer.WriteHeaderNow()
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
- common.LogError(c, err.Error())
+ logger.LogError(c, err.Error())
}
- return nil, usage
+ return usage
}
-func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
+ defer service.CloseResponseBodyGracefully(resp)
+
// count tokens by audio file duration
audioTokens, err := countAudioTokens(c)
if err != nil {
- return service.OpenAIErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError), nil
+ return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
- // Reset response body
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
- // We shouldn't set the header before we parse the response body, because the parse part may fail.
- // And then we will have to send an error response, but in this case, the header has already been set.
- // So the httpClient will be confused by the response.
- // For example, Postman will report error, and we cannot check the response at all.
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
- }
- resp.Body.Close()
+ // 写入新的 response body
+ service.IOCopyBytesGracefully(c, resp, responseBody)
usage := &dto.Usage{}
usage.PromptTokens = audioTokens
@@ -345,13 +312,14 @@ func countAudioTokens(c *gin.Context) (int, error) {
if err = c.ShouldBind(&reqBody); err != nil {
return 0, errors.WithStack(err)
}
-
+ ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
reqFp, err := reqBody.File.Open()
if err != nil {
return 0, errors.WithStack(err)
}
+ defer reqFp.Close()
- tmpFp, err := os.CreateTemp("", "audio-*")
+ tmpFp, err := os.CreateTemp("", "audio-*"+ext)
if err != nil {
return 0, errors.WithStack(err)
}
@@ -365,7 +333,7 @@ func countAudioTokens(c *gin.Context) (int, error) {
return 0, errors.WithStack(err)
}
- duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
+ duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
if err != nil {
return 0, errors.WithStack(err)
}
@@ -373,9 +341,9 @@ func countAudioTokens(c *gin.Context) (int, error) {
return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
}
-func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) {
+func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
- return service.OpenAIErrorWrapper(fmt.Errorf("invalid websocket connection"), "invalid_connection", http.StatusBadRequest), nil
+ return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
}
info.IsStream = true
@@ -413,7 +381,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
}
realtimeEvent := &dto.RealtimeEvent{}
- err = json.Unmarshal(message, realtimeEvent)
+ err = common.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
@@ -432,7 +400,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
- common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
+ logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
@@ -473,7 +441,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
}
info.SetFirstResponseTime()
realtimeEvent := &dto.RealtimeEvent{}
- err = json.Unmarshal(message, realtimeEvent)
+ err = common.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
@@ -505,7 +473,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
- common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
+ logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
info.IsFirstRequest = false
localUsage.InputTokens += textToken + audioToken
@@ -520,9 +488,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
localUsage = &dto.RealtimeUsage{}
// print now usage
}
- //common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
- //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
- //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
+ logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
+ logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
+ logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session
@@ -537,7 +505,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
- common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
+ logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.OutputTokens += textToken + audioToken
localUsage.OutputTokenDetails.TextTokens += textToken
@@ -563,7 +531,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
case <-targetClosed:
case err := <-errChan:
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
- common.LogError(c, "realtime error: "+err.Error())
+ logger.LogError(c, "realtime error: "+err.Error())
case <-c.Done():
}
@@ -598,41 +566,26 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
return err
}
-func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ defer service.CloseResponseBodyGracefully(resp)
+
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
- // Reset response body
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
- // We shouldn't set the header before we parse the response body, because the parse part may fail.
- // And then we will have to send an error response, but in this case, the header has already been set.
- // So the httpClient will be confused by the response.
- // For example, Postman will report error, and we cannot check the response at all.
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- // reset content length
- c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(responseBody)))
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
var usageResp dto.SimpleResponse
- err = json.Unmarshal(responseBody, &usageResp)
+ err = common.Unmarshal(responseBody, &usageResp)
if err != nil {
- return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
+
+ // 写入新的 response body
+ service.IOCopyBytesGracefully(c, resp, responseBody)
+
+ // Once we've written to the client, we should not return errors anymore
+ // because the upstream has already consumed resources and returned content
+ // We should still perform billing even if parsing fails
// format
if usageResp.InputTokens > 0 {
usageResp.PromptTokens += usageResp.InputTokens
@@ -644,5 +597,5 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
}
- return nil, &usageResp.Usage
+ return &usageResp.Usage, nil
}
diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go
index 1d1e060e..ab2aa8a4 100644
--- a/relay/channel/openai/relay_responses.go
+++ b/relay/channel/openai/relay_responses.go
@@ -1,80 +1,66 @@
package openai
import (
- "bytes"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
+ "one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
-func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ defer service.CloseResponseBodyGracefully(resp)
+
// read response body
var responsesResponse dto.OpenAIResponsesResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
- err = resp.Body.Close()
+ err = common.Unmarshal(responseBody, &responsesResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
- err = common.DecodeJson(responseBody, &responsesResponse)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
- }
- if responsesResponse.Error != nil {
- return &dto.OpenAIErrorWithStatusCode{
- Error: dto.OpenAIError{
- Message: responsesResponse.Error.Message,
- Type: "openai_error",
- Code: responsesResponse.Error.Code,
- },
- StatusCode: resp.StatusCode,
- }, nil
+ if oaiError := responsesResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
+ return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
}
- // reset response body
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
- // We shouldn't set the header before we parse the response body, because the parse part may fail.
- // And then we will have to send an error response, but in this case, the header has already been set.
- // So the httpClient will be confused by the response.
- // For example, Postman will report error, and we cannot check the response at all.
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
- // copy response body
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
- common.SysError("error copying response body: " + err.Error())
- }
- resp.Body.Close()
+ // 写入新的 response body
+ service.IOCopyBytesGracefully(c, resp, responseBody)
+
// compute usage
usage := dto.Usage{}
- usage.PromptTokens = responsesResponse.Usage.InputTokens
- usage.CompletionTokens = responsesResponse.Usage.OutputTokens
- usage.TotalTokens = responsesResponse.Usage.TotalTokens
+ if responsesResponse.Usage != nil {
+ usage.PromptTokens = responsesResponse.Usage.InputTokens
+ usage.CompletionTokens = responsesResponse.Usage.OutputTokens
+ usage.TotalTokens = responsesResponse.Usage.TotalTokens
+ if responsesResponse.Usage.InputTokensDetails != nil {
+ usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
+ }
+ }
// 解析 Tools 用量
for _, tool := range responsesResponse.Tools {
- info.ResponsesUsageInfo.BuiltInTools[tool.Type].CallCount++
+ info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])].CallCount++
}
- return nil, &usage
+ return &usage, nil
}
-func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
- common.LogError(c, "invalid response or response body")
- return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
+ logger.LogError(c, "invalid response or response body")
+ return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
}
+ defer service.CloseResponseBodyGracefully(resp)
+
var usage = &dto.Usage{}
var responseTextBuilder strings.Builder
@@ -82,13 +68,18 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
// 检查当前数据是否包含 completed 状态和 usage 信息
var streamResponse dto.ResponsesStreamResponse
- if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
+ if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
sendResponsesStreamData(c, streamResponse, data)
switch streamResponse.Type {
case "response.completed":
- usage.PromptTokens = streamResponse.Response.Usage.InputTokens
- usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
- usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
+ if streamResponse.Response.Usage != nil {
+ usage.PromptTokens = streamResponse.Response.Usage.InputTokens
+ usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
+ usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
+ if streamResponse.Response.Usage.InputTokensDetails != nil {
+ usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens
+ }
+ }
case "response.output_text.delta":
// 处理输出文本
responseTextBuilder.WriteString(streamResponse.Delta)
@@ -110,10 +101,16 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
tempStr := responseTextBuilder.String()
if len(tempStr) > 0 {
// 非正常结束,使用输出文本的 token 数量
- completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
+ completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
usage.CompletionTokens = completionTokens
}
}
- return nil, usage
+ if usage.PromptTokens == 0 && usage.CompletionTokens != 0 {
+ usage.PromptTokens = usage.CompletionTokens
+ } else {
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ }
+
+ return usage, nil
}
diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go
index 3a06e7ee..2a022a1b 100644
--- a/relay/channel/palm/adaptor.go
+++ b/relay/channel/palm/adaptor.go
@@ -9,6 +9,7 @@ import (
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
@@ -16,6 +17,11 @@ import (
type Adaptor struct {
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
@@ -36,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.ChannelBaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -70,13 +76,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
var responseText string
err, responseText = palmStreamHandler(c, resp)
- usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
- err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+ usage, err = palmHandler(c, info, resp)
}
return
}
diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go
index c8e337de..3a6ec2f4 100644
--- a/relay/channel/palm/relay-palm.go
+++ b/relay/channel/palm/relay-palm.go
@@ -2,55 +2,32 @@ package palm
import (
"encoding/json"
- "fmt"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
+ relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
)
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
-func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
- palmRequest := PaLMChatRequest{
- Prompt: PaLMPrompt{
- Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
- },
- Temperature: textRequest.Temperature,
- CandidateCount: textRequest.N,
- TopP: textRequest.TopP,
- TopK: textRequest.MaxTokens,
- }
- for _, message := range textRequest.Messages {
- palmMessage := PaLMChatMessage{
- Content: message.StringContent(),
- }
- if message.Role == "user" {
- palmMessage.Author = "0"
- } else {
- palmMessage.Author = "1"
- }
- palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
- }
- return &palmRequest
-}
-
func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
- content, _ := json.Marshal(candidate.Content)
choice := dto.OpenAITextResponseChoice{
Index: i,
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: candidate.Content,
},
FinishReason: "stop",
}
@@ -72,29 +49,24 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompleti
return &response
}
-func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
+func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, string) {
responseText := ""
- responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+ responseId := helper.GetResponseID(c)
createdTime := common.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- common.SysError("error reading stream response: " + err.Error())
- stopChan <- true
- return
- }
- err = resp.Body.Close()
- if err != nil {
- common.SysError("error closing stream response: " + err.Error())
+ common.SysLog("error reading stream response: " + err.Error())
stopChan <- true
return
}
+ service.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ common.SysLog("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
@@ -106,7 +78,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ common.SysLog("error marshalling stream response: " + err.Error())
stopChan <- true
return
}
@@ -124,52 +96,43 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
return false
}
})
- err := resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
- }
+ service.CloseResponseBodyGracefully(resp)
return nil, responseText
}
-func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
+ service.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
- return &dto.OpenAIErrorWithStatusCode{
- Error: dto.OpenAIError{
- Message: palmResponse.Error.Message,
- Type: palmResponse.Error.Status,
- Param: "",
- Code: palmResponse.Error.Code,
- },
- StatusCode: resp.StatusCode,
- }, nil
+ return nil, types.WithOpenAIError(types.OpenAIError{
+ Message: palmResponse.Error.Message,
+ Type: palmResponse.Error.Status,
+ Param: "",
+ Code: palmResponse.Error.Code,
+ }, resp.StatusCode)
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
- completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
+ completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, info.UpstreamModelName)
usage := dto.Usage{
- PromptTokens: promptTokens,
+ PromptTokens: info.PromptTokens,
CompletionTokens: completionTokens,
- TotalTokens: promptTokens + completionTokens,
+ TotalTokens: info.PromptTokens + completionTokens,
}
fullTextResponse.Usage = usage
- jsonResponse, err := json.Marshal(fullTextResponse)
+ jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return nil, &usage
+ service.IOCopyBytesGracefully(c, resp, jsonResponse)
+ return &usage, nil
}
diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go
index ca206503..8ab9c854 100644
--- a/relay/channel/perplexity/adaptor.go
+++ b/relay/channel/perplexity/adaptor.go
@@ -9,6 +9,7 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
@@ -16,6 +17,11 @@ import (
type Adaptor struct {
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
@@ -36,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
+ return fmt.Sprintf("%s/chat/completions", info.ChannelBaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -73,11 +79,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
- err, usage = openai.OaiStreamHandler(c, resp, info)
+ usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
- err, usage = openai.OpenaiHandler(c, resp, info)
+ usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}
diff --git a/relay/channel/perplexity/relay-perplexity.go b/relay/channel/perplexity/relay-perplexity.go
index 9772aead..7ebadd0f 100644
--- a/relay/channel/perplexity/relay-perplexity.go
+++ b/relay/channel/perplexity/relay-perplexity.go
@@ -16,6 +16,6 @@ func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpen
Messages: messages,
Temperature: request.Temperature,
TopP: request.TopP,
- MaxTokens: request.MaxTokens,
+ MaxTokens: request.GetMaxTokens(),
}
}
diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go
index 89236ea3..4c176c08 100644
--- a/relay/channel/siliconflow/adaptor.go
+++ b/relay/channel/siliconflow/adaptor.go
@@ -10,6 +10,7 @@ import (
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
@@ -17,20 +18,24 @@ import (
type Adaptor struct {
}
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
- panic("implement me")
- return nil, nil
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
+ adaptor := openai.Adaptor{}
+ return adaptor.ConvertClaudeRequest(c, info, req)
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
- return nil, errors.New("not implemented")
+ return nil, errors.New("not supported")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
+ adaptor := openai.Adaptor{}
+ return adaptor.ConvertImageRequest(c, info, request)
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -38,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
- return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
- return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions {
- return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeCompletions {
- return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
+ return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
}
- return "", errors.New("invalid relay mode")
+ return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -76,20 +81,23 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return request, nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeRerank:
- err, usage = siliconflowRerankHandler(c, resp)
+ usage, err = siliconflowRerankHandler(c, info, resp)
+ case constant.RelayModeEmbeddings:
+ usage, err = openai.OpenaiHandler(c, info, resp)
case constant.RelayModeCompletions:
fallthrough
case constant.RelayModeChatCompletions:
+ fallthrough
+ default:
if info.IsStream {
- err, usage = openai.OaiStreamHandler(c, resp, info)
+ usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
- err, usage = openai.OpenaiHandler(c, resp, info)
+ usage, err = openai.OpenaiHandler(c, info, resp)
}
- case constant.RelayModeEmbeddings:
- err, usage = openai.OpenaiHandler(c, resp, info)
+
}
return
}
diff --git a/relay/channel/siliconflow/relay-siliconflow.go b/relay/channel/siliconflow/relay-siliconflow.go
index a01e745c..b21faccb 100644
--- a/relay/channel/siliconflow/relay-siliconflow.go
+++ b/relay/channel/siliconflow/relay-siliconflow.go
@@ -2,26 +2,26 @@ package siliconflow
import (
"encoding/json"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
+ relaycommon "one-api/relay/common"
"one-api/service"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
)
-func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
+ service.CloseResponseBodyGracefully(resp)
var siliconflowResp SFRerankResponse
err = json.Unmarshal(responseBody, &siliconflowResp)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
usage := &dto.Usage{
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
@@ -35,10 +35,10 @@ func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIE
jsonResponse, err := json.Marshal(rerankResp)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return nil, usage
+ service.IOCopyBytesGracefully(c, resp, jsonResponse)
+ return usage, nil
}
diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go
new file mode 100644
index 00000000..a5ada137
--- /dev/null
+++ b/relay/channel/task/jimeng/adaptor.go
@@ -0,0 +1,380 @@
+package jimeng
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "one-api/model"
+ "sort"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/pkg/errors"
+
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+)
+
+// ============================
+// Request / Response structures
+// ============================
+
+type requestPayload struct {
+ ReqKey string `json:"req_key"`
+ BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
+ ImageUrls []string `json:"image_urls,omitempty"`
+ Prompt string `json:"prompt,omitempty"`
+ Seed int64 `json:"seed"`
+ AspectRatio string `json:"aspect_ratio"`
+}
+
+type responsePayload struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ RequestId string `json:"request_id"`
+ Data struct {
+ TaskID string `json:"task_id"`
+ } `json:"data"`
+}
+
+type responseTask struct {
+ Code int `json:"code"`
+ Data struct {
+ BinaryDataBase64 []interface{} `json:"binary_data_base64"`
+ ImageUrls interface{} `json:"image_urls"`
+ RespData string `json:"resp_data"`
+ Status string `json:"status"`
+ VideoUrl string `json:"video_url"`
+ } `json:"data"`
+ Message string `json:"message"`
+ RequestId string `json:"request_id"`
+ Status int `json:"status"`
+ TimeElapsed string `json:"time_elapsed"`
+}
+
+// ============================
+// Adaptor implementation
+// ============================
+
+type TaskAdaptor struct {
+ ChannelType int
+ accessKey string
+ secretKey string
+ baseURL string
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
+ a.ChannelType = info.ChannelType
+ a.baseURL = info.ChannelBaseUrl
+
+ // apiKey format: "access_key|secret_key"
+ keyParts := strings.Split(info.ApiKey, "|")
+ if len(keyParts) == 2 {
+ a.accessKey = strings.TrimSpace(keyParts[0])
+ a.secretKey = strings.TrimSpace(keyParts[1])
+ }
+}
+
+// ValidateRequestAndSetAction parses body, validates fields and sets default action.
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
+ // Accept only POST /v1/video/generations as "generate" action.
+ action := constant.TaskActionGenerate
+ info.Action = action
+
+ req := relaycommon.TaskSubmitReq{}
+ if err := common.UnmarshalBodyReusable(c, &req); err != nil {
+ taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+ return
+ }
+ if strings.TrimSpace(req.Prompt) == "" {
+ taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
+ return
+ }
+
+ // Store into context for later usage
+ c.Set("task_request", req)
+ return nil
+}
+
+// BuildRequestURL constructs the upstream URL.
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
+ return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
+}
+
+// BuildRequestHeader sets required headers.
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ return a.signRequest(req, a.accessKey, a.secretKey)
+}
+
+// BuildRequestBody converts request into Jimeng specific format.
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
+ v, exists := c.Get("task_request")
+ if !exists {
+ return nil, fmt.Errorf("request not found in context")
+ }
+ req := v.(relaycommon.TaskSubmitReq)
+
+ body, err := a.convertToRequestPayload(&req)
+ if err != nil {
+ return nil, errors.Wrap(err, "convert request payload failed")
+ }
+ data, err := json.Marshal(body)
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}
+
+// DoRequest delegates to common helper.
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+// DoResponse handles upstream response, returns taskID etc.
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+ _ = resp.Body.Close()
+
+ // Parse Jimeng response
+ var jResp responsePayload
+ if err := json.Unmarshal(responseBody, &jResp); err != nil {
+ taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+
+ if jResp.Code != 10000 {
+ taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{"task_id": jResp.Data.TaskID})
+ return jResp.Data.TaskID, responseBody, nil
+}
+
+// FetchTask fetch task status
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+ taskID, ok := body["task_id"].(string)
+ if !ok {
+ return nil, fmt.Errorf("invalid task_id")
+ }
+
+ uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
+ payload := map[string]string{
+ "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
+ "task_id": taskID,
+ }
+ payloadBytes, err := json.Marshal(payload)
+ if err != nil {
+ return nil, errors.Wrap(err, "marshal fetch task payload failed")
+ }
+
+ req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes))
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Content-Type", "application/json")
+
+ keyParts := strings.Split(key, "|")
+ if len(keyParts) != 2 {
+ return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
+ }
+ accessKey := strings.TrimSpace(keyParts[0])
+ secretKey := strings.TrimSpace(keyParts[1])
+
+ if err := a.signRequest(req, accessKey, secretKey); err != nil {
+ return nil, errors.Wrap(err, "sign request failed")
+ }
+
+ return service.GetHttpClient().Do(req)
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+ return []string{"jimeng_vgfm_t2v_l20"}
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+ return "jimeng"
+}
+
+func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error {
+ var bodyBytes []byte
+ var err error
+
+ if req.Body != nil {
+ bodyBytes, err = io.ReadAll(req.Body)
+ if err != nil {
+ return errors.Wrap(err, "read request body failed")
+ }
+ _ = req.Body.Close()
+ req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
+ } else {
+ bodyBytes = []byte{}
+ }
+
+ payloadHash := sha256.Sum256(bodyBytes)
+ hexPayloadHash := hex.EncodeToString(payloadHash[:])
+
+ t := time.Now().UTC()
+ xDate := t.Format("20060102T150405Z")
+ shortDate := t.Format("20060102")
+
+ req.Header.Set("Host", req.URL.Host)
+ req.Header.Set("X-Date", xDate)
+ req.Header.Set("X-Content-Sha256", hexPayloadHash)
+
+ // Sort and encode query parameters to create canonical query string
+ queryParams := req.URL.Query()
+ sortedKeys := make([]string, 0, len(queryParams))
+ for k := range queryParams {
+ sortedKeys = append(sortedKeys, k)
+ }
+ sort.Strings(sortedKeys)
+ var queryParts []string
+ for _, k := range sortedKeys {
+ values := queryParams[k]
+ sort.Strings(values)
+ for _, v := range values {
+ queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
+ }
+ }
+ canonicalQueryString := strings.Join(queryParts, "&")
+
+ headersToSign := map[string]string{
+ "host": req.URL.Host,
+ "x-date": xDate,
+ "x-content-sha256": hexPayloadHash,
+ }
+ if req.Header.Get("Content-Type") != "" {
+ headersToSign["content-type"] = req.Header.Get("Content-Type")
+ }
+
+ var signedHeaderKeys []string
+ for k := range headersToSign {
+ signedHeaderKeys = append(signedHeaderKeys, k)
+ }
+ sort.Strings(signedHeaderKeys)
+
+ var canonicalHeaders strings.Builder
+ for _, k := range signedHeaderKeys {
+ canonicalHeaders.WriteString(k)
+ canonicalHeaders.WriteString(":")
+ canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
+ canonicalHeaders.WriteString("\n")
+ }
+ signedHeaders := strings.Join(signedHeaderKeys, ";")
+
+ canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
+ req.Method,
+ req.URL.Path,
+ canonicalQueryString,
+ canonicalHeaders.String(),
+ signedHeaders,
+ hexPayloadHash,
+ )
+
+ hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
+ hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
+
+ region := "cn-north-1"
+ serviceName := "cv"
+ credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
+ stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
+ xDate,
+ credentialScope,
+ hexHashedCanonicalRequest,
+ )
+
+ kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
+ kRegion := hmacSHA256(kDate, []byte(region))
+ kService := hmacSHA256(kRegion, []byte(serviceName))
+ kSigning := hmacSHA256(kService, []byte("request"))
+ signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
+
+ authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
+ accessKey,
+ credentialScope,
+ signedHeaders,
+ signature,
+ )
+ req.Header.Set("Authorization", authorization)
+ return nil
+}
+
+func hmacSHA256(key []byte, data []byte) []byte {
+ h := hmac.New(sha256.New, key)
+ h.Write(data)
+ return h.Sum(nil)
+}
+
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+ r := requestPayload{
+ ReqKey: "jimeng_vgfm_i2v_l20",
+ Prompt: req.Prompt,
+ AspectRatio: "16:9", // Default aspect ratio
+ Seed: -1, // Default to random
+ }
+
+ // Handle one-of image_urls or binary_data_base64
+ if req.Image != "" {
+ if strings.HasPrefix(req.Image, "http") {
+ r.ImageUrls = []string{req.Image}
+ } else {
+ r.BinaryDataBase64 = []string{req.Image}
+ }
+ }
+ metadata := req.Metadata
+ medaBytes, err := json.Marshal(metadata)
+ if err != nil {
+ return nil, errors.Wrap(err, "metadata marshal metadata failed")
+ }
+ err = json.Unmarshal(medaBytes, &r)
+ if err != nil {
+ return nil, errors.Wrap(err, "unmarshal metadata failed")
+ }
+ return &r, nil
+}
+
+func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
+ resTask := responseTask{}
+ if err := json.Unmarshal(respBody, &resTask); err != nil {
+ return nil, errors.Wrap(err, "unmarshal task result failed")
+ }
+ taskResult := relaycommon.TaskInfo{}
+ if resTask.Code == 10000 {
+ taskResult.Code = 0
+ } else {
+ taskResult.Code = resTask.Code // todo uni code
+ taskResult.Reason = resTask.Message
+ taskResult.Status = model.TaskStatusFailure
+ taskResult.Progress = "100%"
+ }
+ switch resTask.Data.Status {
+ case "in_queue":
+ taskResult.Status = model.TaskStatusQueued
+ taskResult.Progress = "10%"
+ case "done":
+ taskResult.Status = model.TaskStatusSuccess
+ taskResult.Progress = "100%"
+ }
+ taskResult.Url = resTask.Data.VideoUrl
+ return &taskResult, nil
+}
diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go
new file mode 100644
index 00000000..1fecda08
--- /dev/null
+++ b/relay/channel/task/kling/adaptor.go
@@ -0,0 +1,340 @@
+package kling
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "github.com/samber/lo"
+ "io"
+ "net/http"
+ "one-api/model"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/golang-jwt/jwt"
+ "github.com/pkg/errors"
+
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+)
+
+// ============================
+// Request / Response structures
+// ============================
+
+type SubmitReq struct {
+ Prompt string `json:"prompt"`
+ Model string `json:"model,omitempty"`
+ Mode string `json:"mode,omitempty"`
+ Image string `json:"image,omitempty"`
+ Size string `json:"size,omitempty"`
+ Duration int `json:"duration,omitempty"`
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+}
+
+type requestPayload struct {
+ Prompt string `json:"prompt,omitempty"`
+ Image string `json:"image,omitempty"`
+ Mode string `json:"mode,omitempty"`
+ Duration string `json:"duration,omitempty"`
+ AspectRatio string `json:"aspect_ratio,omitempty"`
+ ModelName string `json:"model_name,omitempty"`
+ Model string `json:"model,omitempty"` // Compatible with upstreams that only recognize "model"
+ CfgScale float64 `json:"cfg_scale,omitempty"`
+}
+
+type responsePayload struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ TaskId string `json:"task_id"`
+ RequestId string `json:"request_id"`
+ Data struct {
+ TaskId string `json:"task_id"`
+ TaskStatus string `json:"task_status"`
+ TaskStatusMsg string `json:"task_status_msg"`
+ TaskResult struct {
+ Videos []struct {
+ Id string `json:"id"`
+ Url string `json:"url"`
+ Duration string `json:"duration"`
+ } `json:"videos"`
+ } `json:"task_result"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+ } `json:"data"`
+}
+
+// ============================
+// Adaptor implementation
+// ============================
+
+type TaskAdaptor struct {
+ ChannelType int
+ apiKey string
+ baseURL string
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
+ a.ChannelType = info.ChannelType
+ a.baseURL = info.ChannelBaseUrl
+ a.apiKey = info.ApiKey
+
+ // apiKey format: "access_key|secret_key"
+}
+
+// ValidateRequestAndSetAction parses body, validates fields and sets default action.
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
+ // Accept only POST /v1/video/generations as "generate" action.
+ action := constant.TaskActionGenerate
+ info.Action = action
+
+ var req SubmitReq
+ if err := common.UnmarshalBodyReusable(c, &req); err != nil {
+ taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+ return
+ }
+ if strings.TrimSpace(req.Prompt) == "" {
+ taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
+ return
+ }
+
+ // Store into context for later usage
+ c.Set("task_request", req)
+ return nil
+}
+
+// BuildRequestURL constructs the upstream URL.
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
+ path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
+ return fmt.Sprintf("%s%s", a.baseURL, path), nil
+}
+
+// BuildRequestHeader sets required headers.
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
+ token, err := a.createJWTToken()
+ if err != nil {
+ return fmt.Errorf("failed to create JWT token: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Authorization", "Bearer "+token)
+ req.Header.Set("User-Agent", "kling-sdk/1.0")
+ return nil
+}
+
+// BuildRequestBody converts request into Kling specific format.
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
+ v, exists := c.Get("task_request")
+ if !exists {
+ return nil, fmt.Errorf("request not found in context")
+ }
+ req := v.(SubmitReq)
+
+ body, err := a.convertToRequestPayload(&req)
+ if err != nil {
+ return nil, err
+ }
+ data, err := json.Marshal(body)
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}
+
+// DoRequest delegates to common helper.
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+ if action := c.GetString("action"); action != "" {
+ info.Action = action
+ }
+ return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+// DoResponse handles upstream response, returns taskID etc.
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+
+ var kResp responsePayload
+ err = json.Unmarshal(responseBody, &kResp)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
+ return
+ }
+ if kResp.Code != 0 {
+ taskErr = service.TaskErrorWrapperLocal(fmt.Errorf(kResp.Message), "task_failed", http.StatusBadRequest)
+ return
+ }
+ kResp.TaskId = kResp.Data.TaskId
+ c.JSON(http.StatusOK, kResp)
+ return kResp.Data.TaskId, responseBody, nil
+}
+
+// FetchTask fetch task status
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+ taskID, ok := body["task_id"].(string)
+ if !ok {
+ return nil, fmt.Errorf("invalid task_id")
+ }
+ action, ok := body["action"].(string)
+ if !ok {
+ return nil, fmt.Errorf("invalid action")
+ }
+ path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
+ url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
+
+ req, err := http.NewRequest(http.MethodGet, url, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ token, err := a.createJWTTokenWithKey(key)
+ if err != nil {
+ token = key
+ }
+
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Authorization", "Bearer "+token)
+ req.Header.Set("User-Agent", "kling-sdk/1.0")
+
+ return service.GetHttpClient().Do(req)
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+ return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+ return "kling"
+}
+
+// ============================
+// helpers
+// ============================
+
+func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
+ r := requestPayload{
+ Prompt: req.Prompt,
+ Image: req.Image,
+ Mode: defaultString(req.Mode, "std"),
+ Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
+ AspectRatio: a.getAspectRatio(req.Size),
+ ModelName: req.Model,
+ Model: req.Model, // Keep consistent with model_name, double writing improves compatibility
+ CfgScale: 0.5,
+ }
+ if r.ModelName == "" {
+ r.ModelName = "kling-v1"
+ }
+ metadata := req.Metadata
+ medaBytes, err := json.Marshal(metadata)
+ if err != nil {
+ return nil, errors.Wrap(err, "metadata marshal metadata failed")
+ }
+ err = json.Unmarshal(medaBytes, &r)
+ if err != nil {
+ return nil, errors.Wrap(err, "unmarshal metadata failed")
+ }
+ return &r, nil
+}
+
+func (a *TaskAdaptor) getAspectRatio(size string) string {
+ switch size {
+ case "1024x1024", "512x512":
+ return "1:1"
+ case "1280x720", "1920x1080":
+ return "16:9"
+ case "720x1280", "1080x1920":
+ return "9:16"
+ default:
+ return "1:1"
+ }
+}
+
+func defaultString(s, def string) string {
+ if strings.TrimSpace(s) == "" {
+ return def
+ }
+ return s
+}
+
+func defaultInt(v int, def int) int {
+ if v == 0 {
+ return def
+ }
+ return v
+}
+
+// ============================
+// JWT helpers
+// ============================
+
+func (a *TaskAdaptor) createJWTToken() (string, error) {
+ return a.createJWTTokenWithKey(a.apiKey)
+}
+
+//func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
+// parts := strings.Split(apiKey, "|")
+// if len(parts) != 2 {
+// return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
+// }
+// return a.createJWTTokenWithKey(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
+//}
+
+func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
+
+ keyParts := strings.Split(apiKey, "|")
+ accessKey := strings.TrimSpace(keyParts[0])
+ if len(keyParts) == 1 {
+ return accessKey, nil
+ }
+ secretKey := strings.TrimSpace(keyParts[1])
+ now := time.Now().Unix()
+ claims := jwt.MapClaims{
+ "iss": accessKey,
+ "exp": now + 1800, // 30 minutes
+ "nbf": now - 5,
+ }
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+ token.Header["typ"] = "JWT"
+ return token.SignedString([]byte(secretKey))
+}
+
+func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
+ taskInfo := &relaycommon.TaskInfo{}
+ resPayload := responsePayload{}
+ err := json.Unmarshal(respBody, &resPayload)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to unmarshal response body")
+ }
+ taskInfo.Code = resPayload.Code
+ taskInfo.TaskID = resPayload.Data.TaskId
+ taskInfo.Reason = resPayload.Message
+ //任务状态,枚举值:submitted(已提交)、processing(处理中)、succeed(成功)、failed(失败)
+ status := resPayload.Data.TaskStatus
+ switch status {
+ case "submitted":
+ taskInfo.Status = model.TaskStatusSubmitted
+ case "processing":
+ taskInfo.Status = model.TaskStatusInProgress
+ case "succeed":
+ taskInfo.Status = model.TaskStatusSuccess
+ case "failed":
+ taskInfo.Status = model.TaskStatusFailure
+ default:
+ return nil, fmt.Errorf("unknown task status: %s", status)
+ }
+ if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 {
+ video := videos[0]
+ taskInfo.Url = video.Url
+ }
+ return taskInfo, nil
+}
diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go
index 03d60516..df2bb99e 100644
--- a/relay/channel/task/suno/adaptor.go
+++ b/relay/channel/task/suno/adaptor.go
@@ -22,6 +22,10 @@ type TaskAdaptor struct {
ChannelType int
}
+func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
+ return nil, fmt.Errorf("not implement") // todo implement this method if needed
+}
+
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
}
@@ -55,7 +59,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
}
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
- baseURL := info.BaseUrl
+ baseURL := info.ChannelBaseUrl
fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action)
return fullRequestURL, nil
}
@@ -135,7 +139,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
if err != nil {
- common.SysError(fmt.Sprintf("Get Task error: %v", err))
+ common.SysLog(fmt.Sprintf("Get Task error: %v", err))
return nil, err
}
defer req.Body.Close()
diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go
new file mode 100644
index 00000000..b0cc0bdc
--- /dev/null
+++ b/relay/channel/task/vidu/adaptor.go
@@ -0,0 +1,285 @@
+package vidu
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+
+ "github.com/gin-gonic/gin"
+
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/model"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+
+ "github.com/pkg/errors"
+)
+
+// ============================
+// Request / Response structures
+// ============================
+
+type SubmitReq struct {
+ Prompt string `json:"prompt"`
+ Model string `json:"model,omitempty"`
+ Mode string `json:"mode,omitempty"`
+ Image string `json:"image,omitempty"`
+ Size string `json:"size,omitempty"`
+ Duration int `json:"duration,omitempty"`
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+}
+
+type requestPayload struct {
+ Model string `json:"model"`
+ Images []string `json:"images"`
+ Prompt string `json:"prompt,omitempty"`
+ Duration int `json:"duration,omitempty"`
+ Seed int `json:"seed,omitempty"`
+ Resolution string `json:"resolution,omitempty"`
+ MovementAmplitude string `json:"movement_amplitude,omitempty"`
+ Bgm bool `json:"bgm,omitempty"`
+ Payload string `json:"payload,omitempty"`
+ CallbackUrl string `json:"callback_url,omitempty"`
+}
+
+type responsePayload struct {
+ TaskId string `json:"task_id"`
+ State string `json:"state"`
+ Model string `json:"model"`
+ Images []string `json:"images"`
+ Prompt string `json:"prompt"`
+ Duration int `json:"duration"`
+ Seed int `json:"seed"`
+ Resolution string `json:"resolution"`
+ Bgm bool `json:"bgm"`
+ MovementAmplitude string `json:"movement_amplitude"`
+ Payload string `json:"payload"`
+ CreatedAt string `json:"created_at"`
+}
+
+type taskResultResponse struct {
+ State string `json:"state"`
+ ErrCode string `json:"err_code"`
+ Credits int `json:"credits"`
+ Payload string `json:"payload"`
+ Creations []creation `json:"creations"`
+}
+
+type creation struct {
+ ID string `json:"id"`
+ URL string `json:"url"`
+ CoverURL string `json:"cover_url"`
+}
+
+// ============================
+// Adaptor implementation
+// ============================
+
+type TaskAdaptor struct {
+ ChannelType int
+ baseURL string
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
+ a.ChannelType = info.ChannelType
+ a.baseURL = info.ChannelBaseUrl
+}
+
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError {
+ var req SubmitReq
+ if err := c.ShouldBindJSON(&req); err != nil {
+ return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
+ }
+
+ if req.Prompt == "" {
+ return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
+ }
+
+ if req.Image != "" {
+ info.Action = constant.TaskActionGenerate
+ } else {
+ info.Action = constant.TaskActionTextGenerate
+ }
+
+ c.Set("task_request", req)
+ return nil
+}
+
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) {
+ v, exists := c.Get("task_request")
+ if !exists {
+ return nil, fmt.Errorf("request not found in context")
+ }
+ req := v.(SubmitReq)
+
+ body, err := a.convertToRequestPayload(&req)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(body.Images) == 0 {
+ c.Set("action", constant.TaskActionTextGenerate)
+ }
+
+ data, err := json.Marshal(body)
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}
+
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
+ var path string
+ switch info.Action {
+ case constant.TaskActionGenerate:
+ path = "/img2video"
+ default:
+ path = "/text2video"
+ }
+ return fmt.Sprintf("%s/ent/v2%s", a.baseURL, path), nil
+}
+
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Authorization", "Token "+info.ApiKey)
+ return nil
+}
+
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+ if action := c.GetString("action"); action != "" {
+ info.Action = action
+ }
+ return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+
+ var vResp responsePayload
+ err = json.Unmarshal(responseBody, &vResp)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError)
+ return
+ }
+
+ if vResp.State == "failed" {
+ taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task failed"), "task_failed", http.StatusBadRequest)
+ return
+ }
+
+ c.JSON(http.StatusOK, vResp)
+ return vResp.TaskId, responseBody, nil
+}
+
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+ taskID, ok := body["task_id"].(string)
+ if !ok {
+ return nil, fmt.Errorf("invalid task_id")
+ }
+
+ url := fmt.Sprintf("%s/ent/v2/tasks/%s/creations", baseUrl, taskID)
+
+ req, err := http.NewRequest(http.MethodGet, url, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Authorization", "Token "+key)
+
+ return service.GetHttpClient().Do(req)
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+ return []string{"viduq1", "vidu2.0", "vidu1.5"}
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+ return "vidu"
+}
+
+// ============================
+// helpers
+// ============================
+
+func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
+ var images []string
+ if req.Image != "" {
+ images = []string{req.Image}
+ }
+
+ r := requestPayload{
+ Model: defaultString(req.Model, "viduq1"),
+ Images: images,
+ Prompt: req.Prompt,
+ Duration: defaultInt(req.Duration, 5),
+ Resolution: defaultString(req.Size, "1080p"),
+ MovementAmplitude: "auto",
+ Bgm: false,
+ }
+ metadata := req.Metadata
+ medaBytes, err := json.Marshal(metadata)
+ if err != nil {
+ return nil, errors.Wrap(err, "metadata marshal metadata failed")
+ }
+ err = json.Unmarshal(medaBytes, &r)
+ if err != nil {
+ return nil, errors.Wrap(err, "unmarshal metadata failed")
+ }
+ return &r, nil
+}
+
+func defaultString(value, defaultValue string) string {
+ if value == "" {
+ return defaultValue
+ }
+ return value
+}
+
+func defaultInt(value, defaultValue int) int {
+ if value == 0 {
+ return defaultValue
+ }
+ return value
+}
+
+func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
+ taskInfo := &relaycommon.TaskInfo{}
+
+ var taskResp taskResultResponse
+ err := json.Unmarshal(respBody, &taskResp)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to unmarshal response body")
+ }
+
+ state := taskResp.State
+ switch state {
+ case "created", "queueing":
+ taskInfo.Status = model.TaskStatusSubmitted
+ case "processing":
+ taskInfo.Status = model.TaskStatusInProgress
+ case "success":
+ taskInfo.Status = model.TaskStatusSuccess
+ if len(taskResp.Creations) > 0 {
+ taskInfo.Url = taskResp.Creations[0].URL
+ }
+ case "failed":
+ taskInfo.Status = model.TaskStatusFailure
+ if taskResp.ErrCode != "" {
+ taskInfo.Reason = taskResp.ErrCode
+ }
+ default:
+ return nil, fmt.Errorf("unknown task state: %s", state)
+ }
+
+ return taskInfo, nil
+}
diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go
index 44718a25..ab96ecaa 100644
--- a/relay/channel/tencent/adaptor.go
+++ b/relay/channel/tencent/adaptor.go
@@ -6,10 +6,11 @@ import (
"io"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
- "one-api/service"
+ "one-api/types"
"strconv"
"strings"
@@ -24,6 +25,11 @@ type Adaptor struct {
Timestamp int64
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
@@ -47,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return fmt.Sprintf("%s/", info.BaseUrl), nil
+ return fmt.Sprintf("%s/", info.ChannelBaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -63,7 +69,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
- apiKey := c.Request.Header.Get("Authorization")
+ apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
a.AppID = appId
@@ -94,13 +100,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
- var responseText string
- err, responseText = tencentStreamHandler(c, resp)
- usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ usage, err = tencentStreamHandler(c, info, resp)
} else {
- err, usage = tencentHandler(c, resp)
+ usage, err = tencentHandler(c, info, resp)
}
return
}
diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go
index 5630650f..f33a275c 100644
--- a/relay/channel/tencent/relay-tencent.go
+++ b/relay/channel/tencent/relay-tencent.go
@@ -8,17 +8,20 @@ import (
"encoding/json"
"errors"
"fmt"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
+ relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
+ "one-api/types"
"strconv"
"strings"
"time"
+
+ "github.com/gin-gonic/gin"
)
// https://cloud.tencent.com/document/product/1729/97732
@@ -56,12 +59,11 @@ func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextRespon
},
}
if len(response.Choices) > 0 {
- content, _ := json.Marshal(response.Choices[0].Messages.Content)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: response.Choices[0].Messages.Content,
},
FinishReason: response.Choices[0].FinishReason,
}
@@ -87,7 +89,7 @@ func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.Cha
return &response
}
-func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
+func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
@@ -104,7 +106,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
var tencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &tencentResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ common.SysLog("error unmarshalling stream response: " + err.Error())
continue
}
@@ -115,56 +117,47 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
err = helper.ObjectData(c, response)
if err != nil {
- common.SysError(err.Error())
+ common.SysLog(err.Error())
}
}
if err := scanner.Err(); err != nil {
- common.SysError("error reading stream: " + err.Error())
+ common.SysLog("error reading stream: " + err.Error())
}
helper.Done(c)
- err := resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
- }
+ service.CloseResponseBodyGracefully(resp)
- return nil, responseText
+ return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
}
-func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var tencentSb TencentChatResponseSB
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &tencentSb)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if tencentSb.Response.Error.Code != 0 {
- return &dto.OpenAIErrorWithStatusCode{
- Error: dto.OpenAIError{
- Message: tencentSb.Response.Error.Message,
- Code: tencentSb.Response.Error.Code,
- },
- StatusCode: resp.StatusCode,
- }, nil
+ return nil, types.WithOpenAIError(types.OpenAIError{
+ Message: tencentSb.Response.Error.Message,
+ Code: tencentSb.Response.Error.Code,
+ }, resp.StatusCode)
}
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
- jsonResponse, err := json.Marshal(fullTextResponse)
+ jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
- _, err = c.Writer.Write(jsonResponse)
- return nil, &fullTextResponse.Usage
+ service.IOCopyBytesGracefully(c, resp, jsonResponse)
+ return &fullTextResponse.Usage, nil
}
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go
index 31f84abf..0b6b2674 100644
--- a/relay/channel/vertex/adaptor.go
+++ b/relay/channel/vertex/adaptor.go
@@ -14,6 +14,7 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/setting/model_setting"
+ "one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -34,6 +35,7 @@ var claudeModelMap = map[string]string{
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
"claude-sonnet-4-20250514": "claude-sonnet-4@20250514",
"claude-opus-4-20250514": "claude-opus-4@20250514",
+ "claude-opus-4-1-20250805": "claude-opus-4-1@20250805",
}
const anthropicVersion = "vertex-2023-10-16"
@@ -43,6 +45,11 @@ type Adaptor struct {
AccountCredentials Credentials
}
+func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
+ geminiAdaptor := gemini.Adaptor{}
+ return geminiAdaptor.ConvertGeminiRequest(c, info, request)
+}
+
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
c.Set("request_model", v)
@@ -59,17 +66,17 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
+ geminiAdaptor := gemini.Adaptor{}
+ return geminiAdaptor.ConvertImageRequest(c, info, request)
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude") {
a.RequestMode = RequestModeClaude
- } else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
- a.RequestMode = RequestModeGemini
} else if strings.Contains(info.UpstreamModelName, "llama") {
a.RequestMode = RequestModeLlama
+ } else {
+ a.RequestMode = RequestModeGemini
}
}
@@ -82,11 +89,15 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
a.AccountCredentials = *adc
suffix := ""
if a.RequestMode == RequestModeGemini {
+
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
- // suffix -thinking and -nothinking
- if strings.HasSuffix(info.OriginModelName, "-thinking") {
+ // 新增逻辑:处理 -thinking- 格式
+ if strings.Contains(info.UpstreamModelName, "-thinking-") {
+ parts := strings.Split(info.UpstreamModelName, "-thinking-")
+ info.UpstreamModelName = parts[0]
+ } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
- } else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
+ } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
}
}
@@ -96,6 +107,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} else {
suffix = "generateContent"
}
+
+ if strings.HasPrefix(info.UpstreamModelName, "imagen") {
+ suffix = "predict"
+ }
+
if region == "global" {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
@@ -123,14 +139,23 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
model = v
}
- return fmt.Sprintf(
- "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
- region,
- adc.ProjectID,
- region,
- model,
- suffix,
- ), nil
+ if region == "global" {
+ return fmt.Sprintf(
+ "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
+ adc.ProjectID,
+ model,
+ suffix,
+ ), nil
+ } else {
+ return fmt.Sprintf(
+ "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
+ region,
+ adc.ProjectID,
+ region,
+ model,
+ suffix,
+ ), nil
+ }
} else if a.RequestMode == RequestModeLlama {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
@@ -156,8 +181,62 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
+ if a.RequestMode == RequestModeGemini && strings.HasPrefix(info.UpstreamModelName, "imagen") {
+ prompt := ""
+ for _, m := range request.Messages {
+ if m.Role == "user" {
+ prompt = m.StringContent()
+ if prompt != "" {
+ break
+ }
+ }
+ }
+ if prompt == "" {
+ if p, ok := request.Prompt.(string); ok {
+ prompt = p
+ }
+ }
+ if prompt == "" {
+ return nil, errors.New("prompt is required for image generation")
+ }
+
+ imgReq := dto.ImageRequest{
+ Model: request.Model,
+ Prompt: prompt,
+ N: 1,
+ Size: "1024x1024",
+ }
+ if request.N > 0 {
+ imgReq.N = uint(request.N)
+ }
+ if request.Size != "" {
+ imgReq.Size = request.Size
+ }
+ if len(request.ExtraBody) > 0 {
+ var extra map[string]any
+ if err := json.Unmarshal(request.ExtraBody, &extra); err == nil {
+ if n, ok := extra["n"].(float64); ok && n > 0 {
+ imgReq.N = uint(n)
+ }
+ if size, ok := extra["size"].(string); ok {
+ imgReq.Size = size
+ }
+ // accept aspectRatio in extra body (top-level or under parameters)
+ if ar, ok := extra["aspectRatio"].(string); ok && ar != "" {
+ imgReq.Size = ar
+ }
+ if params, ok := extra["parameters"].(map[string]any); ok {
+ if ar, ok := params["aspectRatio"].(string); ok && ar != "" {
+ imgReq.Size = ar
+ }
+ }
+ }
+ }
+ c.Set("request_model", request.Model)
+ return a.ConvertImageRequest(c, info, imgReq)
+ }
if a.RequestMode == RequestModeClaude {
- claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request)
+ claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
if err != nil {
return nil, err
}
@@ -166,7 +245,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
info.UpstreamModelName = claudeReq.Model
return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini {
- geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info)
+ geminiRequest, err := gemini.CovertGemini2OpenAI(c, *request, info)
if err != nil {
return nil, err
}
@@ -196,32 +275,35 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
switch a.RequestMode {
case RequestModeClaude:
- err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
+ return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
case RequestModeGemini:
if info.RelayMode == constant.RelayModeGemini {
- usage, err = gemini.GeminiTextGenerationStreamHandler(c, resp, info)
+ return gemini.GeminiTextGenerationStreamHandler(c, info, resp)
} else {
- err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
+ return gemini.GeminiChatStreamHandler(c, info, resp)
}
case RequestModeLlama:
- err, usage = openai.OaiStreamHandler(c, resp, info)
+ return openai.OaiStreamHandler(c, info, resp)
}
} else {
switch a.RequestMode {
case RequestModeClaude:
- err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
+ return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
case RequestModeGemini:
if info.RelayMode == constant.RelayModeGemini {
- usage, err = gemini.GeminiTextGenerationHandler(c, resp, info)
+ return gemini.GeminiTextGenerationHandler(c, info, resp)
} else {
- err, usage = gemini.GeminiChatHandler(c, resp, info)
+ if strings.HasPrefix(info.UpstreamModelName, "imagen") {
+ return gemini.GeminiImageHandler(c, info, resp)
+ }
+ return gemini.GeminiChatHandler(c, info, resp)
}
case RequestModeLlama:
- err, usage = openai.OpenaiHandler(c, resp, info)
+ return openai.OpenaiHandler(c, info, resp)
}
}
return
diff --git a/relay/channel/vertex/relay-vertex.go b/relay/channel/vertex/relay-vertex.go
index d2596320..5ed87665 100644
--- a/relay/channel/vertex/relay-vertex.go
+++ b/relay/channel/vertex/relay-vertex.go
@@ -4,8 +4,11 @@ import "one-api/common"
func GetModelRegion(other string, localModelName string) string {
// if other is json string
- if common.IsJsonStr(other) {
- m := common.StrToMap(other)
+ if common.IsJsonObject(other) {
+ m, err := common.StrToMap(other)
+ if err != nil {
+ return other // return original if parsing fails
+ }
if m[localModelName] != nil {
return m[localModelName].(string)
} else {
diff --git a/relay/channel/vertex/service_account.go b/relay/channel/vertex/service_account.go
index cc640803..9a4650d9 100644
--- a/relay/channel/vertex/service_account.go
+++ b/relay/channel/vertex/service_account.go
@@ -11,6 +11,7 @@ import (
"net/http"
"net/url"
relaycommon "one-api/relay/common"
+ "one-api/service"
"strings"
"fmt"
@@ -35,7 +36,12 @@ var Cache = asynccache.NewAsyncCache(asynccache.Options{
})
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
- cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
+ var cacheKey string
+ if info.ChannelIsMultiKey {
+ cacheKey = fmt.Sprintf("access-token-%d-%d", info.ChannelId, info.ChannelMultiKeyIndex)
+ } else {
+ cacheKey = fmt.Sprintf("access-token-%d", info.ChannelId)
+ }
val, err := Cache.Get(cacheKey)
if err == nil {
return val.(string), nil
@@ -45,7 +51,7 @@ func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
if err != nil {
return "", fmt.Errorf("failed to create signed JWT: %w", err)
}
- newToken, err := exchangeJwtForAccessToken(signedJWT)
+ newToken, err := exchangeJwtForAccessToken(signedJWT, info)
if err != nil {
return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
}
@@ -96,14 +102,25 @@ func createSignedJWT(email, privateKeyPEM string) (string, error) {
return signedToken, nil
}
-func exchangeJwtForAccessToken(signedJWT string) (string, error) {
+func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (string, error) {
authURL := "https://www.googleapis.com/oauth2/v4/token"
data := url.Values{}
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
data.Set("assertion", signedJWT)
- resp, err := http.PostForm(authURL, data)
+ var client *http.Client
+ var err error
+ if info.ChannelSetting.Proxy != "" {
+ client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
+ if err != nil {
+ return "", fmt.Errorf("new proxy http client failed: %w", err)
+ }
+ } else {
+ client = service.GetHttpClient()
+ }
+
+ resp, err := client.PostForm(authURL, data)
if err != nil {
return "", err
}
diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go
index a4a48ee9..b46cb952 100644
--- a/relay/channel/volcengine/adaptor.go
+++ b/relay/channel/volcengine/adaptor.go
@@ -1,15 +1,20 @@
package volcengine
import (
+ "bytes"
"errors"
"fmt"
"io"
+ "mime/multipart"
"net/http"
+ "net/textproto"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
+ "one-api/types"
+ "path/filepath"
"strings"
"github.com/gin-gonic/gin"
@@ -18,10 +23,14 @@ import (
type Adaptor struct {
}
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
- panic("implement me")
- return nil, nil
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
+ adaptor := openai.Adaptor{}
+ return adaptor.ConvertClaudeRequest(c, info, req)
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -30,8 +39,146 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- //TODO implement me
- return nil, errors.New("not implemented")
+ switch info.RelayMode {
+ case constant.RelayModeImagesEdits:
+
+ var requestBody bytes.Buffer
+ writer := multipart.NewWriter(&requestBody)
+
+ writer.WriteField("model", request.Model)
+ // 获取所有表单字段
+ formData := c.Request.PostForm
+ // 遍历表单字段并打印输出
+ for key, values := range formData {
+ if key == "model" {
+ continue
+ }
+ for _, value := range values {
+ writer.WriteField(key, value)
+ }
+ }
+
+ // Parse the multipart form to handle both single image and multiple images
+ if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
+ return nil, errors.New("failed to parse multipart form")
+ }
+
+ if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
+ // Check if "image" field exists in any form, including array notation
+ var imageFiles []*multipart.FileHeader
+ var exists bool
+
+ // First check for standard "image" field
+ if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
+ // If not found, check for "image[]" field
+ if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
+ // If still not found, iterate through all fields to find any that start with "image["
+ foundArrayImages := false
+ for fieldName, files := range c.Request.MultipartForm.File {
+ if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
+ foundArrayImages = true
+ for _, file := range files {
+ imageFiles = append(imageFiles, file)
+ }
+ }
+ }
+
+ // If no image fields found at all
+ if !foundArrayImages && (len(imageFiles) == 0) {
+ return nil, errors.New("image is required")
+ }
+ }
+ }
+
+ // Process all image files
+ for i, fileHeader := range imageFiles {
+ file, err := fileHeader.Open()
+ if err != nil {
+ return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
+ }
+ defer file.Close()
+
+ // If multiple images, use image[] as the field name
+ fieldName := "image"
+ if len(imageFiles) > 1 {
+ fieldName = "image[]"
+ }
+
+ // Determine MIME type based on file extension
+ mimeType := detectImageMimeType(fileHeader.Filename)
+
+ // Create a form file with the appropriate content type
+ h := make(textproto.MIMEHeader)
+ h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
+ h.Set("Content-Type", mimeType)
+
+ part, err := writer.CreatePart(h)
+ if err != nil {
+ return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
+ }
+
+ if _, err := io.Copy(part, file); err != nil {
+ return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
+ }
+ }
+
+ // Handle mask file if present
+ if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
+ maskFile, err := maskFiles[0].Open()
+ if err != nil {
+ return nil, errors.New("failed to open mask file")
+ }
+ defer maskFile.Close()
+
+ // Determine MIME type for mask file
+ mimeType := detectImageMimeType(maskFiles[0].Filename)
+
+ // Create a form file with the appropriate content type
+ h := make(textproto.MIMEHeader)
+ h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
+ h.Set("Content-Type", mimeType)
+
+ maskPart, err := writer.CreatePart(h)
+ if err != nil {
+ return nil, errors.New("create form file failed for mask")
+ }
+
+ if _, err := io.Copy(maskPart, maskFile); err != nil {
+ return nil, errors.New("copy mask file failed")
+ }
+ }
+ } else {
+ return nil, errors.New("no multipart form data found")
+ }
+
+ // 关闭 multipart 编写器以设置分界线
+ writer.Close()
+ c.Request.Header.Set("Content-Type", writer.FormDataContentType())
+ return bytes.NewReader(requestBody.Bytes()), nil
+
+ default:
+ return request, nil
+ }
+}
+
+// detectImageMimeType determines the MIME type based on the file extension
+func detectImageMimeType(filename string) string {
+ ext := strings.ToLower(filepath.Ext(filename))
+ switch ext {
+ case ".jpg", ".jpeg":
+ return "image/jpeg"
+ case ".png":
+ return "image/png"
+ case ".webp":
+ return "image/webp"
+ default:
+ // Try to detect from extension if possible
+ if strings.HasPrefix(ext, ".jp") {
+ return "image/jpeg"
+ }
+ // Default to png as a fallback
+ return "image/png"
+ }
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -41,11 +188,17 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
if strings.HasPrefix(info.UpstreamModelName, "bot") {
- return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.BaseUrl), nil
+ return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil
}
- return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
+ return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil
case constant.RelayModeEmbeddings:
- return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
+ return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil
+ case constant.RelayModeImagesGenerations:
+ return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil
+ case constant.RelayModeImagesEdits:
+ return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil
+ case constant.RelayModeRerank:
+ return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil
default:
}
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
@@ -81,17 +234,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
- switch info.RelayMode {
- case constant.RelayModeChatCompletions:
- if info.IsStream {
- err, usage = openai.OaiStreamHandler(c, resp, info)
- } else {
- err, usage = openai.OpenaiHandler(c, resp, info)
- }
- case constant.RelayModeEmbeddings:
- err, usage = openai.OpenaiHandler(c, resp, info)
- }
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ adaptor := openai.Adaptor{}
+ usage, err = adaptor.DoResponse(c, resp, info)
return
}
diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go
index b5896415..d5671ab2 100644
--- a/relay/channel/xai/adaptor.go
+++ b/relay/channel/xai/adaptor.go
@@ -8,6 +8,7 @@ import (
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
+ "one-api/types"
"strings"
"one-api/relay/constant"
@@ -18,6 +19,11 @@ import (
type Adaptor struct {
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
//panic("implement me")
@@ -33,7 +39,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
xaiRequest := ImageRequest{
Model: request.Model,
Prompt: request.Prompt,
- N: request.N,
+ N: int(request.N),
ResponseFormat: request.ResponseFormat,
}
return xaiRequest, nil
@@ -43,7 +49,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
+ return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -56,6 +62,15 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
+ if strings.HasSuffix(info.UpstreamModelName, "-search") {
+ info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search")
+ request.Model = info.UpstreamModelName
+ toMap := request.ToMap()
+ toMap["search_parameters"] = map[string]any{
+ "mode": "on",
+ }
+ return toMap, nil
+ }
if strings.HasPrefix(request.Model, "grok-3-mini") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens
@@ -95,15 +110,15 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
- err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
+ usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
default:
if info.IsStream {
- err, usage = xAIStreamHandler(c, resp, info)
+ usage, err = xAIStreamHandler(c, info, resp)
} else {
- err, usage = xAIHandler(c, resp, info)
+ usage, err = xAIHandler(c, info, resp)
}
}
return
diff --git a/relay/channel/xai/constants.go b/relay/channel/xai/constants.go
index 685fe3bb..311b4bb6 100644
--- a/relay/channel/xai/constants.go
+++ b/relay/channel/xai/constants.go
@@ -1,6 +1,8 @@
package xai
var ModelList = []string{
+ // grok-4
+ "grok-4", "grok-4-0709", "grok-4-0709-search",
// grok-3
"grok-3-beta", "grok-3-mini-beta",
// grok-3 mini
diff --git a/relay/channel/xai/dto.go b/relay/channel/xai/dto.go
index b8098475..107a980a 100644
--- a/relay/channel/xai/dto.go
+++ b/relay/channel/xai/dto.go
@@ -4,24 +4,24 @@ import "one-api/dto"
// ChatCompletionResponse represents the response from XAI chat completion API
type ChatCompletionResponse struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- Choices []dto.ChatCompletionsStreamResponseChoice
- Usage *dto.Usage `json:"usage"`
- SystemFingerprint string `json:"system_fingerprint"`
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ Choices []dto.OpenAITextResponseChoice `json:"choices"`
+ Usage *dto.Usage `json:"usage"`
+ SystemFingerprint string `json:"system_fingerprint"`
}
// quality, size or style are not supported by xAI API at the moment.
type ImageRequest struct {
- Model string `json:"model"`
- Prompt string `json:"prompt" binding:"required"`
- N int `json:"n,omitempty"`
+ Model string `json:"model"`
+ Prompt string `json:"prompt" binding:"required"`
+ N int `json:"n,omitempty"`
// Size string `json:"size,omitempty"`
// Quality string `json:"quality,omitempty"`
- ResponseFormat string `json:"response_format,omitempty"`
+ ResponseFormat string `json:"response_format,omitempty"`
// Style string `json:"style,omitempty"`
// User string `json:"user,omitempty"`
// ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
-}
\ No newline at end of file
+}
diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go
index e019c2dc..5cae9c0a 100644
--- a/relay/channel/xai/text.go
+++ b/relay/channel/xai/text.go
@@ -1,9 +1,7 @@
package xai
import (
- "bytes"
"encoding/json"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -12,7 +10,10 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
+ "one-api/types"
"strings"
+
+ "github.com/gin-gonic/gin"
)
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
@@ -34,7 +35,7 @@ func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage
return openAIResp
}
-func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
usage := &dto.Usage{}
var responseTextBuilder strings.Builder
var toolCount int
@@ -46,7 +47,7 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
var xAIResp *dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &xAIResp)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ common.SysLog("error unmarshalling stream response: " + err.Error())
return true
}
@@ -62,58 +63,45 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
err = helper.ObjectData(c, openaiResponse)
if err != nil {
- common.SysError(err.Error())
+ common.SysLog(err.Error())
}
return true
})
if !containStreamUsage {
- usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+ usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}
helper.Done(c)
- err := resp.Body.Close()
- if err != nil {
- //return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- common.SysError("close_response_body_failed: " + err.Error())
- }
- return nil, usage
+ service.CloseResponseBodyGracefully(resp)
+ return usage, nil
}
-func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ defer service.CloseResponseBodyGracefully(resp)
+
responseBody, err := io.ReadAll(resp.Body)
- var response *dto.TextResponse
- err = common.DecodeJson(responseBody, &response)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
- return nil, nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ var xaiResponse ChatCompletionResponse
+ err = common.Unmarshal(responseBody, &xaiResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ if xaiResponse.Usage != nil {
+ xaiResponse.Usage.CompletionTokens = xaiResponse.Usage.TotalTokens - xaiResponse.Usage.PromptTokens
+ xaiResponse.Usage.CompletionTokenDetails.TextTokens = xaiResponse.Usage.CompletionTokens - xaiResponse.Usage.CompletionTokenDetails.ReasoningTokens
}
- response.Usage.CompletionTokens = response.Usage.TotalTokens - response.Usage.PromptTokens
- response.Usage.CompletionTokenDetails.TextTokens = response.Usage.CompletionTokens - response.Usage.CompletionTokenDetails.ReasoningTokens
// new body
- encodeJson, err := common.EncodeJson(response)
+ encodeJson, err := common.Marshal(xaiResponse)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
- return nil, nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
- // set new body
- resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
+ service.IOCopyBytesGracefully(c, resp, encodeJson)
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
-
- return nil, &response.Usage
+ return xaiResponse.Usage, nil
}
diff --git a/relay/channel/xinference/dto.go b/relay/channel/xinference/dto.go
index 2f12ad10..35f339fe 100644
--- a/relay/channel/xinference/dto.go
+++ b/relay/channel/xinference/dto.go
@@ -1,7 +1,7 @@
package xinference
type XinRerankResponseDocument struct {
- Document string `json:"document,omitempty"`
+ Document any `json:"document,omitempty"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}
diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go
index 7591e0e7..7ee76f1a 100644
--- a/relay/channel/xunfei/adaptor.go
+++ b/relay/channel/xunfei/adaptor.go
@@ -7,7 +7,7 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
- "one-api/service"
+ "one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -17,6 +17,11 @@ type Adaptor struct {
request *dto.GeneralOpenAIRequest
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
@@ -74,18 +79,18 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return dummyResp, nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
splits := strings.Split(info.ApiKey, "|")
if len(splits) != 3 {
- return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
+ return nil, types.NewError(errors.New("invalid auth"), types.ErrorCodeChannelInvalidKey)
}
if a.request == nil {
- return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
+ return nil, types.NewError(errors.New("request is nil"), types.ErrorCodeInvalidRequest)
}
if info.IsStream {
- err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
+ usage, err = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
} else {
- err, usage = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2])
+ usage, err = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2])
}
return
}
diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go
index 15d33510..9d5c190f 100644
--- a/relay/channel/xunfei/relay-xunfei.go
+++ b/relay/channel/xunfei/relay-xunfei.go
@@ -6,18 +6,18 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
"io"
- "net/http"
"net/url"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/helper"
- "one-api/service"
+ "one-api/types"
"strings"
"time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/gorilla/websocket"
)
// https://console.xfyun.cn/services/cbm
@@ -48,7 +48,7 @@ func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string,
xunfeiRequest.Parameter.Chat.Domain = domain
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
xunfeiRequest.Parameter.Chat.TopK = request.N
- xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
+ xunfeiRequest.Parameter.Chat.MaxTokens = request.GetMaxTokens()
xunfeiRequest.Payload.Message.Text = messages
return &xunfeiRequest
}
@@ -61,12 +61,11 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *dto.OpenAITextResponse
},
}
}
- content, _ := json.Marshal(response.Payload.Choices.Text[0].Content)
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
- Content: content,
+ Content: response.Payload.Choices.Text[0].Content,
},
FinishReason: constant.FinishReasonStop,
}
@@ -127,11 +126,11 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
return callUrl
}
-func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
- return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
}
helper.SetEventStreamHeaders(c)
var usage dto.Usage
@@ -144,7 +143,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ common.SysLog("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -154,14 +153,14 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
return false
}
})
- return nil, &usage
+ return &usage, nil
}
-func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
if err != nil {
- return service.OpenAIErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
}
var usage dto.Usage
var content string
@@ -192,11 +191,11 @@ func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId s
response := responseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
_, _ = c.Writer.Write(jsonResponse)
- return nil, &usage
+ return &usage, nil
}
func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
@@ -207,6 +206,11 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
if err != nil || resp.StatusCode != 101 {
return nil, nil, err
}
+
+ defer func() {
+ conn.Close()
+ }()
+
data := requestOpenAI2Xunfei(textRequest, appId, domain)
err = conn.WriteJSON(data)
if err != nil {
@@ -219,20 +223,19 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
for {
_, msg, err := conn.ReadMessage()
if err != nil {
- common.SysError("error reading stream response: " + err.Error())
+ common.SysLog("error reading stream response: " + err.Error())
break
}
var response XunfeiChatResponse
err = json.Unmarshal(msg, &response)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ common.SysLog("error unmarshalling stream response: " + err.Error())
break
}
dataChan <- response
if response.Payload.Choices.Status == 2 {
- err := conn.Close()
if err != nil {
- common.SysError("error closing websocket connection: " + err.Error())
+ common.SysLog("error closing websocket connection: " + err.Error())
}
break
}
diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go
index b4d8fb30..bd27c90b 100644
--- a/relay/channel/zhipu/adaptor.go
+++ b/relay/channel/zhipu/adaptor.go
@@ -8,6 +8,7 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
@@ -15,6 +16,11 @@ import (
type Adaptor struct {
}
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
@@ -39,7 +45,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.IsStream {
method = "sse-invoke"
}
- return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
+ return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.ChannelBaseUrl, info.UpstreamModelName, method), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -77,11 +83,11 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
return nil, errors.New("not implemented")
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
- err, usage = zhipuStreamHandler(c, resp)
+ usage, err = zhipuStreamHandler(c, info, resp)
} else {
- err, usage = zhipuHandler(c, resp)
+ usage, err = zhipuHandler(c, info, resp)
}
return
}
diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go
index b0cac858..8eb0dcc1 100644
--- a/relay/channel/zhipu/relay-zhipu.go
+++ b/relay/channel/zhipu/relay-zhipu.go
@@ -3,18 +3,21 @@ package zhipu
import (
"bufio"
"encoding/json"
- "github.com/gin-gonic/gin"
- "github.com/golang-jwt/jwt"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
+ relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
+ "one-api/types"
"strings"
"sync"
"time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/golang-jwt/jwt"
)
// https://open.bigmodel.cn/doc/api#chatglm_std
@@ -36,7 +39,7 @@ func getZhipuToken(apikey string) string {
split := strings.Split(apikey, ".")
if len(split) != 2 {
- common.SysError("invalid zhipu key: " + apikey)
+ common.SysLog("invalid zhipu key: " + apikey)
return ""
}
@@ -108,12 +111,11 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse {
Usage: response.Data.Usage,
}
for i, choice := range response.Data.Choices {
- content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
openaiChoice := dto.OpenAITextResponseChoice{
Index: i,
Message: dto.Message{
Role: choice.Role,
- Content: content,
+ Content: strings.Trim(choice.Content, "\""),
},
FinishReason: "",
}
@@ -151,7 +153,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt
return &response, &zhipuResponse.Usage
}
-func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var usage *dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
@@ -185,7 +187,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
response := streamResponseZhipu2OpenAI(data)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ common.SysLog("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -194,13 +196,13 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
var zhipuResponse ZhipuStreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse)
if err != nil {
- common.SysError("error unmarshalling stream response: " + err.Error())
+ common.SysLog("error unmarshalling stream response: " + err.Error())
return true
}
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
- common.SysError("error marshalling stream response: " + err.Error())
+ common.SysLog("error marshalling stream response: " + err.Error())
return true
}
usage = zhipuUsage
@@ -211,45 +213,34 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
return false
}
})
- err := resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
- }
- return nil, usage
+ service.CloseResponseBodyGracefully(resp)
+ return usage, nil
}
-func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var zhipuResponse ZhipuResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
+ service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if !zhipuResponse.Success {
- return &dto.OpenAIErrorWithStatusCode{
- Error: dto.OpenAIError{
- Message: zhipuResponse.Msg,
- Type: "zhipu_error",
- Param: "",
- Code: zhipuResponse.Code,
- },
- StatusCode: resp.StatusCode,
- }, nil
+ return nil, types.WithOpenAIError(types.OpenAIError{
+ Message: zhipuResponse.Msg,
+ Code: zhipuResponse.Code,
+ }, resp.StatusCode)
}
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
- return nil, &fullTextResponse.Usage
+ return &fullTextResponse.Usage, nil
}
diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go
index 222cdff8..37c0c352 100644
--- a/relay/channel/zhipu_4v/adaptor.go
+++ b/relay/channel/zhipu_4v/adaptor.go
@@ -7,9 +7,11 @@ import (
"net/http"
"one-api/dto"
"one-api/relay/channel"
+ "one-api/relay/channel/claude"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
@@ -17,10 +19,13 @@ import (
type Adaptor struct {
}
-func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
- panic("implement me")
- return nil, nil
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
+ return req, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -37,19 +42,22 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl)
- switch info.RelayMode {
- case relayconstant.RelayModeEmbeddings:
- return fmt.Sprintf("%s/embeddings", baseUrl), nil
+ switch info.RelayFormat {
+ case types.RelayFormatClaude:
+ return fmt.Sprintf("%s/api/anthropic/v1/messages", info.ChannelBaseUrl), nil
default:
- return fmt.Sprintf("%s/chat/completions", baseUrl), nil
+ switch info.RelayMode {
+ case relayconstant.RelayModeEmbeddings:
+ return fmt.Sprintf("%s/api/paas/v4/embeddings", info.ChannelBaseUrl), nil
+ default:
+ return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.ChannelBaseUrl), nil
+ }
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
- token := getZhipuToken(info.ApiKey)
- req.Set("Authorization", token)
+ req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
@@ -80,13 +88,18 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
- if info.IsStream {
- err, usage = openai.OaiStreamHandler(c, resp, info)
- } else {
- err, usage = openai.OpenaiHandler(c, resp, info)
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ switch info.RelayFormat {
+ case types.RelayFormatClaude:
+ if info.IsStream {
+ return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
+ } else {
+ return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
+ }
+ default:
+ adaptor := openai.Adaptor{}
+ return adaptor.DoResponse(c, resp, info)
}
- return
}
func (a *Adaptor) GetModelList() []string {
diff --git a/relay/channel/zhipu_4v/relay-zhipu_v4.go b/relay/channel/zhipu_4v/relay-zhipu_v4.go
index 271dda8f..aec87dd5 100644
--- a/relay/channel/zhipu_4v/relay-zhipu_v4.go
+++ b/relay/channel/zhipu_4v/relay-zhipu_v4.go
@@ -1,69 +1,10 @@
package zhipu_4v
import (
- "github.com/golang-jwt/jwt"
- "one-api/common"
"one-api/dto"
"strings"
- "sync"
- "time"
)
-// https://open.bigmodel.cn/doc/api#chatglm_std
-// chatglm_std, chatglm_lite
-// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
-// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
-
-var zhipuTokens sync.Map
-var expSeconds int64 = 24 * 3600
-
-func getZhipuToken(apikey string) string {
- data, ok := zhipuTokens.Load(apikey)
- if ok {
- tokenData := data.(tokenData)
- if time.Now().Before(tokenData.ExpiryTime) {
- return tokenData.Token
- }
- }
-
- split := strings.Split(apikey, ".")
- if len(split) != 2 {
- common.SysError("invalid zhipu key: " + apikey)
- return ""
- }
-
- id := split[0]
- secret := split[1]
-
- expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
- expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
-
- timestamp := time.Now().UnixNano() / 1e6
-
- payload := jwt.MapClaims{
- "api_key": id,
- "exp": expMillis,
- "timestamp": timestamp,
- }
-
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
-
- token.Header["alg"] = "HS256"
- token.Header["sign_type"] = "SIGN"
-
- tokenString, err := token.SignedString([]byte(secret))
- if err != nil {
- return ""
- }
-
- zhipuTokens.Store(apikey, tokenData{
- Token: tokenString,
- ExpiryTime: expiryTime,
- })
-
- return tokenString
-}
-
func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
@@ -105,9 +46,10 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
Messages: messages,
Temperature: request.Temperature,
TopP: request.TopP,
- MaxTokens: request.MaxTokens,
+ MaxTokens: request.GetMaxTokens(),
Stop: Stop,
Tools: request.Tools,
ToolChoice: request.ToolChoice,
+ THINKING: request.THINKING,
}
}
diff --git a/relay/claude_handler.go b/relay/claude_handler.go
index fb68a88a..59c052f6 100644
--- a/relay/claude_handler.go
+++ b/relay/claude_handler.go
@@ -2,10 +2,7 @@ package relay
import (
"bytes"
- "encoding/json"
- "errors"
"fmt"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -14,150 +11,121 @@ import (
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
+ "one-api/types"
"strings"
+
+ "github.com/gin-gonic/gin"
)
-func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
- textRequest = &dto.ClaudeRequest{}
- err = c.ShouldBindJSON(textRequest)
+func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+
+ info.InitChannelMeta(c)
+
+ claudeReq, ok := info.Request.(*dto.ClaudeRequest)
+
+ if !ok {
+ return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ request, err := common.DeepCopy(claudeReq)
if err != nil {
- return nil, err
+ return types.NewError(fmt.Errorf("failed to copy request to ClaudeRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
- if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
- return nil, errors.New("field messages is required")
- }
- if textRequest.Model == "" {
- return nil, errors.New("field model is required")
- }
- return textRequest, nil
-}
-func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
-
- relayInfo := relaycommon.GenRelayInfoClaude(c)
-
- // get & validate textRequest 获取并验证文本请求
- textRequest, err := getAndValidateClaudeRequest(c)
+ err = helper.ModelMappedHelper(c, info, request)
if err != nil {
- return service.ClaudeErrorWrapperLocal(err, "invalid_claude_request", http.StatusBadRequest)
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
- if textRequest.Stream {
- relayInfo.IsStream = true
- }
-
- err = helper.ModelMappedHelper(c, relayInfo)
- if err != nil {
- return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
- }
-
- textRequest.Model = relayInfo.UpstreamModelName
-
- promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
- // count messages token error 计算promptTokens错误
- if err != nil {
- return service.ClaudeErrorWrapperLocal(err, "count_token_messages_failed", http.StatusInternalServerError)
- }
-
- priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
- if err != nil {
- return service.ClaudeErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
- }
-
- // pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
-
- if openaiErr != nil {
- return service.OpenAIErrorToClaudeError(openaiErr)
- }
- defer func() {
- if openaiErr != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
-
- adaptor := GetAdaptor(relayInfo.ApiType)
+ adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
- return service.ClaudeErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
- adaptor.Init(relayInfo)
- var requestBody io.Reader
+ adaptor.Init(info)
- if textRequest.MaxTokens == 0 {
- textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
+ if request.MaxTokens == 0 {
+ request.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(request.Model))
}
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
- strings.HasSuffix(textRequest.Model, "-thinking") {
- if textRequest.Thinking == nil {
+ strings.HasSuffix(request.Model, "-thinking") {
+ if request.Thinking == nil {
// 因为BudgetTokens 必须大于1024
- if textRequest.MaxTokens < 1280 {
- textRequest.MaxTokens = 1280
+ if request.MaxTokens < 1280 {
+ request.MaxTokens = 1280
}
// BudgetTokens 为 max_tokens 的 80%
- textRequest.Thinking = &dto.Thinking{
+ request.Thinking = &dto.Thinking{
Type: "enabled",
- BudgetTokens: int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
+ BudgetTokens: common.GetPointer[int](int(float64(request.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
- textRequest.TopP = 0
- textRequest.Temperature = common.GetPointer[float64](1.0)
+ request.TopP = 0
+ request.Temperature = common.GetPointer[float64](1.0)
}
- textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
- relayInfo.UpstreamModelName = textRequest.Model
+ request.Model = strings.TrimSuffix(request.Model, "-thinking")
+ info.UpstreamModelName = request.Model
}
- convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
- if err != nil {
- return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
+ var requestBody io.Reader
+ if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
+ body, err := common.GetRequestBody(c)
+ if err != nil {
+ return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+ }
+ requestBody = bytes.NewBuffer(body)
+ } else {
+ convertedRequest, err := adaptor.ConvertClaudeRequest(c, info, request)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+ jsonData, err := common.Marshal(convertedRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+
+ // apply param override
+ if len(info.ParamOverride) > 0 {
+ jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+ }
+ }
+
+ if common.DebugEnabled {
+ println("requestBody: ", string(jsonData))
+ }
+ requestBody = bytes.NewBuffer(jsonData)
}
- jsonData, err := json.Marshal(convertedRequest)
- if common.DebugEnabled {
- println("requestBody: ", string(jsonData))
- }
- if err != nil {
- return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+ resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
- return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
+ return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
if resp != nil {
httpResp = resp.(*http.Response)
- relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+ info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
- openaiErr = service.RelayErrorHandler(httpResp, false)
+ newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return service.OpenAIErrorToClaudeError(openaiErr)
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
}
}
- usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
//log.Printf("usage: %v", usage)
- if openaiErr != nil {
+ if newAPIError != nil {
// reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return service.OpenAIErrorToClaudeError(openaiErr)
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
}
- service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+
+ service.PostClaudeConsumeQuota(c, info, usage.(*dto.Usage))
return nil
}
-
-func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) {
- var promptTokens int
- var err error
- switch info.RelayMode {
- default:
- promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName)
- }
- info.PromptTokens = promptTokens
- return promptTokens, err
-}
diff --git a/relay/common/override.go b/relay/common/override.go
new file mode 100644
index 00000000..c8f216ed
--- /dev/null
+++ b/relay/common/override.go
@@ -0,0 +1,396 @@
+package common
+
+import (
+ "encoding/json"
+ "fmt"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+ "strings"
+)
+
+type ConditionOperation struct {
+ Path string `json:"path"` // JSON路径
+ Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
+ Value interface{} `json:"value"` // 匹配的值
+ Invert bool `json:"invert"` // 反选功能,true表示取反结果
+ PassMissingKey bool `json:"pass_missing_key"` // 未获取到json key时的行为
+}
+
+type ParamOperation struct {
+ Path string `json:"path"`
+ Mode string `json:"mode"` // delete, set, move, prepend, append
+ Value interface{} `json:"value"`
+ KeepOrigin bool `json:"keep_origin"`
+ From string `json:"from,omitempty"`
+ To string `json:"to,omitempty"`
+ Conditions []ConditionOperation `json:"conditions,omitempty"` // 条件列表
+ Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
+}
+
+func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
+ if len(paramOverride) == 0 {
+ return jsonData, nil
+ }
+
+ // 尝试断言为操作格式
+ if operations, ok := tryParseOperations(paramOverride); ok {
+ // 使用新方法
+ result, err := applyOperations(string(jsonData), operations)
+ return []byte(result), err
+ }
+
+ // 直接使用旧方法
+ return applyOperationsLegacy(jsonData, paramOverride)
+}
+
+func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
+ // 检查是否包含 "operations" 字段
+ if opsValue, exists := paramOverride["operations"]; exists {
+ if opsSlice, ok := opsValue.([]interface{}); ok {
+ var operations []ParamOperation
+ for _, op := range opsSlice {
+ if opMap, ok := op.(map[string]interface{}); ok {
+ operation := ParamOperation{}
+
+ // 断言必要字段
+ if path, ok := opMap["path"].(string); ok {
+ operation.Path = path
+ }
+ if mode, ok := opMap["mode"].(string); ok {
+ operation.Mode = mode
+ } else {
+ return nil, false // mode 是必需的
+ }
+
+ // 可选字段
+ if value, exists := opMap["value"]; exists {
+ operation.Value = value
+ }
+ if keepOrigin, ok := opMap["keep_origin"].(bool); ok {
+ operation.KeepOrigin = keepOrigin
+ }
+ if from, ok := opMap["from"].(string); ok {
+ operation.From = from
+ }
+ if to, ok := opMap["to"].(string); ok {
+ operation.To = to
+ }
+ if logic, ok := opMap["logic"].(string); ok {
+ operation.Logic = logic
+ } else {
+ operation.Logic = "OR" // 默认为OR
+ }
+
+ // 解析条件
+ if conditions, exists := opMap["conditions"]; exists {
+ if condSlice, ok := conditions.([]interface{}); ok {
+ for _, cond := range condSlice {
+ if condMap, ok := cond.(map[string]interface{}); ok {
+ condition := ConditionOperation{}
+ if path, ok := condMap["path"].(string); ok {
+ condition.Path = path
+ }
+ if mode, ok := condMap["mode"].(string); ok {
+ condition.Mode = mode
+ }
+ if value, ok := condMap["value"]; ok {
+ condition.Value = value
+ }
+ if invert, ok := condMap["invert"].(bool); ok {
+ condition.Invert = invert
+ }
+ if passMissingKey, ok := condMap["pass_missing_key"].(bool); ok {
+ condition.PassMissingKey = passMissingKey
+ }
+ operation.Conditions = append(operation.Conditions, condition)
+ }
+ }
+ }
+ }
+
+ operations = append(operations, operation)
+ } else {
+ return nil, false
+ }
+ }
+ return operations, true
+ }
+ }
+
+ return nil, false
+}
+
+func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) {
+ if len(conditions) == 0 {
+ return true, nil // 没有条件,直接通过
+ }
+ results := make([]bool, len(conditions))
+ for i, condition := range conditions {
+ result, err := checkSingleCondition(jsonStr, condition)
+ if err != nil {
+ return false, err
+ }
+ results[i] = result
+ }
+
+ if strings.ToUpper(logic) == "AND" {
+ for _, result := range results {
+ if !result {
+ return false, nil
+ }
+ }
+ return true, nil
+ } else {
+ for _, result := range results {
+ if result {
+ return true, nil
+ }
+ }
+ return false, nil
+ }
+}
+
+func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
+ value := gjson.Get(jsonStr, condition.Path)
+ if !value.Exists() {
+ if condition.PassMissingKey {
+ return true, nil
+ }
+ return false, nil
+ }
+
+ // 利用gjson的类型解析
+ targetBytes, err := json.Marshal(condition.Value)
+ if err != nil {
+ return false, fmt.Errorf("failed to marshal condition value: %v", err)
+ }
+ targetValue := gjson.ParseBytes(targetBytes)
+
+ result, err := compareGjsonValues(value, targetValue, strings.ToLower(condition.Mode))
+ if err != nil {
+ return false, fmt.Errorf("comparison failed for path %s: %v", condition.Path, err)
+ }
+
+ if condition.Invert {
+ result = !result
+ }
+ return result, nil
+}
+
+// compareGjsonValues 直接比较两个gjson.Result,支持所有比较模式
+func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) {
+ switch mode {
+ case "full":
+ return compareEqual(jsonValue, targetValue)
+ case "prefix":
+ return strings.HasPrefix(jsonValue.String(), targetValue.String()), nil
+ case "suffix":
+ return strings.HasSuffix(jsonValue.String(), targetValue.String()), nil
+ case "contains":
+ return strings.Contains(jsonValue.String(), targetValue.String()), nil
+ case "gt":
+ return compareNumeric(jsonValue, targetValue, "gt")
+ case "gte":
+ return compareNumeric(jsonValue, targetValue, "gte")
+ case "lt":
+ return compareNumeric(jsonValue, targetValue, "lt")
+ case "lte":
+ return compareNumeric(jsonValue, targetValue, "lte")
+ default:
+ return false, fmt.Errorf("unsupported comparison mode: %s", mode)
+ }
+}
+
+func compareEqual(jsonValue, targetValue gjson.Result) (bool, error) {
+ // 对布尔值特殊处理
+ if (jsonValue.Type == gjson.True || jsonValue.Type == gjson.False) &&
+ (targetValue.Type == gjson.True || targetValue.Type == gjson.False) {
+ return jsonValue.Bool() == targetValue.Bool(), nil
+ }
+
+ // 如果类型不同,报错
+ if jsonValue.Type != targetValue.Type {
+ return false, fmt.Errorf("compare for different types, got %v and %v", jsonValue.Type, targetValue.Type)
+ }
+
+ switch jsonValue.Type {
+ case gjson.True, gjson.False:
+ return jsonValue.Bool() == targetValue.Bool(), nil
+ case gjson.Number:
+ return jsonValue.Num == targetValue.Num, nil
+ case gjson.String:
+ return jsonValue.String() == targetValue.String(), nil
+ default:
+ return jsonValue.String() == targetValue.String(), nil
+ }
+}
+
+func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool, error) {
+ // 只有数字类型才支持数值比较
+ if jsonValue.Type != gjson.Number || targetValue.Type != gjson.Number {
+ return false, fmt.Errorf("numeric comparison requires both values to be numbers, got %v and %v", jsonValue.Type, targetValue.Type)
+ }
+
+ jsonNum := jsonValue.Num
+ targetNum := targetValue.Num
+
+ switch operator {
+ case "gt":
+ return jsonNum > targetNum, nil
+ case "gte":
+ return jsonNum >= targetNum, nil
+ case "lt":
+ return jsonNum < targetNum, nil
+ case "lte":
+ return jsonNum <= targetNum, nil
+ default:
+ return false, fmt.Errorf("unsupported numeric operator: %s", operator)
+ }
+}
+
+// applyOperationsLegacy 原参数覆盖方法
+func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
+ reqMap := make(map[string]interface{})
+ err := json.Unmarshal(jsonData, &reqMap)
+ if err != nil {
+ return nil, err
+ }
+
+ for key, value := range paramOverride {
+ reqMap[key] = value
+ }
+
+ return json.Marshal(reqMap)
+}
+
+func applyOperations(jsonStr string, operations []ParamOperation) (string, error) {
+ result := jsonStr
+ for _, op := range operations {
+ // 检查条件是否满足
+ ok, err := checkConditions(result, op.Conditions, op.Logic)
+ if err != nil {
+ return "", err
+ }
+ if !ok {
+ continue // 条件不满足,跳过当前操作
+ }
+
+ switch op.Mode {
+ case "delete":
+ result, err = sjson.Delete(result, op.Path)
+ case "set":
+ if op.KeepOrigin && gjson.Get(result, op.Path).Exists() {
+ continue
+ }
+ result, err = sjson.Set(result, op.Path, op.Value)
+ case "move":
+ result, err = moveValue(result, op.From, op.To)
+ case "prepend":
+ result, err = modifyValue(result, op.Path, op.Value, op.KeepOrigin, true)
+ case "append":
+ result, err = modifyValue(result, op.Path, op.Value, op.KeepOrigin, false)
+ default:
+ return "", fmt.Errorf("unknown operation: %s", op.Mode)
+ }
+ if err != nil {
+ return "", fmt.Errorf("operation %s failed: %v", op.Mode, err)
+ }
+ }
+ return result, nil
+}
+
+func moveValue(jsonStr, fromPath, toPath string) (string, error) {
+ sourceValue := gjson.Get(jsonStr, fromPath)
+ if !sourceValue.Exists() {
+ return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
+ }
+ result, err := sjson.Set(jsonStr, toPath, sourceValue.Value())
+ if err != nil {
+ return "", err
+ }
+ return sjson.Delete(result, fromPath)
+}
+
+func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) {
+ current := gjson.Get(jsonStr, path)
+ switch {
+ case current.IsArray():
+ return modifyArray(jsonStr, path, value, isPrepend)
+ case current.Type == gjson.String:
+ return modifyString(jsonStr, path, value, isPrepend)
+ case current.Type == gjson.JSON:
+ return mergeObjects(jsonStr, path, value, keepOrigin)
+ }
+ return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
+}
+
+func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
+ current := gjson.Get(jsonStr, path)
+ var newArray []interface{}
+ // 添加新值
+ addValue := func() {
+ if arr, ok := value.([]interface{}); ok {
+ newArray = append(newArray, arr...)
+ } else {
+ newArray = append(newArray, value)
+ }
+ }
+ // 添加原值
+ addOriginal := func() {
+ current.ForEach(func(_, val gjson.Result) bool {
+ newArray = append(newArray, val.Value())
+ return true
+ })
+ }
+ if isPrepend {
+ addValue()
+ addOriginal()
+ } else {
+ addOriginal()
+ addValue()
+ }
+ return sjson.Set(jsonStr, path, newArray)
+}
+
+func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
+ current := gjson.Get(jsonStr, path)
+ valueStr := fmt.Sprintf("%v", value)
+ var newStr string
+ if isPrepend {
+ newStr = valueStr + current.String()
+ } else {
+ newStr = current.String() + valueStr
+ }
+ return sjson.Set(jsonStr, path, newStr)
+}
+
+func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
+ current := gjson.Get(jsonStr, path)
+ var currentMap, newMap map[string]interface{}
+
+ // 解析当前值
+ if err := json.Unmarshal([]byte(current.Raw), ¤tMap); err != nil {
+ return "", err
+ }
+ // 解析新值
+ switch v := value.(type) {
+ case map[string]interface{}:
+ newMap = v
+ default:
+ jsonBytes, _ := json.Marshal(v)
+ if err := json.Unmarshal(jsonBytes, &newMap); err != nil {
+ return "", err
+ }
+ }
+ // 合并
+ result := make(map[string]interface{})
+ for k, v := range currentMap {
+ result[k] = v
+ }
+ for k, v := range newMap {
+ if !keepOrigin || result[k] == nil {
+ result[k] = v
+ }
+ }
+ return sjson.Set(jsonStr, path, result)
+}
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
index f4fc3c1e..032a577d 100644
--- a/relay/common/relay_info.go
+++ b/relay/common/relay_info.go
@@ -1,10 +1,13 @@
package common
import (
+ "errors"
+ "fmt"
"one-api/common"
"one-api/constant"
"one-api/dto"
relayconstant "one-api/relay/constant"
+ "one-api/types"
"strings"
"time"
@@ -33,12 +36,6 @@ type ClaudeConvertInfo struct {
Done bool
}
-const (
- RelayFormatOpenAI = "openai"
- RelayFormatClaude = "claude"
- RelayFormatGemini = "gemini"
-)
-
type RerankerInfo struct {
Documents []any
ReturnDocuments bool
@@ -54,74 +51,216 @@ type ResponsesUsageInfo struct {
BuiltInTools map[string]*BuildInToolInfo
}
+type ChannelMeta struct {
+ ChannelType int
+ ChannelId int
+ ChannelIsMultiKey bool
+ ChannelMultiKeyIndex int
+ ChannelBaseUrl string
+ ApiType int
+ ApiVersion string
+ ApiKey string
+ Organization string
+ ChannelCreateTime int64
+ ParamOverride map[string]interface{}
+ ChannelSetting dto.ChannelSettings
+ ChannelOtherSettings dto.ChannelOtherSettings
+ UpstreamModelName string
+ IsModelMapped bool
+ SupportStreamOptions bool // 是否支持流式选项
+}
+
type RelayInfo struct {
- ChannelType int
- ChannelId int
TokenId int
TokenKey string
UserId int
- Group string
+ UsingGroup string // 使用的分组
+ UserGroup string // 用户所在分组
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
isFirstResponse bool
//SendLastReasoningResponse bool
- ApiType int
- IsStream bool
- IsPlayground bool
- UsePrice bool
- RelayMode int
- UpstreamModelName string
- OriginModelName string
- //RecodeModelName string
- RequestURLPath string
- ApiVersion string
- PromptTokens int
- ApiKey string
- Organization string
- BaseUrl string
- SupportStreamOptions bool
- ShouldIncludeUsage bool
- IsModelMapped bool
- ClientWs *websocket.Conn
- TargetWs *websocket.Conn
- InputAudioFormat string
- OutputAudioFormat string
- RealtimeTools []dto.RealTimeTool
- IsFirstRequest bool
- AudioUsage bool
- ReasoningEffort string
- ChannelSetting map[string]interface{}
- ParamOverride map[string]interface{}
- UserSetting map[string]interface{}
- UserEmail string
- UserQuota int
- RelayFormat string
- SendResponseCount int
- ChannelCreateTime int64
+ IsStream bool
+ IsGeminiBatchEmbedding bool
+ IsPlayground bool
+ UsePrice bool
+ RelayMode int
+ OriginModelName string
+ RequestURLPath string
+ PromptTokens int
+ ShouldIncludeUsage bool
+ DisablePing bool // 是否禁止向下游发送自定义 Ping
+ ClientWs *websocket.Conn
+ TargetWs *websocket.Conn
+ InputAudioFormat string
+ OutputAudioFormat string
+ RealtimeTools []dto.RealTimeTool
+ IsFirstRequest bool
+ AudioUsage bool
+ ReasoningEffort string
+ UserSetting dto.UserSetting
+ UserEmail string
+ UserQuota int
+ RelayFormat types.RelayFormat
+ SendResponseCount int
+ FinalPreConsumedQuota int // 最终预消耗的配额
+
+ PriceData types.PriceData
+
+ Request dto.Request
+
ThinkingContentInfo
*ClaudeConvertInfo
*RerankerInfo
*ResponsesUsageInfo
+ *ChannelMeta
+}
+
+func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
+ channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
+ paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
+ apiType, _ := common.ChannelType2APIType(channelType)
+ channelMeta := &ChannelMeta{
+ ChannelType: channelType,
+ ChannelId: common.GetContextKeyInt(c, constant.ContextKeyChannelId),
+ ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
+ ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
+ ChannelBaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
+ ApiType: apiType,
+ ApiVersion: c.GetString("api_version"),
+ ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
+ Organization: c.GetString("channel_organization"),
+ ChannelCreateTime: c.GetInt64("channel_create_time"),
+ ParamOverride: paramOverride,
+ UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
+ IsModelMapped: false,
+ SupportStreamOptions: false,
+ }
+
+ if channelType == constant.ChannelTypeAzure {
+ channelMeta.ApiVersion = GetAPIVersion(c)
+ }
+ if channelType == constant.ChannelTypeVertexAi {
+ channelMeta.ApiVersion = c.GetString("region")
+ }
+
+ channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
+ if ok {
+ channelMeta.ChannelSetting = channelSetting
+ }
+
+ channelOtherSettings, ok := common.GetContextKeyType[dto.ChannelOtherSettings](c, constant.ContextKeyChannelOtherSetting)
+ if ok {
+ channelMeta.ChannelOtherSettings = channelOtherSettings
+ }
+
+ if streamSupportedChannels[channelMeta.ChannelType] {
+ channelMeta.SupportStreamOptions = true
+ }
+
+ info.ChannelMeta = channelMeta
+
+ // reset some fields based on channel meta
+ // 重置某些字段,例如模型名称等
+ if info.Request != nil {
+ info.Request.SetModelName(info.OriginModelName)
+ }
+}
+
+func (info *RelayInfo) ToString() string {
+ if info == nil {
+ return "RelayInfo"
+ }
+
+ // Basic info
+ b := &strings.Builder{}
+ fmt.Fprintf(b, "RelayInfo{ ")
+ fmt.Fprintf(b, "RelayFormat: %s, ", info.RelayFormat)
+ fmt.Fprintf(b, "RelayMode: %d, ", info.RelayMode)
+ fmt.Fprintf(b, "IsStream: %t, ", info.IsStream)
+ fmt.Fprintf(b, "IsPlayground: %t, ", info.IsPlayground)
+ fmt.Fprintf(b, "RequestURLPath: %q, ", info.RequestURLPath)
+ fmt.Fprintf(b, "OriginModelName: %q, ", info.OriginModelName)
+ fmt.Fprintf(b, "PromptTokens: %d, ", info.PromptTokens)
+ fmt.Fprintf(b, "ShouldIncludeUsage: %t, ", info.ShouldIncludeUsage)
+ fmt.Fprintf(b, "DisablePing: %t, ", info.DisablePing)
+ fmt.Fprintf(b, "SendResponseCount: %d, ", info.SendResponseCount)
+ fmt.Fprintf(b, "FinalPreConsumedQuota: %d, ", info.FinalPreConsumedQuota)
+
+ // User & token info (mask secrets)
+ fmt.Fprintf(b, "User{ Id: %d, Email: %q, Group: %q, UsingGroup: %q, Quota: %d }, ",
+ info.UserId, common.MaskEmail(info.UserEmail), info.UserGroup, info.UsingGroup, info.UserQuota)
+ fmt.Fprintf(b, "Token{ Id: %d, Unlimited: %t, Key: ***masked*** }, ", info.TokenId, info.TokenUnlimited)
+
+ // Time info
+ latencyMs := info.FirstResponseTime.Sub(info.StartTime).Milliseconds()
+ fmt.Fprintf(b, "Timing{ Start: %s, FirstResponse: %s, LatencyMs: %d }, ",
+ info.StartTime.Format(time.RFC3339Nano), info.FirstResponseTime.Format(time.RFC3339Nano), latencyMs)
+
+ // Audio / realtime
+ if info.InputAudioFormat != "" || info.OutputAudioFormat != "" || len(info.RealtimeTools) > 0 || info.AudioUsage {
+ fmt.Fprintf(b, "Realtime{ AudioUsage: %t, InFmt: %q, OutFmt: %q, Tools: %d }, ",
+ info.AudioUsage, info.InputAudioFormat, info.OutputAudioFormat, len(info.RealtimeTools))
+ }
+
+ // Reasoning
+ if info.ReasoningEffort != "" {
+ fmt.Fprintf(b, "ReasoningEffort: %q, ", info.ReasoningEffort)
+ }
+
+ // Price data (non-sensitive)
+ if info.PriceData.UsePrice {
+ fmt.Fprintf(b, "PriceData{ %s }, ", info.PriceData.ToSetting())
+ }
+
+ // Channel metadata (mask ApiKey)
+ if info.ChannelMeta != nil {
+ cm := info.ChannelMeta
+ fmt.Fprintf(b, "ChannelMeta{ Type: %d, Id: %d, IsMultiKey: %t, MultiKeyIndex: %d, BaseURL: %q, ApiType: %d, ApiVersion: %q, Organization: %q, CreateTime: %d, UpstreamModelName: %q, IsModelMapped: %t, SupportStreamOptions: %t, ApiKey: ***masked*** }, ",
+ cm.ChannelType, cm.ChannelId, cm.ChannelIsMultiKey, cm.ChannelMultiKeyIndex, cm.ChannelBaseUrl, cm.ApiType, cm.ApiVersion, cm.Organization, cm.ChannelCreateTime, cm.UpstreamModelName, cm.IsModelMapped, cm.SupportStreamOptions)
+ }
+
+ // Responses usage info (non-sensitive)
+ if info.ResponsesUsageInfo != nil && len(info.ResponsesUsageInfo.BuiltInTools) > 0 {
+ fmt.Fprintf(b, "ResponsesTools{ ")
+ first := true
+ for name, tool := range info.ResponsesUsageInfo.BuiltInTools {
+ if !first {
+ fmt.Fprintf(b, ", ")
+ }
+ first = false
+ if tool != nil {
+ fmt.Fprintf(b, "%s: calls=%d", name, tool.CallCount)
+ } else {
+ fmt.Fprintf(b, "%s: calls=0", name)
+ }
+ }
+ fmt.Fprintf(b, " }, ")
+ }
+
+ fmt.Fprintf(b, "}")
+ return b.String()
}
// 定义支持流式选项的通道类型
var streamSupportedChannels = map[int]bool{
- common.ChannelTypeOpenAI: true,
- common.ChannelTypeAnthropic: true,
- common.ChannelTypeAws: true,
- common.ChannelTypeGemini: true,
- common.ChannelCloudflare: true,
- common.ChannelTypeAzure: true,
- common.ChannelTypeVolcEngine: true,
- common.ChannelTypeOllama: true,
- common.ChannelTypeXai: true,
- common.ChannelTypeDeepSeek: true,
- common.ChannelTypeBaiduV2: true,
+ constant.ChannelTypeOpenAI: true,
+ constant.ChannelTypeAnthropic: true,
+ constant.ChannelTypeAws: true,
+ constant.ChannelTypeGemini: true,
+ constant.ChannelCloudflare: true,
+ constant.ChannelTypeAzure: true,
+ constant.ChannelTypeVolcEngine: true,
+ constant.ChannelTypeOllama: true,
+ constant.ChannelTypeXai: true,
+ constant.ChannelTypeDeepSeek: true,
+ constant.ChannelTypeBaiduV2: true,
}
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
- info := GenRelayInfo(c)
+ info := genBaseRelayInfo(c, nil)
+ info.RelayFormat = types.RelayFormatOpenAIRealtime
info.ClientWs = ws
info.InputAudioFormat = "pcm16"
info.OutputAudioFormat = "pcm16"
@@ -129,9 +268,9 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
return info
}
-func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
- info := GenRelayInfo(c)
- info.RelayFormat = RelayFormatClaude
+func GenRelayInfoClaude(c *gin.Context, request dto.Request) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
+ info.RelayFormat = types.RelayFormatClaude
info.ShouldIncludeUsage = false
info.ClaudeConvertInfo = &ClaudeConvertInfo{
LastMessagesType: LastMessageTypeNone,
@@ -139,123 +278,178 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
return info
}
-func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
- info := GenRelayInfo(c)
+func GenRelayInfoRerank(c *gin.Context, request *dto.RerankRequest) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
info.RelayMode = relayconstant.RelayModeRerank
+ info.RelayFormat = types.RelayFormatRerank
info.RerankerInfo = &RerankerInfo{
- Documents: req.Documents,
- ReturnDocuments: req.GetReturnDocuments(),
+ Documents: request.Documents,
+ ReturnDocuments: request.GetReturnDocuments(),
}
return info
}
-func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
- info := GenRelayInfo(c)
+func GenRelayInfoOpenAIAudio(c *gin.Context, request dto.Request) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
+ info.RelayFormat = types.RelayFormatOpenAIAudio
+ return info
+}
+
+func GenRelayInfoEmbedding(c *gin.Context, request dto.Request) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
+ info.RelayFormat = types.RelayFormatEmbedding
+ return info
+}
+
+func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
info.RelayMode = relayconstant.RelayModeResponses
+ info.RelayFormat = types.RelayFormatOpenAIResponses
+
info.ResponsesUsageInfo = &ResponsesUsageInfo{
BuiltInTools: make(map[string]*BuildInToolInfo),
}
- if len(req.Tools) > 0 {
- for _, tool := range req.Tools {
- info.ResponsesUsageInfo.BuiltInTools[tool.Type] = &BuildInToolInfo{
- ToolName: tool.Type,
+ if len(request.Tools) > 0 {
+ for _, tool := range request.Tools {
+ toolType := common.Interface2String(tool["type"])
+ info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
+ ToolName: toolType,
CallCount: 0,
}
- switch tool.Type {
+ switch toolType {
case dto.BuildInToolWebSearchPreview:
- if tool.SearchContextSize == "" {
- tool.SearchContextSize = "medium"
+ searchContextSize := common.Interface2String(tool["search_context_size"])
+ if searchContextSize == "" {
+ searchContextSize = "medium"
}
- info.ResponsesUsageInfo.BuiltInTools[tool.Type].SearchContextSize = tool.SearchContextSize
+ info.ResponsesUsageInfo.BuiltInTools[toolType].SearchContextSize = searchContextSize
}
}
}
- info.IsStream = req.Stream
return info
}
-func GenRelayInfo(c *gin.Context) *RelayInfo {
- channelType := c.GetInt("channel_type")
- channelId := c.GetInt("channel_id")
- channelSetting := c.GetStringMap("channel_setting")
- paramOverride := c.GetStringMap("param_override")
+func GenRelayInfoGemini(c *gin.Context, request dto.Request) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
+ info.RelayFormat = types.RelayFormatGemini
+ info.ShouldIncludeUsage = false
+
+ return info
+}
+
+func GenRelayInfoImage(c *gin.Context, request dto.Request) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
+ info.RelayFormat = types.RelayFormatOpenAIImage
+ return info
+}
+
+func GenRelayInfoOpenAI(c *gin.Context, request dto.Request) *RelayInfo {
+ info := genBaseRelayInfo(c, request)
+ info.RelayFormat = types.RelayFormatOpenAI
+ return info
+}
+
+func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
+
+ //channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
+ //channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
+ //paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
+
+ startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
+ if startTime.IsZero() {
+ startTime = time.Now()
+ }
+
+ isStream := false
+
+ if request != nil {
+ isStream = request.IsStream(c)
+ }
- tokenId := c.GetInt("token_id")
- tokenKey := c.GetString("token_key")
- userId := c.GetInt("id")
- group := c.GetString("group")
- tokenUnlimited := c.GetBool("token_unlimited_quota")
- startTime := c.GetTime(constant.ContextKeyRequestStartTime)
// firstResponseTime = time.Now() - 1 second
- apiType, _ := relayconstant.ChannelType2APIType(channelType)
-
info := &RelayInfo{
- UserQuota: c.GetInt(constant.ContextKeyUserQuota),
- UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
- UserEmail: c.GetString(constant.ContextKeyUserEmail),
- isFirstResponse: true,
- RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
- BaseUrl: c.GetString("base_url"),
- RequestURLPath: c.Request.URL.String(),
- ChannelType: channelType,
- ChannelId: channelId,
- TokenId: tokenId,
- TokenKey: tokenKey,
- UserId: userId,
- Group: group,
- TokenUnlimited: tokenUnlimited,
+ Request: request,
+
+ UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId),
+ UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
+ UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
+ UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
+ UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
+
+ OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
+ PromptTokens: common.GetContextKeyInt(c, constant.ContextKeyPromptTokens),
+
+ TokenId: common.GetContextKeyInt(c, constant.ContextKeyTokenId),
+ TokenKey: common.GetContextKeyString(c, constant.ContextKeyTokenKey),
+ TokenUnlimited: common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited),
+
+ isFirstResponse: true,
+ RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
+ RequestURLPath: c.Request.URL.String(),
+ IsStream: isStream,
+
StartTime: startTime,
FirstResponseTime: startTime.Add(-time.Second),
- OriginModelName: c.GetString("original_model"),
- UpstreamModelName: c.GetString("original_model"),
- //RecodeModelName: c.GetString("original_model"),
- IsModelMapped: false,
- ApiType: apiType,
- ApiVersion: c.GetString("api_version"),
- ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
- Organization: c.GetString("channel_organization"),
- ChannelSetting: channelSetting,
- ChannelCreateTime: c.GetInt64("channel_create_time"),
- ParamOverride: paramOverride,
- RelayFormat: RelayFormatOpenAI,
ThinkingContentInfo: ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
},
}
+
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true
info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
info.RequestURLPath = "/v1" + info.RequestURLPath
}
- if info.BaseUrl == "" {
- info.BaseUrl = common.ChannelBaseURLs[channelType]
- }
- if info.ChannelType == common.ChannelTypeAzure {
- info.ApiVersion = GetAPIVersion(c)
- }
- if info.ChannelType == common.ChannelTypeVertexAi {
- info.ApiVersion = c.GetString("region")
- }
- if streamSupportedChannels[info.ChannelType] {
- info.SupportStreamOptions = true
- }
- // responses 模式不支持 StreamOptions
- if relayconstant.RelayModeResponses == info.RelayMode {
- info.SupportStreamOptions = false
+
+ userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
+ if ok {
+ info.UserSetting = userSetting
}
+
return info
}
+func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Request, ws *websocket.Conn) (*RelayInfo, error) {
+ switch relayFormat {
+ case types.RelayFormatOpenAI:
+ return GenRelayInfoOpenAI(c, request), nil
+ case types.RelayFormatOpenAIAudio:
+ return GenRelayInfoOpenAIAudio(c, request), nil
+ case types.RelayFormatOpenAIImage:
+ return GenRelayInfoImage(c, request), nil
+ case types.RelayFormatOpenAIRealtime:
+ return GenRelayInfoWs(c, ws), nil
+ case types.RelayFormatClaude:
+ return GenRelayInfoClaude(c, request), nil
+ case types.RelayFormatRerank:
+ if request, ok := request.(*dto.RerankRequest); ok {
+ return GenRelayInfoRerank(c, request), nil
+ }
+ return nil, errors.New("request is not a RerankRequest")
+ case types.RelayFormatGemini:
+ return GenRelayInfoGemini(c, request), nil
+ case types.RelayFormatEmbedding:
+ return GenRelayInfoEmbedding(c, request), nil
+ case types.RelayFormatOpenAIResponses:
+ if request, ok := request.(*dto.OpenAIResponsesRequest); ok {
+ return GenRelayInfoResponses(c, request), nil
+ }
+ return nil, errors.New("request is not a OpenAIResponsesRequest")
+ case types.RelayFormatTask:
+ return genBaseRelayInfo(c, nil), nil
+ case types.RelayFormatMjProxy:
+ return genBaseRelayInfo(c, nil), nil
+ default:
+ return nil, errors.New("invalid relay format")
+ }
+}
+
func (info *RelayInfo) SetPromptTokens(promptTokens int) {
info.PromptTokens = promptTokens
}
-func (info *RelayInfo) SetIsStream(isStream bool) {
- info.IsStream = isStream
-}
-
func (info *RelayInfo) SetFirstResponseTime() {
if info.isFirstResponse {
info.FirstResponseTime = time.Now()
@@ -275,9 +469,33 @@ type TaskRelayInfo struct {
ConsumeQuota bool
}
-func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
- info := &TaskRelayInfo{
- RelayInfo: GenRelayInfo(c),
+func GenTaskRelayInfo(c *gin.Context) (*TaskRelayInfo, error) {
+ relayInfo, err := GenRelayInfo(c, types.RelayFormatTask, nil, nil)
+ if err != nil {
+ return nil, err
}
- return info
+ info := &TaskRelayInfo{
+ RelayInfo: relayInfo,
+ }
+ info.InitChannelMeta(c)
+ return info, nil
+}
+
+type TaskSubmitReq struct {
+ Prompt string `json:"prompt"`
+ Model string `json:"model,omitempty"`
+ Mode string `json:"mode,omitempty"`
+ Image string `json:"image,omitempty"`
+ Size string `json:"size,omitempty"`
+ Duration int `json:"duration,omitempty"`
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+}
+
+type TaskInfo struct {
+ Code int `json:"code"`
+ TaskID string `json:"task_id"`
+ Status string `json:"status"`
+ Reason string `json:"reason,omitempty"`
+ Url string `json:"url,omitempty"`
+ Progress string `json:"progress,omitempty"`
}
diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go
index 7a4f44bb..29086585 100644
--- a/relay/common/relay_utils.go
+++ b/relay/common/relay_utils.go
@@ -6,7 +6,7 @@ import (
_ "image/gif"
_ "image/jpeg"
_ "image/png"
- "one-api/common"
+ "one-api/constant"
"strings"
)
@@ -15,9 +15,9 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
- case common.ChannelTypeOpenAI:
+ case constant.ChannelTypeOpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
- case common.ChannelTypeAzure:
+ case constant.ChannelTypeAzure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go
index 496278b5..05dbfa6d 100644
--- a/relay/common_handler/rerank.go
+++ b/relay/common_handler/rerank.go
@@ -1,34 +1,34 @@
package common_handler
import (
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
"one-api/service"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
)
-func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
- }
- err = resp.Body.Close()
- if err != nil {
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
+ service.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
println("reranker response body: ", string(responseBody))
}
var jinaResp dto.RerankResponse
- if info.ChannelType == common.ChannelTypeXinference {
+ if info.ChannelType == constant.ChannelTypeXinference {
var xinRerankResponse xinference.XinRerankResponse
- err = common.DecodeJson(responseBody, &xinRerankResponse)
+ err = common.Unmarshal(responseBody, &xinRerankResponse)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
for i, result := range xinRerankResponse.Results {
@@ -38,10 +38,16 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
}
if info.ReturnDocuments {
var document any
- if result.Document == "" {
- document = info.Documents[result.Index]
- } else {
- document = result.Document
+ if result.Document != nil {
+ if doc, ok := result.Document.(string); ok {
+ if doc == "" {
+ document = info.Documents[result.Index]
+ } else {
+ document = doc
+ }
+ } else {
+ document = result.Document
+ }
}
respResult.Document = document
}
@@ -55,14 +61,14 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
},
}
} else {
- err = common.DecodeJson(responseBody, &jinaResp)
+ err = common.Unmarshal(responseBody, &jinaResp)
if err != nil {
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
}
c.Writer.Header().Set("Content-Type", "application/json")
c.JSON(http.StatusOK, jinaResp)
- return nil, &jinaResp.Usage
+ return &jinaResp.Usage, nil
}
diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go
new file mode 100644
index 00000000..56d65a3f
--- /dev/null
+++ b/relay/compatible_handler.go
@@ -0,0 +1,430 @@
+package relay
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/logger"
+ "one-api/model"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting/model_setting"
+ "one-api/setting/operation_setting"
+ "one-api/types"
+ "strings"
+ "time"
+
+ "github.com/shopspring/decimal"
+
+ "github.com/gin-gonic/gin"
+)
+
+func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+ info.InitChannelMeta(c)
+
+ textReq, ok := info.Request.(*dto.GeneralOpenAIRequest)
+ if !ok {
+ return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ request, err := common.DeepCopy(textReq)
+ if err != nil {
+ return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ if request.WebSearchOptions != nil {
+ c.Set("chat_completion_web_search_context_size", request.WebSearchOptions.SearchContextSize)
+ }
+
+ err = helper.ModelMappedHelper(c, info, request)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
+ }
+
+ includeUsage := true
+ // 判断用户是否需要返回使用情况
+ if request.StreamOptions != nil {
+ includeUsage = request.StreamOptions.IncludeUsage
+ }
+
+ // 如果不支持StreamOptions,将StreamOptions设置为nil
+ if !info.SupportStreamOptions || !request.Stream {
+ request.StreamOptions = nil
+ } else {
+ // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
+ if constant.ForceStreamOption {
+ request.StreamOptions = &dto.StreamOptions{
+ IncludeUsage: true,
+ }
+ }
+ }
+
+ info.ShouldIncludeUsage = includeUsage
+
+ adaptor := GetAdaptor(info.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ }
+ adaptor.Init(info)
+ var requestBody io.Reader
+
+ if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
+ body, err := common.GetRequestBody(c)
+ if err != nil {
+ return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+ }
+ if common.DebugEnabled {
+ println("requestBody: ", string(body))
+ }
+ requestBody = bytes.NewBuffer(body)
+ } else {
+ convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+
+ if info.ChannelSetting.SystemPrompt != "" {
+ // 如果有系统提示,则将其添加到请求中
+ request := convertedRequest.(*dto.GeneralOpenAIRequest)
+ containSystemPrompt := false
+ for _, message := range request.Messages {
+ if message.Role == request.GetSystemRoleName() {
+ containSystemPrompt = true
+ break
+ }
+ }
+ if !containSystemPrompt {
+ // 如果没有系统提示,则添加系统提示
+ systemMessage := dto.Message{
+ Role: request.GetSystemRoleName(),
+ Content: info.ChannelSetting.SystemPrompt,
+ }
+ request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
+ } else if info.ChannelSetting.SystemPromptOverride {
+ common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
+ // 如果有系统提示,且允许覆盖,则拼接到前面
+ for i, message := range request.Messages {
+ if message.Role == request.GetSystemRoleName() {
+ if message.IsStringContent() {
+ request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
+ } else {
+ contents := message.ParseContent()
+ contents = append([]dto.MediaContent{
+ {
+ Type: dto.ContentTypeText,
+ Text: info.ChannelSetting.SystemPrompt,
+ },
+ }, contents...)
+ request.Messages[i].Content = contents
+ }
+ break
+ }
+ }
+ }
+ }
+
+ jsonData, err := common.Marshal(convertedRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+
+ // apply param override
+ if len(info.ParamOverride) > 0 {
+ jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+ }
+ }
+
+ logger.LogDebug(c, fmt.Sprintf("text request body: %s", string(jsonData)))
+
+ requestBody = bytes.NewBuffer(jsonData)
+ }
+
+ var httpResp *http.Response
+ resp, err := adaptor.DoRequest(c, info, requestBody)
+ if err != nil {
+ return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
+ }
+
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+ if httpResp.StatusCode != http.StatusOK {
+ newApiErr := service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(newApiErr, statusCodeMappingStr)
+ return newApiErr
+ }
+ }
+
+ usage, newApiErr := adaptor.DoResponse(c, httpResp, info)
+ if newApiErr != nil {
+ // reset status code 重置状态码
+ service.ResetStatusCode(newApiErr, statusCodeMappingStr)
+ return newApiErr
+ }
+
+ if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
+ service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
+ } else {
+ postConsumeQuota(c, info, usage.(*dto.Usage), "")
+ }
+ return nil
+}
+
+func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
+ if usage == nil {
+ usage = &dto.Usage{
+ PromptTokens: relayInfo.PromptTokens,
+ CompletionTokens: 0,
+ TotalTokens: relayInfo.PromptTokens,
+ }
+ extraContent += "(可能是请求出错)"
+ }
+ useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
+ promptTokens := usage.PromptTokens
+ cacheTokens := usage.PromptTokensDetails.CachedTokens
+ imageTokens := usage.PromptTokensDetails.ImageTokens
+ audioTokens := usage.PromptTokensDetails.AudioTokens
+ completionTokens := usage.CompletionTokens
+ modelName := relayInfo.OriginModelName
+
+ tokenName := ctx.GetString("token_name")
+ completionRatio := relayInfo.PriceData.CompletionRatio
+ cacheRatio := relayInfo.PriceData.CacheRatio
+ imageRatio := relayInfo.PriceData.ImageRatio
+ modelRatio := relayInfo.PriceData.ModelRatio
+ groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
+ modelPrice := relayInfo.PriceData.ModelPrice
+
+ // Convert values to decimal for precise calculation
+ dPromptTokens := decimal.NewFromInt(int64(promptTokens))
+ dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
+ dImageTokens := decimal.NewFromInt(int64(imageTokens))
+ dAudioTokens := decimal.NewFromInt(int64(audioTokens))
+ dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
+ dCompletionRatio := decimal.NewFromFloat(completionRatio)
+ dCacheRatio := decimal.NewFromFloat(cacheRatio)
+ dImageRatio := decimal.NewFromFloat(imageRatio)
+ dModelRatio := decimal.NewFromFloat(modelRatio)
+ dGroupRatio := decimal.NewFromFloat(groupRatio)
+ dModelPrice := decimal.NewFromFloat(modelPrice)
+ dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
+
+ ratio := dModelRatio.Mul(dGroupRatio)
+
+ // openai web search 工具计费
+ var dWebSearchQuota decimal.Decimal
+ var webSearchPrice float64
+ // response api 格式工具计费
+ if relayInfo.ResponsesUsageInfo != nil {
+ if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
+ // 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率)
+ webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, webSearchTool.SearchContextSize)
+ dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
+ Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
+ Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
+ extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
+ webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String())
+ }
+ } else if strings.HasSuffix(modelName, "search-preview") {
+ // search-preview 模型不支持 response api
+ searchContextSize := ctx.GetString("chat_completion_web_search_context_size")
+ if searchContextSize == "" {
+ searchContextSize = "medium"
+ }
+ webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
+ dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
+ Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
+ extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
+ searchContextSize, dWebSearchQuota.String())
+ }
+ // claude web search tool 计费
+ var dClaudeWebSearchQuota decimal.Decimal
+ var claudeWebSearchPrice float64
+ claudeWebSearchCallCount := ctx.GetInt("claude_web_search_requests")
+ if claudeWebSearchCallCount > 0 {
+ claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
+ dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
+ Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
+ extraContent += fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
+ claudeWebSearchCallCount, dClaudeWebSearchQuota.String())
+ }
+ // file search tool 计费
+ var dFileSearchQuota decimal.Decimal
+ var fileSearchPrice float64
+ if relayInfo.ResponsesUsageInfo != nil {
+ if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
+ fileSearchPrice = operation_setting.GetFileSearchPricePerThousand()
+ dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
+ Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
+ Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
+ extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
+ fileSearchTool.CallCount, dFileSearchQuota.String())
+ }
+ }
+
+ var quotaCalculateDecimal decimal.Decimal
+
+ var audioInputQuota decimal.Decimal
+ var audioInputPrice float64
+ if !relayInfo.PriceData.UsePrice {
+ baseTokens := dPromptTokens
+ // 减去 cached tokens
+ var cachedTokensWithRatio decimal.Decimal
+ if !dCacheTokens.IsZero() {
+ baseTokens = baseTokens.Sub(dCacheTokens)
+ cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
+ }
+
+ // 减去 image tokens
+ var imageTokensWithRatio decimal.Decimal
+ if !dImageTokens.IsZero() {
+ baseTokens = baseTokens.Sub(dImageTokens)
+ imageTokensWithRatio = dImageTokens.Mul(dImageRatio)
+ }
+
+ // 减去 Gemini audio tokens
+ if !dAudioTokens.IsZero() {
+ audioInputPrice = operation_setting.GetGeminiInputAudioPricePerMillionTokens(modelName)
+ if audioInputPrice > 0 {
+ // 重新计算 base tokens
+ baseTokens = baseTokens.Sub(dAudioTokens)
+ audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
+ extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
+ }
+ }
+ promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio)
+
+ completionQuota := dCompletionTokens.Mul(dCompletionRatio)
+
+ quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
+
+ if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) {
+ quotaCalculateDecimal = decimal.NewFromInt(1)
+ }
+ } else {
+ quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
+ }
+ // 添加 responses tools call 调用的配额
+ quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
+ quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
+ // 添加 audio input 独立计费
+ quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
+
+ quota := int(quotaCalculateDecimal.Round(0).IntPart())
+ totalTokens := promptTokens + completionTokens
+
+ var logContent string
+
+ // record all the consume log even if quota is 0
+ if totalTokens == 0 {
+ // in this case, must be some error happened
+ // we cannot just return, because we may have to return the pre-consumed quota
+ quota = 0
+ logContent += fmt.Sprintf("(可能是上游超时)")
+ logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
+ } else {
+ if !ratio.IsZero() && quota == 0 {
+ quota = 1
+ }
+ model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
+ model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
+ }
+
+ quotaDelta := quota - relayInfo.FinalPreConsumedQuota
+
+ //logger.LogInfo(ctx, fmt.Sprintf("request quota delta: %s", logger.FormatQuota(quotaDelta)))
+
+ if quotaDelta > 0 {
+ logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
+ logger.FormatQuota(quotaDelta),
+ logger.FormatQuota(quota),
+ logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
+ ))
+ } else if quotaDelta < 0 {
+ logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
+ logger.FormatQuota(-quotaDelta),
+ logger.FormatQuota(quota),
+ logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
+ ))
+ }
+
+ if quotaDelta != 0 {
+ err := service.PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
+ if err != nil {
+ logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
+ }
+ }
+
+ logModel := modelName
+ if strings.HasPrefix(logModel, "gpt-4-gizmo") {
+ logModel = "gpt-4-gizmo-*"
+ logContent += fmt.Sprintf(",模型 %s", modelName)
+ }
+ if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
+ logModel = "gpt-4o-gizmo-*"
+ logContent += fmt.Sprintf(",模型 %s", modelName)
+ }
+ if extraContent != "" {
+ logContent += ", " + extraContent
+ }
+ other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
+ if imageTokens != 0 {
+ other["image"] = true
+ other["image_ratio"] = imageRatio
+ other["image_output"] = imageTokens
+ }
+ if !dWebSearchQuota.IsZero() {
+ if relayInfo.ResponsesUsageInfo != nil {
+ if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
+ other["web_search"] = true
+ other["web_search_call_count"] = webSearchTool.CallCount
+ other["web_search_price"] = webSearchPrice
+ }
+ } else if strings.HasSuffix(modelName, "search-preview") {
+ other["web_search"] = true
+ other["web_search_call_count"] = 1
+ other["web_search_price"] = webSearchPrice
+ }
+ } else if !dClaudeWebSearchQuota.IsZero() {
+ other["web_search"] = true
+ other["web_search_call_count"] = claudeWebSearchCallCount
+ other["web_search_price"] = claudeWebSearchPrice
+ }
+ if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil {
+ if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists {
+ other["file_search"] = true
+ other["file_search_call_count"] = fileSearchTool.CallCount
+ other["file_search_price"] = fileSearchPrice
+ }
+ }
+ if !audioInputQuota.IsZero() {
+ other["audio_input_seperate_price"] = true
+ other["audio_input_token_count"] = audioTokens
+ other["audio_input_price"] = audioInputPrice
+ }
+ model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
+ ChannelId: relayInfo.ChannelId,
+ PromptTokens: promptTokens,
+ CompletionTokens: completionTokens,
+ ModelName: logModel,
+ TokenName: tokenName,
+ Quota: quota,
+ Content: logContent,
+ TokenId: relayInfo.TokenId,
+ UseTimeSeconds: int(useTimeSeconds),
+ IsStream: relayInfo.IsStream,
+ Group: relayInfo.UsingGroup,
+ Other: other,
+ })
+}
diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go
deleted file mode 100644
index 3f1ecd78..00000000
--- a/relay/constant/api_type.go
+++ /dev/null
@@ -1,106 +0,0 @@
-package constant
-
-import (
- "one-api/common"
-)
-
-const (
- APITypeOpenAI = iota
- APITypeAnthropic
- APITypePaLM
- APITypeBaidu
- APITypeZhipu
- APITypeAli
- APITypeXunfei
- APITypeAIProxyLibrary
- APITypeTencent
- APITypeGemini
- APITypeZhipuV4
- APITypeOllama
- APITypePerplexity
- APITypeAws
- APITypeCohere
- APITypeDify
- APITypeJina
- APITypeCloudflare
- APITypeSiliconFlow
- APITypeVertexAi
- APITypeMistral
- APITypeDeepSeek
- APITypeMokaAI
- APITypeVolcEngine
- APITypeBaiduV2
- APITypeOpenRouter
- APITypeXinference
- APITypeXai
- APITypeCoze
- APITypeDummy // this one is only for count, do not add any channel after this
-)
-
-func ChannelType2APIType(channelType int) (int, bool) {
- apiType := -1
- switch channelType {
- case common.ChannelTypeOpenAI:
- apiType = APITypeOpenAI
- case common.ChannelTypeAnthropic:
- apiType = APITypeAnthropic
- case common.ChannelTypeBaidu:
- apiType = APITypeBaidu
- case common.ChannelTypePaLM:
- apiType = APITypePaLM
- case common.ChannelTypeZhipu:
- apiType = APITypeZhipu
- case common.ChannelTypeAli:
- apiType = APITypeAli
- case common.ChannelTypeXunfei:
- apiType = APITypeXunfei
- case common.ChannelTypeAIProxyLibrary:
- apiType = APITypeAIProxyLibrary
- case common.ChannelTypeTencent:
- apiType = APITypeTencent
- case common.ChannelTypeGemini:
- apiType = APITypeGemini
- case common.ChannelTypeZhipu_v4:
- apiType = APITypeZhipuV4
- case common.ChannelTypeOllama:
- apiType = APITypeOllama
- case common.ChannelTypePerplexity:
- apiType = APITypePerplexity
- case common.ChannelTypeAws:
- apiType = APITypeAws
- case common.ChannelTypeCohere:
- apiType = APITypeCohere
- case common.ChannelTypeDify:
- apiType = APITypeDify
- case common.ChannelTypeJina:
- apiType = APITypeJina
- case common.ChannelCloudflare:
- apiType = APITypeCloudflare
- case common.ChannelTypeSiliconFlow:
- apiType = APITypeSiliconFlow
- case common.ChannelTypeVertexAi:
- apiType = APITypeVertexAi
- case common.ChannelTypeMistral:
- apiType = APITypeMistral
- case common.ChannelTypeDeepSeek:
- apiType = APITypeDeepSeek
- case common.ChannelTypeMokaAI:
- apiType = APITypeMokaAI
- case common.ChannelTypeVolcEngine:
- apiType = APITypeVolcEngine
- case common.ChannelTypeBaiduV2:
- apiType = APITypeBaiduV2
- case common.ChannelTypeOpenRouter:
- apiType = APITypeOpenRouter
- case common.ChannelTypeXinference:
- apiType = APITypeXinference
- case common.ChannelTypeXai:
- apiType = APITypeXai
- case common.ChannelTypeCoze:
- apiType = APITypeCoze
- }
- if apiType == -1 {
- return APITypeOpenAI, false
- }
- return apiType, true
-}
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
index f22a20bd..85a1b9c5 100644
--- a/relay/constant/relay_mode.go
+++ b/relay/constant/relay_mode.go
@@ -29,6 +29,8 @@ const (
RelayModeMidjourneyShorten
RelayModeSwapFace
RelayModeMidjourneyUpload
+ RelayModeMidjourneyVideo
+ RelayModeMidjourneyEdits
RelayModeAudioSpeech // tts
RelayModeAudioTranscription // whisper
@@ -38,6 +40,9 @@ const (
RelayModeSunoFetchByID
RelayModeSunoSubmit
+ RelayModeVideoFetchByID
+ RelayModeVideoSubmit
+
RelayModeRerank
RelayModeResponses
@@ -77,8 +82,10 @@ func Path2RelayMode(path string) int {
relayMode = RelayModeRerank
} else if strings.HasPrefix(path, "/v1/realtime") {
relayMode = RelayModeRealtime
- } else if strings.HasPrefix(path, "/v1beta/models") {
+ } else if strings.HasPrefix(path, "/v1beta/models") || strings.HasPrefix(path, "/v1/models") {
relayMode = RelayModeGemini
+ } else if strings.HasPrefix(path, "/mj") {
+ relayMode = Path2RelayModeMidjourney(path)
}
return relayMode
}
@@ -102,6 +109,10 @@ func Path2RelayModeMidjourney(path string) int {
relayMode = RelayModeMidjourneyUpload
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine
+ } else if strings.HasSuffix(path, "/mj/submit/video") {
+ relayMode = RelayModeMidjourneyVideo
+ } else if strings.HasSuffix(path, "/mj/submit/edits") {
+ relayMode = RelayModeMidjourneyEdits
} else if strings.HasSuffix(path, "/mj/submit/blend") {
relayMode = RelayModeMidjourneyBlend
} else if strings.HasSuffix(path, "/mj/submit/describe") {
diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go
new file mode 100644
index 00000000..26dcf971
--- /dev/null
+++ b/relay/embedding_handler.go
@@ -0,0 +1,76 @@
+package relay
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+ info.InitChannelMeta(c)
+
+ embeddingReq, ok := info.Request.(*dto.EmbeddingRequest)
+ if !ok {
+ return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ request, err := common.DeepCopy(embeddingReq)
+ if err != nil {
+ return types.NewError(fmt.Errorf("failed to copy request to EmbeddingRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ err = helper.ModelMappedHelper(c, info, request)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
+ }
+
+ adaptor := GetAdaptor(info.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ }
+ adaptor.Init(info)
+
+ convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, info, *request)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+ jsonData, err := json.Marshal(convertedRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+ requestBody := bytes.NewBuffer(jsonData)
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+ resp, err := adaptor.DoRequest(c, info, requestBody)
+ if err != nil {
+ return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
+ }
+
+ var httpResp *http.Response
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ if httpResp.StatusCode != http.StatusOK {
+ newAPIError = service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ }
+
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
+ if newAPIError != nil {
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ postConsumeQuota(c, info, usage.(*dto.Usage), "")
+ return nil
+}
diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go
new file mode 100644
index 00000000..460fd2f5
--- /dev/null
+++ b/relay/gemini_handler.go
@@ -0,0 +1,266 @@
+package relay
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/logger"
+ "one-api/relay/channel/gemini"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting/model_setting"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
+ if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
+ configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget
+ if configBudget != nil && *configBudget == 0 {
+ // 如果思考预算为 0,则认为是非思考请求
+ return true
+ }
+ }
+ return false
+}
+
+func trimModelThinking(modelName string) string {
+ // 去除模型名称中的 -nothinking 后缀
+ if strings.HasSuffix(modelName, "-nothinking") {
+ return strings.TrimSuffix(modelName, "-nothinking")
+ }
+ // 去除模型名称中的 -thinking 后缀
+ if strings.HasSuffix(modelName, "-thinking") {
+ return strings.TrimSuffix(modelName, "-thinking")
+ }
+
+ // 去除模型名称中的 -thinking-number
+ if strings.Contains(modelName, "-thinking-") {
+ parts := strings.Split(modelName, "-thinking-")
+ if len(parts) > 1 {
+ return parts[0] + "-thinking"
+ }
+ }
+ return modelName
+}
+
+func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+ info.InitChannelMeta(c)
+
+ geminiReq, ok := info.Request.(*dto.GeminiChatRequest)
+ if !ok {
+ return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ request, err := common.DeepCopy(geminiReq)
+ if err != nil {
+ return types.NewError(fmt.Errorf("failed to copy request to GeminiChatRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ // model mapped 模型映射
+ err = helper.ModelMappedHelper(c, info, request)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
+ }
+
+ if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
+ if isNoThinkingRequest(request) {
+ // check is thinking
+ if !strings.Contains(info.OriginModelName, "-nothinking") {
+ // try to get no thinking model price
+ noThinkingModelName := info.OriginModelName + "-nothinking"
+ containPrice := helper.ContainPriceOrRatio(noThinkingModelName)
+ if containPrice {
+ info.OriginModelName = noThinkingModelName
+ info.UpstreamModelName = noThinkingModelName
+ }
+ }
+ }
+ if request.GenerationConfig.ThinkingConfig == nil {
+ gemini.ThinkingAdaptor(request, info)
+ }
+ }
+
+ adaptor := GetAdaptor(info.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ }
+
+ adaptor.Init(info)
+
+ // Clean up empty system instruction
+ if request.SystemInstructions != nil {
+ hasContent := false
+ for _, part := range request.SystemInstructions.Parts {
+ if part.Text != "" {
+ hasContent = true
+ break
+ }
+ }
+ if !hasContent {
+ request.SystemInstructions = nil
+ }
+ }
+
+ var requestBody io.Reader
+ if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
+ body, err := common.GetRequestBody(c)
+ if err != nil {
+ return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+ }
+ requestBody = bytes.NewReader(body)
+ } else {
+ // 使用 ConvertGeminiRequest 转换请求格式
+ convertedRequest, err := adaptor.ConvertGeminiRequest(c, info, request)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+ jsonData, err := common.Marshal(convertedRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+
+ // apply param override
+ if len(info.ParamOverride) > 0 {
+ jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+ }
+ }
+
+ logger.LogDebug(c, "Gemini request body: "+string(jsonData))
+
+ requestBody = bytes.NewReader(jsonData)
+ }
+
+ resp, err := adaptor.DoRequest(c, info, requestBody)
+ if err != nil {
+ logger.LogError(c, "Do gemini request failed: "+err.Error())
+ return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
+ }
+
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+
+ var httpResp *http.Response
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+ if httpResp.StatusCode != http.StatusOK {
+ newAPIError = service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ }
+
+ usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
+ if openaiErr != nil {
+ service.ResetStatusCode(openaiErr, statusCodeMappingStr)
+ return openaiErr
+ }
+
+ postConsumeQuota(c, info, usage.(*dto.Usage), "")
+ return nil
+}
+
+func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+ info.InitChannelMeta(c)
+
+ isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
+ info.IsGeminiBatchEmbedding = isBatch
+
+ var req dto.Request
+ var err error
+ var inputTexts []string
+
+ if isBatch {
+ batchRequest := &dto.GeminiBatchEmbeddingRequest{}
+ err = common.UnmarshalBodyReusable(c, batchRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+ req = batchRequest
+ for _, r := range batchRequest.Requests {
+ for _, part := range r.Content.Parts {
+ if part.Text != "" {
+ inputTexts = append(inputTexts, part.Text)
+ }
+ }
+ }
+ } else {
+ singleRequest := &dto.GeminiEmbeddingRequest{}
+ err = common.UnmarshalBodyReusable(c, singleRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+ req = singleRequest
+ for _, part := range singleRequest.Content.Parts {
+ if part.Text != "" {
+ inputTexts = append(inputTexts, part.Text)
+ }
+ }
+ }
+
+ err = helper.ModelMappedHelper(c, info, req)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
+ }
+
+ adaptor := GetAdaptor(info.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ }
+ adaptor.Init(info)
+
+ var requestBody io.Reader
+ jsonData, err := common.Marshal(req)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+
+ // apply param override
+ if len(info.ParamOverride) > 0 {
+ reqMap := make(map[string]interface{})
+ _ = common.Unmarshal(jsonData, &reqMap)
+ for key, value := range info.ParamOverride {
+ reqMap[key] = value
+ }
+ jsonData, err = common.Marshal(reqMap)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+ }
+ }
+ requestBody = bytes.NewReader(jsonData)
+
+ resp, err := adaptor.DoRequest(c, info, requestBody)
+ if err != nil {
+ logger.LogError(c, "Do gemini request failed: "+err.Error())
+ return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
+ }
+
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+ var httpResp *http.Response
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ if httpResp.StatusCode != http.StatusOK {
+ newAPIError = service.RelayErrorHandler(httpResp, false)
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ }
+
+ usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), info)
+ if openaiErr != nil {
+ service.ResetStatusCode(openaiErr, statusCodeMappingStr)
+ return openaiErr
+ }
+
+ postConsumeQuota(c, info, usage.(*dto.Usage), "")
+ return nil
+}
diff --git a/relay/helper/common.go b/relay/helper/common.go
index 35d983f7..5b3e7674 100644
--- a/relay/helper/common.go
+++ b/relay/helper/common.go
@@ -4,27 +4,41 @@ import (
"encoding/json"
"errors"
"fmt"
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/logger"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+ "github.com/gorilla/websocket"
)
+func FlushWriter(c *gin.Context) error {
+ if c.Writer == nil {
+ return nil
+ }
+ if flusher, ok := c.Writer.(http.Flusher); ok {
+ flusher.Flush()
+ return nil
+ }
+ return errors.New("streaming error: flusher not found")
+}
+
func SetEventStreamHeaders(c *gin.Context) {
- // 检查是否已经设置过头部
- if _, exists := c.Get("event_stream_headers_set"); exists {
- return
- }
-
- c.Writer.Header().Set("Content-Type", "text/event-stream")
- c.Writer.Header().Set("Cache-Control", "no-cache")
- c.Writer.Header().Set("Connection", "keep-alive")
- c.Writer.Header().Set("Transfer-Encoding", "chunked")
- c.Writer.Header().Set("X-Accel-Buffering", "no")
-
- // 设置标志,表示头部已经设置过
- c.Set("event_stream_headers_set", true)
+ // 检查是否已经设置过头部
+ if _, exists := c.Get("event_stream_headers_set"); exists {
+ return
+ }
+
+ // 设置标志,表示头部已经设置过
+ c.Set("event_stream_headers_set", true)
+
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("Transfer-Encoding", "chunked")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
}
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
@@ -35,49 +49,33 @@ func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
}
- if flusher, ok := c.Writer.(http.Flusher); ok {
- flusher.Flush()
- } else {
- return errors.New("streaming error: flusher not found")
- }
+ _ = FlushWriter(c)
return nil
}
func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
- if flusher, ok := c.Writer.(http.Flusher); ok {
- flusher.Flush()
- }
+ _ = FlushWriter(c)
}
func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)})
- if flusher, ok := c.Writer.(http.Flusher); ok {
- flusher.Flush()
- }
+ _ = FlushWriter(c)
}
func StringData(c *gin.Context, str string) error {
//str = strings.TrimPrefix(str, "data: ")
//str = strings.TrimSuffix(str, "\r")
c.Render(-1, common.CustomEvent{Data: "data: " + str})
- if flusher, ok := c.Writer.(http.Flusher); ok {
- flusher.Flush()
- } else {
- return errors.New("streaming error: flusher not found")
- }
+ _ = FlushWriter(c)
return nil
}
func PingData(c *gin.Context) error {
c.Writer.Write([]byte(": PING\n\n"))
- if flusher, ok := c.Writer.(http.Flusher); ok {
- flusher.Flush()
- } else {
- return errors.New("streaming error: flusher not found")
- }
+ _ = FlushWriter(c)
return nil
}
@@ -85,7 +83,7 @@ func ObjectData(c *gin.Context, object interface{}) error {
if object == nil {
return errors.New("object is nil")
}
- jsonData, err := json.Marshal(object)
+ jsonData, err := common.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
@@ -98,7 +96,7 @@ func Done(c *gin.Context) {
func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
if ws == nil {
- common.LogError(c, "websocket connection is nil")
+ logger.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
//common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
@@ -111,14 +109,17 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
return fmt.Errorf("error marshalling object: %w", err)
}
if ws == nil {
- common.LogError(c, "websocket connection is nil")
+ logger.LogError(c, "websocket connection is nil")
return errors.New("websocket connection is nil")
}
//common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
return ws.WriteMessage(1, jsonData)
}
-func WssError(c *gin.Context, ws *websocket.Conn, openaiError dto.OpenAIError) {
+func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) {
+ if ws == nil {
+ return
+ }
errorObj := &dto.RealtimeEvent{
Type: "error",
EventId: GetLocalRealtimeID(c),
@@ -137,6 +138,24 @@ func GetLocalRealtimeID(c *gin.Context) string {
return fmt.Sprintf("evt_%s", logID)
}
+func GenerateStartEmptyResponse(id string, createAt int64, model string, systemFingerprint *string) *dto.ChatCompletionsStreamResponse {
+ return &dto.ChatCompletionsStreamResponse{
+ Id: id,
+ Object: "chat.completion.chunk",
+ Created: createAt,
+ Model: model,
+ SystemFingerprint: systemFingerprint,
+ Choices: []dto.ChatCompletionsStreamResponseChoice{
+ {
+ Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
+ Role: "assistant",
+ Content: common.GetPointer(""),
+ },
+ },
+ },
+ }
+}
+
func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,
diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go
index 9bf67c03..5b64cd8b 100644
--- a/relay/helper/model_mapped.go
+++ b/relay/helper/model_mapped.go
@@ -4,12 +4,12 @@ import (
"encoding/json"
"errors"
"fmt"
- "one-api/relay/common"
-
"github.com/gin-gonic/gin"
+ "one-api/dto"
+ "one-api/relay/common"
)
-func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
+func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request dto.Request) error {
// map model name
modelMapping := c.GetString("model_mapping")
if modelMapping != "" && modelMapping != "{}" {
@@ -50,5 +50,8 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
info.UpstreamModelName = currentModel
}
}
+ if request != nil {
+ request.SetModelName(info.UpstreamModelName)
+ }
return nil
}
diff --git a/relay/helper/price.go b/relay/helper/price.go
index 89efa1da..fdc5b66d 100644
--- a/relay/helper/price.go
+++ b/relay/helper/price.go
@@ -2,33 +2,50 @@ package helper
import (
"fmt"
- "github.com/gin-gonic/gin"
"one-api/common"
- constant2 "one-api/constant"
relaycommon "one-api/relay/common"
- "one-api/setting"
- "one-api/setting/operation_setting"
+ "one-api/setting/ratio_setting"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
)
-type PriceData struct {
- ModelPrice float64
- ModelRatio float64
- CompletionRatio float64
- CacheRatio float64
- CacheCreationRatio float64
- ImageRatio float64
- GroupRatio float64
- UsePrice bool
- ShouldPreConsumedQuota int
+// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present
+func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) types.GroupRatioInfo {
+ groupRatioInfo := types.GroupRatioInfo{
+ GroupRatio: 1.0, // default ratio
+ GroupSpecialRatio: -1,
+ }
+
+ // check auto group
+ autoGroup, exists := ctx.Get("auto_group")
+ if exists {
+ if common.DebugEnabled {
+ println(fmt.Sprintf("final group: %s", autoGroup))
+ }
+ relayInfo.UsingGroup = autoGroup.(string)
+ }
+
+ // check user group special ratio
+ userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
+ if ok {
+ // user group special ratio
+ groupRatioInfo.GroupSpecialRatio = userGroupRatio
+ groupRatioInfo.GroupRatio = userGroupRatio
+ groupRatioInfo.HasSpecialRatio = true
+ } else {
+ // normal group ratio
+ groupRatioInfo.GroupRatio = ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
+ }
+
+ return groupRatioInfo
}
-func (p PriceData) ToSetting() string {
- return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
-}
+func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, meta *types.TokenCountMeta) (types.PriceData, error) {
+ modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false)
+
+ groupRatioInfo := HandleGroupRatio(c, info)
-func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
- modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
- groupRatio := setting.GetGroupRatio(info.Group)
var preConsumedQuota int
var modelRatio float64
var completionRatio float64
@@ -36,39 +53,40 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
var imageRatio float64
var cacheCreationRatio float64
if !usePrice {
- preConsumedTokens := common.PreConsumedQuota
- if maxTokens != 0 {
- preConsumedTokens = promptTokens + maxTokens
+ preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota)
+ if meta.MaxTokens != 0 {
+ preConsumedTokens += meta.MaxTokens
}
var success bool
- modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName)
+ var matchName string
+ modelRatio, success, matchName = ratio_setting.GetModelRatio(info.OriginModelName)
if !success {
acceptUnsetRatio := false
- if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok {
- b, ok := accept.(bool)
- if ok {
- acceptUnsetRatio = b
- }
+ if info.UserSetting.AcceptUnsetRatioModel {
+ acceptUnsetRatio = true
}
if !acceptUnsetRatio {
- return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
+ return types.PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName)
}
}
- completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
- cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
- cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
- imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
- ratio := modelRatio * groupRatio
+ completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName)
+ cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName)
+ cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName)
+ imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName)
+ ratio := modelRatio * groupRatioInfo.GroupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
+ if meta.ImagePriceRatio != 0 {
+ modelPrice = modelPrice * meta.ImagePriceRatio
+ }
+ preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
}
- priceData := PriceData{
+ priceData := types.PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
CompletionRatio: completionRatio,
- GroupRatio: groupRatio,
+ GroupRatioInfo: groupRatioInfo,
UsePrice: usePrice,
CacheRatio: cacheRatio,
ImageRatio: imageRatio,
@@ -79,16 +97,39 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
if common.DebugEnabled {
println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting()))
}
-
+ info.PriceData = priceData
return priceData, nil
}
+// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
+func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
+ groupRatioInfo := HandleGroupRatio(c, info)
+
+ modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
+ // 如果没有配置价格,则使用默认价格
+ if !success {
+ defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
+ if !ok {
+ modelPrice = 0.1
+ } else {
+ modelPrice = defaultPrice
+ }
+ }
+ quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
+ priceData := types.PerCallPriceData{
+ ModelPrice: modelPrice,
+ Quota: quota,
+ GroupRatioInfo: groupRatioInfo,
+ }
+ return priceData
+}
+
func ContainPriceOrRatio(modelName string) bool {
- _, ok := operation_setting.GetModelPrice(modelName, false)
+ _, ok := ratio_setting.GetModelPrice(modelName, false)
if ok {
return true
}
- _, ok = operation_setting.GetModelRatio(modelName)
+ _, ok, _ = ratio_setting.GetModelRatio(modelName)
if ok {
return true
}
diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go
index c1bc0d6e..725d178c 100644
--- a/relay/helper/stream_scanner.go
+++ b/relay/helper/stream_scanner.go
@@ -3,10 +3,12 @@ package helper
import (
"bufio"
"context"
+ "fmt"
"io"
"net/http"
"one-api/common"
"one-api/constant"
+ "one-api/logger"
relaycommon "one-api/relay/common"
"one-api/setting/operation_setting"
"strings"
@@ -19,7 +21,7 @@ import (
)
const (
- InitialScannerBufferSize = 1 << 20 // 1MB (1*1024*1024)
+ InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
DefaultPingInterval = 10 * time.Second
)
@@ -30,24 +32,26 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
return
}
- defer resp.Body.Close()
+ // 确保响应体总是被关闭
+ defer func() {
+ if resp.Body != nil {
+ resp.Body.Close()
+ }
+ }()
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
- if strings.HasPrefix(info.UpstreamModelName, "o") {
- // twice timeout for thinking model
- streamingTimeout *= 2
- }
var (
- stopChan = make(chan bool, 2)
+ stopChan = make(chan bool, 3) // 增加缓冲区避免阻塞
scanner = bufio.NewScanner(resp.Body)
ticker = time.NewTicker(streamingTimeout)
pingTicker *time.Ticker
- writeMutex sync.Mutex // Mutex to protect concurrent writes
+ writeMutex sync.Mutex // Mutex to protect concurrent writes
+ wg sync.WaitGroup // 用于等待所有 goroutine 退出
)
generalSettings := operation_setting.GetGeneralSetting()
- pingEnabled := generalSettings.PingIntervalEnabled
+ pingEnabled := generalSettings.PingIntervalEnabled && !info.DisablePing
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
if pingInterval <= 0 {
pingInterval = DefaultPingInterval
@@ -57,13 +61,39 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
pingTicker = time.NewTicker(pingInterval)
}
+ if common.DebugEnabled {
+ // print timeout and ping interval for debugging
+ println("relay timeout seconds:", common.RelayTimeout)
+ println("streaming timeout seconds:", int64(streamingTimeout.Seconds()))
+ println("ping interval seconds:", int64(pingInterval.Seconds()))
+ }
+
+ // 改进资源清理,确保所有 goroutine 正确退出
defer func() {
+ // 通知所有 goroutine 停止
+ common.SafeSendBool(stopChan, true)
+
ticker.Stop()
if pingTicker != nil {
pingTicker.Stop()
}
+
+ // 等待所有 goroutine 退出,最多等待5秒
+ done := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(5 * time.Second):
+ logger.LogError(c, "timeout waiting for goroutines to exit")
+ }
+
close(stopChan)
}()
+
scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
scanner.Split(bufio.ScanLines)
SetEventStreamHeaders(c)
@@ -73,35 +103,95 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
ctx = context.WithValue(ctx, "stop_chan", stopChan)
- // Handle ping data sending
+ // Handle ping data sending with improved error handling
if pingEnabled && pingTicker != nil {
+ wg.Add(1)
gopool.Go(func() {
+ defer func() {
+ wg.Done()
+ if r := recover(); r != nil {
+ logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
+ common.SafeSendBool(stopChan, true)
+ }
+ if common.DebugEnabled {
+ println("ping goroutine exited")
+ }
+ }()
+
+ // 添加超时保护,防止 goroutine 无限运行
+ maxPingDuration := 30 * time.Minute // 最大 ping 持续时间
+ pingTimeout := time.NewTimer(maxPingDuration)
+ defer pingTimeout.Stop()
+
for {
select {
case <-pingTicker.C:
- writeMutex.Lock() // Lock before writing
- err := PingData(c)
- writeMutex.Unlock() // Unlock after writing
- if err != nil {
- common.LogError(c, "ping data error: "+err.Error())
- common.SafeSendBool(stopChan, true)
+ // 使用超时机制防止写操作阻塞
+ done := make(chan error, 1)
+ go func() {
+ writeMutex.Lock()
+ defer writeMutex.Unlock()
+ done <- PingData(c)
+ }()
+
+ select {
+ case err := <-done:
+ if err != nil {
+ logger.LogError(c, "ping data error: "+err.Error())
+ return
+ }
+ if common.DebugEnabled {
+ println("ping data sent")
+ }
+ case <-time.After(10 * time.Second):
+ logger.LogError(c, "ping data send timeout")
+ return
+ case <-ctx.Done():
+ return
+ case <-stopChan:
return
}
- if common.DebugEnabled {
- println("ping data sent")
- }
case <-ctx.Done():
- if common.DebugEnabled {
- println("ping data goroutine stopped")
- }
+ return
+ case <-stopChan:
+ return
+ case <-c.Request.Context().Done():
+ // 监听客户端断开连接
+ return
+ case <-pingTimeout.C:
+ logger.LogError(c, "ping goroutine max duration reached")
return
}
}
})
}
+ // Scanner goroutine with improved error handling
+ wg.Add(1)
common.RelayCtxGo(ctx, func() {
+ defer func() {
+ wg.Done()
+ if r := recover(); r != nil {
+ logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
+ }
+ common.SafeSendBool(stopChan, true)
+ if common.DebugEnabled {
+ println("scanner goroutine exited")
+ }
+ }()
+
for scanner.Scan() {
+ // 检查是否需要停止
+ select {
+ case <-stopChan:
+ return
+ case <-ctx.Done():
+ return
+ case <-c.Request.Context().Done():
+ return
+ default:
+ }
+
ticker.Reset(streamingTimeout)
data := scanner.Text()
if common.DebugEnabled {
@@ -119,31 +209,54 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
data = strings.TrimSuffix(data, "\r")
if !strings.HasPrefix(data, "[DONE]") {
info.SetFirstResponseTime()
- writeMutex.Lock() // Lock before writing
- success := dataHandler(data)
- writeMutex.Unlock() // Unlock after writing
- if !success {
- break
+
+ // 使用超时机制防止写操作阻塞
+ done := make(chan bool, 1)
+ go func() {
+ writeMutex.Lock()
+ defer writeMutex.Unlock()
+ done <- dataHandler(data)
+ }()
+
+ select {
+ case success := <-done:
+ if !success {
+ return
+ }
+ case <-time.After(10 * time.Second):
+ logger.LogError(c, "data handler timeout")
+ return
+ case <-ctx.Done():
+ return
+ case <-stopChan:
+ return
}
+ } else {
+ // done, 处理完成标志,直接退出停止读取剩余数据防止出错
+ if common.DebugEnabled {
+ println("received [DONE], stopping scanner")
+ }
+ return
}
}
if err := scanner.Err(); err != nil {
if err != io.EOF {
- common.LogError(c, "scanner error: "+err.Error())
+ logger.LogError(c, "scanner error: "+err.Error())
}
}
-
- common.SafeSendBool(stopChan, true)
})
+ // 主循环等待完成或超时
select {
case <-ticker.C:
// 超时处理逻辑
- common.LogError(c, "streaming timeout")
- common.SafeSendBool(stopChan, true)
+ logger.LogError(c, "streaming timeout")
case <-stopChan:
// 正常结束
- common.LogInfo(c, "streaming finished")
+ logger.LogInfo(c, "streaming finished")
+ case <-c.Request.Context().Done():
+ // 客户端断开连接
+ logger.LogInfo(c, "client disconnected")
}
}
diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go
new file mode 100644
index 00000000..285f26aa
--- /dev/null
+++ b/relay/helper/valid_request.go
@@ -0,0 +1,301 @@
+package helper
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/logger"
+ relayconstant "one-api/relay/constant"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dto.Request, err error) {
+ relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
+
+ switch format {
+ case types.RelayFormatOpenAI:
+ request, err = GetAndValidateTextRequest(c, relayMode)
+ case types.RelayFormatGemini:
+ request, err = GetAndValidateGeminiRequest(c)
+ case types.RelayFormatClaude:
+ request, err = GetAndValidateClaudeRequest(c)
+ case types.RelayFormatOpenAIResponses:
+ request, err = GetAndValidateResponsesRequest(c)
+
+ case types.RelayFormatOpenAIImage:
+ request, err = GetAndValidOpenAIImageRequest(c, relayMode)
+ case types.RelayFormatEmbedding:
+ request, err = GetAndValidateEmbeddingRequest(c, relayMode)
+ case types.RelayFormatRerank:
+ request, err = GetAndValidateRerankRequest(c)
+ case types.RelayFormatOpenAIAudio:
+ request, err = GetAndValidAudioRequest(c, relayMode)
+ case types.RelayFormatOpenAIRealtime:
+ request = &dto.BaseRequest{}
+ default:
+ return nil, fmt.Errorf("unsupported relay format: %s", format)
+ }
+ return request, err
+}
+
+func GetAndValidAudioRequest(c *gin.Context, relayMode int) (*dto.AudioRequest, error) {
+ audioRequest := &dto.AudioRequest{}
+ err := common.UnmarshalBodyReusable(c, audioRequest)
+ if err != nil {
+ return nil, err
+ }
+ switch relayMode {
+ case relayconstant.RelayModeAudioSpeech:
+ if audioRequest.Model == "" {
+ return nil, errors.New("model is required")
+ }
+ default:
+ err = c.Request.ParseForm()
+ if err != nil {
+ return nil, err
+ }
+ formData := c.Request.PostForm
+ if audioRequest.Model == "" {
+ audioRequest.Model = formData.Get("model")
+ }
+
+ if audioRequest.Model == "" {
+ return nil, errors.New("model is required")
+ }
+ audioRequest.ResponseFormat = formData.Get("response_format")
+ if audioRequest.ResponseFormat == "" {
+ audioRequest.ResponseFormat = "json"
+ }
+ }
+ return audioRequest, nil
+}
+
+func GetAndValidateRerankRequest(c *gin.Context) (*dto.RerankRequest, error) {
+ var rerankRequest *dto.RerankRequest
+ err := common.UnmarshalBodyReusable(c, &rerankRequest)
+ if err != nil {
+ logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
+ return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ if rerankRequest.Query == "" {
+ return nil, types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+ if len(rerankRequest.Documents) == 0 {
+ return nil, types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+ return rerankRequest, nil
+}
+
+func GetAndValidateEmbeddingRequest(c *gin.Context, relayMode int) (*dto.EmbeddingRequest, error) {
+ var embeddingRequest *dto.EmbeddingRequest
+ err := common.UnmarshalBodyReusable(c, &embeddingRequest)
+ if err != nil {
+ logger.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
+ return nil, types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ if embeddingRequest.Input == nil {
+ return nil, fmt.Errorf("input is empty")
+ }
+ if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
+ embeddingRequest.Model = "omni-moderation-latest"
+ }
+ if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
+ embeddingRequest.Model = c.Param("model")
+ }
+ return embeddingRequest, nil
+}
+
+func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
+ request := &dto.OpenAIResponsesRequest{}
+ err := common.UnmarshalBodyReusable(c, request)
+ if err != nil {
+ return nil, err
+ }
+ if request.Model == "" {
+ return nil, errors.New("model is required")
+ }
+ if request.Input == nil {
+ return nil, errors.New("input is required")
+ }
+ return request, nil
+}
+
+func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageRequest, error) {
+ imageRequest := &dto.ImageRequest{}
+
+ switch relayMode {
+ case relayconstant.RelayModeImagesEdits:
+ _, err := c.MultipartForm()
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
+ }
+ formData := c.Request.PostForm
+ imageRequest.Prompt = formData.Get("prompt")
+ imageRequest.Model = formData.Get("model")
+ imageRequest.N = uint(common.String2Int(formData.Get("n")))
+ imageRequest.Quality = formData.Get("quality")
+ imageRequest.Size = formData.Get("size")
+
+ if imageRequest.Model == "gpt-image-1" {
+ if imageRequest.Quality == "" {
+ imageRequest.Quality = "standard"
+ }
+ }
+ if imageRequest.N == 0 {
+ imageRequest.N = 1
+ }
+
+ watermark := formData.Has("watermark")
+ if watermark {
+ imageRequest.Watermark = &watermark
+ }
+ default:
+ err := common.UnmarshalBodyReusable(c, imageRequest)
+ if err != nil {
+ return nil, err
+ }
+
+ if imageRequest.Model == "" {
+ imageRequest.Model = "dall-e-3"
+ }
+
+ if strings.Contains(imageRequest.Size, "×") {
+ return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
+ }
+
+ // Not "256x256", "512x512", or "1024x1024"
+ if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
+ if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
+ return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
+ }
+ if imageRequest.Size == "" {
+ imageRequest.Size = "1024x1024"
+ }
+ } else if imageRequest.Model == "dall-e-3" {
+ if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
+ return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
+ }
+ if imageRequest.Quality == "" {
+ imageRequest.Quality = "standard"
+ }
+ if imageRequest.Size == "" {
+ imageRequest.Size = "1024x1024"
+ }
+ } else if imageRequest.Model == "gpt-image-1" {
+ if imageRequest.Quality == "" {
+ imageRequest.Quality = "auto"
+ }
+ }
+
+ if imageRequest.Prompt == "" {
+ return nil, errors.New("prompt is required")
+ }
+
+ if imageRequest.N == 0 {
+ imageRequest.N = 1
+ }
+ }
+
+ return imageRequest, nil
+}
+
+func GetAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
+ textRequest = &dto.ClaudeRequest{}
+ err = c.ShouldBindJSON(textRequest)
+ if err != nil {
+ return nil, err
+ }
+ if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
+ return nil, errors.New("field messages is required")
+ }
+ if textRequest.Model == "" {
+ return nil, errors.New("field model is required")
+ }
+
+ //if textRequest.Stream {
+ // relayInfo.IsStream = true
+ //}
+
+ return textRequest, nil
+}
+
+func GetAndValidateTextRequest(c *gin.Context, relayMode int) (*dto.GeneralOpenAIRequest, error) {
+ textRequest := &dto.GeneralOpenAIRequest{}
+ err := common.UnmarshalBodyReusable(c, textRequest)
+ if err != nil {
+ return nil, err
+ }
+
+ if relayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
+ textRequest.Model = "text-moderation-latest"
+ }
+ if relayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
+ textRequest.Model = c.Param("model")
+ }
+
+ if textRequest.MaxTokens > math.MaxInt32/2 {
+ return nil, errors.New("max_tokens is invalid")
+ }
+ if textRequest.Model == "" {
+ return nil, errors.New("model is required")
+ }
+ if textRequest.WebSearchOptions != nil {
+ if textRequest.WebSearchOptions.SearchContextSize != "" {
+ validSizes := map[string]bool{
+ "high": true,
+ "medium": true,
+ "low": true,
+ }
+ if !validSizes[textRequest.WebSearchOptions.SearchContextSize] {
+ return nil, errors.New("invalid search_context_size, must be one of: high, medium, low")
+ }
+ } else {
+ textRequest.WebSearchOptions.SearchContextSize = "medium"
+ }
+ }
+ switch relayMode {
+ case relayconstant.RelayModeCompletions:
+ if textRequest.Prompt == "" {
+ return nil, errors.New("field prompt is required")
+ }
+ case relayconstant.RelayModeChatCompletions:
+ if len(textRequest.Messages) == 0 {
+ return nil, errors.New("field messages is required")
+ }
+ case relayconstant.RelayModeEmbeddings:
+ case relayconstant.RelayModeModerations:
+ if textRequest.Input == nil || textRequest.Input == "" {
+ return nil, errors.New("field input is required")
+ }
+ case relayconstant.RelayModeEdits:
+ if textRequest.Instruction == "" {
+ return nil, errors.New("field instruction is required")
+ }
+ }
+ return textRequest, nil
+}
+
+func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
+
+ request := &dto.GeminiChatRequest{}
+ err := common.UnmarshalBodyReusable(c, request)
+ if err != nil {
+ return nil, err
+ }
+ if len(request.Contents) == 0 {
+ return nil, errors.New("contents is required")
+ }
+
+ //if c.Query("alt") == "sse" {
+ // relayInfo.IsStream = true
+ //}
+
+ return request, nil
+}
diff --git a/relay/image_handler.go b/relay/image_handler.go
new file mode 100644
index 00000000..c700424f
--- /dev/null
+++ b/relay/image_handler.go
@@ -0,0 +1,127 @@
+package relay
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting/model_setting"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+ info.InitChannelMeta(c)
+
+ imageReq, ok := info.Request.(*dto.ImageRequest)
+ if !ok {
+ return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.ImageRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ request, err := common.DeepCopy(imageReq)
+ if err != nil {
+ return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ err = helper.ModelMappedHelper(c, info, request)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
+ }
+
+ adaptor := GetAdaptor(info.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ }
+ adaptor.Init(info)
+
+ var requestBody io.Reader
+
+ if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
+ body, err := common.GetRequestBody(c)
+ if err != nil {
+ return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+ }
+ requestBody = bytes.NewBuffer(body)
+ } else {
+ convertedRequest, err := adaptor.ConvertImageRequest(c, info, *request)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+ if info.RelayMode == relayconstant.RelayModeImagesEdits {
+ requestBody = convertedRequest.(io.Reader)
+ } else {
+ jsonData, err := json.Marshal(convertedRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+
+ // apply param override
+ if len(info.ParamOverride) > 0 {
+ jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+ }
+ }
+
+ if common.DebugEnabled {
+ println(fmt.Sprintf("image request body: %s", string(jsonData)))
+ }
+ requestBody = bytes.NewBuffer(jsonData)
+ }
+ }
+
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+
+ resp, err := adaptor.DoRequest(c, info, requestBody)
+ if err != nil {
+ return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
+ }
+ var httpResp *http.Response
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+ if httpResp.StatusCode != http.StatusOK {
+ newAPIError = service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ }
+
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
+ if newAPIError != nil {
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+
+ if usage.(*dto.Usage).TotalTokens == 0 {
+ usage.(*dto.Usage).TotalTokens = int(request.N)
+ }
+ if usage.(*dto.Usage).PromptTokens == 0 {
+ usage.(*dto.Usage).PromptTokens = int(request.N)
+ }
+
+ quality := "standard"
+ if request.Quality == "hd" {
+ quality = "hd"
+ }
+
+ var logContent string
+
+ if len(request.Size) > 0 {
+ logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality)
+ }
+
+ postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
+ return nil
+}
diff --git a/relay/relay-mj.go b/relay/mjproxy_handler.go
similarity index 77%
rename from relay/relay-mj.go
rename to relay/mjproxy_handler.go
index 9d0a2077..7c52cb6b 100644
--- a/relay/relay-mj.go
+++ b/relay/mjproxy_handler.go
@@ -13,9 +13,9 @@ import (
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
"one-api/service"
"one-api/setting"
- "one-api/setting/operation_setting"
"strconv"
"strings"
"time"
@@ -34,14 +34,13 @@ func RelayMidjourneyImage(c *gin.Context) {
}
var httpClient *http.Client
if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil {
- if proxy, ok := channel.GetSetting()["proxy"]; ok {
- if proxyURL, ok := proxy.(string); ok && proxyURL != "" {
- if httpClient, err = service.NewProxyHttpClient(proxyURL); err != nil {
- c.JSON(400, gin.H{
- "error": "proxy_url_invalid",
- })
- return
- }
+ proxy := channel.GetSetting().Proxy
+ if proxy != "" {
+ if httpClient, err = service.NewProxyHttpClient(proxy); err != nil {
+ c.JSON(400, gin.H{
+ "error": "proxy_url_invalid",
+ })
+ return
}
}
}
@@ -106,6 +105,9 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
midjourneyTask.StartTime = midjRequest.StartTime
midjourneyTask.FinishTime = midjRequest.FinishTime
midjourneyTask.ImageUrl = midjRequest.ImageUrl
+ midjourneyTask.VideoUrl = midjRequest.VideoUrl
+ videoUrlsStr, _ := json.Marshal(midjRequest.VideoUrls)
+ midjourneyTask.VideoUrls = string(videoUrlsStr)
midjourneyTask.Status = midjRequest.Status
midjourneyTask.FailReason = midjRequest.FailReason
err = midjourneyTask.Update()
@@ -136,6 +138,9 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
} else {
midjourneyTask.ImageUrl = originTask.ImageUrl
}
+ if originTask.VideoUrl != "" {
+ midjourneyTask.VideoUrl = originTask.VideoUrl
+ }
midjourneyTask.Status = originTask.Status
midjourneyTask.FailReason = originTask.FailReason
midjourneyTask.Action = originTask.Action
@@ -148,6 +153,13 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.Buttons = buttons
}
}
+ if originTask.VideoUrls != "" {
+ var videoUrls []dto.ImgUrls
+ err := json.Unmarshal([]byte(originTask.VideoUrls), &videoUrls)
+ if err == nil {
+ midjourneyTask.VideoUrls = videoUrls
+ }
+ }
if originTask.Properties != "" {
var properties dto.Properties
err := json.Unmarshal([]byte(originTask.Properties), &properties)
@@ -158,44 +170,31 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
return
}
-func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
- startTime := time.Now().UnixNano() / int64(time.Millisecond)
- tokenId := c.GetInt("token_id")
- userId := c.GetInt("id")
- group := c.GetString("group")
- channelId := c.GetInt("channel_id")
- relayInfo := relaycommon.GenRelayInfo(c)
+func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyResponse {
var swapFaceRequest dto.SwapFaceRequest
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
}
+
+ info.InitChannelMeta(c)
+
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
}
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
- modelPrice, success := operation_setting.GetModelPrice(modelName, true)
- // 如果没有配置价格,则使用默认价格
- if !success {
- defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
- if !ok {
- modelPrice = 0.1
- } else {
- modelPrice = defaultPrice
- }
- }
- groupRatio := setting.GetGroupRatio(group)
- ratio := modelPrice * groupRatio
- userQuota, err := model.GetUserQuota(userId, false)
+
+ priceData := helper.ModelPriceHelperPerCall(c, info)
+
+ userQuota, err := model.GetUserQuota(info.UserId, false)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: err.Error(),
}
}
- quota := int(ratio * common.QuotaPerUnit)
- if userQuota-quota < 0 {
+ if userQuota-priceData.Quota < 0 {
return &dto.MidjourneyResponse{
Code: 4,
Description: "quota_not_enough",
@@ -210,31 +209,31 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
}
defer func() {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
- err := service.PostConsumeQuota(relayInfo, quota, 0, true)
+ err := service.PostConsumeQuota(info, priceData.Quota, 0, true)
if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
- }
- //err = model.CacheUpdateUserQuota(userId)
- if err != nil {
- common.SysError("error update user quota cache: " + err.Error())
- }
- if quota != 0 {
- tokenName := c.GetString("token_name")
- logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
- other := make(map[string]interface{})
- other["model_price"] = modelPrice
- other["group_ratio"] = groupRatio
- model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
- quota, logContent, tokenId, userQuota, 0, false, group, other)
- model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
- channelId := c.GetInt("channel_id")
- model.UpdateChannelUsedQuota(channelId, quota)
+ common.SysLog("error consuming token remain quota: " + err.Error())
}
+
+ tokenName := c.GetString("token_name")
+ logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
+ other := service.GenerateMjOtherInfo(priceData)
+ model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
+ ChannelId: info.ChannelId,
+ ModelName: modelName,
+ TokenName: tokenName,
+ Quota: priceData.Quota,
+ Content: logContent,
+ TokenId: info.TokenId,
+ Group: info.UsingGroup,
+ Other: other,
+ })
+ model.UpdateUserUsedQuotaAndRequestCount(info.UserId, priceData.Quota)
+ model.UpdateChannelUsedQuota(info.ChannelId, priceData.Quota)
}
}()
midjResponse := &mjResp.Response
midjourneyTask := &model.Midjourney{
- UserId: userId,
+ UserId: info.UserId,
Code: midjResponse.Code,
Action: constant.MjActionSwapFace,
MjId: midjResponse.Result,
@@ -242,7 +241,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
PromptEn: "",
Description: midjResponse.Description,
State: "",
- SubmitTime: startTime,
+ SubmitTime: info.StartTime.UnixNano() / int64(time.Millisecond),
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
FinishTime: 0,
ImageUrl: "",
@@ -250,7 +249,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
Progress: "0%",
FailReason: "",
ChannelId: c.GetInt("channel_id"),
- Quota: quota,
+ Quota: priceData.Quota,
}
err = midjourneyTask.Insert()
if err != nil {
@@ -297,10 +296,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
if err != nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
}
- _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
- if err != nil {
- return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
- }
+ service.IOCopyBytesGracefully(c, nil, respBody)
return nil
}
@@ -369,14 +365,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
return nil
}
-func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
-
- tokenId := c.GetInt("token_id")
- //channelType := c.GetInt("channel")
- userId := c.GetInt("id")
- group := c.GetString("group")
- channelId := c.GetInt("channel_id")
- relayInfo := relaycommon.GenRelayInfo(c)
+func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.MidjourneyResponse {
consumeQuota := true
var midjRequest dto.MidjourneyRequest
err := common.UnmarshalBodyReusable(c, &midjRequest)
@@ -384,30 +373,37 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
}
- if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
+ relayInfo.InitChannelMeta(c)
+
+ if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
if mjErr != nil {
return mjErr
}
- relayMode = relayconstant.RelayModeMidjourneyChange
+ relayInfo.RelayMode = relayconstant.RelayModeMidjourneyChange
+ }
+ if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo {
+ midjRequest.Action = constant.MjActionVideo
}
- if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
+ if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
}
midjRequest.Action = constant.MjActionImagine
- } else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
midjRequest.Action = constant.MjActionDescribe
- } else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复
+ midjRequest.Action = constant.MjActionEdits
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
midjRequest.Action = constant.MjActionShorten
- } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = constant.MjActionBlend
- } else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
midjRequest.Action = constant.MjActionUpload
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
mjId := ""
- if relayMode == relayconstant.RelayModeMidjourneyChange {
+ if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
} else if midjRequest.Action == "" {
@@ -417,7 +413,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
//action = midjRequest.Action
mjId = midjRequest.TaskId
- } else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneySimpleChange {
if midjRequest.Content == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
}
@@ -427,20 +423,28 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
mjId = params.TaskId
midjRequest.Action = params.Action
- } else if relayMode == relayconstant.RelayModeMidjourneyModal {
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyModal {
//if midjRequest.MaskBase64 == "" {
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
//}
mjId = midjRequest.TaskId
midjRequest.Action = constant.MjActionModal
+ } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo {
+ midjRequest.Action = constant.MjActionVideo
+ if midjRequest.TaskId == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
+ } else if midjRequest.Action == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
+ }
+ mjId = midjRequest.TaskId
}
- originTask := model.GetByMJId(userId, mjId)
+ originTask := model.GetByMJId(relayInfo.UserId, mjId)
if originTask == nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
if setting.MjActionCheckSuccessEnabled {
- if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
+ if originTask.Status != "SUCCESS" && relayInfo.RelayMode != relayconstant.RelayModeMidjourneyModal {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
}
}
@@ -480,28 +484,18 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
modelName := service.CoverActionToModelName(midjRequest.Action)
- modelPrice, success := operation_setting.GetModelPrice(modelName, true)
- // 如果没有配置价格,则使用默认价格
- if !success {
- defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
- if !ok {
- modelPrice = 0.1
- } else {
- modelPrice = defaultPrice
- }
- }
- groupRatio := setting.GetGroupRatio(group)
- ratio := modelPrice * groupRatio
- userQuota, err := model.GetUserQuota(userId, false)
+
+ priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
+
+ userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
Description: err.Error(),
}
}
- quota := int(ratio * common.QuotaPerUnit)
- if consumeQuota && userQuota-quota < 0 {
+ if consumeQuota && userQuota-priceData.Quota < 0 {
return &dto.MidjourneyResponse{
Code: 4,
Description: "quota_not_enough",
@@ -516,22 +510,25 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
defer func() {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
- err := service.PostConsumeQuota(relayInfo, quota, 0, true)
+ err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
- }
- if quota != 0 {
- tokenName := c.GetString("token_name")
- logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result)
- other := make(map[string]interface{})
- other["model_price"] = modelPrice
- other["group_ratio"] = groupRatio
- model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
- quota, logContent, tokenId, userQuota, 0, false, group, other)
- model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
- channelId := c.GetInt("channel_id")
- model.UpdateChannelUsedQuota(channelId, quota)
+ common.SysLog("error consuming token remain quota: " + err.Error())
}
+ tokenName := c.GetString("token_name")
+ logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
+ other := service.GenerateMjOtherInfo(priceData)
+ model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
+ ChannelId: relayInfo.ChannelId,
+ ModelName: modelName,
+ TokenName: tokenName,
+ Quota: priceData.Quota,
+ Content: logContent,
+ TokenId: relayInfo.TokenId,
+ Group: relayInfo.UsingGroup,
+ Other: other,
+ })
+ model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, priceData.Quota)
+ model.UpdateChannelUsedQuota(relayInfo.ChannelId, priceData.Quota)
}
}()
@@ -543,7 +540,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
// other: 提交错误,description为错误描述
midjourneyTask := &model.Midjourney{
- UserId: userId,
+ UserId: relayInfo.UserId,
Code: midjResponse.Code,
Action: midjRequest.Action,
MjId: midjResponse.Result,
@@ -559,16 +556,16 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
Progress: "0%",
FailReason: "",
ChannelId: c.GetInt("channel_id"),
- Quota: quota,
+ Quota: priceData.Quota,
}
if midjResponse.Code == 3 {
//无实例账号自动禁用渠道(No available account instance)
channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
if err != nil {
- common.SysError("get_channel_null: " + err.Error())
+ common.SysLog("get_channel_null: " + err.Error())
}
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
- model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
+ model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")
}
}
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
diff --git a/relay/relay-audio.go b/relay/relay-audio.go
deleted file mode 100644
index deb45c58..00000000
--- a/relay/relay-audio.go
+++ /dev/null
@@ -1,137 +0,0 @@
-package relay
-
-import (
- "errors"
- "fmt"
- "github.com/gin-gonic/gin"
- "net/http"
- "one-api/common"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting"
- "strings"
-)
-
-func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
- audioRequest := &dto.AudioRequest{}
- err := common.UnmarshalBodyReusable(c, audioRequest)
- if err != nil {
- return nil, err
- }
- switch info.RelayMode {
- case relayconstant.RelayModeAudioSpeech:
- if audioRequest.Model == "" {
- return nil, errors.New("model is required")
- }
- if setting.ShouldCheckPromptSensitive() {
- words, err := service.CheckSensitiveInput(audioRequest.Input)
- if err != nil {
- common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
- return nil, err
- }
- }
- default:
- err = c.Request.ParseForm()
- if err != nil {
- return nil, err
- }
- formData := c.Request.PostForm
- if audioRequest.Model == "" {
- audioRequest.Model = formData.Get("model")
- }
-
- if audioRequest.Model == "" {
- return nil, errors.New("model is required")
- }
- audioRequest.ResponseFormat = formData.Get("response_format")
- if audioRequest.ResponseFormat == "" {
- audioRequest.ResponseFormat = "json"
- }
- }
- return audioRequest, nil
-}
-
-func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
- relayInfo := relaycommon.GenRelayInfo(c)
- audioRequest, err := getAndValidAudioRequest(c, relayInfo)
-
- if err != nil {
- common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
- return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest)
- }
-
- promptTokens := 0
- preConsumedTokens := common.PreConsumedQuota
- if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
- promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
- }
- preConsumedTokens = promptTokens
- relayInfo.PromptTokens = promptTokens
- }
-
- priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
- }
-
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if openaiErr != nil {
- return openaiErr
- }
- defer func() {
- if openaiErr != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
-
- err = helper.ModelMappedHelper(c, relayInfo)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
- }
-
- audioRequest.Model = relayInfo.UpstreamModelName
-
- adaptor := GetAdaptor(relayInfo.ApiType)
- if adaptor == nil {
- return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
- }
- adaptor.Init(relayInfo)
-
- ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
- }
-
- resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
- }
- statusCodeMappingStr := c.GetString("status_code_mapping")
-
- var httpResp *http.Response
- if resp != nil {
- httpResp = resp.(*http.Response)
- if httpResp.StatusCode != http.StatusOK {
- openaiErr = service.RelayErrorHandler(httpResp, false)
- // reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
- }
-
- usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
- if openaiErr != nil {
- // reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
-
- postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
-
- return nil
-}
diff --git a/relay/relay-gemini.go b/relay/relay-gemini.go
deleted file mode 100644
index 93a2b7aa..00000000
--- a/relay/relay-gemini.go
+++ /dev/null
@@ -1,157 +0,0 @@
-package relay
-
-import (
- "bytes"
- "encoding/json"
- "errors"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/relay/channel/gemini"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
- request := &gemini.GeminiChatRequest{}
- err := common.UnmarshalBodyReusable(c, request)
- if err != nil {
- return nil, err
- }
- if len(request.Contents) == 0 {
- return nil, errors.New("contents is required")
- }
- return request, nil
-}
-
-// 流模式
-// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx
-func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
- if c.Query("alt") == "sse" {
- relayInfo.IsStream = true
- }
-
- // if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
- // relayInfo.IsStream = true
- // }
-}
-
-func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
- var inputTexts []string
- for _, content := range textRequest.Contents {
- for _, part := range content.Parts {
- if part.Text != "" {
- inputTexts = append(inputTexts, part.Text)
- }
- }
- }
- if len(inputTexts) == 0 {
- return nil, nil
- }
-
- sensitiveWords, err := service.CheckSensitiveInput(inputTexts)
- return sensitiveWords, err
-}
-
-func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
- // 计算输入 token 数量
- var inputTexts []string
- for _, content := range req.Contents {
- for _, part := range content.Parts {
- if part.Text != "" {
- inputTexts = append(inputTexts, part.Text)
- }
- }
- }
-
- inputText := strings.Join(inputTexts, "\n")
- inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName)
- info.PromptTokens = inputTokens
- return inputTokens, err
-}
-
-func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
- req, err := getAndValidateGeminiRequest(c)
- if err != nil {
- common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
- return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
- }
-
- relayInfo := relaycommon.GenRelayInfo(c)
-
- // 检查 Gemini 流式模式
- checkGeminiStreamMode(c, relayInfo)
-
- if setting.ShouldCheckPromptSensitive() {
- sensitiveWords, err := checkGeminiInputSensitive(req)
- if err != nil {
- common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
- return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
- }
- }
-
- // model mapped 模型映射
- err = helper.ModelMappedHelper(c, relayInfo)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
- }
-
- if value, exists := c.Get("prompt_tokens"); exists {
- promptTokens := value.(int)
- relayInfo.SetPromptTokens(promptTokens)
- } else {
- promptTokens, err := getGeminiInputTokens(req, relayInfo)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
- }
- c.Set("prompt_tokens", promptTokens)
- }
-
- priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
- }
-
- // pre consume quota
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if openaiErr != nil {
- return openaiErr
- }
- defer func() {
- if openaiErr != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
-
- adaptor := GetAdaptor(relayInfo.ApiType)
- if adaptor == nil {
- return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
- }
-
- adaptor.Init(relayInfo)
-
- requestBody, err := json.Marshal(req)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
- }
-
- resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
- if err != nil {
- common.LogError(c, "Do gemini request failed: "+err.Error())
- return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
- }
-
- usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
- if openaiErr != nil {
- return openaiErr
- }
-
- postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
- return nil
-}
diff --git a/relay/relay-image.go b/relay/relay-image.go
deleted file mode 100644
index dc63cce8..00000000
--- a/relay/relay-image.go
+++ /dev/null
@@ -1,240 +0,0 @@
-package relay
-
-import (
- "bytes"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/model"
- relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) {
- imageRequest := &dto.ImageRequest{}
-
- switch info.RelayMode {
- case relayconstant.RelayModeImagesEdits:
- _, err := c.MultipartForm()
- if err != nil {
- return nil, err
- }
- formData := c.Request.PostForm
- imageRequest.Prompt = formData.Get("prompt")
- imageRequest.Model = formData.Get("model")
- imageRequest.N = common.String2Int(formData.Get("n"))
- imageRequest.Quality = formData.Get("quality")
- imageRequest.Size = formData.Get("size")
-
- if imageRequest.Model == "gpt-image-1" {
- if imageRequest.Quality == "" {
- imageRequest.Quality = "standard"
- }
- }
- if imageRequest.N == 0 {
- imageRequest.N = 1
- }
- default:
- err := common.UnmarshalBodyReusable(c, imageRequest)
- if err != nil {
- return nil, err
- }
-
- if imageRequest.Model == "" {
- imageRequest.Model = "dall-e-3"
- }
-
- if strings.Contains(imageRequest.Size, "×") {
- return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
- }
-
- // Not "256x256", "512x512", or "1024x1024"
- if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
- if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
- return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
- }
- if imageRequest.Size == "" {
- imageRequest.Size = "1024x1024"
- }
- } else if imageRequest.Model == "dall-e-3" {
- if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
- return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
- }
- if imageRequest.Quality == "" {
- imageRequest.Quality = "standard"
- }
- if imageRequest.Size == "" {
- imageRequest.Size = "1024x1024"
- }
- } else if imageRequest.Model == "gpt-image-1" {
- if imageRequest.Quality == "" {
- imageRequest.Quality = "auto"
- }
- }
-
- if imageRequest.Prompt == "" {
- return nil, errors.New("prompt is required")
- }
-
- if imageRequest.N == 0 {
- imageRequest.N = 1
- }
- }
-
- if setting.ShouldCheckPromptSensitive() {
- words, err := service.CheckSensitiveInput(imageRequest.Prompt)
- if err != nil {
- common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
- return nil, err
- }
- }
- return imageRequest, nil
-}
-
-func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
- relayInfo := relaycommon.GenRelayInfo(c)
-
- imageRequest, err := getAndValidImageRequest(c, relayInfo)
- if err != nil {
- common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
- return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
- }
-
- err = helper.ModelMappedHelper(c, relayInfo)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
- }
-
- imageRequest.Model = relayInfo.UpstreamModelName
-
- priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
- }
- var preConsumedQuota int
- var quota int
- var userQuota int
- if !priceData.UsePrice {
- // modelRatio 16 = modelPrice $0.04
- // per 1 modelRatio = $0.04 / 16
- // priceData.ModelPrice = 0.0025 * priceData.ModelRatio
- var openaiErr *dto.OpenAIErrorWithStatusCode
- preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if openaiErr != nil {
- return openaiErr
- }
- defer func() {
- if openaiErr != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
-
- } else {
- sizeRatio := 1.0
- // Size
- if imageRequest.Size == "256x256" {
- sizeRatio = 0.4
- } else if imageRequest.Size == "512x512" {
- sizeRatio = 0.45
- } else if imageRequest.Size == "1024x1024" {
- sizeRatio = 1
- } else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
- sizeRatio = 2
- }
-
- qualityRatio := 1.0
- if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
- qualityRatio = 2.0
- if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
- qualityRatio = 1.5
- }
- }
-
- // reset model price
- priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
- quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
- userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
- }
- if userQuota-quota < 0 {
- return service.OpenAIErrorWrapperLocal(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), "insufficient_user_quota", http.StatusForbidden)
- }
- }
-
- adaptor := GetAdaptor(relayInfo.ApiType)
- if adaptor == nil {
- return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
- }
- adaptor.Init(relayInfo)
-
- var requestBody io.Reader
-
- convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
- }
- if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
- requestBody = convertedRequest.(io.Reader)
- } else {
- jsonData, err := json.Marshal(convertedRequest)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(jsonData)
- }
-
- if common.DebugEnabled {
- println(fmt.Sprintf("image request body: %s", requestBody))
- }
-
- statusCodeMappingStr := c.GetString("status_code_mapping")
-
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
- }
- var httpResp *http.Response
- if resp != nil {
- httpResp = resp.(*http.Response)
- relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
- if httpResp.StatusCode != http.StatusOK {
- openaiErr := service.RelayErrorHandler(httpResp, false)
- // reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
- }
-
- usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
- if openaiErr != nil {
- // reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
-
- if usage.(*dto.Usage).TotalTokens == 0 {
- usage.(*dto.Usage).TotalTokens = imageRequest.N
- }
- if usage.(*dto.Usage).PromptTokens == 0 {
- usage.(*dto.Usage).PromptTokens = imageRequest.N
- }
- quality := "standard"
- if imageRequest.Quality == "hd" {
- quality = "hd"
- }
-
- logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
- postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, logContent)
- return nil
-}
diff --git a/relay/relay-responses.go b/relay/relay-responses.go
deleted file mode 100644
index fd3ddb5a..00000000
--- a/relay/relay-responses.go
+++ /dev/null
@@ -1,171 +0,0 @@
-package relay
-
-import (
- "bytes"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "one-api/common"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting"
- "one-api/setting/model_setting"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
- request := &dto.OpenAIResponsesRequest{}
- err := common.UnmarshalBodyReusable(c, request)
- if err != nil {
- return nil, err
- }
- if request.Model == "" {
- return nil, errors.New("model is required")
- }
- if len(request.Input) == 0 {
- return nil, errors.New("input is required")
- }
- return request, nil
-
-}
-
-func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) {
- sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input)
- return sensitiveWords, err
-}
-
-func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) {
- inputTokens, err := service.CountTokenInput(req.Input, req.Model)
- info.PromptTokens = inputTokens
- return inputTokens, err
-}
-
-func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
- req, err := getAndValidateResponsesRequest(c)
- if err != nil {
- common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
- return service.OpenAIErrorWrapperLocal(err, "invalid_responses_request", http.StatusBadRequest)
- }
-
- relayInfo := relaycommon.GenRelayInfoResponses(c, req)
-
- if setting.ShouldCheckPromptSensitive() {
- sensitiveWords, err := checkInputSensitive(req, relayInfo)
- if err != nil {
- common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
- return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
- }
- }
-
- err = helper.ModelMappedHelper(c, relayInfo)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
- }
- req.Model = relayInfo.UpstreamModelName
- if value, exists := c.Get("prompt_tokens"); exists {
- promptTokens := value.(int)
- relayInfo.SetPromptTokens(promptTokens)
- } else {
- promptTokens, err := getInputTokens(req, relayInfo)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
- }
- c.Set("prompt_tokens", promptTokens)
- }
-
- priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
- }
- // pre consume quota
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if openaiErr != nil {
- return openaiErr
- }
- defer func() {
- if openaiErr != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
- adaptor := GetAdaptor(relayInfo.ApiType)
- if adaptor == nil {
- return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
- }
- adaptor.Init(relayInfo)
- var requestBody io.Reader
- if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
- body, err := common.GetRequestBody(c)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "get_request_body_error", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(body)
- } else {
- convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "convert_request_error", http.StatusBadRequest)
- }
- jsonData, err := json.Marshal(convertedRequest)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "marshal_request_error", http.StatusInternalServerError)
- }
- // apply param override
- if len(relayInfo.ParamOverride) > 0 {
- reqMap := make(map[string]interface{})
- err = json.Unmarshal(jsonData, &reqMap)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError)
- }
- for key, value := range relayInfo.ParamOverride {
- reqMap[key] = value
- }
- jsonData, err = json.Marshal(reqMap)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError)
- }
- }
-
- if common.DebugEnabled {
- println("requestBody: ", string(jsonData))
- }
- requestBody = bytes.NewBuffer(jsonData)
- }
-
- var httpResp *http.Response
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
- }
-
- statusCodeMappingStr := c.GetString("status_code_mapping")
-
- if resp != nil {
- httpResp = resp.(*http.Response)
-
- if httpResp.StatusCode != http.StatusOK {
- openaiErr = service.RelayErrorHandler(httpResp, false)
- // reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
- }
-
- usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
- if openaiErr != nil {
- // reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
-
- if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
- service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
- } else {
- postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
- }
- return nil
-}
diff --git a/relay/relay-text.go b/relay/relay-text.go
deleted file mode 100644
index f1105907..00000000
--- a/relay/relay-text.go
+++ /dev/null
@@ -1,517 +0,0 @@
-package relay
-
-import (
- "bytes"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "math"
- "net/http"
- "one-api/common"
- "one-api/constant"
- "one-api/dto"
- "one-api/model"
- relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
- "one-api/relay/helper"
- "one-api/service"
- "one-api/setting"
- "one-api/setting/model_setting"
- "one-api/setting/operation_setting"
- "strings"
- "time"
-
- "github.com/bytedance/gopkg/util/gopool"
- "github.com/shopspring/decimal"
-
- "github.com/gin-gonic/gin"
-)
-
-func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
- textRequest := &dto.GeneralOpenAIRequest{}
- err := common.UnmarshalBodyReusable(c, textRequest)
- if err != nil {
- return nil, err
- }
- if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
- textRequest.Model = "text-moderation-latest"
- }
- if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
- textRequest.Model = c.Param("model")
- }
-
- if textRequest.MaxTokens > math.MaxInt32/2 {
- return nil, errors.New("max_tokens is invalid")
- }
- if textRequest.Model == "" {
- return nil, errors.New("model is required")
- }
- if textRequest.WebSearchOptions != nil {
- if textRequest.WebSearchOptions.SearchContextSize != "" {
- validSizes := map[string]bool{
- "high": true,
- "medium": true,
- "low": true,
- }
- if !validSizes[textRequest.WebSearchOptions.SearchContextSize] {
- return nil, errors.New("invalid search_context_size, must be one of: high, medium, low")
- }
- } else {
- textRequest.WebSearchOptions.SearchContextSize = "medium"
- }
- }
- switch relayInfo.RelayMode {
- case relayconstant.RelayModeCompletions:
- if textRequest.Prompt == "" {
- return nil, errors.New("field prompt is required")
- }
- case relayconstant.RelayModeChatCompletions:
- if len(textRequest.Messages) == 0 {
- return nil, errors.New("field messages is required")
- }
- case relayconstant.RelayModeEmbeddings:
- case relayconstant.RelayModeModerations:
- if textRequest.Input == nil || textRequest.Input == "" {
- return nil, errors.New("field input is required")
- }
- case relayconstant.RelayModeEdits:
- if textRequest.Instruction == "" {
- return nil, errors.New("field instruction is required")
- }
- }
- relayInfo.IsStream = textRequest.Stream
- return textRequest, nil
-}
-
-func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
-
- relayInfo := relaycommon.GenRelayInfo(c)
-
- // get & validate textRequest 获取并验证文本请求
- textRequest, err := getAndValidateTextRequest(c, relayInfo)
- if textRequest.WebSearchOptions != nil {
- c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
- }
-
- if err != nil {
- common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
- return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
- }
-
- if setting.ShouldCheckPromptSensitive() {
- words, err := checkRequestSensitive(textRequest, relayInfo)
- if err != nil {
- common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
- return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
- }
- }
-
- err = helper.ModelMappedHelper(c, relayInfo)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
- }
-
- textRequest.Model = relayInfo.UpstreamModelName
-
- // 获取 promptTokens,如果上下文中已经存在,则直接使用
- var promptTokens int
- if value, exists := c.Get("prompt_tokens"); exists {
- promptTokens = value.(int)
- relayInfo.PromptTokens = promptTokens
- } else {
- promptTokens, err = getPromptTokens(textRequest, relayInfo)
- // count messages token error 计算promptTokens错误
- if err != nil {
- return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
- }
- c.Set("prompt_tokens", promptTokens)
- }
-
- priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
- }
-
- // pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if openaiErr != nil {
- return openaiErr
- }
- defer func() {
- if openaiErr != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
- includeUsage := false
- // 判断用户是否需要返回使用情况
- if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
- includeUsage = true
- }
-
- // 如果不支持StreamOptions,将StreamOptions设置为nil
- if !relayInfo.SupportStreamOptions || !textRequest.Stream {
- textRequest.StreamOptions = nil
- } else {
- // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
- if constant.ForceStreamOption {
- textRequest.StreamOptions = &dto.StreamOptions{
- IncludeUsage: true,
- }
- }
- }
-
- if includeUsage {
- relayInfo.ShouldIncludeUsage = true
- }
-
- adaptor := GetAdaptor(relayInfo.ApiType)
- if adaptor == nil {
- return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
- }
- adaptor.Init(relayInfo)
- var requestBody io.Reader
-
- if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
- body, err := common.GetRequestBody(c)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "get_request_body_failed", http.StatusInternalServerError)
- }
- requestBody = bytes.NewBuffer(body)
- } else {
- convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
- }
- jsonData, err := json.Marshal(convertedRequest)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
- }
-
- // apply param override
- if len(relayInfo.ParamOverride) > 0 {
- reqMap := make(map[string]interface{})
- err = json.Unmarshal(jsonData, &reqMap)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError)
- }
- for key, value := range relayInfo.ParamOverride {
- reqMap[key] = value
- }
- jsonData, err = json.Marshal(reqMap)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError)
- }
- }
-
- if common.DebugEnabled {
- println("requestBody: ", string(jsonData))
- }
- requestBody = bytes.NewBuffer(jsonData)
- }
-
- var httpResp *http.Response
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
-
- if err != nil {
- return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
- }
-
- statusCodeMappingStr := c.GetString("status_code_mapping")
-
- if resp != nil {
- httpResp = resp.(*http.Response)
- relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
- if httpResp.StatusCode != http.StatusOK {
- openaiErr = service.RelayErrorHandler(httpResp, false)
- // reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
- }
-
- usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
- if openaiErr != nil {
- // reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
-
- if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
- service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
- } else {
- postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
- }
- return nil
-}
-
-func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
- var promptTokens int
- var err error
- switch info.RelayMode {
- case relayconstant.RelayModeChatCompletions:
- promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
- case relayconstant.RelayModeCompletions:
- promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
- case relayconstant.RelayModeModerations:
- promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
- case relayconstant.RelayModeEmbeddings:
- promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
- default:
- err = errors.New("unknown relay mode")
- promptTokens = 0
- }
- info.PromptTokens = promptTokens
- return promptTokens, err
-}
-
-func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) {
- var err error
- var words []string
- switch info.RelayMode {
- case relayconstant.RelayModeChatCompletions:
- words, err = service.CheckSensitiveMessages(textRequest.Messages)
- case relayconstant.RelayModeCompletions:
- words, err = service.CheckSensitiveInput(textRequest.Prompt)
- case relayconstant.RelayModeModerations:
- words, err = service.CheckSensitiveInput(textRequest.Input)
- case relayconstant.RelayModeEmbeddings:
- words, err = service.CheckSensitiveInput(textRequest.Input)
- }
- return words, err
-}
-
-// 预扣费并返回用户剩余配额
-func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) {
- userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
- if err != nil {
- return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
- }
- if userQuota <= 0 {
- return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
- }
- if userQuota-preConsumedQuota < 0 {
- return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), "insufficient_user_quota", http.StatusForbidden)
- }
- relayInfo.UserQuota = userQuota
- if userQuota > 100*preConsumedQuota {
- // 用户额度充足,判断令牌额度是否充足
- if !relayInfo.TokenUnlimited {
- // 非无限令牌,判断令牌额度是否充足
- tokenQuota := c.GetInt("token_quota")
- if tokenQuota > 100*preConsumedQuota {
- // 令牌额度充足,信任令牌
- preConsumedQuota = 0
- common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
- }
- } else {
- // in this case, we do not pre-consume quota
- // because the user has enough quota
- preConsumedQuota = 0
- common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
- }
- }
-
- if preConsumedQuota > 0 {
- err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
- if err != nil {
- return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
- }
- err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
- if err != nil {
- return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
- }
- }
- return preConsumedQuota, userQuota, nil
-}
-
-func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
- if preConsumedQuota != 0 {
- gopool.Go(func() {
- relayInfoCopy := *relayInfo
-
- err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
- if err != nil {
- common.SysError("error return pre-consumed quota: " + err.Error())
- }
- })
- }
-}
-
-func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
- usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
- if usage == nil {
- usage = &dto.Usage{
- PromptTokens: relayInfo.PromptTokens,
- CompletionTokens: 0,
- TotalTokens: relayInfo.PromptTokens,
- }
- extraContent += "(可能是请求出错)"
- }
- useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
- promptTokens := usage.PromptTokens
- cacheTokens := usage.PromptTokensDetails.CachedTokens
- imageTokens := usage.PromptTokensDetails.ImageTokens
- completionTokens := usage.CompletionTokens
- modelName := relayInfo.OriginModelName
-
- tokenName := ctx.GetString("token_name")
- completionRatio := priceData.CompletionRatio
- cacheRatio := priceData.CacheRatio
- imageRatio := priceData.ImageRatio
- modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatio
- modelPrice := priceData.ModelPrice
-
- // Convert values to decimal for precise calculation
- dPromptTokens := decimal.NewFromInt(int64(promptTokens))
- dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
- dImageTokens := decimal.NewFromInt(int64(imageTokens))
- dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
- dCompletionRatio := decimal.NewFromFloat(completionRatio)
- dCacheRatio := decimal.NewFromFloat(cacheRatio)
- dImageRatio := decimal.NewFromFloat(imageRatio)
- dModelRatio := decimal.NewFromFloat(modelRatio)
- dGroupRatio := decimal.NewFromFloat(groupRatio)
- dModelPrice := decimal.NewFromFloat(modelPrice)
- dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
-
- ratio := dModelRatio.Mul(dGroupRatio)
-
- // openai web search 工具计费
- var dWebSearchQuota decimal.Decimal
- var webSearchPrice float64
- if relayInfo.ResponsesUsageInfo != nil {
- if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
- // 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率)
- webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, webSearchTool.SearchContextSize)
- dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
- Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
- Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
- extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
- webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String())
- }
- } else if strings.HasSuffix(modelName, "search-preview") {
- // search-preview 模型不支持 response api
- searchContextSize := ctx.GetString("chat_completion_web_search_context_size")
- if searchContextSize == "" {
- searchContextSize = "medium"
- }
- webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
- dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
- Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
- extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
- searchContextSize, dWebSearchQuota.String())
- }
- // file search tool 计费
- var dFileSearchQuota decimal.Decimal
- var fileSearchPrice float64
- if relayInfo.ResponsesUsageInfo != nil {
- if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
- fileSearchPrice = operation_setting.GetFileSearchPricePerThousand()
- dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
- Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
- Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
- extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 $%s",
- fileSearchTool.CallCount, dFileSearchQuota.String())
- }
- }
-
- var quotaCalculateDecimal decimal.Decimal
- if !priceData.UsePrice {
- nonCachedTokens := dPromptTokens.Sub(dCacheTokens)
- cachedTokensWithRatio := dCacheTokens.Mul(dCacheRatio)
-
- promptQuota := nonCachedTokens.Add(cachedTokensWithRatio)
- if imageTokens > 0 {
- nonImageTokens := dPromptTokens.Sub(dImageTokens)
- imageTokensWithRatio := dImageTokens.Mul(dImageRatio)
- promptQuota = nonImageTokens.Add(imageTokensWithRatio)
- }
-
- completionQuota := dCompletionTokens.Mul(dCompletionRatio)
-
- quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
-
- if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) {
- quotaCalculateDecimal = decimal.NewFromInt(1)
- }
- } else {
- quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
- }
- // 添加 responses tools call 调用的配额
- quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
- quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
-
- quota := int(quotaCalculateDecimal.Round(0).IntPart())
- totalTokens := promptTokens + completionTokens
-
- var logContent string
- if !priceData.UsePrice {
- logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
- } else {
- logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
- }
-
- // record all the consume log even if quota is 0
- if totalTokens == 0 {
- // in this case, must be some error happened
- // we cannot just return, because we may have to return the pre-consumed quota
- quota = 0
- logContent += fmt.Sprintf("(可能是上游超时)")
- common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
- "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
- } else {
- model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
- model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
- }
-
- quotaDelta := quota - preConsumedQuota
- if quotaDelta != 0 {
- err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
- if err != nil {
- common.LogError(ctx, "error consuming token remain quota: "+err.Error())
- }
- }
-
- logModel := modelName
- if strings.HasPrefix(logModel, "gpt-4-gizmo") {
- logModel = "gpt-4-gizmo-*"
- logContent += fmt.Sprintf(",模型 %s", modelName)
- }
- if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
- logModel = "gpt-4o-gizmo-*"
- logContent += fmt.Sprintf(",模型 %s", modelName)
- }
- if extraContent != "" {
- logContent += ", " + extraContent
- }
- other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
- if imageTokens != 0 {
- other["image"] = true
- other["image_ratio"] = imageRatio
- other["image_output"] = imageTokens
- }
- if !dWebSearchQuota.IsZero() {
- if relayInfo.ResponsesUsageInfo != nil {
- if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
- other["web_search"] = true
- other["web_search_call_count"] = webSearchTool.CallCount
- other["web_search_price"] = webSearchPrice
- }
- } else if strings.HasSuffix(modelName, "search-preview") {
- other["web_search"] = true
- other["web_search_call_count"] = 1
- other["web_search_price"] = webSearchPrice
- }
- }
- if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil {
- if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists {
- other["file_search"] = true
- other["file_search_call_count"] = fileSearchTool.CallCount
- other["file_search_price"] = fileSearchPrice
- }
- }
- model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
- tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
-}
diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go
index 7bf0da9f..1ee85986 100644
--- a/relay/relay_adaptor.go
+++ b/relay/relay_adaptor.go
@@ -1,7 +1,8 @@
package relay
import (
- commonconstant "one-api/constant"
+ "github.com/gin-gonic/gin"
+ "one-api/constant"
"one-api/relay/channel"
"one-api/relay/channel/ali"
"one-api/relay/channel/aws"
@@ -14,15 +15,20 @@ import (
"one-api/relay/channel/deepseek"
"one-api/relay/channel/dify"
"one-api/relay/channel/gemini"
+ "one-api/relay/channel/jimeng"
"one-api/relay/channel/jina"
"one-api/relay/channel/mistral"
"one-api/relay/channel/mokaai"
+ "one-api/relay/channel/moonshot"
"one-api/relay/channel/ollama"
"one-api/relay/channel/openai"
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
+ taskjimeng "one-api/relay/channel/task/jimeng"
+ "one-api/relay/channel/task/kling"
"one-api/relay/channel/task/suno"
+ taskVidu "one-api/relay/channel/task/vidu"
"one-api/relay/channel/tencent"
"one-api/relay/channel/vertex"
"one-api/relay/channel/volcengine"
@@ -30,7 +36,7 @@ import (
"one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu"
"one-api/relay/channel/zhipu_4v"
- "one-api/relay/constant"
+ "strconv"
)
func GetAdaptor(apiType int) channel.Adaptor {
@@ -91,16 +97,38 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &xai.Adaptor{}
case constant.APITypeCoze:
return &coze.Adaptor{}
+ case constant.APITypeJimeng:
+ return &jimeng.Adaptor{}
+ case constant.APITypeMoonshot:
+ return &moonshot.Adaptor{} // Moonshot uses Claude API
}
return nil
}
-func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
+func GetTaskPlatform(c *gin.Context) constant.TaskPlatform {
+ channelType := c.GetInt("channel_type")
+ if channelType > 0 {
+ return constant.TaskPlatform(strconv.Itoa(channelType))
+ }
+ return constant.TaskPlatform(c.GetString("platform"))
+}
+
+func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
switch platform {
//case constant.APITypeAIProxyLibrary:
// return &aiproxy.Adaptor{}
- case commonconstant.TaskPlatformSuno:
+ case constant.TaskPlatformSuno:
return &suno.TaskAdaptor{}
}
+ if channelType, err := strconv.ParseInt(string(platform), 10, 64); err == nil {
+ switch channelType {
+ case constant.ChannelTypeKling:
+ return &kling.TaskAdaptor{}
+ case constant.ChannelTypeJimeng:
+ return &taskjimeng.TaskAdaptor{}
+ case constant.ChannelTypeVidu:
+ return &taskVidu.TaskAdaptor{}
+ }
+ }
return nil
}
diff --git a/relay/relay_embedding.go b/relay/relay_embedding.go
deleted file mode 100644
index b4909849..00000000
--- a/relay/relay_embedding.go
+++ /dev/null
@@ -1,116 +0,0 @@
-package relay
-
-import (
- "bytes"
- "encoding/json"
- "fmt"
- "github.com/gin-gonic/gin"
- "net/http"
- "one-api/common"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- relayconstant "one-api/relay/constant"
- "one-api/relay/helper"
- "one-api/service"
-)
-
-func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
- token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
- return token
-}
-
-func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error {
- if embeddingRequest.Input == nil {
- return fmt.Errorf("input is empty")
- }
- if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
- embeddingRequest.Model = "omni-moderation-latest"
- }
- if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
- embeddingRequest.Model = c.Param("model")
- }
- return nil
-}
-
-func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
- relayInfo := relaycommon.GenRelayInfo(c)
-
- var embeddingRequest *dto.EmbeddingRequest
- err := common.UnmarshalBodyReusable(c, &embeddingRequest)
- if err != nil {
- common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
- return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
- }
-
- err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
- }
-
- err = helper.ModelMappedHelper(c, relayInfo)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
- }
-
- embeddingRequest.Model = relayInfo.UpstreamModelName
-
- promptToken := getEmbeddingPromptToken(*embeddingRequest)
- relayInfo.PromptTokens = promptToken
-
- priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
- }
- // pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if openaiErr != nil {
- return openaiErr
- }
- defer func() {
- if openaiErr != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
-
- adaptor := GetAdaptor(relayInfo.ApiType)
- if adaptor == nil {
- return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
- }
- adaptor.Init(relayInfo)
-
- convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
-
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
- }
- jsonData, err := json.Marshal(convertedRequest)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
- }
- requestBody := bytes.NewBuffer(jsonData)
- statusCodeMappingStr := c.GetString("status_code_mapping")
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
- }
-
- var httpResp *http.Response
- if resp != nil {
- httpResp = resp.(*http.Response)
- if httpResp.StatusCode != http.StatusOK {
- openaiErr = service.RelayErrorHandler(httpResp, false)
- // reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
- }
-
- usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
- if openaiErr != nil {
- // reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
- postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
- return nil
-}
diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go
deleted file mode 100644
index 6ca98de7..00000000
--- a/relay/relay_rerank.go
+++ /dev/null
@@ -1,110 +0,0 @@
-package relay
-
-import (
- "bytes"
- "encoding/json"
- "fmt"
- "github.com/gin-gonic/gin"
- "net/http"
- "one-api/common"
- "one-api/dto"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/service"
-)
-
-func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
- token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
- for _, document := range rerankRequest.Documents {
- tkm, err := service.CountTokenInput(document, rerankRequest.Model)
- if err == nil {
- token += tkm
- }
- }
- return token
-}
-
-func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
-
- var rerankRequest *dto.RerankRequest
- err := common.UnmarshalBodyReusable(c, &rerankRequest)
- if err != nil {
- common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
- return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
- }
-
- relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
-
- if rerankRequest.Query == "" {
- return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
- }
- if len(rerankRequest.Documents) == 0 {
- return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
- }
-
- err = helper.ModelMappedHelper(c, relayInfo)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
- }
-
- rerankRequest.Model = relayInfo.UpstreamModelName
-
- promptToken := getRerankPromptToken(*rerankRequest)
- relayInfo.PromptTokens = promptToken
-
- priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
- }
- // pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
- if openaiErr != nil {
- return openaiErr
- }
- defer func() {
- if openaiErr != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
-
- adaptor := GetAdaptor(relayInfo.ApiType)
- if adaptor == nil {
- return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
- }
- adaptor.Init(relayInfo)
-
- convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
- }
- jsonData, err := json.Marshal(convertedRequest)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
- }
- requestBody := bytes.NewBuffer(jsonData)
- statusCodeMappingStr := c.GetString("status_code_mapping")
- resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
- if err != nil {
- return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
- }
-
- var httpResp *http.Response
- if resp != nil {
- httpResp = resp.(*http.Response)
- if httpResp.StatusCode != http.StatusOK {
- openaiErr = service.RelayErrorHandler(httpResp, false)
- // reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
- }
-
- usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
- if openaiErr != nil {
- // reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
- }
- postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
- return nil
-}
diff --git a/relay/relay_task.go b/relay/relay_task.go
index 26874ba6..95b8083b 100644
--- a/relay/relay_task.go
+++ b/relay/relay_task.go
@@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
@@ -15,8 +14,9 @@ import (
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
- "one-api/setting"
- "one-api/setting/operation_setting"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
)
/*
@@ -24,7 +24,14 @@ Task 任务通过平台、Action 区分任务
*/
func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
platform := constant.TaskPlatform(c.GetString("platform"))
- relayInfo := relaycommon.GenTaskRelayInfo(c)
+ if platform == "" {
+ platform = GetTaskPlatform(c)
+ }
+
+ relayInfo, err := relaycommon.GenTaskRelayInfo(c)
+ if err != nil {
+ return service.TaskErrorWrapper(err, "gen_relay_info_failed", http.StatusInternalServerError)
+ }
adaptor := GetTaskAdaptor(platform)
if adaptor == nil {
@@ -37,10 +44,13 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
return
}
- modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
- modelPrice, success := operation_setting.GetModelPrice(modelName, true)
+ modelName := relayInfo.OriginModelName
+ if modelName == "" {
+ modelName = service.CoverTaskActionToModelName(platform, relayInfo.Action)
+ }
+ modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
if !success {
- defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
+ defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
@@ -49,8 +59,14 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
// 预扣
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
- ratio := modelPrice * groupRatio
+ groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
+ var ratio float64
+ userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
+ if hasUserGroupRatio {
+ ratio = modelPrice * userGroupRatio
+ } else {
+ ratio = modelPrice * groupRatio
+ }
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
@@ -85,7 +101,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
c.Set("channel_id", originTask.ChannelId)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
- relayInfo.BaseUrl = channel.GetBaseURL()
+ relayInfo.ChannelBaseUrl = channel.GetBaseURL()
relayInfo.ChannelId = originTask.ChannelId
}
}
@@ -115,16 +131,31 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
if err != nil {
- common.SysError("error consuming token remain quota: " + err.Error())
+ common.SysLog("error consuming token remain quota: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
- logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)
+ gRatio := groupRatio
+ if hasUserGroupRatio {
+ gRatio = userGroupRatio
+ }
+ logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action)
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
- model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
- modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
+ if hasUserGroupRatio {
+ other["user_group_ratio"] = userGroupRatio
+ }
+ model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
+ ChannelId: relayInfo.ChannelId,
+ ModelName: modelName,
+ TokenName: tokenName,
+ Quota: quota,
+ Content: logContent,
+ TokenId: relayInfo.TokenId,
+ Group: relayInfo.UsingGroup,
+ Other: other,
+ })
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
@@ -137,10 +168,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
relayInfo.ConsumeQuota = true
// insert task
- task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
+ task := model.InitTask(platform, relayInfo)
task.TaskID = taskID
task.Quota = quota
task.Data = taskData
+ task.Action = relayInfo.Action
err = task.Insert()
if err != nil {
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
@@ -150,8 +182,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
- relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
- relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
+ relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
+ relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
+ relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
}
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
@@ -226,6 +259,30 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
return
}
+func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
+ taskId := c.Param("task_id")
+ if taskId == "" {
+ taskId = c.GetString("task_id")
+ }
+ userId := c.GetInt("id")
+
+ originTask, exist, err := model.GetByTaskId(userId, taskId)
+ if err != nil {
+ taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
+ return
+ }
+ if !exist {
+ taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
+ return
+ }
+
+ respBody, err = json.Marshal(dto.TaskResponse[any]{
+ Code: "success",
+ Data: TaskModel2Dto(originTask),
+ })
+ return
+}
+
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
return &dto.TaskDto{
TaskID: task.TaskID,
diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go
new file mode 100644
index 00000000..fa3c7bbb
--- /dev/null
+++ b/relay/rerank_handler.go
@@ -0,0 +1,99 @@
+package relay
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting/model_setting"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+ info.InitChannelMeta(c)
+
+ rerankReq, ok := info.Request.(*dto.RerankRequest)
+ if !ok {
+ return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.RerankRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ request, err := common.DeepCopy(rerankReq)
+ if err != nil {
+ return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ err = helper.ModelMappedHelper(c, info, request)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
+ }
+
+ adaptor := GetAdaptor(info.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ }
+ adaptor.Init(info)
+
+ var requestBody io.Reader
+ if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
+ body, err := common.GetRequestBody(c)
+ if err != nil {
+ return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+ }
+ requestBody = bytes.NewBuffer(body)
+ } else {
+ convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *request)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+ jsonData, err := common.Marshal(convertedRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+
+ // apply param override
+ if len(info.ParamOverride) > 0 {
+ jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+ }
+ }
+
+ if common.DebugEnabled {
+ println(fmt.Sprintf("Rerank request body: %s", string(jsonData)))
+ }
+ requestBody = bytes.NewBuffer(jsonData)
+ }
+
+ resp, err := adaptor.DoRequest(c, info, requestBody)
+ if err != nil {
+ return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
+ }
+
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+ var httpResp *http.Response
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ if httpResp.StatusCode != http.StatusOK {
+ newAPIError = service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ }
+
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
+ if newAPIError != nil {
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ postConsumeQuota(c, info, usage.(*dto.Usage), "")
+ return nil
+}
diff --git a/relay/responses_handler.go b/relay/responses_handler.go
new file mode 100644
index 00000000..f5f624c9
--- /dev/null
+++ b/relay/responses_handler.go
@@ -0,0 +1,105 @@
+package relay
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting/model_setting"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+ info.InitChannelMeta(c)
+
+ responsesReq, ok := info.Request.(*dto.OpenAIResponsesRequest)
+ if !ok {
+ return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.OpenAIResponsesRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ request, err := common.DeepCopy(responsesReq)
+ if err != nil {
+ return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+ }
+
+ err = helper.ModelMappedHelper(c, info, request)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
+ }
+
+ adaptor := GetAdaptor(info.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+ }
+ adaptor.Init(info)
+ var requestBody io.Reader
+ if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
+ body, err := common.GetRequestBody(c)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
+ }
+ requestBody = bytes.NewBuffer(body)
+ } else {
+ convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *request)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+ jsonData, err := common.Marshal(convertedRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+ }
+ // apply param override
+ if len(info.ParamOverride) > 0 {
+ jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+ }
+ }
+
+ if common.DebugEnabled {
+ println("requestBody: ", string(jsonData))
+ }
+ requestBody = bytes.NewBuffer(jsonData)
+ }
+
+ var httpResp *http.Response
+ resp, err := adaptor.DoRequest(c, info, requestBody)
+ if err != nil {
+ return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
+ }
+
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+
+ if httpResp.StatusCode != http.StatusOK {
+ newAPIError = service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ }
+
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
+ if newAPIError != nil {
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+
+ if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
+ service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
+ } else {
+ postConsumeQuota(c, info, usage.(*dto.Usage), "")
+ }
+ return nil
+}
diff --git a/relay/websocket.go b/relay/websocket.go
index c815eb71..2d313154 100644
--- a/relay/websocket.go
+++ b/relay/websocket.go
@@ -1,118 +1,45 @@
package relay
import (
- "encoding/json"
"fmt"
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
- "net/http"
- "one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
- "one-api/setting"
- "one-api/setting/operation_setting"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+ "github.com/gorilla/websocket"
)
-func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
- relayInfo := relaycommon.GenRelayInfoWs(c, ws)
+func WssHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
+ info.InitChannelMeta(c)
- // get & validate textRequest 获取并验证文本请求
- //realtimeEvent, err := getAndValidateWssRequest(c, ws)
- //if err != nil {
- // common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
- // return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
- //}
-
- // map model name
- modelMapping := c.GetString("model_mapping")
- //isModelMapped := false
- if modelMapping != "" && modelMapping != "{}" {
- modelMap := make(map[string]string)
- err := json.Unmarshal([]byte(modelMapping), &modelMap)
- if err != nil {
- return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
- }
- if modelMap[relayInfo.OriginModelName] != "" {
- relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName]
- // set upstream model name
- //isModelMapped = true
- }
- }
- //relayInfo.UpstreamModelName = textRequest.Model
- modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
-
- var preConsumedQuota int
- var ratio float64
- var modelRatio float64
- //err := service.SensitiveWordsCheck(textRequest)
-
- //if constant.ShouldCheckPromptSensitive() {
- // err = checkRequestSensitive(textRequest, relayInfo)
- // if err != nil {
- // return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
- // }
- //}
-
- //promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
- //// count messages token error 计算promptTokens错误
- //if err != nil {
- // return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
- //}
- //
- if !getModelPriceSuccess {
- preConsumedTokens := common.PreConsumedQuota
- //if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
- // preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
- //}
- modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
- ratio = modelRatio * groupRatio
- preConsumedQuota = int(float64(preConsumedTokens) * ratio)
- } else {
- preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
- relayInfo.UsePrice = true
- }
-
- // pre-consume quota 预消耗配额
- preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
- if openaiErr != nil {
- return openaiErr
- }
-
- defer func() {
- if openaiErr != nil {
- returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
- }
- }()
-
- adaptor := GetAdaptor(relayInfo.ApiType)
+ adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
- return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
+ return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
- adaptor.Init(relayInfo)
+ adaptor.Init(info)
//var requestBody io.Reader
//firstWssRequest, _ := c.Get("first_wss_request")
//requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
statusCodeMappingStr := c.GetString("status_code_mapping")
- resp, err := adaptor.DoRequest(c, relayInfo, nil)
+ resp, err := adaptor.DoRequest(c, info, nil)
if err != nil {
- return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+ return types.NewError(err, types.ErrorCodeDoRequestFailed)
}
if resp != nil {
- relayInfo.TargetWs = resp.(*websocket.Conn)
- defer relayInfo.TargetWs.Close()
+ info.TargetWs = resp.(*websocket.Conn)
+ defer info.TargetWs.Close()
}
- usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo)
- if openaiErr != nil {
+ usage, newAPIError := adaptor.DoResponse(c, nil, info)
+ if newAPIError != nil {
// reset status code 重置状态码
- service.ResetStatusCode(openaiErr, statusCodeMappingStr)
- return openaiErr
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
}
- service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
- userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
+ service.PostWssConsumeQuota(c, info, info.UpstreamModelName, usage.(*dto.RealtimeUsage), "")
return nil
}
diff --git a/router/api-router.go b/router/api-router.go
index 7bbc654a..7a60994d 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -16,6 +16,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/setup", controller.GetSetup)
apiRouter.POST("/setup", controller.PostSetup)
apiRouter.GET("/status", controller.GetStatus)
+ apiRouter.GET("/uptime/status", controller.GetUptimeKumaStatus)
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
apiRouter.GET("/notice", controller.GetNotice)
@@ -23,7 +24,7 @@ func SetApiRouter(router *gin.Engine) {
//apiRouter.GET("/midjourney", controller.GetMidjourney)
apiRouter.GET("/home_page_content", controller.GetHomePageContent)
apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing)
- apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
+ apiRouter.GET("/verification", middleware.EmailVerificationRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
@@ -35,11 +36,15 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
+ apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
+
+ apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
userRoute := apiRouter.Group("/user")
{
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
+ userRoute.POST("/login/2fa", middleware.CriticalRateLimit(), controller.Verify2FALogin)
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
userRoute.GET("/logout", controller.Logout)
userRoute.GET("/epay/notify", controller.EpayNotify)
@@ -55,11 +60,20 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.DELETE("/self", controller.DeleteSelf)
selfRoute.GET("/token", controller.GenerateAccessToken)
selfRoute.GET("/aff", controller.GetAffCode)
- selfRoute.POST("/topup", controller.TopUp)
- selfRoute.POST("/pay", controller.RequestEpay)
+ selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
+ selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestEpay)
selfRoute.POST("/amount", controller.RequestAmount)
+ selfRoute.POST("/stripe/pay", middleware.CriticalRateLimit(), controller.RequestStripePay)
+ selfRoute.POST("/stripe/amount", controller.RequestStripeAmount)
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
selfRoute.PUT("/setting", controller.UpdateUserSetting)
+
+ // 2FA routes
+ selfRoute.GET("/2fa/status", controller.Get2FAStatus)
+ selfRoute.POST("/2fa/setup", controller.Setup2FA)
+ selfRoute.POST("/2fa/enable", controller.Enable2FA)
+ selfRoute.POST("/2fa/disable", controller.Disable2FA)
+ selfRoute.POST("/2fa/backup_codes", controller.RegenerateBackupCodes)
}
adminRoute := userRoute.Group("/")
@@ -72,6 +86,10 @@ func SetApiRouter(router *gin.Engine) {
adminRoute.POST("/manage", controller.ManageUser)
adminRoute.PUT("/", controller.UpdateUser)
adminRoute.DELETE("/:id", controller.DeleteUser)
+
+ // Admin 2FA routes
+ adminRoute.GET("/2fa/stats", controller.Admin2FAStats)
+ adminRoute.DELETE("/:id/2fa", controller.AdminDisable2FA)
}
}
optionRoute := apiRouter.Group("/option")
@@ -80,6 +98,13 @@ func SetApiRouter(router *gin.Engine) {
optionRoute.GET("/", controller.GetOptions)
optionRoute.PUT("/", controller.UpdateOption)
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
+ optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
+ }
+ ratioSyncRoute := apiRouter.Group("/ratio_sync")
+ ratioSyncRoute.Use(middleware.RootAuth())
+ {
+ ratioSyncRoute.GET("/channels", controller.GetSyncableChannels)
+ ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios)
}
channelRoute := apiRouter.Group("/channel")
channelRoute.Use(middleware.AdminAuth())
@@ -105,6 +130,9 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
channelRoute.POST("/fetch_models", controller.FetchModels)
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
+ channelRoute.GET("/tag/models", controller.GetTagModels)
+ channelRoute.POST("/copy/:id", controller.CopyChannel)
+ channelRoute.POST("/multi_key/manage", controller.ManageMultiKeys)
}
tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth())
@@ -116,6 +144,7 @@ func SetApiRouter(router *gin.Engine) {
tokenRoute.POST("/", controller.AddToken)
tokenRoute.PUT("/", controller.UpdateToken)
tokenRoute.DELETE("/:id", controller.DeleteToken)
+ tokenRoute.POST("/batch", controller.DeleteTokenBatch)
}
redemptionRoute := apiRouter.Group("/redemption")
redemptionRoute.Use(middleware.AdminAuth())
@@ -125,6 +154,7 @@ func SetApiRouter(router *gin.Engine) {
redemptionRoute.GET("/:id", controller.GetRedemption)
redemptionRoute.POST("/", controller.AddRedemption)
redemptionRoute.PUT("/", controller.UpdateRedemption)
+ redemptionRoute.DELETE("/invalid", controller.DeleteInvalidRedemption)
redemptionRoute.DELETE("/:id", controller.DeleteRedemption)
}
logRoute := apiRouter.Group("/log")
@@ -149,6 +179,16 @@ func SetApiRouter(router *gin.Engine) {
{
groupRoute.GET("/", controller.GetGroups)
}
+
+ prefillGroupRoute := apiRouter.Group("/prefill_group")
+ prefillGroupRoute.Use(middleware.AdminAuth())
+ {
+ prefillGroupRoute.GET("/", controller.GetPrefillGroups)
+ prefillGroupRoute.POST("/", controller.CreatePrefillGroup)
+ prefillGroupRoute.PUT("/", controller.UpdatePrefillGroup)
+ prefillGroupRoute.DELETE("/:id", controller.DeletePrefillGroup)
+ }
+
mjRoute := apiRouter.Group("/mj")
mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney)
mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney)
@@ -158,5 +198,28 @@ func SetApiRouter(router *gin.Engine) {
taskRoute.GET("/self", middleware.UserAuth(), controller.GetUserTask)
taskRoute.GET("/", middleware.AdminAuth(), controller.GetAllTask)
}
+
+ vendorRoute := apiRouter.Group("/vendors")
+ vendorRoute.Use(middleware.AdminAuth())
+ {
+ vendorRoute.GET("/", controller.GetAllVendors)
+ vendorRoute.GET("/search", controller.SearchVendors)
+ vendorRoute.GET("/:id", controller.GetVendorMeta)
+ vendorRoute.POST("/", controller.CreateVendorMeta)
+ vendorRoute.PUT("/", controller.UpdateVendorMeta)
+ vendorRoute.DELETE("/:id", controller.DeleteVendorMeta)
+ }
+
+ modelsRoute := apiRouter.Group("/models")
+ modelsRoute.Use(middleware.AdminAuth())
+ {
+ modelsRoute.GET("/missing", controller.GetMissingModels)
+ modelsRoute.GET("/", controller.GetAllModelsMeta)
+ modelsRoute.GET("/search", controller.SearchModelsMeta)
+ modelsRoute.GET("/:id", controller.GetModelMeta)
+ modelsRoute.POST("/", controller.CreateModelMeta)
+ modelsRoute.PUT("/", controller.UpdateModelMeta)
+ modelsRoute.DELETE("/:id", controller.DeleteModelMeta)
+ }
}
}
diff --git a/router/main.go b/router/main.go
index b8ac4055..23576427 100644
--- a/router/main.go
+++ b/router/main.go
@@ -3,17 +3,19 @@ package router
import (
"embed"
"fmt"
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"os"
"strings"
+
+ "github.com/gin-gonic/gin"
)
func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
SetApiRouter(router)
SetDashboardRouter(router)
SetRelayRouter(router)
+ SetVideoRouter(router)
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
if common.IsMasterNode && frontendBaseUrl != "" {
frontendBaseUrl = ""
diff --git a/router/relay-router.go b/router/relay-router.go
index 1115a491..e0f05e97 100644
--- a/router/relay-router.go
+++ b/router/relay-router.go
@@ -1,9 +1,11 @@
package router
import (
+ "one-api/constant"
"one-api/controller"
"one-api/middleware"
"one-api/relay"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
@@ -11,15 +13,50 @@ import (
func SetRelayRouter(router *gin.Engine) {
router.Use(middleware.CORS())
router.Use(middleware.DecompressRequestMiddleware())
+ router.Use(middleware.StatsMiddleware())
// https://platform.openai.com/docs/api-reference/introduction
modelsRouter := router.Group("/v1/models")
modelsRouter.Use(middleware.TokenAuth())
{
- modelsRouter.GET("", controller.ListModels)
- modelsRouter.GET("/:model", controller.RetrieveModel)
+ modelsRouter.GET("", func(c *gin.Context) {
+ switch {
+ case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "":
+ controller.ListModels(c, constant.ChannelTypeAnthropic)
+ case c.GetHeader("x-goog-api-key") != "" || c.Query("key") != "": // 单独的适配
+ controller.RetrieveModel(c, constant.ChannelTypeGemini)
+ default:
+ controller.ListModels(c, constant.ChannelTypeOpenAI)
+ }
+ })
+
+ modelsRouter.GET("/:model", func(c *gin.Context) {
+ switch {
+ case c.GetHeader("x-api-key") != "" && c.GetHeader("anthropic-version") != "":
+ controller.RetrieveModel(c, constant.ChannelTypeAnthropic)
+ default:
+ controller.RetrieveModel(c, constant.ChannelTypeOpenAI)
+ }
+ })
}
+
+ geminiRouter := router.Group("/v1beta/models")
+ geminiRouter.Use(middleware.TokenAuth())
+ {
+ geminiRouter.GET("", func(c *gin.Context) {
+ controller.ListModels(c, constant.ChannelTypeGemini)
+ })
+ }
+
+ geminiCompatibleRouter := router.Group("/v1beta/openai/models")
+ geminiCompatibleRouter.Use(middleware.TokenAuth())
+ {
+ geminiCompatibleRouter.GET("", func(c *gin.Context) {
+ controller.ListModels(c, constant.ChannelTypeOpenAI)
+ })
+ }
+
playgroundRouter := router.Group("/pg")
- playgroundRouter.Use(middleware.UserAuth())
+ playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
{
playgroundRouter.POST("/chat/completions", controller.Playground)
}
@@ -27,28 +64,83 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.Use(middleware.TokenAuth())
relayV1Router.Use(middleware.ModelRequestRateLimit())
{
- // WebSocket 路由
+ // WebSocket 路由(统一到 Relay)
wsRouter := relayV1Router.Group("")
wsRouter.Use(middleware.Distribute())
- wsRouter.GET("/realtime", controller.WssRelay)
+ wsRouter.GET("/realtime", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIRealtime)
+ })
}
{
//http router
httpRouter := relayV1Router.Group("")
httpRouter.Use(middleware.Distribute())
- httpRouter.POST("/messages", controller.RelayClaude)
- httpRouter.POST("/completions", controller.Relay)
- httpRouter.POST("/chat/completions", controller.Relay)
- httpRouter.POST("/edits", controller.Relay)
- httpRouter.POST("/images/generations", controller.Relay)
- httpRouter.POST("/images/edits", controller.Relay)
+
+ // claude related routes
+ httpRouter.POST("/messages", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatClaude)
+ })
+
+ // chat related routes
+ httpRouter.POST("/completions", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAI)
+ })
+ httpRouter.POST("/chat/completions", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAI)
+ })
+
+ // response related routes
+ httpRouter.POST("/responses", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIResponses)
+ })
+
+ // image related routes
+ httpRouter.POST("/edits", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIImage)
+ })
+ httpRouter.POST("/images/generations", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIImage)
+ })
+ httpRouter.POST("/images/edits", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIImage)
+ })
+
+ // embedding related routes
+ httpRouter.POST("/embeddings", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatEmbedding)
+ })
+
+ // audio related routes
+ httpRouter.POST("/audio/transcriptions", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIAudio)
+ })
+ httpRouter.POST("/audio/translations", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIAudio)
+ })
+ httpRouter.POST("/audio/speech", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAIAudio)
+ })
+
+ // rerank related routes
+ httpRouter.POST("/rerank", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatRerank)
+ })
+
+ // gemini relay routes
+ httpRouter.POST("/engines/:model/embeddings", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatGemini)
+ })
+ httpRouter.POST("/models/*path", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatGemini)
+ })
+
+ // other relay routes
+ httpRouter.POST("/moderations", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatOpenAI)
+ })
+
+ // not implemented
httpRouter.POST("/images/variations", controller.RelayNotImplemented)
- httpRouter.POST("/embeddings", controller.Relay)
- httpRouter.POST("/engines/:model/embeddings", controller.Relay)
- httpRouter.POST("/audio/transcriptions", controller.Relay)
- httpRouter.POST("/audio/translations", controller.Relay)
- httpRouter.POST("/audio/speech", controller.Relay)
- httpRouter.POST("/responses", controller.Relay)
httpRouter.GET("/files", controller.RelayNotImplemented)
httpRouter.POST("/files", controller.RelayNotImplemented)
httpRouter.DELETE("/files/:id", controller.RelayNotImplemented)
@@ -60,8 +152,6 @@ func SetRelayRouter(router *gin.Engine) {
httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
httpRouter.DELETE("/models/:model", controller.RelayNotImplemented)
- httpRouter.POST("/moderations", controller.Relay)
- httpRouter.POST("/rerank", controller.Relay)
}
relayMjRouter := router.Group("/mj")
@@ -85,7 +175,9 @@ func SetRelayRouter(router *gin.Engine) {
relayGeminiRouter.Use(middleware.Distribute())
{
// Gemini API 路径格式: /v1beta/models/{model_name}:{action}
- relayGeminiRouter.POST("/models/*path", controller.Relay)
+ relayGeminiRouter.POST("/models/*path", func(c *gin.Context) {
+ controller.Relay(c, types.RelayFormatGemini)
+ })
}
}
@@ -101,6 +193,8 @@ func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
relayMjRouter.POST("/submit/describe", controller.RelayMidjourney)
relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/edits", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/video", controller.RelayMidjourney)
relayMjRouter.POST("/notify", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
diff --git a/router/video-router.go b/router/video-router.go
new file mode 100644
index 00000000..bcc05eae
--- /dev/null
+++ b/router/video-router.go
@@ -0,0 +1,34 @@
+package router
+
+import (
+ "one-api/controller"
+ "one-api/middleware"
+
+ "github.com/gin-gonic/gin"
+)
+
+func SetVideoRouter(router *gin.Engine) {
+ videoV1Router := router.Group("/v1")
+ videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
+ {
+ videoV1Router.POST("/video/generations", controller.RelayTask)
+ videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
+ }
+
+ klingV1Router := router.Group("/kling/v1")
+ klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
+ {
+ klingV1Router.POST("/videos/text2video", controller.RelayTask)
+ klingV1Router.POST("/videos/image2video", controller.RelayTask)
+ klingV1Router.GET("/videos/text2video/:task_id", controller.RelayTask)
+ klingV1Router.GET("/videos/image2video/:task_id", controller.RelayTask)
+ }
+
+ // Jimeng official API routes - direct mapping to official API format
+ jimengOfficialGroup := router.Group("jimeng")
+ jimengOfficialGroup.Use(middleware.JimengRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
+ {
+ // Maps to: /?Action=CVSync2AsyncSubmitTask&Version=2022-08-31 and /?Action=CVSync2AsyncGetResult&Version=2022-08-31
+ jimengOfficialGroup.POST("/", controller.RelayTask)
+ }
+}
diff --git a/service/audio.go b/service/audio.go
index d558e96f..c4b6f01b 100644
--- a/service/audio.go
+++ b/service/audio.go
@@ -3,6 +3,7 @@ package service
import (
"encoding/base64"
"fmt"
+ "strings"
)
func parseAudio(audioBase64 string, format string) (duration float64, err error) {
@@ -29,3 +30,19 @@ func parseAudio(audioBase64 string, format string) (duration float64, err error)
duration = float64(samplesCount) / float64(sampleRate)
return duration, nil
}
+
+func DecodeBase64AudioData(audioBase64 string) (string, error) {
+ // 检查并移除 data:audio/xxx;base64, 前缀
+ idx := strings.Index(audioBase64, ",")
+ if idx != -1 {
+ audioBase64 = audioBase64[idx+1:]
+ }
+
+ // 解码 Base64 数据
+ _, err := base64.StdEncoding.DecodeString(audioBase64)
+ if err != nil {
+ return "", fmt.Errorf("base64 decode error: %v", err)
+ }
+
+ return audioBase64, nil
+}
diff --git a/service/cf_worker.go b/service/cf_worker.go
index ae6e1ffe..4a7b4376 100644
--- a/service/cf_worker.go
+++ b/service/cf_worker.go
@@ -42,16 +42,16 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
}
-func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
+func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) {
if setting.EnableWorker() {
- common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
+ common.SysLog(fmt.Sprintf("downloading file from worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
req := &WorkerRequest{
URL: originUrl,
Key: setting.WorkerValidKey,
}
return DoWorkerRequest(req)
} else {
- common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
+ common.SysLog(fmt.Sprintf("downloading from origin with worker: %s, reason: %s", originUrl, strings.Join(reason, ", ")))
return http.Get(originUrl)
}
}
diff --git a/service/channel.go b/service/channel.go
index e3a76af4..faac6d10 100644
--- a/service/channel.go
+++ b/service/channel.go
@@ -4,9 +4,11 @@ import (
"fmt"
"net/http"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/setting/operation_setting"
+ "one-api/types"
"strings"
)
@@ -15,17 +17,17 @@ func formatNotifyType(channelId int, status int) string {
}
// disable & notify
-func DisableChannel(channelId int, channelName string, reason string) {
- success := model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
+func DisableChannel(channelError types.ChannelError, reason string) {
+ success := model.UpdateChannelStatus(channelError.ChannelId, channelError.UsingKey, common.ChannelStatusAutoDisabled, reason)
if success {
- subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
- content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
- NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusAutoDisabled), subject, content)
+ subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelError.ChannelName, channelError.ChannelId)
+ content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, reason)
+ NotifyRootUser(formatNotifyType(channelError.ChannelId, common.ChannelStatusAutoDisabled), subject, content)
}
}
-func EnableChannel(channelId int, channelName string) {
- success := model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
+func EnableChannel(channelId int, usingKey string, channelName string) {
+ success := model.UpdateChannelStatus(channelId, usingKey, common.ChannelStatusEnabled, "")
if success {
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
@@ -33,14 +35,17 @@ func EnableChannel(channelId int, channelName string) {
}
}
-func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {
+func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool {
if !common.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
return false
}
- if err.LocalError {
+ if types.IsChannelError(err) {
+ return true
+ }
+ if types.IsSkipRetryError(err) {
return false
}
if err.StatusCode == http.StatusUnauthorized {
@@ -48,19 +53,22 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
}
if err.StatusCode == http.StatusForbidden {
switch channelType {
- case common.ChannelTypeGemini:
+ case constant.ChannelTypeGemini:
return true
}
}
- switch err.Error.Code {
+ oaiErr := err.ToOpenAIError()
+ switch oaiErr.Code {
case "invalid_api_key":
return true
case "account_deactivated":
return true
case "billing_not_active":
return true
+ case "pre_consume_token_quota_failed":
+ return true
}
- switch err.Error.Type {
+ switch oaiErr.Type {
case "insufficient_quota":
return true
case "insufficient_user_quota":
@@ -74,23 +82,16 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
return true
}
- lowerMessage := strings.ToLower(err.Error.Message)
+ lowerMessage := strings.ToLower(err.Error())
search, _ := AcSearch(lowerMessage, operation_setting.AutomaticDisableKeywords, true)
- if search {
- return true
- }
-
- return false
+ return search
}
-func ShouldEnableChannel(err error, openaiWithStatusErr *dto.OpenAIErrorWithStatusCode, status int) bool {
+func ShouldEnableChannel(newAPIError *types.NewAPIError, status int) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}
- if err != nil {
- return false
- }
- if openaiWithStatusErr != nil {
+ if newAPIError != nil {
return false
}
if status != common.ChannelStatusAutoDisabled {
diff --git a/service/convert.go b/service/convert.go
index cb964a46..ea219c4f 100644
--- a/service/convert.go
+++ b/service/convert.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"one-api/common"
+ "one-api/constant"
"one-api/dto"
"one-api/relay/channel/openrouter"
relaycommon "one-api/relay/common"
@@ -19,12 +20,12 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
Stream: claudeRequest.Stream,
}
- isOpenRouter := info.ChannelType == common.ChannelTypeOpenRouter
+ isOpenRouter := info.ChannelType == constant.ChannelTypeOpenRouter
- if claudeRequest.Thinking != nil {
+ if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" {
if isOpenRouter {
reasoning := openrouter.RequestReasoning{
- MaxTokens: claudeRequest.Thinking.BudgetTokens,
+ MaxTokens: claudeRequest.Thinking.GetBudgetTokens(),
}
reasoningJSON, err := json.Marshal(reasoning)
if err != nil {
@@ -152,9 +153,13 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
toolCalls = append(toolCalls, toolCall)
case "tool_result":
// Add tool result as a separate message
+ toolName := mediaMsg.Name
+ if toolName == "" {
+ toolName = claudeRequest.SearchToolNameByToolCallId(mediaMsg.ToolUseId)
+ }
oaiToolMessage := dto.Message{
Role: "tool",
- Name: &mediaMsg.Name,
+ Name: &toolName,
ToolCallId: mediaMsg.ToolUseId,
}
//oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text)
@@ -162,7 +167,7 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
} else {
mediaContents := mediaMsg.ParseMediaContent()
- encodeJson, _ := common.EncodeJson(mediaContents)
+ encodeJson, _ := common.Marshal(mediaContents)
oaiToolMessage.SetStringContent(string(encodeJson))
}
openAIMessages = append(openAIMessages, oaiToolMessage)
@@ -187,28 +192,6 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
return &openAIRequest, nil
}
-func OpenAIErrorToClaudeError(openAIError *dto.OpenAIErrorWithStatusCode) *dto.ClaudeErrorWithStatusCode {
- claudeError := dto.ClaudeError{
- Type: "new_api_error",
- Message: openAIError.Error.Message,
- }
- return &dto.ClaudeErrorWithStatusCode{
- Error: claudeError,
- StatusCode: openAIError.StatusCode,
- }
-}
-
-func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.OpenAIErrorWithStatusCode {
- openAIError := dto.OpenAIError{
- Message: claudeError.Error.Message,
- Type: "new_api_error",
- }
- return &dto.OpenAIErrorWithStatusCode{
- Error: openAIError,
- StatusCode: claudeError.StatusCode,
- }
-}
-
func generateStopBlock(index int) *dto.ClaudeResponse {
return &dto.ClaudeResponse{
Type: "content_block_stop",
@@ -239,49 +222,88 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
// Type: "ping",
//})
if openAIResponse.IsToolCall() {
+ info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
resp := &dto.ClaudeResponse{
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
- Id: openAIResponse.GetFirstToolCall().ID,
- Type: "tool_use",
- Name: openAIResponse.GetFirstToolCall().Function.Name,
+ Id: openAIResponse.GetFirstToolCall().ID,
+ Type: "tool_use",
+ Name: openAIResponse.GetFirstToolCall().Function.Name,
+ Input: map[string]interface{}{},
},
}
resp.SetIndex(0)
claudeResponses = append(claudeResponses, resp)
} else {
- //resp := &dto.ClaudeResponse{
- // Type: "content_block_start",
- // ContentBlock: &dto.ClaudeMediaMessage{
- // Type: "text",
- // Text: common.GetPointer[string](""),
- // },
- //}
- //resp.SetIndex(0)
- //claudeResponses = append(claudeResponses, resp)
+
+ }
+ // 判断首个响应是否存在内容(非标准的 OpenAI 响应)
+ if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.GetContentString()) > 0 {
+ claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+ Index: &info.ClaudeConvertInfo.Index,
+ Type: "content_block_start",
+ ContentBlock: &dto.ClaudeMediaMessage{
+ Type: "text",
+ Text: common.GetPointer[string](""),
+ },
+ })
+ claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+ Type: "content_block_delta",
+ Delta: &dto.ClaudeMediaMessage{
+ Type: "text",
+ Text: common.GetPointer[string](openAIResponse.Choices[0].Delta.GetContentString()),
+ },
+ })
+ info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
}
return claudeResponses
}
if len(openAIResponse.Choices) == 0 {
// no choices
- // TODO: handle this case
+ // 可能为非标准的 OpenAI 响应,判断是否已经完成
+ if info.Done {
+ claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
+ oaiUsage := info.ClaudeConvertInfo.Usage
+ if oaiUsage != nil {
+ claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+ Type: "message_delta",
+ Usage: &dto.ClaudeUsage{
+ InputTokens: oaiUsage.PromptTokens,
+ OutputTokens: oaiUsage.CompletionTokens,
+ CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
+ CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
+ },
+ Delta: &dto.ClaudeMediaMessage{
+ StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
+ },
+ })
+ }
+ claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+ Type: "message_stop",
+ })
+ }
return claudeResponses
} else {
chosenChoice := openAIResponse.Choices[0]
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
// should be done
info.FinishReason = *chosenChoice.FinishReason
- return claudeResponses
+ if !info.Done {
+ return claudeResponses
+ }
}
if info.Done {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
- if info.ClaudeConvertInfo.Usage != nil {
+ oaiUsage := info.ClaudeConvertInfo.Usage
+ if oaiUsage != nil {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_delta",
Usage: &dto.ClaudeUsage{
- InputTokens: info.ClaudeConvertInfo.Usage.PromptTokens,
- OutputTokens: info.ClaudeConvertInfo.Usage.CompletionTokens,
+ InputTokens: oaiUsage.PromptTokens,
+ OutputTokens: oaiUsage.CompletionTokens,
+ CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
+ CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
},
Delta: &dto.ClaudeMediaMessage{
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
@@ -385,22 +407,26 @@ func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relayco
}
for _, choice := range openAIResponse.Choices {
stopReason = stopReasonOpenAI2Claude(choice.FinishReason)
- claudeContent := dto.ClaudeMediaMessage{}
if choice.FinishReason == "tool_calls" {
- claudeContent.Type = "tool_use"
- claudeContent.Id = choice.Message.ToolCallId
- claudeContent.Name = choice.Message.ParseToolCalls()[0].Function.Name
- var mapParams map[string]interface{}
- if err := json.Unmarshal([]byte(choice.Message.ParseToolCalls()[0].Function.Arguments), &mapParams); err == nil {
- claudeContent.Input = mapParams
- } else {
- claudeContent.Input = choice.Message.ParseToolCalls()[0].Function.Arguments
+ for _, toolUse := range choice.Message.ParseToolCalls() {
+ claudeContent := dto.ClaudeMediaMessage{}
+ claudeContent.Type = "tool_use"
+ claudeContent.Id = toolUse.ID
+ claudeContent.Name = toolUse.Function.Name
+ var mapParams map[string]interface{}
+ if err := common.Unmarshal([]byte(toolUse.Function.Arguments), &mapParams); err == nil {
+ claudeContent.Input = mapParams
+ } else {
+ claudeContent.Input = toolUse.Function.Arguments
+ }
+ contents = append(contents, claudeContent)
}
} else {
+ claudeContent := dto.ClaudeMediaMessage{}
claudeContent.Type = "text"
claudeContent.SetText(choice.Message.StringContent())
+ contents = append(contents, claudeContent)
}
- contents = append(contents, claudeContent)
}
claudeResponse.Content = contents
claudeResponse.StopReason = stopReason
@@ -418,6 +444,8 @@ func stopReasonOpenAI2Claude(reason string) string {
return "end_turn"
case "stop_sequence":
return "stop_sequence"
+ case "length":
+ fallthrough
case "max_tokens":
return "max_tokens"
case "tool_calls":
@@ -434,3 +462,353 @@ func toJSONString(v interface{}) string {
}
return string(b)
}
+
+func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
+ openaiRequest := &dto.GeneralOpenAIRequest{
+ Model: info.UpstreamModelName,
+ Stream: info.IsStream,
+ }
+
+ // 转换 messages
+ var messages []dto.Message
+ for _, content := range geminiRequest.Contents {
+ message := dto.Message{
+ Role: convertGeminiRoleToOpenAI(content.Role),
+ }
+
+ // 处理 parts
+ var mediaContents []dto.MediaContent
+ var toolCalls []dto.ToolCallRequest
+ for _, part := range content.Parts {
+ if part.Text != "" {
+ mediaContent := dto.MediaContent{
+ Type: "text",
+ Text: part.Text,
+ }
+ mediaContents = append(mediaContents, mediaContent)
+ } else if part.InlineData != nil {
+ mediaContent := dto.MediaContent{
+ Type: "image_url",
+ ImageUrl: &dto.MessageImageUrl{
+ Url: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data),
+ Detail: "auto",
+ MimeType: part.InlineData.MimeType,
+ },
+ }
+ mediaContents = append(mediaContents, mediaContent)
+ } else if part.FileData != nil {
+ mediaContent := dto.MediaContent{
+ Type: "image_url",
+ ImageUrl: &dto.MessageImageUrl{
+ Url: part.FileData.FileUri,
+ Detail: "auto",
+ MimeType: part.FileData.MimeType,
+ },
+ }
+ mediaContents = append(mediaContents, mediaContent)
+ } else if part.FunctionCall != nil {
+ // 处理 Gemini 的工具调用
+ toolCall := dto.ToolCallRequest{
+ ID: fmt.Sprintf("call_%d", len(toolCalls)+1), // 生成唯一ID
+ Type: "function",
+ Function: dto.FunctionRequest{
+ Name: part.FunctionCall.FunctionName,
+ Arguments: toJSONString(part.FunctionCall.Arguments),
+ },
+ }
+ toolCalls = append(toolCalls, toolCall)
+ } else if part.FunctionResponse != nil {
+ // 处理 Gemini 的工具响应,创建单独的 tool 消息
+ toolMessage := dto.Message{
+ Role: "tool",
+ ToolCallId: fmt.Sprintf("call_%d", len(toolCalls)), // 使用对应的调用ID
+ }
+ toolMessage.SetStringContent(toJSONString(part.FunctionResponse.Response))
+ messages = append(messages, toolMessage)
+ }
+ }
+
+ // 设置消息内容
+ if len(toolCalls) > 0 {
+ // 如果有工具调用,设置工具调用
+ message.SetToolCalls(toolCalls)
+ } else if len(mediaContents) == 1 && mediaContents[0].Type == "text" {
+ // 如果只有一个文本内容,直接设置字符串
+ message.Content = mediaContents[0].Text
+ } else if len(mediaContents) > 0 {
+ // 如果有多个内容或包含媒体,设置为数组
+ message.SetMediaContent(mediaContents)
+ }
+
+ // 只有当消息有内容或工具调用时才添加
+ if len(message.ParseContent()) > 0 || len(message.ToolCalls) > 0 {
+ messages = append(messages, message)
+ }
+ }
+
+ openaiRequest.Messages = messages
+
+ if geminiRequest.GenerationConfig.Temperature != nil {
+ openaiRequest.Temperature = geminiRequest.GenerationConfig.Temperature
+ }
+ if geminiRequest.GenerationConfig.TopP > 0 {
+ openaiRequest.TopP = geminiRequest.GenerationConfig.TopP
+ }
+ if geminiRequest.GenerationConfig.TopK > 0 {
+ openaiRequest.TopK = int(geminiRequest.GenerationConfig.TopK)
+ }
+ if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
+ openaiRequest.MaxTokens = geminiRequest.GenerationConfig.MaxOutputTokens
+ }
+ // gemini stop sequences 最多 5 个,openai stop 最多 4 个
+ if len(geminiRequest.GenerationConfig.StopSequences) > 0 {
+ openaiRequest.Stop = geminiRequest.GenerationConfig.StopSequences[:4]
+ }
+ if geminiRequest.GenerationConfig.CandidateCount > 0 {
+ openaiRequest.N = geminiRequest.GenerationConfig.CandidateCount
+ }
+
+ // 转换工具调用
+ if len(geminiRequest.GetTools()) > 0 {
+ var tools []dto.ToolCallRequest
+ for _, tool := range geminiRequest.GetTools() {
+ if tool.FunctionDeclarations != nil {
+ // 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
+ functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
+ if ok {
+ for _, function := range functionDeclarations {
+ openAITool := dto.ToolCallRequest{
+ Type: "function",
+ Function: dto.FunctionRequest{
+ Name: function.Name,
+ Description: function.Description,
+ Parameters: function.Parameters,
+ },
+ }
+ tools = append(tools, openAITool)
+ }
+ }
+ }
+ }
+ if len(tools) > 0 {
+ openaiRequest.Tools = tools
+ }
+ }
+
+ // gemini system instructions
+ if geminiRequest.SystemInstructions != nil {
+ // 将系统指令作为第一条消息插入
+ systemMessage := dto.Message{
+ Role: "system",
+ Content: extractTextFromGeminiParts(geminiRequest.SystemInstructions.Parts),
+ }
+ openaiRequest.Messages = append([]dto.Message{systemMessage}, openaiRequest.Messages...)
+ }
+
+ return openaiRequest, nil
+}
+
+func convertGeminiRoleToOpenAI(geminiRole string) string {
+ switch geminiRole {
+ case "user":
+ return "user"
+ case "model":
+ return "assistant"
+ case "function":
+ return "function"
+ default:
+ return "user"
+ }
+}
+
+func extractTextFromGeminiParts(parts []dto.GeminiPart) string {
+ var texts []string
+ for _, part := range parts {
+ if part.Text != "" {
+ texts = append(texts, part.Text)
+ }
+ }
+ return strings.Join(texts, "\n")
+}
+
+// ResponseOpenAI2Gemini 将 OpenAI 响应转换为 Gemini 格式
+func ResponseOpenAI2Gemini(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
+ geminiResponse := &dto.GeminiChatResponse{
+ Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
+ PromptFeedback: dto.GeminiChatPromptFeedback{
+ SafetyRatings: []dto.GeminiChatSafetyRating{},
+ },
+ UsageMetadata: dto.GeminiUsageMetadata{
+ PromptTokenCount: openAIResponse.PromptTokens,
+ CandidatesTokenCount: openAIResponse.CompletionTokens,
+ TotalTokenCount: openAIResponse.PromptTokens + openAIResponse.CompletionTokens,
+ },
+ }
+
+ for _, choice := range openAIResponse.Choices {
+ candidate := dto.GeminiChatCandidate{
+ Index: int64(choice.Index),
+ SafetyRatings: []dto.GeminiChatSafetyRating{},
+ }
+
+ // 设置结束原因
+ var finishReason string
+ switch choice.FinishReason {
+ case "stop":
+ finishReason = "STOP"
+ case "length":
+ finishReason = "MAX_TOKENS"
+ case "content_filter":
+ finishReason = "SAFETY"
+ case "tool_calls":
+ finishReason = "STOP"
+ default:
+ finishReason = "STOP"
+ }
+ candidate.FinishReason = &finishReason
+
+ // 转换消息内容
+ content := dto.GeminiChatContent{
+ Role: "model",
+ Parts: make([]dto.GeminiPart, 0),
+ }
+
+ // 处理工具调用
+ toolCalls := choice.Message.ParseToolCalls()
+ if len(toolCalls) > 0 {
+ for _, toolCall := range toolCalls {
+ // 解析参数
+ var args map[string]interface{}
+ if toolCall.Function.Arguments != "" {
+ if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
+ args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
+ }
+ } else {
+ args = make(map[string]interface{})
+ }
+
+ part := dto.GeminiPart{
+ FunctionCall: &dto.FunctionCall{
+ FunctionName: toolCall.Function.Name,
+ Arguments: args,
+ },
+ }
+ content.Parts = append(content.Parts, part)
+ }
+ } else {
+ // 处理文本内容
+ textContent := choice.Message.StringContent()
+ if textContent != "" {
+ part := dto.GeminiPart{
+ Text: textContent,
+ }
+ content.Parts = append(content.Parts, part)
+ }
+ }
+
+ candidate.Content = content
+ geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
+ }
+
+ return geminiResponse
+}
+
+// StreamResponseOpenAI2Gemini 将 OpenAI 流式响应转换为 Gemini 格式
+func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
+ // 检查是否有实际内容或结束标志
+ hasContent := false
+ hasFinishReason := false
+ for _, choice := range openAIResponse.Choices {
+ if len(choice.Delta.GetContentString()) > 0 || (choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0) {
+ hasContent = true
+ }
+ if choice.FinishReason != nil {
+ hasFinishReason = true
+ }
+ }
+
+ // 如果没有实际内容且没有结束标志,跳过。主要针对 openai 流响应开头的空数据
+ if !hasContent && !hasFinishReason {
+ return nil
+ }
+
+ geminiResponse := &dto.GeminiChatResponse{
+ Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
+ PromptFeedback: dto.GeminiChatPromptFeedback{
+ SafetyRatings: []dto.GeminiChatSafetyRating{},
+ },
+ UsageMetadata: dto.GeminiUsageMetadata{
+ PromptTokenCount: info.PromptTokens,
+ CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
+ TotalTokenCount: info.PromptTokens,
+ },
+ }
+
+ for _, choice := range openAIResponse.Choices {
+ candidate := dto.GeminiChatCandidate{
+ Index: int64(choice.Index),
+ SafetyRatings: []dto.GeminiChatSafetyRating{},
+ }
+
+ // 设置结束原因
+ if choice.FinishReason != nil {
+ var finishReason string
+ switch *choice.FinishReason {
+ case "stop":
+ finishReason = "STOP"
+ case "length":
+ finishReason = "MAX_TOKENS"
+ case "content_filter":
+ finishReason = "SAFETY"
+ case "tool_calls":
+ finishReason = "STOP"
+ default:
+ finishReason = "STOP"
+ }
+ candidate.FinishReason = &finishReason
+ }
+
+ // 转换消息内容
+ content := dto.GeminiChatContent{
+ Role: "model",
+ Parts: make([]dto.GeminiPart, 0),
+ }
+
+ // 处理工具调用
+ if choice.Delta.ToolCalls != nil {
+ for _, toolCall := range choice.Delta.ToolCalls {
+ // 解析参数
+ var args map[string]interface{}
+ if toolCall.Function.Arguments != "" {
+ if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
+ args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
+ }
+ } else {
+ args = make(map[string]interface{})
+ }
+
+ part := dto.GeminiPart{
+ FunctionCall: &dto.FunctionCall{
+ FunctionName: toolCall.Function.Name,
+ Arguments: args,
+ },
+ }
+ content.Parts = append(content.Parts, part)
+ }
+ } else {
+ // 处理文本内容
+ textContent := choice.Delta.GetContentString()
+ if textContent != "" {
+ part := dto.GeminiPart{
+ Text: textContent,
+ }
+ content.Parts = append(content.Parts, part)
+ }
+ }
+
+ candidate.Content = content
+ geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
+ }
+
+ return geminiResponse
+}
diff --git a/service/error.go b/service/error.go
index 1bf5992b..ef5cbbde 100644
--- a/service/error.go
+++ b/service/error.go
@@ -1,12 +1,13 @@
package service
import (
- "encoding/json"
+ "errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
+ "one-api/types"
"strconv"
"strings"
)
@@ -25,39 +26,43 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int)
}
}
-// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
-func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
- text := err.Error()
- lowerText := strings.ToLower(text)
- if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
- common.SysLog(fmt.Sprintf("error: %s", text))
- text = "请求上游地址失败"
- }
- openAIError := dto.OpenAIError{
- Message: text,
- Type: "new_api_error",
- Code: code,
- }
- return &dto.OpenAIErrorWithStatusCode{
- Error: openAIError,
- StatusCode: statusCode,
- }
-}
-
-func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
- openaiErr := OpenAIErrorWrapper(err, code, statusCode)
- openaiErr.LocalError = true
- return openaiErr
-}
+//// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
+//func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
+// text := err.Error()
+// lowerText := strings.ToLower(text)
+// if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") {
+// if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
+// common.SysLog(fmt.Sprintf("error: %s", text))
+// text = "请求上游地址失败"
+// }
+// }
+// openAIError := dto.OpenAIError{
+// Message: text,
+// Type: "new_api_error",
+// Code: code,
+// }
+// return &dto.OpenAIErrorWithStatusCode{
+// Error: openAIError,
+// StatusCode: statusCode,
+// }
+//}
+//
+//func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
+// openaiErr := OpenAIErrorWrapper(err, code, statusCode)
+// openaiErr.LocalError = true
+// return openaiErr
+//}
func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
text := err.Error()
lowerText := strings.ToLower(text)
- if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
- common.SysLog(fmt.Sprintf("error: %s", text))
- text = "请求上游地址失败"
+ if !strings.HasPrefix(lowerText, "get file base64 from url") {
+ if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
+ common.SysLog(fmt.Sprintf("error: %s", text))
+ text = "请求上游地址失败"
+ }
}
- claudeError := dto.ClaudeError{
+ claudeError := types.ClaudeError{
Message: text,
Type: "new_api_error",
}
@@ -73,61 +78,53 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
return claudeErr
}
-func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
- errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
- StatusCode: resp.StatusCode,
- Error: dto.OpenAIError{
- Type: "upstream_error",
- Code: "bad_response_status_code",
- Param: strconv.Itoa(resp.StatusCode),
- },
- }
+func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
+ newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
+
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
- err = resp.Body.Close()
- if err != nil {
- return
- }
+ CloseResponseBodyGracefully(resp)
var errResponse dto.GeneralErrorResponse
- err = json.Unmarshal(responseBody, &errResponse)
+
+ err = common.Unmarshal(responseBody, &errResponse)
if err != nil {
if showBodyWhenFail {
- errWithStatusCode.Error.Message = string(responseBody)
+ newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
} else {
- errWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
+ if common.DebugEnabled {
+ println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
+ }
+ newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
}
return
}
if errResponse.Error.Message != "" {
- // OpenAI format error, so we override the default one
- errWithStatusCode.Error = errResponse.Error
+ // General format error (OpenAI, Anthropic, Gemini, etc.)
+ newApiErr = types.WithOpenAIError(errResponse.Error, resp.StatusCode)
} else {
- errWithStatusCode.Error.Message = errResponse.ToMessage()
- }
- if errWithStatusCode.Error.Message == "" {
- errWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
+ newApiErr = types.NewOpenAIError(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
}
return
}
-func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMappingStr string) {
+func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string) {
if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" {
return
}
statusCodeMapping := make(map[string]string)
- err := json.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
+ err := common.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
if err != nil {
return
}
- if openaiErr.StatusCode == http.StatusOK {
+ if newApiErr.StatusCode == http.StatusOK {
return
}
- codeStr := strconv.Itoa(openaiErr.StatusCode)
+ codeStr := strconv.Itoa(newApiErr.StatusCode)
if _, ok := statusCodeMapping[codeStr]; ok {
intCode, _ := strconv.Atoi(statusCodeMapping[codeStr])
- openaiErr.StatusCode = intCode
+ newApiErr.StatusCode = intCode
}
}
diff --git a/service/file_decoder.go b/service/file_decoder.go
index bbb188f8..94f3f028 100644
--- a/service/file_decoder.go
+++ b/service/file_decoder.go
@@ -1,17 +1,145 @@
package service
import (
+ "bytes"
"encoding/base64"
"fmt"
+ "image"
"io"
+ "net/http"
+ "one-api/common"
"one-api/constant"
- "one-api/dto"
+ "one-api/logger"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
)
-func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
+// GetFileTypeFromUrl 获取文件类型,返回 mime type, 例如 image/jpeg, image/png, image/gif, image/bmp, image/tiff, application/pdf
+// 如果获取失败,返回 application/octet-stream
+func GetFileTypeFromUrl(c *gin.Context, url string, reason ...string) (string, error) {
+ response, err := DoDownloadRequest(url, []string{"get_mime_type", strings.Join(reason, ", ")}...)
+ if err != nil {
+ common.SysLog(fmt.Sprintf("fail to get file type from url: %s, error: %s", url, err.Error()))
+ return "", err
+ }
+ defer response.Body.Close()
+
+ if response.StatusCode != 200 {
+ logger.LogError(c, fmt.Sprintf("failed to download file from %s, status code: %d", url, response.StatusCode))
+ return "", fmt.Errorf("failed to download file, status code: %d", response.StatusCode)
+ }
+
+ if headerType := strings.TrimSpace(response.Header.Get("Content-Type")); headerType != "" {
+ if i := strings.Index(headerType, ";"); i != -1 {
+ headerType = headerType[:i]
+ }
+ if headerType != "application/octet-stream" {
+ return headerType, nil
+ }
+ }
+
+ if cd := response.Header.Get("Content-Disposition"); cd != "" {
+ parts := strings.Split(cd, ";")
+ for _, part := range parts {
+ part = strings.TrimSpace(part)
+ if strings.HasPrefix(strings.ToLower(part), "filename=") {
+ name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
+ if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
+ name = name[1 : len(name)-1]
+ }
+ if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
+ ext := strings.ToLower(name[dot+1:])
+ if ext != "" {
+ mt := GetMimeTypeByExtension(ext)
+ if mt != "application/octet-stream" {
+ return mt, nil
+ }
+ }
+ }
+ break
+ }
+ }
+ }
+
+ cleanedURL := url
+ if q := strings.Index(cleanedURL, "?"); q != -1 {
+ cleanedURL = cleanedURL[:q]
+ }
+ if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
+ last := cleanedURL[slash+1:]
+ if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
+ ext := strings.ToLower(last[dot+1:])
+ if ext != "" {
+ mt := GetMimeTypeByExtension(ext)
+ if mt != "application/octet-stream" {
+ return mt, nil
+ }
+ }
+ }
+ }
+
+ var readData []byte
+ limits := []int{512, 8 * 1024, 24 * 1024, 64 * 1024}
+ for _, limit := range limits {
+ logger.LogDebug(c, fmt.Sprintf("Trying to read %d bytes to determine file type", limit))
+ if len(readData) < limit {
+ need := limit - len(readData)
+ tmp := make([]byte, need)
+ n, _ := io.ReadFull(response.Body, tmp)
+ if n > 0 {
+ readData = append(readData, tmp[:n]...)
+ }
+ }
+
+ if len(readData) == 0 {
+ continue
+ }
+
+ sniffed := http.DetectContentType(readData)
+ if sniffed != "" && sniffed != "application/octet-stream" {
+ return sniffed, nil
+ }
+
+ if _, format, err := image.DecodeConfig(bytes.NewReader(readData)); err == nil {
+ switch strings.ToLower(format) {
+ case "jpeg", "jpg":
+ return "image/jpeg", nil
+ case "png":
+ return "image/png", nil
+ case "gif":
+ return "image/gif", nil
+ case "bmp":
+ return "image/bmp", nil
+ case "tiff":
+ return "image/tiff", nil
+ default:
+ if format != "" {
+ return "image/" + strings.ToLower(format), nil
+ }
+ }
+ }
+ }
+
+ // Fallback
+ return "application/octet-stream", nil
+}
+
+func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) {
+ contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url))
+
+ // Check if the file has already been downloaded in this request
+ if cachedData, exists := c.Get(contextKey); exists {
+ if common.DebugEnabled {
+ logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url))
+ }
+ return cachedData.(*types.LocalFileData), nil
+ }
+
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
- resp, err := DoDownloadRequest(url)
+ resp, err := DoDownloadRequest(url, reason...)
if err != nil {
return nil, err
}
@@ -30,9 +158,105 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
// Convert to base64
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
- return &dto.LocalFileData{
+ mimeType := resp.Header.Get("Content-Type")
+ if len(strings.Split(mimeType, ";")) > 1 {
+ // If Content-Type has parameters, take the first part
+ mimeType = strings.Split(mimeType, ";")[0]
+ }
+ if mimeType == "application/octet-stream" {
+ logger.LogDebug(c, fmt.Sprintf("MIME type is application/octet-stream for URL: %s", url))
+ // try to guess the MIME type from the url last segment
+ urlParts := strings.Split(url, "/")
+ if len(urlParts) > 0 {
+ lastSegment := urlParts[len(urlParts)-1]
+ if strings.Contains(lastSegment, ".") {
+ // Extract the file extension
+ filename := strings.Split(lastSegment, ".")
+ if len(filename) > 1 {
+ ext := strings.ToLower(filename[len(filename)-1])
+ // Guess MIME type based on file extension
+ mimeType = GetMimeTypeByExtension(ext)
+ }
+ }
+ } else {
+ // try to guess the MIME type from the file extension
+ fileName := resp.Header.Get("Content-Disposition")
+ if fileName != "" {
+ // Extract the filename from the Content-Disposition header
+ parts := strings.Split(fileName, ";")
+ for _, part := range parts {
+ if strings.HasPrefix(strings.TrimSpace(part), "filename=") {
+ fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename="))
+ // Remove quotes if present
+ if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' {
+ fileName = fileName[1 : len(fileName)-1]
+ }
+ // Guess MIME type based on file extension
+ if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" {
+ mimeType = GetMimeTypeByExtension(ext)
+ }
+ break
+ }
+ }
+ }
+ }
+ }
+ data := &types.LocalFileData{
Base64Data: base64Data,
- MimeType: resp.Header.Get("Content-Type"),
+ MimeType: mimeType,
Size: int64(len(fileBytes)),
- }, nil
+ }
+ // Store the file data in the context to avoid re-downloading
+ c.Set(contextKey, data)
+
+ return data, nil
+}
+
+func GetMimeTypeByExtension(ext string) string {
+ // Convert to lowercase for case-insensitive comparison
+ ext = strings.ToLower(ext)
+ switch ext {
+ // Text files
+ case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm":
+ return "text/plain"
+
+ // Image files
+ case "jpg", "jpeg":
+ return "image/jpeg"
+ case "png":
+ return "image/png"
+ case "gif":
+ return "image/gif"
+
+ // Audio files
+ case "mp3":
+ return "audio/mp3"
+ case "wav":
+ return "audio/wav"
+ case "mpeg":
+ return "audio/mpeg"
+
+ // Video files
+ case "mp4":
+ return "video/mp4"
+ case "wmv":
+ return "video/wmv"
+ case "flv":
+ return "video/flv"
+ case "mov":
+ return "video/mov"
+ case "mpg":
+ return "video/mpg"
+ case "avi":
+ return "video/avi"
+ case "mpegps":
+ return "video/mpegps"
+
+ // Document files
+ case "pdf":
+ return "application/pdf"
+
+ default:
+ return "application/octet-stream" // Default for unknown types
+ }
}
diff --git a/service/http.go b/service/http.go
new file mode 100644
index 00000000..357a2e78
--- /dev/null
+++ b/service/http.go
@@ -0,0 +1,59 @@
+package service
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/logger"
+
+ "github.com/gin-gonic/gin"
+)
+
+func CloseResponseBodyGracefully(httpResponse *http.Response) {
+ if httpResponse == nil || httpResponse.Body == nil {
+ return
+ }
+ err := httpResponse.Body.Close()
+ if err != nil {
+ common.SysError("failed to close response body: " + err.Error())
+ }
+}
+
+func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
+ if c.Writer == nil {
+ return
+ }
+
+ body := io.NopCloser(bytes.NewBuffer(data))
+
+ // We shouldn't set the header before we parse the response body, because the parse part may fail.
+ // And then we will have to send an error response, but in this case, the header has already been set.
+ // So the httpClient will be confused by the response.
+ // For example, Postman will report error, and we cannot check the response at all.
+ if src != nil {
+ for k, v := range src.Header {
+ // avoid setting Content-Length
+ if k == "Content-Length" {
+ continue
+ }
+ c.Writer.Header().Set(k, v[0])
+ }
+ }
+
+ // set Content-Length header manually BEFORE calling WriteHeader
+ c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
+
+ // Write header with status code (this sends the headers)
+ if src != nil {
+ c.Writer.WriteHeader(src.StatusCode)
+ } else {
+ c.Writer.WriteHeader(http.StatusOK)
+ }
+
+ _, err := io.Copy(c.Writer, body)
+ if err != nil {
+ logger.LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
+ }
+}
diff --git a/service/http_client.go b/service/http_client.go
index 64a361cf..b191ddd7 100644
--- a/service/http_client.go
+++ b/service/http_client.go
@@ -13,9 +13,8 @@ import (
)
var httpClient *http.Client
-var impatientHTTPClient *http.Client
-func init() {
+func InitHttpClient() {
if common.RelayTimeout == 0 {
httpClient = &http.Client{}
} else {
@@ -23,20 +22,12 @@ func init() {
Timeout: time.Duration(common.RelayTimeout) * time.Second,
}
}
-
- impatientHTTPClient = &http.Client{
- Timeout: 5 * time.Second,
- }
}
func GetHttpClient() *http.Client {
return httpClient
}
-func GetImpatientHttpClient() *http.Client {
- return impatientHTTPClient
-}
-
// NewProxyHttpClient 创建支持代理的 HTTP 客户端
func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
if proxyURL == "" {
diff --git a/service/log_info_generate.go b/service/log_info_generate.go
index 75457b97..7a609c9f 100644
--- a/service/log_info_generate.go
+++ b/service/log_info_generate.go
@@ -1,14 +1,17 @@
package service
import (
+ "one-api/common"
+ "one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
+ "one-api/types"
"github.com/gin-gonic/gin"
)
func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
- cacheTokens int, cacheRatio float64, modelPrice float64) map[string]interface{} {
+ cacheTokens int, cacheRatio float64, modelPrice float64, userGroupRatio float64) map[string]interface{} {
other := make(map[string]interface{})
other["model_ratio"] = modelRatio
other["group_ratio"] = groupRatio
@@ -16,6 +19,7 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
other["cache_tokens"] = cacheTokens
other["cache_ratio"] = cacheRatio
other["model_price"] = modelPrice
+ other["user_group_ratio"] = userGroupRatio
other["frt"] = float64(relayInfo.FirstResponseTime.UnixMilli() - relayInfo.StartTime.UnixMilli())
if relayInfo.ReasoningEffort != "" {
other["reasoning_effort"] = relayInfo.ReasoningEffort
@@ -24,14 +28,25 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
other["is_model_mapped"] = true
other["upstream_model_name"] = relayInfo.UpstreamModelName
}
+
+ isSystemPromptOverwritten := common.GetContextKeyBool(ctx, constant.ContextKeySystemPromptOverride)
+ if isSystemPromptOverwritten {
+ other["is_system_prompt_overwritten"] = true
+ }
+
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
+ isMultiKey := common.GetContextKeyBool(ctx, constant.ContextKeyChannelIsMultiKey)
+ if isMultiKey {
+ adminInfo["is_multi_key"] = true
+ adminInfo["multi_key_index"] = common.GetContextKeyInt(ctx, constant.ContextKeyChannelMultiKeyIndex)
+ }
other["admin_info"] = adminInfo
return other
}
-func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice float64) map[string]interface{} {
- info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice)
+func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
+ info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
info["ws"] = true
info["audio_input"] = usage.InputTokenDetails.AudioTokens
info["audio_output"] = usage.OutputTokenDetails.AudioTokens
@@ -42,8 +57,8 @@ func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, us
return info
}
-func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice float64) map[string]interface{} {
- info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice)
+func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
+ info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
info["audio"] = true
info["audio_input"] = usage.PromptTokensDetails.AudioTokens
info["audio_output"] = usage.CompletionTokenDetails.AudioTokens
@@ -55,10 +70,20 @@ func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
}
func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
- cacheTokens int, cacheRatio float64, cacheCreationTokens int, cacheCreationRatio float64, modelPrice float64) map[string]interface{} {
- info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
+ cacheTokens int, cacheRatio float64, cacheCreationTokens int, cacheCreationRatio float64, modelPrice float64, userGroupRatio float64) map[string]interface{} {
+ info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio)
info["claude"] = true
info["cache_creation_tokens"] = cacheCreationTokens
info["cache_creation_ratio"] = cacheCreationRatio
return info
}
+
+func GenerateMjOtherInfo(priceData types.PerCallPriceData) map[string]interface{} {
+ other := make(map[string]interface{})
+ other["model_price"] = priceData.ModelPrice
+ other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio
+ if priceData.GroupRatioInfo.HasSpecialRatio {
+ other["user_group_ratio"] = priceData.GroupRatioInfo.GroupSpecialRatio
+ }
+ return other
+}
diff --git a/service/midjourney.go b/service/midjourney.go
index 635c29ae..916d02d0 100644
--- a/service/midjourney.go
+++ b/service/midjourney.go
@@ -3,7 +3,6 @@ package service
import (
"context"
"encoding/json"
- "github.com/gin-gonic/gin"
"io"
"log"
"net/http"
@@ -15,6 +14,8 @@ import (
"strconv"
"strings"
"time"
+
+ "github.com/gin-gonic/gin"
)
func CoverActionToModelName(mjAction string) string {
@@ -38,6 +39,10 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
switch relayMode {
case relayconstant.RelayModeMidjourneyImagine:
action = constant.MjActionImagine
+ case relayconstant.RelayModeMidjourneyVideo:
+ action = constant.MjActionVideo
+ case relayconstant.RelayModeMidjourneyEdits:
+ action = constant.MjActionEdits
case relayconstant.RelayModeMidjourneyDescribe:
action = constant.MjActionDescribe
case relayconstant.RelayModeMidjourneyBlend:
@@ -199,7 +204,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
req = req.WithContext(ctx)
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- auth := c.Request.Header.Get("Authorization")
+ auth := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
if auth != "" {
auth = strings.TrimPrefix(auth, "Bearer ")
req.Header.Set("mj-api-secret", auth)
@@ -207,7 +212,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
defer cancel()
resp, err := GetHttpClient().Do(req)
if err != nil {
- common.SysError("do request failed: " + err.Error())
+ common.SysLog("do request failed: " + err.Error())
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
}
statusCode := resp.StatusCode
@@ -228,10 +233,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
}
- err = resp.Body.Close()
- if err != nil {
- return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
- }
+ CloseResponseBodyGracefully(resp)
respStr := string(responseBody)
log.Printf("respStr: %s", respStr)
if respStr == "" {
diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go
new file mode 100644
index 00000000..08e3f68f
--- /dev/null
+++ b/service/pre_consume_quota.go
@@ -0,0 +1,79 @@
+package service
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/logger"
+ "one-api/model"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+
+ "github.com/bytedance/gopkg/util/gopool"
+ "github.com/gin-gonic/gin"
+)
+
+func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) {
+ if preConsumedQuota != 0 {
+ logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota)))
+ gopool.Go(func() {
+ relayInfoCopy := *relayInfo
+
+ err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
+ if err != nil {
+ common.SysLog("error return pre-consumed quota: " + err.Error())
+ }
+ })
+ }
+}
+
+// PreConsumeQuota checks if the user has enough quota to pre-consume.
+// It returns the pre-consumed quota if successful, or an error if not.
+func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) {
+ userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
+ if err != nil {
+ return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
+ }
+ if userQuota <= 0 {
+ return 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+ }
+ if userQuota-preConsumedQuota < 0 {
+ return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+ }
+
+ trustQuota := common.GetTrustQuota()
+
+ relayInfo.UserQuota = userQuota
+ if userQuota > trustQuota {
+ // 用户额度充足,判断令牌额度是否充足
+ if !relayInfo.TokenUnlimited {
+ // 非无限令牌,判断令牌额度是否充足
+ tokenQuota := c.GetInt("token_quota")
+ if tokenQuota > trustQuota {
+ // 令牌额度充足,信任令牌
+ preConsumedQuota = 0
+ logger.LogInfo(c, fmt.Sprintf("用户 %d 剩余额度 %s 且令牌 %d 额度 %d 充足, 信任且不需要预扣费", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
+ }
+ } else {
+ // in this case, we do not pre-consume quota
+ // because the user has enough quota
+ preConsumedQuota = 0
+ logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足且为无限额度令牌, 信任且不需要预扣费", relayInfo.UserId))
+ }
+ }
+
+ if preConsumedQuota > 0 {
+ err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
+ if err != nil {
+ return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
+ }
+ err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
+ if err != nil {
+ return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
+ }
+ logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
+ }
+ relayInfo.FinalPreConsumedQuota = preConsumedQuota
+ return preConsumedQuota, nil
+}
diff --git a/service/quota.go b/service/quota.go
index 0d11b4a0..8f65bd20 100644
--- a/service/quota.go
+++ b/service/quota.go
@@ -3,14 +3,17 @@ package service
import (
"errors"
"fmt"
+ "log"
+ "math"
"one-api/common"
- constant2 "one-api/constant"
+ "one-api/constant"
"one-api/dto"
+ "one-api/logger"
"one-api/model"
relaycommon "one-api/relay/common"
- "one-api/relay/helper"
"one-api/setting"
- "one-api/setting/operation_setting"
+ "one-api/setting/ratio_setting"
+ "one-api/types"
"strings"
"time"
@@ -35,6 +38,14 @@ type QuotaInfo struct {
GroupRatio float64
}
+func hasCustomModelRatio(modelName string, currentRatio float64) bool {
+ defaultRatio, exists := ratio_setting.GetDefaultModelRatioMap()[modelName]
+ if !exists {
+ return true
+ }
+ return currentRatio != defaultRatio
+}
+
func calculateAudioQuota(info QuotaInfo) int {
if info.UsePrice {
modelPrice := decimal.NewFromFloat(info.ModelPrice)
@@ -45,9 +56,9 @@ func calculateAudioQuota(info QuotaInfo) int {
return int(quota.IntPart())
}
- completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(info.ModelName))
- audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(info.ModelName))
- audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(info.ModelName))
+ completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(info.ModelName))
+ audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(info.ModelName))
+ audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(info.ModelName))
groupRatio := decimal.NewFromFloat(info.GroupRatio)
modelRatio := decimal.NewFromFloat(info.ModelRatio)
@@ -93,8 +104,21 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
textOutTokens := usage.OutputTokenDetails.TextTokens
audioInputTokens := usage.InputTokenDetails.AudioTokens
audioOutTokens := usage.OutputTokenDetails.AudioTokens
- groupRatio := setting.GetGroupRatio(relayInfo.Group)
- modelRatio, _ := operation_setting.GetModelRatio(modelName)
+ groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
+ modelRatio, _, _ := ratio_setting.GetModelRatio(modelName)
+
+ autoGroup, exists := ctx.Get("auto_group")
+ if exists {
+ groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string))
+ log.Printf("final group ratio: %f", groupRatio)
+ relayInfo.UsingGroup = autoGroup.(string)
+ }
+
+ actualGroupRatio := groupRatio
+ userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
+ if ok {
+ actualGroupRatio = userGroupRatio
+ }
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -108,30 +132,29 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
ModelName: modelName,
UsePrice: relayInfo.UsePrice,
ModelRatio: modelRatio,
- GroupRatio: groupRatio,
+ GroupRatio: actualGroupRatio,
}
quota := calculateAudioQuota(quotaInfo)
if userQuota < quota {
- return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota))
+ return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(quota))
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
- return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
+ return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
}
err = PostConsumeQuota(relayInfo, quota, 0, false)
if err != nil {
return err
}
- common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
+ logger.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
return nil
}
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
- usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
- modelPrice float64, usePrice bool, extraContent string) {
+ usage *dto.RealtimeUsage, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.InputTokenDetails.TextTokens
@@ -141,9 +164,14 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
audioOutTokens := usage.OutputTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
- completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(modelName))
- audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
- audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName))
+ completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(modelName))
+ audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
+ audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
+
+ modelRatio := relayInfo.PriceData.ModelRatio
+ groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
+ modelPrice := relayInfo.PriceData.ModelPrice
+ usePrice := relayInfo.PriceData.UsePrice
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -177,8 +205,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
- common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
- "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
+ logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else {
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
@@ -189,13 +217,24 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
logContent += ", " + extraContent
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
- completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
- model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
- tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
+ completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
+ model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
+ ChannelId: relayInfo.ChannelId,
+ PromptTokens: usage.InputTokens,
+ CompletionTokens: usage.OutputTokens,
+ ModelName: logModel,
+ TokenName: tokenName,
+ Quota: quota,
+ Content: logContent,
+ TokenId: relayInfo.TokenId,
+ UseTimeSeconds: int(useTimeSeconds),
+ IsStream: relayInfo.IsStream,
+ Group: relayInfo.UsingGroup,
+ Other: other,
+ })
}
-func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
- usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
+func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
@@ -203,19 +242,30 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
- completionRatio := priceData.CompletionRatio
- modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatio
- modelPrice := priceData.ModelPrice
-
- cacheRatio := priceData.CacheRatio
+ completionRatio := relayInfo.PriceData.CompletionRatio
+ modelRatio := relayInfo.PriceData.ModelRatio
+ groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
+ modelPrice := relayInfo.PriceData.ModelPrice
+ cacheRatio := relayInfo.PriceData.CacheRatio
cacheTokens := usage.PromptTokensDetails.CachedTokens
- cacheCreationRatio := priceData.CacheCreationRatio
+ cacheCreationRatio := relayInfo.PriceData.CacheCreationRatio
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
+ if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
+ promptTokens -= cacheTokens
+ isUsingCustomSettings := relayInfo.PriceData.UsePrice || hasCustomModelRatio(modelName, relayInfo.PriceData.ModelRatio)
+ if cacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings {
+ maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData)
+ if maybeCacheCreationTokens >= 0 && promptTokens >= maybeCacheCreationTokens {
+ cacheCreationTokens = maybeCacheCreationTokens
+ }
+ }
+ promptTokens -= cacheCreationTokens
+ }
+
calculateQuota := 0.0
- if !priceData.UsePrice {
+ if !relayInfo.PriceData.UsePrice {
calculateQuota = float64(promptTokens)
calculateQuota += float64(cacheTokens) * cacheRatio
calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio
@@ -240,29 +290,77 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游出错)")
- common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
- "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
+ logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else {
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
- quotaDelta := quota - preConsumedQuota
+ quotaDelta := quota - relayInfo.FinalPreConsumedQuota
+
+ if quotaDelta > 0 {
+ logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
+ logger.FormatQuota(quotaDelta),
+ logger.FormatQuota(quota),
+ logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
+ ))
+ } else if quotaDelta < 0 {
+ logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
+ logger.FormatQuota(-quotaDelta),
+ logger.FormatQuota(quota),
+ logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
+ ))
+ }
+
if quotaDelta != 0 {
- err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
+ err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
if err != nil {
- common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+ logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
- cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice)
- model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
- tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
+ cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
+ model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
+ ChannelId: relayInfo.ChannelId,
+ PromptTokens: promptTokens,
+ CompletionTokens: completionTokens,
+ ModelName: modelName,
+ TokenName: tokenName,
+ Quota: quota,
+ Content: logContent,
+ TokenId: relayInfo.TokenId,
+ UseTimeSeconds: int(useTimeSeconds),
+ IsStream: relayInfo.IsStream,
+ Group: relayInfo.UsingGroup,
+ Other: other,
+ })
+
}
-func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
- usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
+func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) int {
+ if priceData.CacheCreationRatio == 1 {
+ return 0
+ }
+ quotaPrice := priceData.ModelRatio / common.QuotaPerUnit
+ promptCacheCreatePrice := quotaPrice * priceData.CacheCreationRatio
+ promptCacheReadPrice := quotaPrice * priceData.CacheRatio
+ completionPrice := quotaPrice * priceData.CompletionRatio
+
+ cost, _ := usage.Cost.(float64)
+ totalPromptTokens := float64(usage.PromptTokens)
+ completionTokens := float64(usage.CompletionTokens)
+ promptCacheReadTokens := float64(usage.PromptTokensDetails.CachedTokens)
+
+ return int(math.Round((cost -
+ totalPromptTokens*quotaPrice +
+ promptCacheReadTokens*(quotaPrice-promptCacheReadPrice) -
+ completionTokens*completionPrice) /
+ (promptCacheCreatePrice - quotaPrice)))
+}
+
+func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.PromptTokensDetails.TextTokens
@@ -272,14 +370,14 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
tokenName := ctx.GetString("token_name")
- completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(relayInfo.OriginModelName))
- audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
- audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
+ completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(relayInfo.OriginModelName))
+ audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
+ audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
- modelRatio := priceData.ModelRatio
- groupRatio := priceData.GroupRatio
- modelPrice := priceData.ModelPrice
- usePrice := priceData.UsePrice
+ modelRatio := relayInfo.PriceData.ModelRatio
+ groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
+ modelPrice := relayInfo.PriceData.ModelPrice
+ usePrice := relayInfo.PriceData.UsePrice
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -313,18 +411,33 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
- common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
- "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
+ logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, relayInfo.FinalPreConsumedQuota))
} else {
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
- quotaDelta := quota - preConsumedQuota
+ quotaDelta := quota - relayInfo.FinalPreConsumedQuota
+
+ if quotaDelta > 0 {
+ logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
+ logger.FormatQuota(quotaDelta),
+ logger.FormatQuota(quota),
+ logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
+ ))
+ } else if quotaDelta < 0 {
+ logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
+ logger.FormatQuota(-quotaDelta),
+ logger.FormatQuota(quota),
+ logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
+ ))
+ }
+
if quotaDelta != 0 {
- err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
+ err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
if err != nil {
- common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+ logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
@@ -333,9 +446,21 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
logContent += ", " + extraContent
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
- completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
- model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
- tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
+ completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
+ model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
+ ChannelId: relayInfo.ChannelId,
+ PromptTokens: usage.PromptTokens,
+ CompletionTokens: usage.CompletionTokens,
+ ModelName: logModel,
+ TokenName: tokenName,
+ Quota: quota,
+ Content: logContent,
+ TokenId: relayInfo.TokenId,
+ UseTimeSeconds: int(useTimeSeconds),
+ IsStream: relayInfo.IsStream,
+ Group: relayInfo.UsingGroup,
+ Other: other,
+ })
}
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
@@ -353,7 +478,7 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
return err
}
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
- return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
+ return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
}
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
@@ -397,8 +522,8 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
gopool.Go(func() {
userSetting := relayInfo.UserSetting
threshold := common.QuotaRemindThreshold
- if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
- threshold = int(userCustomThreshold.(float64))
+ if userSetting.QuotaWarningThreshold != 0 {
+ threshold = int(userSetting.QuotaWarningThreshold)
}
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
@@ -411,7 +536,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
prompt := "您的额度即将用尽"
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}"
- err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
+ err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
if err != nil {
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
}
diff --git a/service/sensitive.go b/service/sensitive.go
index b3e3c4d6..ed033daa 100644
--- a/service/sensitive.go
+++ b/service/sensitive.go
@@ -2,7 +2,6 @@ package service
import (
"errors"
- "fmt"
"one-api/dto"
"one-api/setting"
"strings"
@@ -32,25 +31,8 @@ func CheckSensitiveMessages(messages []dto.Message) ([]string, error) {
return nil, nil
}
-func CheckSensitiveText(text string) ([]string, error) {
- if ok, words := SensitiveWordContains(text); ok {
- return words, errors.New("sensitive words detected")
- }
- return nil, nil
-}
-
-func CheckSensitiveInput(input any) ([]string, error) {
- switch v := input.(type) {
- case string:
- return CheckSensitiveText(v)
- case []string:
- var builder strings.Builder
- for _, s := range v {
- builder.WriteString(s)
- }
- return CheckSensitiveText(builder.String())
- }
- return CheckSensitiveText(fmt.Sprintf("%v", input))
+func CheckSensitiveText(text string) (bool, []string) {
+ return SensitiveWordContains(text)
}
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
@@ -71,7 +53,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
return false, nil, text
}
checkText := strings.ToLower(text)
- m := InitAc(setting.SensitiveWords)
+ m := getOrBuildAC(setting.SensitiveWords)
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
if len(hits) > 0 {
words := make([]string, 0, len(hits))
diff --git a/service/str.go b/service/str.go
index 4390e99b..61054bdc 100644
--- a/service/str.go
+++ b/service/str.go
@@ -3,8 +3,12 @@ package service
import (
"bytes"
"fmt"
- goahocorasick "github.com/anknown/ahocorasick"
+ "hash/fnv"
+ "sort"
"strings"
+ "sync"
+
+ goahocorasick "github.com/anknown/ahocorasick"
)
func SundaySearch(text string, pattern string) bool {
@@ -56,26 +60,73 @@ func RemoveDuplicate(s []string) []string {
return result
}
-func InitAc(words []string) *goahocorasick.Machine {
+func InitAc(dict []string) *goahocorasick.Machine {
m := new(goahocorasick.Machine)
- dict := readRunes(words)
- if err := m.Build(dict); err != nil {
+ runes := readRunes(dict)
+ if err := m.Build(runes); err != nil {
fmt.Println(err)
return nil
}
return m
}
-func readRunes(words []string) [][]rune {
- var dict [][]rune
+var acCache sync.Map
- for _, word := range words {
+func acKey(dict []string) string {
+ if len(dict) == 0 {
+ return ""
+ }
+ normalized := make([]string, 0, len(dict))
+ for _, w := range dict {
+ w = strings.ToLower(strings.TrimSpace(w))
+ if w != "" {
+ normalized = append(normalized, w)
+ }
+ }
+ if len(normalized) == 0 {
+ return ""
+ }
+ sort.Strings(normalized)
+ hasher := fnv.New64a()
+ for _, w := range normalized {
+ hasher.Write([]byte{0})
+ hasher.Write([]byte(w))
+ }
+ return fmt.Sprintf("%x", hasher.Sum64())
+}
+
+func getOrBuildAC(dict []string) *goahocorasick.Machine {
+ key := acKey(dict)
+ if key == "" {
+ return nil
+ }
+ if v, ok := acCache.Load(key); ok {
+ if m, ok2 := v.(*goahocorasick.Machine); ok2 {
+ return m
+ }
+ }
+ m := InitAc(dict)
+ if m == nil {
+ return nil
+ }
+ if actual, loaded := acCache.LoadOrStore(key, m); loaded {
+ if cached, ok := actual.(*goahocorasick.Machine); ok {
+ return cached
+ }
+ }
+ return m
+}
+
+func readRunes(dict []string) [][]rune {
+ var runes [][]rune
+
+ for _, word := range dict {
word = strings.ToLower(word)
l := bytes.TrimSpace([]byte(word))
- dict = append(dict, bytes.Runes(l))
+ runes = append(runes, bytes.Runes(l))
}
- return dict
+ return runes
}
func AcSearch(findText string, dict []string, stopImmediately bool) (bool, []string) {
@@ -85,7 +136,7 @@ func AcSearch(findText string, dict []string, stopImmediately bool) (bool, []str
if len(findText) == 0 {
return false, nil
}
- m := InitAc(dict)
+ m := getOrBuildAC(dict)
if m == nil {
return false, nil
}
diff --git a/service/token_counter.go b/service/token_counter.go
index d63b54ad..bac6c067 100644
--- a/service/token_counter.go
+++ b/service/token_counter.go
@@ -11,126 +11,165 @@ import (
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
- "one-api/setting/operation_setting"
+ "one-api/types"
"strings"
+ "sync"
"unicode/utf8"
- "github.com/pkoukk/tiktoken-go"
+ "github.com/gin-gonic/gin"
+ "github.com/tiktoken-go/tokenizer"
+ "github.com/tiktoken-go/tokenizer/codec"
)
// tokenEncoderMap won't grow after initialization
-var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
-var defaultTokenEncoder *tiktoken.Tiktoken
-var o200kTokenEncoder *tiktoken.Tiktoken
+var defaultTokenEncoder tokenizer.Codec
+
+// tokenEncoderMap is used to store token encoders for different models
+var tokenEncoderMap = make(map[string]tokenizer.Codec)
+
+// tokenEncoderMutex protects tokenEncoderMap for concurrent access
+var tokenEncoderMutex sync.RWMutex
func InitTokenEncoders() {
common.SysLog("initializing token encoders")
- cl100TokenEncoder, err := tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
- if err != nil {
- common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
- }
- defaultTokenEncoder = cl100TokenEncoder
- o200kTokenEncoder, err = tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
- if err != nil {
- common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
- }
- for model, _ := range operation_setting.GetDefaultModelRatioMap() {
- if strings.HasPrefix(model, "gpt-3.5") {
- tokenEncoderMap[model] = cl100TokenEncoder
- } else if strings.HasPrefix(model, "gpt-4") {
- if strings.HasPrefix(model, "gpt-4o") {
- tokenEncoderMap[model] = o200kTokenEncoder
- } else {
- tokenEncoderMap[model] = defaultTokenEncoder
- }
- } else if strings.HasPrefix(model, "o") {
- tokenEncoderMap[model] = o200kTokenEncoder
- } else {
- tokenEncoderMap[model] = defaultTokenEncoder
- }
- }
+ defaultTokenEncoder = codec.NewCl100kBase()
common.SysLog("token encoders initialized")
}
-func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
- if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") || strings.HasPrefix(model, "o1") {
- return o200kTokenEncoder
+func getTokenEncoder(model string) tokenizer.Codec {
+ // First, try to get the encoder from cache with read lock
+ tokenEncoderMutex.RLock()
+ if encoder, exists := tokenEncoderMap[model]; exists {
+ tokenEncoderMutex.RUnlock()
+ return encoder
}
- return defaultTokenEncoder
+ tokenEncoderMutex.RUnlock()
+
+ // If not in cache, create new encoder with write lock
+ tokenEncoderMutex.Lock()
+ defer tokenEncoderMutex.Unlock()
+
+ // Double-check if another goroutine already created the encoder
+ if encoder, exists := tokenEncoderMap[model]; exists {
+ return encoder
+ }
+
+ // Create new encoder
+ modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
+ if err != nil {
+ // Cache the default encoder for this model to avoid repeated failures
+ tokenEncoderMap[model] = defaultTokenEncoder
+ return defaultTokenEncoder
+ }
+
+ // Cache the new encoder
+ tokenEncoderMap[model] = modelCodec
+ return modelCodec
}
-func getTokenEncoder(model string) *tiktoken.Tiktoken {
- tokenEncoder, ok := tokenEncoderMap[model]
- if ok && tokenEncoder != nil {
- return tokenEncoder
- }
- // 如果ok(即model在tokenEncoderMap中),但是tokenEncoder为nil,说明可能是自定义模型
- if ok {
- tokenEncoder, err := tiktoken.EncodingForModel(model)
- if err != nil {
- common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
- tokenEncoder = getModelDefaultTokenEncoder(model)
- }
- tokenEncoderMap[model] = tokenEncoder
- return tokenEncoder
- }
- // 如果model不在tokenEncoderMap中,直接返回默认的tokenEncoder
- return getModelDefaultTokenEncoder(model)
-}
-
-func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
+func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
if text == "" {
return 0
}
- return len(tokenEncoder.Encode(text, nil, nil))
+ tkm, _ := tokenEncoder.Count(text)
+ return tkm
}
-func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
- if imageUrl == nil {
+func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
+ if fileMeta == nil {
return 0, fmt.Errorf("image_url_is_nil")
}
+
+ // Defaults for 4o/4.1/4.5 family unless overridden below
baseTokens := 85
- if model == "glm-4v" {
+ tileTokens := 170
+
+ // Model classification
+ lowerModel := strings.ToLower(model)
+
+ // Special cases from existing behavior
+ if strings.HasPrefix(lowerModel, "glm-4") {
return 1047, nil
}
- if imageUrl.Detail == "low" {
+
+ // Patch-based models (32x32 patches, capped at 1536, with multiplier)
+ isPatchBased := false
+ multiplier := 1.0
+ switch {
+ case strings.Contains(lowerModel, "gpt-4.1-mini"):
+ isPatchBased = true
+ multiplier = 1.62
+ case strings.Contains(lowerModel, "gpt-4.1-nano"):
+ isPatchBased = true
+ multiplier = 2.46
+ case strings.HasPrefix(lowerModel, "o4-mini"):
+ isPatchBased = true
+ multiplier = 1.72
+ case strings.HasPrefix(lowerModel, "gpt-5-mini"):
+ isPatchBased = true
+ multiplier = 1.62
+ case strings.HasPrefix(lowerModel, "gpt-5-nano"):
+ isPatchBased = true
+ multiplier = 2.46
+ }
+
+ // Tile-based model tokens and bases per doc
+ if !isPatchBased {
+ if strings.HasPrefix(lowerModel, "gpt-4o-mini") {
+ baseTokens = 2833
+ tileTokens = 5667
+ } else if strings.HasPrefix(lowerModel, "gpt-5-chat-latest") || (strings.HasPrefix(lowerModel, "gpt-5") && !strings.Contains(lowerModel, "mini") && !strings.Contains(lowerModel, "nano")) {
+ baseTokens = 70
+ tileTokens = 140
+ } else if strings.HasPrefix(lowerModel, "o1") || strings.HasPrefix(lowerModel, "o3") || strings.HasPrefix(lowerModel, "o1-pro") {
+ baseTokens = 75
+ tileTokens = 150
+ } else if strings.Contains(lowerModel, "computer-use-preview") {
+ baseTokens = 65
+ tileTokens = 129
+ } else if strings.Contains(lowerModel, "4.1") || strings.Contains(lowerModel, "4o") || strings.Contains(lowerModel, "4.5") {
+ baseTokens = 85
+ tileTokens = 170
+ }
+ }
+
+ // Respect existing feature flags/short-circuits
+ if fileMeta.Detail == "low" && !isPatchBased {
return baseTokens, nil
}
if !constant.GetMediaTokenNotStream && !stream {
return 3 * baseTokens, nil
}
-
- // 同步One API的图片计费逻辑
- if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
- imageUrl.Detail = "high"
+ // Normalize detail
+ if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
+ fileMeta.Detail = "high"
}
-
- tileTokens := 170
- if strings.HasPrefix(model, "gpt-4o-mini") {
- tileTokens = 5667
- baseTokens = 2833
- }
- // 是否统计图片token
+ // Whether to count image tokens at all
if !constant.GetMediaToken {
return 3 * baseTokens, nil
}
- if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic {
- return 3 * baseTokens, nil
- }
+
+ // Decode image to get dimensions
var config image.Config
var err error
var format string
var b64str string
- if strings.HasPrefix(imageUrl.Url, "http") {
- config, format, err = DecodeUrlImageData(imageUrl.Url)
+
+ if fileMeta.ParsedData != nil {
+ config, format, b64str, err = DecodeBase64ImageData(fileMeta.ParsedData.Base64Data)
} else {
- common.SysLog(fmt.Sprintf("decoding image"))
- config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url)
+ if strings.HasPrefix(fileMeta.OriginData, "http") {
+ config, format, err = DecodeUrlImageData(fileMeta.OriginData)
+ } else {
+ common.SysLog(fmt.Sprintf("decoding image"))
+ config, format, b64str, err = DecodeBase64ImageData(fileMeta.OriginData)
+ }
+ fileMeta.MimeType = format
}
+
if err != nil {
return 0, err
}
- imageUrl.MimeType = format
if config.Width == 0 || config.Height == 0 {
// not an image
@@ -138,63 +177,184 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
// file type
return 3 * baseTokens, nil
}
- return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url))
+ return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.OriginData))
}
- shortSide := config.Width
- otherSide := config.Height
- log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height)
- // 缩放倍数
- scale := 1.0
- if config.Height < shortSide {
- shortSide = config.Height
- otherSide = config.Width
+ width := config.Width
+ height := config.Height
+ log.Printf("format: %s, width: %d, height: %d", format, width, height)
+
+ if isPatchBased {
+ // 32x32 patch-based calculation with 1536 cap and model multiplier
+ ceilDiv := func(a, b int) int { return (a + b - 1) / b }
+ rawPatchesW := ceilDiv(width, 32)
+ rawPatchesH := ceilDiv(height, 32)
+ rawPatches := rawPatchesW * rawPatchesH
+ if rawPatches > 1536 {
+ // scale down
+ area := float64(width * height)
+ r := math.Sqrt(float64(32*32*1536) / area)
+ wScaled := float64(width) * r
+ hScaled := float64(height) * r
+ // adjust to fit whole number of patches after scaling
+ adjW := math.Floor(wScaled/32.0) / (wScaled / 32.0)
+ adjH := math.Floor(hScaled/32.0) / (hScaled / 32.0)
+ adj := math.Min(adjW, adjH)
+ if !math.IsNaN(adj) && adj > 0 {
+ r = r * adj
+ }
+ wScaled = float64(width) * r
+ hScaled = float64(height) * r
+ patchesW := math.Ceil(wScaled / 32.0)
+ patchesH := math.Ceil(hScaled / 32.0)
+ imageTokens := int(patchesW * patchesH)
+ if imageTokens > 1536 {
+ imageTokens = 1536
+ }
+ return int(math.Round(float64(imageTokens) * multiplier)), nil
+ }
+ // below cap
+ imageTokens := rawPatches
+ return int(math.Round(float64(imageTokens) * multiplier)), nil
}
- // 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768
- if shortSide > 768 {
- scale = float64(shortSide) / 768
- shortSide = 768
+ // Tile-based calculation for 4o/4.1/4.5/o1/o3/etc.
+ // Step 1: fit within 2048x2048 square
+ maxSide := math.Max(float64(width), float64(height))
+ fitScale := 1.0
+ if maxSide > 2048 {
+ fitScale = maxSide / 2048.0
}
- // 将另一边按照相同的比例缩小,向上取整
- otherSide = int(math.Ceil(float64(otherSide) / scale))
- log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale)
- // 计算图片的token数量(边的长度除以512,向上取整)
- tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512)
- log.Printf("tiles: %d", tiles)
+ fitW := int(math.Round(float64(width) / fitScale))
+ fitH := int(math.Round(float64(height) / fitScale))
+
+ // Step 2: scale so that shortest side is exactly 768
+ minSide := math.Min(float64(fitW), float64(fitH))
+ if minSide == 0 {
+ return baseTokens, nil
+ }
+ shortScale := 768.0 / minSide
+ finalW := int(math.Round(float64(fitW) * shortScale))
+ finalH := int(math.Round(float64(fitH) * shortScale))
+
+ // Count 512px tiles
+ tilesW := (finalW + 512 - 1) / 512
+ tilesH := (finalH + 512 - 1) / 512
+ tiles := tilesW * tilesH
+
+ if common.DebugEnabled {
+ log.Printf("scaled to: %dx%d, tiles: %d", finalW, finalH, tiles)
+ }
+
return tiles*tileTokens + baseTokens, nil
}
-func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
- tkm := 0
- msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
- if err != nil {
- return 0, err
- }
- tkm += msgTokens
- if request.Tools != nil {
- openaiTools := request.Tools
- countStr := ""
- for _, tool := range openaiTools {
- countStr = tool.Function.Name
- if tool.Function.Description != "" {
- countStr += tool.Function.Description
- }
- if tool.Function.Parameters != nil {
- countStr += fmt.Sprintf("%v", tool.Function.Parameters)
- }
- }
- toolTokens, err := CountTokenInput(countStr, request.Model)
- if err != nil {
- return 0, err
- }
- tkm += 8
- tkm += toolTokens
+func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
+ if meta == nil {
+ return 0, errors.New("token count meta is nil")
}
+ if info.RelayFormat == types.RelayFormatOpenAIRealtime {
+ return 0, nil
+ }
+
+ model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
+ tkm := 0
+
+ if meta.TokenType == types.TokenTypeTextNumber {
+ tkm += utf8.RuneCountInString(meta.CombineText)
+ } else {
+ tkm += CountTextToken(meta.CombineText, model)
+ }
+
+ if info.RelayFormat == types.RelayFormatOpenAI {
+ tkm += meta.ToolsCount * 8
+ tkm += meta.MessagesCount * 3 // 每条消息的格式化token数量
+ tkm += meta.NameCount * 3
+ tkm += 3
+ }
+
+ shouldFetchFiles := true
+
+ if info.RelayFormat == types.RelayFormatOpenAIRealtime || info.RelayFormat == types.RelayFormatGemini {
+ shouldFetchFiles = false
+ }
+
+ if shouldFetchFiles {
+ for _, file := range meta.Files {
+ if strings.HasPrefix(file.OriginData, "http") {
+ mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
+ if err != nil {
+ return 0, fmt.Errorf("error getting file base64 from url: %v", err)
+ }
+ if strings.HasPrefix(mineType, "image/") {
+ file.FileType = types.FileTypeImage
+ } else if strings.HasPrefix(mineType, "video/") {
+ file.FileType = types.FileTypeVideo
+ } else if strings.HasPrefix(mineType, "audio/") {
+ file.FileType = types.FileTypeAudio
+ } else {
+ file.FileType = types.FileTypeFile
+ }
+ file.MimeType = mineType
+ }
+ }
+ }
+
+ for _, file := range meta.Files {
+ switch file.FileType {
+ case types.FileTypeImage:
+ if info.RelayFormat == types.RelayFormatGemini {
+ tkm += 256
+ } else {
+ token, err := getImageToken(file, model, info.IsStream)
+ if err != nil {
+ return 0, fmt.Errorf("error counting image token: %v", err)
+ }
+ tkm += token
+ }
+ case types.FileTypeAudio:
+ tkm += 256
+ case types.FileTypeVideo:
+ tkm += 4096 * 2
+ case types.FileTypeFile:
+ tkm += 4096
+ default:
+ tkm += 4096 // Default case for unknown file types
+ }
+ }
+
+ common.SetContextKey(c, constant.ContextKeyPromptTokens, tkm)
return tkm, nil
}
+//func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
+// tkm := 0
+// msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
+// if err != nil {
+// return 0, err
+// }
+// tkm += msgTokens
+// if request.Tools != nil {
+// openaiTools := request.Tools
+// countStr := ""
+// for _, tool := range openaiTools {
+// countStr = tool.Function.Name
+// if tool.Function.Description != "" {
+// countStr += tool.Function.Description
+// }
+// if tool.Function.Parameters != nil {
+// countStr += fmt.Sprintf("%v", tool.Function.Parameters)
+// }
+// }
+// toolTokens := CountTokenInput(countStr, request.Model)
+// tkm += 8
+// tkm += toolTokens
+// }
+//
+// return tkm, nil
+//}
+
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
tkm := 0
@@ -207,10 +367,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
// Count tokens in system message
if request.System != "" {
- systemTokens, err := CountTokenInput(request.System, model)
- if err != nil {
- return 0, err
- }
+ systemTokens := CountTokenInput(request.System, model)
tkm += systemTokens
}
@@ -261,12 +418,16 @@ func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream
//}
tokenNum += 1000
case "tool_use":
- tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
- inputJSON, _ := json.Marshal(mediaMessage.Input)
- tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
+ if mediaMessage.Input != nil {
+ tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
+ inputJSON, _ := json.Marshal(mediaMessage.Input)
+ tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
+ }
case "tool_result":
- contentJSON, _ := json.Marshal(mediaMessage.Content)
- tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
+ if mediaMessage.Content != nil {
+ contentJSON, _ := json.Marshal(mediaMessage.Content)
+ tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
+ }
}
}
}
@@ -305,10 +466,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
switch request.Type {
case dto.RealtimeEventTypeSessionUpdate:
if request.Session != nil {
- msgTokens, err := CountTextToken(request.Session.Instructions, model)
- if err != nil {
- return 0, 0, err
- }
+ msgTokens := CountTextToken(request.Session.Instructions, model)
textToken += msgTokens
}
case dto.RealtimeEventResponseAudioDelta:
@@ -320,10 +478,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
audioToken += atk
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
// count text token
- tkm, err := CountTextToken(request.Delta, model)
- if err != nil {
- return 0, 0, fmt.Errorf("error counting text token: %v", err)
- }
+ tkm := CountTextToken(request.Delta, model)
textToken += tkm
case dto.RealtimeEventInputAudioBufferAppend:
// count audio token
@@ -338,10 +493,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
case "message":
for _, content := range request.Item.Content {
if content.Type == "input_text" {
- tokens, err := CountTextToken(content.Text, model)
- if err != nil {
- return 0, 0, err
- }
+ tokens := CountTextToken(content.Text, model)
textToken += tokens
}
}
@@ -352,10 +504,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
if !info.IsFirstRequest {
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
for _, tool := range info.RealtimeTools {
- toolTokens, err := CountTokenInput(tool, model)
- if err != nil {
- return 0, 0, err
- }
+ toolTokens := CountTokenInput(tool, model)
textToken += 8
textToken += toolTokens
}
@@ -365,60 +514,57 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
return textToken, audioToken, nil
}
-func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
- //recover when panic
- tokenEncoder := getTokenEncoder(model)
- // Reference:
- // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
- // https://github.com/pkoukk/tiktoken-go/issues/6
- //
- // Every message follows <|start|>{role/name}\n{content}<|end|>\n
- var tokensPerMessage int
- var tokensPerName int
- if model == "gpt-3.5-turbo-0301" {
- tokensPerMessage = 4
- tokensPerName = -1 // If there's a name, the role is omitted
- } else {
- tokensPerMessage = 3
- tokensPerName = 1
- }
- tokenNum := 0
- for _, message := range messages {
- tokenNum += tokensPerMessage
- tokenNum += getTokenNum(tokenEncoder, message.Role)
- if len(message.Content) > 0 {
- if message.Name != nil {
- tokenNum += tokensPerName
- tokenNum += getTokenNum(tokenEncoder, *message.Name)
- }
- arrayContent := message.ParseContent()
- for _, m := range arrayContent {
- if m.Type == dto.ContentTypeImageURL {
- imageUrl := m.GetImageMedia()
- imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
- if err != nil {
- return 0, err
- }
- tokenNum += imageTokenNum
- log.Printf("image token num: %d", imageTokenNum)
- } else if m.Type == dto.ContentTypeInputAudio {
- // TODO: 音频token数量计算
- tokenNum += 100
- } else if m.Type == dto.ContentTypeFile {
- tokenNum += 5000
- } else if m.Type == dto.ContentTypeVideoUrl {
- tokenNum += 5000
- } else {
- tokenNum += getTokenNum(tokenEncoder, m.Text)
- }
- }
- }
- }
- tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
- return tokenNum, nil
-}
+//func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
+// //recover when panic
+// tokenEncoder := getTokenEncoder(model)
+// // Reference:
+// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+// // https://github.com/pkoukk/tiktoken-go/issues/6
+// //
+// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
+// var tokensPerMessage int
+// var tokensPerName int
+//
+// tokensPerMessage = 3
+// tokensPerName = 1
+//
+// tokenNum := 0
+// for _, message := range messages {
+// tokenNum += tokensPerMessage
+// tokenNum += getTokenNum(tokenEncoder, message.Role)
+// if message.Content != nil {
+// if message.Name != nil {
+// tokenNum += tokensPerName
+// tokenNum += getTokenNum(tokenEncoder, *message.Name)
+// }
+// arrayContent := message.ParseContent()
+// for _, m := range arrayContent {
+// if m.Type == dto.ContentTypeImageURL {
+// imageUrl := m.GetImageMedia()
+// imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
+// if err != nil {
+// return 0, err
+// }
+// tokenNum += imageTokenNum
+// log.Printf("image token num: %d", imageTokenNum)
+// } else if m.Type == dto.ContentTypeInputAudio {
+// // TODO: 音频token数量计算
+// tokenNum += 100
+// } else if m.Type == dto.ContentTypeFile {
+// tokenNum += 5000
+// } else if m.Type == dto.ContentTypeVideoUrl {
+// tokenNum += 5000
+// } else {
+// tokenNum += getTokenNum(tokenEncoder, m.Text)
+// }
+// }
+// }
+// }
+// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
+// return tokenNum, nil
+//}
-func CountTokenInput(input any, model string) (int, error) {
+func CountTokenInput(input any, model string) int {
switch v := input.(type) {
case string:
return CountTextToken(v, model)
@@ -441,13 +587,13 @@ func CountTokenInput(input any, model string) (int, error) {
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
tokens := 0
for _, message := range messages {
- tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
+ tkm := CountTokenInput(message.Delta.GetContentString(), model)
tokens += tkm
if message.Delta.ToolCalls != nil {
for _, tool := range message.Delta.ToolCalls {
- tkm, _ := CountTokenInput(tool.Function.Name, model)
+ tkm := CountTokenInput(tool.Function.Name, model)
tokens += tkm
- tkm, _ = CountTokenInput(tool.Function.Arguments, model)
+ tkm = CountTokenInput(tool.Function.Arguments, model)
tokens += tkm
}
}
@@ -455,9 +601,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
return tokens
}
-func CountTTSToken(text string, model string) (int, error) {
+func CountTTSToken(text string, model string) int {
if strings.HasPrefix(model, "tts") {
- return utf8.RuneCountInString(text), nil
+ return utf8.RuneCountInString(text)
} else {
return CountTextToken(text, model)
}
@@ -492,8 +638,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
//}
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
-func CountTextToken(text string, model string) (int, error) {
- var err error
+func CountTextToken(text string, model string) int {
+ if text == "" {
+ return 0
+ }
tokenEncoder := getTokenEncoder(model)
- return getTokenNum(tokenEncoder, text), err
+ return getTokenNum(tokenEncoder, text)
}
diff --git a/service/usage_helpr.go b/service/usage_helpr.go
index c52e1e15..ca9c0830 100644
--- a/service/usage_helpr.go
+++ b/service/usage_helpr.go
@@ -16,13 +16,13 @@ import (
// return 0, errors.New("unknown relay mode")
//}
-func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
+func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
usage := &dto.Usage{}
usage.PromptTokens = promptTokens
- ctkm, err := CountTextToken(responseText, modeName)
+ ctkm := CountTextToken(responseText, modeName)
usage.CompletionTokens = ctkm
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
- return usage, err
+ return usage
}
func ValidUsage(usage *dto.Usage) bool {
diff --git a/service/user_notify.go b/service/user_notify.go
index 51f1ff99..7c864a1b 100644
--- a/service/user_notify.go
+++ b/service/user_notify.go
@@ -3,7 +3,6 @@ package service
import (
"fmt"
"one-api/common"
- "one-api/constant"
"one-api/dto"
"one-api/model"
"strings"
@@ -13,20 +12,20 @@ func NotifyRootUser(t string, subject string, content string) {
user := model.GetRootUser().ToBaseUser()
err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
if err != nil {
- common.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error()))
+ common.SysLog(fmt.Sprintf("failed to notify root user: %s", err.Error()))
}
}
-func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}, data dto.Notify) error {
- notifyType, ok := userSetting[constant.UserSettingNotifyType]
- if !ok {
- notifyType = constant.NotifyTypeEmail
+func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data dto.Notify) error {
+ notifyType := userSetting.NotifyType
+ if notifyType == "" {
+ notifyType = dto.NotifyTypeEmail
}
// Check notification limit
canSend, err := CheckNotificationLimit(userId, data.Type)
if err != nil {
- common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
+ common.SysLog(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
return err
}
if !canSend {
@@ -34,34 +33,23 @@ func NotifyUser(userId int, userEmail string, userSetting map[string]interface{}
}
switch notifyType {
- case constant.NotifyTypeEmail:
+ case dto.NotifyTypeEmail:
// check setting email
- if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
- userEmail = settingEmail.(string)
- }
+ userEmail = userSetting.NotificationEmail
if userEmail == "" {
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
return nil
}
return sendEmailNotify(userEmail, data)
- case constant.NotifyTypeWebhook:
- webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
- if !ok {
- common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
- return nil
- }
- webhookURLStr, ok := webhookURL.(string)
- if !ok {
- common.SysError(fmt.Sprintf("user %d webhook url is not string type", userId))
+ case dto.NotifyTypeWebhook:
+ webhookURLStr := userSetting.WebhookUrl
+ if webhookURLStr == "" {
+ common.SysLog(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
return nil
}
// 获取 webhook secret
- var webhookSecret string
- if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok {
- webhookSecret, _ = secret.(string)
- }
-
+ webhookSecret := userSetting.WebhookSecret
return SendWebhookNotify(webhookURLStr, webhookSecret, data)
}
return nil
diff --git a/service/webhook.go b/service/webhook.go
index ad2967eb..8faccda3 100644
--- a/service/webhook.go
+++ b/service/webhook.go
@@ -101,7 +101,7 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
}
// 发送请求
- client := GetImpatientHttpClient()
+ client := GetHttpClient()
resp, err = client.Do(req)
if err != nil {
return fmt.Errorf("failed to send webhook request: %v", err)
diff --git a/setting/auto_group.go b/setting/auto_group.go
new file mode 100644
index 00000000..5a87ae56
--- /dev/null
+++ b/setting/auto_group.go
@@ -0,0 +1,31 @@
+package setting
+
+import "encoding/json"
+
+var AutoGroups = []string{
+ "default",
+}
+
+var DefaultUseAutoGroup = false
+
+func ContainsAutoGroup(group string) bool {
+ for _, autoGroup := range AutoGroups {
+ if autoGroup == group {
+ return true
+ }
+ }
+ return false
+}
+
+func UpdateAutoGroupsByJsonString(jsonString string) error {
+ AutoGroups = make([]string, 0)
+ return json.Unmarshal([]byte(jsonString), &AutoGroups)
+}
+
+func AutoGroups2JsonString() string {
+ jsonBytes, err := json.Marshal(AutoGroups)
+ if err != nil {
+ return "[]"
+ }
+ return string(jsonBytes)
+}
diff --git a/setting/chat.go b/setting/chat.go
index ef308000..bd1e26e3 100644
--- a/setting/chat.go
+++ b/setting/chat.go
@@ -6,8 +6,14 @@ import (
)
var Chats = []map[string]string{
+ //{
+ // "ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
+ //},
{
- "ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
+ "Cherry Studio": "cherrystudio://providers/api-keys?v=1&data={cherryConfig}",
+ },
+ {
+ "流畅阅读": "fluentread",
},
{
"Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}",
@@ -31,7 +37,7 @@ func UpdateChatsByJsonString(jsonString string) error {
func Chats2JsonString() string {
jsonBytes, err := json.Marshal(Chats)
if err != nil {
- common.SysError("error marshalling chats: " + err.Error())
+ common.SysLog("error marshalling chats: " + err.Error())
return "[]"
}
return string(jsonBytes)
diff --git a/setting/console_setting/config.go b/setting/console_setting/config.go
new file mode 100644
index 00000000..6327e558
--- /dev/null
+++ b/setting/console_setting/config.go
@@ -0,0 +1,39 @@
+package console_setting
+
+import "one-api/setting/config"
+
+type ConsoleSetting struct {
+ ApiInfo string `json:"api_info"` // 控制台 API 信息 (JSON 数组字符串)
+ UptimeKumaGroups string `json:"uptime_kuma_groups"` // Uptime Kuma 分组配置 (JSON 数组字符串)
+ Announcements string `json:"announcements"` // 系统公告 (JSON 数组字符串)
+ FAQ string `json:"faq"` // 常见问题 (JSON 数组字符串)
+ ApiInfoEnabled bool `json:"api_info_enabled"` // 是否启用 API 信息面板
+ UptimeKumaEnabled bool `json:"uptime_kuma_enabled"` // 是否启用 Uptime Kuma 面板
+ AnnouncementsEnabled bool `json:"announcements_enabled"` // 是否启用系统公告面板
+ FAQEnabled bool `json:"faq_enabled"` // 是否启用常见问答面板
+}
+
+// 默认配置
+var defaultConsoleSetting = ConsoleSetting{
+ ApiInfo: "",
+ UptimeKumaGroups: "",
+ Announcements: "",
+ FAQ: "",
+ ApiInfoEnabled: true,
+ UptimeKumaEnabled: true,
+ AnnouncementsEnabled: true,
+ FAQEnabled: true,
+}
+
+// 全局实例
+var consoleSetting = defaultConsoleSetting
+
+func init() {
+ // 注册到全局配置管理器,键名为 console_setting
+ config.GlobalConfig.Register("console_setting", &consoleSetting)
+}
+
+// GetConsoleSetting 获取 ConsoleSetting 配置实例
+func GetConsoleSetting() *ConsoleSetting {
+ return &consoleSetting
+}
\ No newline at end of file
diff --git a/setting/console_setting/validation.go b/setting/console_setting/validation.go
new file mode 100644
index 00000000..fda6453d
--- /dev/null
+++ b/setting/console_setting/validation.go
@@ -0,0 +1,304 @@
+package console_setting
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/url"
+ "regexp"
+ "strings"
+ "time"
+ "sort"
+)
+
+var (
+ urlRegex = regexp.MustCompile(`^https?://(?:(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?|(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))(?:\:[0-9]{1,5})?(?:/.*)?$`)
+ dangerousChars = []string{"
-