Merge branch 'alpha' into main

This commit is contained in:
Calcium-Ion
2025-08-23 15:27:08 +08:00
committed by GitHub
669 changed files with 78247 additions and 30869 deletions

View File

@@ -4,4 +4,5 @@
.vscode .vscode
.gitignore .gitignore
Makefile Makefile
docs docs
.eslintcache

View File

@@ -7,6 +7,8 @@
# 调试相关配置 # 调试相关配置
# 启用pprof # 启用pprof
# ENABLE_PPROF=true # ENABLE_PPROF=true
# 启用调试模式
# DEBUG=true
# 数据库相关配置 # 数据库相关配置
# 数据库连接字符串 # 数据库连接字符串
@@ -41,6 +43,14 @@
# 更新任务启用 # 更新任务启用
# UPDATE_TASK=true # UPDATE_TASK=true
# 对话超时设置
# 所有请求超时时间单位秒默认为0表示不限制
# RELAY_TIMEOUT=0
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
# STREAMING_TIMEOUT=300
# Gemini 识别图片 最大图片数量
# GEMINI_VISION_MAX_IMAGE_NUM=16
# 会话密钥 # 会话密钥
# SESSION_SECRET=random_string # SESSION_SECRET=random_string
@@ -58,8 +68,6 @@
# GET_MEDIA_TOKEN_NOT_STREAM=true # GET_MEDIA_TOKEN_NOT_STREAM=true
# 设置 Dify 渠道是否输出工作流和节点信息到客户端 # 设置 Dify 渠道是否输出工作流和节点信息到客户端
# DIFY_DEBUG=true # DIFY_DEBUG=true
# 设置流式一次回复的超时时间
# STREAMING_TIMEOUT=90
# 节点类型 # 节点类型

View File

@@ -0,0 +1,19 @@
### PR 类型
- [ ] Bug 修复
- [ ] 新功能
- [ ] 文档更新
- [ ] 其他
### PR 是否包含破坏性更新?
- [ ]
- [ ]
### PR 描述
**请在下方详细描述您的 PR包括目的、实现细节等。**
### **重要提示**
**所有 PR 都必须提交到 `alpha` 分支。请确保您的 PR 目标分支是 `alpha`。**

View File

@@ -1,14 +1,15 @@
name: Publish Docker image (amd64) name: Publish Docker image (alpha)
on: on:
push: push:
tags: branches:
- '*' - alpha
workflow_dispatch: workflow_dispatch:
inputs: inputs:
name: name:
description: 'reason' description: "reason"
required: false required: false
jobs: jobs:
push_to_registries: push_to_registries:
name: Push Docker image to multiple registries name: Push Docker image to multiple registries
@@ -22,7 +23,7 @@ jobs:
- name: Save version info - name: Save version info
run: | run: |
git describe --tags > VERSION echo "alpha-$(date +'%Y%m%d')-$(git rev-parse --short HEAD)" > VERSION
- name: Log in to Docker Hub - name: Log in to Docker Hub
uses: docker/login-action@v3 uses: docker/login-action@v3
@@ -37,6 +38,9 @@ jobs:
username: ${{ github.actor }} username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Extract metadata (tags, labels) for Docker - name: Extract metadata (tags, labels) for Docker
id: meta id: meta
uses: docker/metadata-action@v5 uses: docker/metadata-action@v5
@@ -44,11 +48,15 @@ jobs:
images: | images: |
calciumion/new-api calciumion/new-api
ghcr.io/${{ github.repository }} ghcr.io/${{ github.repository }}
tags: |
type=raw,value=alpha
type=raw,value=alpha-{{date 'YYYYMMDD'}}-{{sha}}
- name: Build and push Docker images - name: Build and push Docker images
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
platforms: linux/amd64,linux/arm64
push: true push: true
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}

View File

@@ -1,14 +1,9 @@
name: Publish Docker image (arm64) name: Publish Docker image (Multi Registries)
on: on:
push: push:
tags: tags:
- '*' - '*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs: jobs:
push_to_registries: push_to_registries:
name: Push Docker image to multiple registries name: Push Docker image to multiple registries

View File

@@ -3,6 +3,11 @@ permissions:
contents: write contents: write
on: on:
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
push: push:
tags: tags:
- '*' - '*'
@@ -15,16 +20,16 @@ jobs:
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
fetch-depth: 0 fetch-depth: 0
- uses: actions/setup-node@v3 - uses: oven-sh/setup-bun@v2
with: with:
node-version: 18 bun-version: latest
- name: Build Frontend - name: Build Frontend
env: env:
CI: "" CI: ""
run: | run: |
cd web cd web
npm install bun install
REACT_APP_VERSION=$(git describe --tags) npm run build DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
cd .. cd ..
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v3 uses: actions/setup-go@v3

View File

@@ -3,6 +3,11 @@ permissions:
contents: write contents: write
on: on:
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
push: push:
tags: tags:
- '*' - '*'
@@ -15,16 +20,17 @@ jobs:
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
fetch-depth: 0 fetch-depth: 0
- uses: actions/setup-node@v3 - uses: oven-sh/setup-bun@v2
with: with:
node-version: 18 bun-version: latest
- name: Build Frontend - name: Build Frontend
env: env:
CI: "" CI: ""
NODE_OPTIONS: "--max-old-space-size=4096"
run: | run: |
cd web cd web
npm install bun install
REACT_APP_VERSION=$(git describe --tags) npm run build DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
cd .. cd ..
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v3 uses: actions/setup-go@v3

View File

@@ -0,0 +1,21 @@
name: Check PR Branching Strategy
on:
pull_request:
types: [opened, synchronize, reopened, edited]
jobs:
check-branching-strategy:
runs-on: ubuntu-latest
steps:
- name: Enforce branching strategy
run: |
if [[ "${{ github.base_ref }}" == "main" ]]; then
if [[ "${{ github.head_ref }}" != "alpha" ]]; then
echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
exit 1
fi
elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
exit 1
fi
echo "Branching strategy check passed."

View File

@@ -3,6 +3,11 @@ permissions:
contents: write contents: write
on: on:
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
push: push:
tags: tags:
- '*' - '*'
@@ -18,16 +23,16 @@ jobs:
uses: actions/checkout@v3 uses: actions/checkout@v3
with: with:
fetch-depth: 0 fetch-depth: 0
- uses: actions/setup-node@v3 - uses: oven-sh/setup-bun@v2
with: with:
node-version: 18 bun-version: latest
- name: Build Frontend - name: Build Frontend
env: env:
CI: "" CI: ""
run: | run: |
cd web cd web
npm install bun install
REACT_APP_VERSION=$(git describe --tags) npm run build DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
cd .. cd ..
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v3 uses: actions/setup-go@v3

3
.gitignore vendored
View File

@@ -10,4 +10,5 @@ web/dist
.env .env
one-api one-api
.DS_Store .DS_Store
tiktoken_cache tiktoken_cache
.eslintcache

View File

@@ -2,6 +2,7 @@ FROM oven/bun:latest AS builder
WORKDIR /build WORKDIR /build
COPY web/package.json . COPY web/package.json .
COPY web/bun.lock .
RUN bun install RUN bun install
COPY ./web . COPY ./web .
COPY ./VERSION . COPY ./VERSION .
@@ -24,8 +25,7 @@ RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-
FROM alpine FROM alpine
RUN apk update \ RUN apk upgrade --no-cache \
&& apk upgrade \
&& apk add --no-cache ca-certificates tzdata ffmpeg \ && apk add --no-cache ca-certificates tzdata ffmpeg \
&& update-ca-certificates && update-ca-certificates

240
LICENSE
View File

@@ -1,201 +1,103 @@
Apache License # **New API 许可协议 (Licensing)**
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 本项目采用**基于使用场景的双重许可 (Usage-Based Dual Licensing)** 模式。
1. Definitions. **核心原则:**
"License" shall mean the terms and conditions for use, reproduction, - **默认许可:** 本项目默认在 **GNU Affero 通用公共许可证 v3.0 (AGPLv3)** 下提供。任何用户在遵守 AGPLv3 条款和下述附加限制的前提下,均可免费使用。
and distribution as defined by Sections 1 through 9 of this document. - **商业许可:** 在特定商业场景下,或当您希望获得 AGPLv3 之外的权利时,**必须**获取**商业许可证 (Commercial License)**。
"Licensor" shall mean the copyright owner or entity authorized by ---
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all ## **1. 开源许可证 (Open Source License): AGPLv3 - 适用于基础使用**
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity - 在遵守 **AGPLv3** 条款的前提下,您可以自由地使用、修改和分发 New API。AGPLv3 的完整文本可以访问 [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html) 获取。
exercising permissions granted by this License. - **核心义务:** AGPLv3 的一个关键要求是,如果您修改了 New API 并通过网络提供服务 (SaaS),或者分发了修改后的版本,您必须以 AGPLv3 许可证向所有用户提供相应的**完整源代码**。
- **附加限制 (重要):** 在仅使用 AGPLv3 开源许可证的情况下,您**必须**完整保留项目代码中原有的品牌标识、LOGO 及版权声明信息。**禁止以任何形式修改、移除或遮盖**这些信息。如需移除,必须获取商业许可证。
- 使用前请务必仔细阅读并理解 AGPLv3 的所有条款及上述附加限制。
"Source" form shall mean the preferred form for making modifications, ## **2. 商业许可证 (Commercial License) - 适用于高级场景及闭源需求**
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical 在以下任一情况下,您**必须**联系我们获取并签署一份商业许可证,才能合法使用 New API
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or - **场景一:移除品牌和版权信息**
Object form, made available under the License, as indicated by a 您希望在您的产品或服务中移除 New API 的 LOGO、UI界面中的版权声明或其他品牌标识。
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object - **场景二:规避 AGPLv3 开源义务**
form, that is based on (or derived from) the Work and for which the 您基于 New API 进行了修改,并希望:
editorial revisions, annotations, elaborations, or other modifications - 通过网络提供服务SaaS但**不希望**向您的服务用户公开您修改后的源代码。
represent, as a whole, an original work of authorship. For the purposes - 分发一个集成了 New API 的软件产品,但**不希望**以 AGPLv3 许可证发布您的产品或公开源代码。
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including - **场景三:企业政策与集成需求**
the original version of the Work and any modifications or additions - 您所在公司的政策、客户合同或项目要求不允许使用 AGPLv3 许可的软件。
to that Work or Derivative Works thereof, that is intentionally - 您需要进行 OEM 集成,将 New API 作为您闭源商业产品的一部分进行再分发。
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity - **场景四:需要商业支持与保障**
on behalf of whom a Contribution has been received by Licensor and 您需要 AGPLv3 未提供的商业保障,如官方技术支持等。
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of **获取商业许可:**
this License, each Contributor hereby grants to You a perpetual, 请通过电子邮件 **support@quantumnous.com** 联系 New API 团队洽谈商业授权事宜。
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of ## **3. 贡献 (Contributions)**
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the - 我们欢迎社区对 New API 的贡献。所有向本项目提交的贡献(例如通过 Pull Request都将被视为在 **AGPLv3** 许可证下提供。
Work or Derivative Works thereof in any medium, with or without - 通过向本项目提交贡献,即表示您同意您的代码以 AGPLv3 许可证授权给本项目及所有后续使用者(无论这些使用者最终遵循 AGPLv3 还是商业许可)。
modifications, and in Source or Object form, provided that You - 您也理解并同意,您的贡献可能会被包含在根据商业许可证分发的 New API 版本中。
meet the following conditions:
(a) You must give any other recipients of the Work or ## **4. 其他条款 (Other Terms)**
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices - 关于商业许可证的具体条款、条件和价格,以双方签署的正式商业许可协议为准。
stating that You changed the files; and - 项目维护者保留根据需要更新本许可政策的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。
(c) You must retain, in the Source form of any Derivative Works ---
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its # **New API Licensing**
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and This project uses a **Usage-Based Dual Licensing** model.
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, **Core Principles:**
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade - **Default License:** This project is available by default under the **GNU Affero General Public License v3.0 (AGPLv3)**. Any user may use it free of charge, provided they comply with both the AGPLv3 terms and the additional restrictions listed below.
names, trademarks, service marks, or product names of the Licensor, - **Commercial License:** For specific commercial scenarios, or if you require rights beyond those granted by AGPLv3, you **must** obtain a **Commercial License**.
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or ---
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory, ## **1. Open Source License: AGPLv3 For Basic Usage**
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing - Under the terms of the **AGPLv3**, you are free to use, modify, and distribute New API. The complete AGPLv3 license text can be viewed at [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html).
the Work or Derivative Works thereof, You may choose to offer, - **Core Obligation:** A key AGPLv3 requirement is that if you modify New API and provide it as a network service (SaaS), or distribute a modified version, you must make the **complete corresponding source code** available to all users under the AGPLv3 license.
and charge a fee for, acceptance of support, warranty, indemnity, - **Additional Restriction (Important):** When using only the AGPLv3 open-source license, you **must** retain all original branding, logos, and copyright statements within the projects code. **You are strictly prohibited from modifying, removing, or concealing** any such information. If you wish to remove this, you must obtain a Commercial License.
or other liability obligations and/or rights consistent with this - Please read and ensure that you fully understand all AGPLv3 terms and the above additional restriction before use.
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS ## **2. Commercial License For Advanced Scenarios & Closed Source Needs**
APPENDIX: How to apply the Apache License to your work. You **must** contact us to obtain and sign a Commercial License in any of the following scenarios in order to legally use New API:
To apply the Apache License to your work, attach the following - **Scenario 1: Removal of Branding and Copyright**
boilerplate notice, with the fields enclosed by brackets "[]" You wish to remove the New API logo, copyright statement, or other branding elements from your product or service.
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner] - **Scenario 2: Avoidance of AGPLv3 Open Source Obligations**
You have modified New API and wish to:
- Offer it as a network service (SaaS) **without** disclosing your modifications' source code to your users.
- Distribute a software product integrated with New API **without** releasing your product under AGPLv3 or open-sourcing the code.
Licensed under the Apache License, Version 2.0 (the "License"); - **Scenario 3: Enterprise Policy & Integration Needs**
you may not use this file except in compliance with the License. - Your organizations policies, client contracts, or project requirements prohibit the use of AGPLv3-licensed software.
You may obtain a copy of the License at - You require OEM integration and need to redistribute New API as part of your closed-source commercial product.
http://www.apache.org/licenses/LICENSE-2.0 - **Scenario 4: Commercial Support and Assurances**
You require commercial assurances not provided by AGPLv3, such as official technical support.
Unless required by applicable law or agreed to in writing, software **Obtaining a Commercial License:**
distributed under the License is distributed on an "AS IS" BASIS, Please contact the New API team via email at **support@quantumnous.com** to discuss commercial licensing.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and ## **3. Contributions**
limitations under the License.
- We welcome community contributions to New API. All contributions (e.g., via Pull Request) are deemed to be provided under the **AGPLv3** license.
- By submitting a contribution, you agree that your code is licensed to this project and all downstream users under the AGPLv3 license (regardless of whether those users ultimately operate under AGPLv3 or a Commercial License).
- You also acknowledge and agree that your contribution may be included in New API releases distributed under a Commercial License.
## **4. Other Terms**
- The specific terms, conditions, and pricing of the Commercial License are governed by the formal commercial license agreement executed by both parties.
- Project maintainers reserve the right to update this licensing policy as needed. Updates will be communicated via official project channels (e.g., repository, official website).

View File

@@ -40,6 +40,28 @@
> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**, and must not use it for illegal purposes. > - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**, and must not use it for illegal purposes.
> - According to the [《Interim Measures for the Management of Generative Artificial Intelligence Services》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), please do not provide any unregistered generative AI services to the public in China. > - According to the [《Interim Measures for the Management of Generative Artificial Intelligence Services》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), please do not provide any unregistered generative AI services to the public in China.
<h2>🤝 Trusted Partners</h2>
<p id="premium-sponsors">&nbsp;</p>
<p align="center"><strong>No particular order</strong></p>
<p align="center">
<a href="https://www.cherry-ai.com/" target=_blank><img
src="./docs/images/cherry-studio.png" alt="Cherry Studio" height="120"
/></a>
<a href="https://bda.pku.edu.cn/" target=_blank><img
src="./docs/images/pku.png" alt="Peking University" height="120"
/></a>
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target=_blank><img
src="./docs/images/ucloud.png" alt="UCloud" height="120"
/></a>
<a href="https://www.aliyun.com/" target=_blank><img
src="./docs/images/aliyun.png" alt="Alibaba Cloud" height="120"
/></a>
<a href="https://io.net/" target=_blank><img
src="./docs/images/io-net.png" alt="IO.NET" height="120"
/></a>
</p>
<p>&nbsp;</p>
## 📚 Documentation ## 📚 Documentation
For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/) For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/)
@@ -100,7 +122,7 @@ This version supports multiple models, please refer to [API Documentation-Relay
For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables): For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false` - `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 60 seconds - `STREAMING_TIMEOUT`: Streaming response timeout, default is 300 seconds
- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true` - `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true` - `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true` - `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`

View File

@@ -40,6 +40,28 @@
> - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 > - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 > - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
<h2>🤝 我们信任的合作伙伴</h2>
<p id="premium-sponsors">&nbsp;</p>
<p align="center"><strong>排名不分先后</strong></p>
<p align="center">
<a href="https://www.cherry-ai.com/" target=_blank><img
src="./docs/images/cherry-studio.png" alt="Cherry Studio" height="120"
/></a>
<a href="https://bda.pku.edu.cn/" target=_blank><img
src="./docs/images/pku.png" alt="北京大学" height="120"
/></a>
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target=_blank><img
src="./docs/images/ucloud.png" alt="UCloud 优刻得" height="120"
/></a>
<a href="https://www.aliyun.com/" target=_blank><img
src="./docs/images/aliyun.png" alt="阿里云" height="120"
/></a>
<a href="https://io.net/" target=_blank><img
src="./docs/images/io-net.png" alt="IO.NET" height="120"
/></a>
</p>
<p>&nbsp;</p>
## 📚 文档 ## 📚 文档
详细文档请访问我们的官方Wiki[https://docs.newapi.pro/](https://docs.newapi.pro/) 详细文档请访问我们的官方Wiki[https://docs.newapi.pro/](https://docs.newapi.pro/)
@@ -100,7 +122,7 @@ New API提供了丰富的功能详细特性请参考[特性说明](https://do
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables) 详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables)
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false` - `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
- `STREAMING_TIMEOUT`:流式回复超时时间,默认60秒 - `STREAMING_TIMEOUT`:流式回复超时时间,默认300秒
- `DIFY_DEBUG`Dify渠道是否输出工作流和节点信息默认 `true` - `DIFY_DEBUG`Dify渠道是否输出工作流和节点信息默认 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数默认 `true` - `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数默认 `true`
- `GET_MEDIA_TOKEN`是否统计图片token默认 `true` - `GET_MEDIA_TOKEN`是否统计图片token默认 `true`
@@ -180,7 +202,6 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
其他基于New API的项目 其他基于New API的项目
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon)New API高性能优化版 - [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon)New API高性能优化版
- [VoAPI](https://github.com/VoAPI/VoAPI)基于New API的前端美化版本
## 帮助支持 ## 帮助支持

75
common/api_type.go Normal file
View File

@@ -0,0 +1,75 @@
package common
import "one-api/constant"
func ChannelType2APIType(channelType int) (int, bool) {
apiType := -1
switch channelType {
case constant.ChannelTypeOpenAI:
apiType = constant.APITypeOpenAI
case constant.ChannelTypeAnthropic:
apiType = constant.APITypeAnthropic
case constant.ChannelTypeBaidu:
apiType = constant.APITypeBaidu
case constant.ChannelTypePaLM:
apiType = constant.APITypePaLM
case constant.ChannelTypeZhipu:
apiType = constant.APITypeZhipu
case constant.ChannelTypeAli:
apiType = constant.APITypeAli
case constant.ChannelTypeXunfei:
apiType = constant.APITypeXunfei
case constant.ChannelTypeAIProxyLibrary:
apiType = constant.APITypeAIProxyLibrary
case constant.ChannelTypeTencent:
apiType = constant.APITypeTencent
case constant.ChannelTypeGemini:
apiType = constant.APITypeGemini
case constant.ChannelTypeZhipu_v4:
apiType = constant.APITypeZhipuV4
case constant.ChannelTypeOllama:
apiType = constant.APITypeOllama
case constant.ChannelTypePerplexity:
apiType = constant.APITypePerplexity
case constant.ChannelTypeAws:
apiType = constant.APITypeAws
case constant.ChannelTypeCohere:
apiType = constant.APITypeCohere
case constant.ChannelTypeDify:
apiType = constant.APITypeDify
case constant.ChannelTypeJina:
apiType = constant.APITypeJina
case constant.ChannelCloudflare:
apiType = constant.APITypeCloudflare
case constant.ChannelTypeSiliconFlow:
apiType = constant.APITypeSiliconFlow
case constant.ChannelTypeVertexAi:
apiType = constant.APITypeVertexAi
case constant.ChannelTypeMistral:
apiType = constant.APITypeMistral
case constant.ChannelTypeDeepSeek:
apiType = constant.APITypeDeepSeek
case constant.ChannelTypeMokaAI:
apiType = constant.APITypeMokaAI
case constant.ChannelTypeVolcEngine:
apiType = constant.APITypeVolcEngine
case constant.ChannelTypeBaiduV2:
apiType = constant.APITypeBaiduV2
case constant.ChannelTypeOpenRouter:
apiType = constant.APITypeOpenRouter
case constant.ChannelTypeXinference:
apiType = constant.APITypeXinference
case constant.ChannelTypeXai:
apiType = constant.APITypeXai
case constant.ChannelTypeCoze:
apiType = constant.APITypeCoze
case constant.ChannelTypeJimeng:
apiType = constant.APITypeJimeng
case constant.ChannelTypeMoonshot:
apiType = constant.APITypeMoonshot
}
if apiType == -1 {
return constant.APITypeOpenAI, false
}
return apiType, true
}

View File

@@ -83,6 +83,7 @@ var GitHubClientId = ""
var GitHubClientSecret = "" var GitHubClientSecret = ""
var LinuxDOClientId = "" var LinuxDOClientId = ""
var LinuxDOClientSecret = "" var LinuxDOClientSecret = ""
var LinuxDOMinimumTrustLevel = 0
var WeChatServerAddress = "" var WeChatServerAddress = ""
var WeChatServerToken = "" var WeChatServerToken = ""
@@ -195,105 +196,7 @@ const (
) )
const ( const (
ChannelTypeUnknown = 0 TopUpStatusPending = "pending"
ChannelTypeOpenAI = 1 TopUpStatusSuccess = "success"
ChannelTypeMidjourney = 2 TopUpStatusExpired = "expired"
ChannelTypeAzure = 3
ChannelTypeOllama = 4
ChannelTypeMidjourneyPlus = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
ChannelTypeAILS = 9
ChannelTypeAIProxy = 10
ChannelTypePaLM = 11
ChannelTypeAPI2GPT = 12
ChannelTypeAIGC2D = 13
ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15
ChannelTypeZhipu = 16
ChannelTypeAli = 17
ChannelTypeXunfei = 18
ChannelType360 = 19
ChannelTypeOpenRouter = 20
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
ChannelTypeMoonshot = 25
ChannelTypeZhipu_v4 = 26
ChannelTypePerplexity = 27
ChannelTypeLingYiWanWu = 31
ChannelTypeAws = 33
ChannelTypeCohere = 34
ChannelTypeMiniMax = 35
ChannelTypeSunoAPI = 36
ChannelTypeDify = 37
ChannelTypeJina = 38
ChannelCloudflare = 39
ChannelTypeSiliconFlow = 40
ChannelTypeVertexAi = 41
ChannelTypeMistral = 42
ChannelTypeDeepSeek = 43
ChannelTypeMokaAI = 44
ChannelTypeVolcEngine = 45
ChannelTypeBaiduV2 = 46
ChannelTypeXinference = 47
ChannelTypeXai = 48
ChannelTypeCoze = 49
ChannelTypeDummy // this one is only for count, do not add any channel after this
) )
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"http://localhost:11434", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://api.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.tencentcloudapi.com", //23
"https://generativelanguage.googleapis.com", //24
"https://api.moonshot.cn", //25
"https://open.bigmodel.cn", //26
"https://api.perplexity.ai", //27
"", //28
"", //29
"", //30
"https://api.lingyiwanwu.com", //31
"", //32
"", //33
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"", //36
"https://api.dify.ai", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40
"", //41
"https://api.mistral.ai", //42
"https://api.deepseek.com", //43
"https://api.moka.ai", //44
"https://ark.cn-beijing.volces.com", //45
"https://qianfan.baidubce.com", //46
"", //47
"https://api.x.ai", //48
"https://api.coze.cn", //49
}

21
common/copy.go Normal file
View File

@@ -0,0 +1,21 @@
package common
import (
"fmt"
"github.com/antlabs/pcopy"
)
func DeepCopy[T any](src *T) (*T, error) {
if src == nil {
return nil, fmt.Errorf("copy source cannot be nil")
}
var dst T
err := pcopy.Copy(&dst, src)
if err != nil {
return nil, err
}
if &dst == nil {
return nil, fmt.Errorf("copy result cannot be nil")
}
return &dst, nil
}

View File

@@ -9,6 +9,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync"
) )
type stringWriter interface { type stringWriter interface {
@@ -52,6 +53,8 @@ type CustomEvent struct {
Id string Id string
Retry uint Retry uint
Data interface{} Data interface{}
Mutex sync.Mutex
} }
func encode(writer io.Writer, event CustomEvent) error { func encode(writer io.Writer, event CustomEvent) error {
@@ -73,6 +76,8 @@ func (r CustomEvent) Render(w http.ResponseWriter) error {
} }
func (r CustomEvent) WriteContentType(w http.ResponseWriter) { func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
r.Mutex.Lock()
defer r.Mutex.Unlock()
header := w.Header() header := w.Header()
header["Content-Type"] = contentType header["Content-Type"] = contentType

View File

@@ -1,8 +1,15 @@
package common package common
const (
DatabaseTypeMySQL = "mysql"
DatabaseTypeSQLite = "sqlite"
DatabaseTypePostgreSQL = "postgres"
)
var UsingSQLite = false var UsingSQLite = false
var UsingPostgreSQL = false var UsingPostgreSQL = false
var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
var UsingMySQL = false var UsingMySQL = false
var UsingClickHouse = false var UsingClickHouse = false
var SQLitePath = "one-api.db?_busy_timeout=5000" var SQLitePath = "one-api.db?_busy_timeout=30000"

View File

@@ -0,0 +1,32 @@
package common
import "one-api/constant"
// EndpointInfo 描述单个端点的默认请求信息
// path: 上游路径
// method: HTTP 请求方式,例如 POST/GET
// 目前均为 POST后续可扩展
//
// json 标签用于直接序列化到 API 输出
// 例如:{"path":"/v1/chat/completions","method":"POST"}
type EndpointInfo struct {
Path string `json:"path"`
Method string `json:"method"`
}
// defaultEndpointInfoMap 保存内置端点的默认 Path 与 Method
var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{
constant.EndpointTypeOpenAI: {Path: "/v1/chat/completions", Method: "POST"},
constant.EndpointTypeOpenAIResponse: {Path: "/v1/responses", Method: "POST"},
constant.EndpointTypeAnthropic: {Path: "/v1/messages", Method: "POST"},
constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"},
constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"},
constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"},
}
// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在
func GetDefaultEndpointInfo(et constant.EndpointType) (EndpointInfo, bool) {
info, ok := defaultEndpointInfoMap[et]
return info, ok
}

41
common/endpoint_type.go Normal file
View File

@@ -0,0 +1,41 @@
package common
import "one-api/constant"
// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
var endpointTypes []constant.EndpointType
switch channelType {
case constant.ChannelTypeJina:
endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
//case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney}
//case constant.ChannelTypeSunoAPI:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno}
//case constant.ChannelTypeKling:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeKling}
//case constant.ChannelTypeJimeng:
// endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng}
case constant.ChannelTypeAws:
fallthrough
case constant.ChannelTypeAnthropic:
endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
case constant.ChannelTypeVertexAi:
fallthrough
case constant.ChannelTypeGemini:
endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
default:
if IsOpenAIResponseOnlyModel(modelName) {
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
} else {
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
}
}
if IsImageGenerationModel(modelName) {
// add to first
endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...)
}
return endpointTypes
}

View File

@@ -2,10 +2,13 @@ package common
import ( import (
"bytes" "bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"io" "io"
"net/http"
"one-api/constant"
"strings" "strings"
"time"
"github.com/gin-gonic/gin"
) )
const KeyRequestBody = "key_request_body" const KeyRequestBody = "key_request_body"
@@ -29,9 +32,12 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
if err != nil { if err != nil {
return err return err
} }
//if DebugEnabled {
// println("UnmarshalBodyReusable request body:", string(requestBody))
//}
contentType := c.Request.Header.Get("Content-Type") contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") { if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v) err = Unmarshal(requestBody, &v)
} else { } else {
// skip for now // skip for now
// TODO: someday non json request have variant model, we will need to implementation this // TODO: someday non json request have variant model, we will need to implementation this
@@ -43,3 +49,67 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil return nil
} }
func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
c.Set(string(key), value)
}
func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
return c.Get(string(key))
}
func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
return c.GetString(string(key))
}
func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
return c.GetInt(string(key))
}
func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
return c.GetBool(string(key))
}
func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
return c.GetStringSlice(string(key))
}
func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
return c.GetStringMap(string(key))
}
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
return c.GetTime(string(key))
}
func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
if value, ok := c.Get(string(key)); ok {
if v, ok := value.(T); ok {
return v, true
}
}
var t T
return t, false
}
func ApiError(c *gin.Context, err error) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
}
func ApiErrorMsg(c *gin.Context, msg string) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": msg,
})
}
func ApiSuccess(c *gin.Context, data any) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": data,
})
}

34
common/hash.go Normal file
View File

@@ -0,0 +1,34 @@
package common
import (
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
)
func Sha256Raw(data []byte) []byte {
h := sha256.New()
h.Write(data)
return h.Sum(nil)
}
func Sha1Raw(data []byte) []byte {
h := sha1.New()
h.Write(data)
return h.Sum(nil)
}
func Sha1(data []byte) string {
return hex.EncodeToString(Sha1Raw(data))
}
func HmacSha256Raw(message, key []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(message)
return h.Sum(nil)
}
func HmacSha256(message, key string) string {
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
}

View File

@@ -4,6 +4,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"log" "log"
"one-api/constant"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
@@ -24,7 +25,7 @@ func printHelp() {
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]") fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
} }
func LoadEnv() { func InitEnv() {
flag.Parse() flag.Parse()
if *PrintVersion { if *PrintVersion {
@@ -95,4 +96,25 @@ func LoadEnv() {
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true) GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180)) GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
initConstantEnv()
}
func initConstantEnv() {
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 300)
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
// ForceStreamOption 覆盖请求参数强制返回usage信息
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
// 是否启用错误日志
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
} }

View File

@@ -5,14 +5,18 @@ import (
"encoding/json" "encoding/json"
) )
func DecodeJson(data []byte, v any) error { func Unmarshal(data []byte, v any) error {
return json.NewDecoder(bytes.NewReader(data)).Decode(v) return json.Unmarshal(data, v)
} }
func DecodeJsonStr(data string, v any) error { func UnmarshalJsonStr(data string, v any) error {
return DecodeJson(StringToByteSlice(data), v) return json.Unmarshal(StringToByteSlice(data), v)
} }
func EncodeJson(v any) ([]byte, error) { func DecodeJson(reader *bytes.Reader, v any) error {
return json.NewDecoder(reader).Decode(v)
}
func Marshal(v any) ([]byte, error) {
return json.Marshal(v) return json.Marshal(v)
} }

42
common/model.go Normal file
View File

@@ -0,0 +1,42 @@
package common
import "strings"
var (
// OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
OpenAIResponseOnlyModels = []string{
"o3-pro",
"o3-deep-research",
"o4-mini-deep-research",
}
ImageGenerationModels = []string{
"dall-e-3",
"dall-e-2",
"gpt-image-1",
"prefix:imagen-",
"flux-",
"flux.1-",
}
)
func IsOpenAIResponseOnlyModel(modelName string) bool {
for _, m := range OpenAIResponseOnlyModels {
if strings.Contains(modelName, m) {
return true
}
}
return false
}
func IsImageGenerationModel(modelName string) bool {
modelName = strings.ToLower(modelName)
for _, m := range ImageGenerationModels {
if strings.Contains(modelName, m) {
return true
}
if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) {
return true
}
}
return false
}

82
common/page_info.go Normal file
View File

@@ -0,0 +1,82 @@
package common
import (
"strconv"
"github.com/gin-gonic/gin"
)
type PageInfo struct {
Page int `json:"page"` // page num 页码
PageSize int `json:"page_size"` // page size 页大小
Total int `json:"total"` // 总条数,后设置
Items any `json:"items"` // 数据,后设置
}
func (p *PageInfo) GetStartIdx() int {
return (p.Page - 1) * p.PageSize
}
func (p *PageInfo) GetEndIdx() int {
return p.Page * p.PageSize
}
func (p *PageInfo) GetPageSize() int {
return p.PageSize
}
func (p *PageInfo) GetPage() int {
return p.Page
}
func (p *PageInfo) SetTotal(total int) {
p.Total = total
}
func (p *PageInfo) SetItems(items any) {
p.Items = items
}
func GetPageQuery(c *gin.Context) *PageInfo {
pageInfo := &PageInfo{}
// 手动获取并处理每个参数
if page, err := strconv.Atoi(c.Query("p")); err == nil {
pageInfo.Page = page
}
if pageSize, err := strconv.Atoi(c.Query("page_size")); err == nil {
pageInfo.PageSize = pageSize
}
if pageInfo.Page < 1 {
// 兼容
page, _ := strconv.Atoi(c.Query("p"))
if page != 0 {
pageInfo.Page = page
} else {
pageInfo.Page = 1
}
}
if pageInfo.PageSize == 0 {
// 兼容
pageSize, _ := strconv.Atoi(c.Query("ps"))
if pageSize != 0 {
pageInfo.PageSize = pageSize
}
if pageInfo.PageSize == 0 {
pageSize, _ = strconv.Atoi(c.Query("size")) // token page
if pageSize != 0 {
pageInfo.PageSize = pageSize
}
}
if pageInfo.PageSize == 0 {
pageInfo.PageSize = ItemsPerPage
}
}
if pageInfo.PageSize > 100 {
pageInfo.PageSize = 100
}
return pageInfo
}

5
common/quota.go Normal file
View File

@@ -0,0 +1,5 @@
package common
func GetTrustQuota() int {
return int(10 * QuotaPerUnit)
}

View File

@@ -16,6 +16,10 @@ import (
var RDB *redis.Client var RDB *redis.Client
var RedisEnabled = true var RedisEnabled = true
func RedisKeyCacheSeconds() int {
return SyncFrequency
}
// InitRedisClient This function is called after init() // InitRedisClient This function is called after init()
func InitRedisClient() (err error) { func InitRedisClient() (err error) {
if os.Getenv("REDIS_CONN_STRING") == "" { if os.Getenv("REDIS_CONN_STRING") == "" {
@@ -92,12 +96,12 @@ func RedisDel(key string) error {
return RDB.Del(ctx, key).Err() return RDB.Del(ctx, key).Err()
} }
func RedisHDelObj(key string) error { func RedisDelKey(key string) error {
if DebugEnabled { if DebugEnabled {
SysLog(fmt.Sprintf("Redis HDEL: key=%s", key)) SysLog(fmt.Sprintf("Redis DEL Key: key=%s", key))
} }
ctx := context.Background() ctx := context.Background()
return RDB.HDel(ctx, key).Err() return RDB.Del(ctx, key).Err()
} }
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error { func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
@@ -141,7 +145,11 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
txn := RDB.TxPipeline() txn := RDB.TxPipeline()
txn.HSet(ctx, key, data) txn.HSet(ctx, key, data)
txn.Expire(ctx, key, expiration)
// 只有在 expiration 大于 0 时才设置过期时间
if expiration > 0 {
txn.Expire(ctx, key, expiration)
}
_, err := txn.Exec(ctx) _, err := txn.Exec(ctx)
if err != nil { if err != nil {

View File

@@ -1,9 +1,13 @@
package common package common
import ( import (
"encoding/base64"
"encoding/json" "encoding/json"
"math/rand" "math/rand"
"net/url"
"regexp"
"strconv" "strconv"
"strings"
"unsafe" "unsafe"
) )
@@ -31,16 +35,30 @@ func MapToJsonStr(m map[string]interface{}) string {
return string(bytes) return string(bytes)
} }
func StrToMap(str string) map[string]interface{} { func StrToMap(str string) (map[string]interface{}, error) {
m := make(map[string]interface{}) m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m) err := Unmarshal([]byte(str), &m)
if err != nil { if err != nil {
return nil return nil, err
} }
return m return m, nil
} }
func IsJsonStr(str string) bool { func StrToJsonArray(str string) ([]interface{}, error) {
var js []interface{}
err := json.Unmarshal([]byte(str), &js)
if err != nil {
return nil, err
}
return js, nil
}
func IsJsonArray(str string) bool {
var js []interface{}
return json.Unmarshal([]byte(str), &js) == nil
}
func IsJsonObject(str string) bool {
var js map[string]interface{} var js map[string]interface{}
return json.Unmarshal([]byte(str), &js) == nil return json.Unmarshal([]byte(str), &js) == nil
} }
@@ -68,3 +86,152 @@ func StringToByteSlice(s string) []byte {
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
return *(*[]byte)(unsafe.Pointer(&tmp2)) return *(*[]byte)(unsafe.Pointer(&tmp2))
} }
func EncodeBase64(str string) string {
return base64.StdEncoding.EncodeToString([]byte(str))
}
func GetJsonString(data any) string {
if data == nil {
return ""
}
b, _ := json.Marshal(data)
return string(b)
}
// MaskEmail masks a user email to prevent PII leakage in logs
// Returns "***masked***" if email is empty, otherwise shows only the domain part
func MaskEmail(email string) string {
if email == "" {
return "***masked***"
}
// Find the @ symbol
atIndex := strings.Index(email, "@")
if atIndex == -1 {
// No @ symbol found, return masked
return "***masked***"
}
// Return only the domain part with @ symbol
return "***@" + email[atIndex+1:]
}
// maskHostTail returns the tail parts of a domain/host that should be preserved.
// It keeps 2 parts for likely country-code TLDs (e.g., co.uk, com.cn), otherwise keeps only the TLD.
func maskHostTail(parts []string) []string {
if len(parts) < 2 {
return parts
}
lastPart := parts[len(parts)-1]
secondLastPart := parts[len(parts)-2]
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
// Likely country code TLD like co.uk, com.cn
return []string{secondLastPart, lastPart}
}
return []string{lastPart}
}
// maskHostForURL collapses subdomains and keeps only masked prefix + preserved tail.
// Example: api.openai.com -> ***.com, sub.domain.co.uk -> ***.co.uk
func maskHostForURL(host string) string {
parts := strings.Split(host, ".")
if len(parts) < 2 {
return "***"
}
tail := maskHostTail(parts)
return "***." + strings.Join(tail, ".")
}
// maskHostForPlainDomain masks a plain domain and reflects subdomain depth with multiple ***.
// Example: openai.com -> ***.com, api.openai.com -> ***.***.com, sub.domain.co.uk -> ***.***.co.uk
func maskHostForPlainDomain(domain string) string {
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
tail := maskHostTail(parts)
numStars := len(parts) - len(tail)
if numStars < 1 {
numStars = 1
}
stars := strings.TrimSuffix(strings.Repeat("***.", numStars), ".")
return stars + "." + strings.Join(tail, ".")
}
// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
// Example:
// http://example.com -> http://***.com
// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
// 192.168.1.1 -> ***.***.***.***
// openai.com -> ***.com
// www.openai.com -> ***.***.com
// api.openai.com -> ***.***.com
func MaskSensitiveInfo(str string) string {
// Mask URLs
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
u, err := url.Parse(urlStr)
if err != nil {
return urlStr
}
host := u.Host
if host == "" {
return urlStr
}
// Mask host with unified logic
maskedHost := maskHostForURL(host)
result := u.Scheme + "://" + maskedHost
// Mask path
if u.Path != "" && u.Path != "/" {
pathParts := strings.Split(strings.Trim(u.Path, "/"), "/")
maskedPathParts := make([]string, len(pathParts))
for i := range pathParts {
if pathParts[i] != "" {
maskedPathParts[i] = "***"
}
}
if len(maskedPathParts) > 0 {
result += "/" + strings.Join(maskedPathParts, "/")
}
} else if u.Path == "/" {
result += "/"
}
// Mask query parameters
if u.RawQuery != "" {
values, err := url.ParseQuery(u.RawQuery)
if err != nil {
// If can't parse query, just mask the whole query string
result += "?***"
} else {
maskedParams := make([]string, 0, len(values))
for key := range values {
maskedParams = append(maskedParams, key+"=***")
}
if len(maskedParams) > 0 {
result += "?" + strings.Join(maskedParams, "&")
}
}
}
return result
})
// Mask domain names without protocol (like openai.com, www.openai.com)
domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
return maskHostForPlainDomain(domain)
})
// Mask IP addresses
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
str = ipPattern.ReplaceAllString(str, "***.***.***.***")
return str
}

24
common/sys_log.go Normal file
View File

@@ -0,0 +1,24 @@
package common
import (
"fmt"
"github.com/gin-gonic/gin"
"os"
"time"
)
func SysLog(s string) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func SysError(s string) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func FatalLog(v ...any) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}

150
common/totp.go Normal file
View File

@@ -0,0 +1,150 @@
package common
import (
"crypto/rand"
"fmt"
"os"
"strconv"
"strings"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
)
const (
// 备用码配置
BackupCodeLength = 8 // 备用码长度
BackupCodeCount = 4 // 生成备用码数量
// 限制配置
MaxFailAttempts = 5 // 最大失败尝试次数
LockoutDuration = 300 // 锁定时间(秒)
)
// GenerateTOTPSecret 生成TOTP密钥和配置
func GenerateTOTPSecret(accountName string) (*otp.Key, error) {
issuer := Get2FAIssuer()
return totp.Generate(totp.GenerateOpts{
Issuer: issuer,
AccountName: accountName,
Period: 30,
Digits: otp.DigitsSix,
Algorithm: otp.AlgorithmSHA1,
})
}
// ValidateTOTPCode 验证TOTP验证码
func ValidateTOTPCode(secret, code string) bool {
// 清理验证码格式
cleanCode := strings.ReplaceAll(code, " ", "")
if len(cleanCode) != 6 {
return false
}
// 验证验证码
return totp.Validate(cleanCode, secret)
}
// GenerateBackupCodes 生成备用恢复码
func GenerateBackupCodes() ([]string, error) {
codes := make([]string, BackupCodeCount)
for i := 0; i < BackupCodeCount; i++ {
code, err := generateRandomBackupCode()
if err != nil {
return nil, err
}
codes[i] = code
}
return codes, nil
}
// generateRandomBackupCode 生成单个备用码
func generateRandomBackupCode() (string, error) {
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
code := make([]byte, BackupCodeLength)
for i := range code {
randomBytes := make([]byte, 1)
_, err := rand.Read(randomBytes)
if err != nil {
return "", err
}
code[i] = charset[int(randomBytes[0])%len(charset)]
}
// 格式化为 XXXX-XXXX 格式
return fmt.Sprintf("%s-%s", string(code[:4]), string(code[4:])), nil
}
// ValidateBackupCode 验证备用码格式
func ValidateBackupCode(code string) bool {
// 移除所有分隔符并转为大写
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
if len(cleanCode) != BackupCodeLength {
return false
}
// 检查字符是否合法
for _, char := range cleanCode {
if !((char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9')) {
return false
}
}
return true
}
// NormalizeBackupCode 标准化备用码格式
func NormalizeBackupCode(code string) string {
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
if len(cleanCode) == BackupCodeLength {
return fmt.Sprintf("%s-%s", cleanCode[:4], cleanCode[4:])
}
return code
}
// HashBackupCode 对备用码进行哈希
func HashBackupCode(code string) (string, error) {
normalizedCode := NormalizeBackupCode(code)
return Password2Hash(normalizedCode)
}
// Get2FAIssuer 获取2FA发行者名称
func Get2FAIssuer() string {
return SystemName
}
// getEnvOrDefault 获取环境变量或默认值
func getEnvOrDefault(key, defaultValue string) string {
if value, exists := os.LookupEnv(key); exists {
return value
}
return defaultValue
}
// ValidateNumericCode 验证数字验证码格式
func ValidateNumericCode(code string) (string, error) {
// 移除空格
code = strings.ReplaceAll(code, " ", "")
if len(code) != 6 {
return "", fmt.Errorf("验证码必须是6位数字")
}
// 检查是否为纯数字
if _, err := strconv.Atoi(code); err != nil {
return "", fmt.Errorf("验证码只能包含数字")
}
return code, nil
}
// GenerateQRCodeData 生成二维码数据
func GenerateQRCodeData(secret, username string) string {
issuer := Get2FAIssuer()
accountName := fmt.Sprintf("%s (%s)", username, issuer)
return fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&digits=6&period=30",
issuer, accountName, secret, issuer)
}

View File

@@ -13,6 +13,7 @@ import (
"math/big" "math/big"
"math/rand" "math/rand"
"net" "net"
"net/url"
"os" "os"
"os/exec" "os/exec"
"runtime" "runtime"
@@ -249,13 +250,55 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
} }
// GetAudioDuration returns the duration of an audio file in seconds. // GetAudioDuration returns the duration of an audio file in seconds.
func GetAudioDuration(ctx context.Context, filename string) (float64, error) { func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}} // ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename) c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
output, err := c.Output() output, err := c.Output()
if err != nil { if err != nil {
return 0, errors.Wrap(err, "failed to get audio duration") return 0, errors.Wrap(err, "failed to get audio duration")
} }
durationStr := string(bytes.TrimSpace(output))
if durationStr == "N/A" {
// Create a temporary output file name
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
if err != nil {
return 0, errors.Wrap(err, "failed to create temporary file")
}
tmpName := tmpFp.Name()
// Close immediately so ffmpeg can open the file on Windows.
_ = tmpFp.Close()
defer os.Remove(tmpName)
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64) // ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
if err := ffmpegCmd.Run(); err != nil {
return 0, errors.Wrap(err, "failed to run ffmpeg")
}
// Recalculate the duration of the new file
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
output, err := c.Output()
if err != nil {
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
}
durationStr = string(bytes.TrimSpace(output))
}
return strconv.ParseFloat(durationStr, 64)
}
// BuildURL concatenates base and endpoint, returns the complete url string
func BuildURL(base string, endpoint string) string {
u, err := url.Parse(base)
if err != nil {
return base + endpoint
}
end := endpoint
if end == "" {
end = "/"
}
ref, err := url.Parse(end)
if err != nil {
return base + endpoint
}
return u.ResolveReference(ref).String()
} }

26
constant/README.md Normal file
View File

@@ -0,0 +1,26 @@
# constant 包 (`/constant`)
该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
## 当前文件
| 文件 | 说明 |
|----------------------|---------------------------------------------------------------------|
| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 |
| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 |
| `channel_setting.go` | Channel 级别的设置键,如 `proxy``force_format` 等。 |
| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量请求时间、Token/Channel/User 相关信息等)。 |
| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 |
| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 |
| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 |
| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 |
| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 |
| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 |
## 使用约定
1. `constant` 包**只能被其他包引用**import**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**
2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。

36
constant/api_type.go Normal file
View File

@@ -0,0 +1,36 @@
package constant
const (
APITypeOpenAI = iota
APITypeAnthropic
APITypePaLM
APITypeBaidu
APITypeZhipu
APITypeAli
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
APITypeZhipuV4
APITypeOllama
APITypePerplexity
APITypeAws
APITypeCohere
APITypeDify
APITypeJina
APITypeCloudflare
APITypeSiliconFlow
APITypeVertexAi
APITypeMistral
APITypeDeepSeek
APITypeMokaAI
APITypeVolcEngine
APITypeBaiduV2
APITypeOpenRouter
APITypeXinference
APITypeXai
APITypeCoze
APITypeJimeng
APITypeMoonshot // this one is only for count, do not add any channel after this
APITypeDummy // this one is only for count, do not add any channel after this
)

View File

@@ -1,14 +1,5 @@
package constant package constant
import "one-api/common"
var (
TokenCacheSeconds = common.SyncFrequency
UserId2GroupCacheSeconds = common.SyncFrequency
UserId2QuotaCacheSeconds = common.SyncFrequency
UserId2StatusCacheSeconds = common.SyncFrequency
)
// Cache keys // Cache keys
const ( const (
UserGroupKeyFmt = "user_group:%d" UserGroupKeyFmt = "user_group:%d"

111
constant/channel.go Normal file
View File

@@ -0,0 +1,111 @@
package constant
const (
ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1
ChannelTypeMidjourney = 2
ChannelTypeAzure = 3
ChannelTypeOllama = 4
ChannelTypeMidjourneyPlus = 5
ChannelTypeOpenAIMax = 6
ChannelTypeOhMyGPT = 7
ChannelTypeCustom = 8
ChannelTypeAILS = 9
ChannelTypeAIProxy = 10
ChannelTypePaLM = 11
ChannelTypeAPI2GPT = 12
ChannelTypeAIGC2D = 13
ChannelTypeAnthropic = 14
ChannelTypeBaidu = 15
ChannelTypeZhipu = 16
ChannelTypeAli = 17
ChannelTypeXunfei = 18
ChannelType360 = 19
ChannelTypeOpenRouter = 20
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
ChannelTypeMoonshot = 25
ChannelTypeZhipu_v4 = 26
ChannelTypePerplexity = 27
ChannelTypeLingYiWanWu = 31
ChannelTypeAws = 33
ChannelTypeCohere = 34
ChannelTypeMiniMax = 35
ChannelTypeSunoAPI = 36
ChannelTypeDify = 37
ChannelTypeJina = 38
ChannelCloudflare = 39
ChannelTypeSiliconFlow = 40
ChannelTypeVertexAi = 41
ChannelTypeMistral = 42
ChannelTypeDeepSeek = 43
ChannelTypeMokaAI = 44
ChannelTypeVolcEngine = 45
ChannelTypeBaiduV2 = 46
ChannelTypeXinference = 47
ChannelTypeXai = 48
ChannelTypeCoze = 49
ChannelTypeKling = 50
ChannelTypeJimeng = 51
ChannelTypeVidu = 52
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"http://localhost:11434", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://api.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.tencentcloudapi.com", //23
"https://generativelanguage.googleapis.com", //24
"https://api.moonshot.cn", //25
"https://open.bigmodel.cn", //26
"https://api.perplexity.ai", //27
"", //28
"", //29
"", //30
"https://api.lingyiwanwu.com", //31
"", //32
"", //33
"https://api.cohere.ai", //34
"https://api.minimax.chat", //35
"", //36
"https://api.dify.ai", //37
"https://api.jina.ai", //38
"https://api.cloudflare.com", //39
"https://api.siliconflow.cn", //40
"", //41
"https://api.mistral.ai", //42
"https://api.deepseek.com", //43
"https://api.moka.ai", //44
"https://ark.cn-beijing.volces.com", //45
"https://qianfan.baidubce.com", //46
"", //47
"https://api.x.ai", //48
"https://api.coze.cn", //49
"https://api.klingai.com", //50
"https://visual.volcengineapi.com", //51
"https://api.vidu.cn", //52
}

View File

@@ -1,7 +0,0 @@
package constant
var (
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
ChanelSettingProxy = "proxy" // Proxy 代理
ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
)

View File

@@ -1,10 +1,49 @@
package constant package constant
type ContextKey string
const ( const (
ContextKeyRequestStartTime = "request_start_time" ContextKeyTokenCountMeta ContextKey = "token_count_meta"
ContextKeyUserSetting = "user_setting" ContextKeyPromptTokens ContextKey = "prompt_tokens"
ContextKeyUserQuota = "user_quota"
ContextKeyUserStatus = "user_status" ContextKeyOriginalModel ContextKey = "original_model"
ContextKeyUserEmail = "user_email" ContextKeyRequestStartTime ContextKey = "request_start_time"
ContextKeyUserGroup = "user_group"
/* token related keys */
ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota"
ContextKeyTokenKey ContextKey = "token_key"
ContextKeyTokenId ContextKey = "token_id"
ContextKeyTokenGroup ContextKey = "token_group"
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
/* channel related keys */
ContextKeyChannelId ContextKey = "channel_id"
ContextKeyChannelName ContextKey = "channel_name"
ContextKeyChannelCreateTime ContextKey = "channel_create_time"
ContextKeyChannelBaseUrl ContextKey = "base_url"
ContextKeyChannelType ContextKey = "channel_type"
ContextKeyChannelSetting ContextKey = "channel_setting"
ContextKeyChannelOtherSetting ContextKey = "channel_other_setting"
ContextKeyChannelParamOverride ContextKey = "param_override"
ContextKeyChannelOrganization ContextKey = "channel_organization"
ContextKeyChannelAutoBan ContextKey = "auto_ban"
ContextKeyChannelModelMapping ContextKey = "model_mapping"
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
ContextKeyChannelKey ContextKey = "channel_key"
/* user related keys */
ContextKeyUserId ContextKey = "id"
ContextKeyUserSetting ContextKey = "user_setting"
ContextKeyUserQuota ContextKey = "user_quota"
ContextKeyUserStatus ContextKey = "user_status"
ContextKeyUserEmail ContextKey = "user_email"
ContextKeyUserGroup ContextKey = "user_group"
ContextKeyUsingGroup ContextKey = "group"
ContextKeyUserName ContextKey = "username"
ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
) )

16
constant/endpoint_type.go Normal file
View File

@@ -0,0 +1,16 @@
package constant
type EndpointType string
const (
EndpointTypeOpenAI EndpointType = "openai"
EndpointTypeOpenAIResponse EndpointType = "openai-response"
EndpointTypeAnthropic EndpointType = "anthropic"
EndpointTypeGemini EndpointType = "gemini"
EndpointTypeJinaRerank EndpointType = "jina-rerank"
EndpointTypeImageGeneration EndpointType = "image-generation"
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
//EndpointTypeSuno EndpointType = "suno-proxy"
//EndpointTypeKling EndpointType = "kling"
//EndpointTypeJimeng EndpointType = "jimeng"
)

View File

@@ -1,9 +1,5 @@
package constant package constant
import (
"one-api/common"
)
var StreamingTimeout int var StreamingTimeout int
var DifyDebug bool var DifyDebug bool
var MaxFileDownloadMB int var MaxFileDownloadMB int
@@ -17,39 +13,3 @@ var NotifyLimitCount int
var NotificationLimitDurationMinute int var NotificationLimitDurationMinute int
var GenerateDefaultToken bool var GenerateDefaultToken bool
var ErrorLogEnabled bool var ErrorLogEnabled bool
//var GeminiModelMap = map[string]string{
// "gemini-1.0-pro": "v1",
//}
func InitEnv() {
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
// ForceStreamOption 覆盖请求参数强制返回usage信息
ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
// 是否启用错误日志
ErrorLogEnabled = common.GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
//if modelVersionMapStr == "" {
// return
//}
//for _, pair := range strings.Split(modelVersionMapStr, ",") {
// parts := strings.Split(pair, ":")
// if len(parts) == 2 {
// GeminiModelMap[parts[0]] = parts[1]
// } else {
// common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
// }
//}
}

View File

@@ -22,6 +22,8 @@ const (
MjActionPan = "PAN" MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE" MjActionSwapFace = "SWAP_FACE"
MjActionUpload = "UPLOAD" MjActionUpload = "UPLOAD"
MjActionVideo = "VIDEO"
MjActionEdits = "EDITS"
) )
var MidjourneyModel2Action = map[string]string{ var MidjourneyModel2Action = map[string]string{
@@ -41,4 +43,6 @@ var MidjourneyModel2Action = map[string]string{
"mj_pan": MjActionPan, "mj_pan": MjActionPan,
"swap_face": MjActionSwapFace, "swap_face": MjActionSwapFace,
"mj_upload": MjActionUpload, "mj_upload": MjActionUpload,
"mj_video": MjActionVideo,
"mj_edits": MjActionEdits,
} }

View File

@@ -0,0 +1,8 @@
package constant
type MultiKeyMode string
const (
MultiKeyModeRandom MultiKeyMode = "random" // 随机
MultiKeyModePolling MultiKeyMode = "polling" // 轮询
)

View File

@@ -10,6 +10,9 @@ const (
const ( const (
SunoActionMusic = "MUSIC" SunoActionMusic = "MUSIC"
SunoActionLyrics = "LYRICS" SunoActionLyrics = "LYRICS"
TaskActionGenerate = "generate"
TaskActionTextGenerate = "textGenerate"
) )
var SunoModel2Action = map[string]string{ var SunoModel2Action = map[string]string{

View File

@@ -1,15 +0,0 @@
package constant
var (
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
)
var (
NotifyTypeEmail = "email" // Email 邮件
NotifyTypeWebhook = "webhook" // Webhook
)

View File

@@ -7,11 +7,16 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/model" "one-api/model"
"one-api/service" "one-api/service"
"one-api/setting"
"one-api/types"
"strconv" "strconv"
"time" "time"
"github.com/shopspring/decimal"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -130,7 +135,11 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers { for k := range headers {
req.Header.Add(k, headers.Get(k)) req.Header.Add(k, headers.Get(k))
} }
res, err := service.GetHttpClient().Do(req) client, err := service.NewProxyHttpClient(channel.GetSetting().Proxy)
if err != nil {
return nil, err
}
res, err := client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -304,34 +313,70 @@ func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
return balance, nil return balance, nil
} }
func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
url := "https://api.moonshot.cn/v1/users/me/balance"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
type MoonshotBalanceData struct {
AvailableBalance float64 `json:"available_balance"`
VoucherBalance float64 `json:"voucher_balance"`
CashBalance float64 `json:"cash_balance"`
}
type MoonshotBalanceResponse struct {
Code int `json:"code"`
Data MoonshotBalanceData `json:"data"`
Scode string `json:"scode"`
Status bool `json:"status"`
}
response := MoonshotBalanceResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if !response.Status || response.Code != 0 {
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
}
availableBalanceCny := response.Data.AvailableBalance
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
channel.UpdateBalance(availableBalanceUsd)
return availableBalanceUsd, nil
}
func updateChannelBalance(channel *model.Channel) (float64, error) { func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type] baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" { if channel.GetBaseURL() == "" {
channel.BaseURL = &baseURL channel.BaseURL = &baseURL
} }
switch channel.Type { switch channel.Type {
case common.ChannelTypeOpenAI: case constant.ChannelTypeOpenAI:
if channel.GetBaseURL() != "" { if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL() baseURL = channel.GetBaseURL()
} }
case common.ChannelTypeAzure: case constant.ChannelTypeAzure:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
case common.ChannelTypeCustom: case constant.ChannelTypeCustom:
baseURL = channel.GetBaseURL() baseURL = channel.GetBaseURL()
//case common.ChannelTypeOpenAISB: //case common.ChannelTypeOpenAISB:
// return updateChannelOpenAISBBalance(channel) // return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy: case constant.ChannelTypeAIProxy:
return updateChannelAIProxyBalance(channel) return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT: case constant.ChannelTypeAPI2GPT:
return updateChannelAPI2GPTBalance(channel) return updateChannelAPI2GPTBalance(channel)
case common.ChannelTypeAIGC2D: case constant.ChannelTypeAIGC2D:
return updateChannelAIGC2DBalance(channel) return updateChannelAIGC2DBalance(channel)
case common.ChannelTypeSiliconFlow: case constant.ChannelTypeSiliconFlow:
return updateChannelSiliconFlowBalance(channel) return updateChannelSiliconFlowBalance(channel)
case common.ChannelTypeDeepSeek: case constant.ChannelTypeDeepSeek:
return updateChannelDeepSeekBalance(channel) return updateChannelDeepSeekBalance(channel)
case common.ChannelTypeOpenRouter: case constant.ChannelTypeOpenRouter:
return updateChannelOpenRouterBalance(channel) return updateChannelOpenRouterBalance(channel)
case constant.ChannelTypeMoonshot:
return updateChannelMoonshotBalance(channel)
default: default:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
} }
@@ -370,26 +415,24 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
func UpdateChannelBalance(c *gin.Context) { func UpdateChannelBalance(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
channel, err := model.GetChannelById(id, true) channel, err := model.CacheGetChannel(id)
if err != nil { if err != nil {
common.ApiError(c, err)
return
}
if channel.ChannelInfo.IsMultiKey {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": "多密钥渠道不支持余额查询",
}) })
return return
} }
balance, err := updateChannelBalance(channel) balance, err := updateChannelBalance(channel)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -397,7 +440,6 @@ func UpdateChannelBalance(c *gin.Context) {
"message": "", "message": "",
"balance": balance, "balance": balance,
}) })
return
} }
func updateAllChannelsBalance() error { func updateAllChannelsBalance() error {
@@ -409,6 +451,9 @@ func updateAllChannelsBalance() error {
if channel.Status != common.ChannelStatusEnabled { if channel.Status != common.ChannelStatusEnabled {
continue continue
} }
if channel.ChannelInfo.IsMultiKey {
continue // skip multi-key channels
}
// TODO: support Azure // TODO: support Azure
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { //if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
// continue // continue
@@ -419,7 +464,7 @@ func updateAllChannelsBalance() error {
} else { } else {
// err is nil & balance <= 0 means quota is used up // err is nil & balance <= 0 means quota is used up
if balance <= 0 { if balance <= 0 {
service.DisableChannel(channel.Id, channel.Name, "余额不足") service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足")
} }
} }
time.Sleep(common.RequestInterval) time.Sleep(common.RequestInterval)
@@ -431,10 +476,7 @@ func UpdateAllChannelsBalance(c *gin.Context) {
// TODO: make it async // TODO: make it async
err := updateAllChannelsBalance() err := updateAllChannelsBalance()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -11,14 +11,16 @@ import (
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/middleware" "one-api/middleware"
"one-api/model" "one-api/model"
"one-api/relay" "one-api/relay"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/types"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -29,16 +31,49 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { type testResult struct {
context *gin.Context
localErr error
newAPIError *types.NewAPIError
}
func testChannel(channel *model.Channel, testModel string) testResult {
tik := time.Now() tik := time.Now()
if channel.Type == common.ChannelTypeMidjourney { if channel.Type == constant.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil return testResult{
localErr: errors.New("midjourney channel test is not supported"),
newAPIError: nil,
}
} }
if channel.Type == common.ChannelTypeMidjourneyPlus { if channel.Type == constant.ChannelTypeMidjourneyPlus {
return errors.New("midjourney plus channel test is not supported!!!"), nil return testResult{
localErr: errors.New("midjourney plus channel test is not supported"),
newAPIError: nil,
}
} }
if channel.Type == common.ChannelTypeSunoAPI { if channel.Type == constant.ChannelTypeSunoAPI {
return errors.New("suno channel test is not supported"), nil return testResult{
localErr: errors.New("suno channel test is not supported"),
newAPIError: nil,
}
}
if channel.Type == constant.ChannelTypeKling {
return testResult{
localErr: errors.New("kling channel test is not supported"),
newAPIError: nil,
}
}
if channel.Type == constant.ChannelTypeJimeng {
return testResult{
localErr: errors.New("jimeng channel test is not supported"),
newAPIError: nil,
}
}
if channel.Type == constant.ChannelTypeVidu {
return testResult{
localErr: errors.New("vidu channel test is not supported"),
newAPIError: nil,
}
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
@@ -50,7 +85,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型 strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
strings.Contains(testModel, "bge-") || // bge 系列模型 strings.Contains(testModel, "bge-") || // bge 系列模型
strings.Contains(testModel, "embed") || strings.Contains(testModel, "embed") ||
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型 channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
requestPath = "/v1/embeddings" // 修改请求路径 requestPath = "/v1/embeddings" // 修改请求路径
} }
@@ -75,80 +110,162 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
cache, err := model.GetUserCache(1) cache, err := model.GetUserCache(1)
if err != nil { if err != nil {
return err, nil return testResult{
localErr: err,
newAPIError: nil,
}
} }
cache.WriteContext(c) cache.WriteContext(c)
c.Request.Header.Set("Authorization", "Bearer "+channel.Key) //c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json") c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type) c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
group, _ := model.GetUserGroup(1, false) group, _ := model.GetUserGroup(1, false)
c.Set("group", group) c.Set("group", group)
middleware.SetupContextForSelectedChannel(c, channel, testModel) newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
if newAPIError != nil {
info := relaycommon.GenRelayInfo(c) return testResult{
context: c,
err = helper.ModelMappedHelper(c, info) localErr: newAPIError,
if err != nil { newAPIError: newAPIError,
return err, nil }
} }
testModel = info.UpstreamModelName request := buildTestRequest(testModel)
apiType, _ := constant.ChannelType2APIType(channel.Type) // Determine relay format based on request path
relayFormat := types.RelayFormatOpenAI
if c.Request.URL.Path == "/v1/embeddings" {
relayFormat = types.RelayFormatEmbedding
}
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
if err != nil {
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed),
}
}
info.InitChannelMeta(c)
err = helper.ModelMappedHelper(c, info, request)
if err != nil {
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
}
}
testModel = info.UpstreamModelName
request.Model = testModel
apiType, _ := common.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType) adaptor := relay.GetAdaptor(apiType)
if adaptor == nil { if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return testResult{
context: c,
localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
}
} }
request := buildTestRequest(testModel) //// 创建一个用于日志的 info 副本,移除 ApiKey
// 创建一个用于日志的 info 副本,移除 ApiKey //logInfo := info
logInfo := *info //logInfo.ApiKey = ""
logInfo.ApiKey = "" common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString()))
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens)) priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta())
if err != nil { if err != nil {
return err, nil return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
}
} }
adaptor.Init(info) adaptor.Init(info)
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request) var convertedRequest any
// 根据 RelayMode 选择正确的转换函数
if info.RelayMode == relayconstant.RelayModeEmbeddings {
// 创建一个 EmbeddingRequest
embeddingRequest := dto.EmbeddingRequest{
Input: request.Input,
Model: request.Model,
}
// 调用专门用于 Embedding 的转换函数
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
} else {
// 对其他所有请求类型(如 Chat保持原有逻辑
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
}
if err != nil { if err != nil {
return err, nil return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
}
} }
jsonData, err := json.Marshal(convertedRequest) jsonData, err := json.Marshal(convertedRequest)
if err != nil { if err != nil {
return err, nil return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
}
} }
requestBody := bytes.NewBuffer(jsonData) requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody) c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, info, requestBody) resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil { if err != nil {
return err, nil return testResult{
context: c,
localErr: err,
newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
}
} }
var httpResp *http.Response var httpResp *http.Response
if resp != nil { if resp != nil {
httpResp = resp.(*http.Response) httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(httpResp, true) err := service.RelayErrorHandler(httpResp, true)
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err return testResult{
context: c,
localErr: err,
newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
}
} }
} }
usageA, respErr := adaptor.DoResponse(c, httpResp, info) usageA, respErr := adaptor.DoResponse(c, httpResp, info)
if respErr != nil { if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), respErr return testResult{
context: c,
localErr: respErr,
newAPIError: respErr,
}
} }
if usageA == nil { if usageA == nil {
return errors.New("usage is nil"), nil return testResult{
context: c,
localErr: errors.New("usage is nil"),
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
} }
usage := usageA.(*dto.Usage) usage := usageA.(*dto.Usage)
result := w.Result() result := w.Result()
respBody, err := io.ReadAll(result.Body) respBody, err := io.ReadAll(result.Body)
if err != nil { if err != nil {
return err, nil return testResult{
context: c,
localErr: err,
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
}
} }
info.PromptTokens = usage.PromptTokens info.PromptTokens = usage.PromptTokens
@@ -165,12 +282,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0 consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice) usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试", model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other) ChannelId: channel.Id,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
ModelName: info.OriginModelName,
TokenName: "模型测试",
Quota: quota,
Content: "模型测试",
UseTimeSeconds: int(consumedTime),
IsStream: info.IsStream,
Group: info.UsingGroup,
Other: other,
})
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil return testResult{
context: c,
localErr: nil,
newAPIError: nil,
}
} }
func buildTestRequest(model string) *dto.GeneralOpenAIRequest { func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
@@ -185,7 +317,7 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
strings.Contains(model, "bge-") { strings.Contains(model, "bge-") {
testRequest.Model = model testRequest.Model = model
// Embedding 请求 // Embedding 请求
testRequest.Input = []string{"hello world"} testRequest.Input = []any{"hello world"} // 修改为any因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
return testRequest return testRequest
} }
// 并非Embedding 模型 // 并非Embedding 模型
@@ -196,14 +328,14 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
testRequest.MaxTokens = 50 testRequest.MaxTokens = 50
} }
} else if strings.Contains(model, "gemini") { } else if strings.Contains(model, "gemini") {
testRequest.MaxTokens = 300 testRequest.MaxTokens = 3000
} else { } else {
testRequest.MaxTokens = 10 testRequest.MaxTokens = 10
} }
content, _ := json.Marshal("hi")
testMessage := dto.Message{ testMessage := dto.Message{
Role: "user", Role: "user",
Content: content, Content: "hi",
} }
testRequest.Model = model testRequest.Model = model
testRequest.Messages = append(testRequest.Messages, testMessage) testRequest.Messages = append(testRequest.Messages, testMessage)
@@ -213,31 +345,41 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
func TestChannel(c *gin.Context) { func TestChannel(c *gin.Context) {
channelId, err := strconv.Atoi(c.Param("id")) channelId, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
channel, err := model.GetChannelById(channelId, true) channel, err := model.CacheGetChannel(channelId)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ channel, err = model.GetChannelById(channelId, true)
"success": false, if err != nil {
"message": err.Error(), common.ApiError(c, err)
}) return
return }
} }
//defer func() {
// if channel.ChannelInfo.IsMultiKey {
// go func() { _ = channel.SaveChannelInfo() }()
// }
//}()
testModel := c.Query("model") testModel := c.Query("model")
tik := time.Now() tik := time.Now()
err, _ = testChannel(channel, testModel) result := testChannel(channel, testModel)
if result.localErr != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": result.localErr.Error(),
"time": 0.0,
})
return
}
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds) go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0 consumedTime := float64(milliseconds) / 1000.0
if err != nil { if result.newAPIError != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": result.newAPIError.Error(),
"time": consumedTime, "time": consumedTime,
}) })
return return
@@ -262,52 +404,59 @@ func testAllChannels(notify bool) error {
} }
testAllChannelsRunning = true testAllChannelsRunning = true
testAllChannelsLock.Unlock() testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true, false) channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
if err != nil { if getChannelErr != nil {
return err return getChannelErr
} }
var disableThreshold = int64(common.ChannelDisableThreshold * 1000) var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 { if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value disableThreshold = 10000000 // a impossible value
} }
gopool.Go(func() { gopool.Go(func() {
// 使用 defer 确保无论如何都会重置运行状态,防止死锁
defer func() {
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
}()
for _, channel := range channels { for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
err, openaiWithStatusErr := testChannel(channel, "") result := testChannel(channel, "")
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
shouldBanChannel := false shouldBanChannel := false
newAPIError := result.newAPIError
// request error disables the channel // request error disables the channel
if openaiWithStatusErr != nil { if newAPIError != nil {
oaiErr := openaiWithStatusErr.Error shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
} }
if milliseconds > disableThreshold { // 当错误检查通过,才检查响应时间
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
shouldBanChannel = true if milliseconds > disableThreshold {
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
shouldBanChannel = true
}
} }
// disable channel // disable channel
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() { if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
service.DisableChannel(channel.Id, channel.Name, err.Error()) processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
} }
// enable channel // enable channel
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) { if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
service.EnableChannel(channel.Id, channel.Name) service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
} }
channel.UpdateResponseTime(milliseconds) channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval) time.Sleep(common.RequestInterval)
} }
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
if notify { if notify {
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成") service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
} }
@@ -318,10 +467,7 @@ func testAllChannels(notify bool) error {
func TestAllChannels(c *gin.Context) { func TestAllChannels(c *gin.Context) {
err := testAllChannels(true) err := testAllChannels(true)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -332,6 +478,10 @@ func TestAllChannels(c *gin.Context) {
} }
func AutomaticallyTestChannels(frequency int) { func AutomaticallyTestChannels(frequency int) {
if frequency <= 0 {
common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
return
}
for { for {
time.Sleep(time.Duration(frequency) * time.Minute) time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("testing all channels") common.SysLog("testing all channels")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,104 @@
// 用于迁移检测的旧键,该文件下个版本会删除
package controller
import (
"encoding/json"
"net/http"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
func MigrateConsoleSetting(c *gin.Context) {
// 读取全部 option
opts, err := model.AllOption()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
return
}
// 建立 map
valMap := map[string]string{}
for _, o := range opts {
valMap[o.Key] = o.Value
}
// 处理 APIInfo
if v := valMap["ApiInfo"]; v != "" {
var arr []map[string]interface{}
if err := json.Unmarshal([]byte(v), &arr); err == nil {
if len(arr) > 50 {
arr = arr[:50]
}
bytes, _ := json.Marshal(arr)
model.UpdateOption("console_setting.api_info", string(bytes))
}
model.UpdateOption("ApiInfo", "")
}
// Announcements 直接搬
if v := valMap["Announcements"]; v != "" {
model.UpdateOption("console_setting.announcements", v)
model.UpdateOption("Announcements", "")
}
// FAQ 转换
if v := valMap["FAQ"]; v != "" {
var arr []map[string]interface{}
if err := json.Unmarshal([]byte(v), &arr); err == nil {
out := []map[string]interface{}{}
for _, item := range arr {
q, _ := item["question"].(string)
if q == "" {
q, _ = item["title"].(string)
}
a, _ := item["answer"].(string)
if a == "" {
a, _ = item["content"].(string)
}
if q != "" && a != "" {
out = append(out, map[string]interface{}{"question": q, "answer": a})
}
}
if len(out) > 50 {
out = out[:50]
}
bytes, _ := json.Marshal(out)
model.UpdateOption("console_setting.faq", string(bytes))
}
model.UpdateOption("FAQ", "")
}
// Uptime Kuma 迁移到新的 groups 结构console_setting.uptime_kuma_groups
url := valMap["UptimeKumaUrl"]
slug := valMap["UptimeKumaSlug"]
if url != "" && slug != "" {
// 仅当同时存在 URL 与 Slug 时才进行迁移
groups := []map[string]interface{}{
{
"id": 1,
"categoryName": "old",
"url": url,
"slug": slug,
"description": "",
},
}
bytes, _ := json.Marshal(groups)
model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
}
// 清空旧键内容
if url != "" {
model.UpdateOption("UptimeKumaUrl", "")
}
if slug != "" {
model.UpdateOption("UptimeKumaSlug", "")
}
// 删除旧键记录
oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
// 重新加载 OptionMap
model.InitOptionMap()
common.SysLog("console setting migrated")
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
}

View File

@@ -5,13 +5,14 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"time" "time"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
) )
type GitHubOAuthResponse struct { type GitHubOAuthResponse struct {
@@ -103,10 +104,7 @@ func GitHubOAuth(c *gin.Context) {
code := c.Query("code") code := c.Query("code")
githubUser, err := getGitHubUserInfoByCode(code) githubUser, err := getGitHubUserInfoByCode(code)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
user := model.User{ user := model.User{
@@ -185,10 +183,7 @@ func GitHubBind(c *gin.Context) {
code := c.Query("code") code := c.Query("code")
githubUser, err := getGitHubUserInfoByCode(code) githubUser, err := getGitHubUserInfoByCode(code)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
user := model.User{ user := model.User{
@@ -207,19 +202,13 @@ func GitHubBind(c *gin.Context) {
user.Id = id.(int) user.Id = id.(int)
err = user.FillUserById() err = user.FillUserById()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
user.GitHubId = githubUser.Login user.GitHubId = githubUser.Login
err = user.Update(false) err = user.Update(false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -239,10 +228,7 @@ func GenerateOAuthCode(c *gin.Context) {
session.Set("oauth_state", state) session.Set("oauth_state", state)
err := session.Save() err := session.Save()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -1,15 +1,17 @@
package controller package controller
import ( import (
"github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/model" "one-api/model"
"one-api/setting" "one-api/setting"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
) )
func GetGroups(c *gin.Context) { func GetGroups(c *gin.Context) {
groupNames := make([]string, 0) groupNames := make([]string, 0)
for groupName, _ := range setting.GetGroupRatioCopy() { for groupName := range ratio_setting.GetGroupRatioCopy() {
groupNames = append(groupNames, groupName) groupNames = append(groupNames, groupName)
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -24,7 +26,7 @@ func GetUserGroups(c *gin.Context) {
userGroup := "" userGroup := ""
userId := c.GetInt("id") userId := c.GetInt("id")
userGroup, _ = model.GetUserGroup(userId, false) userGroup, _ = model.GetUserGroup(userId, false)
for groupName, ratio := range setting.GetGroupRatioCopy() { for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
// UserUsableGroups contains the groups that the user can use // UserUsableGroups contains the groups that the user can use
userUsableGroups := setting.GetUserUsableGroups(userGroup) userUsableGroups := setting.GetUserUsableGroups(userGroup)
if desc, ok := userUsableGroups[groupName]; ok { if desc, ok := userUsableGroups[groupName]; ok {
@@ -34,6 +36,12 @@ func GetUserGroups(c *gin.Context) {
} }
} }
} }
if setting.GroupInUserUsableGroups("auto") {
usableGroups["auto"] = map[string]interface{}{
"ratio": "自动",
"desc": setting.GetUsableGroupDescription("auto"),
}
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",

View File

@@ -38,10 +38,7 @@ func LinuxDoBind(c *gin.Context) {
code := c.Query("code") code := c.Query("code")
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c) linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
@@ -63,20 +60,14 @@ func LinuxDoBind(c *gin.Context) {
err = user.FillUserById() err = user.FillUserById()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id) user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
err = user.Update(false) err = user.Update(false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
@@ -202,10 +193,7 @@ func LinuxdoOAuth(c *gin.Context) {
code := c.Query("code") code := c.Query("code")
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c) linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
@@ -232,21 +220,29 @@ func LinuxdoOAuth(c *gin.Context) {
} }
} else { } else {
if common.RegisterEnabled { if common.RegisterEnabled {
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1) if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel {
user.DisplayName = linuxdoUser.Name user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
user.Role = common.RoleCommonUser user.DisplayName = linuxdoUser.Name
user.Status = common.UserStatusEnabled user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
affCode := session.Get("aff") affCode := session.Get("aff")
inviterId := 0 inviterId := 0
if affCode != nil { if affCode != nil {
inviterId, _ = model.GetUserIdByAffCode(affCode.(string)) inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
} }
if err := user.Insert(inviterId); err != nil { if err := user.Insert(inviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": "Linux DO 信任等级未达到管理员设置的最低信任等级",
}) })
return return
} }

View File

@@ -10,14 +10,7 @@ import (
) )
func GetAllLogs(c *gin.Context) { func GetAllLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) pageInfo := common.GetPageQuery(c)
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 {
p = 1
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
logType, _ := strconv.Atoi(c.Query("type")) logType, _ := strconv.Atoi(c.Query("type"))
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
@@ -26,38 +19,19 @@ func GetAllLogs(c *gin.Context) {
modelName := c.Query("model_name") modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel")) channel, _ := strconv.Atoi(c.Query("channel"))
group := c.Query("group") group := c.Query("group")
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel, group) logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), channel, group)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ pageInfo.SetTotal(int(total))
"success": true, pageInfo.SetItems(logs)
"message": "", common.ApiSuccess(c, pageInfo)
"data": map[string]any{ return
"items": logs,
"total": total,
"page": p,
"page_size": pageSize,
},
})
} }
func GetUserLogs(c *gin.Context) { func GetUserLogs(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) pageInfo := common.GetPageQuery(c)
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 {
p = 1
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
if pageSize > 100 {
pageSize = 100
}
userId := c.GetInt("id") userId := c.GetInt("id")
logType, _ := strconv.Atoi(c.Query("type")) logType, _ := strconv.Atoi(c.Query("type"))
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
@@ -65,24 +39,14 @@ func GetUserLogs(c *gin.Context) {
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
group := c.Query("group") group := c.Query("group")
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize, group) logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), group)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ pageInfo.SetTotal(int(total))
"success": true, pageInfo.SetItems(logs)
"message": "", common.ApiSuccess(c, pageInfo)
"data": map[string]any{
"items": logs,
"total": total,
"page": p,
"page_size": pageSize,
},
})
return return
} }
@@ -90,10 +54,7 @@ func SearchAllLogs(c *gin.Context) {
keyword := c.Query("keyword") keyword := c.Query("keyword")
logs, err := model.SearchAllLogs(keyword) logs, err := model.SearchAllLogs(keyword)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -109,10 +70,7 @@ func SearchUserLogs(c *gin.Context) {
userId := c.GetInt("id") userId := c.GetInt("id")
logs, err := model.SearchUserLogs(userId, keyword) logs, err := model.SearchUserLogs(userId, keyword)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -198,10 +156,7 @@ func DeleteHistoryLogs(c *gin.Context) {
} }
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100) count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -5,17 +5,17 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/logger"
"one-api/model" "one-api/model"
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting"
"strconv"
"time" "time"
"github.com/gin-gonic/gin"
) )
func UpdateMidjourneyTaskBulk() { func UpdateMidjourneyTaskBulk() {
@@ -29,7 +29,7 @@ func UpdateMidjourneyTaskBulk() {
continue continue
} }
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
taskChannelM := make(map[int][]string) taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Midjourney) taskM := make(map[string]*model.Midjourney)
nullTaskIds := make([]int, 0) nullTaskIds := make([]int, 0)
@@ -48,9 +48,9 @@ func UpdateMidjourneyTaskBulk() {
"progress": "100%", "progress": "100%",
}) })
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err)) logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
} else { } else {
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds)) logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
} }
} }
if len(taskChannelM) == 0 { if len(taskChannelM) == 0 {
@@ -58,20 +58,20 @@ func UpdateMidjourneyTaskBulk() {
} }
for channelId, taskIds := range taskChannelM { for channelId, taskIds := range taskChannelM {
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 { if len(taskIds) == 0 {
continue continue
} }
midjourneyChannel, err := model.CacheGetChannel(channelId) midjourneyChannel, err := model.CacheGetChannel(channelId)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err)) logger.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
err := model.MjBulkUpdate(taskIds, map[string]any{ err := model.MjBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", channelId), "fail_reason": fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", channelId),
"status": "FAILURE", "status": "FAILURE",
"progress": "100%", "progress": "100%",
}) })
if err != nil { if err != nil {
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
} }
continue continue
} }
@@ -82,7 +82,7 @@ func UpdateMidjourneyTaskBulk() {
}) })
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body)) req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err)) logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
continue continue
} }
// 设置超时时间 // 设置超时时间
@@ -94,22 +94,22 @@ func UpdateMidjourneyTaskBulk() {
req.Header.Set("mj-api-secret", midjourneyChannel.Key) req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := service.GetHttpClient().Do(req) resp, err := service.GetHttpClient().Do(req)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue continue
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
continue continue
} }
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue continue
} }
var responseItems []dto.MidjourneyDto var responseItems []dto.MidjourneyDto
err = json.Unmarshal(responseBody, &responseItems) err = json.Unmarshal(responseBody, &responseItems)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
continue continue
} }
resp.Body.Close() resp.Body.Close()
@@ -146,9 +146,25 @@ func UpdateMidjourneyTaskBulk() {
buttonStr, _ := json.Marshal(responseItem.Buttons) buttonStr, _ := json.Marshal(responseItem.Buttons)
task.Buttons = string(buttonStr) task.Buttons = string(buttonStr)
} }
// 映射 VideoUrl
task.VideoUrl = responseItem.VideoUrl
// 映射 VideoUrls - 将数组序列化为 JSON 字符串
if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 {
videoUrlsStr, err := json.Marshal(responseItem.VideoUrls)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err))
task.VideoUrls = "[]" // 失败时设置为空数组
} else {
task.VideoUrls = string(videoUrlsStr)
}
} else {
task.VideoUrls = "" // 空值时清空字段
}
shouldReturnQuota := false shouldReturnQuota := false
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%" task.Progress = "100%"
if task.Quota != 0 { if task.Quota != 0 {
shouldReturnQuota = true shouldReturnQuota = true
@@ -156,14 +172,14 @@ func UpdateMidjourneyTaskBulk() {
} }
err = task.Update() err = task.Update()
if err != nil { if err != nil {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else { } else {
if shouldReturnQuota { if shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota, false) err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
if err != nil { if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error()) logger.LogError(ctx, "fail to increase user quota: "+err.Error())
} }
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(task.Quota)) logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, logger.LogQuota(task.Quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent) model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
} }
} }
@@ -209,15 +225,26 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
if oldTask.Progress != "100%" && newTask.FailReason != "" { if oldTask.Progress != "100%" && newTask.FailReason != "" {
return true return true
} }
// 检查 VideoUrl 是否需要更新
if oldTask.VideoUrl != newTask.VideoUrl {
return true
}
// 检查 VideoUrls 是否需要更新
if newTask.VideoUrls != nil && len(newTask.VideoUrls) > 0 {
newVideoUrlsStr, _ := json.Marshal(newTask.VideoUrls)
if oldTask.VideoUrls != string(newVideoUrlsStr) {
return true
}
} else if oldTask.VideoUrls != "" {
// 如果新数据没有 VideoUrls 但旧数据有,需要更新(清空)
return true
}
return false return false
} }
func GetAllMidjourney(c *gin.Context) { func GetAllMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) pageInfo := common.GetPageQuery(c)
if p < 0 {
p = 0
}
// 解析其他查询参数 // 解析其他查询参数
queryParams := model.TaskQueryParams{ queryParams := model.TaskQueryParams{
@@ -227,31 +254,24 @@ func GetAllMidjourney(c *gin.Context) {
EndTimestamp: c.Query("end_timestamp"), EndTimestamp: c.Query("end_timestamp"),
} }
logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams) items := model.GetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
if logs == nil { total := model.CountAllTasks(queryParams)
logs = make([]*model.Midjourney, 0)
}
if setting.MjForwardUrlEnabled { if setting.MjForwardUrlEnabled {
for i, midjourney := range logs { for i, midjourney := range items {
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney items[i] = midjourney
} }
} }
c.JSON(200, gin.H{ pageInfo.SetTotal(int(total))
"success": true, pageInfo.SetItems(items)
"message": "", common.ApiSuccess(c, pageInfo)
"data": logs,
})
} }
func GetUserMidjourney(c *gin.Context) { func GetUserMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) pageInfo := common.GetPageQuery(c)
if p < 0 {
p = 0
}
userId := c.GetInt("id") userId := c.GetInt("id")
log.Printf("userId = %d \n", userId)
queryParams := model.TaskQueryParams{ queryParams := model.TaskQueryParams{
MjID: c.Query("mj_id"), MjID: c.Query("mj_id"),
@@ -259,19 +279,16 @@ func GetUserMidjourney(c *gin.Context) {
EndTimestamp: c.Query("end_timestamp"), EndTimestamp: c.Query("end_timestamp"),
} }
logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams) items := model.GetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
if logs == nil { total := model.CountAllUserTask(userId, queryParams)
logs = make([]*model.Midjourney, 0)
}
if setting.MjForwardUrlEnabled { if setting.MjForwardUrlEnabled {
for i, midjourney := range logs { for i, midjourney := range items {
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney items[i] = midjourney
} }
} }
c.JSON(200, gin.H{ pageInfo.SetTotal(int(total))
"success": true, pageInfo.SetItems(items)
"message": "", common.ApiSuccess(c, pageInfo)
"data": logs,
})
} }

View File

@@ -6,8 +6,10 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/middleware"
"one-api/model" "one-api/model"
"one-api/setting" "one-api/setting"
"one-api/setting/console_setting"
"one-api/setting/operation_setting" "one-api/setting/operation_setting"
"one-api/setting/system_setting" "one-api/setting/system_setting"
"strings" "strings"
@@ -24,57 +26,90 @@ func TestStatus(c *gin.Context) {
}) })
return return
} }
// 获取HTTP统计信息
httpStats := middleware.GetStats()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "Server is running", "message": "Server is running",
"http_stats": httpStats,
}) })
return return
} }
func GetStatus(c *gin.Context) { func GetStatus(c *gin.Context) {
cs := console_setting.GetConsoleSetting()
data := gin.H{
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
"linuxdo_client_id": common.LinuxDOClientId,
"linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": setting.ServerAddress,
"price": setting.Price,
"stripe_unit_price": setting.StripeUnitPrice,
"min_topup": setting.MinTopUp,
"stripe_min_topup": setting.StripeMinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_task": common.TaskEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"default_use_auto_group": setting.DefaultUseAutoGroup,
"pay_methods": setting.PayMethods,
"usd_exchange_rate": setting.USDExchangeRate,
// 面板启用开关
"api_info_enabled": cs.ApiInfoEnabled,
"uptime_kuma_enabled": cs.UptimeKumaEnabled,
"announcements_enabled": cs.AnnouncementsEnabled,
"faq_enabled": cs.FAQEnabled,
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
"setup": constant.Setup,
}
// 根据启用状态注入可选内容
if cs.ApiInfoEnabled {
data["api_info"] = console_setting.GetApiInfo()
}
if cs.AnnouncementsEnabled {
data["announcements"] = console_setting.GetAnnouncements()
}
if cs.FAQEnabled {
data["faq"] = console_setting.GetFAQ()
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": gin.H{ "data": data,
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
"linuxdo_client_id": common.LinuxDOClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": setting.ServerAddress,
"price": setting.Price,
"min_topup": setting.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_task": common.TaskEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
"setup": constant.Setup,
},
}) })
return return
} }
@@ -184,10 +219,7 @@ func SendEmailVerification(c *gin.Context) {
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes) "<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content) err := common.SendEmail(subject, email, content)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -223,10 +255,7 @@ func SendPasswordResetEmail(c *gin.Context) {
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes) "<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content) err := common.SendEmail(subject, email, content)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -261,10 +290,7 @@ func ResetPassword(c *gin.Context) {
password := common.GenerateVerificationCode(12) password := common.GenerateVerificationCode(12)
err = model.ResetUserPasswordByEmail(req.Email, password) err = model.ResetUserPasswordByEmail(req.Email, password)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
common.DeleteKey(req.Email, common.PasswordResetPurpose) common.DeleteKey(req.Email, common.PasswordResetPurpose)

View File

@@ -0,0 +1,27 @@
package controller
import (
"net/http"
"one-api/model"
"github.com/gin-gonic/gin"
)
// GetMissingModels returns the list of model names that are referenced by channels
// but do not have corresponding records in the models meta table.
// This helps administrators quickly discover models that need configuration.
func GetMissingModels(c *gin.Context) {
missing, err := model.GetMissingModels()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": missing,
})
}

View File

@@ -3,6 +3,7 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/samber/lo"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
@@ -14,7 +15,8 @@ import (
"one-api/relay/channel/minimax" "one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot" "one-api/relay/channel/moonshot"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" "one-api/setting"
"time"
) )
// https://platform.openai.com/docs/api-reference/models/list // https://platform.openai.com/docs/api-reference/models/list
@@ -23,30 +25,10 @@ var openAIModels []dto.OpenAIModels
var openAIModelsMap map[string]dto.OpenAIModels var openAIModelsMap map[string]dto.OpenAIModels
var channelId2Models map[int][]string var channelId2Models map[int][]string
func getPermission() []dto.OpenAIModelPermission {
var permission []dto.OpenAIModelPermission
permission = append(permission, dto.OpenAIModelPermission{
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
Object: "model_permission",
Created: 1626777600,
AllowCreateEngine: true,
AllowSampling: true,
AllowLogprobs: true,
AllowSearchIndices: false,
AllowView: true,
AllowFineTuning: false,
Organization: "*",
Group: nil,
IsBlocking: false,
})
return permission
}
func init() { func init() {
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
permission := getPermission() for i := 0; i < constant.APITypeDummy; i++ {
for i := 0; i < relayconstant.APITypeDummy; i++ { if i == constant.APITypeAIProxyLibrary {
if i == relayconstant.APITypeAIProxyLibrary {
continue continue
} }
adaptor := relay.GetAdaptor(i) adaptor := relay.GetAdaptor(i)
@@ -54,69 +36,51 @@ func init() {
modelNames := adaptor.GetModelList() modelNames := adaptor.GetModelList()
for _, modelName := range modelNames { for _, modelName := range modelNames {
openAIModels = append(openAIModels, dto.OpenAIModels{ openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
OwnedBy: channelName, OwnedBy: channelName,
Permission: permission,
Root: modelName,
Parent: nil,
}) })
} }
} }
for _, modelName := range ai360.ModelList { for _, modelName := range ai360.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{ openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
OwnedBy: ai360.ChannelName, OwnedBy: ai360.ChannelName,
Permission: permission,
Root: modelName,
Parent: nil,
}) })
} }
for _, modelName := range moonshot.ModelList { for _, modelName := range moonshot.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{ openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
OwnedBy: moonshot.ChannelName, OwnedBy: moonshot.ChannelName,
Permission: permission,
Root: modelName,
Parent: nil,
}) })
} }
for _, modelName := range lingyiwanwu.ModelList { for _, modelName := range lingyiwanwu.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{ openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
OwnedBy: lingyiwanwu.ChannelName, OwnedBy: lingyiwanwu.ChannelName,
Permission: permission,
Root: modelName,
Parent: nil,
}) })
} }
for _, modelName := range minimax.ModelList { for _, modelName := range minimax.ModelList {
openAIModels = append(openAIModels, dto.OpenAIModels{ openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
OwnedBy: minimax.ChannelName, OwnedBy: minimax.ChannelName,
Permission: permission,
Root: modelName,
Parent: nil,
}) })
} }
for modelName, _ := range constant.MidjourneyModel2Action { for modelName, _ := range constant.MidjourneyModel2Action {
openAIModels = append(openAIModels, dto.OpenAIModels{ openAIModels = append(openAIModels, dto.OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
OwnedBy: "midjourney", OwnedBy: "midjourney",
Permission: permission,
Root: modelName,
Parent: nil,
}) })
} }
openAIModelsMap = make(map[string]dto.OpenAIModels) openAIModelsMap = make(map[string]dto.OpenAIModels)
@@ -124,25 +88,29 @@ func init() {
openAIModelsMap[aiModel.Id] = aiModel openAIModelsMap[aiModel.Id] = aiModel
} }
channelId2Models = make(map[int][]string) channelId2Models = make(map[int][]string)
for i := 1; i <= common.ChannelTypeDummy; i++ { for i := 1; i <= constant.ChannelTypeDummy; i++ {
apiType, success := relayconstant.ChannelType2APIType(i) apiType, success := common.ChannelType2APIType(i)
if !success || apiType == relayconstant.APITypeAIProxyLibrary { if !success || apiType == constant.APITypeAIProxyLibrary {
continue continue
} }
meta := &relaycommon.RelayInfo{ChannelType: i} meta := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
ChannelType: i,
}}
adaptor := relay.GetAdaptor(apiType) adaptor := relay.GetAdaptor(apiType)
adaptor.Init(meta) adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList() channelId2Models[i] = adaptor.GetModelList()
} }
openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
return m.Id
})
} }
func ListModels(c *gin.Context) { func ListModels(c *gin.Context, modelType int) {
userOpenAiModels := make([]dto.OpenAIModels, 0) userOpenAiModels := make([]dto.OpenAIModels, 0)
permission := getPermission()
modelLimitEnable := c.GetBool("token_model_limit_enabled") modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
if modelLimitEnable { if modelLimitEnable {
s, ok := c.Get("token_model_limit") s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
var tokenModelLimit map[string]bool var tokenModelLimit map[string]bool
if ok { if ok {
tokenModelLimit = s.(map[string]bool) tokenModelLimit = s.(map[string]bool)
@@ -150,23 +118,22 @@ func ListModels(c *gin.Context) {
tokenModelLimit = map[string]bool{} tokenModelLimit = map[string]bool{}
} }
for allowModel, _ := range tokenModelLimit { for allowModel, _ := range tokenModelLimit {
if _, ok := openAIModelsMap[allowModel]; ok { if oaiModel, ok := openAIModelsMap[allowModel]; ok {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel]) oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
userOpenAiModels = append(userOpenAiModels, oaiModel)
} else { } else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: allowModel, Id: allowModel,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
OwnedBy: "custom", OwnedBy: "custom",
Permission: permission, SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
Root: allowModel,
Parent: nil,
}) })
} }
} }
} else { } else {
userId := c.GetInt("id") userId := c.GetInt("id")
userGroup, err := model.GetUserGroup(userId, true) userGroup, err := model.GetUserGroup(userId, false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -175,31 +142,73 @@ func ListModels(c *gin.Context) {
return return
} }
group := userGroup group := userGroup
tokenGroup := c.GetString("token_group") tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
if tokenGroup != "" { if tokenGroup != "" {
group = tokenGroup group = tokenGroup
} }
models := model.GetGroupModels(group) var models []string
for _, s := range models { if tokenGroup == "auto" {
if _, ok := openAIModelsMap[s]; ok { for _, autoGroup := range setting.AutoGroups {
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s]) groupModels := model.GetGroupEnabledModels(autoGroup)
for _, g := range groupModels {
if !common.StringsContains(models, g) {
models = append(models, g)
}
}
}
} else {
models = model.GetGroupEnabledModels(group)
}
for _, modelName := range models {
if oaiModel, ok := openAIModelsMap[modelName]; ok {
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
userOpenAiModels = append(userOpenAiModels, oaiModel)
} else { } else {
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
Id: s, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
OwnedBy: "custom", OwnedBy: "custom",
Permission: permission, SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
Root: s,
Parent: nil,
}) })
} }
} }
} }
c.JSON(200, gin.H{ switch modelType {
"success": true, case constant.ChannelTypeAnthropic:
"data": userOpenAiModels, useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
}) for i, model := range userOpenAiModels {
useranthropicModels[i] = dto.AnthropicModel{
ID: model.Id,
CreatedAt: time.Unix(int64(model.Created), 0).UTC().Format(time.RFC3339),
DisplayName: model.Id,
Type: "model",
}
}
c.JSON(200, gin.H{
"data": useranthropicModels,
"first_id": useranthropicModels[0].ID,
"has_more": false,
"last_id": useranthropicModels[len(useranthropicModels)-1].ID,
})
case constant.ChannelTypeGemini:
userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels))
for i, model := range userOpenAiModels {
userGeminiModels[i] = dto.GeminiModel{
Name: model.Id,
DisplayName: model.Id,
}
}
c.JSON(200, gin.H{
"models": userGeminiModels,
"nextPageToken": nil,
})
default:
c.JSON(200, gin.H{
"success": true,
"data": userOpenAiModels,
})
}
} }
func ChannelListModels(c *gin.Context) { func ChannelListModels(c *gin.Context) {
@@ -223,10 +232,20 @@ func EnabledListModels(c *gin.Context) {
}) })
} }
func RetrieveModel(c *gin.Context) { func RetrieveModel(c *gin.Context, modelType int) {
modelId := c.Param("model") modelId := c.Param("model")
if aiModel, ok := openAIModelsMap[modelId]; ok { if aiModel, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, aiModel) switch modelType {
case constant.ChannelTypeAnthropic:
c.JSON(200, dto.AnthropicModel{
ID: aiModel.Id,
CreatedAt: time.Unix(int64(aiModel.Created), 0).UTC().Format(time.RFC3339),
DisplayName: aiModel.Id,
Type: "model",
})
default:
c.JSON(200, aiModel)
}
} else { } else {
openAIError := dto.OpenAIError{ openAIError := dto.OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId), Message: fmt.Sprintf("The model '%s' does not exist", modelId),

330
controller/model_meta.go Normal file
View File

@@ -0,0 +1,330 @@
package controller
import (
"encoding/json"
"sort"
"strconv"
"strings"
"one-api/common"
"one-api/constant"
"one-api/model"
"github.com/gin-gonic/gin"
)
// GetAllModelsMeta 获取模型列表(分页)
func GetAllModelsMeta(c *gin.Context) {
pageInfo := common.GetPageQuery(c)
modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
common.ApiError(c, err)
return
}
// 批量填充附加字段,提升列表接口性能
enrichModels(modelsMeta)
var total int64
model.DB.Model(&model.Model{}).Count(&total)
// 统计供应商计数(全部数据,不受分页影响)
vendorCounts, _ := model.GetVendorModelCounts()
pageInfo.SetTotal(int(total))
pageInfo.SetItems(modelsMeta)
common.ApiSuccess(c, gin.H{
"items": modelsMeta,
"total": total,
"page": pageInfo.GetPage(),
"page_size": pageInfo.GetPageSize(),
"vendor_counts": vendorCounts,
})
}
// SearchModelsMeta 搜索模型列表
func SearchModelsMeta(c *gin.Context) {
keyword := c.Query("keyword")
vendor := c.Query("vendor")
pageInfo := common.GetPageQuery(c)
modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
common.ApiError(c, err)
return
}
// 批量填充附加字段,提升列表接口性能
enrichModels(modelsMeta)
pageInfo.SetTotal(int(total))
pageInfo.SetItems(modelsMeta)
common.ApiSuccess(c, pageInfo)
}
// GetModelMeta 根据 ID 获取单条模型信息
func GetModelMeta(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
common.ApiError(c, err)
return
}
var m model.Model
if err := model.DB.First(&m, id).Error; err != nil {
common.ApiError(c, err)
return
}
enrichModels([]*model.Model{&m})
common.ApiSuccess(c, &m)
}
// CreateModelMeta 新建模型
func CreateModelMeta(c *gin.Context) {
var m model.Model
if err := c.ShouldBindJSON(&m); err != nil {
common.ApiError(c, err)
return
}
if m.ModelName == "" {
common.ApiErrorMsg(c, "模型名称不能为空")
return
}
// 名称冲突检查
if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
common.ApiError(c, err)
return
} else if dup {
common.ApiErrorMsg(c, "模型名称已存在")
return
}
if err := m.Insert(); err != nil {
common.ApiError(c, err)
return
}
model.RefreshPricing()
common.ApiSuccess(c, &m)
}
// UpdateModelMeta 更新模型
func UpdateModelMeta(c *gin.Context) {
statusOnly := c.Query("status_only") == "true"
var m model.Model
if err := c.ShouldBindJSON(&m); err != nil {
common.ApiError(c, err)
return
}
if m.Id == 0 {
common.ApiErrorMsg(c, "缺少模型 ID")
return
}
if statusOnly {
// 只更新状态,防止误清空其他字段
if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
common.ApiError(c, err)
return
}
} else {
// 名称冲突检查
if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
common.ApiError(c, err)
return
} else if dup {
common.ApiErrorMsg(c, "模型名称已存在")
return
}
if err := m.Update(); err != nil {
common.ApiError(c, err)
return
}
}
model.RefreshPricing()
common.ApiSuccess(c, &m)
}
// DeleteModelMeta 删除模型
func DeleteModelMeta(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
common.ApiError(c, err)
return
}
if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
common.ApiError(c, err)
return
}
model.RefreshPricing()
common.ApiSuccess(c, nil)
}
// enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询
func enrichModels(models []*model.Model) {
if len(models) == 0 {
return
}
// 1) 拆分精确与规则匹配
exactNames := make([]string, 0)
exactIdx := make(map[string][]int) // modelName -> indices in models
ruleIndices := make([]int, 0)
for i, m := range models {
if m == nil {
continue
}
if m.NameRule == model.NameRuleExact {
exactNames = append(exactNames, m.ModelName)
exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i)
} else {
ruleIndices = append(ruleIndices, i)
}
}
// 2) 批量查询精确模型的绑定渠道
channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames)
// 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存
for name, indices := range exactIdx {
chs := channelsByModel[name]
for _, idx := range indices {
mm := models[idx]
if mm.Endpoints == "" {
eps := model.GetModelSupportEndpointTypes(mm.ModelName)
if b, err := json.Marshal(eps); err == nil {
mm.Endpoints = string(b)
}
}
mm.BoundChannels = chs
mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName)
mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName)
}
}
if len(ruleIndices) == 0 {
return
}
// 4) 一次性读取定价缓存,内存匹配所有规则模型
pricings := model.GetPricing()
// 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合
matchedNamesByIdx := make(map[int][]string)
endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{})
groupSetByIdx := make(map[int]map[string]struct{})
quotaSetByIdx := make(map[int]map[int]struct{})
for _, p := range pricings {
for _, idx := range ruleIndices {
mm := models[idx]
var matched bool
switch mm.NameRule {
case model.NameRulePrefix:
matched = strings.HasPrefix(p.ModelName, mm.ModelName)
case model.NameRuleSuffix:
matched = strings.HasSuffix(p.ModelName, mm.ModelName)
case model.NameRuleContains:
matched = strings.Contains(p.ModelName, mm.ModelName)
}
if !matched {
continue
}
matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName)
es := endpointSetByIdx[idx]
if es == nil {
es = make(map[constant.EndpointType]struct{})
endpointSetByIdx[idx] = es
}
for _, et := range p.SupportedEndpointTypes {
es[et] = struct{}{}
}
gs := groupSetByIdx[idx]
if gs == nil {
gs = make(map[string]struct{})
groupSetByIdx[idx] = gs
}
for _, g := range p.EnableGroup {
gs[g] = struct{}{}
}
qs := quotaSetByIdx[idx]
if qs == nil {
qs = make(map[int]struct{})
quotaSetByIdx[idx] = qs
}
qs[p.QuotaType] = struct{}{}
}
}
// 5) 汇总所有匹配到的模型名称,批量查询一次渠道
allMatchedSet := make(map[string]struct{})
for _, names := range matchedNamesByIdx {
for _, n := range names {
allMatchedSet[n] = struct{}{}
}
}
allMatched := make([]string, 0, len(allMatchedSet))
for n := range allMatchedSet {
allMatched = append(allMatched, n)
}
matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched)
// 6) 回填每个规则模型的并集信息
for _, idx := range ruleIndices {
mm := models[idx]
// 端点并集 -> 序列化
if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" {
eps := make([]constant.EndpointType, 0, len(es))
for et := range es {
eps = append(eps, et)
}
if b, err := json.Marshal(eps); err == nil {
mm.Endpoints = string(b)
}
}
// 分组并集
if gs, ok := groupSetByIdx[idx]; ok {
groups := make([]string, 0, len(gs))
for g := range gs {
groups = append(groups, g)
}
mm.EnableGroups = groups
}
// 配额类型集合(保持去重并排序)
if qs, ok := quotaSetByIdx[idx]; ok {
arr := make([]int, 0, len(qs))
for k := range qs {
arr = append(arr, k)
}
sort.Ints(arr)
mm.QuotaTypes = arr
}
// 渠道并集
names := matchedNamesByIdx[idx]
channelSet := make(map[string]model.BoundChannel)
for _, n := range names {
for _, ch := range matchedChannelsByModel[n] {
key := ch.Name + "_" + strconv.Itoa(ch.Type)
channelSet[key] = ch
}
}
if len(channelSet) > 0 {
chs := make([]model.BoundChannel, 0, len(channelSet))
for _, ch := range channelSet {
chs = append(chs, ch)
}
mm.BoundChannels = chs
}
// 匹配信息
mm.MatchedModels = names
mm.MatchedCount = len(names)
}
}

View File

@@ -69,7 +69,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
} }
if oidcResponse.AccessToken == "" { if oidcResponse.AccessToken == "" {
common.SysError("OIDC 获取 Token 失败,请检查设置!") common.SysLog("OIDC 获取 Token 失败,请检查设置!")
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!") return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
} }
@@ -85,7 +85,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
} }
defer res2.Body.Close() defer res2.Body.Close()
if res2.StatusCode != http.StatusOK { if res2.StatusCode != http.StatusOK {
common.SysError("OIDC 获取用户信息失败!请检查设置!") common.SysLog("OIDC 获取用户信息失败!请检查设置!")
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!") return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
} }
@@ -95,7 +95,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
return nil, err return nil, err
} }
if oidcUser.OpenID == "" || oidcUser.Email == "" { if oidcUser.OpenID == "" || oidcUser.Email == "" {
common.SysError("OIDC 获取用户信息为空!请检查设置!") common.SysLog("OIDC 获取用户信息为空!请检查设置!")
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!") return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
} }
return &oidcUser, nil return &oidcUser, nil
@@ -126,10 +126,7 @@ func OidcAuth(c *gin.Context) {
code := c.Query("code") code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code) oidcUser, err := getOidcUserInfoByCode(code)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
user := model.User{ user := model.User{
@@ -195,10 +192,7 @@ func OidcBind(c *gin.Context) {
code := c.Query("code") code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code) oidcUser, err := getOidcUserInfoByCode(code)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
user := model.User{ user := model.User{
@@ -217,19 +211,13 @@ func OidcBind(c *gin.Context) {
user.Id = id.(int) user.Id = id.(int)
err = user.FillUserById() err = user.FillUserById()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
user.OidcId = oidcUser.OpenID user.OidcId = oidcUser.OpenID
err = user.Update(false) err = user.Update(false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -6,6 +6,8 @@ import (
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"one-api/setting" "one-api/setting"
"one-api/setting/console_setting"
"one-api/setting/ratio_setting"
"one-api/setting/system_setting" "one-api/setting/system_setting"
"strings" "strings"
@@ -102,7 +104,7 @@ func UpdateOption(c *gin.Context) {
return return
} }
case "GroupRatio": case "GroupRatio":
err = setting.CheckGroupRatio(option.Value) err = ratio_setting.CheckGroupRatio(option.Value)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -119,14 +121,46 @@ func UpdateOption(c *gin.Context) {
}) })
return return
} }
case "console_setting.api_info":
err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
case "console_setting.announcements":
err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
case "console_setting.faq":
err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
case "console_setting.uptime_kuma_groups":
err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} }
err = model.UpdateOption(option.Key, option.Value) err = model.UpdateOption(option.Key, option.Value)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -3,67 +3,58 @@ package controller
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto"
"one-api/middleware" "one-api/middleware"
"one-api/model" "one-api/model"
"one-api/service" "one-api/types"
"one-api/setting"
"time" "time"
"github.com/gin-gonic/gin"
) )
func Playground(c *gin.Context) { func Playground(c *gin.Context) {
var openaiErr *dto.OpenAIErrorWithStatusCode var newAPIError *types.NewAPIError
defer func() { defer func() {
if openaiErr != nil { if newAPIError != nil {
c.JSON(openaiErr.StatusCode, gin.H{ c.JSON(newAPIError.StatusCode, gin.H{
"error": openaiErr.Error, "error": newAPIError.ToOpenAIError(),
}) })
} }
}() }()
useAccessToken := c.GetBool("use_access_token") useAccessToken := c.GetBool("use_access_token")
if useAccessToken { if useAccessToken {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest) newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
return return
} }
playgroundRequest := &dto.PlayGroundRequest{} group := c.GetString("group")
err := common.UnmarshalBodyReusable(c, playgroundRequest) modelName := c.GetString("original_model")
userId := c.GetInt("id")
// Write user context to ensure acceptUnsetRatio is available
userCache, err := model.GetUserCache(userId)
if err != nil { if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest) newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
return return
} }
userCache.WriteContext(c)
if playgroundRequest.Model == "" { tempToken := &model.Token{
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest) UserId: userId,
Name: fmt.Sprintf("playground-%s", group),
Group: group,
}
_ = middleware.SetupContextForToken(c, tempToken)
_, newAPIError = getChannel(c, group, modelName, 0)
if newAPIError != nil {
return return
} }
c.Set("original_model", playgroundRequest.Model) //middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
group := playgroundRequest.Group common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
userGroup := c.GetString("group")
if group == "" { Relay(c, types.RelayFormatOpenAI)
group = userGroup
} else {
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
return
}
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
c.Set(constant.ContextKeyRequestStartTime, time.Now())
Relay(c)
} }

View File

@@ -0,0 +1,90 @@
package controller
import (
"strconv"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
// GetPrefillGroups 获取预填组列表,可通过 ?type=xxx 过滤
func GetPrefillGroups(c *gin.Context) {
groupType := c.Query("type")
groups, err := model.GetAllPrefillGroups(groupType)
if err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, groups)
}
// CreatePrefillGroup 创建新的预填组
func CreatePrefillGroup(c *gin.Context) {
var g model.PrefillGroup
if err := c.ShouldBindJSON(&g); err != nil {
common.ApiError(c, err)
return
}
if g.Name == "" || g.Type == "" {
common.ApiErrorMsg(c, "组名称和类型不能为空")
return
}
// 创建前检查名称
if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil {
common.ApiError(c, err)
return
} else if dup {
common.ApiErrorMsg(c, "组名称已存在")
return
}
if err := g.Insert(); err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, &g)
}
// UpdatePrefillGroup 更新预填组
func UpdatePrefillGroup(c *gin.Context) {
var g model.PrefillGroup
if err := c.ShouldBindJSON(&g); err != nil {
common.ApiError(c, err)
return
}
if g.Id == 0 {
common.ApiErrorMsg(c, "缺少组 ID")
return
}
// 名称冲突检查
if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil {
common.ApiError(c, err)
return
} else if dup {
common.ApiErrorMsg(c, "组名称已存在")
return
}
if err := g.Update(); err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, &g)
}
// DeletePrefillGroup 删除预填组
func DeletePrefillGroup(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
common.ApiError(c, err)
return
}
if err := model.DeletePrefillGroupByID(id); err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, nil)
}

View File

@@ -1,10 +1,11 @@
package controller package controller
import ( import (
"github.com/gin-gonic/gin"
"one-api/model" "one-api/model"
"one-api/setting" "one-api/setting"
"one-api/setting/operation_setting" "one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
) )
func GetPricing(c *gin.Context) { func GetPricing(c *gin.Context) {
@@ -12,7 +13,7 @@ func GetPricing(c *gin.Context) {
userId, exists := c.Get("id") userId, exists := c.Get("id")
usableGroup := map[string]string{} usableGroup := map[string]string{}
groupRatio := map[string]float64{} groupRatio := map[string]float64{}
for s, f := range setting.GetGroupRatioCopy() { for s, f := range ratio_setting.GetGroupRatioCopy() {
groupRatio[s] = f groupRatio[s] = f
} }
var group string var group string
@@ -20,27 +21,36 @@ func GetPricing(c *gin.Context) {
user, err := model.GetUserCache(userId.(int)) user, err := model.GetUserCache(userId.(int))
if err == nil { if err == nil {
group = user.Group group = user.Group
for g := range groupRatio {
ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
if ok {
groupRatio[g] = ratio
}
}
} }
} }
usableGroup = setting.GetUserUsableGroups(group) usableGroup = setting.GetUserUsableGroups(group)
// check groupRatio contains usableGroup // check groupRatio contains usableGroup
for group := range setting.GetGroupRatioCopy() { for group := range ratio_setting.GetGroupRatioCopy() {
if _, ok := usableGroup[group]; !ok { if _, ok := usableGroup[group]; !ok {
delete(groupRatio, group) delete(groupRatio, group)
} }
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"success": true, "success": true,
"data": pricing, "data": pricing,
"group_ratio": groupRatio, "vendors": model.GetVendors(),
"usable_group": usableGroup, "group_ratio": groupRatio,
"usable_group": usableGroup,
"supported_endpoint": model.GetSupportedEndpointMap(),
"auto_groups": setting.AutoGroups,
}) })
} }
func ResetModelRatio(c *gin.Context) { func ResetModelRatio(c *gin.Context) {
defaultStr := operation_setting.DefaultModelRatio2JSONString() defaultStr := ratio_setting.DefaultModelRatio2JSONString()
err := model.UpdateOption("ModelRatio", defaultStr) err := model.UpdateOption("ModelRatio", defaultStr)
if err != nil { if err != nil {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
@@ -49,7 +59,7 @@ func ResetModelRatio(c *gin.Context) {
}) })
return return
} }
err = operation_setting.UpdateModelRatioByJSONString(defaultStr) err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
if err != nil { if err != nil {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"success": false, "success": false,

View File

@@ -0,0 +1,24 @@
package controller
import (
"net/http"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
func GetRatioConfig(c *gin.Context) {
if !ratio_setting.IsExposeRatioEnabled() {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "倍率配置接口未启用",
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": ratio_setting.GetExposedData(),
})
}

474
controller/ratio_sync.go Normal file
View File

@@ -0,0 +1,474 @@
package controller
import (
"context"
"encoding/json"
"fmt"
"net/http"
"one-api/logger"
"strings"
"sync"
"time"
"one-api/dto"
"one-api/model"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
const (
defaultTimeoutSeconds = 10
defaultEndpoint = "/api/ratio_config"
maxConcurrentFetches = 8
)
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
type upstreamResult struct {
Name string `json:"name"`
Data map[string]any `json:"data,omitempty"`
Err string `json:"err,omitempty"`
}
func FetchUpstreamRatios(c *gin.Context) {
var req dto.UpstreamRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
return
}
if req.Timeout <= 0 {
req.Timeout = defaultTimeoutSeconds
}
var upstreams []dto.UpstreamDTO
if len(req.Upstreams) > 0 {
for _, u := range req.Upstreams {
if strings.HasPrefix(u.BaseURL, "http") {
if u.Endpoint == "" {
u.Endpoint = defaultEndpoint
}
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
upstreams = append(upstreams, u)
}
}
} else if len(req.ChannelIDs) > 0 {
intIds := make([]int, 0, len(req.ChannelIDs))
for _, id64 := range req.ChannelIDs {
intIds = append(intIds, int(id64))
}
dbChannels, err := model.GetChannelsByIds(intIds)
if err != nil {
logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
return
}
for _, ch := range dbChannels {
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
upstreams = append(upstreams, dto.UpstreamDTO{
ID: ch.Id,
Name: ch.Name,
BaseURL: strings.TrimRight(base, "/"),
Endpoint: "",
})
}
}
}
if len(upstreams) == 0 {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
return
}
var wg sync.WaitGroup
ch := make(chan upstreamResult, len(upstreams))
sem := make(chan struct{}, maxConcurrentFetches)
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
for _, chn := range upstreams {
wg.Add(1)
go func(chItem dto.UpstreamDTO) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
endpoint := chItem.Endpoint
if endpoint == "" {
endpoint = defaultEndpoint
} else if !strings.HasPrefix(endpoint, "/") {
endpoint = "/" + endpoint
}
fullURL := chItem.BaseURL + endpoint
uniqueName := chItem.Name
if chItem.ID != 0 {
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
}
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
defer cancel()
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
if err != nil {
logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
return
}
resp, err := client.Do(httpReq)
if err != nil {
logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
return
}
// 兼容两种上游接口格式:
// type1: /api/ratio_config -> data 为 map[string]any包含 model_ratio/completion_ratio/cache_ratio/model_price
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
var body struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data"`
Message string `json:"message"`
}
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
return
}
if !body.Success {
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
return
}
// 尝试按 type1 解析
var type1Data map[string]any
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
isType1 := false
for _, rt := range ratioTypes {
if _, ok := type1Data[rt]; ok {
isType1 = true
break
}
}
if isType1 {
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
return
}
}
// 如果不是 type1则尝试按 type2 (/api/pricing) 解析
var pricingItems []struct {
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
CompletionRatio float64 `json:"completion_ratio"`
}
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
return
}
modelRatioMap := make(map[string]float64)
completionRatioMap := make(map[string]float64)
modelPriceMap := make(map[string]float64)
for _, item := range pricingItems {
if item.QuotaType == 1 {
modelPriceMap[item.ModelName] = item.ModelPrice
} else {
modelRatioMap[item.ModelName] = item.ModelRatio
// completionRatio 可能为 0此时也直接赋值保持与上游一致
completionRatioMap[item.ModelName] = item.CompletionRatio
}
}
converted := make(map[string]any)
if len(modelRatioMap) > 0 {
ratioAny := make(map[string]any, len(modelRatioMap))
for k, v := range modelRatioMap {
ratioAny[k] = v
}
converted["model_ratio"] = ratioAny
}
if len(completionRatioMap) > 0 {
compAny := make(map[string]any, len(completionRatioMap))
for k, v := range completionRatioMap {
compAny[k] = v
}
converted["completion_ratio"] = compAny
}
if len(modelPriceMap) > 0 {
priceAny := make(map[string]any, len(modelPriceMap))
for k, v := range modelPriceMap {
priceAny[k] = v
}
converted["model_price"] = priceAny
}
ch <- upstreamResult{Name: uniqueName, Data: converted}
}(chn)
}
wg.Wait()
close(ch)
localData := ratio_setting.GetExposedData()
var testResults []dto.TestResult
var successfulChannels []struct {
name string
data map[string]any
}
for r := range ch {
if r.Err != "" {
testResults = append(testResults, dto.TestResult{
Name: r.Name,
Status: "error",
Error: r.Err,
})
} else {
testResults = append(testResults, dto.TestResult{
Name: r.Name,
Status: "success",
})
successfulChannels = append(successfulChannels, struct {
name string
data map[string]any
}{name: r.Name, data: r.Data})
}
}
differences := buildDifferences(localData, successfulChannels)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"differences": differences,
"test_results": testResults,
},
})
}
func buildDifferences(localData map[string]any, successfulChannels []struct {
name string
data map[string]any
}) map[string]map[string]dto.DifferenceItem {
differences := make(map[string]map[string]dto.DifferenceItem)
allModels := make(map[string]struct{})
for _, ratioType := range ratioTypes {
if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok {
for modelName := range localRatio {
allModels[modelName] = struct{}{}
}
}
}
}
for _, channel := range successfulChannels {
for _, ratioType := range ratioTypes {
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
for modelName := range upstreamRatio {
allModels[modelName] = struct{}{}
}
}
}
}
confidenceMap := make(map[string]map[string]bool)
// 预处理阶段检查pricing接口的可信度
for _, channel := range successfulChannels {
confidenceMap[channel.name] = make(map[string]bool)
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
if hasModelRatio && hasCompletionRatio {
// 遍历所有模型,检查是否满足不可信条件
for modelName := range allModels {
// 默认为可信
confidenceMap[channel.name][modelName] = true
// 检查是否满足不可信条件model_ratio为37.5且completion_ratio为1
if modelRatioVal, ok := modelRatios[modelName]; ok {
if completionRatioVal, ok := completionRatios[modelName]; ok {
// 转换为float64进行比较
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
confidenceMap[channel.name][modelName] = false
}
}
}
}
}
}
} else {
// 如果不是从pricing接口获取的数据则全部标记为可信
for modelName := range allModels {
confidenceMap[channel.name][modelName] = true
}
}
}
for modelName := range allModels {
for _, ratioType := range ratioTypes {
var localValue interface{} = nil
if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok {
if val, exists := localRatio[modelName]; exists {
localValue = val
}
}
}
upstreamValues := make(map[string]interface{})
confidenceValues := make(map[string]bool)
hasUpstreamValue := false
hasDifference := false
for _, channel := range successfulChannels {
var upstreamValue interface{} = nil
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
if val, exists := upstreamRatio[modelName]; exists {
upstreamValue = val
hasUpstreamValue = true
if localValue != nil && localValue != val {
hasDifference = true
} else if localValue == val {
upstreamValue = "same"
}
}
}
if upstreamValue == nil && localValue == nil {
upstreamValue = "same"
}
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
hasDifference = true
}
upstreamValues[channel.name] = upstreamValue
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
}
shouldInclude := false
if localValue != nil {
if hasDifference {
shouldInclude = true
}
} else {
if hasUpstreamValue {
shouldInclude = true
}
}
if shouldInclude {
if differences[modelName] == nil {
differences[modelName] = make(map[string]dto.DifferenceItem)
}
differences[modelName][ratioType] = dto.DifferenceItem{
Current: localValue,
Upstreams: upstreamValues,
Confidence: confidenceValues,
}
}
}
}
channelHasDiff := make(map[string]bool)
for _, ratioMap := range differences {
for _, item := range ratioMap {
for chName, val := range item.Upstreams {
if val != nil && val != "same" {
channelHasDiff[chName] = true
}
}
}
}
for modelName, ratioMap := range differences {
for ratioType, item := range ratioMap {
for chName := range item.Upstreams {
if !channelHasDiff[chName] {
delete(item.Upstreams, chName)
delete(item.Confidence, chName)
}
}
allSame := true
for _, v := range item.Upstreams {
if v != "same" {
allSame = false
break
}
}
if len(item.Upstreams) == 0 || allSame {
delete(ratioMap, ratioType)
} else {
differences[modelName][ratioType] = item
}
}
if len(ratioMap) == 0 {
delete(differences, modelName)
}
}
return differences
}
func GetSyncableChannels(c *gin.Context) {
channels, err := model.GetAllChannels(0, 0, true, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var syncableChannels []dto.SyncableChannel
for _, channel := range channels {
if channel.GetBaseURL() != "" {
syncableChannels = append(syncableChannels, dto.SyncableChannel{
ID: channel.Id,
Name: channel.Name,
BaseURL: channel.GetBaseURL(),
Status: channel.Status,
})
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": syncableChannels,
})
}

View File

@@ -1,90 +1,52 @@
package controller package controller
import ( import (
"errors"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"unicode/utf8"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func GetAllRedemptions(c *gin.Context) { func GetAllRedemptions(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) pageInfo := common.GetPageQuery(c)
pageSize, _ := strconv.Atoi(c.Query("page_size")) redemptions, total, err := model.GetAllRedemptions(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if p < 0 {
p = 0
}
if pageSize < 1 {
pageSize = common.ItemsPerPage
}
redemptions, total, err := model.GetAllRedemptions((p-1)*pageSize, pageSize)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ pageInfo.SetTotal(int(total))
"success": true, pageInfo.SetItems(redemptions)
"message": "", common.ApiSuccess(c, pageInfo)
"data": gin.H{
"items": redemptions,
"total": total,
"page": p,
"page_size": pageSize,
},
})
return return
} }
func SearchRedemptions(c *gin.Context) { func SearchRedemptions(c *gin.Context) {
keyword := c.Query("keyword") keyword := c.Query("keyword")
p, _ := strconv.Atoi(c.Query("p")) pageInfo := common.GetPageQuery(c)
pageSize, _ := strconv.Atoi(c.Query("page_size")) redemptions, total, err := model.SearchRedemptions(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if p < 0 {
p = 0
}
if pageSize < 1 {
pageSize = common.ItemsPerPage
}
redemptions, total, err := model.SearchRedemptions(keyword, (p-1)*pageSize, pageSize)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ pageInfo.SetTotal(int(total))
"success": true, pageInfo.SetItems(redemptions)
"message": "", common.ApiSuccess(c, pageInfo)
"data": gin.H{
"items": redemptions,
"total": total,
"page": p,
"page_size": pageSize,
},
})
return return
} }
func GetRedemption(c *gin.Context) { func GetRedemption(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
redemption, err := model.GetRedemptionById(id) redemption, err := model.GetRedemptionById(id)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -99,13 +61,10 @@ func AddRedemption(c *gin.Context) {
redemption := model.Redemption{} redemption := model.Redemption{}
err := c.ShouldBindJSON(&redemption) err := c.ShouldBindJSON(&redemption)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
if len(redemption.Name) == 0 || len(redemption.Name) > 20 { if utf8.RuneCountInString(redemption.Name) == 0 || utf8.RuneCountInString(redemption.Name) > 20 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "兑换码名称长度必须在1-20之间", "message": "兑换码名称长度必须在1-20之间",
@@ -126,6 +85,10 @@ func AddRedemption(c *gin.Context) {
}) })
return return
} }
if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
var keys []string var keys []string
for i := 0; i < redemption.Count; i++ { for i := 0; i < redemption.Count; i++ {
key := common.GetUUID() key := common.GetUUID()
@@ -135,6 +98,7 @@ func AddRedemption(c *gin.Context) {
Key: key, Key: key,
CreatedTime: common.GetTimestamp(), CreatedTime: common.GetTimestamp(),
Quota: redemption.Quota, Quota: redemption.Quota,
ExpiredTime: redemption.ExpiredTime,
} }
err = cleanRedemption.Insert() err = cleanRedemption.Insert()
if err != nil { if err != nil {
@@ -159,10 +123,7 @@ func DeleteRedemption(c *gin.Context) {
id, _ := strconv.Atoi(c.Param("id")) id, _ := strconv.Atoi(c.Param("id"))
err := model.DeleteRedemptionById(id) err := model.DeleteRedemptionById(id)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -177,33 +138,30 @@ func UpdateRedemption(c *gin.Context) {
redemption := model.Redemption{} redemption := model.Redemption{}
err := c.ShouldBindJSON(&redemption) err := c.ShouldBindJSON(&redemption)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
cleanRedemption, err := model.GetRedemptionById(redemption.Id) cleanRedemption, err := model.GetRedemptionById(redemption.Id)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
if statusOnly != "" { if statusOnly == "" {
cleanRedemption.Status = redemption.Status if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
} else { c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
// If you add more fields, please also update redemption.Update() // If you add more fields, please also update redemption.Update()
cleanRedemption.Name = redemption.Name cleanRedemption.Name = redemption.Name
cleanRedemption.Quota = redemption.Quota cleanRedemption.Quota = redemption.Quota
cleanRedemption.ExpiredTime = redemption.ExpiredTime
}
if statusOnly != "" {
cleanRedemption.Status = redemption.Status
} }
err = cleanRedemption.Update() err = cleanRedemption.Update()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -213,3 +171,24 @@ func UpdateRedemption(c *gin.Context) {
}) })
return return
} }
func DeleteInvalidRedemption(c *gin.Context) {
rows, err := model.DeleteInvalidRedemptions()
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": rows,
})
return
}
func validateExpiredTime(expired int64) error {
if expired != 0 && expired < common.GetTimestamp() {
return errors.New("过期时间不能早于当前时间")
}
return nil
}

View File

@@ -2,114 +2,192 @@ package controller
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool"
"io" "io"
"log" "log"
"net/http" "net/http"
"one-api/common" "one-api/common"
constant2 "one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/logger"
"one-api/middleware" "one-api/middleware"
"one-api/model" "one-api/model"
"one-api/relay" "one-api/relay"
"one-api/relay/constant" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting"
"one-api/types"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
var err *dto.OpenAIErrorWithStatusCode var err *types.NewAPIError
switch relayMode { switch info.RelayMode {
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits: case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
err = relay.ImageHelper(c) err = relay.ImageHelper(c, info)
case relayconstant.RelayModeAudioSpeech: case relayconstant.RelayModeAudioSpeech:
fallthrough fallthrough
case relayconstant.RelayModeAudioTranslation: case relayconstant.RelayModeAudioTranslation:
fallthrough fallthrough
case relayconstant.RelayModeAudioTranscription: case relayconstant.RelayModeAudioTranscription:
err = relay.AudioHelper(c) err = relay.AudioHelper(c, info)
case relayconstant.RelayModeRerank: case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, relayMode) err = relay.RerankHelper(c, info)
case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeEmbeddings:
err = relay.EmbeddingHelper(c) err = relay.EmbeddingHelper(c, info)
case relayconstant.RelayModeResponses: case relayconstant.RelayModeResponses:
err = relay.ResponsesHelper(c) err = relay.ResponsesHelper(c, info)
case relayconstant.RelayModeGemini:
err = relay.GeminiHelper(c)
default: default:
err = relay.TextHelper(c) err = relay.TextHelper(c, info)
} }
if constant2.ErrorLogEnabled && err != nil {
// 保存错误日志到mysql中
userId := c.GetInt("id")
tokenName := c.GetString("token_name")
modelName := c.GetString("original_model")
tokenId := c.GetInt("token_id")
userGroup := c.GetString("group")
channelId := c.GetInt("channel_id")
other := make(map[string]interface{})
other["error_type"] = err.Error.Type
other["error_code"] = err.Error.Code
other["status_code"] = err.StatusCode
other["channel_id"] = channelId
other["channel_name"] = c.GetString("channel_name")
other["channel_type"] = c.GetInt("channel_type")
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error.Message, tokenId, 0, false, userGroup, other)
}
return err return err
} }
func Relay(c *gin.Context) { func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
relayMode := constant.Path2RelayMode(c.Request.URL.Path) var err *types.NewAPIError
if strings.Contains(c.Request.URL.Path, "embed") {
err = relay.GeminiEmbeddingHandler(c, info)
} else {
err = relay.GeminiHelper(c, info)
}
return err
}
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
requestId := c.GetString(common.RequestIdKey) requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group") group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
originalModel := c.GetString("original_model") originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
var openaiErr *dto.OpenAIErrorWithStatusCode
var (
newAPIError *types.NewAPIError
ws *websocket.Conn
)
if relayFormat == types.RelayFormatOpenAIRealtime {
var err error
ws, err = upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
return
}
defer ws.Close()
}
defer func() {
if newAPIError != nil {
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
switch relayFormat {
case types.RelayFormatOpenAIRealtime:
helper.WssError(c, ws, newAPIError.ToOpenAIError())
case types.RelayFormatClaude:
c.JSON(newAPIError.StatusCode, gin.H{
"type": "error",
"error": newAPIError.ToClaudeError(),
})
default:
c.JSON(newAPIError.StatusCode, gin.H{
"error": newAPIError.ToOpenAIError(),
})
}
}
}()
request, err := helper.GetAndValidateRequest(c, relayFormat)
if err != nil {
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
return
}
relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws)
if err != nil {
newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed)
return
}
meta := request.GetTokenCountMeta()
if setting.ShouldCheckPromptSensitive() {
contains, words := service.CheckSensitiveText(meta.CombineText)
if contains {
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
return
}
}
tokens, err := service.CountRequestToken(c, meta, relayInfo)
if err != nil {
newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed)
return
}
relayInfo.SetPromptTokens(tokens)
priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
if err != nil {
newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)
return
}
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return
}
defer func() {
// Only return quota if downstream failed and quota was actually pre-consumed
if newAPIError != nil && preConsumedQuota != 0 {
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
}
}()
for i := 0; i <= common.RetryTimes; i++ { for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i) channel, err := getChannel(c, group, originalModel, i)
if err != nil { if err != nil {
common.LogError(c, err.Error()) logger.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) newAPIError = err
break break
} }
openaiErr = relayRequest(c, relayMode, channel) addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
if openaiErr == nil { switch relayFormat {
return // 成功处理请求,直接返回 case types.RelayFormatOpenAIRealtime:
newAPIError = relay.WssHelper(c, relayInfo)
case types.RelayFormatClaude:
newAPIError = relay.ClaudeHelper(c, relayInfo)
case types.RelayFormatGemini:
newAPIError = geminiRelayHandler(c, relayInfo)
default:
newAPIError = relayHandler(c, relayInfo)
} }
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr) if newAPIError == nil {
return
}
if !shouldRetry(c, openaiErr, common.RetryTimes-i) { processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
break break
} }
} }
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 { if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr) logger.LogInfo(c, retryLogStr)
}
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message))
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
})
} }
} }
@@ -120,132 +198,13 @@ var upgrader = websocket.Upgrader{
}, },
} }
func WssRelay(c *gin.Context) {
// 将 HTTP 连接升级为 WebSocket 连接
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
defer ws.Close()
if err != nil {
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
helper.WssError(c, ws, openaiErr.Error)
return
}
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
originalModel := c.GetString("original_model")
var openaiErr *dto.OpenAIErrorWithStatusCode
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
openaiErr = wssRequest(c, ws, relayMode, channel)
if openaiErr == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
helper.WssError(c, ws, openaiErr.Error)
}
}
func RelayClaude(c *gin.Context) {
//relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
originalModel := c.GetString("original_model")
var claudeErr *dto.ClaudeErrorWithStatusCode
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
claudeErr = claudeRequest(c, channel)
if claudeErr == nil {
return // 成功处理请求,直接返回
}
openaiErr := service.ClaudeErrorToOpenAIError(claudeErr)
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if claudeErr != nil {
claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId)
c.JSON(claudeErr.StatusCode, gin.H{
"type": "error",
"error": claudeErr.Error,
})
}
}
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relayHandler(c, relayMode)
}
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relay.WssHelper(c, ws)
}
func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relay.ClaudeHelper(c)
}
func addUsedChannel(c *gin.Context, channelId int) { func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel) c.Set("use_channel", useChannel)
} }
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) { func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
if retryCount == 0 { if retryCount == 0 {
autoBan := c.GetBool("auto_ban") autoBan := c.GetBool("auto_ban")
autoBanInt := 1 autoBanInt := 1
@@ -259,19 +218,28 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
AutoBan: &autoBanInt, AutoBan: &autoBanInt,
}, nil }, nil
} }
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount) channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil { if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error())) return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败retry: %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
if channel == nil {
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在数据库一致性已被破坏retry", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
if newAPIError != nil {
return nil, newAPIError
} }
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
return channel, nil return channel, nil
} }
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool { func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
if openaiErr == nil { if openaiErr == nil {
return false return false
} }
if openaiErr.LocalError { if types.IsChannelError(openaiErr) {
return true
}
if types.IsSkipRetryError(openaiErr) {
return false return false
} }
if retryTimes <= 0 { if retryTimes <= 0 {
@@ -294,10 +262,6 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
return true return true
} }
if openaiErr.StatusCode == http.StatusBadRequest { if openaiErr.StatusCode == http.StatusBadRequest {
channelType := c.GetInt("channel_type")
if channelType == common.ChannelTypeAnthropic {
return true
}
return false return false
} }
if openaiErr.StatusCode == 408 { if openaiErr.StatusCode == 408 {
@@ -310,45 +274,85 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
return true return true
} }
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) { func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
// 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况 logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message)) gopool.Go(func() {
if service.ShouldDisableChannel(channelType, err) && autoBan { // 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
service.DisableChannel(channelId, channelName, err.Error.Message) // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
service.DisableChannel(channelError, err.Error())
}
})
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
// 保存错误日志到mysql中
userId := c.GetInt("id")
tokenName := c.GetString("token_name")
modelName := c.GetString("original_model")
tokenId := c.GetInt("token_id")
userGroup := c.GetString("group")
channelId := c.GetInt("channel_id")
other := make(map[string]interface{})
other["error_type"] = err.GetErrorType()
other["error_code"] = err.GetErrorCode()
other["status_code"] = err.StatusCode
other["channel_id"] = channelId
other["channel_name"] = c.GetString("channel_name")
other["channel_type"] = c.GetInt("channel_type")
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
if isMultiKey {
adminInfo["is_multi_key"] = true
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
}
other["admin_info"] = adminInfo
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
} }
} }
func RelayMidjourney(c *gin.Context) { func RelayMidjourney(c *gin.Context) {
relayMode := c.GetInt("relay_mode") relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil)
var err *dto.MidjourneyResponse
switch relayMode { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"description": fmt.Sprintf("failed to generate relay info: %s", err.Error()),
"type": "upstream_error",
"code": 4,
})
return
}
var mjErr *dto.MidjourneyResponse
switch relayInfo.RelayMode {
case relayconstant.RelayModeMidjourneyNotify: case relayconstant.RelayModeMidjourneyNotify:
err = relay.RelayMidjourneyNotify(c) mjErr = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition: case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
err = relay.RelayMidjourneyTask(c, relayMode) mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode)
case relayconstant.RelayModeMidjourneyTaskImageSeed: case relayconstant.RelayModeMidjourneyTaskImageSeed:
err = relay.RelayMidjourneyTaskImageSeed(c) mjErr = relay.RelayMidjourneyTaskImageSeed(c)
case relayconstant.RelayModeSwapFace: case relayconstant.RelayModeSwapFace:
err = relay.RelaySwapFace(c) mjErr = relay.RelaySwapFace(c, relayInfo)
default: default:
err = relay.RelayMidjourneySubmit(c, relayMode) mjErr = relay.RelayMidjourneySubmit(c, relayInfo)
} }
//err = relayMidjourneySubmit(c, relayMode) //err = relayMidjourneySubmit(c, relayMode)
log.Println(err) log.Println(mjErr)
if err != nil { if mjErr != nil {
statusCode := http.StatusBadRequest statusCode := http.StatusBadRequest
if err.Code == 30 { if mjErr.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
statusCode = http.StatusTooManyRequests statusCode = http.StatusTooManyRequests
} }
c.JSON(statusCode, gin.H{ c.JSON(statusCode, gin.H{
"description": fmt.Sprintf("%s %s", err.Description, err.Result), "description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result),
"type": "upstream_error", "type": "upstream_error",
"code": err.Code, "code": mjErr.Code,
}) })
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result))) logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result)))
} }
} }
@@ -388,26 +392,27 @@ func RelayTask(c *gin.Context) {
retryTimes = 0 retryTimes = 0
} }
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) channel, newAPIError := getChannel(c, group, originalModel, i)
if err != nil { if newAPIError != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
break break
} }
channelId = channel.Id channelId = channel.Id
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel) c.Set("use_channel", useChannel)
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel) //middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c) requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
taskErr = taskRelayHandler(c, relayMode) taskErr = taskRelayHandler(c, relayMode)
} }
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 { if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr) logger.LogInfo(c, retryLogStr)
} }
if taskErr != nil { if taskErr != nil {
if taskErr.StatusCode == http.StatusTooManyRequests { if taskErr.StatusCode == http.StatusTooManyRequests {
@@ -420,7 +425,7 @@ func RelayTask(c *gin.Context) {
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError { func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
var err *dto.TaskError var err *dto.TaskError
switch relayMode { switch relayMode {
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID: case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
err = relay.RelayTaskFetch(c, relayMode) err = relay.RelayTaskFetch(c, relayMode)
default: default:
err = relay.RelayTaskSubmit(c, relayMode) err = relay.RelayTaskSubmit(c, relayMode)

View File

@@ -75,6 +75,14 @@ func PostSetup(c *gin.Context) {
// If root doesn't exist, validate and create admin account // If root doesn't exist, validate and create admin account
if !rootExists { if !rootExists {
// Validate username length: max 12 characters to align with model.User validation
if len(req.Username) > 12 {
c.JSON(400, gin.H{
"success": false,
"message": "用户名长度不能超过12个字符",
})
return
}
// Validate password // Validate password
if req.Password != req.ConfirmPassword { if req.Password != req.ConfirmPassword {
c.JSON(400, gin.H{ c.JSON(400, gin.H{

136
controller/swag_video.go Normal file
View File

@@ -0,0 +1,136 @@
package controller
import (
"github.com/gin-gonic/gin"
)
// VideoGenerations
// @Summary 生成视频
// @Description 调用视频生成接口生成视频
// @Description 支持多种视频生成服务:
// @Description - 可灵AI (Kling): https://app.klingai.com/cn/dev/document-api/apiReference/commonInfo
// @Description - 即梦 (Jimeng): https://www.volcengine.com/docs/85621/1538636
// @Tags Video
// @Accept json
// @Produce json
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
// @Param request body dto.VideoRequest true "视频生成请求参数"
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
// @Failure 401 {object} dto.OpenAIError "未授权"
// @Failure 403 {object} dto.OpenAIError "无权限"
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
// @Router /v1/video/generations [post]
func VideoGenerations(c *gin.Context) {
}
// VideoGenerationsTaskId
// @Summary 查询视频
// @Description 根据任务ID查询视频生成任务的状态和结果
// @Tags Video
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param task_id path string true "Task ID"
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
// @Failure 401 {object} dto.OpenAIError "未授权"
// @Failure 403 {object} dto.OpenAIError "无权限"
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
// @Router /v1/video/generations/{task_id} [get]
func VideoGenerationsTaskId(c *gin.Context) {
}
// KlingText2VideoGenerations
// @Summary 可灵文生视频
// @Description 调用可灵AI文生视频接口生成视频内容
// @Tags Video
// @Accept json
// @Produce json
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
// @Param request body KlingText2VideoRequest true "视频生成请求参数"
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
// @Failure 401 {object} dto.OpenAIError "未授权"
// @Failure 403 {object} dto.OpenAIError "无权限"
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
// @Router /kling/v1/videos/text2video [post]
func KlingText2VideoGenerations(c *gin.Context) {
}
type KlingText2VideoRequest struct {
ModelName string `json:"model_name,omitempty" example:"kling-v1"`
Prompt string `json:"prompt" binding:"required" example:"A cat playing piano in the garden"`
NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
Mode string `json:"mode,omitempty" example:"std"`
CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
Duration string `json:"duration,omitempty" example:"5"`
CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-001"`
}
type KlingCameraControl struct {
Type string `json:"type,omitempty" example:"simple"`
Config *KlingCameraConfig `json:"config,omitempty"`
}
type KlingCameraConfig struct {
Horizontal float64 `json:"horizontal,omitempty" example:"2.5"`
Vertical float64 `json:"vertical,omitempty" example:"0"`
Pan float64 `json:"pan,omitempty" example:"0"`
Tilt float64 `json:"tilt,omitempty" example:"0"`
Roll float64 `json:"roll,omitempty" example:"0"`
Zoom float64 `json:"zoom,omitempty" example:"0"`
}
// KlingImage2VideoGenerations
// @Summary 可灵官方-图生视频
// @Description 调用可灵AI图生视频接口生成视频内容
// @Tags Video
// @Accept json
// @Produce json
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
// @Param request body KlingImage2VideoRequest true "图生视频请求参数"
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
// @Failure 401 {object} dto.OpenAIError "未授权"
// @Failure 403 {object} dto.OpenAIError "无权限"
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
// @Router /kling/v1/videos/image2video [post]
func KlingImage2VideoGenerations(c *gin.Context) {
}
type KlingImage2VideoRequest struct {
ModelName string `json:"model_name,omitempty" example:"kling-v2-master"`
Image string `json:"image" binding:"required" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"`
Prompt string `json:"prompt,omitempty" example:"A cat playing piano in the garden"`
NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
Mode string `json:"mode,omitempty" example:"std"`
CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
Duration string `json:"duration,omitempty" example:"5"`
CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-002"`
}
// KlingImage2videoTaskId godoc
// @Summary 可灵任务查询--图生视频
// @Description Query the status and result of a Kling video generation task by task ID
// @Tags Origin
// @Accept json
// @Produce json
// @Param task_id path string true "Task ID"
// @Router /kling/v1/videos/image2video/{task_id} [get]
func KlingImage2videoTaskId(c *gin.Context) {}
// KlingText2videoTaskId godoc
// @Summary 可灵任务查询--文生视频
// @Description Query the status and result of a Kling text-to-video generation task by task ID
// @Tags Origin
// @Accept json
// @Produce json
// @Param task_id path string true "Task ID"
// @Router /kling/v1/videos/text2video/{task_id} [get]
func KlingText2videoTaskId(c *gin.Context) {}

View File

@@ -5,18 +5,20 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/logger"
"one-api/model" "one-api/model"
"one-api/relay" "one-api/relay"
"sort" "sort"
"strconv" "strconv"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
) )
func UpdateTaskBulk() { func UpdateTaskBulk() {
@@ -53,9 +55,9 @@ func UpdateTaskBulk() {
"progress": "100%", "progress": "100%",
}) })
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
} else { } else {
common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
} }
} }
if len(taskChannelM) == 0 { if len(taskChannelM) == 0 {
@@ -75,7 +77,9 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
case constant.TaskPlatformSuno: case constant.TaskPlatformSuno:
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
default: default:
common.SysLog("未知平台") if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
}
} }
} }
@@ -83,14 +87,14 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM
for channelId, taskIds := range taskChannelM { for channelId, taskIds := range taskChannelM {
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM) err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error())) logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
} }
} }
return nil return nil
} }
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 { if len(taskIds) == 0 {
return nil return nil
} }
@@ -103,7 +107,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
"progress": "100%", "progress": "100%",
}) })
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
} }
return err return err
} }
@@ -115,23 +119,23 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
"ids": taskIds, "ids": taskIds,
}) })
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("Get Task Do req error: %v", err)) common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
return err return err
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
} }
defer resp.Body.Close() defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("Get Task parse body error: %v", err)) common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
return err return err
} }
var responseItems dto.TaskResponse[[]dto.SunoDataResponse] var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
err = json.Unmarshal(responseBody, &responseItems) err = json.Unmarshal(responseBody, &responseItems)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
return err return err
} }
if !responseItems.IsSuccess() { if !responseItems.IsSuccess() {
@@ -151,19 +155,19 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
task.Progress = "100%" task.Progress = "100%"
//err = model.CacheUpdateUserQuota(task.UserId) ? //err = model.CacheUpdateUserQuota(task.UserId) ?
if err != nil { if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error()) logger.LogError(ctx, "error update user quota cache: "+err.Error())
} else { } else {
quota := task.Quota quota := task.Quota
if quota != 0 { if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota, false) err = model.IncreaseUserQuota(task.UserId, quota, false)
if err != nil { if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error()) logger.LogError(ctx, "fail to increase user quota: "+err.Error())
} }
logContent := fmt.Sprintf("异步任务执行失败 %s补偿 %s", task.TaskID, common.LogQuota(quota)) logContent := fmt.Sprintf("异步任务执行失败 %s补偿 %s", task.TaskID, logger.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent) model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
} }
} }
@@ -175,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
err = task.Update() err = task.Update()
if err != nil { if err != nil {
common.SysError("UpdateMidjourneyTask task error: " + err.Error()) common.SysLog("UpdateMidjourneyTask task error: " + err.Error())
} }
} }
return nil return nil
@@ -223,10 +227,8 @@ func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool
} }
func GetAllTask(c *gin.Context) { func GetAllTask(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) pageInfo := common.GetPageQuery(c)
if p < 0 {
p = 0
}
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
// 解析其他查询参数 // 解析其他查询参数
@@ -237,25 +239,18 @@ func GetAllTask(c *gin.Context) {
Action: c.Query("action"), Action: c.Query("action"),
StartTimestamp: startTimestamp, StartTimestamp: startTimestamp,
EndTimestamp: endTimestamp, EndTimestamp: endTimestamp,
ChannelID: c.Query("channel_id"),
} }
logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams) items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
if logs == nil { total := model.TaskCountAllTasks(queryParams)
logs = make([]*model.Task, 0) pageInfo.SetTotal(int(total))
} pageInfo.SetItems(items)
common.ApiSuccess(c, pageInfo)
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
} }
func GetUserTask(c *gin.Context) { func GetUserTask(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) pageInfo := common.GetPageQuery(c)
if p < 0 {
p = 0
}
userId := c.GetInt("id") userId := c.GetInt("id")
@@ -271,14 +266,9 @@ func GetUserTask(c *gin.Context) {
EndTimestamp: endTimestamp, EndTimestamp: endTimestamp,
} }
logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams) items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
if logs == nil { total := model.TaskCountAllUserTask(userId, queryParams)
logs = make([]*model.Task, 0) pageInfo.SetTotal(int(total))
} pageInfo.SetItems(items)
common.ApiSuccess(c, pageInfo)
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
} }

148
controller/task_video.go Normal file
View File

@@ -0,0 +1,148 @@
package controller
import (
"context"
"encoding/json"
"fmt"
"io"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/model"
"one-api/relay"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"time"
)
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
for channelId, taskIds := range taskChannelM {
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
}
}
return nil
}
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
}
cacheGetChannel, err := model.CacheGetChannel(channelId)
if err != nil {
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if errUpdate != nil {
common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
}
return fmt.Errorf("CacheGetChannel failed: %w", err)
}
adaptor := relay.GetTaskAdaptor(platform)
if adaptor == nil {
return fmt.Errorf("video adaptor not found")
}
for _, taskId := range taskIds {
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
}
}
return nil
}
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
task := taskM[taskId]
if task == nil {
logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
return fmt.Errorf("task %s not found", taskId)
}
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
"task_id": taskId,
"action": task.Action,
})
if err != nil {
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
}
//if resp.StatusCode != http.StatusOK {
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
//}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
}
taskResult := &relaycommon.TaskInfo{}
// try parse as New API response format
var responseItems dto.TaskResponse[model.Task]
if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
t := responseItems.Data
taskResult.TaskID = t.TaskID
taskResult.Status = string(t.Status)
taskResult.Url = t.FailReason
taskResult.Progress = t.Progress
taskResult.Reason = t.FailReason
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
} else {
task.Data = responseBody
}
now := time.Now().Unix()
if taskResult.Status == "" {
return fmt.Errorf("task %s status is empty", taskId)
}
task.Status = model.TaskStatus(taskResult.Status)
switch taskResult.Status {
case model.TaskStatusSubmitted:
task.Progress = "10%"
case model.TaskStatusQueued:
task.Progress = "20%"
case model.TaskStatusInProgress:
task.Progress = "30%"
if task.StartTime == 0 {
task.StartTime = now
}
case model.TaskStatusSuccess:
task.Progress = "100%"
if task.FinishTime == 0 {
task.FinishTime = now
}
task.FailReason = taskResult.Url
case model.TaskStatusFailure:
task.Status = model.TaskStatusFailure
task.Progress = "100%"
if task.FinishTime == 0 {
task.FinishTime = now
}
task.FailReason = taskResult.Reason
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
quota := task.Quota
if quota != 0 {
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
logger.LogError(ctx, "Failed to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
default:
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
}
if taskResult.Progress != "" {
task.Progress = taskResult.Progress
}
if err := task.Update(); err != nil {
common.SysLog("UpdateVideoTask task error: " + err.Error())
}
return nil
}

View File

@@ -12,29 +12,16 @@ import (
func GetAllTokens(c *gin.Context) { func GetAllTokens(c *gin.Context) {
userId := c.GetInt("id") userId := c.GetInt("id")
p, _ := strconv.Atoi(c.Query("p")) pageInfo := common.GetPageQuery(c)
size, _ := strconv.Atoi(c.Query("size")) tokens, err := model.GetAllUserTokens(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if p < 0 {
p = 0
}
if size <= 0 {
size = common.ItemsPerPage
} else if size > 100 {
size = 100
}
tokens, err := model.GetAllUserTokens(userId, p*size, size)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ total, _ := model.CountUserTokens(userId)
"success": true, pageInfo.SetTotal(int(total))
"message": "", pageInfo.SetItems(tokens)
"data": tokens, common.ApiSuccess(c, pageInfo)
})
return return
} }
@@ -44,10 +31,7 @@ func SearchTokens(c *gin.Context) {
token := c.Query("token") token := c.Query("token")
tokens, err := model.SearchUserTokens(userId, keyword, token) tokens, err := model.SearchUserTokens(userId, keyword, token)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -62,18 +46,12 @@ func GetToken(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
userId := c.GetInt("id") userId := c.GetInt("id")
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
token, err := model.GetTokenByIds(id, userId) token, err := model.GetTokenByIds(id, userId)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -89,10 +67,7 @@ func GetTokenStatus(c *gin.Context) {
userId := c.GetInt("id") userId := c.GetInt("id")
token, err := model.GetTokenByIds(tokenId, userId) token, err := model.GetTokenByIds(tokenId, userId)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
expiredAt := token.ExpiredTime expiredAt := token.ExpiredTime
@@ -162,10 +137,7 @@ func AddToken(c *gin.Context) {
token := model.Token{} token := model.Token{}
err := c.ShouldBindJSON(&token) err := c.ShouldBindJSON(&token)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
if len(token.Name) > 30 { if len(token.Name) > 30 {
@@ -181,7 +153,7 @@ func AddToken(c *gin.Context) {
"success": false, "success": false,
"message": "生成令牌失败", "message": "生成令牌失败",
}) })
common.SysError("failed to generate token key: " + err.Error()) common.SysLog("failed to generate token key: " + err.Error())
return return
} }
cleanToken := model.Token{ cleanToken := model.Token{
@@ -200,10 +172,7 @@ func AddToken(c *gin.Context) {
} }
err = cleanToken.Insert() err = cleanToken.Insert()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -218,10 +187,7 @@ func DeleteToken(c *gin.Context) {
userId := c.GetInt("id") userId := c.GetInt("id")
err := model.DeleteTokenById(id, userId) err := model.DeleteTokenById(id, userId)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -237,10 +203,7 @@ func UpdateToken(c *gin.Context) {
token := model.Token{} token := model.Token{}
err := c.ShouldBindJSON(&token) err := c.ShouldBindJSON(&token)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
if len(token.Name) > 30 { if len(token.Name) > 30 {
@@ -252,10 +215,7 @@ func UpdateToken(c *gin.Context) {
} }
cleanToken, err := model.GetTokenByIds(token.Id, userId) cleanToken, err := model.GetTokenByIds(token.Id, userId)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
if token.Status == common.TokenStatusEnabled { if token.Status == common.TokenStatusEnabled {
@@ -289,10 +249,7 @@ func UpdateToken(c *gin.Context) {
} }
err = cleanToken.Update() err = cleanToken.Update()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -302,3 +259,29 @@ func UpdateToken(c *gin.Context) {
}) })
return return
} }
type TokenBatch struct {
Ids []int `json:"ids"`
}
func DeleteTokenBatch(c *gin.Context) {
tokenBatch := TokenBatch{}
if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
userId := c.GetInt("id")
count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": count,
})
}

View File

@@ -5,6 +5,7 @@ import (
"log" "log"
"net/url" "net/url"
"one-api/common" "one-api/common"
"one-api/logger"
"one-api/model" "one-api/model"
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting"
@@ -97,16 +98,14 @@ func RequestEpay(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return return
} }
payType := "wxpay"
if req.PaymentMethod == "zfb" { if !setting.ContainsPayMethod(req.PaymentMethod) {
payType = "alipay" c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
} return
if req.PaymentMethod == "wx" {
req.PaymentMethod = "wxpay"
payType = "wxpay"
} }
callBackAddress := service.GetCallbackAddress() callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(setting.ServerAddress + "/log") returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix()) tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo) tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
@@ -116,7 +115,7 @@ func RequestEpay(c *gin.Context) {
return return
} }
uri, params, err := client.Purchase(&epay.PurchaseArgs{ uri, params, err := client.Purchase(&epay.PurchaseArgs{
Type: payType, Type: req.PaymentMethod,
ServiceTradeNo: tradeNo, ServiceTradeNo: tradeNo,
Name: fmt.Sprintf("TUC%d", req.Amount), Name: fmt.Sprintf("TUC%d", req.Amount),
Money: strconv.FormatFloat(payMoney, 'f', 2, 64), Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
@@ -233,7 +232,7 @@ func EpayNotify(c *gin.Context) {
return return
} }
log.Printf("易支付回调更新用户成功 %v", topUp) log.Printf("易支付回调更新用户成功 %v", topUp)
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", common.LogQuota(quotaToAdd), topUp.Money)) model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", logger.LogQuota(quotaToAdd), topUp.Money))
} }
} else { } else {
log.Printf("易支付异常回调: %v", verifyInfo) log.Printf("易支付异常回调: %v", verifyInfo)

275
controller/topup_stripe.go Normal file
View File

@@ -0,0 +1,275 @@
package controller
import (
"fmt"
"io"
"log"
"net/http"
"one-api/common"
"one-api/model"
"one-api/setting"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/stripe/stripe-go/v81"
"github.com/stripe/stripe-go/v81/checkout/session"
"github.com/stripe/stripe-go/v81/webhook"
"github.com/thanhpk/randstr"
)
const (
PaymentMethodStripe = "stripe"
)
var stripeAdaptor = &StripeAdaptor{}
type StripePayRequest struct {
Amount int64 `json:"amount"`
PaymentMethod string `json:"payment_method"`
}
type StripeAdaptor struct {
}
func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
if req.Amount < getStripeMinTopup() {
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
return
}
id := c.GetInt("id")
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
return
}
payMoney := getStripePayMoney(float64(req.Amount), group)
if payMoney <= 0.01 {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return
}
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
}
func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
if req.PaymentMethod != PaymentMethodStripe {
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
return
}
if req.Amount < getStripeMinTopup() {
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
return
}
if req.Amount > 10000 {
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
return
}
id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), randstr.String(4))
referenceId := "ref_" + common.Sha1([]byte(reference))
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount)
if err != nil {
log.Println("获取Stripe Checkout支付链接失败", err)
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
topUp := &model.TopUp{
UserId: id,
Amount: req.Amount,
Money: chargedMoney,
TradeNo: referenceId,
CreateTime: time.Now().Unix(),
Status: common.TopUpStatusPending,
}
err = topUp.Insert()
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
return
}
c.JSON(200, gin.H{
"message": "success",
"data": gin.H{
"pay_link": payLink,
},
})
}
func RequestStripeAmount(c *gin.Context) {
var req StripePayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
return
}
stripeAdaptor.RequestAmount(c, &req)
}
func RequestStripePay(c *gin.Context) {
var req StripePayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
return
}
stripeAdaptor.RequestPay(c, &req)
}
func StripeWebhook(c *gin.Context) {
payload, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
c.AbortWithStatus(http.StatusServiceUnavailable)
return
}
signature := c.GetHeader("Stripe-Signature")
endpointSecret := setting.StripeWebhookSecret
event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
IgnoreAPIVersionMismatch: true,
})
if err != nil {
log.Printf("Stripe Webhook验签失败: %v\n", err)
c.AbortWithStatus(http.StatusBadRequest)
return
}
switch event.Type {
case stripe.EventTypeCheckoutSessionCompleted:
sessionCompleted(event)
case stripe.EventTypeCheckoutSessionExpired:
sessionExpired(event)
default:
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
}
c.Status(http.StatusOK)
}
func sessionCompleted(event stripe.Event) {
customerId := event.GetObjectValue("customer")
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "complete" != status {
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
return
}
err := model.Recharge(referenceId, customerId)
if err != nil {
log.Println(err.Error(), referenceId)
return
}
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
currency := strings.ToUpper(event.GetObjectValue("currency"))
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
}
func sessionExpired(event stripe.Event) {
referenceId := event.GetObjectValue("client_reference_id")
status := event.GetObjectValue("status")
if "expired" != status {
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
return
}
if len(referenceId) == 0 {
log.Println("未提供支付单号")
return
}
topUp := model.GetTopUpByTradeNo(referenceId)
if topUp == nil {
log.Println("充值订单不存在", referenceId)
return
}
if topUp.Status != common.TopUpStatusPending {
log.Println("充值订单状态错误", referenceId)
}
topUp.Status = common.TopUpStatusExpired
err := topUp.Update()
if err != nil {
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
return
}
log.Println("充值订单已过期", referenceId)
}
func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
if !strings.HasPrefix(setting.StripeApiSecret, "sk_") && !strings.HasPrefix(setting.StripeApiSecret, "rk_") {
return "", fmt.Errorf("无效的Stripe API密钥")
}
stripe.Key = setting.StripeApiSecret
params := &stripe.CheckoutSessionParams{
ClientReferenceID: stripe.String(referenceId),
SuccessURL: stripe.String(setting.ServerAddress + "/log"),
CancelURL: stripe.String(setting.ServerAddress + "/topup"),
LineItems: []*stripe.CheckoutSessionLineItemParams{
{
Price: stripe.String(setting.StripePriceId),
Quantity: stripe.Int64(amount),
},
},
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
}
if "" == customerId {
if "" != email {
params.CustomerEmail = stripe.String(email)
}
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
} else {
params.Customer = stripe.String(customerId)
}
result, err := session.New(params)
if err != nil {
return "", err
}
return result.URL, nil
}
func GetChargedAmount(count float64, user model.User) float64 {
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
if topUpGroupRatio == 0 {
topUpGroupRatio = 1
}
return count * topUpGroupRatio
}
func getStripePayMoney(amount float64, group string) float64 {
if !common.DisplayInCurrencyEnabled {
amount = amount / common.QuotaPerUnit
}
// Using float64 for monetary calculations is acceptable here due to the small amounts involved
topupGroupRatio := common.GetTopupGroupRatio(group)
if topupGroupRatio == 0 {
topupGroupRatio = 1
}
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio
return payMoney
}
func getStripeMinTopup() int64 {
minTopup := setting.StripeMinTopUp
if !common.DisplayInCurrencyEnabled {
minTopup = minTopup * int(common.QuotaPerUnit)
}
return int64(minTopup)
}

553
controller/twofa.go Normal file
View File

@@ -0,0 +1,553 @@
package controller
import (
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
// Setup2FARequest 设置2FA请求结构
type Setup2FARequest struct {
Code string `json:"code" binding:"required"`
}
// Verify2FARequest 验证2FA请求结构
type Verify2FARequest struct {
Code string `json:"code" binding:"required"`
}
// Setup2FAResponse 设置2FA响应结构
type Setup2FAResponse struct {
Secret string `json:"secret"`
QRCodeData string `json:"qr_code_data"`
BackupCodes []string `json:"backup_codes"`
}
// Setup2FA 初始化2FA设置
func Setup2FA(c *gin.Context) {
userId := c.GetInt("id")
// 检查用户是否已经启用2FA
existing, err := model.GetTwoFAByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
if existing != nil && existing.IsEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已启用2FA请先禁用后重新设置",
})
return
}
// 如果存在已禁用的2FA记录先删除它
if existing != nil && !existing.IsEnabled {
if err := existing.Delete(); err != nil {
common.ApiError(c, err)
return
}
existing = nil // 重置为nil后续将创建新记录
}
// 获取用户信息
user, err := model.GetUserById(userId, false)
if err != nil {
common.ApiError(c, err)
return
}
// 生成TOTP密钥
key, err := common.GenerateTOTPSecret(user.Username)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成2FA密钥失败",
})
common.SysLog("生成TOTP密钥失败: " + err.Error())
return
}
// 生成备用码
backupCodes, err := common.GenerateBackupCodes()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成备用码失败",
})
common.SysLog("生成备用码失败: " + err.Error())
return
}
// 生成二维码数据
qrCodeData := common.GenerateQRCodeData(key.Secret(), user.Username)
// 创建或更新2FA记录暂未启用
twoFA := &model.TwoFA{
UserId: userId,
Secret: key.Secret(),
IsEnabled: false,
}
if existing != nil {
// 更新现有记录
twoFA.Id = existing.Id
err = twoFA.Update()
} else {
// 创建新记录
err = twoFA.Create()
}
if err != nil {
common.ApiError(c, err)
return
}
// 创建备用码记录
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "保存备用码失败",
})
common.SysLog("保存备用码失败: " + err.Error())
return
}
// 记录操作日志
model.RecordLog(userId, model.LogTypeSystem, "开始设置两步验证")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "2FA设置初始化成功请使用认证器扫描二维码并输入验证码完成设置",
"data": Setup2FAResponse{
Secret: key.Secret(),
QRCodeData: qrCodeData,
BackupCodes: backupCodes,
},
})
}
// Enable2FA 启用2FA
func Enable2FA(c *gin.Context) {
var req Setup2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
userId := c.GetInt("id")
// 获取2FA记录
twoFA, err := model.GetTwoFAByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
if twoFA == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "请先完成2FA初始化设置",
})
return
}
if twoFA.IsEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "2FA已经启用",
})
return
}
// 验证TOTP验证码
cleanCode, err := common.ValidateNumericCode(req.Code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if !common.ValidateTOTPCode(twoFA.Secret, cleanCode) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "验证码或备用码错误,请重试",
})
return
}
// 启用2FA
if err := twoFA.Enable(); err != nil {
common.ApiError(c, err)
return
}
// 记录操作日志
model.RecordLog(userId, model.LogTypeSystem, "成功启用两步验证")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "两步验证启用成功",
})
}
// Disable2FA 禁用2FA
func Disable2FA(c *gin.Context) {
var req Verify2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
userId := c.GetInt("id")
// 获取2FA记录
twoFA, err := model.GetTwoFAByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
if twoFA == nil || !twoFA.IsEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户未启用2FA",
})
return
}
// 验证TOTP验证码或备用码
cleanCode, err := common.ValidateNumericCode(req.Code)
isValidTOTP := false
isValidBackup := false
if err == nil {
// 尝试验证TOTP
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
}
if !isValidTOTP {
// 尝试验证备用码
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
}
if !isValidTOTP && !isValidBackup {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "验证码或备用码错误,请重试",
})
return
}
// 禁用2FA
if err := model.DisableTwoFA(userId); err != nil {
common.ApiError(c, err)
return
}
// 记录操作日志
model.RecordLog(userId, model.LogTypeSystem, "禁用两步验证")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "两步验证已禁用",
})
}
// Get2FAStatus 获取用户2FA状态
func Get2FAStatus(c *gin.Context) {
userId := c.GetInt("id")
twoFA, err := model.GetTwoFAByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
status := map[string]interface{}{
"enabled": false,
"locked": false,
}
if twoFA != nil {
status["enabled"] = twoFA.IsEnabled
status["locked"] = twoFA.IsLocked()
if twoFA.IsEnabled {
// 获取剩余备用码数量
backupCount, err := model.GetUnusedBackupCodeCount(userId)
if err != nil {
common.SysLog("获取备用码数量失败: " + err.Error())
} else {
status["backup_codes_remaining"] = backupCount
}
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": status,
})
}
// RegenerateBackupCodes 重新生成备用码
func RegenerateBackupCodes(c *gin.Context) {
var req Verify2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
userId := c.GetInt("id")
// 获取2FA记录
twoFA, err := model.GetTwoFAByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
if twoFA == nil || !twoFA.IsEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户未启用2FA",
})
return
}
// 验证TOTP验证码
cleanCode, err := common.ValidateNumericCode(req.Code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
valid, err := twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if !valid {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "验证码或备用码错误,请重试",
})
return
}
// 生成新的备用码
backupCodes, err := common.GenerateBackupCodes()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成备用码失败",
})
common.SysLog("生成备用码失败: " + err.Error())
return
}
// 保存新的备用码
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "保存备用码失败",
})
common.SysLog("保存备用码失败: " + err.Error())
return
}
// 记录操作日志
model.RecordLog(userId, model.LogTypeSystem, "重新生成两步验证备用码")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "备用码重新生成成功",
"data": map[string]interface{}{
"backup_codes": backupCodes,
},
})
}
// Verify2FALogin 登录时验证2FA
func Verify2FALogin(c *gin.Context) {
var req Verify2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
// 从会话中获取pending用户信息
session := sessions.Default(c)
pendingUserId := session.Get("pending_user_id")
if pendingUserId == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "会话已过期,请重新登录",
})
return
}
userId, ok := pendingUserId.(int)
if !ok {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "会话数据无效,请重新登录",
})
return
}
// 获取用户信息
user, err := model.GetUserById(userId, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户不存在",
})
return
}
// 获取2FA记录
twoFA, err := model.GetTwoFAByUserId(user.Id)
if err != nil {
common.ApiError(c, err)
return
}
if twoFA == nil || !twoFA.IsEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户未启用2FA",
})
return
}
// 验证TOTP验证码或备用码
cleanCode, err := common.ValidateNumericCode(req.Code)
isValidTOTP := false
isValidBackup := false
if err == nil {
// 尝试验证TOTP
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
}
if !isValidTOTP {
// 尝试验证备用码
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
}
if !isValidTOTP && !isValidBackup {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "验证码或备用码错误,请重试",
})
return
}
// 2FA验证成功清理pending会话信息并完成登录
session.Delete("pending_username")
session.Delete("pending_user_id")
session.Save()
setupLogin(user, c)
}
// Admin2FAStats 管理员获取2FA统计信息
func Admin2FAStats(c *gin.Context) {
stats, err := model.GetTwoFAStats()
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": stats,
})
}
// AdminDisable2FA 管理员强制禁用用户2FA
func AdminDisable2FA(c *gin.Context) {
userIdStr := c.Param("id")
userId, err := strconv.Atoi(userIdStr)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户ID格式错误",
})
return
}
// 检查目标用户权限
targetUser, err := model.GetUserById(userId, false)
if err != nil {
common.ApiError(c, err)
return
}
myRole := c.GetInt("role")
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权操作同级或更高级用户的2FA设置",
})
return
}
// 禁用2FA
if err := model.DisableTwoFA(userId); err != nil {
if errors.Is(err, model.ErrTwoFANotEnabled) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户未启用2FA",
})
return
}
common.ApiError(c, err)
return
}
// 记录操作日志
adminId := c.GetInt("id")
model.RecordLog(userId, model.LogTypeManage,
fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId))
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "用户2FA已被强制禁用",
})
}

154
controller/uptime_kuma.go Normal file
View File

@@ -0,0 +1,154 @@
package controller
import (
"context"
"encoding/json"
"errors"
"net/http"
"one-api/setting/console_setting"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
)
const (
requestTimeout = 30 * time.Second
httpTimeout = 10 * time.Second
uptimeKeySuffix = "_24"
apiStatusPath = "/api/status-page/"
apiHeartbeatPath = "/api/status-page/heartbeat/"
)
type Monitor struct {
Name string `json:"name"`
Uptime float64 `json:"uptime"`
Status int `json:"status"`
Group string `json:"group,omitempty"`
}
type UptimeGroupResult struct {
CategoryName string `json:"categoryName"`
Monitors []Monitor `json:"monitors"`
}
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.New("non-200 status")
}
return json.NewDecoder(resp.Body).Decode(dest)
}
func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[string]interface{}) UptimeGroupResult {
url, _ := groupConfig["url"].(string)
slug, _ := groupConfig["slug"].(string)
categoryName, _ := groupConfig["categoryName"].(string)
result := UptimeGroupResult{
CategoryName: categoryName,
Monitors: []Monitor{},
}
if url == "" || slug == "" {
return result
}
baseURL := strings.TrimSuffix(url, "/")
var statusData struct {
PublicGroupList []struct {
ID int `json:"id"`
Name string `json:"name"`
MonitorList []struct {
ID int `json:"id"`
Name string `json:"name"`
} `json:"monitorList"`
} `json:"publicGroupList"`
}
var heartbeatData struct {
HeartbeatList map[string][]struct {
Status int `json:"status"`
} `json:"heartbeatList"`
UptimeList map[string]float64 `json:"uptimeList"`
}
g, gCtx := errgroup.WithContext(ctx)
g.Go(func() error {
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
})
g.Go(func() error {
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
})
if g.Wait() != nil {
return result
}
for _, pg := range statusData.PublicGroupList {
if len(pg.MonitorList) == 0 {
continue
}
for _, m := range pg.MonitorList {
monitor := Monitor{
Name: m.Name,
Group: pg.Name,
}
monitorID := strconv.Itoa(m.ID)
if uptime, exists := heartbeatData.UptimeList[monitorID+uptimeKeySuffix]; exists {
monitor.Uptime = uptime
}
if heartbeats, exists := heartbeatData.HeartbeatList[monitorID]; exists && len(heartbeats) > 0 {
monitor.Status = heartbeats[0].Status
}
result.Monitors = append(result.Monitors, monitor)
}
}
return result
}
func GetUptimeKumaStatus(c *gin.Context) {
groups := console_setting.GetUptimeKumaGroups()
if len(groups) == 0 {
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": []UptimeGroupResult{}})
return
}
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
defer cancel()
client := &http.Client{Timeout: httpTimeout}
results := make([]UptimeGroupResult, len(groups))
g, gCtx := errgroup.WithContext(ctx)
for i, group := range groups {
i, group := i, group
g.Go(func() error {
results[i] = fetchGroupData(gCtx, client, group)
return nil
})
}
g.Wait()
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
}

View File

@@ -1,10 +1,12 @@
package controller package controller
import ( import (
"github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"github.com/gin-gonic/gin"
) )
func GetAllQuotaDates(c *gin.Context) { func GetAllQuotaDates(c *gin.Context) {
@@ -13,10 +15,7 @@ func GetAllQuotaDates(c *gin.Context) {
username := c.Query("username") username := c.Query("username")
dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username) dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -41,10 +40,7 @@ func GetUserQuotaDates(c *gin.Context) {
} }
dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp) dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -6,6 +6,8 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"one-api/common" "one-api/common"
"one-api/dto"
"one-api/logger"
"one-api/model" "one-api/model"
"one-api/setting" "one-api/setting"
"strconv" "strconv"
@@ -61,6 +63,32 @@ func Login(c *gin.Context) {
}) })
return return
} }
// 检查是否启用2FA
if model.IsTwoFAEnabled(user.Id) {
// 设置pending session等待2FA验证
session := sessions.Default(c)
session.Set("pending_username", user.Username)
session.Set("pending_user_id", user.Id)
err := session.Save()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"message": "无法保存会话信息,请重试",
"success": false,
})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "请输入两步验证码",
"success": true,
"data": map[string]interface{}{
"require_2fa": true,
},
})
return
}
setupLogin(&user, c) setupLogin(&user, c)
} }
@@ -165,7 +193,7 @@ func Register(c *gin.Context) {
"success": false, "success": false,
"message": "数据库错误,请稍后重试", "message": "数据库错误,请稍后重试",
}) })
common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err)) common.SysLog(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
return return
} }
if exist { if exist {
@@ -187,10 +215,7 @@ func Register(c *gin.Context) {
cleanUser.Email = user.Email cleanUser.Email = user.Email
} }
if err := cleanUser.Insert(inviterId); err != nil { if err := cleanUser.Insert(inviterId); err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
@@ -211,7 +236,7 @@ func Register(c *gin.Context) {
"success": false, "success": false,
"message": "生成默认令牌失败", "message": "生成默认令牌失败",
}) })
common.SysError("failed to generate token key: " + err.Error()) common.SysLog("failed to generate token key: " + err.Error())
return return
} }
// 生成默认令牌 // 生成默认令牌
@@ -226,6 +251,9 @@ func Register(c *gin.Context) {
UnlimitedQuota: true, UnlimitedQuota: true,
ModelLimitsEnabled: false, ModelLimitsEnabled: false,
} }
if setting.DefaultUseAutoGroup {
token.Group = "auto"
}
if err := token.Insert(); err != nil { if err := token.Insert(); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -243,83 +271,45 @@ func Register(c *gin.Context) {
} }
func GetAllUsers(c *gin.Context) { func GetAllUsers(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p")) pageInfo := common.GetPageQuery(c)
pageSize, _ := strconv.Atoi(c.Query("page_size")) users, total, err := model.GetAllUsers(pageInfo)
if p < 1 {
p = 1
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
users, total, err := model.GetAllUsers((p-1)*pageSize, pageSize)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{
"success": true, pageInfo.SetTotal(int(total))
"message": "", pageInfo.SetItems(users)
"data": gin.H{
"items": users, common.ApiSuccess(c, pageInfo)
"total": total,
"page": p,
"page_size": pageSize,
},
})
return return
} }
func SearchUsers(c *gin.Context) { func SearchUsers(c *gin.Context) {
keyword := c.Query("keyword") keyword := c.Query("keyword")
group := c.Query("group") group := c.Query("group")
p, _ := strconv.Atoi(c.Query("p")) pageInfo := common.GetPageQuery(c)
pageSize, _ := strconv.Atoi(c.Query("page_size")) users, total, err := model.SearchUsers(keyword, group, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if p < 1 {
p = 1
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
startIdx := (p - 1) * pageSize
users, total, err := model.SearchUsers(keyword, group, startIdx, pageSize)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{
"success": true, pageInfo.SetTotal(int(total))
"message": "", pageInfo.SetItems(users)
"data": gin.H{ common.ApiSuccess(c, pageInfo)
"items": users,
"total": total,
"page": p,
"page_size": pageSize,
},
})
return return
} }
func GetUser(c *gin.Context) { func GetUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
user, err := model.GetUserById(id, false) user, err := model.GetUserById(id, false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
myRole := c.GetInt("role") myRole := c.GetInt("role")
@@ -342,10 +332,7 @@ func GenerateAccessToken(c *gin.Context) {
id := c.GetInt("id") id := c.GetInt("id")
user, err := model.GetUserById(id, true) user, err := model.GetUserById(id, true)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
// get rand int 28-32 // get rand int 28-32
@@ -356,7 +343,7 @@ func GenerateAccessToken(c *gin.Context) {
"success": false, "success": false,
"message": "生成失败", "message": "生成失败",
}) })
common.SysError("failed to generate key: " + err.Error()) common.SysLog("failed to generate key: " + err.Error())
return return
} }
user.SetAccessToken(key) user.SetAccessToken(key)
@@ -370,10 +357,7 @@ func GenerateAccessToken(c *gin.Context) {
} }
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
@@ -393,18 +377,12 @@ func TransferAffQuota(c *gin.Context) {
id := c.GetInt("id") id := c.GetInt("id")
user, err := model.GetUserById(id, true) user, err := model.GetUserById(id, true)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
tran := TransferAffQuotaRequest{} tran := TransferAffQuotaRequest{}
if err := c.ShouldBindJSON(&tran); err != nil { if err := c.ShouldBindJSON(&tran); err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
err = user.TransferAffQuotaToQuota(tran.Quota) err = user.TransferAffQuotaToQuota(tran.Quota)
@@ -425,10 +403,7 @@ func GetAffCode(c *gin.Context) {
id := c.GetInt("id") id := c.GetInt("id")
user, err := model.GetUserById(id, true) user, err := model.GetUserById(id, true)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
if user.AffCode == "" { if user.AffCode == "" {
@@ -453,12 +428,12 @@ func GetSelf(c *gin.Context) {
id := c.GetInt("id") id := c.GetInt("id")
user, err := model.GetUserById(id, false) user, err := model.GetUserById(id, false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
user.Remark = ""
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
@@ -474,16 +449,13 @@ func GetUserModels(c *gin.Context) {
} }
user, err := model.GetUserCache(id) user, err := model.GetUserCache(id)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
groups := setting.GetUserUsableGroups(user.Group) groups := setting.GetUserUsableGroups(user.Group)
var models []string var models []string
for group := range groups { for group := range groups {
for _, g := range model.GetGroupModels(group) { for _, g := range model.GetGroupEnabledModels(group) {
if !common.StringsContains(models, g) { if !common.StringsContains(models, g) {
models = append(models, g) models = append(models, g)
} }
@@ -519,10 +491,7 @@ func UpdateUser(c *gin.Context) {
} }
originUser, err := model.GetUserById(updatedUser.Id, false) originUser, err := model.GetUserById(updatedUser.Id, false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
myRole := c.GetInt("role") myRole := c.GetInt("role")
@@ -545,14 +514,11 @@ func UpdateUser(c *gin.Context) {
} }
updatePassword := updatedUser.Password != "" updatePassword := updatedUser.Password != ""
if err := updatedUser.Edit(updatePassword); err != nil { if err := updatedUser.Edit(updatePassword); err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
if originUser.Quota != updatedUser.Quota { if originUser.Quota != updatedUser.Quota {
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota)))
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
@@ -594,17 +560,11 @@ func UpdateSelf(c *gin.Context) {
} }
updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id) updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
if err := cleanUser.Update(updatePassword); err != nil { if err := cleanUser.Update(updatePassword); err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
@@ -635,18 +595,12 @@ func checkUpdatePassword(originalPassword string, newPassword string, userId int
func DeleteUser(c *gin.Context) { func DeleteUser(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
originUser, err := model.GetUserById(id, false) originUser, err := model.GetUserById(id, false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
myRole := c.GetInt("role") myRole := c.GetInt("role")
@@ -681,10 +635,7 @@ func DeleteSelf(c *gin.Context) {
err := model.DeleteUserById(id) err := model.DeleteUserById(id)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -730,10 +681,7 @@ func CreateUser(c *gin.Context) {
DisplayName: user.DisplayName, DisplayName: user.DisplayName,
} }
if err := cleanUser.Insert(0); err != nil { if err := cleanUser.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
@@ -843,10 +791,7 @@ func ManageUser(c *gin.Context) {
} }
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
clearUser := model.User{ clearUser := model.User{
@@ -878,20 +823,14 @@ func EmailBind(c *gin.Context) {
} }
err := user.FillUserById() err := user.FillUserById()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
user.Email = email user.Email = email
// no need to check if this email already taken, because we have used verification code to check it // no need to check if this email already taken, because we have used verification code to check it
err = user.Update(false) err = user.Update(false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -905,27 +844,67 @@ type topUpRequest struct {
Key string `json:"key"` Key string `json:"key"`
} }
var topUpLock = sync.Mutex{} var topUpLocks sync.Map
var topUpCreateLock sync.Mutex
type topUpTryLock struct {
ch chan struct{}
}
func newTopUpTryLock() *topUpTryLock {
return &topUpTryLock{ch: make(chan struct{}, 1)}
}
func (l *topUpTryLock) TryLock() bool {
select {
case l.ch <- struct{}{}:
return true
default:
return false
}
}
func (l *topUpTryLock) Unlock() {
select {
case <-l.ch:
default:
}
}
func getTopUpLock(userID int) *topUpTryLock {
if v, ok := topUpLocks.Load(userID); ok {
return v.(*topUpTryLock)
}
topUpCreateLock.Lock()
defer topUpCreateLock.Unlock()
if v, ok := topUpLocks.Load(userID); ok {
return v.(*topUpTryLock)
}
l := newTopUpTryLock()
topUpLocks.Store(userID, l)
return l
}
func TopUp(c *gin.Context) { func TopUp(c *gin.Context) {
topUpLock.Lock() id := c.GetInt("id")
defer topUpLock.Unlock() lock := getTopUpLock(id)
req := topUpRequest{} if !lock.TryLock() {
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": "充值处理中,请稍后重试",
}) })
return return
} }
id := c.GetInt("id") defer lock.Unlock()
req := topUpRequest{}
err := c.ShouldBindJSON(&req)
if err != nil {
common.ApiError(c, err)
return
}
quota, err := model.Redeem(req.Key, id) quota, err := model.Redeem(req.Key, id)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -933,7 +912,6 @@ func TopUp(c *gin.Context) {
"message": "", "message": "",
"data": quota, "data": quota,
}) })
return
} }
type UpdateUserSettingRequest struct { type UpdateUserSettingRequest struct {
@@ -943,6 +921,7 @@ type UpdateUserSettingRequest struct {
WebhookSecret string `json:"webhook_secret,omitempty"` WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"` NotificationEmail string `json:"notification_email,omitempty"`
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"` AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
RecordIpLog bool `json:"record_ip_log"`
} }
func UpdateUserSetting(c *gin.Context) { func UpdateUserSetting(c *gin.Context) {
@@ -956,7 +935,7 @@ func UpdateUserSetting(c *gin.Context) {
} }
// 验证预警类型 // 验证预警类型
if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook { if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无效的预警类型", "message": "无效的预警类型",
@@ -974,7 +953,7 @@ func UpdateUserSetting(c *gin.Context) {
} }
// 如果是webhook类型,验证webhook地址 // 如果是webhook类型,验证webhook地址
if req.QuotaWarningType == constant.NotifyTypeWebhook { if req.QuotaWarningType == dto.NotifyTypeWebhook {
if req.WebhookUrl == "" { if req.WebhookUrl == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -993,7 +972,7 @@ func UpdateUserSetting(c *gin.Context) {
} }
// 如果是邮件类型,验证邮箱地址 // 如果是邮件类型,验证邮箱地址
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" { if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
// 验证邮箱格式 // 验证邮箱格式
if !strings.Contains(req.NotificationEmail, "@") { if !strings.Contains(req.NotificationEmail, "@") {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -1007,31 +986,29 @@ func UpdateUserSetting(c *gin.Context) {
userId := c.GetInt("id") userId := c.GetInt("id")
user, err := model.GetUserById(userId, true) user, err := model.GetUserById(userId, true)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
// 构建设置 // 构建设置
settings := map[string]interface{}{ settings := dto.UserSetting{
constant.UserSettingNotifyType: req.QuotaWarningType, NotifyType: req.QuotaWarningType,
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold, QuotaWarningThreshold: req.QuotaWarningThreshold,
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel, AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel,
RecordIpLog: req.RecordIpLog,
} }
// 如果是webhook类型,添加webhook相关设置 // 如果是webhook类型,添加webhook相关设置
if req.QuotaWarningType == constant.NotifyTypeWebhook { if req.QuotaWarningType == dto.NotifyTypeWebhook {
settings[constant.UserSettingWebhookUrl] = req.WebhookUrl settings.WebhookUrl = req.WebhookUrl
if req.WebhookSecret != "" { if req.WebhookSecret != "" {
settings[constant.UserSettingWebhookSecret] = req.WebhookSecret settings.WebhookSecret = req.WebhookSecret
} }
} }
// 如果提供了通知邮箱,添加到设置中 // 如果提供了通知邮箱,添加到设置中
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" { if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
settings[constant.UserSettingNotificationEmail] = req.NotificationEmail settings.NotificationEmail = req.NotificationEmail
} }
// 更新用户设置 // 更新用户设置

124
controller/vendor_meta.go Normal file
View File

@@ -0,0 +1,124 @@
package controller
import (
"strconv"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
// GetAllVendors 获取供应商列表(分页)
func GetAllVendors(c *gin.Context) {
pageInfo := common.GetPageQuery(c)
vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
common.ApiError(c, err)
return
}
var total int64
model.DB.Model(&model.Vendor{}).Count(&total)
pageInfo.SetTotal(int(total))
pageInfo.SetItems(vendors)
common.ApiSuccess(c, pageInfo)
}
// SearchVendors 搜索供应商
func SearchVendors(c *gin.Context) {
keyword := c.Query("keyword")
pageInfo := common.GetPageQuery(c)
vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
common.ApiError(c, err)
return
}
pageInfo.SetTotal(int(total))
pageInfo.SetItems(vendors)
common.ApiSuccess(c, pageInfo)
}
// GetVendorMeta 根据 ID 获取供应商
func GetVendorMeta(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
common.ApiError(c, err)
return
}
v, err := model.GetVendorByID(id)
if err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, v)
}
// CreateVendorMeta 新建供应商
func CreateVendorMeta(c *gin.Context) {
var v model.Vendor
if err := c.ShouldBindJSON(&v); err != nil {
common.ApiError(c, err)
return
}
if v.Name == "" {
common.ApiErrorMsg(c, "供应商名称不能为空")
return
}
// 创建前先检查名称
if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil {
common.ApiError(c, err)
return
} else if dup {
common.ApiErrorMsg(c, "供应商名称已存在")
return
}
if err := v.Insert(); err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, &v)
}
// UpdateVendorMeta 更新供应商
func UpdateVendorMeta(c *gin.Context) {
var v model.Vendor
if err := c.ShouldBindJSON(&v); err != nil {
common.ApiError(c, err)
return
}
if v.Id == 0 {
common.ApiErrorMsg(c, "缺少供应商 ID")
return
}
// 名称冲突检查
if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil {
common.ApiError(c, err)
return
} else if dup {
common.ApiErrorMsg(c, "供应商名称已存在")
return
}
if err := v.Update(); err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, &v)
}
// DeleteVendorMeta 删除供应商
func DeleteVendorMeta(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
common.ApiError(c, err)
return
}
if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, nil)
}

View File

@@ -4,13 +4,14 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"time" "time"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
) )
type wechatLoginResponse struct { type wechatLoginResponse struct {
@@ -150,19 +151,13 @@ func WeChatBind(c *gin.Context) {
} }
err = user.FillUserById() err = user.FillUserById()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
user.WeChatId = wechatId user.WeChatId = wechatId
err = user.Update(false) err = user.Update(false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ common.ApiError(c, err)
"success": false,
"message": err.Error(),
})
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -16,7 +16,7 @@ services:
- REDIS_CONN_STRING=redis://redis - REDIS_CONN_STRING=redis://redis
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录 - ERROR_LOG_ENABLED=true # 是否启用错误日志记录
# - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache请取消注释 # - STREAMING_TIMEOUT=300 # 流模式无响应超时时间单位秒默认120秒如果出现空补全可以尝试改为更大值
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!! # - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment # - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed # - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed

View File

197
docs/api/web_api.md Normal file
View File

@@ -0,0 +1,197 @@
# New API Web 界面后端接口文档
> 本文档汇总了 **New API** 后端提供给前端 Web 界面的全部 REST 接口(不含 *Relay* 相关接口)。
>
> 接口前缀统一为 `https://<your-domain>`,以下仅列出 **路径**、**HTTP 方法**、**鉴权要求** 与 **功能简介**。
>
> 鉴权级别说明:
> * **公开** 不需要登录即可调用
> * **用户** 需携带用户 Token`middleware.UserAuth`
> * **管理员** 需管理员 Token`middleware.AdminAuth`
> * **Root** 仅限最高权限 Root 用户(`middleware.RootAuth`
---
## 1. 初始化 / 系统状态
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| GET | /api/setup | 公开 | 获取系统初始化状态 |
| POST | /api/setup | 公开 | 完成首次安装向导 |
| GET | /api/status | 公开 | 获取运行状态摘要 |
| GET | /api/uptime/status | 公开 | Uptime-Kuma 兼容状态探针 |
| GET | /api/status/test | 管理员 | 测试后端与依赖组件是否正常 |
## 2. 公共信息
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| GET | /api/models | 用户 | 获取前端可用模型列表 |
| GET | /api/notice | 公开 | 获取公告栏内容 |
| GET | /api/about | 公开 | 关于页面信息 |
| GET | /api/home_page_content | 公开 | 首页自定义内容 |
| GET | /api/pricing | 可匿名/用户 | 价格与套餐信息 |
| GET | /api/ratio_config | 公开 | 模型倍率配置(仅公开字段) |
## 3. 邮件 / 身份验证
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| GET | /api/verification | 公开 (限流) | 发送邮箱验证邮件 |
| GET | /api/reset_password | 公开 (限流) | 发送重置密码邮件 |
| POST | /api/user/reset | 公开 | 提交重置密码请求 |
## 4. OAuth / 第三方登录
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| GET | /api/oauth/github | 公开 | GitHub OAuth 跳转 |
| GET | /api/oauth/oidc | 公开 | OIDC 通用 OAuth 跳转 |
| GET | /api/oauth/linuxdo | 公开 | LinuxDo OAuth 跳转 |
| GET | /api/oauth/wechat | 公开 | 微信扫码登录跳转 |
| GET | /api/oauth/wechat/bind | 公开 | 微信账户绑定 |
| GET | /api/oauth/email/bind | 公开 | 邮箱绑定 |
| GET | /api/oauth/telegram/login | 公开 | Telegram 登录 |
| GET | /api/oauth/telegram/bind | 公开 | Telegram 账户绑定 |
| GET | /api/oauth/state | 公开 | 获取随机 state防 CSRF |
## 5. 用户模块
### 5.1 账号注册/登录
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| POST | /api/user/register | 公开 | 注册新账号 |
| POST | /api/user/login | 公开 | 用户登录 |
| GET | /api/user/logout | 用户 | 退出登录 |
| GET | /api/user/epay/notify | 公开 | Epay 支付回调 |
| GET | /api/user/groups | 公开 | 列出所有分组(无鉴权版) |
### 5.2 用户自身操作 (需登录)
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| GET | /api/user/self/groups | 用户 | 获取自己所在分组 |
| GET | /api/user/self | 用户 | 获取个人资料 |
| GET | /api/user/models | 用户 | 获取模型可见性 |
| PUT | /api/user/self | 用户 | 修改个人资料 |
| DELETE | /api/user/self | 用户 | 注销账号 |
| GET | /api/user/token | 用户 | 生成用户级别 Access Token |
| GET | /api/user/aff | 用户 | 获取推广码信息 |
| POST | /api/user/topup | 用户 | 余额直充 |
| POST | /api/user/pay | 用户 | 提交支付订单 |
| POST | /api/user/amount | 用户 | 余额支付 |
| POST | /api/user/aff_transfer | 用户 | 推广额度转账 |
| PUT | /api/user/setting | 用户 | 更新用户设置 |
### 5.3 管理员用户管理
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| GET | /api/user/ | 管理员 | 获取全部用户列表 |
| GET | /api/user/search | 管理员 | 搜索用户 |
| GET | /api/user/:id | 管理员 | 获取单个用户信息 |
| POST | /api/user/ | 管理员 | 创建用户 |
| POST | /api/user/manage | 管理员 | 冻结/重置等管理操作 |
| PUT | /api/user/ | 管理员 | 更新用户 |
| DELETE | /api/user/:id | 管理员 | 删除用户 |
## 6. 站点选项 (Root)
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| GET | /api/option/ | Root | 获取全局配置 |
| PUT | /api/option/ | Root | 更新全局配置 |
| POST | /api/option/rest_model_ratio | Root | 重置模型倍率 |
| POST | /api/option/migrate_console_setting | Root | 迁移旧版控制台配置 |
## 7. 模型倍率同步 (Root)
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| GET | /api/ratio_sync/channels | Root | 获取可同步渠道列表 |
| POST | /api/ratio_sync/fetch | Root | 从上游拉取倍率 |
## 8. 渠道管理 (管理员)
| 方法 | 路径 | 说明 |
|------|------|------|
| GET | /api/channel/ | 获取渠道列表 |
| GET | /api/channel/search | 搜索渠道 |
| GET | /api/channel/models | 查询渠道模型能力 |
| GET | /api/channel/models_enabled | 查询启用模型能力 |
| GET | /api/channel/:id | 获取单个渠道 |
| GET | /api/channel/test | 批量测试渠道连通性 |
| GET | /api/channel/test/:id | 单个渠道测试 |
| GET | /api/channel/update_balance | 批量刷新余额 |
| GET | /api/channel/update_balance/:id | 单个刷新余额 |
| POST | /api/channel/ | 新增渠道 |
| PUT | /api/channel/ | 更新渠道 |
| DELETE | /api/channel/disabled | 删除已禁用渠道 |
| POST | /api/channel/tag/disabled | 批量禁用标签渠道 |
| POST | /api/channel/tag/enabled | 批量启用标签渠道 |
| PUT | /api/channel/tag | 编辑渠道标签 |
| DELETE | /api/channel/:id | 删除渠道 |
| POST | /api/channel/batch | 批量删除渠道 |
| POST | /api/channel/fix | 修复渠道能力表 |
| GET | /api/channel/fetch_models/:id | 拉取单渠道模型 |
| POST | /api/channel/fetch_models | 拉取全部渠道模型 |
| POST | /api/channel/batch/tag | 批量设置渠道标签 |
| GET | /api/channel/tag/models | 根据标签获取模型 |
| POST | /api/channel/copy/:id | 复制渠道 |
## 9. Token 管理
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| GET | /api/token/ | 用户 | 获取全部 Token |
| GET | /api/token/search | 用户 | 搜索 Token |
| GET | /api/token/:id | 用户 | 获取单个 Token |
| POST | /api/token/ | 用户 | 创建 Token |
| PUT | /api/token/ | 用户 | 更新 Token |
| DELETE | /api/token/:id | 用户 | 删除 Token |
| POST | /api/token/batch | 用户 | 批量删除 Token |
## 10. 兑换码管理 (管理员)
| 方法 | 路径 | 说明 |
|------|------|------|
| GET | /api/redemption/ | 获取兑换码列表 |
| GET | /api/redemption/search | 搜索兑换码 |
| GET | /api/redemption/:id | 获取单个兑换码 |
| POST | /api/redemption/ | 创建兑换码 |
| PUT | /api/redemption/ | 更新兑换码 |
| DELETE | /api/redemption/invalid | 删除无效兑换码 |
| DELETE | /api/redemption/:id | 删除兑换码 |
## 11. 日志
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| GET | /api/log/ | 管理员 | 获取全部日志 |
| DELETE | /api/log/ | 管理员 | 删除历史日志 |
| GET | /api/log/stat | 管理员 | 日志统计 |
| GET | /api/log/self/stat | 用户 | 我的日志统计 |
| GET | /api/log/search | 管理员 | 搜索全部日志 |
| GET | /api/log/self | 用户 | 获取我的日志 |
| GET | /api/log/self/search | 用户 | 搜索我的日志 |
| GET | /api/log/token | 公开 | 根据 Token 查询日志(支持 CORS |
## 12. 数据统计
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| GET | /api/data/ | 管理员 | 全站用量按日期统计 |
| GET | /api/data/self | 用户 | 我的用量按日期统计 |
## 13. 分组
| GET | /api/group/ | 管理员 | 获取全部分组列表 |
## 14. Midjourney 任务
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| GET | /api/mj/self | 用户 | 获取自己的 MJ 任务 |
| GET | /api/mj/ | 管理员 | 获取全部 MJ 任务 |
## 15. 任务中心
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| GET | /api/task/self | 用户 | 获取我的任务 |
| GET | /api/task/ | 管理员 | 获取全部任务 |
## 16. 账户计费面板 (Dashboard)
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| GET | /dashboard/billing/subscription | 用户 Token | 获取订阅额度信息 |
| GET | /v1/dashboard/billing/subscription | 同上 | 兼容 OpenAI SDK 路径 |
| GET | /dashboard/billing/usage | 用户 Token | 获取使用量信息 |
| GET | /v1/dashboard/billing/usage | 同上 | 兼容 OpenAI SDK 路径 |
---
> **更新日期**2025.07.17

BIN
docs/images/aliyun.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

BIN
docs/images/io-net.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

BIN
docs/images/pku.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

BIN
docs/images/ucloud.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

View File

@@ -1,5 +1,11 @@
package dto package dto
import (
"one-api/types"
"github.com/gin-gonic/gin"
)
type AudioRequest struct { type AudioRequest struct {
Model string `json:"model"` Model string `json:"model"`
Input string `json:"input"` Input string `json:"input"`
@@ -8,6 +14,24 @@ type AudioRequest struct {
ResponseFormat string `json:"response_format,omitempty"` ResponseFormat string `json:"response_format,omitempty"`
} }
func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta {
meta := &types.TokenCountMeta{
CombineText: r.Input,
TokenType: types.TokenTypeTextNumber,
}
return meta
}
func (r *AudioRequest) IsStream(c *gin.Context) bool {
return false
}
func (r *AudioRequest) SetModelName(modelName string) {
if modelName != "" {
r.Model = modelName
}
}
type AudioResponse struct { type AudioResponse struct {
Text string `json:"text"` Text string `json:"text"`
} }

14
dto/channel_settings.go Normal file
View File

@@ -0,0 +1,14 @@
package dto
type ChannelSettings struct {
ForceFormat bool `json:"force_format,omitempty"`
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
Proxy string `json:"proxy"`
PassThroughBodyEnabled bool `json:"pass_through_body_enabled,omitempty"`
SystemPrompt string `json:"system_prompt,omitempty"`
SystemPromptOverride bool `json:"system_prompt_override,omitempty"`
}
type ChannelOtherSettings struct {
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
}

View File

@@ -1,6 +1,14 @@
package dto package dto
import "encoding/json" import (
"encoding/json"
"fmt"
"one-api/common"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type ClaudeMetadata struct { type ClaudeMetadata struct {
UserId string `json:"user_id"` UserId string `json:"user_id"`
@@ -20,11 +28,11 @@ type ClaudeMediaMessage struct {
Delta string `json:"delta,omitempty"` Delta string `json:"delta,omitempty"`
CacheControl json.RawMessage `json:"cache_control,omitempty"` CacheControl json.RawMessage `json:"cache_control,omitempty"`
// tool_calls // tool_calls
Id string `json:"id,omitempty"` Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"` Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"` Input any `json:"input,omitempty"`
Content json.RawMessage `json:"content,omitempty"` Content any `json:"content,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"` ToolUseId string `json:"tool_use_id,omitempty"`
} }
func (c *ClaudeMediaMessage) SetText(s string) { func (c *ClaudeMediaMessage) SetText(s string) {
@@ -39,34 +47,54 @@ func (c *ClaudeMediaMessage) GetText() string {
} }
func (c *ClaudeMediaMessage) IsStringContent() bool { func (c *ClaudeMediaMessage) IsStringContent() bool {
var content string if c.Content == nil {
return json.Unmarshal(c.Content, &content) == nil return false
}
_, ok := c.Content.(string)
if ok {
return true
}
return false
} }
func (c *ClaudeMediaMessage) GetStringContent() string { func (c *ClaudeMediaMessage) GetStringContent() string {
var content string if c.Content == nil {
if err := json.Unmarshal(c.Content, &content); err == nil { return ""
return content
} }
switch c.Content.(type) {
case string:
return c.Content.(string)
case []any:
var contentStr string
for _, contentItem := range c.Content.([]any) {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == ContentTypeText {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return "" return ""
} }
func (c *ClaudeMediaMessage) GetJsonRowString() string { func (c *ClaudeMediaMessage) GetJsonRowString() string {
jsonContent, _ := json.Marshal(c) jsonContent, _ := common.Marshal(c)
return string(jsonContent) return string(jsonContent)
} }
func (c *ClaudeMediaMessage) SetContent(content any) { func (c *ClaudeMediaMessage) SetContent(content any) {
jsonContent, _ := json.Marshal(content) c.Content = content
c.Content = jsonContent
} }
func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage { func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
var mediaContent []ClaudeMediaMessage mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.Content)
if err := json.Unmarshal(c.Content, &mediaContent); err == nil { return mediaContent
return mediaContent
}
return make([]ClaudeMediaMessage, 0)
} }
type ClaudeMessageSource struct { type ClaudeMessageSource struct {
@@ -82,14 +110,36 @@ type ClaudeMessage struct {
} }
func (c *ClaudeMessage) IsStringContent() bool { func (c *ClaudeMessage) IsStringContent() bool {
if c.Content == nil {
return false
}
_, ok := c.Content.(string) _, ok := c.Content.(string)
return ok return ok
} }
func (c *ClaudeMessage) GetStringContent() string { func (c *ClaudeMessage) GetStringContent() string {
if c.IsStringContent() { if c.Content == nil {
return c.Content.(string) return ""
} }
switch c.Content.(type) {
case string:
return c.Content.(string)
case []any:
var contentStr string
for _, contentItem := range c.Content.([]any) {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == ContentTypeText {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return "" return ""
} }
@@ -98,15 +148,7 @@ func (c *ClaudeMessage) SetStringContent(content string) {
} }
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) { func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
// map content to []ClaudeMediaMessage return common.Any2Type[[]ClaudeMediaMessage](c.Content)
// parse to json
jsonContent, _ := json.Marshal(c.Content)
var contentList []ClaudeMediaMessage
err := json.Unmarshal(jsonContent, &contentList)
if err != nil {
return make([]ClaudeMediaMessage, 0), err
}
return contentList, nil
} }
type Tool struct { type Tool struct {
@@ -121,6 +163,27 @@ type InputSchema struct {
Required any `json:"required,omitempty"` Required any `json:"required,omitempty"`
} }
type ClaudeWebSearchTool struct {
Type string `json:"type"`
Name string `json:"name"`
MaxUses int `json:"max_uses,omitempty"`
UserLocation *ClaudeWebSearchUserLocation `json:"user_location,omitempty"`
}
type ClaudeWebSearchUserLocation struct {
Type string `json:"type"`
Timezone string `json:"timezone,omitempty"`
Country string `json:"country,omitempty"`
Region string `json:"region,omitempty"`
City string `json:"city,omitempty"`
}
type ClaudeToolChoice struct {
Type string `json:"type"`
Name string `json:"name,omitempty"`
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
}
type ClaudeRequest struct { type ClaudeRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
@@ -139,9 +202,210 @@ type ClaudeRequest struct {
Thinking *Thinking `json:"thinking,omitempty"` Thinking *Thinking `json:"thinking,omitempty"`
} }
func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
var tokenCountMeta = types.TokenCountMeta{
TokenType: types.TokenTypeTokenizer,
MaxTokens: int(c.MaxTokens),
}
var texts = make([]string, 0)
var fileMeta = make([]*types.FileMeta, 0)
// system
if c.System != nil {
if c.IsStringSystem() {
sys := c.GetStringSystem()
if sys != "" {
texts = append(texts, sys)
}
} else {
systemMedia := c.ParseSystem()
for _, media := range systemMedia {
switch media.Type {
case "text":
texts = append(texts, media.GetText())
case "image":
if media.Source != nil {
data := media.Source.Url
if data == "" {
data = common.Interface2String(media.Source.Data)
}
if data != "" {
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
}
}
}
}
}
}
// messages
for _, message := range c.Messages {
tokenCountMeta.MessagesCount++
texts = append(texts, message.Role)
if message.IsStringContent() {
content := message.GetStringContent()
if content != "" {
texts = append(texts, content)
}
continue
}
content, _ := message.ParseContent()
for _, media := range content {
switch media.Type {
case "text":
texts = append(texts, media.GetText())
case "image":
if media.Source != nil {
data := media.Source.Url
if data == "" {
data = common.Interface2String(media.Source.Data)
}
if data != "" {
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
}
}
case "tool_use":
if media.Name != "" {
texts = append(texts, media.Name)
}
if media.Input != nil {
b, _ := common.Marshal(media.Input)
texts = append(texts, string(b))
}
case "tool_result":
if media.Content != nil {
b, _ := common.Marshal(media.Content)
texts = append(texts, string(b))
}
}
}
}
// tools
if c.Tools != nil {
tools := c.GetTools()
normalTools, webSearchTools := ProcessTools(tools)
if normalTools != nil {
for _, t := range normalTools {
tokenCountMeta.ToolsCount++
if t.Name != "" {
texts = append(texts, t.Name)
}
if t.Description != "" {
texts = append(texts, t.Description)
}
if t.InputSchema != nil {
b, _ := common.Marshal(t.InputSchema)
texts = append(texts, string(b))
}
}
}
if webSearchTools != nil {
for _, t := range webSearchTools {
tokenCountMeta.ToolsCount++
if t.Name != "" {
texts = append(texts, t.Name)
}
if t.UserLocation != nil {
b, _ := common.Marshal(t.UserLocation)
texts = append(texts, string(b))
}
}
}
}
tokenCountMeta.CombineText = strings.Join(texts, "\n")
tokenCountMeta.Files = fileMeta
return &tokenCountMeta
}
func (c *ClaudeRequest) IsStream(ctx *gin.Context) bool {
return c.Stream
}
func (c *ClaudeRequest) SetModelName(modelName string) {
if modelName != "" {
c.Model = modelName
}
}
func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string {
for _, message := range c.Messages {
content, _ := message.ParseContent()
for _, mediaMessage := range content {
if mediaMessage.Id == toolCallId {
return mediaMessage.Name
}
}
}
return ""
}
// AddTool 添加工具到请求中
func (c *ClaudeRequest) AddTool(tool any) {
if c.Tools == nil {
c.Tools = make([]any, 0)
}
switch tools := c.Tools.(type) {
case []any:
c.Tools = append(tools, tool)
default:
// 如果Tools不是[]any类型重新初始化为[]any
c.Tools = []any{tool}
}
}
// GetTools 获取工具列表
func (c *ClaudeRequest) GetTools() []any {
if c.Tools == nil {
return nil
}
switch tools := c.Tools.(type) {
case []any:
return tools
default:
return nil
}
}
// ProcessTools 处理工具列表,支持类型断言
func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) {
var normalTools []*Tool
var webSearchTools []*ClaudeWebSearchTool
for _, tool := range tools {
switch t := tool.(type) {
case *Tool:
normalTools = append(normalTools, t)
case *ClaudeWebSearchTool:
webSearchTools = append(webSearchTools, t)
case Tool:
normalTools = append(normalTools, &t)
case ClaudeWebSearchTool:
webSearchTools = append(webSearchTools, &t)
default:
// 未知类型,跳过
continue
}
}
return normalTools, webSearchTools
}
type Thinking struct { type Thinking struct {
Type string `json:"type"` Type string `json:"type"`
BudgetTokens int `json:"budget_tokens"` BudgetTokens *int `json:"budget_tokens,omitempty"`
}
func (c *Thinking) GetBudgetTokens() int {
if c.BudgetTokens == nil {
return 0
}
return *c.BudgetTokens
} }
func (c *ClaudeRequest) IsStringSystem() bool { func (c *ClaudeRequest) IsStringSystem() bool {
@@ -161,24 +425,13 @@ func (c *ClaudeRequest) SetStringSystem(system string) {
} }
func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage { func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
// map content to []ClaudeMediaMessage mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.System)
// parse to json return mediaContent
jsonContent, _ := json.Marshal(c.System)
var contentList []ClaudeMediaMessage
if err := json.Unmarshal(jsonContent, &contentList); err == nil {
return contentList
}
return make([]ClaudeMediaMessage, 0)
}
type ClaudeError struct {
Type string `json:"type,omitempty"`
Message string `json:"message,omitempty"`
} }
type ClaudeErrorWithStatusCode struct { type ClaudeErrorWithStatusCode struct {
Error ClaudeError `json:"error"` Error types.ClaudeError `json:"error"`
StatusCode int `json:"status_code"` StatusCode int `json:"status_code"`
LocalError bool LocalError bool
} }
@@ -190,7 +443,7 @@ type ClaudeResponse struct {
Completion string `json:"completion,omitempty"` Completion string `json:"completion,omitempty"`
StopReason string `json:"stop_reason,omitempty"` StopReason string `json:"stop_reason,omitempty"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Error *ClaudeError `json:"error,omitempty"` Error any `json:"error,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"` Usage *ClaudeUsage `json:"usage,omitempty"`
Index *int `json:"index,omitempty"` Index *int `json:"index,omitempty"`
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"` ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
@@ -211,9 +464,50 @@ func (c *ClaudeResponse) GetIndex() int {
return *c.Index return *c.Index
} }
type ClaudeUsage struct { // GetClaudeError 从动态错误类型中提取ClaudeError结构
InputTokens int `json:"input_tokens"` func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError {
CacheCreationInputTokens int `json:"cache_creation_input_tokens"` if c.Error == nil {
CacheReadInputTokens int `json:"cache_read_input_tokens"` return nil
OutputTokens int `json:"output_tokens"` }
switch err := c.Error.(type) {
case types.ClaudeError:
return &err
case *types.ClaudeError:
return err
case map[string]interface{}:
// 处理从JSON解析来的map结构
claudeErr := &types.ClaudeError{}
if errType, ok := err["type"].(string); ok {
claudeErr.Type = errType
}
if errMsg, ok := err["message"].(string); ok {
claudeErr.Message = errMsg
}
return claudeErr
case string:
// 处理简单字符串错误
return &types.ClaudeError{
Type: "error",
Message: err,
}
default:
// 未知类型,尝试转换为字符串
return &types.ClaudeError{
Type: "unknown_error",
Message: fmt.Sprintf("%v", err),
}
}
}
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
OutputTokens int `json:"output_tokens"`
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use,omitempty"`
}
type ClaudeServerToolUse struct {
WebSearchRequests int `json:"web_search_requests"`
} }

View File

@@ -1,28 +0,0 @@
package dto
import "encoding/json"
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
User string `json:"user,omitempty"`
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
Background string `json:"background,omitempty"`
Moderation string `json:"moderation,omitempty"`
OutputFormat string `json:"output_format,omitempty"`
}
type ImageResponse struct {
Data []ImageData `json:"data"`
Created int64 `json:"created"`
}
type ImageData struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
RevisedPrompt string `json:"revised_prompt"`
}

View File

@@ -1,5 +1,12 @@
package dto package dto
import (
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type EmbeddingOptions struct { type EmbeddingOptions struct {
Seed int `json:"seed,omitempty"` Seed int `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
@@ -24,9 +31,32 @@ type EmbeddingRequest struct {
PresencePenalty float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
} }
func (r EmbeddingRequest) ParseInput() []string { func (r *EmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
var texts = make([]string, 0)
inputs := r.ParseInput()
for _, input := range inputs {
texts = append(texts, input)
}
return &types.TokenCountMeta{
CombineText: strings.Join(texts, "\n"),
}
}
func (r *EmbeddingRequest) IsStream(c *gin.Context) bool {
return false
}
func (r *EmbeddingRequest) SetModelName(modelName string) {
if modelName != "" {
r.Model = modelName
}
}
func (r *EmbeddingRequest) ParseInput() []string {
if r.Input == nil { if r.Input == nil {
return nil return make([]string, 0)
} }
var input []string var input []string
switch r.Input.(type) { switch r.Input.(type) {

View File

@@ -1,5 +1,7 @@
package dto package dto
import "one-api/types"
type OpenAIError struct { type OpenAIError struct {
Message string `json:"message"` Message string `json:"message"`
Type string `json:"type"` Type string `json:"type"`
@@ -14,11 +16,11 @@ type OpenAIErrorWithStatusCode struct {
} }
type GeneralErrorResponse struct { type GeneralErrorResponse struct {
Error OpenAIError `json:"error"` Error types.OpenAIError `json:"error"`
Message string `json:"message"` Message string `json:"message"`
Msg string `json:"msg"` Msg string `json:"msg"`
Err string `json:"err"` Err string `json:"err"`
ErrorMsg string `json:"error_msg"` ErrorMsg string `json:"error_msg"`
Header struct { Header struct {
Message string `json:"message"` Message string `json:"message"`
} `json:"header"` } `json:"header"`

384
dto/gemini.go Normal file
View File

@@ -0,0 +1,384 @@
package dto
import (
"encoding/json"
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/logger"
"one-api/types"
"strings"
)
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
Tools json.RawMessage `json:"tools,omitempty"`
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
}
func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
var files []*types.FileMeta = make([]*types.FileMeta, 0)
var maxTokens int
if r.GenerationConfig.MaxOutputTokens > 0 {
maxTokens = int(r.GenerationConfig.MaxOutputTokens)
}
var inputTexts []string
for _, content := range r.Contents {
for _, part := range content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
if part.InlineData != nil && part.InlineData.Data != "" {
if strings.HasPrefix(part.InlineData.MimeType, "image/") {
files = append(files, &types.FileMeta{
FileType: types.FileTypeImage,
OriginData: part.InlineData.Data,
})
} else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
files = append(files, &types.FileMeta{
FileType: types.FileTypeAudio,
OriginData: part.InlineData.Data,
})
} else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
files = append(files, &types.FileMeta{
FileType: types.FileTypeVideo,
OriginData: part.InlineData.Data,
})
} else {
files = append(files, &types.FileMeta{
FileType: types.FileTypeFile,
OriginData: part.InlineData.Data,
})
}
}
}
}
inputText := strings.Join(inputTexts, "\n")
return &types.TokenCountMeta{
CombineText: inputText,
Files: files,
MaxTokens: maxTokens,
}
}
func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
if c.Query("alt") == "sse" {
return true
}
return false
}
func (r *GeminiChatRequest) SetModelName(modelName string) {
// GeminiChatRequest does not have a model field, so this method does nothing.
}
func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
var tools []GeminiChatTool
if strings.HasSuffix(string(r.Tools), "[") {
// is array
if err := common.Unmarshal(r.Tools, &tools); err != nil {
logger.LogError(nil, "error_unmarshalling_tools: "+err.Error())
return nil
}
} else if strings.HasPrefix(string(r.Tools), "{") {
// is object
singleTool := GeminiChatTool{}
if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
return nil
}
tools = []GeminiChatTool{singleTool}
}
return tools
}
func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
if len(tools) == 0 {
r.Tools = json.RawMessage("[]")
return
}
// Marshal the tools to JSON
data, err := common.Marshal(tools)
if err != nil {
logger.LogError(nil, "error_marshalling_tools: "+err.Error())
return
}
r.Tools = data
}
type GeminiThinkingConfig struct {
IncludeThoughts bool `json:"includeThoughts,omitempty"`
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
}
func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
c.ThinkingBudget = &budget
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
// UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
type Alias GeminiInlineData // Use type alias to avoid recursion
var aux struct {
Alias
MimeTypeSnake string `json:"mime_type"`
}
if err := common.Unmarshal(data, &aux); err != nil {
return err
}
*g = GeminiInlineData(aux.Alias) // Copy other fields if any in future
// Prioritize snake_case if present
if aux.MimeTypeSnake != "" {
g.MimeType = aux.MimeTypeSnake
} else if aux.MimeType != "" { // Fallback to camelCase from Alias
g.MimeType = aux.MimeType
}
// g.Data would be populated by aux.Alias.Data
return nil
}
type FunctionCall struct {
FunctionName string `json:"name"`
Arguments any `json:"args"`
}
type GeminiFunctionResponse struct {
Name string `json:"name"`
Response map[string]interface{} `json:"response"`
}
type GeminiPartExecutableCode struct {
Language string `json:"language,omitempty"`
Code string `json:"code,omitempty"`
}
type GeminiPartCodeExecutionResult struct {
Outcome string `json:"outcome,omitempty"`
Output string `json:"output,omitempty"`
}
type GeminiFileData struct {
MimeType string `json:"mimeType,omitempty"`
FileUri string `json:"fileUri,omitempty"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
Thought bool `json:"thought,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
FileData *GeminiFileData `json:"fileData,omitempty"`
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
}
// UnmarshalJSON custom unmarshaler for GeminiPart to support snake_case and camelCase for InlineData
func (p *GeminiPart) UnmarshalJSON(data []byte) error {
// Alias to avoid recursion during unmarshalling
type Alias GeminiPart
var aux struct {
Alias
InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
}
if err := common.Unmarshal(data, &aux); err != nil {
return err
}
// Assign fields from alias
*p = GeminiPart(aux.Alias)
// Prioritize snake_case for InlineData if present
if aux.InlineDataSnake != nil {
p.InlineData = aux.InlineDataSnake
} else if aux.InlineData != nil { // Fallback to camelCase from Alias
p.InlineData = aux.InlineData
}
// Other fields like Text, FunctionCall etc. are already populated via aux.Alias
return nil
}
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type GeminiChatTool struct {
GoogleSearch any `json:"googleSearch,omitempty"`
GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
CodeExecution any `json:"codeExecution,omitempty"`
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"`
Seed int64 `json:"seed,omitempty"`
ResponseModalities []string `json:"responseModalities,omitempty"`
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason *string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type GeminiChatPromptFeedback struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
}
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
}
type GeminiPromptTokensDetails struct {
Modality string `json:"modality"`
TokenCount int `json:"tokenCount"`
}
// Imagen related structs
type GeminiImageRequest struct {
Instances []GeminiImageInstance `json:"instances"`
Parameters GeminiImageParameters `json:"parameters"`
}
type GeminiImageInstance struct {
Prompt string `json:"prompt"`
}
type GeminiImageParameters struct {
SampleCount int `json:"sampleCount,omitempty"`
AspectRatio string `json:"aspectRatio,omitempty"`
PersonGeneration string `json:"personGeneration,omitempty"`
}
type GeminiImageResponse struct {
Predictions []GeminiImagePrediction `json:"predictions"`
}
type GeminiImagePrediction struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
SafetyAttributes any `json:"safetyAttributes,omitempty"`
}
// Embedding related structs
type GeminiEmbeddingRequest struct {
Model string `json:"model,omitempty"`
Content GeminiChatContent `json:"content"`
TaskType string `json:"taskType,omitempty"`
Title string `json:"title,omitempty"`
OutputDimensionality int `json:"outputDimensionality,omitempty"`
}
func (r *GeminiEmbeddingRequest) IsStream(c *gin.Context) bool {
// Gemini embedding requests are not streamed
return false
}
func (r *GeminiEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
var inputTexts []string
for _, part := range r.Content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
}
inputText := strings.Join(inputTexts, "\n")
return &types.TokenCountMeta{
CombineText: inputText,
}
}
func (r *GeminiEmbeddingRequest) SetModelName(modelName string) {
if modelName != "" {
r.Model = modelName
}
}
type GeminiBatchEmbeddingRequest struct {
Requests []*GeminiEmbeddingRequest `json:"requests"`
}
func (r *GeminiBatchEmbeddingRequest) IsStream(c *gin.Context) bool {
// Gemini batch embedding requests are not streamed
return false
}
func (r *GeminiBatchEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
var inputTexts []string
for _, request := range r.Requests {
meta := request.GetTokenCountMeta()
if meta != nil && meta.CombineText != "" {
inputTexts = append(inputTexts, meta.CombineText)
}
}
inputText := strings.Join(inputTexts, "\n")
return &types.TokenCountMeta{
CombineText: inputText,
}
}
func (r *GeminiBatchEmbeddingRequest) SetModelName(modelName string) {
if modelName != "" {
for _, req := range r.Requests {
req.SetModelName(modelName)
}
}
}
type GeminiEmbeddingResponse struct {
Embedding ContentEmbedding `json:"embedding"`
}
type GeminiBatchEmbeddingResponse struct {
Embeddings []*ContentEmbedding `json:"embeddings"`
}
type ContentEmbedding struct {
Values []float64 `json:"values"`
}

View File

@@ -57,6 +57,8 @@ type MidjourneyDto struct {
StartTime int64 `json:"startTime"` StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"` FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"` ImageUrl string `json:"imageUrl"`
VideoUrl string `json:"videoUrl"`
VideoUrls []ImgUrls `json:"videoUrls"`
Status string `json:"status"` Status string `json:"status"`
Progress string `json:"progress"` Progress string `json:"progress"`
FailReason string `json:"failReason"` FailReason string `json:"failReason"`
@@ -65,6 +67,10 @@ type MidjourneyDto struct {
Properties *Properties `json:"properties"` Properties *Properties `json:"properties"`
} }
type ImgUrls struct {
Url string `json:"url"`
}
type MidjourneyStatus struct { type MidjourneyStatus struct {
Status int `json:"status"` Status int `json:"status"`
} }

80
dto/openai_image.go Normal file
View File

@@ -0,0 +1,80 @@
package dto
import (
"encoding/json"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N uint `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style json.RawMessage `json:"style,omitempty"`
User json.RawMessage `json:"user,omitempty"`
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
Background json.RawMessage `json:"background,omitempty"`
Moderation json.RawMessage `json:"moderation,omitempty"`
OutputFormat json.RawMessage `json:"output_format,omitempty"`
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
PartialImages json.RawMessage `json:"partial_images,omitempty"`
// Stream bool `json:"stream,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
}
func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
var sizeRatio = 1.0
var qualityRatio = 1.0
if strings.HasPrefix(i.Model, "dall-e") {
// Size
if i.Size == "256x256" {
sizeRatio = 0.4
} else if i.Size == "512x512" {
sizeRatio = 0.45
} else if i.Size == "1024x1024" {
sizeRatio = 1
} else if i.Size == "1024x1792" || i.Size == "1792x1024" {
sizeRatio = 2
}
if i.Model == "dall-e-3" && i.Quality == "hd" {
qualityRatio = 2.0
if i.Size == "1024x1792" || i.Size == "1792x1024" {
qualityRatio = 1.5
}
}
}
// not support token count for dalle
return &types.TokenCountMeta{
CombineText: i.Prompt,
MaxTokens: 1584,
ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N),
}
}
func (i *ImageRequest) IsStream(c *gin.Context) bool {
return false
}
func (i *ImageRequest) SetModelName(modelName string) {
if modelName != "" {
i.Model = modelName
}
}
type ImageResponse struct {
Data []ImageData `json:"data"`
Created int64 `json:"created"`
}
type ImageData struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
RevisedPrompt string `json:"revised_prompt"`
}

View File

@@ -2,70 +2,211 @@ package dto
import ( import (
"encoding/json" "encoding/json"
"fmt"
"one-api/common" "one-api/common"
"one-api/types"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
type ResponseFormat struct { type ResponseFormat struct {
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"` JsonSchema json.RawMessage `json:"json_schema,omitempty"`
} }
type FormatJsonSchema struct { type FormatJsonSchema struct {
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Name string `json:"name"` Name string `json:"name"`
Schema any `json:"schema,omitempty"` Schema any `json:"schema,omitempty"`
Strict any `json:"strict,omitempty"` Strict json.RawMessage `json:"strict,omitempty"`
} }
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"` Prompt any `json:"prompt,omitempty"`
Prefix any `json:"prefix,omitempty"` Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"` Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"` StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"` MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"` ReasoningEffort string `json:"reasoning_effort,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Verbosity json.RawMessage `json:"verbosity,omitempty"` // gpt-5
TopP float64 `json:"top_p,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"` TopP float64 `json:"top_p,omitempty"`
Stop any `json:"stop,omitempty"` TopK int `json:"top_k,omitempty"`
N int `json:"n,omitempty"` Stop any `json:"stop,omitempty"`
Input any `json:"input,omitempty"` N int `json:"n,omitempty"`
Instruction string `json:"instruction,omitempty"` Input any `json:"input,omitempty"`
Size string `json:"size,omitempty"` Instruction string `json:"instruction,omitempty"`
Functions any `json:"functions,omitempty"` Size string `json:"size,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` Functions json.RawMessage `json:"functions,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"` EncodingFormat json.RawMessage `json:"encoding_format,omitempty"`
ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"` Seed float64 `json:"seed,omitempty"`
Tools []ToolCallRequest `json:"tools,omitempty"` ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` Tools []ToolCallRequest `json:"tools,omitempty"`
User string `json:"user,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
LogProbs bool `json:"logprobs,omitempty"` User string `json:"user,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"` LogProbs bool `json:"logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"` TopLogProbs int `json:"top_logprobs,omitempty"`
Modalities any `json:"modalities,omitempty"` Dimensions int `json:"dimensions,omitempty"`
Audio any `json:"audio,omitempty"` Modalities json.RawMessage `json:"modalities,omitempty"`
EnableThinking any `json:"enable_thinking,omitempty"` // ali Audio json.RawMessage `json:"audio,omitempty"`
ExtraBody any `json:"extra_body,omitempty"` EnableThinking any `json:"enable_thinking,omitempty"` // ali
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` THINKING json.RawMessage `json:"thinking,omitempty"` // doubao,zhipu_v4
// OpenRouter Params ExtraBody json.RawMessage `json:"extra_body,omitempty"`
SearchParameters any `json:"search_parameters,omitempty"` //xai
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
// OpenRouter Params
Usage json.RawMessage `json:"usage,omitempty"`
Reasoning json.RawMessage `json:"reasoning,omitempty"` Reasoning json.RawMessage `json:"reasoning,omitempty"`
// Ali Qwen Params
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
// 用匿名参数接收额外参数例如ollama的think参数在此接收
Extra map[string]json.RawMessage `json:"-"`
}
func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
var tokenCountMeta types.TokenCountMeta
var texts = make([]string, 0)
var fileMeta = make([]*types.FileMeta, 0)
if r.Prompt != nil {
switch v := r.Prompt.(type) {
case string:
texts = append(texts, v)
case []any:
for _, item := range v {
if str, ok := item.(string); ok {
texts = append(texts, str)
}
}
default:
texts = append(texts, fmt.Sprintf("%v", r.Prompt))
}
}
if r.Input != nil {
inputs := r.ParseInput()
texts = append(texts, inputs...)
}
if r.MaxCompletionTokens > r.MaxTokens {
tokenCountMeta.MaxTokens = int(r.MaxCompletionTokens)
} else {
tokenCountMeta.MaxTokens = int(r.MaxTokens)
}
for _, message := range r.Messages {
tokenCountMeta.MessagesCount++
texts = append(texts, message.Role)
if message.Content != nil {
if message.Name != nil {
tokenCountMeta.NameCount++
texts = append(texts, *message.Name)
}
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == ContentTypeImageURL {
imageUrl := m.GetImageMedia()
if imageUrl != nil {
if imageUrl.Url != "" {
meta := &types.FileMeta{
FileType: types.FileTypeImage,
}
meta.OriginData = imageUrl.Url
meta.Detail = imageUrl.Detail
fileMeta = append(fileMeta, meta)
}
}
} else if m.Type == ContentTypeInputAudio {
inputAudio := m.GetInputAudio()
if inputAudio != nil {
meta := &types.FileMeta{
FileType: types.FileTypeAudio,
}
meta.OriginData = inputAudio.Data
fileMeta = append(fileMeta, meta)
}
} else if m.Type == ContentTypeFile {
file := m.GetFile()
if file != nil {
meta := &types.FileMeta{
FileType: types.FileTypeFile,
}
meta.OriginData = file.FileData
fileMeta = append(fileMeta, meta)
}
} else if m.Type == ContentTypeVideoUrl {
videoUrl := m.GetVideoUrl()
if videoUrl != nil && videoUrl.Url != "" {
meta := &types.FileMeta{
FileType: types.FileTypeVideo,
}
meta.OriginData = videoUrl.Url
fileMeta = append(fileMeta, meta)
}
} else {
texts = append(texts, m.Text)
}
}
}
}
if r.Tools != nil {
openaiTools := r.Tools
for _, tool := range openaiTools {
tokenCountMeta.ToolsCount++
texts = append(texts, tool.Function.Name)
if tool.Function.Description != "" {
texts = append(texts, tool.Function.Description)
}
if tool.Function.Parameters != nil {
texts = append(texts, fmt.Sprintf("%v", tool.Function.Parameters))
}
}
//toolTokens := CountTokenInput(countStr, request.Model)
//tkm += 8
//tkm += toolTokens
}
tokenCountMeta.CombineText = strings.Join(texts, "\n")
tokenCountMeta.Files = fileMeta
return &tokenCountMeta
}
func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool {
return r.Stream
}
func (r *GeneralOpenAIRequest) SetModelName(modelName string) {
if modelName != "" {
r.Model = modelName
}
} }
func (r *GeneralOpenAIRequest) ToMap() map[string]any { func (r *GeneralOpenAIRequest) ToMap() map[string]any {
result := make(map[string]any) result := make(map[string]any)
data, _ := common.EncodeJson(r) data, _ := common.Marshal(r)
_ = common.DecodeJson(data, &result) _ = common.Unmarshal(data, &result)
return result return result
} }
func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
if strings.HasPrefix(r.Model, "o") {
if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
return "developer"
}
} else if strings.HasPrefix(r.Model, "gpt-5") {
return "developer"
}
return "system"
}
type ToolCallRequest struct { type ToolCallRequest struct {
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
Type string `json:"type"` Type string `json:"type"`
@@ -83,8 +224,11 @@ type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"` IncludeUsage bool `json:"include_usage,omitempty"`
} }
func (r *GeneralOpenAIRequest) GetMaxTokens() int { func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
return int(r.MaxTokens) if r.MaxCompletionTokens != 0 {
return r.MaxCompletionTokens
}
return r.MaxTokens
} }
func (r *GeneralOpenAIRequest) ParseInput() []string { func (r *GeneralOpenAIRequest) ParseInput() []string {
@@ -107,16 +251,16 @@ func (r *GeneralOpenAIRequest) ParseInput() []string {
} }
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content json.RawMessage `json:"content"` Content any `json:"content"`
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
Prefix *bool `json:"prefix,omitempty"` Prefix *bool `json:"prefix,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"` ReasoningContent string `json:"reasoning_content,omitempty"`
Reasoning string `json:"reasoning,omitempty"` Reasoning string `json:"reasoning,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"` ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"` ToolCallId string `json:"tool_call_id,omitempty"`
parsedContent []MediaContent parsedContent []MediaContent
parsedStringContent *string //parsedStringContent *string
} }
type MediaContent struct { type MediaContent struct {
@@ -132,21 +276,65 @@ type MediaContent struct {
func (m *MediaContent) GetImageMedia() *MessageImageUrl { func (m *MediaContent) GetImageMedia() *MessageImageUrl {
if m.ImageUrl != nil { if m.ImageUrl != nil {
return m.ImageUrl.(*MessageImageUrl) if _, ok := m.ImageUrl.(*MessageImageUrl); ok {
return m.ImageUrl.(*MessageImageUrl)
}
if itemMap, ok := m.ImageUrl.(map[string]any); ok {
out := &MessageImageUrl{
Url: common.Interface2String(itemMap["url"]),
Detail: common.Interface2String(itemMap["detail"]),
MimeType: common.Interface2String(itemMap["mime_type"]),
}
return out
}
} }
return nil return nil
} }
func (m *MediaContent) GetInputAudio() *MessageInputAudio { func (m *MediaContent) GetInputAudio() *MessageInputAudio {
if m.InputAudio != nil { if m.InputAudio != nil {
return m.InputAudio.(*MessageInputAudio) if _, ok := m.InputAudio.(*MessageInputAudio); ok {
return m.InputAudio.(*MessageInputAudio)
}
if itemMap, ok := m.InputAudio.(map[string]any); ok {
out := &MessageInputAudio{
Data: common.Interface2String(itemMap["data"]),
Format: common.Interface2String(itemMap["format"]),
}
return out
}
} }
return nil return nil
} }
func (m *MediaContent) GetFile() *MessageFile { func (m *MediaContent) GetFile() *MessageFile {
if m.File != nil { if m.File != nil {
return m.File.(*MessageFile) if _, ok := m.File.(*MessageFile); ok {
return m.File.(*MessageFile)
}
if itemMap, ok := m.File.(map[string]any); ok {
out := &MessageFile{
FileName: common.Interface2String(itemMap["file_name"]),
FileData: common.Interface2String(itemMap["file_data"]),
FileId: common.Interface2String(itemMap["file_id"]),
}
return out
}
}
return nil
}
func (m *MediaContent) GetVideoUrl() *MessageVideoUrl {
if m.VideoUrl != nil {
if _, ok := m.VideoUrl.(*MessageVideoUrl); ok {
return m.VideoUrl.(*MessageVideoUrl)
}
if itemMap, ok := m.VideoUrl.(map[string]any); ok {
out := &MessageVideoUrl{
Url: common.Interface2String(itemMap["url"]),
}
return out
}
} }
return nil return nil
} }
@@ -182,6 +370,7 @@ const (
ContentTypeInputAudio = "input_audio" ContentTypeInputAudio = "input_audio"
ContentTypeFile = "file" ContentTypeFile = "file"
ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别 ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别
//ContentTypeAudioUrl = "audio_url"
) )
func (m *Message) GetPrefix() bool { func (m *Message) GetPrefix() bool {
@@ -212,6 +401,186 @@ func (m *Message) SetToolCalls(toolCalls any) {
} }
func (m *Message) StringContent() string { func (m *Message) StringContent() string {
switch m.Content.(type) {
case string:
return m.Content.(string)
case []any:
var contentStr string
for _, contentItem := range m.Content.([]any) {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == ContentTypeText {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return ""
}
func (m *Message) SetNullContent() {
m.Content = nil
m.parsedContent = nil
}
func (m *Message) SetStringContent(content string) {
m.Content = content
m.parsedContent = nil
}
func (m *Message) SetMediaContent(content []MediaContent) {
m.Content = content
m.parsedContent = content
}
func (m *Message) IsStringContent() bool {
_, ok := m.Content.(string)
if ok {
return true
}
return false
}
func (m *Message) ParseContent() []MediaContent {
if m.Content == nil {
return nil
}
if len(m.parsedContent) > 0 {
return m.parsedContent
}
var contentList []MediaContent
// 先尝试解析为字符串
content, ok := m.Content.(string)
if ok {
contentList = []MediaContent{{
Type: ContentTypeText,
Text: content,
}}
m.parsedContent = contentList
return contentList
}
// 尝试解析为数组
//var arrayContent []map[string]interface{}
arrayContent, ok := m.Content.([]any)
if !ok {
return contentList
}
for _, contentItemAny := range arrayContent {
mediaItem, ok := contentItemAny.(MediaContent)
if ok {
contentList = append(contentList, mediaItem)
continue
}
contentItem, ok := contentItemAny.(map[string]any)
if !ok {
continue
}
contentType, ok := contentItem["type"].(string)
if !ok {
continue
}
switch contentType {
case ContentTypeText:
if text, ok := contentItem["text"].(string); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeText,
Text: text,
})
}
case ContentTypeImageURL:
imageUrl := contentItem["image_url"]
temp := &MessageImageUrl{
Detail: "high",
}
switch v := imageUrl.(type) {
case string:
temp.Url = v
case map[string]interface{}:
url, ok1 := v["url"].(string)
detail, ok2 := v["detail"].(string)
if ok2 {
temp.Detail = detail
}
if ok1 {
temp.Url = url
}
}
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: temp,
})
case ContentTypeInputAudio:
if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
data, ok1 := audioData["data"].(string)
format, ok2 := audioData["format"].(string)
if ok1 && ok2 {
temp := &MessageInputAudio{
Data: data,
Format: format,
}
contentList = append(contentList, MediaContent{
Type: ContentTypeInputAudio,
InputAudio: temp,
})
}
}
case ContentTypeFile:
if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
fileId, ok3 := fileData["file_id"].(string)
if ok3 {
contentList = append(contentList, MediaContent{
Type: ContentTypeFile,
File: &MessageFile{
FileId: fileId,
},
})
} else {
fileName, ok1 := fileData["filename"].(string)
fileDataStr, ok2 := fileData["file_data"].(string)
if ok1 && ok2 {
contentList = append(contentList, MediaContent{
Type: ContentTypeFile,
File: &MessageFile{
FileName: fileName,
FileData: fileDataStr,
},
})
}
}
}
case ContentTypeVideoUrl:
if videoUrl, ok := contentItem["video_url"].(string); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeVideoUrl,
VideoUrl: &MessageVideoUrl{
Url: videoUrl,
},
})
}
}
}
if len(contentList) > 0 {
m.parsedContent = contentList
}
return contentList
}
// old code
/*func (m *Message) StringContent() string {
if m.parsedStringContent != nil { if m.parsedStringContent != nil {
return *m.parsedStringContent return *m.parsedStringContent
} }
@@ -382,33 +751,106 @@ func (m *Message) ParseContent() []MediaContent {
m.parsedContent = contentList m.parsedContent = contentList
} }
return contentList return contentList
} }*/
type WebSearchOptions struct { type WebSearchOptions struct {
SearchContextSize string `json:"search_context_size,omitempty"` SearchContextSize string `json:"search_context_size,omitempty"`
UserLocation json.RawMessage `json:"user_location,omitempty"` UserLocation json.RawMessage `json:"user_location,omitempty"`
} }
// https://platform.openai.com/docs/api-reference/responses/create
type OpenAIResponsesRequest struct { type OpenAIResponsesRequest struct {
Model string `json:"model"` Model string `json:"model"`
Input json.RawMessage `json:"input,omitempty"` Input any `json:"input,omitempty"`
Include json.RawMessage `json:"include,omitempty"` Include json.RawMessage `json:"include,omitempty"`
Instructions json.RawMessage `json:"instructions,omitempty"` Instructions json.RawMessage `json:"instructions,omitempty"`
MaxOutputTokens uint `json:"max_output_tokens,omitempty"` MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
Metadata json.RawMessage `json:"metadata,omitempty"` Metadata json.RawMessage `json:"metadata,omitempty"`
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
PreviousResponseID string `json:"previous_response_id,omitempty"` PreviousResponseID string `json:"previous_response_id,omitempty"`
Reasoning *Reasoning `json:"reasoning,omitempty"` Reasoning *Reasoning `json:"reasoning,omitempty"`
ServiceTier string `json:"service_tier,omitempty"` ServiceTier string `json:"service_tier,omitempty"`
Store bool `json:"store,omitempty"` Store bool `json:"store,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
Text json.RawMessage `json:"text,omitempty"` Text json.RawMessage `json:"text,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"` ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
Tools []ResponsesToolsCall `json:"tools,omitempty"` Tools []map[string]any `json:"tools,omitempty"` // 需要处理的参数很少MCP 参数太多不确定,所以用 map
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
Truncation string `json:"truncation,omitempty"` Truncation string `json:"truncation,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
Prompt json.RawMessage `json:"prompt,omitempty"`
}
func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
var fileMeta = make([]*types.FileMeta, 0)
var texts = make([]string, 0)
if r.Input != nil {
inputs := r.ParseInput()
for _, input := range inputs {
if input.Type == "input_image" {
if input.ImageUrl != "" {
fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeImage,
OriginData: input.ImageUrl,
Detail: input.Detail,
})
}
} else if input.Type == "input_file" {
if input.FileUrl != "" {
fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeFile,
OriginData: input.FileUrl,
})
}
} else {
texts = append(texts, input.Text)
}
}
}
if len(r.Instructions) > 0 {
texts = append(texts, string(r.Instructions))
}
if len(r.Metadata) > 0 {
texts = append(texts, string(r.Metadata))
}
if len(r.Text) > 0 {
texts = append(texts, string(r.Text))
}
if len(r.ToolChoice) > 0 {
texts = append(texts, string(r.ToolChoice))
}
if len(r.Prompt) > 0 {
texts = append(texts, string(r.Prompt))
}
if len(r.Tools) > 0 {
toolStr, _ := common.Marshal(r.Tools)
texts = append(texts, string(toolStr))
}
return &types.TokenCountMeta{
CombineText: strings.Join(texts, "\n"),
Files: fileMeta,
MaxTokens: int(r.MaxOutputTokens),
}
}
func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool {
return r.Stream
}
func (r *OpenAIResponsesRequest) SetModelName(modelName string) {
if modelName != "" {
r.Model = modelName
}
} }
type Reasoning struct { type Reasoning struct {
@@ -416,21 +858,80 @@ type Reasoning struct {
Summary string `json:"summary,omitempty"` Summary string `json:"summary,omitempty"`
} }
type ResponsesToolsCall struct { type MediaInput struct {
Type string `json:"type"` Type string `json:"type"`
// Web Search Text string `json:"text,omitempty"`
UserLocation json.RawMessage `json:"user_location,omitempty"` FileUrl string `json:"file_url,omitempty"`
SearchContextSize string `json:"search_context_size,omitempty"` ImageUrl string `json:"image_url,omitempty"`
// File Search Detail string `json:"detail,omitempty"` // 仅 input_image 有效
VectorStoreIds []string `json:"vector_store_ids,omitempty"` }
MaxNumResults uint `json:"max_num_results,omitempty"`
Filters json.RawMessage `json:"filters,omitempty"` // ParseInput parses the Responses API `input` field into a normalized slice of MediaInput.
// Computer Use // Reference implementation mirrors Message.ParseContent:
DisplayWidth uint `json:"display_width,omitempty"` // - input can be a string, treated as an input_text item
DisplayHeight uint `json:"display_height,omitempty"` // - input can be an array of objects with a `type` field
Environment string `json:"environment,omitempty"` // supported types: input_text, input_image, input_file
// Function func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
Name string `json:"name,omitempty"` if r.Input == nil {
Description string `json:"description,omitempty"` return nil
Parameters json.RawMessage `json:"parameters,omitempty"` }
var inputs []MediaInput
// Try string first
if str, ok := r.Input.(string); ok {
inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
return inputs
}
// Try array of parts
if array, ok := r.Input.([]any); ok {
for _, itemAny := range array {
// Already parsed MediaInput
if media, ok := itemAny.(MediaInput); ok {
inputs = append(inputs, media)
continue
}
// Generic map
item, ok := itemAny.(map[string]any)
if !ok {
continue
}
typeVal, ok := item["type"].(string)
if !ok {
continue
}
switch typeVal {
case "input_text":
text, _ := item["text"].(string)
inputs = append(inputs, MediaInput{Type: "input_text", Text: text})
case "input_image":
// image_url may be string or object with url field
var imageUrl string
switch v := item["image_url"].(type) {
case string:
imageUrl = v
case map[string]any:
if url, ok := v["url"].(string); ok {
imageUrl = url
}
}
inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
case "input_file":
// file_url may be string or object with url field
var fileUrl string
switch v := item["file_url"].(type) {
case string:
fileUrl = v
case map[string]any:
if url, ok := v["url"].(string); ok {
fileUrl = url
}
}
inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
}
}
}
return inputs
} }

View File

@@ -1,10 +1,19 @@
package dto package dto
import "encoding/json" import (
"encoding/json"
"fmt"
"one-api/types"
)
type SimpleResponse struct { type SimpleResponse struct {
Usage `json:"usage"` Usage `json:"usage"`
Error *OpenAIError `json:"error"` Error any `json:"error"`
}
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
func (s *SimpleResponse) GetOpenAIError() *types.OpenAIError {
return GetOpenAIError(s.Error)
} }
type TextResponse struct { type TextResponse struct {
@@ -26,12 +35,17 @@ type OpenAITextResponse struct {
Id string `json:"id"` Id string `json:"id"`
Model string `json:"model"` Model string `json:"model"`
Object string `json:"object"` Object string `json:"object"`
Created int64 `json:"created"` Created any `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"` Choices []OpenAITextResponseChoice `json:"choices"`
Error *OpenAIError `json:"error,omitempty"` Error any `json:"error,omitempty"`
Usage `json:"usage"` Usage `json:"usage"`
} }
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
func (o *OpenAITextResponse) GetOpenAIError() *types.OpenAIError {
return GetOpenAIError(o.Error)
}
type OpenAIEmbeddingResponseItem struct { type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"` Object string `json:"object"`
Index int `json:"index"` Index int `json:"index"`
@@ -45,6 +59,19 @@ type OpenAIEmbeddingResponse struct {
Usage `json:"usage"` Usage `json:"usage"`
} }
type FlexibleEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding any `json:"embedding"`
}
type FlexibleEmbeddingResponse struct {
Object string `json:"object"`
Data []FlexibleEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
}
type ChatCompletionsStreamResponseChoice struct { type ChatCompletionsStreamResponseChoice struct {
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"` Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"`
Logprobs *any `json:"logprobs"` Logprobs *any `json:"logprobs"`
@@ -83,7 +110,7 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string
func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) { func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
c.ReasoningContent = &s c.ReasoningContent = &s
c.Reasoning = &s //c.Reasoning = &s
} }
type ToolCallResponse struct { type ToolCallResponse struct {
@@ -116,6 +143,13 @@ type ChatCompletionsStreamResponse struct {
Usage *Usage `json:"usage"` Usage *Usage `json:"usage"`
} }
func (c *ChatCompletionsStreamResponse) IsFinished() bool {
if len(c.Choices) == 0 {
return false
}
return c.Choices[0].FinishReason != nil && *c.Choices[0].FinishReason != ""
}
func (c *ChatCompletionsStreamResponse) IsToolCall() bool { func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
if len(c.Choices) == 0 { if len(c.Choices) == 0 {
return false return false
@@ -130,6 +164,19 @@ func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse {
return nil return nil
} }
func (c *ChatCompletionsStreamResponse) ClearToolCalls() {
if !c.IsToolCall() {
return
}
for choiceIdx := range c.Choices {
for callIdx := range c.Choices[choiceIdx].Delta.ToolCalls {
c.Choices[choiceIdx].Delta.ToolCalls[callIdx].ID = ""
c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Type = nil
c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Function.Name = ""
}
}
}
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse { func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices)) choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
copy(choices, c.Choices) copy(choices, c.Choices)
@@ -178,6 +225,8 @@ type Usage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"` InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
// OpenRouter Params
Cost any `json:"cost,omitempty"`
} }
type InputTokenDetails struct { type InputTokenDetails struct {
@@ -195,28 +244,33 @@ type OutputTokenDetails struct {
} }
type OpenAIResponsesResponse struct { type OpenAIResponsesResponse struct {
ID string `json:"id"` ID string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
CreatedAt int `json:"created_at"` CreatedAt int `json:"created_at"`
Status string `json:"status"` Status string `json:"status"`
Error *OpenAIError `json:"error,omitempty"` Error any `json:"error,omitempty"`
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
Instructions string `json:"instructions"` Instructions string `json:"instructions"`
MaxOutputTokens int `json:"max_output_tokens"` MaxOutputTokens int `json:"max_output_tokens"`
Model string `json:"model"` Model string `json:"model"`
Output []ResponsesOutput `json:"output"` Output []ResponsesOutput `json:"output"`
ParallelToolCalls bool `json:"parallel_tool_calls"` ParallelToolCalls bool `json:"parallel_tool_calls"`
PreviousResponseID string `json:"previous_response_id"` PreviousResponseID string `json:"previous_response_id"`
Reasoning *Reasoning `json:"reasoning"` Reasoning *Reasoning `json:"reasoning"`
Store bool `json:"store"` Store bool `json:"store"`
Temperature float64 `json:"temperature"` Temperature float64 `json:"temperature"`
ToolChoice string `json:"tool_choice"` ToolChoice string `json:"tool_choice"`
Tools []ResponsesToolsCall `json:"tools"` Tools []map[string]any `json:"tools"`
TopP float64 `json:"top_p"` TopP float64 `json:"top_p"`
Truncation string `json:"truncation"` Truncation string `json:"truncation"`
Usage *Usage `json:"usage"` Usage *Usage `json:"usage"`
User json.RawMessage `json:"user"` User json.RawMessage `json:"user"`
Metadata json.RawMessage `json:"metadata"` Metadata json.RawMessage `json:"metadata"`
}
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
return GetOpenAIError(o.Error)
} }
type IncompleteDetails struct { type IncompleteDetails struct {
@@ -258,3 +312,45 @@ type ResponsesStreamResponse struct {
Delta string `json:"delta,omitempty"` Delta string `json:"delta,omitempty"`
Item *ResponsesOutput `json:"item,omitempty"` Item *ResponsesOutput `json:"item,omitempty"`
} }
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
func GetOpenAIError(errorField any) *types.OpenAIError {
if errorField == nil {
return nil
}
switch err := errorField.(type) {
case types.OpenAIError:
return &err
case *types.OpenAIError:
return err
case map[string]interface{}:
// 处理从JSON解析来的map结构
openaiErr := &types.OpenAIError{}
if errType, ok := err["type"].(string); ok {
openaiErr.Type = errType
}
if errMsg, ok := err["message"].(string); ok {
openaiErr.Message = errMsg
}
if errParam, ok := err["param"].(string); ok {
openaiErr.Param = errParam
}
if errCode, ok := err["code"]; ok {
openaiErr.Code = errCode
}
return openaiErr
case string:
// 处理简单字符串错误
return &types.OpenAIError{
Type: "error",
Message: err,
}
default:
// 未知类型,尝试转换为字符串
return &types.OpenAIError{
Type: "unknown_error",
Message: fmt.Sprintf("%v", err),
}
}
}

View File

@@ -1,26 +1,35 @@
package dto package dto
type OpenAIModelPermission struct { import "one-api/constant"
Id string `json:"id"`
Object string `json:"object"` // 这里不好动就不动了,本来想独立出来的(
Created int `json:"created"` type OpenAIModels struct {
AllowCreateEngine bool `json:"allow_create_engine"` Id string `json:"id"`
AllowSampling bool `json:"allow_sampling"` Object string `json:"object"`
AllowLogprobs bool `json:"allow_logprobs"` Created int `json:"created"`
AllowSearchIndices bool `json:"allow_search_indices"` OwnedBy string `json:"owned_by"`
AllowView bool `json:"allow_view"` SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group *string `json:"group"`
IsBlocking bool `json:"is_blocking"`
} }
type OpenAIModels struct { type AnthropicModel struct {
Id string `json:"id"` ID string `json:"id"`
Object string `json:"object"` CreatedAt string `json:"created_at"`
Created int `json:"created"` DisplayName string `json:"display_name"`
OwnedBy string `json:"owned_by"` Type string `json:"type"`
Permission []OpenAIModelPermission `json:"permission"` }
Root string `json:"root"`
Parent *string `json:"parent"` type GeminiModel struct {
Name interface{} `json:"name"`
BaseModelId interface{} `json:"baseModelId"`
Version interface{} `json:"version"`
DisplayName interface{} `json:"displayName"`
Description interface{} `json:"description"`
InputTokenLimit interface{} `json:"inputTokenLimit"`
OutputTokenLimit interface{} `json:"outputTokenLimit"`
SupportedGenerationMethods []interface{} `json:"supportedGenerationMethods"`
Thinking interface{} `json:"thinking"`
Temperature interface{} `json:"temperature"`
MaxTemperature interface{} `json:"maxTemperature"`
TopP interface{} `json:"topP"`
TopK interface{} `json:"topK"`
} }

Some files were not shown because too many files have changed in this diff Show More