Merge branch 'alpha' into main
This commit is contained in:
@@ -5,3 +5,4 @@
|
|||||||
.gitignore
|
.gitignore
|
||||||
Makefile
|
Makefile
|
||||||
docs
|
docs
|
||||||
|
.eslintcache
|
||||||
12
.env.example
12
.env.example
@@ -7,6 +7,8 @@
|
|||||||
# 调试相关配置
|
# 调试相关配置
|
||||||
# 启用pprof
|
# 启用pprof
|
||||||
# ENABLE_PPROF=true
|
# ENABLE_PPROF=true
|
||||||
|
# 启用调试模式
|
||||||
|
# DEBUG=true
|
||||||
|
|
||||||
# 数据库相关配置
|
# 数据库相关配置
|
||||||
# 数据库连接字符串
|
# 数据库连接字符串
|
||||||
@@ -41,6 +43,14 @@
|
|||||||
# 更新任务启用
|
# 更新任务启用
|
||||||
# UPDATE_TASK=true
|
# UPDATE_TASK=true
|
||||||
|
|
||||||
|
# 对话超时设置
|
||||||
|
# 所有请求超时时间,单位秒,默认为0,表示不限制
|
||||||
|
# RELAY_TIMEOUT=0
|
||||||
|
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
|
||||||
|
# STREAMING_TIMEOUT=300
|
||||||
|
|
||||||
|
# Gemini 识别图片 最大图片数量
|
||||||
|
# GEMINI_VISION_MAX_IMAGE_NUM=16
|
||||||
|
|
||||||
# 会话密钥
|
# 会话密钥
|
||||||
# SESSION_SECRET=random_string
|
# SESSION_SECRET=random_string
|
||||||
@@ -58,8 +68,6 @@
|
|||||||
# GET_MEDIA_TOKEN_NOT_STREAM=true
|
# GET_MEDIA_TOKEN_NOT_STREAM=true
|
||||||
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
|
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
|
||||||
# DIFY_DEBUG=true
|
# DIFY_DEBUG=true
|
||||||
# 设置流式一次回复的超时时间
|
|
||||||
# STREAMING_TIMEOUT=90
|
|
||||||
|
|
||||||
|
|
||||||
# 节点类型
|
# 节点类型
|
||||||
|
|||||||
19
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
19
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
### PR 类型
|
||||||
|
|
||||||
|
- [ ] Bug 修复
|
||||||
|
- [ ] 新功能
|
||||||
|
- [ ] 文档更新
|
||||||
|
- [ ] 其他
|
||||||
|
|
||||||
|
### PR 是否包含破坏性更新?
|
||||||
|
|
||||||
|
- [ ] 是
|
||||||
|
- [ ] 否
|
||||||
|
|
||||||
|
### PR 描述
|
||||||
|
|
||||||
|
**请在下方详细描述您的 PR,包括目的、实现细节等。**
|
||||||
|
|
||||||
|
### **重要提示**
|
||||||
|
|
||||||
|
**所有 PR 都必须提交到 `alpha` 分支。请确保您的 PR 目标分支是 `alpha`。**
|
||||||
@@ -1,14 +1,15 @@
|
|||||||
name: Publish Docker image (amd64)
|
name: Publish Docker image (alpha)
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
branches:
|
||||||
- '*'
|
- alpha
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
name:
|
name:
|
||||||
description: 'reason'
|
description: "reason"
|
||||||
required: false
|
required: false
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
push_to_registries:
|
push_to_registries:
|
||||||
name: Push Docker image to multiple registries
|
name: Push Docker image to multiple registries
|
||||||
@@ -22,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Save version info
|
- name: Save version info
|
||||||
run: |
|
run: |
|
||||||
git describe --tags > VERSION
|
echo "alpha-$(date +'%Y%m%d')-$(git rev-parse --short HEAD)" > VERSION
|
||||||
|
|
||||||
- name: Log in to Docker Hub
|
- name: Log in to Docker Hub
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
@@ -37,6 +38,9 @@ jobs:
|
|||||||
username: ${{ github.actor }}
|
username: ${{ github.actor }}
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
- name: Extract metadata (tags, labels) for Docker
|
- name: Extract metadata (tags, labels) for Docker
|
||||||
id: meta
|
id: meta
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
@@ -44,11 +48,15 @@ jobs:
|
|||||||
images: |
|
images: |
|
||||||
calciumion/new-api
|
calciumion/new-api
|
||||||
ghcr.io/${{ github.repository }}
|
ghcr.io/${{ github.repository }}
|
||||||
|
tags: |
|
||||||
|
type=raw,value=alpha
|
||||||
|
type=raw,value=alpha-{{date 'YYYYMMDD'}}-{{sha}}
|
||||||
|
|
||||||
- name: Build and push Docker images
|
- name: Build and push Docker images
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
|
platforms: linux/amd64,linux/arm64
|
||||||
push: true
|
push: true
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
7
.github/workflows/docker-image-arm64.yml
vendored
7
.github/workflows/docker-image-arm64.yml
vendored
@@ -1,14 +1,9 @@
|
|||||||
name: Publish Docker image (arm64)
|
name: Publish Docker image (Multi Registries)
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- '*'
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
name:
|
|
||||||
description: 'reason'
|
|
||||||
required: false
|
|
||||||
jobs:
|
jobs:
|
||||||
push_to_registries:
|
push_to_registries:
|
||||||
name: Push Docker image to multiple registries
|
name: Push Docker image to multiple registries
|
||||||
|
|||||||
13
.github/workflows/linux-release.yml
vendored
13
.github/workflows/linux-release.yml
vendored
@@ -3,6 +3,11 @@ permissions:
|
|||||||
contents: write
|
contents: write
|
||||||
|
|
||||||
on:
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
name:
|
||||||
|
description: 'reason'
|
||||||
|
required: false
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- '*'
|
||||||
@@ -15,16 +20,16 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
- uses: actions/setup-node@v3
|
- uses: oven-sh/setup-bun@v2
|
||||||
with:
|
with:
|
||||||
node-version: 18
|
bun-version: latest
|
||||||
- name: Build Frontend
|
- name: Build Frontend
|
||||||
env:
|
env:
|
||||||
CI: ""
|
CI: ""
|
||||||
run: |
|
run: |
|
||||||
cd web
|
cd web
|
||||||
npm install
|
bun install
|
||||||
REACT_APP_VERSION=$(git describe --tags) npm run build
|
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||||
cd ..
|
cd ..
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v3
|
uses: actions/setup-go@v3
|
||||||
|
|||||||
14
.github/workflows/macos-release.yml
vendored
14
.github/workflows/macos-release.yml
vendored
@@ -3,6 +3,11 @@ permissions:
|
|||||||
contents: write
|
contents: write
|
||||||
|
|
||||||
on:
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
name:
|
||||||
|
description: 'reason'
|
||||||
|
required: false
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- '*'
|
||||||
@@ -15,16 +20,17 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
- uses: actions/setup-node@v3
|
- uses: oven-sh/setup-bun@v2
|
||||||
with:
|
with:
|
||||||
node-version: 18
|
bun-version: latest
|
||||||
- name: Build Frontend
|
- name: Build Frontend
|
||||||
env:
|
env:
|
||||||
CI: ""
|
CI: ""
|
||||||
|
NODE_OPTIONS: "--max-old-space-size=4096"
|
||||||
run: |
|
run: |
|
||||||
cd web
|
cd web
|
||||||
npm install
|
bun install
|
||||||
REACT_APP_VERSION=$(git describe --tags) npm run build
|
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||||
cd ..
|
cd ..
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v3
|
uses: actions/setup-go@v3
|
||||||
|
|||||||
21
.github/workflows/pr-target-branch-check.yml
vendored
Normal file
21
.github/workflows/pr-target-branch-check.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
name: Check PR Branching Strategy
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types: [opened, synchronize, reopened, edited]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check-branching-strategy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Enforce branching strategy
|
||||||
|
run: |
|
||||||
|
if [[ "${{ github.base_ref }}" == "main" ]]; then
|
||||||
|
if [[ "${{ github.head_ref }}" != "alpha" ]]; then
|
||||||
|
echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
|
||||||
|
echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Branching strategy check passed."
|
||||||
13
.github/workflows/windows-release.yml
vendored
13
.github/workflows/windows-release.yml
vendored
@@ -3,6 +3,11 @@ permissions:
|
|||||||
contents: write
|
contents: write
|
||||||
|
|
||||||
on:
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
name:
|
||||||
|
description: 'reason'
|
||||||
|
required: false
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- '*'
|
||||||
@@ -18,16 +23,16 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
- uses: actions/setup-node@v3
|
- uses: oven-sh/setup-bun@v2
|
||||||
with:
|
with:
|
||||||
node-version: 18
|
bun-version: latest
|
||||||
- name: Build Frontend
|
- name: Build Frontend
|
||||||
env:
|
env:
|
||||||
CI: ""
|
CI: ""
|
||||||
run: |
|
run: |
|
||||||
cd web
|
cd web
|
||||||
npm install
|
bun install
|
||||||
REACT_APP_VERSION=$(git describe --tags) npm run build
|
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||||
cd ..
|
cd ..
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v3
|
uses: actions/setup-go@v3
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -11,3 +11,4 @@ web/dist
|
|||||||
one-api
|
one-api
|
||||||
.DS_Store
|
.DS_Store
|
||||||
tiktoken_cache
|
tiktoken_cache
|
||||||
|
.eslintcache
|
||||||
@@ -2,6 +2,7 @@ FROM oven/bun:latest AS builder
|
|||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
COPY web/package.json .
|
COPY web/package.json .
|
||||||
|
COPY web/bun.lock .
|
||||||
RUN bun install
|
RUN bun install
|
||||||
COPY ./web .
|
COPY ./web .
|
||||||
COPY ./VERSION .
|
COPY ./VERSION .
|
||||||
@@ -24,8 +25,7 @@ RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-
|
|||||||
|
|
||||||
FROM alpine
|
FROM alpine
|
||||||
|
|
||||||
RUN apk update \
|
RUN apk upgrade --no-cache \
|
||||||
&& apk upgrade \
|
|
||||||
&& apk add --no-cache ca-certificates tzdata ffmpeg \
|
&& apk add --no-cache ca-certificates tzdata ffmpeg \
|
||||||
&& update-ca-certificates
|
&& update-ca-certificates
|
||||||
|
|
||||||
|
|||||||
240
LICENSE
240
LICENSE
@@ -1,201 +1,103 @@
|
|||||||
Apache License
|
# **New API 许可协议 (Licensing)**
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
本项目采用**基于使用场景的双重许可 (Usage-Based Dual Licensing)** 模式。
|
||||||
|
|
||||||
1. Definitions.
|
**核心原则:**
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
- **默认许可:** 本项目默认在 **GNU Affero 通用公共许可证 v3.0 (AGPLv3)** 下提供。任何用户在遵守 AGPLv3 条款和下述附加限制的前提下,均可免费使用。
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
- **商业许可:** 在特定商业场景下,或当您希望获得 AGPLv3 之外的权利时,**必须**获取**商业许可证 (Commercial License)**。
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
---
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
## **1. 开源许可证 (Open Source License): AGPLv3 - 适用于基础使用**
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
- 在遵守 **AGPLv3** 条款的前提下,您可以自由地使用、修改和分发 New API。AGPLv3 的完整文本可以访问 [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html) 获取。
|
||||||
exercising permissions granted by this License.
|
- **核心义务:** AGPLv3 的一个关键要求是,如果您修改了 New API 并通过网络提供服务 (SaaS),或者分发了修改后的版本,您必须以 AGPLv3 许可证向所有用户提供相应的**完整源代码**。
|
||||||
|
- **附加限制 (重要):** 在仅使用 AGPLv3 开源许可证的情况下,您**必须**完整保留项目代码中原有的品牌标识、LOGO 及版权声明信息。**禁止以任何形式修改、移除或遮盖**这些信息。如需移除,必须获取商业许可证。
|
||||||
|
- 使用前请务必仔细阅读并理解 AGPLv3 的所有条款及上述附加限制。
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
## **2. 商业许可证 (Commercial License) - 适用于高级场景及闭源需求**
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
在以下任一情况下,您**必须**联系我们获取并签署一份商业许可证,才能合法使用 New API:
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
- **场景一:移除品牌和版权信息**
|
||||||
Object form, made available under the License, as indicated by a
|
您希望在您的产品或服务中移除 New API 的 LOGO、UI界面中的版权声明或其他品牌标识。
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
- **场景二:规避 AGPLv3 开源义务**
|
||||||
form, that is based on (or derived from) the Work and for which the
|
您基于 New API 进行了修改,并希望:
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
- 通过网络提供服务(SaaS),但**不希望**向您的服务用户公开您修改后的源代码。
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
- 分发一个集成了 New API 的软件产品,但**不希望**以 AGPLv3 许可证发布您的产品或公开源代码。
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
- **场景三:企业政策与集成需求**
|
||||||
the original version of the Work and any modifications or additions
|
- 您所在公司的政策、客户合同或项目要求不允许使用 AGPLv3 许可的软件。
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
- 您需要进行 OEM 集成,将 New API 作为您闭源商业产品的一部分进行再分发。
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
- **场景四:需要商业支持与保障**
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
您需要 AGPLv3 未提供的商业保障,如官方技术支持等。
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
**获取商业许可:**
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
请通过电子邮件 **support@quantumnous.com** 联系 New API 团队洽谈商业授权事宜。
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
## **3. 贡献 (Contributions)**
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
- 我们欢迎社区对 New API 的贡献。所有向本项目提交的贡献(例如通过 Pull Request)都将被视为在 **AGPLv3** 许可证下提供。
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
- 通过向本项目提交贡献,即表示您同意您的代码以 AGPLv3 许可证授权给本项目及所有后续使用者(无论这些使用者最终遵循 AGPLv3 还是商业许可)。
|
||||||
modifications, and in Source or Object form, provided that You
|
- 您也理解并同意,您的贡献可能会被包含在根据商业许可证分发的 New API 版本中。
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
## **4. 其他条款 (Other Terms)**
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
- 关于商业许可证的具体条款、条件和价格,以双方签署的正式商业许可协议为准。
|
||||||
stating that You changed the files; and
|
- 项目维护者保留根据需要更新本许可政策的权利。相关更新将通过项目官方渠道(如代码仓库、官方网站)进行通知。
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
---
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
# **New API Licensing**
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
This project uses a **Usage-Based Dual Licensing** model.
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
**Core Principles:**
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
- **Default License:** This project is available by default under the **GNU Affero General Public License v3.0 (AGPLv3)**. Any user may use it free of charge, provided they comply with both the AGPLv3 terms and the additional restrictions listed below.
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
- **Commercial License:** For specific commercial scenarios, or if you require rights beyond those granted by AGPLv3, you **must** obtain a **Commercial License**.
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
---
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
## **1. Open Source License: AGPLv3 – For Basic Usage**
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
- Under the terms of the **AGPLv3**, you are free to use, modify, and distribute New API. The complete AGPLv3 license text can be viewed at [https://www.gnu.org/licenses/agpl-3.0.html](https://www.gnu.org/licenses/agpl-3.0.html).
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
- **Core Obligation:** A key AGPLv3 requirement is that if you modify New API and provide it as a network service (SaaS), or distribute a modified version, you must make the **complete corresponding source code** available to all users under the AGPLv3 license.
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
- **Additional Restriction (Important):** When using only the AGPLv3 open-source license, you **must** retain all original branding, logos, and copyright statements within the project’s code. **You are strictly prohibited from modifying, removing, or concealing** any such information. If you wish to remove this, you must obtain a Commercial License.
|
||||||
or other liability obligations and/or rights consistent with this
|
- Please read and ensure that you fully understand all AGPLv3 terms and the above additional restriction before use.
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
## **2. Commercial License – For Advanced Scenarios & Closed Source Needs**
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
You **must** contact us to obtain and sign a Commercial License in any of the following scenarios in order to legally use New API:
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
- **Scenario 1: Removal of Branding and Copyright**
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
You wish to remove the New API logo, copyright statement, or other branding elements from your product or service.
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
- **Scenario 2: Avoidance of AGPLv3 Open Source Obligations**
|
||||||
|
You have modified New API and wish to:
|
||||||
|
- Offer it as a network service (SaaS) **without** disclosing your modifications' source code to your users.
|
||||||
|
- Distribute a software product integrated with New API **without** releasing your product under AGPLv3 or open-sourcing the code.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
- **Scenario 3: Enterprise Policy & Integration Needs**
|
||||||
you may not use this file except in compliance with the License.
|
- Your organization’s policies, client contracts, or project requirements prohibit the use of AGPLv3-licensed software.
|
||||||
You may obtain a copy of the License at
|
- You require OEM integration and need to redistribute New API as part of your closed-source commercial product.
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
- **Scenario 4: Commercial Support and Assurances**
|
||||||
|
You require commercial assurances not provided by AGPLv3, such as official technical support.
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
**Obtaining a Commercial License:**
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
Please contact the New API team via email at **support@quantumnous.com** to discuss commercial licensing.
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
## **3. Contributions**
|
||||||
limitations under the License.
|
|
||||||
|
- We welcome community contributions to New API. All contributions (e.g., via Pull Request) are deemed to be provided under the **AGPLv3** license.
|
||||||
|
- By submitting a contribution, you agree that your code is licensed to this project and all downstream users under the AGPLv3 license (regardless of whether those users ultimately operate under AGPLv3 or a Commercial License).
|
||||||
|
- You also acknowledge and agree that your contribution may be included in New API releases distributed under a Commercial License.
|
||||||
|
|
||||||
|
## **4. Other Terms**
|
||||||
|
|
||||||
|
- The specific terms, conditions, and pricing of the Commercial License are governed by the formal commercial license agreement executed by both parties.
|
||||||
|
- Project maintainers reserve the right to update this licensing policy as needed. Updates will be communicated via official project channels (e.g., repository, official website).
|
||||||
|
|||||||
24
README.en.md
24
README.en.md
@@ -40,6 +40,28 @@
|
|||||||
> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**, and must not use it for illegal purposes.
|
> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**, and must not use it for illegal purposes.
|
||||||
> - According to the [《Interim Measures for the Management of Generative Artificial Intelligence Services》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), please do not provide any unregistered generative AI services to the public in China.
|
> - According to the [《Interim Measures for the Management of Generative Artificial Intelligence Services》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), please do not provide any unregistered generative AI services to the public in China.
|
||||||
|
|
||||||
|
<h2>🤝 Trusted Partners</h2>
|
||||||
|
<p id="premium-sponsors"> </p>
|
||||||
|
<p align="center"><strong>No particular order</strong></p>
|
||||||
|
<p align="center">
|
||||||
|
<a href="https://www.cherry-ai.com/" target=_blank><img
|
||||||
|
src="./docs/images/cherry-studio.png" alt="Cherry Studio" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://bda.pku.edu.cn/" target=_blank><img
|
||||||
|
src="./docs/images/pku.png" alt="Peking University" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target=_blank><img
|
||||||
|
src="./docs/images/ucloud.png" alt="UCloud" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://www.aliyun.com/" target=_blank><img
|
||||||
|
src="./docs/images/aliyun.png" alt="Alibaba Cloud" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://io.net/" target=_blank><img
|
||||||
|
src="./docs/images/io-net.png" alt="IO.NET" height="120"
|
||||||
|
/></a>
|
||||||
|
</p>
|
||||||
|
<p> </p>
|
||||||
|
|
||||||
## 📚 Documentation
|
## 📚 Documentation
|
||||||
|
|
||||||
For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/)
|
For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/)
|
||||||
@@ -100,7 +122,7 @@ This version supports multiple models, please refer to [API Documentation-Relay
|
|||||||
For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
|
For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
|
||||||
|
|
||||||
- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
|
- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
|
||||||
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 60 seconds
|
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 300 seconds
|
||||||
- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
|
- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
|
||||||
- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
|
- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
|
||||||
- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`
|
- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`
|
||||||
|
|||||||
25
README.md
25
README.md
@@ -40,6 +40,28 @@
|
|||||||
> - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
> - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
||||||
> - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
> - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
||||||
|
|
||||||
|
<h2>🤝 我们信任的合作伙伴</h2>
|
||||||
|
<p id="premium-sponsors"> </p>
|
||||||
|
<p align="center"><strong>排名不分先后</strong></p>
|
||||||
|
<p align="center">
|
||||||
|
<a href="https://www.cherry-ai.com/" target=_blank><img
|
||||||
|
src="./docs/images/cherry-studio.png" alt="Cherry Studio" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://bda.pku.edu.cn/" target=_blank><img
|
||||||
|
src="./docs/images/pku.png" alt="北京大学" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target=_blank><img
|
||||||
|
src="./docs/images/ucloud.png" alt="UCloud 优刻得" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://www.aliyun.com/" target=_blank><img
|
||||||
|
src="./docs/images/aliyun.png" alt="阿里云" height="120"
|
||||||
|
/></a>
|
||||||
|
<a href="https://io.net/" target=_blank><img
|
||||||
|
src="./docs/images/io-net.png" alt="IO.NET" height="120"
|
||||||
|
/></a>
|
||||||
|
</p>
|
||||||
|
<p> </p>
|
||||||
|
|
||||||
## 📚 文档
|
## 📚 文档
|
||||||
|
|
||||||
详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/)
|
详细文档请访问我们的官方Wiki:[https://docs.newapi.pro/](https://docs.newapi.pro/)
|
||||||
@@ -100,7 +122,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
|
|||||||
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables):
|
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables):
|
||||||
|
|
||||||
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
|
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
|
||||||
- `STREAMING_TIMEOUT`:流式回复超时时间,默认60秒
|
- `STREAMING_TIMEOUT`:流式回复超时时间,默认300秒
|
||||||
- `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true`
|
- `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true`
|
||||||
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true`
|
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true`
|
||||||
- `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true`
|
- `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true`
|
||||||
@@ -180,7 +202,6 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
|
|||||||
|
|
||||||
其他基于New API的项目:
|
其他基于New API的项目:
|
||||||
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版
|
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版
|
||||||
- [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的前端美化版本
|
|
||||||
|
|
||||||
## 帮助支持
|
## 帮助支持
|
||||||
|
|
||||||
|
|||||||
75
common/api_type.go
Normal file
75
common/api_type.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "one-api/constant"
|
||||||
|
|
||||||
|
func ChannelType2APIType(channelType int) (int, bool) {
|
||||||
|
apiType := -1
|
||||||
|
switch channelType {
|
||||||
|
case constant.ChannelTypeOpenAI:
|
||||||
|
apiType = constant.APITypeOpenAI
|
||||||
|
case constant.ChannelTypeAnthropic:
|
||||||
|
apiType = constant.APITypeAnthropic
|
||||||
|
case constant.ChannelTypeBaidu:
|
||||||
|
apiType = constant.APITypeBaidu
|
||||||
|
case constant.ChannelTypePaLM:
|
||||||
|
apiType = constant.APITypePaLM
|
||||||
|
case constant.ChannelTypeZhipu:
|
||||||
|
apiType = constant.APITypeZhipu
|
||||||
|
case constant.ChannelTypeAli:
|
||||||
|
apiType = constant.APITypeAli
|
||||||
|
case constant.ChannelTypeXunfei:
|
||||||
|
apiType = constant.APITypeXunfei
|
||||||
|
case constant.ChannelTypeAIProxyLibrary:
|
||||||
|
apiType = constant.APITypeAIProxyLibrary
|
||||||
|
case constant.ChannelTypeTencent:
|
||||||
|
apiType = constant.APITypeTencent
|
||||||
|
case constant.ChannelTypeGemini:
|
||||||
|
apiType = constant.APITypeGemini
|
||||||
|
case constant.ChannelTypeZhipu_v4:
|
||||||
|
apiType = constant.APITypeZhipuV4
|
||||||
|
case constant.ChannelTypeOllama:
|
||||||
|
apiType = constant.APITypeOllama
|
||||||
|
case constant.ChannelTypePerplexity:
|
||||||
|
apiType = constant.APITypePerplexity
|
||||||
|
case constant.ChannelTypeAws:
|
||||||
|
apiType = constant.APITypeAws
|
||||||
|
case constant.ChannelTypeCohere:
|
||||||
|
apiType = constant.APITypeCohere
|
||||||
|
case constant.ChannelTypeDify:
|
||||||
|
apiType = constant.APITypeDify
|
||||||
|
case constant.ChannelTypeJina:
|
||||||
|
apiType = constant.APITypeJina
|
||||||
|
case constant.ChannelCloudflare:
|
||||||
|
apiType = constant.APITypeCloudflare
|
||||||
|
case constant.ChannelTypeSiliconFlow:
|
||||||
|
apiType = constant.APITypeSiliconFlow
|
||||||
|
case constant.ChannelTypeVertexAi:
|
||||||
|
apiType = constant.APITypeVertexAi
|
||||||
|
case constant.ChannelTypeMistral:
|
||||||
|
apiType = constant.APITypeMistral
|
||||||
|
case constant.ChannelTypeDeepSeek:
|
||||||
|
apiType = constant.APITypeDeepSeek
|
||||||
|
case constant.ChannelTypeMokaAI:
|
||||||
|
apiType = constant.APITypeMokaAI
|
||||||
|
case constant.ChannelTypeVolcEngine:
|
||||||
|
apiType = constant.APITypeVolcEngine
|
||||||
|
case constant.ChannelTypeBaiduV2:
|
||||||
|
apiType = constant.APITypeBaiduV2
|
||||||
|
case constant.ChannelTypeOpenRouter:
|
||||||
|
apiType = constant.APITypeOpenRouter
|
||||||
|
case constant.ChannelTypeXinference:
|
||||||
|
apiType = constant.APITypeXinference
|
||||||
|
case constant.ChannelTypeXai:
|
||||||
|
apiType = constant.APITypeXai
|
||||||
|
case constant.ChannelTypeCoze:
|
||||||
|
apiType = constant.APITypeCoze
|
||||||
|
case constant.ChannelTypeJimeng:
|
||||||
|
apiType = constant.APITypeJimeng
|
||||||
|
case constant.ChannelTypeMoonshot:
|
||||||
|
apiType = constant.APITypeMoonshot
|
||||||
|
}
|
||||||
|
if apiType == -1 {
|
||||||
|
return constant.APITypeOpenAI, false
|
||||||
|
}
|
||||||
|
return apiType, true
|
||||||
|
}
|
||||||
@@ -83,6 +83,7 @@ var GitHubClientId = ""
|
|||||||
var GitHubClientSecret = ""
|
var GitHubClientSecret = ""
|
||||||
var LinuxDOClientId = ""
|
var LinuxDOClientId = ""
|
||||||
var LinuxDOClientSecret = ""
|
var LinuxDOClientSecret = ""
|
||||||
|
var LinuxDOMinimumTrustLevel = 0
|
||||||
|
|
||||||
var WeChatServerAddress = ""
|
var WeChatServerAddress = ""
|
||||||
var WeChatServerToken = ""
|
var WeChatServerToken = ""
|
||||||
@@ -195,105 +196,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ChannelTypeUnknown = 0
|
TopUpStatusPending = "pending"
|
||||||
ChannelTypeOpenAI = 1
|
TopUpStatusSuccess = "success"
|
||||||
ChannelTypeMidjourney = 2
|
TopUpStatusExpired = "expired"
|
||||||
ChannelTypeAzure = 3
|
|
||||||
ChannelTypeOllama = 4
|
|
||||||
ChannelTypeMidjourneyPlus = 5
|
|
||||||
ChannelTypeOpenAIMax = 6
|
|
||||||
ChannelTypeOhMyGPT = 7
|
|
||||||
ChannelTypeCustom = 8
|
|
||||||
ChannelTypeAILS = 9
|
|
||||||
ChannelTypeAIProxy = 10
|
|
||||||
ChannelTypePaLM = 11
|
|
||||||
ChannelTypeAPI2GPT = 12
|
|
||||||
ChannelTypeAIGC2D = 13
|
|
||||||
ChannelTypeAnthropic = 14
|
|
||||||
ChannelTypeBaidu = 15
|
|
||||||
ChannelTypeZhipu = 16
|
|
||||||
ChannelTypeAli = 17
|
|
||||||
ChannelTypeXunfei = 18
|
|
||||||
ChannelType360 = 19
|
|
||||||
ChannelTypeOpenRouter = 20
|
|
||||||
ChannelTypeAIProxyLibrary = 21
|
|
||||||
ChannelTypeFastGPT = 22
|
|
||||||
ChannelTypeTencent = 23
|
|
||||||
ChannelTypeGemini = 24
|
|
||||||
ChannelTypeMoonshot = 25
|
|
||||||
ChannelTypeZhipu_v4 = 26
|
|
||||||
ChannelTypePerplexity = 27
|
|
||||||
ChannelTypeLingYiWanWu = 31
|
|
||||||
ChannelTypeAws = 33
|
|
||||||
ChannelTypeCohere = 34
|
|
||||||
ChannelTypeMiniMax = 35
|
|
||||||
ChannelTypeSunoAPI = 36
|
|
||||||
ChannelTypeDify = 37
|
|
||||||
ChannelTypeJina = 38
|
|
||||||
ChannelCloudflare = 39
|
|
||||||
ChannelTypeSiliconFlow = 40
|
|
||||||
ChannelTypeVertexAi = 41
|
|
||||||
ChannelTypeMistral = 42
|
|
||||||
ChannelTypeDeepSeek = 43
|
|
||||||
ChannelTypeMokaAI = 44
|
|
||||||
ChannelTypeVolcEngine = 45
|
|
||||||
ChannelTypeBaiduV2 = 46
|
|
||||||
ChannelTypeXinference = 47
|
|
||||||
ChannelTypeXai = 48
|
|
||||||
ChannelTypeCoze = 49
|
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
|
||||||
"", // 0
|
|
||||||
"https://api.openai.com", // 1
|
|
||||||
"https://oa.api2d.net", // 2
|
|
||||||
"", // 3
|
|
||||||
"http://localhost:11434", // 4
|
|
||||||
"https://api.openai-sb.com", // 5
|
|
||||||
"https://api.openaimax.com", // 6
|
|
||||||
"https://api.ohmygpt.com", // 7
|
|
||||||
"", // 8
|
|
||||||
"https://api.caipacity.com", // 9
|
|
||||||
"https://api.aiproxy.io", // 10
|
|
||||||
"", // 11
|
|
||||||
"https://api.api2gpt.com", // 12
|
|
||||||
"https://api.aigc2d.com", // 13
|
|
||||||
"https://api.anthropic.com", // 14
|
|
||||||
"https://aip.baidubce.com", // 15
|
|
||||||
"https://open.bigmodel.cn", // 16
|
|
||||||
"https://dashscope.aliyuncs.com", // 17
|
|
||||||
"", // 18
|
|
||||||
"https://api.360.cn", // 19
|
|
||||||
"https://openrouter.ai/api", // 20
|
|
||||||
"https://api.aiproxy.io", // 21
|
|
||||||
"https://fastgpt.run/api/openapi", // 22
|
|
||||||
"https://hunyuan.tencentcloudapi.com", //23
|
|
||||||
"https://generativelanguage.googleapis.com", //24
|
|
||||||
"https://api.moonshot.cn", //25
|
|
||||||
"https://open.bigmodel.cn", //26
|
|
||||||
"https://api.perplexity.ai", //27
|
|
||||||
"", //28
|
|
||||||
"", //29
|
|
||||||
"", //30
|
|
||||||
"https://api.lingyiwanwu.com", //31
|
|
||||||
"", //32
|
|
||||||
"", //33
|
|
||||||
"https://api.cohere.ai", //34
|
|
||||||
"https://api.minimax.chat", //35
|
|
||||||
"", //36
|
|
||||||
"https://api.dify.ai", //37
|
|
||||||
"https://api.jina.ai", //38
|
|
||||||
"https://api.cloudflare.com", //39
|
|
||||||
"https://api.siliconflow.cn", //40
|
|
||||||
"", //41
|
|
||||||
"https://api.mistral.ai", //42
|
|
||||||
"https://api.deepseek.com", //43
|
|
||||||
"https://api.moka.ai", //44
|
|
||||||
"https://ark.cn-beijing.volces.com", //45
|
|
||||||
"https://qianfan.baidubce.com", //46
|
|
||||||
"", //47
|
|
||||||
"https://api.x.ai", //48
|
|
||||||
"https://api.coze.cn", //49
|
|
||||||
}
|
|
||||||
|
|||||||
21
common/copy.go
Normal file
21
common/copy.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/antlabs/pcopy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DeepCopy[T any](src *T) (*T, error) {
|
||||||
|
if src == nil {
|
||||||
|
return nil, fmt.Errorf("copy source cannot be nil")
|
||||||
|
}
|
||||||
|
var dst T
|
||||||
|
err := pcopy.Copy(&dst, src)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if &dst == nil {
|
||||||
|
return nil, fmt.Errorf("copy result cannot be nil")
|
||||||
|
}
|
||||||
|
return &dst, nil
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type stringWriter interface {
|
type stringWriter interface {
|
||||||
@@ -52,6 +53,8 @@ type CustomEvent struct {
|
|||||||
Id string
|
Id string
|
||||||
Retry uint
|
Retry uint
|
||||||
Data interface{}
|
Data interface{}
|
||||||
|
|
||||||
|
Mutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func encode(writer io.Writer, event CustomEvent) error {
|
func encode(writer io.Writer, event CustomEvent) error {
|
||||||
@@ -73,6 +76,8 @@ func (r CustomEvent) Render(w http.ResponseWriter) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
|
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
|
||||||
|
r.Mutex.Lock()
|
||||||
|
defer r.Mutex.Unlock()
|
||||||
header := w.Header()
|
header := w.Header()
|
||||||
header["Content-Type"] = contentType
|
header["Content-Type"] = contentType
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,15 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
|
const (
|
||||||
|
DatabaseTypeMySQL = "mysql"
|
||||||
|
DatabaseTypeSQLite = "sqlite"
|
||||||
|
DatabaseTypePostgreSQL = "postgres"
|
||||||
|
)
|
||||||
|
|
||||||
var UsingSQLite = false
|
var UsingSQLite = false
|
||||||
var UsingPostgreSQL = false
|
var UsingPostgreSQL = false
|
||||||
|
var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
|
||||||
var UsingMySQL = false
|
var UsingMySQL = false
|
||||||
var UsingClickHouse = false
|
var UsingClickHouse = false
|
||||||
|
|
||||||
var SQLitePath = "one-api.db?_busy_timeout=5000"
|
var SQLitePath = "one-api.db?_busy_timeout=30000"
|
||||||
|
|||||||
32
common/endpoint_defaults.go
Normal file
32
common/endpoint_defaults.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "one-api/constant"
|
||||||
|
|
||||||
|
// EndpointInfo 描述单个端点的默认请求信息
|
||||||
|
// path: 上游路径
|
||||||
|
// method: HTTP 请求方式,例如 POST/GET
|
||||||
|
// 目前均为 POST,后续可扩展
|
||||||
|
//
|
||||||
|
// json 标签用于直接序列化到 API 输出
|
||||||
|
// 例如:{"path":"/v1/chat/completions","method":"POST"}
|
||||||
|
|
||||||
|
type EndpointInfo struct {
|
||||||
|
Path string `json:"path"`
|
||||||
|
Method string `json:"method"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultEndpointInfoMap 保存内置端点的默认 Path 与 Method
|
||||||
|
var defaultEndpointInfoMap = map[constant.EndpointType]EndpointInfo{
|
||||||
|
constant.EndpointTypeOpenAI: {Path: "/v1/chat/completions", Method: "POST"},
|
||||||
|
constant.EndpointTypeOpenAIResponse: {Path: "/v1/responses", Method: "POST"},
|
||||||
|
constant.EndpointTypeAnthropic: {Path: "/v1/messages", Method: "POST"},
|
||||||
|
constant.EndpointTypeGemini: {Path: "/v1beta/models/{model}:generateContent", Method: "POST"},
|
||||||
|
constant.EndpointTypeJinaRerank: {Path: "/rerank", Method: "POST"},
|
||||||
|
constant.EndpointTypeImageGeneration: {Path: "/v1/images/generations", Method: "POST"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefaultEndpointInfo 返回指定端点类型的默认信息以及是否存在
|
||||||
|
func GetDefaultEndpointInfo(et constant.EndpointType) (EndpointInfo, bool) {
|
||||||
|
info, ok := defaultEndpointInfoMap[et]
|
||||||
|
return info, ok
|
||||||
|
}
|
||||||
41
common/endpoint_type.go
Normal file
41
common/endpoint_type.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "one-api/constant"
|
||||||
|
|
||||||
|
// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
|
||||||
|
func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
|
||||||
|
var endpointTypes []constant.EndpointType
|
||||||
|
switch channelType {
|
||||||
|
case constant.ChannelTypeJina:
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
|
||||||
|
//case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus:
|
||||||
|
// endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney}
|
||||||
|
//case constant.ChannelTypeSunoAPI:
|
||||||
|
// endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno}
|
||||||
|
//case constant.ChannelTypeKling:
|
||||||
|
// endpointTypes = []constant.EndpointType{constant.EndpointTypeKling}
|
||||||
|
//case constant.ChannelTypeJimeng:
|
||||||
|
// endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng}
|
||||||
|
case constant.ChannelTypeAws:
|
||||||
|
fallthrough
|
||||||
|
case constant.ChannelTypeAnthropic:
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
|
||||||
|
case constant.ChannelTypeVertexAi:
|
||||||
|
fallthrough
|
||||||
|
case constant.ChannelTypeGemini:
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
|
||||||
|
case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||||
|
default:
|
||||||
|
if IsOpenAIResponseOnlyModel(modelName) {
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
|
||||||
|
} else {
|
||||||
|
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if IsImageGenerationModel(modelName) {
|
||||||
|
// add to first
|
||||||
|
endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...)
|
||||||
|
}
|
||||||
|
return endpointTypes
|
||||||
|
}
|
||||||
@@ -2,10 +2,13 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/constant"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
const KeyRequestBody = "key_request_body"
|
const KeyRequestBody = "key_request_body"
|
||||||
@@ -29,9 +32,12 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
//if DebugEnabled {
|
||||||
|
// println("UnmarshalBodyReusable request body:", string(requestBody))
|
||||||
|
//}
|
||||||
contentType := c.Request.Header.Get("Content-Type")
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
if strings.HasPrefix(contentType, "application/json") {
|
if strings.HasPrefix(contentType, "application/json") {
|
||||||
err = json.Unmarshal(requestBody, &v)
|
err = Unmarshal(requestBody, &v)
|
||||||
} else {
|
} else {
|
||||||
// skip for now
|
// skip for now
|
||||||
// TODO: someday non json request have variant model, we will need to implementation this
|
// TODO: someday non json request have variant model, we will need to implementation this
|
||||||
@@ -43,3 +49,67 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
|||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
|
||||||
|
c.Set(string(key), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
|
||||||
|
return c.Get(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
|
||||||
|
return c.GetString(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
|
||||||
|
return c.GetInt(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
|
||||||
|
return c.GetBool(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
|
||||||
|
return c.GetStringSlice(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
|
||||||
|
return c.GetStringMap(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
|
||||||
|
return c.GetTime(string(key))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
|
||||||
|
if value, ok := c.Get(string(key)); ok {
|
||||||
|
if v, ok := value.(T); ok {
|
||||||
|
return v, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var t T
|
||||||
|
return t, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApiError(c *gin.Context, err error) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApiErrorMsg(c *gin.Context, msg string) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": msg,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApiSuccess(c *gin.Context, data any) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
34
common/hash.go
Normal file
34
common/hash.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Sha256Raw(data []byte) []byte {
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write(data)
|
||||||
|
return h.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Sha1Raw(data []byte) []byte {
|
||||||
|
h := sha1.New()
|
||||||
|
h.Write(data)
|
||||||
|
return h.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Sha1(data []byte) string {
|
||||||
|
return hex.EncodeToString(Sha1Raw(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
func HmacSha256Raw(message, key []byte) []byte {
|
||||||
|
h := hmac.New(sha256.New, key)
|
||||||
|
h.Write(message)
|
||||||
|
return h.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func HmacSha256(message, key string) string {
|
||||||
|
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"one-api/constant"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -24,7 +25,7 @@ func printHelp() {
|
|||||||
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadEnv() {
|
func InitEnv() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
if *PrintVersion {
|
if *PrintVersion {
|
||||||
@@ -95,4 +96,25 @@ func LoadEnv() {
|
|||||||
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
||||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||||
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
||||||
|
|
||||||
|
initConstantEnv()
|
||||||
|
}
|
||||||
|
|
||||||
|
func initConstantEnv() {
|
||||||
|
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 300)
|
||||||
|
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||||
|
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||||
|
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||||
|
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||||
|
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||||
|
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||||
|
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||||
|
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
||||||
|
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||||
|
constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||||
|
constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||||
|
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||||
|
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||||
|
// 是否启用错误日志
|
||||||
|
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,14 +5,18 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
func DecodeJson(data []byte, v any) error {
|
func Unmarshal(data []byte, v any) error {
|
||||||
return json.NewDecoder(bytes.NewReader(data)).Decode(v)
|
return json.Unmarshal(data, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecodeJsonStr(data string, v any) error {
|
func UnmarshalJsonStr(data string, v any) error {
|
||||||
return DecodeJson(StringToByteSlice(data), v)
|
return json.Unmarshal(StringToByteSlice(data), v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func EncodeJson(v any) ([]byte, error) {
|
func DecodeJson(reader *bytes.Reader, v any) error {
|
||||||
|
return json.NewDecoder(reader).Decode(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Marshal(v any) ([]byte, error) {
|
||||||
return json.Marshal(v)
|
return json.Marshal(v)
|
||||||
}
|
}
|
||||||
|
|||||||
42
common/model.go
Normal file
42
common/model.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
var (
|
||||||
|
// OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
|
||||||
|
OpenAIResponseOnlyModels = []string{
|
||||||
|
"o3-pro",
|
||||||
|
"o3-deep-research",
|
||||||
|
"o4-mini-deep-research",
|
||||||
|
}
|
||||||
|
ImageGenerationModels = []string{
|
||||||
|
"dall-e-3",
|
||||||
|
"dall-e-2",
|
||||||
|
"gpt-image-1",
|
||||||
|
"prefix:imagen-",
|
||||||
|
"flux-",
|
||||||
|
"flux.1-",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func IsOpenAIResponseOnlyModel(modelName string) bool {
|
||||||
|
for _, m := range OpenAIResponseOnlyModels {
|
||||||
|
if strings.Contains(modelName, m) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsImageGenerationModel(modelName string) bool {
|
||||||
|
modelName = strings.ToLower(modelName)
|
||||||
|
for _, m := range ImageGenerationModels {
|
||||||
|
if strings.Contains(modelName, m) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
82
common/page_info.go
Normal file
82
common/page_info.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PageInfo struct {
|
||||||
|
Page int `json:"page"` // page num 页码
|
||||||
|
PageSize int `json:"page_size"` // page size 页大小
|
||||||
|
|
||||||
|
Total int `json:"total"` // 总条数,后设置
|
||||||
|
Items any `json:"items"` // 数据,后设置
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PageInfo) GetStartIdx() int {
|
||||||
|
return (p.Page - 1) * p.PageSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PageInfo) GetEndIdx() int {
|
||||||
|
return p.Page * p.PageSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PageInfo) GetPageSize() int {
|
||||||
|
return p.PageSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PageInfo) GetPage() int {
|
||||||
|
return p.Page
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PageInfo) SetTotal(total int) {
|
||||||
|
p.Total = total
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PageInfo) SetItems(items any) {
|
||||||
|
p.Items = items
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPageQuery(c *gin.Context) *PageInfo {
|
||||||
|
pageInfo := &PageInfo{}
|
||||||
|
// 手动获取并处理每个参数
|
||||||
|
if page, err := strconv.Atoi(c.Query("p")); err == nil {
|
||||||
|
pageInfo.Page = page
|
||||||
|
}
|
||||||
|
if pageSize, err := strconv.Atoi(c.Query("page_size")); err == nil {
|
||||||
|
pageInfo.PageSize = pageSize
|
||||||
|
}
|
||||||
|
if pageInfo.Page < 1 {
|
||||||
|
// 兼容
|
||||||
|
page, _ := strconv.Atoi(c.Query("p"))
|
||||||
|
if page != 0 {
|
||||||
|
pageInfo.Page = page
|
||||||
|
} else {
|
||||||
|
pageInfo.Page = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pageInfo.PageSize == 0 {
|
||||||
|
// 兼容
|
||||||
|
pageSize, _ := strconv.Atoi(c.Query("ps"))
|
||||||
|
if pageSize != 0 {
|
||||||
|
pageInfo.PageSize = pageSize
|
||||||
|
}
|
||||||
|
if pageInfo.PageSize == 0 {
|
||||||
|
pageSize, _ = strconv.Atoi(c.Query("size")) // token page
|
||||||
|
if pageSize != 0 {
|
||||||
|
pageInfo.PageSize = pageSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if pageInfo.PageSize == 0 {
|
||||||
|
pageInfo.PageSize = ItemsPerPage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if pageInfo.PageSize > 100 {
|
||||||
|
pageInfo.PageSize = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
return pageInfo
|
||||||
|
}
|
||||||
5
common/quota.go
Normal file
5
common/quota.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
func GetTrustQuota() int {
|
||||||
|
return int(10 * QuotaPerUnit)
|
||||||
|
}
|
||||||
@@ -16,6 +16,10 @@ import (
|
|||||||
var RDB *redis.Client
|
var RDB *redis.Client
|
||||||
var RedisEnabled = true
|
var RedisEnabled = true
|
||||||
|
|
||||||
|
func RedisKeyCacheSeconds() int {
|
||||||
|
return SyncFrequency
|
||||||
|
}
|
||||||
|
|
||||||
// InitRedisClient This function is called after init()
|
// InitRedisClient This function is called after init()
|
||||||
func InitRedisClient() (err error) {
|
func InitRedisClient() (err error) {
|
||||||
if os.Getenv("REDIS_CONN_STRING") == "" {
|
if os.Getenv("REDIS_CONN_STRING") == "" {
|
||||||
@@ -92,12 +96,12 @@ func RedisDel(key string) error {
|
|||||||
return RDB.Del(ctx, key).Err()
|
return RDB.Del(ctx, key).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func RedisHDelObj(key string) error {
|
func RedisDelKey(key string) error {
|
||||||
if DebugEnabled {
|
if DebugEnabled {
|
||||||
SysLog(fmt.Sprintf("Redis HDEL: key=%s", key))
|
SysLog(fmt.Sprintf("Redis DEL Key: key=%s", key))
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
return RDB.HDel(ctx, key).Err()
|
return RDB.Del(ctx, key).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
|
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
|
||||||
@@ -141,7 +145,11 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
|
|||||||
|
|
||||||
txn := RDB.TxPipeline()
|
txn := RDB.TxPipeline()
|
||||||
txn.HSet(ctx, key, data)
|
txn.HSet(ctx, key, data)
|
||||||
txn.Expire(ctx, key, expiration)
|
|
||||||
|
// 只有在 expiration 大于 0 时才设置过期时间
|
||||||
|
if expiration > 0 {
|
||||||
|
txn.Expire(ctx, key, expiration)
|
||||||
|
}
|
||||||
|
|
||||||
_, err := txn.Exec(ctx)
|
_, err := txn.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
177
common/str.go
177
common/str.go
@@ -1,9 +1,13 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,16 +35,30 @@ func MapToJsonStr(m map[string]interface{}) string {
|
|||||||
return string(bytes)
|
return string(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func StrToMap(str string) map[string]interface{} {
|
func StrToMap(str string) (map[string]interface{}, error) {
|
||||||
m := make(map[string]interface{})
|
m := make(map[string]interface{})
|
||||||
err := json.Unmarshal([]byte(str), &m)
|
err := Unmarshal([]byte(str), &m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil, err
|
||||||
}
|
}
|
||||||
return m
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsJsonStr(str string) bool {
|
func StrToJsonArray(str string) ([]interface{}, error) {
|
||||||
|
var js []interface{}
|
||||||
|
err := json.Unmarshal([]byte(str), &js)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return js, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsJsonArray(str string) bool {
|
||||||
|
var js []interface{}
|
||||||
|
return json.Unmarshal([]byte(str), &js) == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsJsonObject(str string) bool {
|
||||||
var js map[string]interface{}
|
var js map[string]interface{}
|
||||||
return json.Unmarshal([]byte(str), &js) == nil
|
return json.Unmarshal([]byte(str), &js) == nil
|
||||||
}
|
}
|
||||||
@@ -68,3 +86,152 @@ func StringToByteSlice(s string) []byte {
|
|||||||
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
|
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
|
||||||
return *(*[]byte)(unsafe.Pointer(&tmp2))
|
return *(*[]byte)(unsafe.Pointer(&tmp2))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func EncodeBase64(str string) string {
|
||||||
|
return base64.StdEncoding.EncodeToString([]byte(str))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetJsonString(data any) string {
|
||||||
|
if data == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(data)
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaskEmail masks a user email to prevent PII leakage in logs
|
||||||
|
// Returns "***masked***" if email is empty, otherwise shows only the domain part
|
||||||
|
func MaskEmail(email string) string {
|
||||||
|
if email == "" {
|
||||||
|
return "***masked***"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the @ symbol
|
||||||
|
atIndex := strings.Index(email, "@")
|
||||||
|
if atIndex == -1 {
|
||||||
|
// No @ symbol found, return masked
|
||||||
|
return "***masked***"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return only the domain part with @ symbol
|
||||||
|
return "***@" + email[atIndex+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// maskHostTail returns the tail parts of a domain/host that should be preserved.
|
||||||
|
// It keeps 2 parts for likely country-code TLDs (e.g., co.uk, com.cn), otherwise keeps only the TLD.
|
||||||
|
func maskHostTail(parts []string) []string {
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
lastPart := parts[len(parts)-1]
|
||||||
|
secondLastPart := parts[len(parts)-2]
|
||||||
|
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
|
||||||
|
// Likely country code TLD like co.uk, com.cn
|
||||||
|
return []string{secondLastPart, lastPart}
|
||||||
|
}
|
||||||
|
return []string{lastPart}
|
||||||
|
}
|
||||||
|
|
||||||
|
// maskHostForURL collapses subdomains and keeps only masked prefix + preserved tail.
|
||||||
|
// Example: api.openai.com -> ***.com, sub.domain.co.uk -> ***.co.uk
|
||||||
|
func maskHostForURL(host string) string {
|
||||||
|
parts := strings.Split(host, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return "***"
|
||||||
|
}
|
||||||
|
tail := maskHostTail(parts)
|
||||||
|
return "***." + strings.Join(tail, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
// maskHostForPlainDomain masks a plain domain and reflects subdomain depth with multiple ***.
|
||||||
|
// Example: openai.com -> ***.com, api.openai.com -> ***.***.com, sub.domain.co.uk -> ***.***.co.uk
|
||||||
|
func maskHostForPlainDomain(domain string) string {
|
||||||
|
parts := strings.Split(domain, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return domain
|
||||||
|
}
|
||||||
|
tail := maskHostTail(parts)
|
||||||
|
numStars := len(parts) - len(tail)
|
||||||
|
if numStars < 1 {
|
||||||
|
numStars = 1
|
||||||
|
}
|
||||||
|
stars := strings.TrimSuffix(strings.Repeat("***.", numStars), ".")
|
||||||
|
return stars + "." + strings.Join(tail, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
|
||||||
|
// Example:
|
||||||
|
// http://example.com -> http://***.com
|
||||||
|
// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
|
||||||
|
// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
|
||||||
|
// 192.168.1.1 -> ***.***.***.***
|
||||||
|
// openai.com -> ***.com
|
||||||
|
// www.openai.com -> ***.***.com
|
||||||
|
// api.openai.com -> ***.***.com
|
||||||
|
func MaskSensitiveInfo(str string) string {
|
||||||
|
// Mask URLs
|
||||||
|
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
|
||||||
|
str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
|
||||||
|
u, err := url.Parse(urlStr)
|
||||||
|
if err != nil {
|
||||||
|
return urlStr
|
||||||
|
}
|
||||||
|
|
||||||
|
host := u.Host
|
||||||
|
if host == "" {
|
||||||
|
return urlStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mask host with unified logic
|
||||||
|
maskedHost := maskHostForURL(host)
|
||||||
|
|
||||||
|
result := u.Scheme + "://" + maskedHost
|
||||||
|
|
||||||
|
// Mask path
|
||||||
|
if u.Path != "" && u.Path != "/" {
|
||||||
|
pathParts := strings.Split(strings.Trim(u.Path, "/"), "/")
|
||||||
|
maskedPathParts := make([]string, len(pathParts))
|
||||||
|
for i := range pathParts {
|
||||||
|
if pathParts[i] != "" {
|
||||||
|
maskedPathParts[i] = "***"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(maskedPathParts) > 0 {
|
||||||
|
result += "/" + strings.Join(maskedPathParts, "/")
|
||||||
|
}
|
||||||
|
} else if u.Path == "/" {
|
||||||
|
result += "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mask query parameters
|
||||||
|
if u.RawQuery != "" {
|
||||||
|
values, err := url.ParseQuery(u.RawQuery)
|
||||||
|
if err != nil {
|
||||||
|
// If can't parse query, just mask the whole query string
|
||||||
|
result += "?***"
|
||||||
|
} else {
|
||||||
|
maskedParams := make([]string, 0, len(values))
|
||||||
|
for key := range values {
|
||||||
|
maskedParams = append(maskedParams, key+"=***")
|
||||||
|
}
|
||||||
|
if len(maskedParams) > 0 {
|
||||||
|
result += "?" + strings.Join(maskedParams, "&")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
})
|
||||||
|
|
||||||
|
// Mask domain names without protocol (like openai.com, www.openai.com)
|
||||||
|
domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
|
||||||
|
str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
|
||||||
|
return maskHostForPlainDomain(domain)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Mask IP addresses
|
||||||
|
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
|
||||||
|
str = ipPattern.ReplaceAllString(str, "***.***.***.***")
|
||||||
|
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|||||||
24
common/sys_log.go
Normal file
24
common/sys_log.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SysLog(s string) {
|
||||||
|
t := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SysError(s string) {
|
||||||
|
t := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FatalLog(v ...any) {
|
||||||
|
t := time.Now()
|
||||||
|
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
150
common/totp.go
Normal file
150
common/totp.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pquerna/otp"
|
||||||
|
"github.com/pquerna/otp/totp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// 备用码配置
|
||||||
|
BackupCodeLength = 8 // 备用码长度
|
||||||
|
BackupCodeCount = 4 // 生成备用码数量
|
||||||
|
|
||||||
|
// 限制配置
|
||||||
|
MaxFailAttempts = 5 // 最大失败尝试次数
|
||||||
|
LockoutDuration = 300 // 锁定时间(秒)
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateTOTPSecret 生成TOTP密钥和配置
|
||||||
|
func GenerateTOTPSecret(accountName string) (*otp.Key, error) {
|
||||||
|
issuer := Get2FAIssuer()
|
||||||
|
return totp.Generate(totp.GenerateOpts{
|
||||||
|
Issuer: issuer,
|
||||||
|
AccountName: accountName,
|
||||||
|
Period: 30,
|
||||||
|
Digits: otp.DigitsSix,
|
||||||
|
Algorithm: otp.AlgorithmSHA1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateTOTPCode 验证TOTP验证码
|
||||||
|
func ValidateTOTPCode(secret, code string) bool {
|
||||||
|
// 清理验证码格式
|
||||||
|
cleanCode := strings.ReplaceAll(code, " ", "")
|
||||||
|
if len(cleanCode) != 6 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证验证码
|
||||||
|
return totp.Validate(cleanCode, secret)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateBackupCodes 生成备用恢复码
|
||||||
|
func GenerateBackupCodes() ([]string, error) {
|
||||||
|
codes := make([]string, BackupCodeCount)
|
||||||
|
|
||||||
|
for i := 0; i < BackupCodeCount; i++ {
|
||||||
|
code, err := generateRandomBackupCode()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
codes[i] = code
|
||||||
|
}
|
||||||
|
|
||||||
|
return codes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRandomBackupCode 生成单个备用码
|
||||||
|
func generateRandomBackupCode() (string, error) {
|
||||||
|
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
|
code := make([]byte, BackupCodeLength)
|
||||||
|
|
||||||
|
for i := range code {
|
||||||
|
randomBytes := make([]byte, 1)
|
||||||
|
_, err := rand.Read(randomBytes)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
code[i] = charset[int(randomBytes[0])%len(charset)]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 格式化为 XXXX-XXXX 格式
|
||||||
|
return fmt.Sprintf("%s-%s", string(code[:4]), string(code[4:])), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateBackupCode 验证备用码格式
|
||||||
|
func ValidateBackupCode(code string) bool {
|
||||||
|
// 移除所有分隔符并转为大写
|
||||||
|
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
|
||||||
|
if len(cleanCode) != BackupCodeLength {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查字符是否合法
|
||||||
|
for _, char := range cleanCode {
|
||||||
|
if !((char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9')) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeBackupCode 标准化备用码格式
|
||||||
|
func NormalizeBackupCode(code string) string {
|
||||||
|
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
|
||||||
|
if len(cleanCode) == BackupCodeLength {
|
||||||
|
return fmt.Sprintf("%s-%s", cleanCode[:4], cleanCode[4:])
|
||||||
|
}
|
||||||
|
return code
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashBackupCode 对备用码进行哈希
|
||||||
|
func HashBackupCode(code string) (string, error) {
|
||||||
|
normalizedCode := NormalizeBackupCode(code)
|
||||||
|
return Password2Hash(normalizedCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get2FAIssuer 获取2FA发行者名称
|
||||||
|
func Get2FAIssuer() string {
|
||||||
|
return SystemName
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEnvOrDefault 获取环境变量或默认值
|
||||||
|
func getEnvOrDefault(key, defaultValue string) string {
|
||||||
|
if value, exists := os.LookupEnv(key); exists {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateNumericCode 验证数字验证码格式
|
||||||
|
func ValidateNumericCode(code string) (string, error) {
|
||||||
|
// 移除空格
|
||||||
|
code = strings.ReplaceAll(code, " ", "")
|
||||||
|
|
||||||
|
if len(code) != 6 {
|
||||||
|
return "", fmt.Errorf("验证码必须是6位数字")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否为纯数字
|
||||||
|
if _, err := strconv.Atoi(code); err != nil {
|
||||||
|
return "", fmt.Errorf("验证码只能包含数字")
|
||||||
|
}
|
||||||
|
|
||||||
|
return code, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateQRCodeData 生成二维码数据
|
||||||
|
func GenerateQRCodeData(secret, username string) string {
|
||||||
|
issuer := Get2FAIssuer()
|
||||||
|
accountName := fmt.Sprintf("%s (%s)", username, issuer)
|
||||||
|
return fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&digits=6&period=30",
|
||||||
|
issuer, accountName, secret, issuer)
|
||||||
|
}
|
||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"math/big"
|
"math/big"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -249,13 +250,55 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAudioDuration returns the duration of an audio file in seconds.
|
// GetAudioDuration returns the duration of an audio file in seconds.
|
||||||
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
|
func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
|
||||||
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
|
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
|
||||||
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
|
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
|
||||||
output, err := c.Output()
|
output, err := c.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Wrap(err, "failed to get audio duration")
|
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||||
}
|
}
|
||||||
|
durationStr := string(bytes.TrimSpace(output))
|
||||||
|
if durationStr == "N/A" {
|
||||||
|
// Create a temporary output file name
|
||||||
|
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.Wrap(err, "failed to create temporary file")
|
||||||
|
}
|
||||||
|
tmpName := tmpFp.Name()
|
||||||
|
// Close immediately so ffmpeg can open the file on Windows.
|
||||||
|
_ = tmpFp.Close()
|
||||||
|
defer os.Remove(tmpName)
|
||||||
|
|
||||||
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
|
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
||||||
|
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
||||||
|
if err := ffmpegCmd.Run(); err != nil {
|
||||||
|
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recalculate the duration of the new file
|
||||||
|
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
||||||
|
output, err := c.Output()
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
||||||
|
}
|
||||||
|
durationStr = string(bytes.TrimSpace(output))
|
||||||
|
}
|
||||||
|
return strconv.ParseFloat(durationStr, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildURL concatenates base and endpoint, returns the complete url string
|
||||||
|
func BuildURL(base string, endpoint string) string {
|
||||||
|
u, err := url.Parse(base)
|
||||||
|
if err != nil {
|
||||||
|
return base + endpoint
|
||||||
|
}
|
||||||
|
end := endpoint
|
||||||
|
if end == "" {
|
||||||
|
end = "/"
|
||||||
|
}
|
||||||
|
ref, err := url.Parse(end)
|
||||||
|
if err != nil {
|
||||||
|
return base + endpoint
|
||||||
|
}
|
||||||
|
return u.ResolveReference(ref).String()
|
||||||
}
|
}
|
||||||
|
|||||||
26
constant/README.md
Normal file
26
constant/README.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# constant 包 (`/constant`)
|
||||||
|
|
||||||
|
该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
|
||||||
|
|
||||||
|
## 当前文件
|
||||||
|
|
||||||
|
| 文件 | 说明 |
|
||||||
|
|----------------------|---------------------------------------------------------------------|
|
||||||
|
| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 |
|
||||||
|
| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 |
|
||||||
|
| `channel_setting.go` | Channel 级别的设置键,如 `proxy`、`force_format` 等。 |
|
||||||
|
| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量(请求时间、Token/Channel/User 相关信息等)。 |
|
||||||
|
| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 |
|
||||||
|
| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 |
|
||||||
|
| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 |
|
||||||
|
| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 |
|
||||||
|
| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 |
|
||||||
|
| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 |
|
||||||
|
|
||||||
|
## 使用约定
|
||||||
|
|
||||||
|
1. `constant` 包**只能被其他包引用**(import),**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**。
|
||||||
|
2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
|
||||||
|
3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
|
||||||
|
|
||||||
|
> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。
|
||||||
36
constant/api_type.go
Normal file
36
constant/api_type.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
const (
|
||||||
|
APITypeOpenAI = iota
|
||||||
|
APITypeAnthropic
|
||||||
|
APITypePaLM
|
||||||
|
APITypeBaidu
|
||||||
|
APITypeZhipu
|
||||||
|
APITypeAli
|
||||||
|
APITypeXunfei
|
||||||
|
APITypeAIProxyLibrary
|
||||||
|
APITypeTencent
|
||||||
|
APITypeGemini
|
||||||
|
APITypeZhipuV4
|
||||||
|
APITypeOllama
|
||||||
|
APITypePerplexity
|
||||||
|
APITypeAws
|
||||||
|
APITypeCohere
|
||||||
|
APITypeDify
|
||||||
|
APITypeJina
|
||||||
|
APITypeCloudflare
|
||||||
|
APITypeSiliconFlow
|
||||||
|
APITypeVertexAi
|
||||||
|
APITypeMistral
|
||||||
|
APITypeDeepSeek
|
||||||
|
APITypeMokaAI
|
||||||
|
APITypeVolcEngine
|
||||||
|
APITypeBaiduV2
|
||||||
|
APITypeOpenRouter
|
||||||
|
APITypeXinference
|
||||||
|
APITypeXai
|
||||||
|
APITypeCoze
|
||||||
|
APITypeJimeng
|
||||||
|
APITypeMoonshot // this one is only for count, do not add any channel after this
|
||||||
|
APITypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
)
|
||||||
@@ -1,14 +1,5 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
import "one-api/common"
|
|
||||||
|
|
||||||
var (
|
|
||||||
TokenCacheSeconds = common.SyncFrequency
|
|
||||||
UserId2GroupCacheSeconds = common.SyncFrequency
|
|
||||||
UserId2QuotaCacheSeconds = common.SyncFrequency
|
|
||||||
UserId2StatusCacheSeconds = common.SyncFrequency
|
|
||||||
)
|
|
||||||
|
|
||||||
// Cache keys
|
// Cache keys
|
||||||
const (
|
const (
|
||||||
UserGroupKeyFmt = "user_group:%d"
|
UserGroupKeyFmt = "user_group:%d"
|
||||||
|
|||||||
111
constant/channel.go
Normal file
111
constant/channel.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
const (
|
||||||
|
ChannelTypeUnknown = 0
|
||||||
|
ChannelTypeOpenAI = 1
|
||||||
|
ChannelTypeMidjourney = 2
|
||||||
|
ChannelTypeAzure = 3
|
||||||
|
ChannelTypeOllama = 4
|
||||||
|
ChannelTypeMidjourneyPlus = 5
|
||||||
|
ChannelTypeOpenAIMax = 6
|
||||||
|
ChannelTypeOhMyGPT = 7
|
||||||
|
ChannelTypeCustom = 8
|
||||||
|
ChannelTypeAILS = 9
|
||||||
|
ChannelTypeAIProxy = 10
|
||||||
|
ChannelTypePaLM = 11
|
||||||
|
ChannelTypeAPI2GPT = 12
|
||||||
|
ChannelTypeAIGC2D = 13
|
||||||
|
ChannelTypeAnthropic = 14
|
||||||
|
ChannelTypeBaidu = 15
|
||||||
|
ChannelTypeZhipu = 16
|
||||||
|
ChannelTypeAli = 17
|
||||||
|
ChannelTypeXunfei = 18
|
||||||
|
ChannelType360 = 19
|
||||||
|
ChannelTypeOpenRouter = 20
|
||||||
|
ChannelTypeAIProxyLibrary = 21
|
||||||
|
ChannelTypeFastGPT = 22
|
||||||
|
ChannelTypeTencent = 23
|
||||||
|
ChannelTypeGemini = 24
|
||||||
|
ChannelTypeMoonshot = 25
|
||||||
|
ChannelTypeZhipu_v4 = 26
|
||||||
|
ChannelTypePerplexity = 27
|
||||||
|
ChannelTypeLingYiWanWu = 31
|
||||||
|
ChannelTypeAws = 33
|
||||||
|
ChannelTypeCohere = 34
|
||||||
|
ChannelTypeMiniMax = 35
|
||||||
|
ChannelTypeSunoAPI = 36
|
||||||
|
ChannelTypeDify = 37
|
||||||
|
ChannelTypeJina = 38
|
||||||
|
ChannelCloudflare = 39
|
||||||
|
ChannelTypeSiliconFlow = 40
|
||||||
|
ChannelTypeVertexAi = 41
|
||||||
|
ChannelTypeMistral = 42
|
||||||
|
ChannelTypeDeepSeek = 43
|
||||||
|
ChannelTypeMokaAI = 44
|
||||||
|
ChannelTypeVolcEngine = 45
|
||||||
|
ChannelTypeBaiduV2 = 46
|
||||||
|
ChannelTypeXinference = 47
|
||||||
|
ChannelTypeXai = 48
|
||||||
|
ChannelTypeCoze = 49
|
||||||
|
ChannelTypeKling = 50
|
||||||
|
ChannelTypeJimeng = 51
|
||||||
|
ChannelTypeVidu = 52
|
||||||
|
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
var ChannelBaseURLs = []string{
|
||||||
|
"", // 0
|
||||||
|
"https://api.openai.com", // 1
|
||||||
|
"https://oa.api2d.net", // 2
|
||||||
|
"", // 3
|
||||||
|
"http://localhost:11434", // 4
|
||||||
|
"https://api.openai-sb.com", // 5
|
||||||
|
"https://api.openaimax.com", // 6
|
||||||
|
"https://api.ohmygpt.com", // 7
|
||||||
|
"", // 8
|
||||||
|
"https://api.caipacity.com", // 9
|
||||||
|
"https://api.aiproxy.io", // 10
|
||||||
|
"", // 11
|
||||||
|
"https://api.api2gpt.com", // 12
|
||||||
|
"https://api.aigc2d.com", // 13
|
||||||
|
"https://api.anthropic.com", // 14
|
||||||
|
"https://aip.baidubce.com", // 15
|
||||||
|
"https://open.bigmodel.cn", // 16
|
||||||
|
"https://dashscope.aliyuncs.com", // 17
|
||||||
|
"", // 18
|
||||||
|
"https://api.360.cn", // 19
|
||||||
|
"https://openrouter.ai/api", // 20
|
||||||
|
"https://api.aiproxy.io", // 21
|
||||||
|
"https://fastgpt.run/api/openapi", // 22
|
||||||
|
"https://hunyuan.tencentcloudapi.com", //23
|
||||||
|
"https://generativelanguage.googleapis.com", //24
|
||||||
|
"https://api.moonshot.cn", //25
|
||||||
|
"https://open.bigmodel.cn", //26
|
||||||
|
"https://api.perplexity.ai", //27
|
||||||
|
"", //28
|
||||||
|
"", //29
|
||||||
|
"", //30
|
||||||
|
"https://api.lingyiwanwu.com", //31
|
||||||
|
"", //32
|
||||||
|
"", //33
|
||||||
|
"https://api.cohere.ai", //34
|
||||||
|
"https://api.minimax.chat", //35
|
||||||
|
"", //36
|
||||||
|
"https://api.dify.ai", //37
|
||||||
|
"https://api.jina.ai", //38
|
||||||
|
"https://api.cloudflare.com", //39
|
||||||
|
"https://api.siliconflow.cn", //40
|
||||||
|
"", //41
|
||||||
|
"https://api.mistral.ai", //42
|
||||||
|
"https://api.deepseek.com", //43
|
||||||
|
"https://api.moka.ai", //44
|
||||||
|
"https://ark.cn-beijing.volces.com", //45
|
||||||
|
"https://qianfan.baidubce.com", //46
|
||||||
|
"", //47
|
||||||
|
"https://api.x.ai", //48
|
||||||
|
"https://api.coze.cn", //49
|
||||||
|
"https://api.klingai.com", //50
|
||||||
|
"https://visual.volcengineapi.com", //51
|
||||||
|
"https://api.vidu.cn", //52
|
||||||
|
}
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
package constant
|
|
||||||
|
|
||||||
var (
|
|
||||||
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
|
|
||||||
ChanelSettingProxy = "proxy" // Proxy 代理
|
|
||||||
ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
|
|
||||||
)
|
|
||||||
@@ -1,10 +1,49 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
|
type ContextKey string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ContextKeyRequestStartTime = "request_start_time"
|
ContextKeyTokenCountMeta ContextKey = "token_count_meta"
|
||||||
ContextKeyUserSetting = "user_setting"
|
ContextKeyPromptTokens ContextKey = "prompt_tokens"
|
||||||
ContextKeyUserQuota = "user_quota"
|
|
||||||
ContextKeyUserStatus = "user_status"
|
ContextKeyOriginalModel ContextKey = "original_model"
|
||||||
ContextKeyUserEmail = "user_email"
|
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
||||||
ContextKeyUserGroup = "user_group"
|
|
||||||
|
/* token related keys */
|
||||||
|
ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota"
|
||||||
|
ContextKeyTokenKey ContextKey = "token_key"
|
||||||
|
ContextKeyTokenId ContextKey = "token_id"
|
||||||
|
ContextKeyTokenGroup ContextKey = "token_group"
|
||||||
|
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
||||||
|
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
||||||
|
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||||
|
|
||||||
|
/* channel related keys */
|
||||||
|
ContextKeyChannelId ContextKey = "channel_id"
|
||||||
|
ContextKeyChannelName ContextKey = "channel_name"
|
||||||
|
ContextKeyChannelCreateTime ContextKey = "channel_create_time"
|
||||||
|
ContextKeyChannelBaseUrl ContextKey = "base_url"
|
||||||
|
ContextKeyChannelType ContextKey = "channel_type"
|
||||||
|
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||||
|
ContextKeyChannelOtherSetting ContextKey = "channel_other_setting"
|
||||||
|
ContextKeyChannelParamOverride ContextKey = "param_override"
|
||||||
|
ContextKeyChannelOrganization ContextKey = "channel_organization"
|
||||||
|
ContextKeyChannelAutoBan ContextKey = "auto_ban"
|
||||||
|
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
||||||
|
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
|
||||||
|
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
|
||||||
|
ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
|
||||||
|
ContextKeyChannelKey ContextKey = "channel_key"
|
||||||
|
|
||||||
|
/* user related keys */
|
||||||
|
ContextKeyUserId ContextKey = "id"
|
||||||
|
ContextKeyUserSetting ContextKey = "user_setting"
|
||||||
|
ContextKeyUserQuota ContextKey = "user_quota"
|
||||||
|
ContextKeyUserStatus ContextKey = "user_status"
|
||||||
|
ContextKeyUserEmail ContextKey = "user_email"
|
||||||
|
ContextKeyUserGroup ContextKey = "user_group"
|
||||||
|
ContextKeyUsingGroup ContextKey = "group"
|
||||||
|
ContextKeyUserName ContextKey = "username"
|
||||||
|
|
||||||
|
ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
|
||||||
)
|
)
|
||||||
|
|||||||
16
constant/endpoint_type.go
Normal file
16
constant/endpoint_type.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
type EndpointType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
EndpointTypeOpenAI EndpointType = "openai"
|
||||||
|
EndpointTypeOpenAIResponse EndpointType = "openai-response"
|
||||||
|
EndpointTypeAnthropic EndpointType = "anthropic"
|
||||||
|
EndpointTypeGemini EndpointType = "gemini"
|
||||||
|
EndpointTypeJinaRerank EndpointType = "jina-rerank"
|
||||||
|
EndpointTypeImageGeneration EndpointType = "image-generation"
|
||||||
|
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
|
||||||
|
//EndpointTypeSuno EndpointType = "suno-proxy"
|
||||||
|
//EndpointTypeKling EndpointType = "kling"
|
||||||
|
//EndpointTypeJimeng EndpointType = "jimeng"
|
||||||
|
)
|
||||||
@@ -1,9 +1,5 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
import (
|
|
||||||
"one-api/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
var StreamingTimeout int
|
var StreamingTimeout int
|
||||||
var DifyDebug bool
|
var DifyDebug bool
|
||||||
var MaxFileDownloadMB int
|
var MaxFileDownloadMB int
|
||||||
@@ -17,39 +13,3 @@ var NotifyLimitCount int
|
|||||||
var NotificationLimitDurationMinute int
|
var NotificationLimitDurationMinute int
|
||||||
var GenerateDefaultToken bool
|
var GenerateDefaultToken bool
|
||||||
var ErrorLogEnabled bool
|
var ErrorLogEnabled bool
|
||||||
|
|
||||||
//var GeminiModelMap = map[string]string{
|
|
||||||
// "gemini-1.0-pro": "v1",
|
|
||||||
//}
|
|
||||||
|
|
||||||
func InitEnv() {
|
|
||||||
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
|
|
||||||
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
|
||||||
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
|
||||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
|
||||||
ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
|
||||||
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
|
||||||
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
|
||||||
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
|
||||||
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
|
||||||
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
|
||||||
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
|
||||||
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
|
||||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
|
||||||
GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
|
||||||
// 是否启用错误日志
|
|
||||||
ErrorLogEnabled = common.GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
|
||||||
|
|
||||||
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
|
||||||
//if modelVersionMapStr == "" {
|
|
||||||
// return
|
|
||||||
//}
|
|
||||||
//for _, pair := range strings.Split(modelVersionMapStr, ",") {
|
|
||||||
// parts := strings.Split(pair, ":")
|
|
||||||
// if len(parts) == 2 {
|
|
||||||
// GeminiModelMap[parts[0]] = parts[1]
|
|
||||||
// } else {
|
|
||||||
// common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ const (
|
|||||||
MjActionPan = "PAN"
|
MjActionPan = "PAN"
|
||||||
MjActionSwapFace = "SWAP_FACE"
|
MjActionSwapFace = "SWAP_FACE"
|
||||||
MjActionUpload = "UPLOAD"
|
MjActionUpload = "UPLOAD"
|
||||||
|
MjActionVideo = "VIDEO"
|
||||||
|
MjActionEdits = "EDITS"
|
||||||
)
|
)
|
||||||
|
|
||||||
var MidjourneyModel2Action = map[string]string{
|
var MidjourneyModel2Action = map[string]string{
|
||||||
@@ -41,4 +43,6 @@ var MidjourneyModel2Action = map[string]string{
|
|||||||
"mj_pan": MjActionPan,
|
"mj_pan": MjActionPan,
|
||||||
"swap_face": MjActionSwapFace,
|
"swap_face": MjActionSwapFace,
|
||||||
"mj_upload": MjActionUpload,
|
"mj_upload": MjActionUpload,
|
||||||
|
"mj_video": MjActionVideo,
|
||||||
|
"mj_edits": MjActionEdits,
|
||||||
}
|
}
|
||||||
|
|||||||
8
constant/multi_key_mode.go
Normal file
8
constant/multi_key_mode.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
type MultiKeyMode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MultiKeyModeRandom MultiKeyMode = "random" // 随机
|
||||||
|
MultiKeyModePolling MultiKeyMode = "polling" // 轮询
|
||||||
|
)
|
||||||
@@ -10,6 +10,9 @@ const (
|
|||||||
const (
|
const (
|
||||||
SunoActionMusic = "MUSIC"
|
SunoActionMusic = "MUSIC"
|
||||||
SunoActionLyrics = "LYRICS"
|
SunoActionLyrics = "LYRICS"
|
||||||
|
|
||||||
|
TaskActionGenerate = "generate"
|
||||||
|
TaskActionTextGenerate = "textGenerate"
|
||||||
)
|
)
|
||||||
|
|
||||||
var SunoModel2Action = map[string]string{
|
var SunoModel2Action = map[string]string{
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
package constant
|
|
||||||
|
|
||||||
var (
|
|
||||||
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
|
|
||||||
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
|
|
||||||
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
|
|
||||||
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
|
|
||||||
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
|
|
||||||
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
NotifyTypeEmail = "email" // Email 邮件
|
|
||||||
NotifyTypeWebhook = "webhook" // Webhook
|
|
||||||
)
|
|
||||||
@@ -7,11 +7,16 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/setting"
|
||||||
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -130,7 +135,11 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
|
|||||||
for k := range headers {
|
for k := range headers {
|
||||||
req.Header.Add(k, headers.Get(k))
|
req.Header.Add(k, headers.Get(k))
|
||||||
}
|
}
|
||||||
res, err := service.GetHttpClient().Do(req)
|
client, err := service.NewProxyHttpClient(channel.GetSetting().Proxy)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -304,34 +313,70 @@ func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
|
|||||||
return balance, nil
|
return balance, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
||||||
|
url := "https://api.moonshot.cn/v1/users/me/balance"
|
||||||
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type MoonshotBalanceData struct {
|
||||||
|
AvailableBalance float64 `json:"available_balance"`
|
||||||
|
VoucherBalance float64 `json:"voucher_balance"`
|
||||||
|
CashBalance float64 `json:"cash_balance"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MoonshotBalanceResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Data MoonshotBalanceData `json:"data"`
|
||||||
|
Scode string `json:"scode"`
|
||||||
|
Status bool `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
response := MoonshotBalanceResponse{}
|
||||||
|
err = json.Unmarshal(body, &response)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if !response.Status || response.Code != 0 {
|
||||||
|
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
||||||
|
}
|
||||||
|
availableBalanceCny := response.Data.AvailableBalance
|
||||||
|
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
|
||||||
|
channel.UpdateBalance(availableBalanceUsd)
|
||||||
|
return availableBalanceUsd, nil
|
||||||
|
}
|
||||||
|
|
||||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||||
if channel.GetBaseURL() == "" {
|
if channel.GetBaseURL() == "" {
|
||||||
channel.BaseURL = &baseURL
|
channel.BaseURL = &baseURL
|
||||||
}
|
}
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeOpenAI:
|
case constant.ChannelTypeOpenAI:
|
||||||
if channel.GetBaseURL() != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
baseURL = channel.GetBaseURL()
|
baseURL = channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
case common.ChannelTypeAzure:
|
case constant.ChannelTypeAzure:
|
||||||
return 0, errors.New("尚未实现")
|
return 0, errors.New("尚未实现")
|
||||||
case common.ChannelTypeCustom:
|
case constant.ChannelTypeCustom:
|
||||||
baseURL = channel.GetBaseURL()
|
baseURL = channel.GetBaseURL()
|
||||||
//case common.ChannelTypeOpenAISB:
|
//case common.ChannelTypeOpenAISB:
|
||||||
// return updateChannelOpenAISBBalance(channel)
|
// return updateChannelOpenAISBBalance(channel)
|
||||||
case common.ChannelTypeAIProxy:
|
case constant.ChannelTypeAIProxy:
|
||||||
return updateChannelAIProxyBalance(channel)
|
return updateChannelAIProxyBalance(channel)
|
||||||
case common.ChannelTypeAPI2GPT:
|
case constant.ChannelTypeAPI2GPT:
|
||||||
return updateChannelAPI2GPTBalance(channel)
|
return updateChannelAPI2GPTBalance(channel)
|
||||||
case common.ChannelTypeAIGC2D:
|
case constant.ChannelTypeAIGC2D:
|
||||||
return updateChannelAIGC2DBalance(channel)
|
return updateChannelAIGC2DBalance(channel)
|
||||||
case common.ChannelTypeSiliconFlow:
|
case constant.ChannelTypeSiliconFlow:
|
||||||
return updateChannelSiliconFlowBalance(channel)
|
return updateChannelSiliconFlowBalance(channel)
|
||||||
case common.ChannelTypeDeepSeek:
|
case constant.ChannelTypeDeepSeek:
|
||||||
return updateChannelDeepSeekBalance(channel)
|
return updateChannelDeepSeekBalance(channel)
|
||||||
case common.ChannelTypeOpenRouter:
|
case constant.ChannelTypeOpenRouter:
|
||||||
return updateChannelOpenRouterBalance(channel)
|
return updateChannelOpenRouterBalance(channel)
|
||||||
|
case constant.ChannelTypeMoonshot:
|
||||||
|
return updateChannelMoonshotBalance(channel)
|
||||||
default:
|
default:
|
||||||
return 0, errors.New("尚未实现")
|
return 0, errors.New("尚未实现")
|
||||||
}
|
}
|
||||||
@@ -370,26 +415,24 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
|||||||
func UpdateChannelBalance(c *gin.Context) {
|
func UpdateChannelBalance(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel, err := model.GetChannelById(id, true)
|
channel, err := model.CacheGetChannel(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if channel.ChannelInfo.IsMultiKey {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": "多密钥渠道不支持余额查询",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
balance, err := updateChannelBalance(channel)
|
balance, err := updateChannelBalance(channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -397,7 +440,6 @@ func UpdateChannelBalance(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"balance": balance,
|
"balance": balance,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateAllChannelsBalance() error {
|
func updateAllChannelsBalance() error {
|
||||||
@@ -409,6 +451,9 @@ func updateAllChannelsBalance() error {
|
|||||||
if channel.Status != common.ChannelStatusEnabled {
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if channel.ChannelInfo.IsMultiKey {
|
||||||
|
continue // skip multi-key channels
|
||||||
|
}
|
||||||
// TODO: support Azure
|
// TODO: support Azure
|
||||||
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
|
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
|
||||||
// continue
|
// continue
|
||||||
@@ -419,7 +464,7 @@ func updateAllChannelsBalance() error {
|
|||||||
} else {
|
} else {
|
||||||
// err is nil & balance <= 0 means quota is used up
|
// err is nil & balance <= 0 means quota is used up
|
||||||
if balance <= 0 {
|
if balance <= 0 {
|
||||||
service.DisableChannel(channel.Id, channel.Name, "余额不足")
|
service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
time.Sleep(common.RequestInterval)
|
time.Sleep(common.RequestInterval)
|
||||||
@@ -431,10 +476,7 @@ func UpdateAllChannelsBalance(c *gin.Context) {
|
|||||||
// TODO: make it async
|
// TODO: make it async
|
||||||
err := updateAllChannelsBalance()
|
err := updateAllChannelsBalance()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -11,14 +11,16 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -29,16 +31,49 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
type testResult struct {
|
||||||
|
context *gin.Context
|
||||||
|
localErr error
|
||||||
|
newAPIError *types.NewAPIError
|
||||||
|
}
|
||||||
|
|
||||||
|
func testChannel(channel *model.Channel, testModel string) testResult {
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
if channel.Type == common.ChannelTypeMidjourney {
|
if channel.Type == constant.ChannelTypeMidjourney {
|
||||||
return errors.New("midjourney channel test is not supported"), nil
|
return testResult{
|
||||||
|
localErr: errors.New("midjourney channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if channel.Type == common.ChannelTypeMidjourneyPlus {
|
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
||||||
return errors.New("midjourney plus channel test is not supported!!!"), nil
|
return testResult{
|
||||||
|
localErr: errors.New("midjourney plus channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if channel.Type == common.ChannelTypeSunoAPI {
|
if channel.Type == constant.ChannelTypeSunoAPI {
|
||||||
return errors.New("suno channel test is not supported"), nil
|
return testResult{
|
||||||
|
localErr: errors.New("suno channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if channel.Type == constant.ChannelTypeKling {
|
||||||
|
return testResult{
|
||||||
|
localErr: errors.New("kling channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if channel.Type == constant.ChannelTypeJimeng {
|
||||||
|
return testResult{
|
||||||
|
localErr: errors.New("jimeng channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if channel.Type == constant.ChannelTypeVidu {
|
||||||
|
return testResult{
|
||||||
|
localErr: errors.New("vidu channel test is not supported"),
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
@@ -50,7 +85,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
||||||
strings.Contains(testModel, "bge-") || // bge 系列模型
|
strings.Contains(testModel, "bge-") || // bge 系列模型
|
||||||
strings.Contains(testModel, "embed") ||
|
strings.Contains(testModel, "embed") ||
|
||||||
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
|
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
|
||||||
requestPath = "/v1/embeddings" // 修改请求路径
|
requestPath = "/v1/embeddings" // 修改请求路径
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,80 +110,162 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
|
|
||||||
cache, err := model.GetUserCache(1)
|
cache, err := model.GetUserCache(1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return testResult{
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
cache.WriteContext(c)
|
cache.WriteContext(c)
|
||||||
|
|
||||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||||
c.Request.Header.Set("Content-Type", "application/json")
|
c.Request.Header.Set("Content-Type", "application/json")
|
||||||
c.Set("channel", channel.Type)
|
c.Set("channel", channel.Type)
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
group, _ := model.GetUserGroup(1, false)
|
group, _ := model.GetUserGroup(1, false)
|
||||||
c.Set("group", group)
|
c.Set("group", group)
|
||||||
|
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||||
|
if newAPIError != nil {
|
||||||
info := relaycommon.GenRelayInfo(c)
|
return testResult{
|
||||||
|
context: c,
|
||||||
err = helper.ModelMappedHelper(c, info)
|
localErr: newAPIError,
|
||||||
if err != nil {
|
newAPIError: newAPIError,
|
||||||
return err, nil
|
}
|
||||||
}
|
}
|
||||||
testModel = info.UpstreamModelName
|
request := buildTestRequest(testModel)
|
||||||
|
|
||||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
// Determine relay format based on request path
|
||||||
|
relayFormat := types.RelayFormatOpenAI
|
||||||
|
if c.Request.URL.Path == "/v1/embeddings" {
|
||||||
|
relayFormat = types.RelayFormatEmbedding
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info.InitChannelMeta(c)
|
||||||
|
|
||||||
|
err = helper.ModelMappedHelper(c, info, request)
|
||||||
|
if err != nil {
|
||||||
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testModel = info.UpstreamModelName
|
||||||
|
request.Model = testModel
|
||||||
|
|
||||||
|
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
|
||||||
|
newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
request := buildTestRequest(testModel)
|
//// 创建一个用于日志的 info 副本,移除 ApiKey
|
||||||
// 创建一个用于日志的 info 副本,移除 ApiKey
|
//logInfo := info
|
||||||
logInfo := *info
|
//logInfo.ApiKey = ""
|
||||||
logInfo.ApiKey = ""
|
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString()))
|
||||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
|
|
||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
adaptor.Init(info)
|
adaptor.Init(info)
|
||||||
|
|
||||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
var convertedRequest any
|
||||||
|
// 根据 RelayMode 选择正确的转换函数
|
||||||
|
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||||
|
// 创建一个 EmbeddingRequest
|
||||||
|
embeddingRequest := dto.EmbeddingRequest{
|
||||||
|
Input: request.Input,
|
||||||
|
Model: request.Model,
|
||||||
|
}
|
||||||
|
// 调用专门用于 Embedding 的转换函数
|
||||||
|
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
|
||||||
|
} else {
|
||||||
|
// 对其他所有请求类型(如 Chat),保持原有逻辑
|
||||||
|
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(convertedRequest)
|
jsonData, err := json.Marshal(convertedRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
requestBody := bytes.NewBuffer(jsonData)
|
requestBody := bytes.NewBuffer(jsonData)
|
||||||
c.Request.Body = io.NopCloser(requestBody)
|
c.Request.Body = io.NopCloser(requestBody)
|
||||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
var httpResp *http.Response
|
var httpResp *http.Response
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
httpResp = resp.(*http.Response)
|
httpResp = resp.(*http.Response)
|
||||||
if httpResp.StatusCode != http.StatusOK {
|
if httpResp.StatusCode != http.StatusOK {
|
||||||
err := service.RelayErrorHandler(httpResp, true)
|
err := service.RelayErrorHandler(httpResp, true)
|
||||||
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: respErr,
|
||||||
|
newAPIError: respErr,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if usageA == nil {
|
if usageA == nil {
|
||||||
return errors.New("usage is nil"), nil
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: errors.New("usage is nil"),
|
||||||
|
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
usage := usageA.(*dto.Usage)
|
usage := usageA.(*dto.Usage)
|
||||||
result := w.Result()
|
result := w.Result()
|
||||||
respBody, err := io.ReadAll(result.Body)
|
respBody, err := io.ReadAll(result.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: err,
|
||||||
|
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
info.PromptTokens = usage.PromptTokens
|
info.PromptTokens = usage.PromptTokens
|
||||||
|
|
||||||
@@ -165,12 +282,27 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
consumedTime := float64(milliseconds) / 1000.0
|
consumedTime := float64(milliseconds) / 1000.0
|
||||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
|
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
||||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice)
|
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
|
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
|
||||||
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
|
ChannelId: channel.Id,
|
||||||
|
PromptTokens: usage.PromptTokens,
|
||||||
|
CompletionTokens: usage.CompletionTokens,
|
||||||
|
ModelName: info.OriginModelName,
|
||||||
|
TokenName: "模型测试",
|
||||||
|
Quota: quota,
|
||||||
|
Content: "模型测试",
|
||||||
|
UseTimeSeconds: int(consumedTime),
|
||||||
|
IsStream: info.IsStream,
|
||||||
|
Group: info.UsingGroup,
|
||||||
|
Other: other,
|
||||||
|
})
|
||||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||||
return nil, nil
|
return testResult{
|
||||||
|
context: c,
|
||||||
|
localErr: nil,
|
||||||
|
newAPIError: nil,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||||
@@ -185,7 +317,7 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
|||||||
strings.Contains(model, "bge-") {
|
strings.Contains(model, "bge-") {
|
||||||
testRequest.Model = model
|
testRequest.Model = model
|
||||||
// Embedding 请求
|
// Embedding 请求
|
||||||
testRequest.Input = []string{"hello world"}
|
testRequest.Input = []any{"hello world"} // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
|
||||||
return testRequest
|
return testRequest
|
||||||
}
|
}
|
||||||
// 并非Embedding 模型
|
// 并非Embedding 模型
|
||||||
@@ -196,14 +328,14 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
|||||||
testRequest.MaxTokens = 50
|
testRequest.MaxTokens = 50
|
||||||
}
|
}
|
||||||
} else if strings.Contains(model, "gemini") {
|
} else if strings.Contains(model, "gemini") {
|
||||||
testRequest.MaxTokens = 300
|
testRequest.MaxTokens = 3000
|
||||||
} else {
|
} else {
|
||||||
testRequest.MaxTokens = 10
|
testRequest.MaxTokens = 10
|
||||||
}
|
}
|
||||||
content, _ := json.Marshal("hi")
|
|
||||||
testMessage := dto.Message{
|
testMessage := dto.Message{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: content,
|
Content: "hi",
|
||||||
}
|
}
|
||||||
testRequest.Model = model
|
testRequest.Model = model
|
||||||
testRequest.Messages = append(testRequest.Messages, testMessage)
|
testRequest.Messages = append(testRequest.Messages, testMessage)
|
||||||
@@ -213,31 +345,41 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
|||||||
func TestChannel(c *gin.Context) {
|
func TestChannel(c *gin.Context) {
|
||||||
channelId, err := strconv.Atoi(c.Param("id"))
|
channelId, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel, err := model.GetChannelById(channelId, true)
|
channel, err := model.CacheGetChannel(channelId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
channel, err = model.GetChannelById(channelId, true)
|
||||||
"success": false,
|
if err != nil {
|
||||||
"message": err.Error(),
|
common.ApiError(c, err)
|
||||||
})
|
return
|
||||||
return
|
}
|
||||||
}
|
}
|
||||||
|
//defer func() {
|
||||||
|
// if channel.ChannelInfo.IsMultiKey {
|
||||||
|
// go func() { _ = channel.SaveChannelInfo() }()
|
||||||
|
// }
|
||||||
|
//}()
|
||||||
testModel := c.Query("model")
|
testModel := c.Query("model")
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err, _ = testChannel(channel, testModel)
|
result := testChannel(channel, testModel)
|
||||||
|
if result.localErr != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": result.localErr.Error(),
|
||||||
|
"time": 0.0,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
go channel.UpdateResponseTime(milliseconds)
|
go channel.UpdateResponseTime(milliseconds)
|
||||||
consumedTime := float64(milliseconds) / 1000.0
|
consumedTime := float64(milliseconds) / 1000.0
|
||||||
if err != nil {
|
if result.newAPIError != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": result.newAPIError.Error(),
|
||||||
"time": consumedTime,
|
"time": consumedTime,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
@@ -262,52 +404,59 @@ func testAllChannels(notify bool) error {
|
|||||||
}
|
}
|
||||||
testAllChannelsRunning = true
|
testAllChannelsRunning = true
|
||||||
testAllChannelsLock.Unlock()
|
testAllChannelsLock.Unlock()
|
||||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
|
||||||
if err != nil {
|
if getChannelErr != nil {
|
||||||
return err
|
return getChannelErr
|
||||||
}
|
}
|
||||||
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
||||||
if disableThreshold == 0 {
|
if disableThreshold == 0 {
|
||||||
disableThreshold = 10000000 // a impossible value
|
disableThreshold = 10000000 // a impossible value
|
||||||
}
|
}
|
||||||
gopool.Go(func() {
|
gopool.Go(func() {
|
||||||
|
// 使用 defer 确保无论如何都会重置运行状态,防止死锁
|
||||||
|
defer func() {
|
||||||
|
testAllChannelsLock.Lock()
|
||||||
|
testAllChannelsRunning = false
|
||||||
|
testAllChannelsLock.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err, openaiWithStatusErr := testChannel(channel, "")
|
result := testChannel(channel, "")
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
|
|
||||||
shouldBanChannel := false
|
shouldBanChannel := false
|
||||||
|
newAPIError := result.newAPIError
|
||||||
// request error disables the channel
|
// request error disables the channel
|
||||||
if openaiWithStatusErr != nil {
|
if newAPIError != nil {
|
||||||
oaiErr := openaiWithStatusErr.Error
|
shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
|
||||||
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
|
|
||||||
shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if milliseconds > disableThreshold {
|
// 当错误检查通过,才检查响应时间
|
||||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
||||||
shouldBanChannel = true
|
if milliseconds > disableThreshold {
|
||||||
|
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||||
|
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
|
||||||
|
shouldBanChannel = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// disable channel
|
// disable channel
|
||||||
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
||||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// enable channel
|
// enable channel
|
||||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
|
if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
|
||||||
service.EnableChannel(channel.Id, channel.Name)
|
service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
channel.UpdateResponseTime(milliseconds)
|
channel.UpdateResponseTime(milliseconds)
|
||||||
time.Sleep(common.RequestInterval)
|
time.Sleep(common.RequestInterval)
|
||||||
}
|
}
|
||||||
testAllChannelsLock.Lock()
|
|
||||||
testAllChannelsRunning = false
|
|
||||||
testAllChannelsLock.Unlock()
|
|
||||||
if notify {
|
if notify {
|
||||||
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
|
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
|
||||||
}
|
}
|
||||||
@@ -318,10 +467,7 @@ func testAllChannels(notify bool) error {
|
|||||||
func TestAllChannels(c *gin.Context) {
|
func TestAllChannels(c *gin.Context) {
|
||||||
err := testAllChannels(true)
|
err := testAllChannels(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -332,6 +478,10 @@ func TestAllChannels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AutomaticallyTestChannels(frequency int) {
|
func AutomaticallyTestChannels(frequency int) {
|
||||||
|
if frequency <= 0 {
|
||||||
|
common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
|
||||||
|
return
|
||||||
|
}
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||||
common.SysLog("testing all channels")
|
common.SysLog("testing all channels")
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
104
controller/console_migrate.go
Normal file
104
controller/console_migrate.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
// 用于迁移检测的旧键,该文件下个版本会删除
|
||||||
|
|
||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
|
||||||
|
func MigrateConsoleSetting(c *gin.Context) {
|
||||||
|
// 读取全部 option
|
||||||
|
opts, err := model.AllOption()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 建立 map
|
||||||
|
valMap := map[string]string{}
|
||||||
|
for _, o := range opts {
|
||||||
|
valMap[o.Key] = o.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 APIInfo
|
||||||
|
if v := valMap["ApiInfo"]; v != "" {
|
||||||
|
var arr []map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||||
|
if len(arr) > 50 {
|
||||||
|
arr = arr[:50]
|
||||||
|
}
|
||||||
|
bytes, _ := json.Marshal(arr)
|
||||||
|
model.UpdateOption("console_setting.api_info", string(bytes))
|
||||||
|
}
|
||||||
|
model.UpdateOption("ApiInfo", "")
|
||||||
|
}
|
||||||
|
// Announcements 直接搬
|
||||||
|
if v := valMap["Announcements"]; v != "" {
|
||||||
|
model.UpdateOption("console_setting.announcements", v)
|
||||||
|
model.UpdateOption("Announcements", "")
|
||||||
|
}
|
||||||
|
// FAQ 转换
|
||||||
|
if v := valMap["FAQ"]; v != "" {
|
||||||
|
var arr []map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||||
|
out := []map[string]interface{}{}
|
||||||
|
for _, item := range arr {
|
||||||
|
q, _ := item["question"].(string)
|
||||||
|
if q == "" {
|
||||||
|
q, _ = item["title"].(string)
|
||||||
|
}
|
||||||
|
a, _ := item["answer"].(string)
|
||||||
|
if a == "" {
|
||||||
|
a, _ = item["content"].(string)
|
||||||
|
}
|
||||||
|
if q != "" && a != "" {
|
||||||
|
out = append(out, map[string]interface{}{"question": q, "answer": a})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(out) > 50 {
|
||||||
|
out = out[:50]
|
||||||
|
}
|
||||||
|
bytes, _ := json.Marshal(out)
|
||||||
|
model.UpdateOption("console_setting.faq", string(bytes))
|
||||||
|
}
|
||||||
|
model.UpdateOption("FAQ", "")
|
||||||
|
}
|
||||||
|
// Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
|
||||||
|
url := valMap["UptimeKumaUrl"]
|
||||||
|
slug := valMap["UptimeKumaSlug"]
|
||||||
|
if url != "" && slug != "" {
|
||||||
|
// 仅当同时存在 URL 与 Slug 时才进行迁移
|
||||||
|
groups := []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"categoryName": "old",
|
||||||
|
"url": url,
|
||||||
|
"slug": slug,
|
||||||
|
"description": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
bytes, _ := json.Marshal(groups)
|
||||||
|
model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
|
||||||
|
}
|
||||||
|
// 清空旧键内容
|
||||||
|
if url != "" {
|
||||||
|
model.UpdateOption("UptimeKumaUrl", "")
|
||||||
|
}
|
||||||
|
if slug != "" {
|
||||||
|
model.UpdateOption("UptimeKumaSlug", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除旧键记录
|
||||||
|
oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
|
||||||
|
model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
|
||||||
|
|
||||||
|
// 重新加载 OptionMap
|
||||||
|
model.InitOptionMap()
|
||||||
|
common.SysLog("console setting migrated")
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
|
||||||
|
}
|
||||||
@@ -5,13 +5,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-contrib/sessions"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GitHubOAuthResponse struct {
|
type GitHubOAuthResponse struct {
|
||||||
@@ -103,10 +104,7 @@ func GitHubOAuth(c *gin.Context) {
|
|||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
githubUser, err := getGitHubUserInfoByCode(code)
|
githubUser, err := getGitHubUserInfoByCode(code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user := model.User{
|
user := model.User{
|
||||||
@@ -185,10 +183,7 @@ func GitHubBind(c *gin.Context) {
|
|||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
githubUser, err := getGitHubUserInfoByCode(code)
|
githubUser, err := getGitHubUserInfoByCode(code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user := model.User{
|
user := model.User{
|
||||||
@@ -207,19 +202,13 @@ func GitHubBind(c *gin.Context) {
|
|||||||
user.Id = id.(int)
|
user.Id = id.(int)
|
||||||
err = user.FillUserById()
|
err = user.FillUserById()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.GitHubId = githubUser.Login
|
user.GitHubId = githubUser.Login
|
||||||
err = user.Update(false)
|
err = user.Update(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -239,10 +228,7 @@ func GenerateOAuthCode(c *gin.Context) {
|
|||||||
session.Set("oauth_state", state)
|
session.Set("oauth_state", state)
|
||||||
err := session.Save()
|
err := session.Save()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetGroups(c *gin.Context) {
|
func GetGroups(c *gin.Context) {
|
||||||
groupNames := make([]string, 0)
|
groupNames := make([]string, 0)
|
||||||
for groupName, _ := range setting.GetGroupRatioCopy() {
|
for groupName := range ratio_setting.GetGroupRatioCopy() {
|
||||||
groupNames = append(groupNames, groupName)
|
groupNames = append(groupNames, groupName)
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -24,7 +26,7 @@ func GetUserGroups(c *gin.Context) {
|
|||||||
userGroup := ""
|
userGroup := ""
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
userGroup, _ = model.GetUserGroup(userId, false)
|
userGroup, _ = model.GetUserGroup(userId, false)
|
||||||
for groupName, ratio := range setting.GetGroupRatioCopy() {
|
for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
|
||||||
// UserUsableGroups contains the groups that the user can use
|
// UserUsableGroups contains the groups that the user can use
|
||||||
userUsableGroups := setting.GetUserUsableGroups(userGroup)
|
userUsableGroups := setting.GetUserUsableGroups(userGroup)
|
||||||
if desc, ok := userUsableGroups[groupName]; ok {
|
if desc, ok := userUsableGroups[groupName]; ok {
|
||||||
@@ -34,6 +36,12 @@ func GetUserGroups(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if setting.GroupInUserUsableGroups("auto") {
|
||||||
|
usableGroups["auto"] = map[string]interface{}{
|
||||||
|
"ratio": "自动",
|
||||||
|
"desc": setting.GetUsableGroupDescription("auto"),
|
||||||
|
}
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
|
|||||||
@@ -38,10 +38,7 @@ func LinuxDoBind(c *gin.Context) {
|
|||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,20 +60,14 @@ func LinuxDoBind(c *gin.Context) {
|
|||||||
|
|
||||||
err = user.FillUserById()
|
err = user.FillUserById()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
|
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
|
||||||
err = user.Update(false)
|
err = user.Update(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -202,10 +193,7 @@ func LinuxdoOAuth(c *gin.Context) {
|
|||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -232,21 +220,29 @@ func LinuxdoOAuth(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if common.RegisterEnabled {
|
if common.RegisterEnabled {
|
||||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel {
|
||||||
user.DisplayName = linuxdoUser.Name
|
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||||
user.Role = common.RoleCommonUser
|
user.DisplayName = linuxdoUser.Name
|
||||||
user.Status = common.UserStatusEnabled
|
user.Role = common.RoleCommonUser
|
||||||
|
user.Status = common.UserStatusEnabled
|
||||||
|
|
||||||
affCode := session.Get("aff")
|
affCode := session.Get("aff")
|
||||||
inviterId := 0
|
inviterId := 0
|
||||||
if affCode != nil {
|
if affCode != nil {
|
||||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := user.Insert(inviterId); err != nil {
|
if err := user.Insert(inviterId); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": "Linux DO 信任等级未达到管理员设置的最低信任等级",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,14 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func GetAllLogs(c *gin.Context) {
|
func GetAllLogs(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
|
||||||
if p < 1 {
|
|
||||||
p = 1
|
|
||||||
}
|
|
||||||
if pageSize < 0 {
|
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
logType, _ := strconv.Atoi(c.Query("type"))
|
logType, _ := strconv.Atoi(c.Query("type"))
|
||||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||||
@@ -26,38 +19,19 @@ func GetAllLogs(c *gin.Context) {
|
|||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||||
group := c.Query("group")
|
group := c.Query("group")
|
||||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel, group)
|
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), channel, group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
pageInfo.SetTotal(int(total))
|
||||||
"success": true,
|
pageInfo.SetItems(logs)
|
||||||
"message": "",
|
common.ApiSuccess(c, pageInfo)
|
||||||
"data": map[string]any{
|
return
|
||||||
"items": logs,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserLogs(c *gin.Context) {
|
func GetUserLogs(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
|
||||||
if p < 1 {
|
|
||||||
p = 1
|
|
||||||
}
|
|
||||||
if pageSize < 0 {
|
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
if pageSize > 100 {
|
|
||||||
pageSize = 100
|
|
||||||
}
|
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
logType, _ := strconv.Atoi(c.Query("type"))
|
logType, _ := strconv.Atoi(c.Query("type"))
|
||||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||||
@@ -65,24 +39,14 @@ func GetUserLogs(c *gin.Context) {
|
|||||||
tokenName := c.Query("token_name")
|
tokenName := c.Query("token_name")
|
||||||
modelName := c.Query("model_name")
|
modelName := c.Query("model_name")
|
||||||
group := c.Query("group")
|
group := c.Query("group")
|
||||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize, group)
|
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
pageInfo.SetTotal(int(total))
|
||||||
"success": true,
|
pageInfo.SetItems(logs)
|
||||||
"message": "",
|
common.ApiSuccess(c, pageInfo)
|
||||||
"data": map[string]any{
|
|
||||||
"items": logs,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,10 +54,7 @@ func SearchAllLogs(c *gin.Context) {
|
|||||||
keyword := c.Query("keyword")
|
keyword := c.Query("keyword")
|
||||||
logs, err := model.SearchAllLogs(keyword)
|
logs, err := model.SearchAllLogs(keyword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -109,10 +70,7 @@ func SearchUserLogs(c *gin.Context) {
|
|||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
logs, err := model.SearchUserLogs(userId, keyword)
|
logs, err := model.SearchUserLogs(userId, keyword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -198,10 +156,7 @@ func DeleteHistoryLogs(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
|
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -5,17 +5,17 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UpdateMidjourneyTaskBulk() {
|
func UpdateMidjourneyTaskBulk() {
|
||||||
@@ -29,7 +29,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
||||||
taskChannelM := make(map[int][]string)
|
taskChannelM := make(map[int][]string)
|
||||||
taskM := make(map[string]*model.Midjourney)
|
taskM := make(map[string]*model.Midjourney)
|
||||||
nullTaskIds := make([]int, 0)
|
nullTaskIds := make([]int, 0)
|
||||||
@@ -48,9 +48,9 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
|
||||||
} else {
|
} else {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
|
logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(taskChannelM) == 0 {
|
if len(taskChannelM) == 0 {
|
||||||
@@ -58,20 +58,20 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for channelId, taskIds := range taskChannelM {
|
for channelId, taskIds := range taskChannelM {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||||
if len(taskIds) == 0 {
|
if len(taskIds) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
midjourneyChannel, err := model.CacheGetChannel(channelId)
|
midjourneyChannel, err := model.CacheGetChannel(channelId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
|
||||||
err := model.MjBulkUpdate(taskIds, map[string]any{
|
err := model.MjBulkUpdate(taskIds, map[string]any{
|
||||||
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||||
"status": "FAILURE",
|
"status": "FAILURE",
|
||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -82,7 +82,7 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
})
|
})
|
||||||
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
|
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 设置超时时间
|
// 设置超时时间
|
||||||
@@ -94,22 +94,22 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||||
resp, err := service.GetHttpClient().Do(req)
|
resp, err := service.GetHttpClient().Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
var responseItems []dto.MidjourneyDto
|
var responseItems []dto.MidjourneyDto
|
||||||
err = json.Unmarshal(responseBody, &responseItems)
|
err = json.Unmarshal(responseBody, &responseItems)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
@@ -146,9 +146,25 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
||||||
task.Buttons = string(buttonStr)
|
task.Buttons = string(buttonStr)
|
||||||
}
|
}
|
||||||
|
// 映射 VideoUrl
|
||||||
|
task.VideoUrl = responseItem.VideoUrl
|
||||||
|
|
||||||
|
// 映射 VideoUrls - 将数组序列化为 JSON 字符串
|
||||||
|
if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 {
|
||||||
|
videoUrlsStr, err := json.Marshal(responseItem.VideoUrls)
|
||||||
|
if err != nil {
|
||||||
|
logger.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err))
|
||||||
|
task.VideoUrls = "[]" // 失败时设置为空数组
|
||||||
|
} else {
|
||||||
|
task.VideoUrls = string(videoUrlsStr)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
task.VideoUrls = "" // 空值时清空字段
|
||||||
|
}
|
||||||
|
|
||||||
shouldReturnQuota := false
|
shouldReturnQuota := false
|
||||||
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
||||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||||
task.Progress = "100%"
|
task.Progress = "100%"
|
||||||
if task.Quota != 0 {
|
if task.Quota != 0 {
|
||||||
shouldReturnQuota = true
|
shouldReturnQuota = true
|
||||||
@@ -156,14 +172,14 @@ func UpdateMidjourneyTaskBulk() {
|
|||||||
}
|
}
|
||||||
err = task.Update()
|
err = task.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||||
} else {
|
} else {
|
||||||
if shouldReturnQuota {
|
if shouldReturnQuota {
|
||||||
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
|
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||||
}
|
}
|
||||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota))
|
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, logger.LogQuota(task.Quota))
|
||||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -209,15 +225,26 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
|
|||||||
if oldTask.Progress != "100%" && newTask.FailReason != "" {
|
if oldTask.Progress != "100%" && newTask.FailReason != "" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
// 检查 VideoUrl 是否需要更新
|
||||||
|
if oldTask.VideoUrl != newTask.VideoUrl {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// 检查 VideoUrls 是否需要更新
|
||||||
|
if newTask.VideoUrls != nil && len(newTask.VideoUrls) > 0 {
|
||||||
|
newVideoUrlsStr, _ := json.Marshal(newTask.VideoUrls)
|
||||||
|
if oldTask.VideoUrls != string(newVideoUrlsStr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
} else if oldTask.VideoUrls != "" {
|
||||||
|
// 如果新数据没有 VideoUrls 但旧数据有,需要更新(清空)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllMidjourney(c *gin.Context) {
|
func GetAllMidjourney(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
if p < 0 {
|
|
||||||
p = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析其他查询参数
|
// 解析其他查询参数
|
||||||
queryParams := model.TaskQueryParams{
|
queryParams := model.TaskQueryParams{
|
||||||
@@ -227,31 +254,24 @@ func GetAllMidjourney(c *gin.Context) {
|
|||||||
EndTimestamp: c.Query("end_timestamp"),
|
EndTimestamp: c.Query("end_timestamp"),
|
||||||
}
|
}
|
||||||
|
|
||||||
logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
items := model.GetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||||
if logs == nil {
|
total := model.CountAllTasks(queryParams)
|
||||||
logs = make([]*model.Midjourney, 0)
|
|
||||||
}
|
|
||||||
if setting.MjForwardUrlEnabled {
|
if setting.MjForwardUrlEnabled {
|
||||||
for i, midjourney := range logs {
|
for i, midjourney := range items {
|
||||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||||
logs[i] = midjourney
|
items[i] = midjourney
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
pageInfo.SetTotal(int(total))
|
||||||
"success": true,
|
pageInfo.SetItems(items)
|
||||||
"message": "",
|
common.ApiSuccess(c, pageInfo)
|
||||||
"data": logs,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserMidjourney(c *gin.Context) {
|
func GetUserMidjourney(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
if p < 0 {
|
|
||||||
p = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
log.Printf("userId = %d \n", userId)
|
|
||||||
|
|
||||||
queryParams := model.TaskQueryParams{
|
queryParams := model.TaskQueryParams{
|
||||||
MjID: c.Query("mj_id"),
|
MjID: c.Query("mj_id"),
|
||||||
@@ -259,19 +279,16 @@ func GetUserMidjourney(c *gin.Context) {
|
|||||||
EndTimestamp: c.Query("end_timestamp"),
|
EndTimestamp: c.Query("end_timestamp"),
|
||||||
}
|
}
|
||||||
|
|
||||||
logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
items := model.GetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||||
if logs == nil {
|
total := model.CountAllUserTask(userId, queryParams)
|
||||||
logs = make([]*model.Midjourney, 0)
|
|
||||||
}
|
|
||||||
if setting.MjForwardUrlEnabled {
|
if setting.MjForwardUrlEnabled {
|
||||||
for i, midjourney := range logs {
|
for i, midjourney := range items {
|
||||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||||
logs[i] = midjourney
|
items[i] = midjourney
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
pageInfo.SetTotal(int(total))
|
||||||
"success": true,
|
pageInfo.SetItems(items)
|
||||||
"message": "",
|
common.ApiSuccess(c, pageInfo)
|
||||||
"data": logs,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/setting/console_setting"
|
||||||
"one-api/setting/operation_setting"
|
"one-api/setting/operation_setting"
|
||||||
"one-api/setting/system_setting"
|
"one-api/setting/system_setting"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -24,57 +26,90 @@ func TestStatus(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 获取HTTP统计信息
|
||||||
|
httpStats := middleware.GetStats()
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "Server is running",
|
"message": "Server is running",
|
||||||
|
"http_stats": httpStats,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetStatus(c *gin.Context) {
|
func GetStatus(c *gin.Context) {
|
||||||
|
|
||||||
|
cs := console_setting.GetConsoleSetting()
|
||||||
|
|
||||||
|
data := gin.H{
|
||||||
|
"version": common.Version,
|
||||||
|
"start_time": common.StartTime,
|
||||||
|
"email_verification": common.EmailVerificationEnabled,
|
||||||
|
"github_oauth": common.GitHubOAuthEnabled,
|
||||||
|
"github_client_id": common.GitHubClientId,
|
||||||
|
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||||
|
"linuxdo_client_id": common.LinuxDOClientId,
|
||||||
|
"linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel,
|
||||||
|
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||||
|
"telegram_bot_name": common.TelegramBotName,
|
||||||
|
"system_name": common.SystemName,
|
||||||
|
"logo": common.Logo,
|
||||||
|
"footer_html": common.Footer,
|
||||||
|
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||||
|
"wechat_login": common.WeChatAuthEnabled,
|
||||||
|
"server_address": setting.ServerAddress,
|
||||||
|
"price": setting.Price,
|
||||||
|
"stripe_unit_price": setting.StripeUnitPrice,
|
||||||
|
"min_topup": setting.MinTopUp,
|
||||||
|
"stripe_min_topup": setting.StripeMinTopUp,
|
||||||
|
"turnstile_check": common.TurnstileCheckEnabled,
|
||||||
|
"turnstile_site_key": common.TurnstileSiteKey,
|
||||||
|
"top_up_link": common.TopUpLink,
|
||||||
|
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||||
|
"quota_per_unit": common.QuotaPerUnit,
|
||||||
|
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||||
|
"enable_batch_update": common.BatchUpdateEnabled,
|
||||||
|
"enable_drawing": common.DrawingEnabled,
|
||||||
|
"enable_task": common.TaskEnabled,
|
||||||
|
"enable_data_export": common.DataExportEnabled,
|
||||||
|
"data_export_default_time": common.DataExportDefaultTime,
|
||||||
|
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||||
|
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||||
|
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||||
|
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||||
|
"chats": setting.Chats,
|
||||||
|
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||||
|
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||||
|
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||||
|
"pay_methods": setting.PayMethods,
|
||||||
|
"usd_exchange_rate": setting.USDExchangeRate,
|
||||||
|
|
||||||
|
// 面板启用开关
|
||||||
|
"api_info_enabled": cs.ApiInfoEnabled,
|
||||||
|
"uptime_kuma_enabled": cs.UptimeKumaEnabled,
|
||||||
|
"announcements_enabled": cs.AnnouncementsEnabled,
|
||||||
|
"faq_enabled": cs.FAQEnabled,
|
||||||
|
|
||||||
|
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||||
|
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||||
|
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||||
|
"setup": constant.Setup,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据启用状态注入可选内容
|
||||||
|
if cs.ApiInfoEnabled {
|
||||||
|
data["api_info"] = console_setting.GetApiInfo()
|
||||||
|
}
|
||||||
|
if cs.AnnouncementsEnabled {
|
||||||
|
data["announcements"] = console_setting.GetAnnouncements()
|
||||||
|
}
|
||||||
|
if cs.FAQEnabled {
|
||||||
|
data["faq"] = console_setting.GetFAQ()
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": gin.H{
|
"data": data,
|
||||||
"version": common.Version,
|
|
||||||
"start_time": common.StartTime,
|
|
||||||
"email_verification": common.EmailVerificationEnabled,
|
|
||||||
"github_oauth": common.GitHubOAuthEnabled,
|
|
||||||
"github_client_id": common.GitHubClientId,
|
|
||||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
|
||||||
"linuxdo_client_id": common.LinuxDOClientId,
|
|
||||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
|
||||||
"telegram_bot_name": common.TelegramBotName,
|
|
||||||
"system_name": common.SystemName,
|
|
||||||
"logo": common.Logo,
|
|
||||||
"footer_html": common.Footer,
|
|
||||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
|
||||||
"wechat_login": common.WeChatAuthEnabled,
|
|
||||||
"server_address": setting.ServerAddress,
|
|
||||||
"price": setting.Price,
|
|
||||||
"min_topup": setting.MinTopUp,
|
|
||||||
"turnstile_check": common.TurnstileCheckEnabled,
|
|
||||||
"turnstile_site_key": common.TurnstileSiteKey,
|
|
||||||
"top_up_link": common.TopUpLink,
|
|
||||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
|
||||||
"quota_per_unit": common.QuotaPerUnit,
|
|
||||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
|
||||||
"enable_batch_update": common.BatchUpdateEnabled,
|
|
||||||
"enable_drawing": common.DrawingEnabled,
|
|
||||||
"enable_task": common.TaskEnabled,
|
|
||||||
"enable_data_export": common.DataExportEnabled,
|
|
||||||
"data_export_default_time": common.DataExportDefaultTime,
|
|
||||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
|
||||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
|
||||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
|
||||||
"chats": setting.Chats,
|
|
||||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
|
||||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
|
||||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
|
||||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
|
||||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
|
||||||
"setup": constant.Setup,
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -184,10 +219,7 @@ func SendEmailVerification(c *gin.Context) {
|
|||||||
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes)
|
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes)
|
||||||
err := common.SendEmail(subject, email, content)
|
err := common.SendEmail(subject, email, content)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -223,10 +255,7 @@ func SendPasswordResetEmail(c *gin.Context) {
|
|||||||
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
|
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
|
||||||
err := common.SendEmail(subject, email, content)
|
err := common.SendEmail(subject, email, content)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -261,10 +290,7 @@ func ResetPassword(c *gin.Context) {
|
|||||||
password := common.GenerateVerificationCode(12)
|
password := common.GenerateVerificationCode(12)
|
||||||
err = model.ResetUserPasswordByEmail(req.Email, password)
|
err = model.ResetUserPasswordByEmail(req.Email, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
common.DeleteKey(req.Email, common.PasswordResetPurpose)
|
common.DeleteKey(req.Email, common.PasswordResetPurpose)
|
||||||
|
|||||||
27
controller/missing_models.go
Normal file
27
controller/missing_models.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetMissingModels returns the list of model names that are referenced by channels
|
||||||
|
// but do not have corresponding records in the models meta table.
|
||||||
|
// This helps administrators quickly discover models that need configuration.
|
||||||
|
func GetMissingModels(c *gin.Context) {
|
||||||
|
missing, err := model.GetMissingModels()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": missing,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/samber/lo"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
@@ -14,7 +15,8 @@ import (
|
|||||||
"one-api/relay/channel/minimax"
|
"one-api/relay/channel/minimax"
|
||||||
"one-api/relay/channel/moonshot"
|
"one-api/relay/channel/moonshot"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
"one-api/setting"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/models/list
|
// https://platform.openai.com/docs/api-reference/models/list
|
||||||
@@ -23,30 +25,10 @@ var openAIModels []dto.OpenAIModels
|
|||||||
var openAIModelsMap map[string]dto.OpenAIModels
|
var openAIModelsMap map[string]dto.OpenAIModels
|
||||||
var channelId2Models map[int][]string
|
var channelId2Models map[int][]string
|
||||||
|
|
||||||
func getPermission() []dto.OpenAIModelPermission {
|
|
||||||
var permission []dto.OpenAIModelPermission
|
|
||||||
permission = append(permission, dto.OpenAIModelPermission{
|
|
||||||
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
|
|
||||||
Object: "model_permission",
|
|
||||||
Created: 1626777600,
|
|
||||||
AllowCreateEngine: true,
|
|
||||||
AllowSampling: true,
|
|
||||||
AllowLogprobs: true,
|
|
||||||
AllowSearchIndices: false,
|
|
||||||
AllowView: true,
|
|
||||||
AllowFineTuning: false,
|
|
||||||
Organization: "*",
|
|
||||||
Group: nil,
|
|
||||||
IsBlocking: false,
|
|
||||||
})
|
|
||||||
return permission
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
permission := getPermission()
|
for i := 0; i < constant.APITypeDummy; i++ {
|
||||||
for i := 0; i < relayconstant.APITypeDummy; i++ {
|
if i == constant.APITypeAIProxyLibrary {
|
||||||
if i == relayconstant.APITypeAIProxyLibrary {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
adaptor := relay.GetAdaptor(i)
|
adaptor := relay.GetAdaptor(i)
|
||||||
@@ -54,69 +36,51 @@ func init() {
|
|||||||
modelNames := adaptor.GetModelList()
|
modelNames := adaptor.GetModelList()
|
||||||
for _, modelName := range modelNames {
|
for _, modelName := range modelNames {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: channelName,
|
OwnedBy: channelName,
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, modelName := range ai360.ModelList {
|
for _, modelName := range ai360.ModelList {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: ai360.ChannelName,
|
OwnedBy: ai360.ChannelName,
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for _, modelName := range moonshot.ModelList {
|
for _, modelName := range moonshot.ModelList {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: moonshot.ChannelName,
|
OwnedBy: moonshot.ChannelName,
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for _, modelName := range lingyiwanwu.ModelList {
|
for _, modelName := range lingyiwanwu.ModelList {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: lingyiwanwu.ChannelName,
|
OwnedBy: lingyiwanwu.ChannelName,
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for _, modelName := range minimax.ModelList {
|
for _, modelName := range minimax.ModelList {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: minimax.ChannelName,
|
OwnedBy: minimax.ChannelName,
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
for modelName, _ := range constant.MidjourneyModel2Action {
|
for modelName, _ := range constant.MidjourneyModel2Action {
|
||||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: "midjourney",
|
OwnedBy: "midjourney",
|
||||||
Permission: permission,
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
openAIModelsMap = make(map[string]dto.OpenAIModels)
|
openAIModelsMap = make(map[string]dto.OpenAIModels)
|
||||||
@@ -124,25 +88,29 @@ func init() {
|
|||||||
openAIModelsMap[aiModel.Id] = aiModel
|
openAIModelsMap[aiModel.Id] = aiModel
|
||||||
}
|
}
|
||||||
channelId2Models = make(map[int][]string)
|
channelId2Models = make(map[int][]string)
|
||||||
for i := 1; i <= common.ChannelTypeDummy; i++ {
|
for i := 1; i <= constant.ChannelTypeDummy; i++ {
|
||||||
apiType, success := relayconstant.ChannelType2APIType(i)
|
apiType, success := common.ChannelType2APIType(i)
|
||||||
if !success || apiType == relayconstant.APITypeAIProxyLibrary {
|
if !success || apiType == constant.APITypeAIProxyLibrary {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
meta := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
|
||||||
|
ChannelType: i,
|
||||||
|
}}
|
||||||
adaptor := relay.GetAdaptor(apiType)
|
adaptor := relay.GetAdaptor(apiType)
|
||||||
adaptor.Init(meta)
|
adaptor.Init(meta)
|
||||||
channelId2Models[i] = adaptor.GetModelList()
|
channelId2Models[i] = adaptor.GetModelList()
|
||||||
}
|
}
|
||||||
|
openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
|
||||||
|
return m.Id
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListModels(c *gin.Context) {
|
func ListModels(c *gin.Context, modelType int) {
|
||||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||||
permission := getPermission()
|
|
||||||
|
|
||||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||||
if modelLimitEnable {
|
if modelLimitEnable {
|
||||||
s, ok := c.Get("token_model_limit")
|
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||||
var tokenModelLimit map[string]bool
|
var tokenModelLimit map[string]bool
|
||||||
if ok {
|
if ok {
|
||||||
tokenModelLimit = s.(map[string]bool)
|
tokenModelLimit = s.(map[string]bool)
|
||||||
@@ -150,23 +118,22 @@ func ListModels(c *gin.Context) {
|
|||||||
tokenModelLimit = map[string]bool{}
|
tokenModelLimit = map[string]bool{}
|
||||||
}
|
}
|
||||||
for allowModel, _ := range tokenModelLimit {
|
for allowModel, _ := range tokenModelLimit {
|
||||||
if _, ok := openAIModelsMap[allowModel]; ok {
|
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
|
||||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
|
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
|
||||||
|
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||||
} else {
|
} else {
|
||||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||||
Id: allowModel,
|
Id: allowModel,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: "custom",
|
OwnedBy: "custom",
|
||||||
Permission: permission,
|
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
|
||||||
Root: allowModel,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
userGroup, err := model.GetUserGroup(userId, true)
|
userGroup, err := model.GetUserGroup(userId, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -175,31 +142,73 @@ func ListModels(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
group := userGroup
|
group := userGroup
|
||||||
tokenGroup := c.GetString("token_group")
|
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||||
if tokenGroup != "" {
|
if tokenGroup != "" {
|
||||||
group = tokenGroup
|
group = tokenGroup
|
||||||
}
|
}
|
||||||
models := model.GetGroupModels(group)
|
var models []string
|
||||||
for _, s := range models {
|
if tokenGroup == "auto" {
|
||||||
if _, ok := openAIModelsMap[s]; ok {
|
for _, autoGroup := range setting.AutoGroups {
|
||||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
groupModels := model.GetGroupEnabledModels(autoGroup)
|
||||||
|
for _, g := range groupModels {
|
||||||
|
if !common.StringsContains(models, g) {
|
||||||
|
models = append(models, g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
models = model.GetGroupEnabledModels(group)
|
||||||
|
}
|
||||||
|
for _, modelName := range models {
|
||||||
|
if oaiModel, ok := openAIModelsMap[modelName]; ok {
|
||||||
|
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||||||
|
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||||
} else {
|
} else {
|
||||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||||
Id: s,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
OwnedBy: "custom",
|
OwnedBy: "custom",
|
||||||
Permission: permission,
|
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
|
||||||
Root: s,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.JSON(200, gin.H{
|
switch modelType {
|
||||||
"success": true,
|
case constant.ChannelTypeAnthropic:
|
||||||
"data": userOpenAiModels,
|
useranthropicModels := make([]dto.AnthropicModel, len(userOpenAiModels))
|
||||||
})
|
for i, model := range userOpenAiModels {
|
||||||
|
useranthropicModels[i] = dto.AnthropicModel{
|
||||||
|
ID: model.Id,
|
||||||
|
CreatedAt: time.Unix(int64(model.Created), 0).UTC().Format(time.RFC3339),
|
||||||
|
DisplayName: model.Id,
|
||||||
|
Type: "model",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"data": useranthropicModels,
|
||||||
|
"first_id": useranthropicModels[0].ID,
|
||||||
|
"has_more": false,
|
||||||
|
"last_id": useranthropicModels[len(useranthropicModels)-1].ID,
|
||||||
|
})
|
||||||
|
case constant.ChannelTypeGemini:
|
||||||
|
userGeminiModels := make([]dto.GeminiModel, len(userOpenAiModels))
|
||||||
|
for i, model := range userOpenAiModels {
|
||||||
|
userGeminiModels[i] = dto.GeminiModel{
|
||||||
|
Name: model.Id,
|
||||||
|
DisplayName: model.Id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"models": userGeminiModels,
|
||||||
|
"nextPageToken": nil,
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": userOpenAiModels,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ChannelListModels(c *gin.Context) {
|
func ChannelListModels(c *gin.Context) {
|
||||||
@@ -223,10 +232,20 @@ func EnabledListModels(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func RetrieveModel(c *gin.Context) {
|
func RetrieveModel(c *gin.Context, modelType int) {
|
||||||
modelId := c.Param("model")
|
modelId := c.Param("model")
|
||||||
if aiModel, ok := openAIModelsMap[modelId]; ok {
|
if aiModel, ok := openAIModelsMap[modelId]; ok {
|
||||||
c.JSON(200, aiModel)
|
switch modelType {
|
||||||
|
case constant.ChannelTypeAnthropic:
|
||||||
|
c.JSON(200, dto.AnthropicModel{
|
||||||
|
ID: aiModel.Id,
|
||||||
|
CreatedAt: time.Unix(int64(aiModel.Created), 0).UTC().Format(time.RFC3339),
|
||||||
|
DisplayName: aiModel.Id,
|
||||||
|
Type: "model",
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
c.JSON(200, aiModel)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
openAIError := dto.OpenAIError{
|
openAIError := dto.OpenAIError{
|
||||||
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
||||||
|
|||||||
330
controller/model_meta.go
Normal file
330
controller/model_meta.go
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetAllModelsMeta 获取模型列表(分页)
|
||||||
|
func GetAllModelsMeta(c *gin.Context) {
|
||||||
|
|
||||||
|
pageInfo := common.GetPageQuery(c)
|
||||||
|
modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 批量填充附加字段,提升列表接口性能
|
||||||
|
enrichModels(modelsMeta)
|
||||||
|
var total int64
|
||||||
|
model.DB.Model(&model.Model{}).Count(&total)
|
||||||
|
|
||||||
|
// 统计供应商计数(全部数据,不受分页影响)
|
||||||
|
vendorCounts, _ := model.GetVendorModelCounts()
|
||||||
|
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
|
pageInfo.SetItems(modelsMeta)
|
||||||
|
common.ApiSuccess(c, gin.H{
|
||||||
|
"items": modelsMeta,
|
||||||
|
"total": total,
|
||||||
|
"page": pageInfo.GetPage(),
|
||||||
|
"page_size": pageInfo.GetPageSize(),
|
||||||
|
"vendor_counts": vendorCounts,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchModelsMeta 搜索模型列表
|
||||||
|
func SearchModelsMeta(c *gin.Context) {
|
||||||
|
|
||||||
|
keyword := c.Query("keyword")
|
||||||
|
vendor := c.Query("vendor")
|
||||||
|
pageInfo := common.GetPageQuery(c)
|
||||||
|
|
||||||
|
modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 批量填充附加字段,提升列表接口性能
|
||||||
|
enrichModels(modelsMeta)
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
|
pageInfo.SetItems(modelsMeta)
|
||||||
|
common.ApiSuccess(c, pageInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelMeta 根据 ID 获取单条模型信息
|
||||||
|
func GetModelMeta(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var m model.Model
|
||||||
|
if err := model.DB.First(&m, id).Error; err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
enrichModels([]*model.Model{&m})
|
||||||
|
common.ApiSuccess(c, &m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateModelMeta 新建模型
|
||||||
|
func CreateModelMeta(c *gin.Context) {
|
||||||
|
var m model.Model
|
||||||
|
if err := c.ShouldBindJSON(&m); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m.ModelName == "" {
|
||||||
|
common.ApiErrorMsg(c, "模型名称不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 名称冲突检查
|
||||||
|
if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "模型名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Insert(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
model.RefreshPricing()
|
||||||
|
common.ApiSuccess(c, &m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateModelMeta 更新模型
|
||||||
|
func UpdateModelMeta(c *gin.Context) {
|
||||||
|
statusOnly := c.Query("status_only") == "true"
|
||||||
|
|
||||||
|
var m model.Model
|
||||||
|
if err := c.ShouldBindJSON(&m); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m.Id == 0 {
|
||||||
|
common.ApiErrorMsg(c, "缺少模型 ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if statusOnly {
|
||||||
|
// 只更新状态,防止误清空其他字段
|
||||||
|
if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 名称冲突检查
|
||||||
|
if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "模型名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Update(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
model.RefreshPricing()
|
||||||
|
common.ApiSuccess(c, &m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteModelMeta 删除模型
|
||||||
|
func DeleteModelMeta(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
model.RefreshPricing()
|
||||||
|
common.ApiSuccess(c, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询
|
||||||
|
func enrichModels(models []*model.Model) {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1) 拆分精确与规则匹配
|
||||||
|
exactNames := make([]string, 0)
|
||||||
|
exactIdx := make(map[string][]int) // modelName -> indices in models
|
||||||
|
ruleIndices := make([]int, 0)
|
||||||
|
for i, m := range models {
|
||||||
|
if m == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if m.NameRule == model.NameRuleExact {
|
||||||
|
exactNames = append(exactNames, m.ModelName)
|
||||||
|
exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i)
|
||||||
|
} else {
|
||||||
|
ruleIndices = append(ruleIndices, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 批量查询精确模型的绑定渠道
|
||||||
|
channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames)
|
||||||
|
|
||||||
|
// 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存
|
||||||
|
for name, indices := range exactIdx {
|
||||||
|
chs := channelsByModel[name]
|
||||||
|
for _, idx := range indices {
|
||||||
|
mm := models[idx]
|
||||||
|
if mm.Endpoints == "" {
|
||||||
|
eps := model.GetModelSupportEndpointTypes(mm.ModelName)
|
||||||
|
if b, err := json.Marshal(eps); err == nil {
|
||||||
|
mm.Endpoints = string(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mm.BoundChannels = chs
|
||||||
|
mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName)
|
||||||
|
mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ruleIndices) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) 一次性读取定价缓存,内存匹配所有规则模型
|
||||||
|
pricings := model.GetPricing()
|
||||||
|
|
||||||
|
// 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合
|
||||||
|
matchedNamesByIdx := make(map[int][]string)
|
||||||
|
endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{})
|
||||||
|
groupSetByIdx := make(map[int]map[string]struct{})
|
||||||
|
quotaSetByIdx := make(map[int]map[int]struct{})
|
||||||
|
|
||||||
|
for _, p := range pricings {
|
||||||
|
for _, idx := range ruleIndices {
|
||||||
|
mm := models[idx]
|
||||||
|
var matched bool
|
||||||
|
switch mm.NameRule {
|
||||||
|
case model.NameRulePrefix:
|
||||||
|
matched = strings.HasPrefix(p.ModelName, mm.ModelName)
|
||||||
|
case model.NameRuleSuffix:
|
||||||
|
matched = strings.HasSuffix(p.ModelName, mm.ModelName)
|
||||||
|
case model.NameRuleContains:
|
||||||
|
matched = strings.Contains(p.ModelName, mm.ModelName)
|
||||||
|
}
|
||||||
|
if !matched {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName)
|
||||||
|
|
||||||
|
es := endpointSetByIdx[idx]
|
||||||
|
if es == nil {
|
||||||
|
es = make(map[constant.EndpointType]struct{})
|
||||||
|
endpointSetByIdx[idx] = es
|
||||||
|
}
|
||||||
|
for _, et := range p.SupportedEndpointTypes {
|
||||||
|
es[et] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
gs := groupSetByIdx[idx]
|
||||||
|
if gs == nil {
|
||||||
|
gs = make(map[string]struct{})
|
||||||
|
groupSetByIdx[idx] = gs
|
||||||
|
}
|
||||||
|
for _, g := range p.EnableGroup {
|
||||||
|
gs[g] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
qs := quotaSetByIdx[idx]
|
||||||
|
if qs == nil {
|
||||||
|
qs = make(map[int]struct{})
|
||||||
|
quotaSetByIdx[idx] = qs
|
||||||
|
}
|
||||||
|
qs[p.QuotaType] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5) 汇总所有匹配到的模型名称,批量查询一次渠道
|
||||||
|
allMatchedSet := make(map[string]struct{})
|
||||||
|
for _, names := range matchedNamesByIdx {
|
||||||
|
for _, n := range names {
|
||||||
|
allMatchedSet[n] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
allMatched := make([]string, 0, len(allMatchedSet))
|
||||||
|
for n := range allMatchedSet {
|
||||||
|
allMatched = append(allMatched, n)
|
||||||
|
}
|
||||||
|
matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched)
|
||||||
|
|
||||||
|
// 6) 回填每个规则模型的并集信息
|
||||||
|
for _, idx := range ruleIndices {
|
||||||
|
mm := models[idx]
|
||||||
|
|
||||||
|
// 端点并集 -> 序列化
|
||||||
|
if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" {
|
||||||
|
eps := make([]constant.EndpointType, 0, len(es))
|
||||||
|
for et := range es {
|
||||||
|
eps = append(eps, et)
|
||||||
|
}
|
||||||
|
if b, err := json.Marshal(eps); err == nil {
|
||||||
|
mm.Endpoints = string(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 分组并集
|
||||||
|
if gs, ok := groupSetByIdx[idx]; ok {
|
||||||
|
groups := make([]string, 0, len(gs))
|
||||||
|
for g := range gs {
|
||||||
|
groups = append(groups, g)
|
||||||
|
}
|
||||||
|
mm.EnableGroups = groups
|
||||||
|
}
|
||||||
|
|
||||||
|
// 配额类型集合(保持去重并排序)
|
||||||
|
if qs, ok := quotaSetByIdx[idx]; ok {
|
||||||
|
arr := make([]int, 0, len(qs))
|
||||||
|
for k := range qs {
|
||||||
|
arr = append(arr, k)
|
||||||
|
}
|
||||||
|
sort.Ints(arr)
|
||||||
|
mm.QuotaTypes = arr
|
||||||
|
}
|
||||||
|
|
||||||
|
// 渠道并集
|
||||||
|
names := matchedNamesByIdx[idx]
|
||||||
|
channelSet := make(map[string]model.BoundChannel)
|
||||||
|
for _, n := range names {
|
||||||
|
for _, ch := range matchedChannelsByModel[n] {
|
||||||
|
key := ch.Name + "_" + strconv.Itoa(ch.Type)
|
||||||
|
channelSet[key] = ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(channelSet) > 0 {
|
||||||
|
chs := make([]model.BoundChannel, 0, len(channelSet))
|
||||||
|
for _, ch := range channelSet {
|
||||||
|
chs = append(chs, ch)
|
||||||
|
}
|
||||||
|
mm.BoundChannels = chs
|
||||||
|
}
|
||||||
|
|
||||||
|
// 匹配信息
|
||||||
|
mm.MatchedModels = names
|
||||||
|
mm.MatchedCount = len(names)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -69,7 +69,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if oidcResponse.AccessToken == "" {
|
if oidcResponse.AccessToken == "" {
|
||||||
common.SysError("OIDC 获取 Token 失败,请检查设置!")
|
common.SysLog("OIDC 获取 Token 失败,请检查设置!")
|
||||||
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
}
|
}
|
||||||
defer res2.Body.Close()
|
defer res2.Body.Close()
|
||||||
if res2.StatusCode != http.StatusOK {
|
if res2.StatusCode != http.StatusOK {
|
||||||
common.SysError("OIDC 获取用户信息失败!请检查设置!")
|
common.SysLog("OIDC 获取用户信息失败!请检查设置!")
|
||||||
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,7 +95,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
||||||
common.SysError("OIDC 获取用户信息为空!请检查设置!")
|
common.SysLog("OIDC 获取用户信息为空!请检查设置!")
|
||||||
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
||||||
}
|
}
|
||||||
return &oidcUser, nil
|
return &oidcUser, nil
|
||||||
@@ -126,10 +126,7 @@ func OidcAuth(c *gin.Context) {
|
|||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
oidcUser, err := getOidcUserInfoByCode(code)
|
oidcUser, err := getOidcUserInfoByCode(code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user := model.User{
|
user := model.User{
|
||||||
@@ -195,10 +192,7 @@ func OidcBind(c *gin.Context) {
|
|||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
oidcUser, err := getOidcUserInfoByCode(code)
|
oidcUser, err := getOidcUserInfoByCode(code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user := model.User{
|
user := model.User{
|
||||||
@@ -217,19 +211,13 @@ func OidcBind(c *gin.Context) {
|
|||||||
user.Id = id.(int)
|
user.Id = id.(int)
|
||||||
err = user.FillUserById()
|
err = user.FillUserById()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.OidcId = oidcUser.OpenID
|
user.OidcId = oidcUser.OpenID
|
||||||
err = user.Update(false)
|
err = user.Update(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
|
"one-api/setting/console_setting"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
"one-api/setting/system_setting"
|
"one-api/setting/system_setting"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -102,7 +104,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "GroupRatio":
|
case "GroupRatio":
|
||||||
err = setting.CheckGroupRatio(option.Value)
|
err = ratio_setting.CheckGroupRatio(option.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -119,14 +121,46 @@ func UpdateOption(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
case "console_setting.api_info":
|
||||||
|
err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "console_setting.announcements":
|
||||||
|
err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "console_setting.faq":
|
||||||
|
err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "console_setting.uptime_kuma_groups":
|
||||||
|
err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
err = model.UpdateOption(option.Key, option.Value)
|
err = model.UpdateOption(option.Key, option.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -3,67 +3,58 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/types"
|
||||||
"one-api/setting"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Playground(c *gin.Context) {
|
func Playground(c *gin.Context) {
|
||||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
var newAPIError *types.NewAPIError
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if openaiErr != nil {
|
if newAPIError != nil {
|
||||||
c.JSON(openaiErr.StatusCode, gin.H{
|
c.JSON(newAPIError.StatusCode, gin.H{
|
||||||
"error": openaiErr.Error,
|
"error": newAPIError.ToOpenAIError(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
useAccessToken := c.GetBool("use_access_token")
|
useAccessToken := c.GetBool("use_access_token")
|
||||||
if useAccessToken {
|
if useAccessToken {
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
|
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
playgroundRequest := &dto.PlayGroundRequest{}
|
group := c.GetString("group")
|
||||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
modelName := c.GetString("original_model")
|
||||||
|
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
// Write user context to ensure acceptUnsetRatio is available
|
||||||
|
userCache, err := model.GetUserCache(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
|
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
userCache.WriteContext(c)
|
||||||
|
|
||||||
if playgroundRequest.Model == "" {
|
tempToken := &model.Token{
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
|
UserId: userId,
|
||||||
|
Name: fmt.Sprintf("playground-%s", group),
|
||||||
|
Group: group,
|
||||||
|
}
|
||||||
|
_ = middleware.SetupContextForToken(c, tempToken)
|
||||||
|
_, newAPIError = getChannel(c, group, modelName, 0)
|
||||||
|
if newAPIError != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set("original_model", playgroundRequest.Model)
|
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||||
group := playgroundRequest.Group
|
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||||
userGroup := c.GetString("group")
|
|
||||||
|
|
||||||
if group == "" {
|
Relay(c, types.RelayFormatOpenAI)
|
||||||
group = userGroup
|
|
||||||
} else {
|
|
||||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Set("group", group)
|
|
||||||
}
|
|
||||||
c.Set("token_name", "playground-"+group)
|
|
||||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
|
|
||||||
if err != nil {
|
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
|
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
|
||||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
|
||||||
Relay(c)
|
|
||||||
}
|
}
|
||||||
|
|||||||
90
controller/prefill_group.go
Normal file
90
controller/prefill_group.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetPrefillGroups 获取预填组列表,可通过 ?type=xxx 过滤
|
||||||
|
func GetPrefillGroups(c *gin.Context) {
|
||||||
|
groupType := c.Query("type")
|
||||||
|
groups, err := model.GetAllPrefillGroups(groupType)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, groups)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePrefillGroup 创建新的预填组
|
||||||
|
func CreatePrefillGroup(c *gin.Context) {
|
||||||
|
var g model.PrefillGroup
|
||||||
|
if err := c.ShouldBindJSON(&g); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if g.Name == "" || g.Type == "" {
|
||||||
|
common.ApiErrorMsg(c, "组名称和类型不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 创建前检查名称
|
||||||
|
if dup, err := model.IsPrefillGroupNameDuplicated(0, g.Name); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "组名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.Insert(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, &g)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePrefillGroup 更新预填组
|
||||||
|
func UpdatePrefillGroup(c *gin.Context) {
|
||||||
|
var g model.PrefillGroup
|
||||||
|
if err := c.ShouldBindJSON(&g); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if g.Id == 0 {
|
||||||
|
common.ApiErrorMsg(c, "缺少组 ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 名称冲突检查
|
||||||
|
if dup, err := model.IsPrefillGroupNameDuplicated(g.Id, g.Name); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "组名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.Update(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, &g)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePrefillGroup 删除预填组
|
||||||
|
func DeletePrefillGroup(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := model.DeletePrefillGroupByID(id); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, nil)
|
||||||
|
}
|
||||||
@@ -1,10 +1,11 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"one-api/setting/operation_setting"
|
"one-api/setting/ratio_setting"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetPricing(c *gin.Context) {
|
func GetPricing(c *gin.Context) {
|
||||||
@@ -12,7 +13,7 @@ func GetPricing(c *gin.Context) {
|
|||||||
userId, exists := c.Get("id")
|
userId, exists := c.Get("id")
|
||||||
usableGroup := map[string]string{}
|
usableGroup := map[string]string{}
|
||||||
groupRatio := map[string]float64{}
|
groupRatio := map[string]float64{}
|
||||||
for s, f := range setting.GetGroupRatioCopy() {
|
for s, f := range ratio_setting.GetGroupRatioCopy() {
|
||||||
groupRatio[s] = f
|
groupRatio[s] = f
|
||||||
}
|
}
|
||||||
var group string
|
var group string
|
||||||
@@ -20,27 +21,36 @@ func GetPricing(c *gin.Context) {
|
|||||||
user, err := model.GetUserCache(userId.(int))
|
user, err := model.GetUserCache(userId.(int))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
group = user.Group
|
group = user.Group
|
||||||
|
for g := range groupRatio {
|
||||||
|
ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
|
||||||
|
if ok {
|
||||||
|
groupRatio[g] = ratio
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
usableGroup = setting.GetUserUsableGroups(group)
|
usableGroup = setting.GetUserUsableGroups(group)
|
||||||
// check groupRatio contains usableGroup
|
// check groupRatio contains usableGroup
|
||||||
for group := range setting.GetGroupRatioCopy() {
|
for group := range ratio_setting.GetGroupRatioCopy() {
|
||||||
if _, ok := usableGroup[group]; !ok {
|
if _, ok := usableGroup[group]; !ok {
|
||||||
delete(groupRatio, group)
|
delete(groupRatio, group)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": pricing,
|
"data": pricing,
|
||||||
"group_ratio": groupRatio,
|
"vendors": model.GetVendors(),
|
||||||
"usable_group": usableGroup,
|
"group_ratio": groupRatio,
|
||||||
|
"usable_group": usableGroup,
|
||||||
|
"supported_endpoint": model.GetSupportedEndpointMap(),
|
||||||
|
"auto_groups": setting.AutoGroups,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ResetModelRatio(c *gin.Context) {
|
func ResetModelRatio(c *gin.Context) {
|
||||||
defaultStr := operation_setting.DefaultModelRatio2JSONString()
|
defaultStr := ratio_setting.DefaultModelRatio2JSONString()
|
||||||
err := model.UpdateOption("ModelRatio", defaultStr)
|
err := model.UpdateOption("ModelRatio", defaultStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
@@ -49,7 +59,7 @@ func ResetModelRatio(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
|
err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
|
|||||||
24
controller/ratio_config.go
Normal file
24
controller/ratio_config.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetRatioConfig(c *gin.Context) {
|
||||||
|
if !ratio_setting.IsExposeRatioEnabled() {
|
||||||
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "倍率配置接口未启用",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": ratio_setting.GetExposedData(),
|
||||||
|
})
|
||||||
|
}
|
||||||
474
controller/ratio_sync.go
Normal file
474
controller/ratio_sync.go
Normal file
@@ -0,0 +1,474 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/logger"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultTimeoutSeconds = 10
|
||||||
|
defaultEndpoint = "/api/ratio_config"
|
||||||
|
maxConcurrentFetches = 8
|
||||||
|
)
|
||||||
|
|
||||||
|
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||||
|
|
||||||
|
type upstreamResult struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Data map[string]any `json:"data,omitempty"`
|
||||||
|
Err string `json:"err,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func FetchUpstreamRatios(c *gin.Context) {
|
||||||
|
var req dto.UpstreamRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Timeout <= 0 {
|
||||||
|
req.Timeout = defaultTimeoutSeconds
|
||||||
|
}
|
||||||
|
|
||||||
|
var upstreams []dto.UpstreamDTO
|
||||||
|
|
||||||
|
if len(req.Upstreams) > 0 {
|
||||||
|
for _, u := range req.Upstreams {
|
||||||
|
if strings.HasPrefix(u.BaseURL, "http") {
|
||||||
|
if u.Endpoint == "" {
|
||||||
|
u.Endpoint = defaultEndpoint
|
||||||
|
}
|
||||||
|
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
|
||||||
|
upstreams = append(upstreams, u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if len(req.ChannelIDs) > 0 {
|
||||||
|
intIds := make([]int, 0, len(req.ChannelIDs))
|
||||||
|
for _, id64 := range req.ChannelIDs {
|
||||||
|
intIds = append(intIds, int(id64))
|
||||||
|
}
|
||||||
|
dbChannels, err := model.GetChannelsByIds(intIds)
|
||||||
|
if err != nil {
|
||||||
|
logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, ch := range dbChannels {
|
||||||
|
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
||||||
|
upstreams = append(upstreams, dto.UpstreamDTO{
|
||||||
|
ID: ch.Id,
|
||||||
|
Name: ch.Name,
|
||||||
|
BaseURL: strings.TrimRight(base, "/"),
|
||||||
|
Endpoint: "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(upstreams) == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
ch := make(chan upstreamResult, len(upstreams))
|
||||||
|
|
||||||
|
sem := make(chan struct{}, maxConcurrentFetches)
|
||||||
|
|
||||||
|
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
||||||
|
|
||||||
|
for _, chn := range upstreams {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(chItem dto.UpstreamDTO) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
sem <- struct{}{}
|
||||||
|
defer func() { <-sem }()
|
||||||
|
|
||||||
|
endpoint := chItem.Endpoint
|
||||||
|
if endpoint == "" {
|
||||||
|
endpoint = defaultEndpoint
|
||||||
|
} else if !strings.HasPrefix(endpoint, "/") {
|
||||||
|
endpoint = "/" + endpoint
|
||||||
|
}
|
||||||
|
fullURL := chItem.BaseURL + endpoint
|
||||||
|
|
||||||
|
uniqueName := chItem.Name
|
||||||
|
if chItem.ID != 0 {
|
||||||
|
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 兼容两种上游接口格式:
|
||||||
|
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
||||||
|
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
||||||
|
var body struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data json.RawMessage `json:"data"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||||
|
logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !body.Success {
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试按 type1 解析
|
||||||
|
var type1Data map[string]any
|
||||||
|
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
||||||
|
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
||||||
|
isType1 := false
|
||||||
|
for _, rt := range ratioTypes {
|
||||||
|
if _, ok := type1Data[rt]; ok {
|
||||||
|
isType1 = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isType1 {
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
||||||
|
var pricingItems []struct {
|
||||||
|
ModelName string `json:"model_name"`
|
||||||
|
QuotaType int `json:"quota_type"`
|
||||||
|
ModelRatio float64 `json:"model_ratio"`
|
||||||
|
ModelPrice float64 `json:"model_price"`
|
||||||
|
CompletionRatio float64 `json:"completion_ratio"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
|
||||||
|
logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelRatioMap := make(map[string]float64)
|
||||||
|
completionRatioMap := make(map[string]float64)
|
||||||
|
modelPriceMap := make(map[string]float64)
|
||||||
|
|
||||||
|
for _, item := range pricingItems {
|
||||||
|
if item.QuotaType == 1 {
|
||||||
|
modelPriceMap[item.ModelName] = item.ModelPrice
|
||||||
|
} else {
|
||||||
|
modelRatioMap[item.ModelName] = item.ModelRatio
|
||||||
|
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
||||||
|
completionRatioMap[item.ModelName] = item.CompletionRatio
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
converted := make(map[string]any)
|
||||||
|
|
||||||
|
if len(modelRatioMap) > 0 {
|
||||||
|
ratioAny := make(map[string]any, len(modelRatioMap))
|
||||||
|
for k, v := range modelRatioMap {
|
||||||
|
ratioAny[k] = v
|
||||||
|
}
|
||||||
|
converted["model_ratio"] = ratioAny
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(completionRatioMap) > 0 {
|
||||||
|
compAny := make(map[string]any, len(completionRatioMap))
|
||||||
|
for k, v := range completionRatioMap {
|
||||||
|
compAny[k] = v
|
||||||
|
}
|
||||||
|
converted["completion_ratio"] = compAny
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(modelPriceMap) > 0 {
|
||||||
|
priceAny := make(map[string]any, len(modelPriceMap))
|
||||||
|
for k, v := range modelPriceMap {
|
||||||
|
priceAny[k] = v
|
||||||
|
}
|
||||||
|
converted["model_price"] = priceAny
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
||||||
|
}(chn)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(ch)
|
||||||
|
|
||||||
|
localData := ratio_setting.GetExposedData()
|
||||||
|
|
||||||
|
var testResults []dto.TestResult
|
||||||
|
var successfulChannels []struct {
|
||||||
|
name string
|
||||||
|
data map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
for r := range ch {
|
||||||
|
if r.Err != "" {
|
||||||
|
testResults = append(testResults, dto.TestResult{
|
||||||
|
Name: r.Name,
|
||||||
|
Status: "error",
|
||||||
|
Error: r.Err,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
testResults = append(testResults, dto.TestResult{
|
||||||
|
Name: r.Name,
|
||||||
|
Status: "success",
|
||||||
|
})
|
||||||
|
successfulChannels = append(successfulChannels, struct {
|
||||||
|
name string
|
||||||
|
data map[string]any
|
||||||
|
}{name: r.Name, data: r.Data})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
differences := buildDifferences(localData, successfulChannels)
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": gin.H{
|
||||||
|
"differences": differences,
|
||||||
|
"test_results": testResults,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||||
|
name string
|
||||||
|
data map[string]any
|
||||||
|
}) map[string]map[string]dto.DifferenceItem {
|
||||||
|
differences := make(map[string]map[string]dto.DifferenceItem)
|
||||||
|
|
||||||
|
allModels := make(map[string]struct{})
|
||||||
|
|
||||||
|
for _, ratioType := range ratioTypes {
|
||||||
|
if localRatioAny, ok := localData[ratioType]; ok {
|
||||||
|
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||||
|
for modelName := range localRatio {
|
||||||
|
allModels[modelName] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, channel := range successfulChannels {
|
||||||
|
for _, ratioType := range ratioTypes {
|
||||||
|
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||||
|
for modelName := range upstreamRatio {
|
||||||
|
allModels[modelName] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
confidenceMap := make(map[string]map[string]bool)
|
||||||
|
|
||||||
|
// 预处理阶段:检查pricing接口的可信度
|
||||||
|
for _, channel := range successfulChannels {
|
||||||
|
confidenceMap[channel.name] = make(map[string]bool)
|
||||||
|
|
||||||
|
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
||||||
|
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
||||||
|
|
||||||
|
if hasModelRatio && hasCompletionRatio {
|
||||||
|
// 遍历所有模型,检查是否满足不可信条件
|
||||||
|
for modelName := range allModels {
|
||||||
|
// 默认为可信
|
||||||
|
confidenceMap[channel.name][modelName] = true
|
||||||
|
|
||||||
|
// 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
|
||||||
|
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
||||||
|
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
||||||
|
// 转换为float64进行比较
|
||||||
|
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
||||||
|
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
||||||
|
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
||||||
|
confidenceMap[channel.name][modelName] = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 如果不是从pricing接口获取的数据,则全部标记为可信
|
||||||
|
for modelName := range allModels {
|
||||||
|
confidenceMap[channel.name][modelName] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for modelName := range allModels {
|
||||||
|
for _, ratioType := range ratioTypes {
|
||||||
|
var localValue interface{} = nil
|
||||||
|
if localRatioAny, ok := localData[ratioType]; ok {
|
||||||
|
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||||
|
if val, exists := localRatio[modelName]; exists {
|
||||||
|
localValue = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamValues := make(map[string]interface{})
|
||||||
|
confidenceValues := make(map[string]bool)
|
||||||
|
hasUpstreamValue := false
|
||||||
|
hasDifference := false
|
||||||
|
|
||||||
|
for _, channel := range successfulChannels {
|
||||||
|
var upstreamValue interface{} = nil
|
||||||
|
|
||||||
|
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||||
|
if val, exists := upstreamRatio[modelName]; exists {
|
||||||
|
upstreamValue = val
|
||||||
|
hasUpstreamValue = true
|
||||||
|
|
||||||
|
if localValue != nil && localValue != val {
|
||||||
|
hasDifference = true
|
||||||
|
} else if localValue == val {
|
||||||
|
upstreamValue = "same"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if upstreamValue == nil && localValue == nil {
|
||||||
|
upstreamValue = "same"
|
||||||
|
}
|
||||||
|
|
||||||
|
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
||||||
|
hasDifference = true
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamValues[channel.name] = upstreamValue
|
||||||
|
|
||||||
|
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldInclude := false
|
||||||
|
|
||||||
|
if localValue != nil {
|
||||||
|
if hasDifference {
|
||||||
|
shouldInclude = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if hasUpstreamValue {
|
||||||
|
shouldInclude = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldInclude {
|
||||||
|
if differences[modelName] == nil {
|
||||||
|
differences[modelName] = make(map[string]dto.DifferenceItem)
|
||||||
|
}
|
||||||
|
differences[modelName][ratioType] = dto.DifferenceItem{
|
||||||
|
Current: localValue,
|
||||||
|
Upstreams: upstreamValues,
|
||||||
|
Confidence: confidenceValues,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
channelHasDiff := make(map[string]bool)
|
||||||
|
for _, ratioMap := range differences {
|
||||||
|
for _, item := range ratioMap {
|
||||||
|
for chName, val := range item.Upstreams {
|
||||||
|
if val != nil && val != "same" {
|
||||||
|
channelHasDiff[chName] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for modelName, ratioMap := range differences {
|
||||||
|
for ratioType, item := range ratioMap {
|
||||||
|
for chName := range item.Upstreams {
|
||||||
|
if !channelHasDiff[chName] {
|
||||||
|
delete(item.Upstreams, chName)
|
||||||
|
delete(item.Confidence, chName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
allSame := true
|
||||||
|
for _, v := range item.Upstreams {
|
||||||
|
if v != "same" {
|
||||||
|
allSame = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(item.Upstreams) == 0 || allSame {
|
||||||
|
delete(ratioMap, ratioType)
|
||||||
|
} else {
|
||||||
|
differences[modelName][ratioType] = item
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ratioMap) == 0 {
|
||||||
|
delete(differences, modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return differences
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSyncableChannels(c *gin.Context) {
|
||||||
|
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var syncableChannels []dto.SyncableChannel
|
||||||
|
for _, channel := range channels {
|
||||||
|
if channel.GetBaseURL() != "" {
|
||||||
|
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||||
|
ID: channel.Id,
|
||||||
|
Name: channel.Name,
|
||||||
|
BaseURL: channel.GetBaseURL(),
|
||||||
|
Status: channel.Status,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": syncableChannels,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,90 +1,52 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAllRedemptions(c *gin.Context) {
|
func GetAllRedemptions(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
redemptions, total, err := model.GetAllRedemptions(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
if p < 0 {
|
|
||||||
p = 0
|
|
||||||
}
|
|
||||||
if pageSize < 1 {
|
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
redemptions, total, err := model.GetAllRedemptions((p-1)*pageSize, pageSize)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
pageInfo.SetTotal(int(total))
|
||||||
"success": true,
|
pageInfo.SetItems(redemptions)
|
||||||
"message": "",
|
common.ApiSuccess(c, pageInfo)
|
||||||
"data": gin.H{
|
|
||||||
"items": redemptions,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchRedemptions(c *gin.Context) {
|
func SearchRedemptions(c *gin.Context) {
|
||||||
keyword := c.Query("keyword")
|
keyword := c.Query("keyword")
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
redemptions, total, err := model.SearchRedemptions(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
if p < 0 {
|
|
||||||
p = 0
|
|
||||||
}
|
|
||||||
if pageSize < 1 {
|
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
redemptions, total, err := model.SearchRedemptions(keyword, (p-1)*pageSize, pageSize)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
pageInfo.SetTotal(int(total))
|
||||||
"success": true,
|
pageInfo.SetItems(redemptions)
|
||||||
"message": "",
|
common.ApiSuccess(c, pageInfo)
|
||||||
"data": gin.H{
|
|
||||||
"items": redemptions,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRedemption(c *gin.Context) {
|
func GetRedemption(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
redemption, err := model.GetRedemptionById(id)
|
redemption, err := model.GetRedemptionById(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -99,13 +61,10 @@ func AddRedemption(c *gin.Context) {
|
|||||||
redemption := model.Redemption{}
|
redemption := model.Redemption{}
|
||||||
err := c.ShouldBindJSON(&redemption)
|
err := c.ShouldBindJSON(&redemption)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(redemption.Name) == 0 || len(redemption.Name) > 20 {
|
if utf8.RuneCountInString(redemption.Name) == 0 || utf8.RuneCountInString(redemption.Name) > 20 {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "兑换码名称长度必须在1-20之间",
|
"message": "兑换码名称长度必须在1-20之间",
|
||||||
@@ -126,6 +85,10 @@ func AddRedemption(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
var keys []string
|
var keys []string
|
||||||
for i := 0; i < redemption.Count; i++ {
|
for i := 0; i < redemption.Count; i++ {
|
||||||
key := common.GetUUID()
|
key := common.GetUUID()
|
||||||
@@ -135,6 +98,7 @@ func AddRedemption(c *gin.Context) {
|
|||||||
Key: key,
|
Key: key,
|
||||||
CreatedTime: common.GetTimestamp(),
|
CreatedTime: common.GetTimestamp(),
|
||||||
Quota: redemption.Quota,
|
Quota: redemption.Quota,
|
||||||
|
ExpiredTime: redemption.ExpiredTime,
|
||||||
}
|
}
|
||||||
err = cleanRedemption.Insert()
|
err = cleanRedemption.Insert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -159,10 +123,7 @@ func DeleteRedemption(c *gin.Context) {
|
|||||||
id, _ := strconv.Atoi(c.Param("id"))
|
id, _ := strconv.Atoi(c.Param("id"))
|
||||||
err := model.DeleteRedemptionById(id)
|
err := model.DeleteRedemptionById(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -177,33 +138,30 @@ func UpdateRedemption(c *gin.Context) {
|
|||||||
redemption := model.Redemption{}
|
redemption := model.Redemption{}
|
||||||
err := c.ShouldBindJSON(&redemption)
|
err := c.ShouldBindJSON(&redemption)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cleanRedemption, err := model.GetRedemptionById(redemption.Id)
|
cleanRedemption, err := model.GetRedemptionById(redemption.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if statusOnly != "" {
|
if statusOnly == "" {
|
||||||
cleanRedemption.Status = redemption.Status
|
if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
|
||||||
} else {
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
// If you add more fields, please also update redemption.Update()
|
// If you add more fields, please also update redemption.Update()
|
||||||
cleanRedemption.Name = redemption.Name
|
cleanRedemption.Name = redemption.Name
|
||||||
cleanRedemption.Quota = redemption.Quota
|
cleanRedemption.Quota = redemption.Quota
|
||||||
|
cleanRedemption.ExpiredTime = redemption.ExpiredTime
|
||||||
|
}
|
||||||
|
if statusOnly != "" {
|
||||||
|
cleanRedemption.Status = redemption.Status
|
||||||
}
|
}
|
||||||
err = cleanRedemption.Update()
|
err = cleanRedemption.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -213,3 +171,24 @@ func UpdateRedemption(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DeleteInvalidRedemption(c *gin.Context) {
|
||||||
|
rows, err := model.DeleteInvalidRedemptions()
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": rows,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateExpiredTime(expired int64) error {
|
||||||
|
if expired != 0 && expired < common.GetTimestamp() {
|
||||||
|
return errors.New("过期时间不能早于当前时间")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,114 +2,192 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
constant2 "one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/middleware"
|
"one-api/middleware"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"one-api/relay/constant"
|
relaycommon "one-api/relay/common"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/setting"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
|
||||||
var err *dto.OpenAIErrorWithStatusCode
|
var err *types.NewAPIError
|
||||||
switch relayMode {
|
switch info.RelayMode {
|
||||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||||
err = relay.ImageHelper(c)
|
err = relay.ImageHelper(c, info)
|
||||||
case relayconstant.RelayModeAudioSpeech:
|
case relayconstant.RelayModeAudioSpeech:
|
||||||
fallthrough
|
fallthrough
|
||||||
case relayconstant.RelayModeAudioTranslation:
|
case relayconstant.RelayModeAudioTranslation:
|
||||||
fallthrough
|
fallthrough
|
||||||
case relayconstant.RelayModeAudioTranscription:
|
case relayconstant.RelayModeAudioTranscription:
|
||||||
err = relay.AudioHelper(c)
|
err = relay.AudioHelper(c, info)
|
||||||
case relayconstant.RelayModeRerank:
|
case relayconstant.RelayModeRerank:
|
||||||
err = relay.RerankHelper(c, relayMode)
|
err = relay.RerankHelper(c, info)
|
||||||
case relayconstant.RelayModeEmbeddings:
|
case relayconstant.RelayModeEmbeddings:
|
||||||
err = relay.EmbeddingHelper(c)
|
err = relay.EmbeddingHelper(c, info)
|
||||||
case relayconstant.RelayModeResponses:
|
case relayconstant.RelayModeResponses:
|
||||||
err = relay.ResponsesHelper(c)
|
err = relay.ResponsesHelper(c, info)
|
||||||
case relayconstant.RelayModeGemini:
|
|
||||||
err = relay.GeminiHelper(c)
|
|
||||||
default:
|
default:
|
||||||
err = relay.TextHelper(c)
|
err = relay.TextHelper(c, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
if constant2.ErrorLogEnabled && err != nil {
|
|
||||||
// 保存错误日志到mysql中
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
modelName := c.GetString("original_model")
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
userGroup := c.GetString("group")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
other := make(map[string]interface{})
|
|
||||||
other["error_type"] = err.Error.Type
|
|
||||||
other["error_code"] = err.Error.Code
|
|
||||||
other["status_code"] = err.StatusCode
|
|
||||||
other["channel_id"] = channelId
|
|
||||||
other["channel_name"] = c.GetString("channel_name")
|
|
||||||
other["channel_type"] = c.GetInt("channel_type")
|
|
||||||
|
|
||||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error.Message, tokenId, 0, false, userGroup, other)
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
|
||||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
var err *types.NewAPIError
|
||||||
|
if strings.Contains(c.Request.URL.Path, "embed") {
|
||||||
|
err = relay.GeminiEmbeddingHandler(c, info)
|
||||||
|
} else {
|
||||||
|
err = relay.GeminiHelper(c, info)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||||
|
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
group := c.GetString("group")
|
group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
||||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
|
||||||
|
var (
|
||||||
|
newAPIError *types.NewAPIError
|
||||||
|
ws *websocket.Conn
|
||||||
|
)
|
||||||
|
|
||||||
|
if relayFormat == types.RelayFormatOpenAIRealtime {
|
||||||
|
var err error
|
||||||
|
ws, err = upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if newAPIError != nil {
|
||||||
|
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||||
|
switch relayFormat {
|
||||||
|
case types.RelayFormatOpenAIRealtime:
|
||||||
|
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
||||||
|
case types.RelayFormatClaude:
|
||||||
|
c.JSON(newAPIError.StatusCode, gin.H{
|
||||||
|
"type": "error",
|
||||||
|
"error": newAPIError.ToClaudeError(),
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
c.JSON(newAPIError.StatusCode, gin.H{
|
||||||
|
"error": newAPIError.ToOpenAIError(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
request, err := helper.GetAndValidateRequest(c, relayFormat)
|
||||||
|
if err != nil {
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws)
|
||||||
|
if err != nil {
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := request.GetTokenCountMeta()
|
||||||
|
|
||||||
|
if setting.ShouldCheckPromptSensitive() {
|
||||||
|
contains, words := service.CheckSensitiveText(meta.CombineText)
|
||||||
|
if contains {
|
||||||
|
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := service.CountRequestToken(c, meta, relayInfo)
|
||||||
|
if err != nil {
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
relayInfo.SetPromptTokens(tokens)
|
||||||
|
|
||||||
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
|
||||||
|
if err != nil {
|
||||||
|
newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
|
||||||
|
|
||||||
|
preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
|
if newAPIError != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
// Only return quota if downstream failed and quota was actually pre-consumed
|
||||||
|
if newAPIError != nil && preConsumedQuota != 0 {
|
||||||
|
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
for i := 0; i <= common.RetryTimes; i++ {
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, err.Error())
|
logger.LogError(c, err.Error())
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
newAPIError = err
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
openaiErr = relayRequest(c, relayMode, channel)
|
addUsedChannel(c, channel.Id)
|
||||||
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
|
||||||
if openaiErr == nil {
|
switch relayFormat {
|
||||||
return // 成功处理请求,直接返回
|
case types.RelayFormatOpenAIRealtime:
|
||||||
|
newAPIError = relay.WssHelper(c, relayInfo)
|
||||||
|
case types.RelayFormatClaude:
|
||||||
|
newAPIError = relay.ClaudeHelper(c, relayInfo)
|
||||||
|
case types.RelayFormatGemini:
|
||||||
|
newAPIError = geminiRelayHandler(c, relayInfo)
|
||||||
|
default:
|
||||||
|
newAPIError = relayHandler(c, relayInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
if newAPIError == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||||
|
|
||||||
|
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
if len(useChannel) > 1 {
|
if len(useChannel) > 1 {
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||||
common.LogInfo(c, retryLogStr)
|
logger.LogInfo(c, retryLogStr)
|
||||||
}
|
|
||||||
|
|
||||||
if openaiErr != nil {
|
|
||||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
|
||||||
common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message))
|
|
||||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
|
||||||
}
|
|
||||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
|
||||||
c.JSON(openaiErr.StatusCode, gin.H{
|
|
||||||
"error": openaiErr.Error,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,132 +198,13 @@ var upgrader = websocket.Upgrader{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func WssRelay(c *gin.Context) {
|
|
||||||
// 将 HTTP 连接升级为 WebSocket 连接
|
|
||||||
|
|
||||||
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
||||||
defer ws.Close()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
|
|
||||||
helper.WssError(c, ws, openaiErr.Error)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
|
||||||
group := c.GetString("group")
|
|
||||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
|
||||||
originalModel := c.GetString("original_model")
|
|
||||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(c, err.Error())
|
|
||||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
openaiErr = wssRequest(c, ws, relayMode, channel)
|
|
||||||
|
|
||||||
if openaiErr == nil {
|
|
||||||
return // 成功处理请求,直接返回
|
|
||||||
}
|
|
||||||
|
|
||||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
|
||||||
|
|
||||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
|
||||||
if len(useChannel) > 1 {
|
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
|
||||||
common.LogInfo(c, retryLogStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if openaiErr != nil {
|
|
||||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
|
||||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
|
||||||
}
|
|
||||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
|
||||||
helper.WssError(c, ws, openaiErr.Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func RelayClaude(c *gin.Context) {
|
|
||||||
//relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
|
||||||
group := c.GetString("group")
|
|
||||||
originalModel := c.GetString("original_model")
|
|
||||||
var claudeErr *dto.ClaudeErrorWithStatusCode
|
|
||||||
|
|
||||||
for i := 0; i <= common.RetryTimes; i++ {
|
|
||||||
channel, err := getChannel(c, group, originalModel, i)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(c, err.Error())
|
|
||||||
claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
claudeErr = claudeRequest(c, channel)
|
|
||||||
|
|
||||||
if claudeErr == nil {
|
|
||||||
return // 成功处理请求,直接返回
|
|
||||||
}
|
|
||||||
|
|
||||||
openaiErr := service.ClaudeErrorToOpenAIError(claudeErr)
|
|
||||||
|
|
||||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
|
||||||
|
|
||||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
|
||||||
if len(useChannel) > 1 {
|
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
|
||||||
common.LogInfo(c, retryLogStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if claudeErr != nil {
|
|
||||||
claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId)
|
|
||||||
c.JSON(claudeErr.StatusCode, gin.H{
|
|
||||||
"type": "error",
|
|
||||||
"error": claudeErr.Error,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
|
||||||
addUsedChannel(c, channel.Id)
|
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
||||||
return relayHandler(c, relayMode)
|
|
||||||
}
|
|
||||||
|
|
||||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
|
||||||
addUsedChannel(c, channel.Id)
|
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
||||||
return relay.WssHelper(c, ws)
|
|
||||||
}
|
|
||||||
|
|
||||||
func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode {
|
|
||||||
addUsedChannel(c, channel.Id)
|
|
||||||
requestBody, _ := common.GetRequestBody(c)
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
||||||
return relay.ClaudeHelper(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func addUsedChannel(c *gin.Context, channelId int) {
|
func addUsedChannel(c *gin.Context, channelId int) {
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||||
c.Set("use_channel", useChannel)
|
c.Set("use_channel", useChannel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
|
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
|
||||||
if retryCount == 0 {
|
if retryCount == 0 {
|
||||||
autoBan := c.GetBool("auto_ban")
|
autoBan := c.GetBool("auto_ban")
|
||||||
autoBanInt := 1
|
autoBanInt := 1
|
||||||
@@ -259,19 +218,28 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
|||||||
AutoBan: &autoBanInt,
|
AutoBan: &autoBanInt,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
|
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
|
return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败(retry): %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||||
|
}
|
||||||
|
if channel == nil {
|
||||||
|
return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在(数据库一致性已被破坏,retry)", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
|
||||||
|
}
|
||||||
|
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
if newAPIError != nil {
|
||||||
|
return nil, newAPIError
|
||||||
}
|
}
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
|
||||||
return channel, nil
|
return channel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
|
func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
|
||||||
if openaiErr == nil {
|
if openaiErr == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if openaiErr.LocalError {
|
if types.IsChannelError(openaiErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if types.IsSkipRetryError(openaiErr) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if retryTimes <= 0 {
|
if retryTimes <= 0 {
|
||||||
@@ -294,10 +262,6 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||||
channelType := c.GetInt("channel_type")
|
|
||||||
if channelType == common.ChannelTypeAnthropic {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if openaiErr.StatusCode == 408 {
|
if openaiErr.StatusCode == 408 {
|
||||||
@@ -310,45 +274,85 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
|
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
|
||||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
gopool.Go(func() {
|
||||||
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||||
|
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||||
|
service.DisableChannel(channelError, err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
|
||||||
|
// 保存错误日志到mysql中
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
tokenName := c.GetString("token_name")
|
||||||
|
modelName := c.GetString("original_model")
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
userGroup := c.GetString("group")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
other := make(map[string]interface{})
|
||||||
|
other["error_type"] = err.GetErrorType()
|
||||||
|
other["error_code"] = err.GetErrorCode()
|
||||||
|
other["status_code"] = err.StatusCode
|
||||||
|
other["channel_id"] = channelId
|
||||||
|
other["channel_name"] = c.GetString("channel_name")
|
||||||
|
other["channel_type"] = c.GetInt("channel_type")
|
||||||
|
adminInfo := make(map[string]interface{})
|
||||||
|
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
|
||||||
|
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
|
||||||
|
if isMultiKey {
|
||||||
|
adminInfo["is_multi_key"] = true
|
||||||
|
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
||||||
|
}
|
||||||
|
other["admin_info"] = adminInfo
|
||||||
|
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourney(c *gin.Context) {
|
func RelayMidjourney(c *gin.Context) {
|
||||||
relayMode := c.GetInt("relay_mode")
|
relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil)
|
||||||
var err *dto.MidjourneyResponse
|
|
||||||
switch relayMode {
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
|
"description": fmt.Sprintf("failed to generate relay info: %s", err.Error()),
|
||||||
|
"type": "upstream_error",
|
||||||
|
"code": 4,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var mjErr *dto.MidjourneyResponse
|
||||||
|
switch relayInfo.RelayMode {
|
||||||
case relayconstant.RelayModeMidjourneyNotify:
|
case relayconstant.RelayModeMidjourneyNotify:
|
||||||
err = relay.RelayMidjourneyNotify(c)
|
mjErr = relay.RelayMidjourneyNotify(c)
|
||||||
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
||||||
err = relay.RelayMidjourneyTask(c, relayMode)
|
mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode)
|
||||||
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
||||||
err = relay.RelayMidjourneyTaskImageSeed(c)
|
mjErr = relay.RelayMidjourneyTaskImageSeed(c)
|
||||||
case relayconstant.RelayModeSwapFace:
|
case relayconstant.RelayModeSwapFace:
|
||||||
err = relay.RelaySwapFace(c)
|
mjErr = relay.RelaySwapFace(c, relayInfo)
|
||||||
default:
|
default:
|
||||||
err = relay.RelayMidjourneySubmit(c, relayMode)
|
mjErr = relay.RelayMidjourneySubmit(c, relayInfo)
|
||||||
}
|
}
|
||||||
//err = relayMidjourneySubmit(c, relayMode)
|
//err = relayMidjourneySubmit(c, relayMode)
|
||||||
log.Println(err)
|
log.Println(mjErr)
|
||||||
if err != nil {
|
if mjErr != nil {
|
||||||
statusCode := http.StatusBadRequest
|
statusCode := http.StatusBadRequest
|
||||||
if err.Code == 30 {
|
if mjErr.Code == 30 {
|
||||||
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||||
statusCode = http.StatusTooManyRequests
|
statusCode = http.StatusTooManyRequests
|
||||||
}
|
}
|
||||||
c.JSON(statusCode, gin.H{
|
c.JSON(statusCode, gin.H{
|
||||||
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
|
"description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result),
|
||||||
"type": "upstream_error",
|
"type": "upstream_error",
|
||||||
"code": err.Code,
|
"code": mjErr.Code,
|
||||||
})
|
})
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,26 +392,27 @@ func RelayTask(c *gin.Context) {
|
|||||||
retryTimes = 0
|
retryTimes = 0
|
||||||
}
|
}
|
||||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
channel, newAPIError := getChannel(c, group, originalModel, i)
|
||||||
if err != nil {
|
if newAPIError != nil {
|
||||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
channelId = channel.Id
|
channelId = channel.Id
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||||
c.Set("use_channel", useChannel)
|
c.Set("use_channel", useChannel)
|
||||||
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
|
||||||
requestBody, err := common.GetRequestBody(c)
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
taskErr = taskRelayHandler(c, relayMode)
|
taskErr = taskRelayHandler(c, relayMode)
|
||||||
}
|
}
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
if len(useChannel) > 1 {
|
if len(useChannel) > 1 {
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||||
common.LogInfo(c, retryLogStr)
|
logger.LogInfo(c, retryLogStr)
|
||||||
}
|
}
|
||||||
if taskErr != nil {
|
if taskErr != nil {
|
||||||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||||||
@@ -420,7 +425,7 @@ func RelayTask(c *gin.Context) {
|
|||||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
||||||
var err *dto.TaskError
|
var err *dto.TaskError
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
|
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID:
|
||||||
err = relay.RelayTaskFetch(c, relayMode)
|
err = relay.RelayTaskFetch(c, relayMode)
|
||||||
default:
|
default:
|
||||||
err = relay.RelayTaskSubmit(c, relayMode)
|
err = relay.RelayTaskSubmit(c, relayMode)
|
||||||
|
|||||||
@@ -75,6 +75,14 @@ func PostSetup(c *gin.Context) {
|
|||||||
|
|
||||||
// If root doesn't exist, validate and create admin account
|
// If root doesn't exist, validate and create admin account
|
||||||
if !rootExists {
|
if !rootExists {
|
||||||
|
// Validate username length: max 12 characters to align with model.User validation
|
||||||
|
if len(req.Username) > 12 {
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户名长度不能超过12个字符",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
// Validate password
|
// Validate password
|
||||||
if req.Password != req.ConfirmPassword {
|
if req.Password != req.ConfirmPassword {
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
|
|||||||
136
controller/swag_video.go
Normal file
136
controller/swag_video.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VideoGenerations
|
||||||
|
// @Summary 生成视频
|
||||||
|
// @Description 调用视频生成接口生成视频
|
||||||
|
// @Description 支持多种视频生成服务:
|
||||||
|
// @Description - 可灵AI (Kling): https://app.klingai.com/cn/dev/document-api/apiReference/commonInfo
|
||||||
|
// @Description - 即梦 (Jimeng): https://www.volcengine.com/docs/85621/1538636
|
||||||
|
// @Tags Video
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
|
||||||
|
// @Param request body dto.VideoRequest true "视频生成请求参数"
|
||||||
|
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||||
|
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||||
|
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||||
|
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||||
|
// @Router /v1/video/generations [post]
|
||||||
|
func VideoGenerations(c *gin.Context) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoGenerationsTaskId
|
||||||
|
// @Summary 查询视频
|
||||||
|
// @Description 根据任务ID查询视频生成任务的状态和结果
|
||||||
|
// @Tags Video
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Security BearerAuth
|
||||||
|
// @Param task_id path string true "Task ID"
|
||||||
|
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
|
||||||
|
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||||
|
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||||
|
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||||
|
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||||
|
// @Router /v1/video/generations/{task_id} [get]
|
||||||
|
func VideoGenerationsTaskId(c *gin.Context) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// KlingText2VideoGenerations
|
||||||
|
// @Summary 可灵文生视频
|
||||||
|
// @Description 调用可灵AI文生视频接口,生成视频内容
|
||||||
|
// @Tags Video
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
|
||||||
|
// @Param request body KlingText2VideoRequest true "视频生成请求参数"
|
||||||
|
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
|
||||||
|
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||||
|
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||||
|
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||||
|
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||||
|
// @Router /kling/v1/videos/text2video [post]
|
||||||
|
func KlingText2VideoGenerations(c *gin.Context) {
|
||||||
|
}
|
||||||
|
|
||||||
|
type KlingText2VideoRequest struct {
|
||||||
|
ModelName string `json:"model_name,omitempty" example:"kling-v1"`
|
||||||
|
Prompt string `json:"prompt" binding:"required" example:"A cat playing piano in the garden"`
|
||||||
|
NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
|
||||||
|
CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
|
||||||
|
Mode string `json:"mode,omitempty" example:"std"`
|
||||||
|
CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
|
||||||
|
AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
|
||||||
|
Duration string `json:"duration,omitempty" example:"5"`
|
||||||
|
CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
|
||||||
|
ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-001"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KlingCameraControl struct {
|
||||||
|
Type string `json:"type,omitempty" example:"simple"`
|
||||||
|
Config *KlingCameraConfig `json:"config,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type KlingCameraConfig struct {
|
||||||
|
Horizontal float64 `json:"horizontal,omitempty" example:"2.5"`
|
||||||
|
Vertical float64 `json:"vertical,omitempty" example:"0"`
|
||||||
|
Pan float64 `json:"pan,omitempty" example:"0"`
|
||||||
|
Tilt float64 `json:"tilt,omitempty" example:"0"`
|
||||||
|
Roll float64 `json:"roll,omitempty" example:"0"`
|
||||||
|
Zoom float64 `json:"zoom,omitempty" example:"0"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KlingImage2VideoGenerations
|
||||||
|
// @Summary 可灵官方-图生视频
|
||||||
|
// @Description 调用可灵AI图生视频接口,生成视频内容
|
||||||
|
// @Tags Video
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
|
||||||
|
// @Param request body KlingImage2VideoRequest true "图生视频请求参数"
|
||||||
|
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
|
||||||
|
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||||
|
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||||
|
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||||
|
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||||
|
// @Router /kling/v1/videos/image2video [post]
|
||||||
|
func KlingImage2VideoGenerations(c *gin.Context) {
|
||||||
|
}
|
||||||
|
|
||||||
|
type KlingImage2VideoRequest struct {
|
||||||
|
ModelName string `json:"model_name,omitempty" example:"kling-v2-master"`
|
||||||
|
Image string `json:"image" binding:"required" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"`
|
||||||
|
Prompt string `json:"prompt,omitempty" example:"A cat playing piano in the garden"`
|
||||||
|
NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
|
||||||
|
CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
|
||||||
|
Mode string `json:"mode,omitempty" example:"std"`
|
||||||
|
CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
|
||||||
|
AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
|
||||||
|
Duration string `json:"duration,omitempty" example:"5"`
|
||||||
|
CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
|
||||||
|
ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-002"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KlingImage2videoTaskId godoc
|
||||||
|
// @Summary 可灵任务查询--图生视频
|
||||||
|
// @Description Query the status and result of a Kling video generation task by task ID
|
||||||
|
// @Tags Origin
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param task_id path string true "Task ID"
|
||||||
|
// @Router /kling/v1/videos/image2video/{task_id} [get]
|
||||||
|
func KlingImage2videoTaskId(c *gin.Context) {}
|
||||||
|
|
||||||
|
// KlingText2videoTaskId godoc
|
||||||
|
// @Summary 可灵任务查询--文生视频
|
||||||
|
// @Description Query the status and result of a Kling text-to-video generation task by task ID
|
||||||
|
// @Tags Origin
|
||||||
|
// @Accept json
|
||||||
|
// @Produce json
|
||||||
|
// @Param task_id path string true "Task ID"
|
||||||
|
// @Router /kling/v1/videos/text2video/{task_id} [get]
|
||||||
|
func KlingText2videoTaskId(c *gin.Context) {}
|
||||||
@@ -5,18 +5,20 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/samber/lo"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/samber/lo"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UpdateTaskBulk() {
|
func UpdateTaskBulk() {
|
||||||
@@ -53,9 +55,9 @@ func UpdateTaskBulk() {
|
|||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
||||||
} else {
|
} else {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(taskChannelM) == 0 {
|
if len(taskChannelM) == 0 {
|
||||||
@@ -75,7 +77,9 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
|||||||
case constant.TaskPlatformSuno:
|
case constant.TaskPlatformSuno:
|
||||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||||
default:
|
default:
|
||||||
common.SysLog("未知平台")
|
if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
|
||||||
|
common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,14 +87,14 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM
|
|||||||
for channelId, taskIds := range taskChannelM {
|
for channelId, taskIds := range taskChannelM {
|
||||||
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
|
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
|
logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||||
if len(taskIds) == 0 {
|
if len(taskIds) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -103,7 +107,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
"progress": "100%",
|
"progress": "100%",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -115,23 +119,23 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
"ids": taskIds,
|
"ids": taskIds,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
|
common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
|
common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
||||||
err = json.Unmarshal(responseBody, &responseItems)
|
err = json.Unmarshal(responseBody, &responseItems)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !responseItems.IsSuccess() {
|
if !responseItems.IsSuccess() {
|
||||||
@@ -151,19 +155,19 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
|
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
|
||||||
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
|
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
|
||||||
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
||||||
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
||||||
task.Progress = "100%"
|
task.Progress = "100%"
|
||||||
//err = model.CacheUpdateUserQuota(task.UserId) ?
|
//err = model.CacheUpdateUserQuota(task.UserId) ?
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
logger.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||||
} else {
|
} else {
|
||||||
quota := task.Quota
|
quota := task.Quota
|
||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
err = model.IncreaseUserQuota(task.UserId, quota, false)
|
err = model.IncreaseUserQuota(task.UserId, quota, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||||
}
|
}
|
||||||
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
|
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota))
|
||||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -175,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
|||||||
|
|
||||||
err = task.Update()
|
err = task.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("UpdateMidjourneyTask task error: " + err.Error())
|
common.SysLog("UpdateMidjourneyTask task error: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -223,10 +227,8 @@ func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetAllTask(c *gin.Context) {
|
func GetAllTask(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
if p < 0 {
|
|
||||||
p = 0
|
|
||||||
}
|
|
||||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||||
// 解析其他查询参数
|
// 解析其他查询参数
|
||||||
@@ -237,25 +239,18 @@ func GetAllTask(c *gin.Context) {
|
|||||||
Action: c.Query("action"),
|
Action: c.Query("action"),
|
||||||
StartTimestamp: startTimestamp,
|
StartTimestamp: startTimestamp,
|
||||||
EndTimestamp: endTimestamp,
|
EndTimestamp: endTimestamp,
|
||||||
|
ChannelID: c.Query("channel_id"),
|
||||||
}
|
}
|
||||||
|
|
||||||
logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||||
if logs == nil {
|
total := model.TaskCountAllTasks(queryParams)
|
||||||
logs = make([]*model.Task, 0)
|
pageInfo.SetTotal(int(total))
|
||||||
}
|
pageInfo.SetItems(items)
|
||||||
|
common.ApiSuccess(c, pageInfo)
|
||||||
c.JSON(200, gin.H{
|
|
||||||
"success": true,
|
|
||||||
"message": "",
|
|
||||||
"data": logs,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserTask(c *gin.Context) {
|
func GetUserTask(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
if p < 0 {
|
|
||||||
p = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
@@ -271,14 +266,9 @@ func GetUserTask(c *gin.Context) {
|
|||||||
EndTimestamp: endTimestamp,
|
EndTimestamp: endTimestamp,
|
||||||
}
|
}
|
||||||
|
|
||||||
logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||||
if logs == nil {
|
total := model.TaskCountAllUserTask(userId, queryParams)
|
||||||
logs = make([]*model.Task, 0)
|
pageInfo.SetTotal(int(total))
|
||||||
}
|
pageInfo.SetItems(items)
|
||||||
|
common.ApiSuccess(c, pageInfo)
|
||||||
c.JSON(200, gin.H{
|
|
||||||
"success": true,
|
|
||||||
"message": "",
|
|
||||||
"data": logs,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
148
controller/task_video.go
Normal file
148
controller/task_video.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/relay"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||||
|
for channelId, taskIds := range taskChannelM {
|
||||||
|
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
||||||
|
logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||||
|
logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||||
|
if len(taskIds) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
||||||
|
if err != nil {
|
||||||
|
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
|
||||||
|
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
||||||
|
"status": "FAILURE",
|
||||||
|
"progress": "100%",
|
||||||
|
})
|
||||||
|
if errUpdate != nil {
|
||||||
|
common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||||
|
}
|
||||||
|
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||||
|
}
|
||||||
|
adaptor := relay.GetTaskAdaptor(platform)
|
||||||
|
if adaptor == nil {
|
||||||
|
return fmt.Errorf("video adaptor not found")
|
||||||
|
}
|
||||||
|
for _, taskId := range taskIds {
|
||||||
|
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||||
|
logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||||
|
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||||
|
if channel.GetBaseURL() != "" {
|
||||||
|
baseURL = channel.GetBaseURL()
|
||||||
|
}
|
||||||
|
|
||||||
|
task := taskM[taskId]
|
||||||
|
if task == nil {
|
||||||
|
logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||||
|
return fmt.Errorf("task %s not found", taskId)
|
||||||
|
}
|
||||||
|
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||||
|
"task_id": taskId,
|
||||||
|
"action": task.Action,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
|
||||||
|
}
|
||||||
|
//if resp.StatusCode != http.StatusOK {
|
||||||
|
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
|
||||||
|
//}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
taskResult := &relaycommon.TaskInfo{}
|
||||||
|
// try parse as New API response format
|
||||||
|
var responseItems dto.TaskResponse[model.Task]
|
||||||
|
if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
|
||||||
|
t := responseItems.Data
|
||||||
|
taskResult.TaskID = t.TaskID
|
||||||
|
taskResult.Status = string(t.Status)
|
||||||
|
taskResult.Url = t.FailReason
|
||||||
|
taskResult.Progress = t.Progress
|
||||||
|
taskResult.Reason = t.FailReason
|
||||||
|
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
||||||
|
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||||
|
} else {
|
||||||
|
task.Data = responseBody
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().Unix()
|
||||||
|
if taskResult.Status == "" {
|
||||||
|
return fmt.Errorf("task %s status is empty", taskId)
|
||||||
|
}
|
||||||
|
task.Status = model.TaskStatus(taskResult.Status)
|
||||||
|
switch taskResult.Status {
|
||||||
|
case model.TaskStatusSubmitted:
|
||||||
|
task.Progress = "10%"
|
||||||
|
case model.TaskStatusQueued:
|
||||||
|
task.Progress = "20%"
|
||||||
|
case model.TaskStatusInProgress:
|
||||||
|
task.Progress = "30%"
|
||||||
|
if task.StartTime == 0 {
|
||||||
|
task.StartTime = now
|
||||||
|
}
|
||||||
|
case model.TaskStatusSuccess:
|
||||||
|
task.Progress = "100%"
|
||||||
|
if task.FinishTime == 0 {
|
||||||
|
task.FinishTime = now
|
||||||
|
}
|
||||||
|
task.FailReason = taskResult.Url
|
||||||
|
case model.TaskStatusFailure:
|
||||||
|
task.Status = model.TaskStatusFailure
|
||||||
|
task.Progress = "100%"
|
||||||
|
if task.FinishTime == 0 {
|
||||||
|
task.FinishTime = now
|
||||||
|
}
|
||||||
|
task.FailReason = taskResult.Reason
|
||||||
|
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||||
|
quota := task.Quota
|
||||||
|
if quota != 0 {
|
||||||
|
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||||
|
logger.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
||||||
|
}
|
||||||
|
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||||
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
|
||||||
|
}
|
||||||
|
if taskResult.Progress != "" {
|
||||||
|
task.Progress = taskResult.Progress
|
||||||
|
}
|
||||||
|
if err := task.Update(); err != nil {
|
||||||
|
common.SysLog("UpdateVideoTask task error: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -12,29 +12,16 @@ import (
|
|||||||
|
|
||||||
func GetAllTokens(c *gin.Context) {
|
func GetAllTokens(c *gin.Context) {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
size, _ := strconv.Atoi(c.Query("size"))
|
tokens, err := model.GetAllUserTokens(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
if p < 0 {
|
|
||||||
p = 0
|
|
||||||
}
|
|
||||||
if size <= 0 {
|
|
||||||
size = common.ItemsPerPage
|
|
||||||
} else if size > 100 {
|
|
||||||
size = 100
|
|
||||||
}
|
|
||||||
tokens, err := model.GetAllUserTokens(userId, p*size, size)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
total, _ := model.CountUserTokens(userId)
|
||||||
"success": true,
|
pageInfo.SetTotal(int(total))
|
||||||
"message": "",
|
pageInfo.SetItems(tokens)
|
||||||
"data": tokens,
|
common.ApiSuccess(c, pageInfo)
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,10 +31,7 @@ func SearchTokens(c *gin.Context) {
|
|||||||
token := c.Query("token")
|
token := c.Query("token")
|
||||||
tokens, err := model.SearchUserTokens(userId, keyword, token)
|
tokens, err := model.SearchUserTokens(userId, keyword, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -62,18 +46,12 @@ func GetToken(c *gin.Context) {
|
|||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
token, err := model.GetTokenByIds(id, userId)
|
token, err := model.GetTokenByIds(id, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -89,10 +67,7 @@ func GetTokenStatus(c *gin.Context) {
|
|||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
token, err := model.GetTokenByIds(tokenId, userId)
|
token, err := model.GetTokenByIds(tokenId, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
expiredAt := token.ExpiredTime
|
expiredAt := token.ExpiredTime
|
||||||
@@ -162,10 +137,7 @@ func AddToken(c *gin.Context) {
|
|||||||
token := model.Token{}
|
token := model.Token{}
|
||||||
err := c.ShouldBindJSON(&token)
|
err := c.ShouldBindJSON(&token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(token.Name) > 30 {
|
if len(token.Name) > 30 {
|
||||||
@@ -181,7 +153,7 @@ func AddToken(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "生成令牌失败",
|
"message": "生成令牌失败",
|
||||||
})
|
})
|
||||||
common.SysError("failed to generate token key: " + err.Error())
|
common.SysLog("failed to generate token key: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cleanToken := model.Token{
|
cleanToken := model.Token{
|
||||||
@@ -200,10 +172,7 @@ func AddToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
err = cleanToken.Insert()
|
err = cleanToken.Insert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -218,10 +187,7 @@ func DeleteToken(c *gin.Context) {
|
|||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
err := model.DeleteTokenById(id, userId)
|
err := model.DeleteTokenById(id, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -237,10 +203,7 @@ func UpdateToken(c *gin.Context) {
|
|||||||
token := model.Token{}
|
token := model.Token{}
|
||||||
err := c.ShouldBindJSON(&token)
|
err := c.ShouldBindJSON(&token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(token.Name) > 30 {
|
if len(token.Name) > 30 {
|
||||||
@@ -252,10 +215,7 @@ func UpdateToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
cleanToken, err := model.GetTokenByIds(token.Id, userId)
|
cleanToken, err := model.GetTokenByIds(token.Id, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if token.Status == common.TokenStatusEnabled {
|
if token.Status == common.TokenStatusEnabled {
|
||||||
@@ -289,10 +249,7 @@ func UpdateToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
err = cleanToken.Update()
|
err = cleanToken.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -302,3 +259,29 @@ func UpdateToken(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TokenBatch struct {
|
||||||
|
Ids []int `json:"ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteTokenBatch(c *gin.Context) {
|
||||||
|
tokenBatch := TokenBatch{}
|
||||||
|
if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "参数错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": count,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
@@ -97,16 +98,14 @@ func RequestEpay(c *gin.Context) {
|
|||||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payType := "wxpay"
|
|
||||||
if req.PaymentMethod == "zfb" {
|
if !setting.ContainsPayMethod(req.PaymentMethod) {
|
||||||
payType = "alipay"
|
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||||
}
|
return
|
||||||
if req.PaymentMethod == "wx" {
|
|
||||||
req.PaymentMethod = "wxpay"
|
|
||||||
payType = "wxpay"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
callBackAddress := service.GetCallbackAddress()
|
callBackAddress := service.GetCallbackAddress()
|
||||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/log")
|
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
||||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||||
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
||||||
@@ -116,7 +115,7 @@ func RequestEpay(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
||||||
Type: payType,
|
Type: req.PaymentMethod,
|
||||||
ServiceTradeNo: tradeNo,
|
ServiceTradeNo: tradeNo,
|
||||||
Name: fmt.Sprintf("TUC%d", req.Amount),
|
Name: fmt.Sprintf("TUC%d", req.Amount),
|
||||||
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||||
@@ -233,7 +232,7 @@ func EpayNotify(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("易支付回调更新用户成功 %v", topUp)
|
log.Printf("易支付回调更新用户成功 %v", topUp)
|
||||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money))
|
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", logger.LogQuota(quotaToAdd), topUp.Money))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Printf("易支付异常回调: %v", verifyInfo)
|
log.Printf("易支付异常回调: %v", verifyInfo)
|
||||||
|
|||||||
275
controller/topup_stripe.go
Normal file
275
controller/topup_stripe.go
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/setting"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stripe/stripe-go/v81"
|
||||||
|
"github.com/stripe/stripe-go/v81/checkout/session"
|
||||||
|
"github.com/stripe/stripe-go/v81/webhook"
|
||||||
|
"github.com/thanhpk/randstr"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PaymentMethodStripe = "stripe"
|
||||||
|
)
|
||||||
|
|
||||||
|
var stripeAdaptor = &StripeAdaptor{}
|
||||||
|
|
||||||
|
type StripePayRequest struct {
|
||||||
|
Amount int64 `json:"amount"`
|
||||||
|
PaymentMethod string `json:"payment_method"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type StripeAdaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
|
||||||
|
if req.Amount < getStripeMinTopup() {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id := c.GetInt("id")
|
||||||
|
group, err := model.GetUserGroup(id, true)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payMoney := getStripePayMoney(float64(req.Amount), group)
|
||||||
|
if payMoney <= 0.01 {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
|
||||||
|
if req.PaymentMethod != PaymentMethodStripe {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Amount < getStripeMinTopup() {
|
||||||
|
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Amount > 10000 {
|
||||||
|
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.GetInt("id")
|
||||||
|
user, _ := model.GetUserById(id, false)
|
||||||
|
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
|
||||||
|
|
||||||
|
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), randstr.String(4))
|
||||||
|
referenceId := "ref_" + common.Sha1([]byte(reference))
|
||||||
|
|
||||||
|
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("获取Stripe Checkout支付链接失败", err)
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topUp := &model.TopUp{
|
||||||
|
UserId: id,
|
||||||
|
Amount: req.Amount,
|
||||||
|
Money: chargedMoney,
|
||||||
|
TradeNo: referenceId,
|
||||||
|
CreateTime: time.Now().Unix(),
|
||||||
|
Status: common.TopUpStatusPending,
|
||||||
|
}
|
||||||
|
err = topUp.Insert()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"message": "success",
|
||||||
|
"data": gin.H{
|
||||||
|
"pay_link": payLink,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func RequestStripeAmount(c *gin.Context) {
|
||||||
|
var req StripePayRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stripeAdaptor.RequestAmount(c, &req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RequestStripePay(c *gin.Context) {
|
||||||
|
var req StripePayRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stripeAdaptor.RequestPay(c, &req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func StripeWebhook(c *gin.Context) {
|
||||||
|
payload, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
|
||||||
|
c.AbortWithStatus(http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
signature := c.GetHeader("Stripe-Signature")
|
||||||
|
endpointSecret := setting.StripeWebhookSecret
|
||||||
|
event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
|
||||||
|
IgnoreAPIVersionMismatch: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Stripe Webhook验签失败: %v\n", err)
|
||||||
|
c.AbortWithStatus(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch event.Type {
|
||||||
|
case stripe.EventTypeCheckoutSessionCompleted:
|
||||||
|
sessionCompleted(event)
|
||||||
|
case stripe.EventTypeCheckoutSessionExpired:
|
||||||
|
sessionExpired(event)
|
||||||
|
default:
|
||||||
|
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sessionCompleted(event stripe.Event) {
|
||||||
|
customerId := event.GetObjectValue("customer")
|
||||||
|
referenceId := event.GetObjectValue("client_reference_id")
|
||||||
|
status := event.GetObjectValue("status")
|
||||||
|
if "complete" != status {
|
||||||
|
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := model.Recharge(referenceId, customerId)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err.Error(), referenceId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
|
||||||
|
currency := strings.ToUpper(event.GetObjectValue("currency"))
|
||||||
|
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sessionExpired(event stripe.Event) {
|
||||||
|
referenceId := event.GetObjectValue("client_reference_id")
|
||||||
|
status := event.GetObjectValue("status")
|
||||||
|
if "expired" != status {
|
||||||
|
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(referenceId) == 0 {
|
||||||
|
log.Println("未提供支付单号")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topUp := model.GetTopUpByTradeNo(referenceId)
|
||||||
|
if topUp == nil {
|
||||||
|
log.Println("充值订单不存在", referenceId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if topUp.Status != common.TopUpStatusPending {
|
||||||
|
log.Println("充值订单状态错误", referenceId)
|
||||||
|
}
|
||||||
|
|
||||||
|
topUp.Status = common.TopUpStatusExpired
|
||||||
|
err := topUp.Update()
|
||||||
|
if err != nil {
|
||||||
|
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("充值订单已过期", referenceId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
|
||||||
|
if !strings.HasPrefix(setting.StripeApiSecret, "sk_") && !strings.HasPrefix(setting.StripeApiSecret, "rk_") {
|
||||||
|
return "", fmt.Errorf("无效的Stripe API密钥")
|
||||||
|
}
|
||||||
|
|
||||||
|
stripe.Key = setting.StripeApiSecret
|
||||||
|
|
||||||
|
params := &stripe.CheckoutSessionParams{
|
||||||
|
ClientReferenceID: stripe.String(referenceId),
|
||||||
|
SuccessURL: stripe.String(setting.ServerAddress + "/log"),
|
||||||
|
CancelURL: stripe.String(setting.ServerAddress + "/topup"),
|
||||||
|
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||||
|
{
|
||||||
|
Price: stripe.String(setting.StripePriceId),
|
||||||
|
Quantity: stripe.Int64(amount),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if "" == customerId {
|
||||||
|
if "" != email {
|
||||||
|
params.CustomerEmail = stripe.String(email)
|
||||||
|
}
|
||||||
|
|
||||||
|
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
|
||||||
|
} else {
|
||||||
|
params.Customer = stripe.String(customerId)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := session.New(params)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.URL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetChargedAmount(count float64, user model.User) float64 {
|
||||||
|
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
|
||||||
|
if topUpGroupRatio == 0 {
|
||||||
|
topUpGroupRatio = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return count * topUpGroupRatio
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStripePayMoney(amount float64, group string) float64 {
|
||||||
|
if !common.DisplayInCurrencyEnabled {
|
||||||
|
amount = amount / common.QuotaPerUnit
|
||||||
|
}
|
||||||
|
// Using float64 for monetary calculations is acceptable here due to the small amounts involved
|
||||||
|
topupGroupRatio := common.GetTopupGroupRatio(group)
|
||||||
|
if topupGroupRatio == 0 {
|
||||||
|
topupGroupRatio = 1
|
||||||
|
}
|
||||||
|
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio
|
||||||
|
return payMoney
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStripeMinTopup() int64 {
|
||||||
|
minTopup := setting.StripeMinTopUp
|
||||||
|
if !common.DisplayInCurrencyEnabled {
|
||||||
|
minTopup = minTopup * int(common.QuotaPerUnit)
|
||||||
|
}
|
||||||
|
return int64(minTopup)
|
||||||
|
}
|
||||||
553
controller/twofa.go
Normal file
553
controller/twofa.go
Normal file
@@ -0,0 +1,553 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Setup2FARequest 设置2FA请求结构
|
||||||
|
type Setup2FARequest struct {
|
||||||
|
Code string `json:"code" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify2FARequest 验证2FA请求结构
|
||||||
|
type Verify2FARequest struct {
|
||||||
|
Code string `json:"code" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup2FAResponse 设置2FA响应结构
|
||||||
|
type Setup2FAResponse struct {
|
||||||
|
Secret string `json:"secret"`
|
||||||
|
QRCodeData string `json:"qr_code_data"`
|
||||||
|
BackupCodes []string `json:"backup_codes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup2FA 初始化2FA设置
|
||||||
|
func Setup2FA(c *gin.Context) {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
// 检查用户是否已经启用2FA
|
||||||
|
existing, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if existing != nil && existing.IsEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户已启用2FA,请先禁用后重新设置",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果存在已禁用的2FA记录,先删除它
|
||||||
|
if existing != nil && !existing.IsEnabled {
|
||||||
|
if err := existing.Delete(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
existing = nil // 重置为nil,后续将创建新记录
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取用户信息
|
||||||
|
user, err := model.GetUserById(userId, false)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成TOTP密钥
|
||||||
|
key, err := common.GenerateTOTPSecret(user.Username)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "生成2FA密钥失败",
|
||||||
|
})
|
||||||
|
common.SysLog("生成TOTP密钥失败: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成备用码
|
||||||
|
backupCodes, err := common.GenerateBackupCodes()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "生成备用码失败",
|
||||||
|
})
|
||||||
|
common.SysLog("生成备用码失败: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成二维码数据
|
||||||
|
qrCodeData := common.GenerateQRCodeData(key.Secret(), user.Username)
|
||||||
|
|
||||||
|
// 创建或更新2FA记录(暂未启用)
|
||||||
|
twoFA := &model.TwoFA{
|
||||||
|
UserId: userId,
|
||||||
|
Secret: key.Secret(),
|
||||||
|
IsEnabled: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
if existing != nil {
|
||||||
|
// 更新现有记录
|
||||||
|
twoFA.Id = existing.Id
|
||||||
|
err = twoFA.Update()
|
||||||
|
} else {
|
||||||
|
// 创建新记录
|
||||||
|
err = twoFA.Create()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建备用码记录
|
||||||
|
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "保存备用码失败",
|
||||||
|
})
|
||||||
|
common.SysLog("保存备用码失败: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
model.RecordLog(userId, model.LogTypeSystem, "开始设置两步验证")
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "2FA设置初始化成功,请使用认证器扫描二维码并输入验证码完成设置",
|
||||||
|
"data": Setup2FAResponse{
|
||||||
|
Secret: key.Secret(),
|
||||||
|
QRCodeData: qrCodeData,
|
||||||
|
BackupCodes: backupCodes,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable2FA 启用2FA
|
||||||
|
func Enable2FA(c *gin.Context) {
|
||||||
|
var req Setup2FARequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "参数错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
// 获取2FA记录
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if twoFA == nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "请先完成2FA初始化设置",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if twoFA.IsEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "2FA已经启用",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证TOTP验证码
|
||||||
|
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !common.ValidateTOTPCode(twoFA.Secret, cleanCode) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "验证码或备用码错误,请重试",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启用2FA
|
||||||
|
if err := twoFA.Enable(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
model.RecordLog(userId, model.LogTypeSystem, "成功启用两步验证")
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "两步验证启用成功",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable2FA 禁用2FA
|
||||||
|
func Disable2FA(c *gin.Context) {
|
||||||
|
var req Verify2FARequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "参数错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
// 获取2FA记录
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if twoFA == nil || !twoFA.IsEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户未启用2FA",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证TOTP验证码或备用码
|
||||||
|
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||||
|
isValidTOTP := false
|
||||||
|
isValidBackup := false
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
// 尝试验证TOTP
|
||||||
|
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidTOTP {
|
||||||
|
// 尝试验证备用码
|
||||||
|
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidTOTP && !isValidBackup {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "验证码或备用码错误,请重试",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 禁用2FA
|
||||||
|
if err := model.DisableTwoFA(userId); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
model.RecordLog(userId, model.LogTypeSystem, "禁用两步验证")
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "两步验证已禁用",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get2FAStatus 获取用户2FA状态
|
||||||
|
func Get2FAStatus(c *gin.Context) {
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
status := map[string]interface{}{
|
||||||
|
"enabled": false,
|
||||||
|
"locked": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
if twoFA != nil {
|
||||||
|
status["enabled"] = twoFA.IsEnabled
|
||||||
|
status["locked"] = twoFA.IsLocked()
|
||||||
|
if twoFA.IsEnabled {
|
||||||
|
// 获取剩余备用码数量
|
||||||
|
backupCount, err := model.GetUnusedBackupCodeCount(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.SysLog("获取备用码数量失败: " + err.Error())
|
||||||
|
} else {
|
||||||
|
status["backup_codes_remaining"] = backupCount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": status,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegenerateBackupCodes 重新生成备用码
|
||||||
|
func RegenerateBackupCodes(c *gin.Context) {
|
||||||
|
var req Verify2FARequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "参数错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
// 获取2FA记录
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if twoFA == nil || !twoFA.IsEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户未启用2FA",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证TOTP验证码
|
||||||
|
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, err := twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "验证码或备用码错误,请重试",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成新的备用码
|
||||||
|
backupCodes, err := common.GenerateBackupCodes()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "生成备用码失败",
|
||||||
|
})
|
||||||
|
common.SysLog("生成备用码失败: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存新的备用码
|
||||||
|
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "保存备用码失败",
|
||||||
|
})
|
||||||
|
common.SysLog("保存备用码失败: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
model.RecordLog(userId, model.LogTypeSystem, "重新生成两步验证备用码")
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "备用码重新生成成功",
|
||||||
|
"data": map[string]interface{}{
|
||||||
|
"backup_codes": backupCodes,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify2FALogin 登录时验证2FA
|
||||||
|
func Verify2FALogin(c *gin.Context) {
|
||||||
|
var req Verify2FARequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "参数错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从会话中获取pending用户信息
|
||||||
|
session := sessions.Default(c)
|
||||||
|
pendingUserId := session.Get("pending_user_id")
|
||||||
|
if pendingUserId == nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "会话已过期,请重新登录",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userId, ok := pendingUserId.(int)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "会话数据无效,请重新登录",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 获取用户信息
|
||||||
|
user, err := model.GetUserById(userId, false)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户不存在",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取2FA记录
|
||||||
|
twoFA, err := model.GetTwoFAByUserId(user.Id)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if twoFA == nil || !twoFA.IsEnabled {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户未启用2FA",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证TOTP验证码或备用码
|
||||||
|
cleanCode, err := common.ValidateNumericCode(req.Code)
|
||||||
|
isValidTOTP := false
|
||||||
|
isValidBackup := false
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
// 尝试验证TOTP
|
||||||
|
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidTOTP {
|
||||||
|
// 尝试验证备用码
|
||||||
|
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidTOTP && !isValidBackup {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "验证码或备用码错误,请重试",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2FA验证成功,清理pending会话信息并完成登录
|
||||||
|
session.Delete("pending_username")
|
||||||
|
session.Delete("pending_user_id")
|
||||||
|
session.Save()
|
||||||
|
|
||||||
|
setupLogin(user, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Admin2FAStats 管理员获取2FA统计信息
|
||||||
|
func Admin2FAStats(c *gin.Context) {
|
||||||
|
stats, err := model.GetTwoFAStats()
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": stats,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminDisable2FA 管理员强制禁用用户2FA
|
||||||
|
func AdminDisable2FA(c *gin.Context) {
|
||||||
|
userIdStr := c.Param("id")
|
||||||
|
userId, err := strconv.Atoi(userIdStr)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户ID格式错误",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查目标用户权限
|
||||||
|
targetUser, err := model.GetUserById(userId, false)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
myRole := c.GetInt("role")
|
||||||
|
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无权操作同级或更高级用户的2FA设置",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 禁用2FA
|
||||||
|
if err := model.DisableTwoFA(userId); err != nil {
|
||||||
|
if errors.Is(err, model.ErrTwoFANotEnabled) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "用户未启用2FA",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录操作日志
|
||||||
|
adminId := c.GetInt("id")
|
||||||
|
model.RecordLog(userId, model.LogTypeManage,
|
||||||
|
fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId))
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "用户2FA已被强制禁用",
|
||||||
|
})
|
||||||
|
}
|
||||||
154
controller/uptime_kuma.go
Normal file
154
controller/uptime_kuma.go
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"one-api/setting/console_setting"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
requestTimeout = 30 * time.Second
|
||||||
|
httpTimeout = 10 * time.Second
|
||||||
|
uptimeKeySuffix = "_24"
|
||||||
|
apiStatusPath = "/api/status-page/"
|
||||||
|
apiHeartbeatPath = "/api/status-page/heartbeat/"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Monitor struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Uptime float64 `json:"uptime"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
Group string `json:"group,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UptimeGroupResult struct {
|
||||||
|
CategoryName string `json:"categoryName"`
|
||||||
|
Monitors []Monitor `json:"monitors"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return errors.New("non-200 status")
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.NewDecoder(resp.Body).Decode(dest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[string]interface{}) UptimeGroupResult {
|
||||||
|
url, _ := groupConfig["url"].(string)
|
||||||
|
slug, _ := groupConfig["slug"].(string)
|
||||||
|
categoryName, _ := groupConfig["categoryName"].(string)
|
||||||
|
|
||||||
|
result := UptimeGroupResult{
|
||||||
|
CategoryName: categoryName,
|
||||||
|
Monitors: []Monitor{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if url == "" || slug == "" {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := strings.TrimSuffix(url, "/")
|
||||||
|
|
||||||
|
var statusData struct {
|
||||||
|
PublicGroupList []struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
MonitorList []struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
} `json:"monitorList"`
|
||||||
|
} `json:"publicGroupList"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var heartbeatData struct {
|
||||||
|
HeartbeatList map[string][]struct {
|
||||||
|
Status int `json:"status"`
|
||||||
|
} `json:"heartbeatList"`
|
||||||
|
UptimeList map[string]float64 `json:"uptimeList"`
|
||||||
|
}
|
||||||
|
|
||||||
|
g, gCtx := errgroup.WithContext(ctx)
|
||||||
|
g.Go(func() error {
|
||||||
|
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
||||||
|
})
|
||||||
|
g.Go(func() error {
|
||||||
|
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
||||||
|
})
|
||||||
|
|
||||||
|
if g.Wait() != nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pg := range statusData.PublicGroupList {
|
||||||
|
if len(pg.MonitorList) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range pg.MonitorList {
|
||||||
|
monitor := Monitor{
|
||||||
|
Name: m.Name,
|
||||||
|
Group: pg.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
monitorID := strconv.Itoa(m.ID)
|
||||||
|
|
||||||
|
if uptime, exists := heartbeatData.UptimeList[monitorID+uptimeKeySuffix]; exists {
|
||||||
|
monitor.Uptime = uptime
|
||||||
|
}
|
||||||
|
|
||||||
|
if heartbeats, exists := heartbeatData.HeartbeatList[monitorID]; exists && len(heartbeats) > 0 {
|
||||||
|
monitor.Status = heartbeats[0].Status
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Monitors = append(result.Monitors, monitor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUptimeKumaStatus(c *gin.Context) {
|
||||||
|
groups := console_setting.GetUptimeKumaGroups()
|
||||||
|
if len(groups) == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": []UptimeGroupResult{}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: httpTimeout}
|
||||||
|
results := make([]UptimeGroupResult, len(groups))
|
||||||
|
|
||||||
|
g, gCtx := errgroup.WithContext(ctx)
|
||||||
|
for i, group := range groups {
|
||||||
|
i, group := i, group
|
||||||
|
g.Go(func() error {
|
||||||
|
results[i] = fetchGroupData(gCtx, client, group)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
g.Wait()
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
|
||||||
|
}
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetAllQuotaDates(c *gin.Context) {
|
func GetAllQuotaDates(c *gin.Context) {
|
||||||
@@ -13,10 +15,7 @@ func GetAllQuotaDates(c *gin.Context) {
|
|||||||
username := c.Query("username")
|
username := c.Query("username")
|
||||||
dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username)
|
dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -41,10 +40,7 @@ func GetUserQuotaDates(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
|
dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/logger"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -61,6 +63,32 @@ func Login(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否启用2FA
|
||||||
|
if model.IsTwoFAEnabled(user.Id) {
|
||||||
|
// 设置pending session,等待2FA验证
|
||||||
|
session := sessions.Default(c)
|
||||||
|
session.Set("pending_username", user.Username)
|
||||||
|
session.Set("pending_user_id", user.Id)
|
||||||
|
err := session.Save()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "无法保存会话信息,请重试",
|
||||||
|
"success": false,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": "请输入两步验证码",
|
||||||
|
"success": true,
|
||||||
|
"data": map[string]interface{}{
|
||||||
|
"require_2fa": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
setupLogin(&user, c)
|
setupLogin(&user, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,7 +193,7 @@ func Register(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "数据库错误,请稍后重试",
|
"message": "数据库错误,请稍后重试",
|
||||||
})
|
})
|
||||||
common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
|
common.SysLog(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if exist {
|
if exist {
|
||||||
@@ -187,10 +215,7 @@ func Register(c *gin.Context) {
|
|||||||
cleanUser.Email = user.Email
|
cleanUser.Email = user.Email
|
||||||
}
|
}
|
||||||
if err := cleanUser.Insert(inviterId); err != nil {
|
if err := cleanUser.Insert(inviterId); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -211,7 +236,7 @@ func Register(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "生成默认令牌失败",
|
"message": "生成默认令牌失败",
|
||||||
})
|
})
|
||||||
common.SysError("failed to generate token key: " + err.Error())
|
common.SysLog("failed to generate token key: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 生成默认令牌
|
// 生成默认令牌
|
||||||
@@ -226,6 +251,9 @@ func Register(c *gin.Context) {
|
|||||||
UnlimitedQuota: true,
|
UnlimitedQuota: true,
|
||||||
ModelLimitsEnabled: false,
|
ModelLimitsEnabled: false,
|
||||||
}
|
}
|
||||||
|
if setting.DefaultUseAutoGroup {
|
||||||
|
token.Group = "auto"
|
||||||
|
}
|
||||||
if err := token.Insert(); err != nil {
|
if err := token.Insert(); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -243,83 +271,45 @@ func Register(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetAllUsers(c *gin.Context) {
|
func GetAllUsers(c *gin.Context) {
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
users, total, err := model.GetAllUsers(pageInfo)
|
||||||
if p < 1 {
|
|
||||||
p = 1
|
|
||||||
}
|
|
||||||
if pageSize < 0 {
|
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
users, total, err := model.GetAllUsers((p-1)*pageSize, pageSize)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": true,
|
pageInfo.SetTotal(int(total))
|
||||||
"message": "",
|
pageInfo.SetItems(users)
|
||||||
"data": gin.H{
|
|
||||||
"items": users,
|
common.ApiSuccess(c, pageInfo)
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchUsers(c *gin.Context) {
|
func SearchUsers(c *gin.Context) {
|
||||||
keyword := c.Query("keyword")
|
keyword := c.Query("keyword")
|
||||||
group := c.Query("group")
|
group := c.Query("group")
|
||||||
p, _ := strconv.Atoi(c.Query("p"))
|
pageInfo := common.GetPageQuery(c)
|
||||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
users, total, err := model.SearchUsers(keyword, group, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
if p < 1 {
|
|
||||||
p = 1
|
|
||||||
}
|
|
||||||
if pageSize < 0 {
|
|
||||||
pageSize = common.ItemsPerPage
|
|
||||||
}
|
|
||||||
startIdx := (p - 1) * pageSize
|
|
||||||
users, total, err := model.SearchUsers(keyword, group, startIdx, pageSize)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": true,
|
pageInfo.SetTotal(int(total))
|
||||||
"message": "",
|
pageInfo.SetItems(users)
|
||||||
"data": gin.H{
|
common.ApiSuccess(c, pageInfo)
|
||||||
"items": users,
|
|
||||||
"total": total,
|
|
||||||
"page": p,
|
|
||||||
"page_size": pageSize,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUser(c *gin.Context) {
|
func GetUser(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user, err := model.GetUserById(id, false)
|
user, err := model.GetUserById(id, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
myRole := c.GetInt("role")
|
myRole := c.GetInt("role")
|
||||||
@@ -342,10 +332,7 @@ func GenerateAccessToken(c *gin.Context) {
|
|||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
user, err := model.GetUserById(id, true)
|
user, err := model.GetUserById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// get rand int 28-32
|
// get rand int 28-32
|
||||||
@@ -356,7 +343,7 @@ func GenerateAccessToken(c *gin.Context) {
|
|||||||
"success": false,
|
"success": false,
|
||||||
"message": "生成失败",
|
"message": "生成失败",
|
||||||
})
|
})
|
||||||
common.SysError("failed to generate key: " + err.Error())
|
common.SysLog("failed to generate key: " + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.SetAccessToken(key)
|
user.SetAccessToken(key)
|
||||||
@@ -370,10 +357,7 @@ func GenerateAccessToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := user.Update(false); err != nil {
|
if err := user.Update(false); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -393,18 +377,12 @@ func TransferAffQuota(c *gin.Context) {
|
|||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
user, err := model.GetUserById(id, true)
|
user, err := model.GetUserById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tran := TransferAffQuotaRequest{}
|
tran := TransferAffQuotaRequest{}
|
||||||
if err := c.ShouldBindJSON(&tran); err != nil {
|
if err := c.ShouldBindJSON(&tran); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = user.TransferAffQuotaToQuota(tran.Quota)
|
err = user.TransferAffQuotaToQuota(tran.Quota)
|
||||||
@@ -425,10 +403,7 @@ func GetAffCode(c *gin.Context) {
|
|||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
user, err := model.GetUserById(id, true)
|
user, err := model.GetUserById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if user.AffCode == "" {
|
if user.AffCode == "" {
|
||||||
@@ -453,12 +428,12 @@ func GetSelf(c *gin.Context) {
|
|||||||
id := c.GetInt("id")
|
id := c.GetInt("id")
|
||||||
user, err := model.GetUserById(id, false)
|
user, err := model.GetUserById(id, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
||||||
|
user.Remark = ""
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
@@ -474,16 +449,13 @@ func GetUserModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
user, err := model.GetUserCache(id)
|
user, err := model.GetUserCache(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
groups := setting.GetUserUsableGroups(user.Group)
|
groups := setting.GetUserUsableGroups(user.Group)
|
||||||
var models []string
|
var models []string
|
||||||
for group := range groups {
|
for group := range groups {
|
||||||
for _, g := range model.GetGroupModels(group) {
|
for _, g := range model.GetGroupEnabledModels(group) {
|
||||||
if !common.StringsContains(models, g) {
|
if !common.StringsContains(models, g) {
|
||||||
models = append(models, g)
|
models = append(models, g)
|
||||||
}
|
}
|
||||||
@@ -519,10 +491,7 @@ func UpdateUser(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
originUser, err := model.GetUserById(updatedUser.Id, false)
|
originUser, err := model.GetUserById(updatedUser.Id, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
myRole := c.GetInt("role")
|
myRole := c.GetInt("role")
|
||||||
@@ -545,14 +514,11 @@ func UpdateUser(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
updatePassword := updatedUser.Password != ""
|
updatePassword := updatedUser.Password != ""
|
||||||
if err := updatedUser.Edit(updatePassword); err != nil {
|
if err := updatedUser.Edit(updatePassword); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if originUser.Quota != updatedUser.Quota {
|
if originUser.Quota != updatedUser.Quota {
|
||||||
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
|
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota)))
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
@@ -594,17 +560,11 @@ func UpdateSelf(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id)
|
updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := cleanUser.Update(updatePassword); err != nil {
|
if err := cleanUser.Update(updatePassword); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -635,18 +595,12 @@ func checkUpdatePassword(originalPassword string, newPassword string, userId int
|
|||||||
func DeleteUser(c *gin.Context) {
|
func DeleteUser(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
originUser, err := model.GetUserById(id, false)
|
originUser, err := model.GetUserById(id, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
myRole := c.GetInt("role")
|
myRole := c.GetInt("role")
|
||||||
@@ -681,10 +635,7 @@ func DeleteSelf(c *gin.Context) {
|
|||||||
|
|
||||||
err := model.DeleteUserById(id)
|
err := model.DeleteUserById(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -730,10 +681,7 @@ func CreateUser(c *gin.Context) {
|
|||||||
DisplayName: user.DisplayName,
|
DisplayName: user.DisplayName,
|
||||||
}
|
}
|
||||||
if err := cleanUser.Insert(0); err != nil {
|
if err := cleanUser.Insert(0); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -843,10 +791,7 @@ func ManageUser(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := user.Update(false); err != nil {
|
if err := user.Update(false); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
clearUser := model.User{
|
clearUser := model.User{
|
||||||
@@ -878,20 +823,14 @@ func EmailBind(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
err := user.FillUserById()
|
err := user.FillUserById()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.Email = email
|
user.Email = email
|
||||||
// no need to check if this email already taken, because we have used verification code to check it
|
// no need to check if this email already taken, because we have used verification code to check it
|
||||||
err = user.Update(false)
|
err = user.Update(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -905,27 +844,67 @@ type topUpRequest struct {
|
|||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var topUpLock = sync.Mutex{}
|
var topUpLocks sync.Map
|
||||||
|
var topUpCreateLock sync.Mutex
|
||||||
|
|
||||||
|
type topUpTryLock struct {
|
||||||
|
ch chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTopUpTryLock() *topUpTryLock {
|
||||||
|
return &topUpTryLock{ch: make(chan struct{}, 1)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *topUpTryLock) TryLock() bool {
|
||||||
|
select {
|
||||||
|
case l.ch <- struct{}{}:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *topUpTryLock) Unlock() {
|
||||||
|
select {
|
||||||
|
case <-l.ch:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTopUpLock(userID int) *topUpTryLock {
|
||||||
|
if v, ok := topUpLocks.Load(userID); ok {
|
||||||
|
return v.(*topUpTryLock)
|
||||||
|
}
|
||||||
|
topUpCreateLock.Lock()
|
||||||
|
defer topUpCreateLock.Unlock()
|
||||||
|
if v, ok := topUpLocks.Load(userID); ok {
|
||||||
|
return v.(*topUpTryLock)
|
||||||
|
}
|
||||||
|
l := newTopUpTryLock()
|
||||||
|
topUpLocks.Store(userID, l)
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
func TopUp(c *gin.Context) {
|
func TopUp(c *gin.Context) {
|
||||||
topUpLock.Lock()
|
id := c.GetInt("id")
|
||||||
defer topUpLock.Unlock()
|
lock := getTopUpLock(id)
|
||||||
req := topUpRequest{}
|
if !lock.TryLock() {
|
||||||
err := c.ShouldBindJSON(&req)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": "充值处理中,请稍后重试",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
id := c.GetInt("id")
|
defer lock.Unlock()
|
||||||
|
req := topUpRequest{}
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
quota, err := model.Redeem(req.Key, id)
|
quota, err := model.Redeem(req.Key, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -933,7 +912,6 @@ func TopUp(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": quota,
|
"data": quota,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateUserSettingRequest struct {
|
type UpdateUserSettingRequest struct {
|
||||||
@@ -943,6 +921,7 @@ type UpdateUserSettingRequest struct {
|
|||||||
WebhookSecret string `json:"webhook_secret,omitempty"`
|
WebhookSecret string `json:"webhook_secret,omitempty"`
|
||||||
NotificationEmail string `json:"notification_email,omitempty"`
|
NotificationEmail string `json:"notification_email,omitempty"`
|
||||||
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
||||||
|
RecordIpLog bool `json:"record_ip_log"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUserSetting(c *gin.Context) {
|
func UpdateUserSetting(c *gin.Context) {
|
||||||
@@ -956,7 +935,7 @@ func UpdateUserSetting(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证预警类型
|
// 验证预警类型
|
||||||
if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
|
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无效的预警类型",
|
"message": "无效的预警类型",
|
||||||
@@ -974,7 +953,7 @@ func UpdateUserSetting(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 如果是webhook类型,验证webhook地址
|
// 如果是webhook类型,验证webhook地址
|
||||||
if req.QuotaWarningType == constant.NotifyTypeWebhook {
|
if req.QuotaWarningType == dto.NotifyTypeWebhook {
|
||||||
if req.WebhookUrl == "" {
|
if req.WebhookUrl == "" {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -993,7 +972,7 @@ func UpdateUserSetting(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 如果是邮件类型,验证邮箱地址
|
// 如果是邮件类型,验证邮箱地址
|
||||||
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
|
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||||
// 验证邮箱格式
|
// 验证邮箱格式
|
||||||
if !strings.Contains(req.NotificationEmail, "@") {
|
if !strings.Contains(req.NotificationEmail, "@") {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -1007,31 +986,29 @@ func UpdateUserSetting(c *gin.Context) {
|
|||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
user, err := model.GetUserById(userId, true)
|
user, err := model.GetUserById(userId, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建设置
|
// 构建设置
|
||||||
settings := map[string]interface{}{
|
settings := dto.UserSetting{
|
||||||
constant.UserSettingNotifyType: req.QuotaWarningType,
|
NotifyType: req.QuotaWarningType,
|
||||||
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
|
QuotaWarningThreshold: req.QuotaWarningThreshold,
|
||||||
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
|
AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel,
|
||||||
|
RecordIpLog: req.RecordIpLog,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果是webhook类型,添加webhook相关设置
|
// 如果是webhook类型,添加webhook相关设置
|
||||||
if req.QuotaWarningType == constant.NotifyTypeWebhook {
|
if req.QuotaWarningType == dto.NotifyTypeWebhook {
|
||||||
settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
|
settings.WebhookUrl = req.WebhookUrl
|
||||||
if req.WebhookSecret != "" {
|
if req.WebhookSecret != "" {
|
||||||
settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
|
settings.WebhookSecret = req.WebhookSecret
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果提供了通知邮箱,添加到设置中
|
// 如果提供了通知邮箱,添加到设置中
|
||||||
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
|
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||||
settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
|
settings.NotificationEmail = req.NotificationEmail
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新用户设置
|
// 更新用户设置
|
||||||
|
|||||||
124
controller/vendor_meta.go
Normal file
124
controller/vendor_meta.go
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetAllVendors 获取供应商列表(分页)
|
||||||
|
func GetAllVendors(c *gin.Context) {
|
||||||
|
pageInfo := common.GetPageQuery(c)
|
||||||
|
vendors, err := model.GetAllVendors(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var total int64
|
||||||
|
model.DB.Model(&model.Vendor{}).Count(&total)
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
|
pageInfo.SetItems(vendors)
|
||||||
|
common.ApiSuccess(c, pageInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchVendors 搜索供应商
|
||||||
|
func SearchVendors(c *gin.Context) {
|
||||||
|
keyword := c.Query("keyword")
|
||||||
|
pageInfo := common.GetPageQuery(c)
|
||||||
|
vendors, total, err := model.SearchVendors(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pageInfo.SetTotal(int(total))
|
||||||
|
pageInfo.SetItems(vendors)
|
||||||
|
common.ApiSuccess(c, pageInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVendorMeta 根据 ID 获取供应商
|
||||||
|
func GetVendorMeta(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
v, err := model.GetVendorByID(id)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateVendorMeta 新建供应商
|
||||||
|
func CreateVendorMeta(c *gin.Context) {
|
||||||
|
var v model.Vendor
|
||||||
|
if err := c.ShouldBindJSON(&v); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if v.Name == "" {
|
||||||
|
common.ApiErrorMsg(c, "供应商名称不能为空")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 创建前先检查名称
|
||||||
|
if dup, err := model.IsVendorNameDuplicated(0, v.Name); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "供应商名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := v.Insert(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, &v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateVendorMeta 更新供应商
|
||||||
|
func UpdateVendorMeta(c *gin.Context) {
|
||||||
|
var v model.Vendor
|
||||||
|
if err := c.ShouldBindJSON(&v); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if v.Id == 0 {
|
||||||
|
common.ApiErrorMsg(c, "缺少供应商 ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 名称冲突检查
|
||||||
|
if dup, err := model.IsVendorNameDuplicated(v.Id, v.Name); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
} else if dup {
|
||||||
|
common.ApiErrorMsg(c, "供应商名称已存在")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := v.Update(); err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, &v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteVendorMeta 删除供应商
|
||||||
|
func DeleteVendorMeta(c *gin.Context) {
|
||||||
|
idStr := c.Param("id")
|
||||||
|
id, err := strconv.Atoi(idStr)
|
||||||
|
if err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := model.DB.Delete(&model.Vendor{}, id).Error; err != nil {
|
||||||
|
common.ApiError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
common.ApiSuccess(c, nil)
|
||||||
|
}
|
||||||
@@ -4,13 +4,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-contrib/sessions"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type wechatLoginResponse struct {
|
type wechatLoginResponse struct {
|
||||||
@@ -150,19 +151,13 @@ func WeChatBind(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
err = user.FillUserById()
|
err = user.FillUserById()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user.WeChatId = wechatId
|
user.WeChatId = wechatId
|
||||||
err = user.Update(false)
|
err = user.Update(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
common.ApiError(c, err)
|
||||||
"success": false,
|
|
||||||
"message": err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ services:
|
|||||||
- REDIS_CONN_STRING=redis://redis
|
- REDIS_CONN_STRING=redis://redis
|
||||||
- TZ=Asia/Shanghai
|
- TZ=Asia/Shanghai
|
||||||
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
||||||
# - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache,请取消注释
|
# - STREAMING_TIMEOUT=300 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值
|
||||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
|
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
|
||||||
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
||||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||||
|
|||||||
197
docs/api/web_api.md
Normal file
197
docs/api/web_api.md
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
# New API – Web 界面后端接口文档
|
||||||
|
|
||||||
|
> 本文档汇总了 **New API** 后端提供给前端 Web 界面的全部 REST 接口(不含 *Relay* 相关接口)。
|
||||||
|
>
|
||||||
|
> 接口前缀统一为 `https://<your-domain>`,以下仅列出 **路径**、**HTTP 方法**、**鉴权要求** 与 **功能简介**。
|
||||||
|
>
|
||||||
|
> 鉴权级别说明:
|
||||||
|
> * **公开** – 不需要登录即可调用
|
||||||
|
> * **用户** – 需携带用户 Token(`middleware.UserAuth`)
|
||||||
|
> * **管理员** – 需管理员 Token(`middleware.AdminAuth`)
|
||||||
|
> * **Root** – 仅限最高权限 Root 用户(`middleware.RootAuth`)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 初始化 / 系统状态
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/setup | 公开 | 获取系统初始化状态 |
|
||||||
|
| POST | /api/setup | 公开 | 完成首次安装向导 |
|
||||||
|
| GET | /api/status | 公开 | 获取运行状态摘要 |
|
||||||
|
| GET | /api/uptime/status | 公开 | Uptime-Kuma 兼容状态探针 |
|
||||||
|
| GET | /api/status/test | 管理员 | 测试后端与依赖组件是否正常 |
|
||||||
|
|
||||||
|
## 2. 公共信息
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/models | 用户 | 获取前端可用模型列表 |
|
||||||
|
| GET | /api/notice | 公开 | 获取公告栏内容 |
|
||||||
|
| GET | /api/about | 公开 | 关于页面信息 |
|
||||||
|
| GET | /api/home_page_content | 公开 | 首页自定义内容 |
|
||||||
|
| GET | /api/pricing | 可匿名/用户 | 价格与套餐信息 |
|
||||||
|
| GET | /api/ratio_config | 公开 | 模型倍率配置(仅公开字段) |
|
||||||
|
|
||||||
|
## 3. 邮件 / 身份验证
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/verification | 公开 (限流) | 发送邮箱验证邮件 |
|
||||||
|
| GET | /api/reset_password | 公开 (限流) | 发送重置密码邮件 |
|
||||||
|
| POST | /api/user/reset | 公开 | 提交重置密码请求 |
|
||||||
|
|
||||||
|
## 4. OAuth / 第三方登录
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/oauth/github | 公开 | GitHub OAuth 跳转 |
|
||||||
|
| GET | /api/oauth/oidc | 公开 | OIDC 通用 OAuth 跳转 |
|
||||||
|
| GET | /api/oauth/linuxdo | 公开 | LinuxDo OAuth 跳转 |
|
||||||
|
| GET | /api/oauth/wechat | 公开 | 微信扫码登录跳转 |
|
||||||
|
| GET | /api/oauth/wechat/bind | 公开 | 微信账户绑定 |
|
||||||
|
| GET | /api/oauth/email/bind | 公开 | 邮箱绑定 |
|
||||||
|
| GET | /api/oauth/telegram/login | 公开 | Telegram 登录 |
|
||||||
|
| GET | /api/oauth/telegram/bind | 公开 | Telegram 账户绑定 |
|
||||||
|
| GET | /api/oauth/state | 公开 | 获取随机 state(防 CSRF) |
|
||||||
|
|
||||||
|
## 5. 用户模块
|
||||||
|
### 5.1 账号注册/登录
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| POST | /api/user/register | 公开 | 注册新账号 |
|
||||||
|
| POST | /api/user/login | 公开 | 用户登录 |
|
||||||
|
| GET | /api/user/logout | 用户 | 退出登录 |
|
||||||
|
| GET | /api/user/epay/notify | 公开 | Epay 支付回调 |
|
||||||
|
| GET | /api/user/groups | 公开 | 列出所有分组(无鉴权版) |
|
||||||
|
|
||||||
|
### 5.2 用户自身操作 (需登录)
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/user/self/groups | 用户 | 获取自己所在分组 |
|
||||||
|
| GET | /api/user/self | 用户 | 获取个人资料 |
|
||||||
|
| GET | /api/user/models | 用户 | 获取模型可见性 |
|
||||||
|
| PUT | /api/user/self | 用户 | 修改个人资料 |
|
||||||
|
| DELETE | /api/user/self | 用户 | 注销账号 |
|
||||||
|
| GET | /api/user/token | 用户 | 生成用户级别 Access Token |
|
||||||
|
| GET | /api/user/aff | 用户 | 获取推广码信息 |
|
||||||
|
| POST | /api/user/topup | 用户 | 余额直充 |
|
||||||
|
| POST | /api/user/pay | 用户 | 提交支付订单 |
|
||||||
|
| POST | /api/user/amount | 用户 | 余额支付 |
|
||||||
|
| POST | /api/user/aff_transfer | 用户 | 推广额度转账 |
|
||||||
|
| PUT | /api/user/setting | 用户 | 更新用户设置 |
|
||||||
|
|
||||||
|
### 5.3 管理员用户管理
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/user/ | 管理员 | 获取全部用户列表 |
|
||||||
|
| GET | /api/user/search | 管理员 | 搜索用户 |
|
||||||
|
| GET | /api/user/:id | 管理员 | 获取单个用户信息 |
|
||||||
|
| POST | /api/user/ | 管理员 | 创建用户 |
|
||||||
|
| POST | /api/user/manage | 管理员 | 冻结/重置等管理操作 |
|
||||||
|
| PUT | /api/user/ | 管理员 | 更新用户 |
|
||||||
|
| DELETE | /api/user/:id | 管理员 | 删除用户 |
|
||||||
|
|
||||||
|
## 6. 站点选项 (Root)
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/option/ | Root | 获取全局配置 |
|
||||||
|
| PUT | /api/option/ | Root | 更新全局配置 |
|
||||||
|
| POST | /api/option/rest_model_ratio | Root | 重置模型倍率 |
|
||||||
|
| POST | /api/option/migrate_console_setting | Root | 迁移旧版控制台配置 |
|
||||||
|
|
||||||
|
## 7. 模型倍率同步 (Root)
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/ratio_sync/channels | Root | 获取可同步渠道列表 |
|
||||||
|
| POST | /api/ratio_sync/fetch | Root | 从上游拉取倍率 |
|
||||||
|
|
||||||
|
## 8. 渠道管理 (管理员)
|
||||||
|
| 方法 | 路径 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| GET | /api/channel/ | 获取渠道列表 |
|
||||||
|
| GET | /api/channel/search | 搜索渠道 |
|
||||||
|
| GET | /api/channel/models | 查询渠道模型能力 |
|
||||||
|
| GET | /api/channel/models_enabled | 查询启用模型能力 |
|
||||||
|
| GET | /api/channel/:id | 获取单个渠道 |
|
||||||
|
| GET | /api/channel/test | 批量测试渠道连通性 |
|
||||||
|
| GET | /api/channel/test/:id | 单个渠道测试 |
|
||||||
|
| GET | /api/channel/update_balance | 批量刷新余额 |
|
||||||
|
| GET | /api/channel/update_balance/:id | 单个刷新余额 |
|
||||||
|
| POST | /api/channel/ | 新增渠道 |
|
||||||
|
| PUT | /api/channel/ | 更新渠道 |
|
||||||
|
| DELETE | /api/channel/disabled | 删除已禁用渠道 |
|
||||||
|
| POST | /api/channel/tag/disabled | 批量禁用标签渠道 |
|
||||||
|
| POST | /api/channel/tag/enabled | 批量启用标签渠道 |
|
||||||
|
| PUT | /api/channel/tag | 编辑渠道标签 |
|
||||||
|
| DELETE | /api/channel/:id | 删除渠道 |
|
||||||
|
| POST | /api/channel/batch | 批量删除渠道 |
|
||||||
|
| POST | /api/channel/fix | 修复渠道能力表 |
|
||||||
|
| GET | /api/channel/fetch_models/:id | 拉取单渠道模型 |
|
||||||
|
| POST | /api/channel/fetch_models | 拉取全部渠道模型 |
|
||||||
|
| POST | /api/channel/batch/tag | 批量设置渠道标签 |
|
||||||
|
| GET | /api/channel/tag/models | 根据标签获取模型 |
|
||||||
|
| POST | /api/channel/copy/:id | 复制渠道 |
|
||||||
|
|
||||||
|
## 9. Token 管理
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/token/ | 用户 | 获取全部 Token |
|
||||||
|
| GET | /api/token/search | 用户 | 搜索 Token |
|
||||||
|
| GET | /api/token/:id | 用户 | 获取单个 Token |
|
||||||
|
| POST | /api/token/ | 用户 | 创建 Token |
|
||||||
|
| PUT | /api/token/ | 用户 | 更新 Token |
|
||||||
|
| DELETE | /api/token/:id | 用户 | 删除 Token |
|
||||||
|
| POST | /api/token/batch | 用户 | 批量删除 Token |
|
||||||
|
|
||||||
|
## 10. 兑换码管理 (管理员)
|
||||||
|
| 方法 | 路径 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| GET | /api/redemption/ | 获取兑换码列表 |
|
||||||
|
| GET | /api/redemption/search | 搜索兑换码 |
|
||||||
|
| GET | /api/redemption/:id | 获取单个兑换码 |
|
||||||
|
| POST | /api/redemption/ | 创建兑换码 |
|
||||||
|
| PUT | /api/redemption/ | 更新兑换码 |
|
||||||
|
| DELETE | /api/redemption/invalid | 删除无效兑换码 |
|
||||||
|
| DELETE | /api/redemption/:id | 删除兑换码 |
|
||||||
|
|
||||||
|
## 11. 日志
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/log/ | 管理员 | 获取全部日志 |
|
||||||
|
| DELETE | /api/log/ | 管理员 | 删除历史日志 |
|
||||||
|
| GET | /api/log/stat | 管理员 | 日志统计 |
|
||||||
|
| GET | /api/log/self/stat | 用户 | 我的日志统计 |
|
||||||
|
| GET | /api/log/search | 管理员 | 搜索全部日志 |
|
||||||
|
| GET | /api/log/self | 用户 | 获取我的日志 |
|
||||||
|
| GET | /api/log/self/search | 用户 | 搜索我的日志 |
|
||||||
|
| GET | /api/log/token | 公开 | 根据 Token 查询日志(支持 CORS) |
|
||||||
|
|
||||||
|
## 12. 数据统计
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/data/ | 管理员 | 全站用量按日期统计 |
|
||||||
|
| GET | /api/data/self | 用户 | 我的用量按日期统计 |
|
||||||
|
|
||||||
|
## 13. 分组
|
||||||
|
| GET | /api/group/ | 管理员 | 获取全部分组列表 |
|
||||||
|
|
||||||
|
## 14. Midjourney 任务
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/mj/self | 用户 | 获取自己的 MJ 任务 |
|
||||||
|
| GET | /api/mj/ | 管理员 | 获取全部 MJ 任务 |
|
||||||
|
|
||||||
|
## 15. 任务中心
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /api/task/self | 用户 | 获取我的任务 |
|
||||||
|
| GET | /api/task/ | 管理员 | 获取全部任务 |
|
||||||
|
|
||||||
|
## 16. 账户计费面板 (Dashboard)
|
||||||
|
| 方法 | 路径 | 鉴权 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| GET | /dashboard/billing/subscription | 用户 Token | 获取订阅额度信息 |
|
||||||
|
| GET | /v1/dashboard/billing/subscription | 同上 | 兼容 OpenAI SDK 路径 |
|
||||||
|
| GET | /dashboard/billing/usage | 用户 Token | 获取使用量信息 |
|
||||||
|
| GET | /v1/dashboard/billing/usage | 同上 | 兼容 OpenAI SDK 路径 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
> **更新日期**:2025.07.17
|
||||||
BIN
docs/images/aliyun.png
Normal file
BIN
docs/images/aliyun.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.0 KiB |
BIN
docs/images/cherry-studio.png
Normal file
BIN
docs/images/cherry-studio.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
BIN
docs/images/io-net.png
Normal file
BIN
docs/images/io-net.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.0 KiB |
BIN
docs/images/pku.png
Normal file
BIN
docs/images/pku.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 12 KiB |
BIN
docs/images/ucloud.png
Normal file
BIN
docs/images/ucloud.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 11 KiB |
24
dto/audio.go
24
dto/audio.go
@@ -1,5 +1,11 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/types"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
type AudioRequest struct {
|
type AudioRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Input string `json:"input"`
|
Input string `json:"input"`
|
||||||
@@ -8,6 +14,24 @@ type AudioRequest struct {
|
|||||||
ResponseFormat string `json:"response_format,omitempty"`
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
meta := &types.TokenCountMeta{
|
||||||
|
CombineText: r.Input,
|
||||||
|
TokenType: types.TokenTypeTextNumber,
|
||||||
|
}
|
||||||
|
return meta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AudioRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *AudioRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type AudioResponse struct {
|
type AudioResponse struct {
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
}
|
}
|
||||||
|
|||||||
14
dto/channel_settings.go
Normal file
14
dto/channel_settings.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type ChannelSettings struct {
|
||||||
|
ForceFormat bool `json:"force_format,omitempty"`
|
||||||
|
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
|
||||||
|
Proxy string `json:"proxy"`
|
||||||
|
PassThroughBodyEnabled bool `json:"pass_through_body_enabled,omitempty"`
|
||||||
|
SystemPrompt string `json:"system_prompt,omitempty"`
|
||||||
|
SystemPromptOverride bool `json:"system_prompt_override,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChannelOtherSettings struct {
|
||||||
|
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||||
|
}
|
||||||
398
dto/claude.go
398
dto/claude.go
@@ -1,6 +1,14 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
import "encoding/json"
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
type ClaudeMetadata struct {
|
type ClaudeMetadata struct {
|
||||||
UserId string `json:"user_id"`
|
UserId string `json:"user_id"`
|
||||||
@@ -20,11 +28,11 @@ type ClaudeMediaMessage struct {
|
|||||||
Delta string `json:"delta,omitempty"`
|
Delta string `json:"delta,omitempty"`
|
||||||
CacheControl json.RawMessage `json:"cache_control,omitempty"`
|
CacheControl json.RawMessage `json:"cache_control,omitempty"`
|
||||||
// tool_calls
|
// tool_calls
|
||||||
Id string `json:"id,omitempty"`
|
Id string `json:"id,omitempty"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
Input any `json:"input,omitempty"`
|
Input any `json:"input,omitempty"`
|
||||||
Content json.RawMessage `json:"content,omitempty"`
|
Content any `json:"content,omitempty"`
|
||||||
ToolUseId string `json:"tool_use_id,omitempty"`
|
ToolUseId string `json:"tool_use_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaudeMediaMessage) SetText(s string) {
|
func (c *ClaudeMediaMessage) SetText(s string) {
|
||||||
@@ -39,34 +47,54 @@ func (c *ClaudeMediaMessage) GetText() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaudeMediaMessage) IsStringContent() bool {
|
func (c *ClaudeMediaMessage) IsStringContent() bool {
|
||||||
var content string
|
if c.Content == nil {
|
||||||
return json.Unmarshal(c.Content, &content) == nil
|
return false
|
||||||
|
}
|
||||||
|
_, ok := c.Content.(string)
|
||||||
|
if ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaudeMediaMessage) GetStringContent() string {
|
func (c *ClaudeMediaMessage) GetStringContent() string {
|
||||||
var content string
|
if c.Content == nil {
|
||||||
if err := json.Unmarshal(c.Content, &content); err == nil {
|
return ""
|
||||||
return content
|
|
||||||
}
|
}
|
||||||
|
switch c.Content.(type) {
|
||||||
|
case string:
|
||||||
|
return c.Content.(string)
|
||||||
|
case []any:
|
||||||
|
var contentStr string
|
||||||
|
for _, contentItem := range c.Content.([]any) {
|
||||||
|
contentMap, ok := contentItem.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if contentMap["type"] == ContentTypeText {
|
||||||
|
if subStr, ok := contentMap["text"].(string); ok {
|
||||||
|
contentStr += subStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return contentStr
|
||||||
|
}
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaudeMediaMessage) GetJsonRowString() string {
|
func (c *ClaudeMediaMessage) GetJsonRowString() string {
|
||||||
jsonContent, _ := json.Marshal(c)
|
jsonContent, _ := common.Marshal(c)
|
||||||
return string(jsonContent)
|
return string(jsonContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaudeMediaMessage) SetContent(content any) {
|
func (c *ClaudeMediaMessage) SetContent(content any) {
|
||||||
jsonContent, _ := json.Marshal(content)
|
c.Content = content
|
||||||
c.Content = jsonContent
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
|
func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
|
||||||
var mediaContent []ClaudeMediaMessage
|
mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.Content)
|
||||||
if err := json.Unmarshal(c.Content, &mediaContent); err == nil {
|
return mediaContent
|
||||||
return mediaContent
|
|
||||||
}
|
|
||||||
return make([]ClaudeMediaMessage, 0)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClaudeMessageSource struct {
|
type ClaudeMessageSource struct {
|
||||||
@@ -82,14 +110,36 @@ type ClaudeMessage struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaudeMessage) IsStringContent() bool {
|
func (c *ClaudeMessage) IsStringContent() bool {
|
||||||
|
if c.Content == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
_, ok := c.Content.(string)
|
_, ok := c.Content.(string)
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaudeMessage) GetStringContent() string {
|
func (c *ClaudeMessage) GetStringContent() string {
|
||||||
if c.IsStringContent() {
|
if c.Content == nil {
|
||||||
return c.Content.(string)
|
return ""
|
||||||
}
|
}
|
||||||
|
switch c.Content.(type) {
|
||||||
|
case string:
|
||||||
|
return c.Content.(string)
|
||||||
|
case []any:
|
||||||
|
var contentStr string
|
||||||
|
for _, contentItem := range c.Content.([]any) {
|
||||||
|
contentMap, ok := contentItem.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if contentMap["type"] == ContentTypeText {
|
||||||
|
if subStr, ok := contentMap["text"].(string); ok {
|
||||||
|
contentStr += subStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return contentStr
|
||||||
|
}
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,15 +148,7 @@ func (c *ClaudeMessage) SetStringContent(content string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
|
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
|
||||||
// map content to []ClaudeMediaMessage
|
return common.Any2Type[[]ClaudeMediaMessage](c.Content)
|
||||||
// parse to json
|
|
||||||
jsonContent, _ := json.Marshal(c.Content)
|
|
||||||
var contentList []ClaudeMediaMessage
|
|
||||||
err := json.Unmarshal(jsonContent, &contentList)
|
|
||||||
if err != nil {
|
|
||||||
return make([]ClaudeMediaMessage, 0), err
|
|
||||||
}
|
|
||||||
return contentList, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
@@ -121,6 +163,27 @@ type InputSchema struct {
|
|||||||
Required any `json:"required,omitempty"`
|
Required any `json:"required,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ClaudeWebSearchTool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
MaxUses int `json:"max_uses,omitempty"`
|
||||||
|
UserLocation *ClaudeWebSearchUserLocation `json:"user_location,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeWebSearchUserLocation struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Timezone string `json:"timezone,omitempty"`
|
||||||
|
Country string `json:"country,omitempty"`
|
||||||
|
Region string `json:"region,omitempty"`
|
||||||
|
City string `json:"city,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeToolChoice struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type ClaudeRequest struct {
|
type ClaudeRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
@@ -139,9 +202,210 @@ type ClaudeRequest struct {
|
|||||||
Thinking *Thinking `json:"thinking,omitempty"`
|
Thinking *Thinking `json:"thinking,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var tokenCountMeta = types.TokenCountMeta{
|
||||||
|
TokenType: types.TokenTypeTokenizer,
|
||||||
|
MaxTokens: int(c.MaxTokens),
|
||||||
|
}
|
||||||
|
|
||||||
|
var texts = make([]string, 0)
|
||||||
|
var fileMeta = make([]*types.FileMeta, 0)
|
||||||
|
|
||||||
|
// system
|
||||||
|
if c.System != nil {
|
||||||
|
if c.IsStringSystem() {
|
||||||
|
sys := c.GetStringSystem()
|
||||||
|
if sys != "" {
|
||||||
|
texts = append(texts, sys)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
systemMedia := c.ParseSystem()
|
||||||
|
for _, media := range systemMedia {
|
||||||
|
switch media.Type {
|
||||||
|
case "text":
|
||||||
|
texts = append(texts, media.GetText())
|
||||||
|
case "image":
|
||||||
|
if media.Source != nil {
|
||||||
|
data := media.Source.Url
|
||||||
|
if data == "" {
|
||||||
|
data = common.Interface2String(media.Source.Data)
|
||||||
|
}
|
||||||
|
if data != "" {
|
||||||
|
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// messages
|
||||||
|
for _, message := range c.Messages {
|
||||||
|
tokenCountMeta.MessagesCount++
|
||||||
|
texts = append(texts, message.Role)
|
||||||
|
if message.IsStringContent() {
|
||||||
|
content := message.GetStringContent()
|
||||||
|
if content != "" {
|
||||||
|
texts = append(texts, content)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
content, _ := message.ParseContent()
|
||||||
|
for _, media := range content {
|
||||||
|
switch media.Type {
|
||||||
|
case "text":
|
||||||
|
texts = append(texts, media.GetText())
|
||||||
|
case "image":
|
||||||
|
if media.Source != nil {
|
||||||
|
data := media.Source.Url
|
||||||
|
if data == "" {
|
||||||
|
data = common.Interface2String(media.Source.Data)
|
||||||
|
}
|
||||||
|
if data != "" {
|
||||||
|
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "tool_use":
|
||||||
|
if media.Name != "" {
|
||||||
|
texts = append(texts, media.Name)
|
||||||
|
}
|
||||||
|
if media.Input != nil {
|
||||||
|
b, _ := common.Marshal(media.Input)
|
||||||
|
texts = append(texts, string(b))
|
||||||
|
}
|
||||||
|
case "tool_result":
|
||||||
|
if media.Content != nil {
|
||||||
|
b, _ := common.Marshal(media.Content)
|
||||||
|
texts = append(texts, string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tools
|
||||||
|
if c.Tools != nil {
|
||||||
|
tools := c.GetTools()
|
||||||
|
normalTools, webSearchTools := ProcessTools(tools)
|
||||||
|
if normalTools != nil {
|
||||||
|
for _, t := range normalTools {
|
||||||
|
tokenCountMeta.ToolsCount++
|
||||||
|
if t.Name != "" {
|
||||||
|
texts = append(texts, t.Name)
|
||||||
|
}
|
||||||
|
if t.Description != "" {
|
||||||
|
texts = append(texts, t.Description)
|
||||||
|
}
|
||||||
|
if t.InputSchema != nil {
|
||||||
|
b, _ := common.Marshal(t.InputSchema)
|
||||||
|
texts = append(texts, string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if webSearchTools != nil {
|
||||||
|
for _, t := range webSearchTools {
|
||||||
|
tokenCountMeta.ToolsCount++
|
||||||
|
if t.Name != "" {
|
||||||
|
texts = append(texts, t.Name)
|
||||||
|
}
|
||||||
|
if t.UserLocation != nil {
|
||||||
|
b, _ := common.Marshal(t.UserLocation)
|
||||||
|
texts = append(texts, string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenCountMeta.CombineText = strings.Join(texts, "\n")
|
||||||
|
tokenCountMeta.Files = fileMeta
|
||||||
|
return &tokenCountMeta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeRequest) IsStream(ctx *gin.Context) bool {
|
||||||
|
return c.Stream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
c.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string {
|
||||||
|
for _, message := range c.Messages {
|
||||||
|
content, _ := message.ParseContent()
|
||||||
|
for _, mediaMessage := range content {
|
||||||
|
if mediaMessage.Id == toolCallId {
|
||||||
|
return mediaMessage.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTool 添加工具到请求中
|
||||||
|
func (c *ClaudeRequest) AddTool(tool any) {
|
||||||
|
if c.Tools == nil {
|
||||||
|
c.Tools = make([]any, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch tools := c.Tools.(type) {
|
||||||
|
case []any:
|
||||||
|
c.Tools = append(tools, tool)
|
||||||
|
default:
|
||||||
|
// 如果Tools不是[]any类型,重新初始化为[]any
|
||||||
|
c.Tools = []any{tool}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTools 获取工具列表
|
||||||
|
func (c *ClaudeRequest) GetTools() []any {
|
||||||
|
if c.Tools == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch tools := c.Tools.(type) {
|
||||||
|
case []any:
|
||||||
|
return tools
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessTools 处理工具列表,支持类型断言
|
||||||
|
func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) {
|
||||||
|
var normalTools []*Tool
|
||||||
|
var webSearchTools []*ClaudeWebSearchTool
|
||||||
|
|
||||||
|
for _, tool := range tools {
|
||||||
|
switch t := tool.(type) {
|
||||||
|
case *Tool:
|
||||||
|
normalTools = append(normalTools, t)
|
||||||
|
case *ClaudeWebSearchTool:
|
||||||
|
webSearchTools = append(webSearchTools, t)
|
||||||
|
case Tool:
|
||||||
|
normalTools = append(normalTools, &t)
|
||||||
|
case ClaudeWebSearchTool:
|
||||||
|
webSearchTools = append(webSearchTools, &t)
|
||||||
|
default:
|
||||||
|
// 未知类型,跳过
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return normalTools, webSearchTools
|
||||||
|
}
|
||||||
|
|
||||||
type Thinking struct {
|
type Thinking struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
BudgetTokens int `json:"budget_tokens"`
|
BudgetTokens *int `json:"budget_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Thinking) GetBudgetTokens() int {
|
||||||
|
if c.BudgetTokens == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *c.BudgetTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaudeRequest) IsStringSystem() bool {
|
func (c *ClaudeRequest) IsStringSystem() bool {
|
||||||
@@ -161,24 +425,13 @@ func (c *ClaudeRequest) SetStringSystem(system string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
|
func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
|
||||||
// map content to []ClaudeMediaMessage
|
mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.System)
|
||||||
// parse to json
|
return mediaContent
|
||||||
jsonContent, _ := json.Marshal(c.System)
|
|
||||||
var contentList []ClaudeMediaMessage
|
|
||||||
if err := json.Unmarshal(jsonContent, &contentList); err == nil {
|
|
||||||
return contentList
|
|
||||||
}
|
|
||||||
return make([]ClaudeMediaMessage, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClaudeError struct {
|
|
||||||
Type string `json:"type,omitempty"`
|
|
||||||
Message string `json:"message,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClaudeErrorWithStatusCode struct {
|
type ClaudeErrorWithStatusCode struct {
|
||||||
Error ClaudeError `json:"error"`
|
Error types.ClaudeError `json:"error"`
|
||||||
StatusCode int `json:"status_code"`
|
StatusCode int `json:"status_code"`
|
||||||
LocalError bool
|
LocalError bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,7 +443,7 @@ type ClaudeResponse struct {
|
|||||||
Completion string `json:"completion,omitempty"`
|
Completion string `json:"completion,omitempty"`
|
||||||
StopReason string `json:"stop_reason,omitempty"`
|
StopReason string `json:"stop_reason,omitempty"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Error *ClaudeError `json:"error,omitempty"`
|
Error any `json:"error,omitempty"`
|
||||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||||
Index *int `json:"index,omitempty"`
|
Index *int `json:"index,omitempty"`
|
||||||
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
||||||
@@ -211,9 +464,50 @@ func (c *ClaudeResponse) GetIndex() int {
|
|||||||
return *c.Index
|
return *c.Index
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClaudeUsage struct {
|
// GetClaudeError 从动态错误类型中提取ClaudeError结构
|
||||||
InputTokens int `json:"input_tokens"`
|
func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError {
|
||||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
if c.Error == nil {
|
||||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
return nil
|
||||||
OutputTokens int `json:"output_tokens"`
|
}
|
||||||
|
|
||||||
|
switch err := c.Error.(type) {
|
||||||
|
case types.ClaudeError:
|
||||||
|
return &err
|
||||||
|
case *types.ClaudeError:
|
||||||
|
return err
|
||||||
|
case map[string]interface{}:
|
||||||
|
// 处理从JSON解析来的map结构
|
||||||
|
claudeErr := &types.ClaudeError{}
|
||||||
|
if errType, ok := err["type"].(string); ok {
|
||||||
|
claudeErr.Type = errType
|
||||||
|
}
|
||||||
|
if errMsg, ok := err["message"].(string); ok {
|
||||||
|
claudeErr.Message = errMsg
|
||||||
|
}
|
||||||
|
return claudeErr
|
||||||
|
case string:
|
||||||
|
// 处理简单字符串错误
|
||||||
|
return &types.ClaudeError{
|
||||||
|
Type: "error",
|
||||||
|
Message: err,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// 未知类型,尝试转换为字符串
|
||||||
|
return &types.ClaudeError{
|
||||||
|
Type: "unknown_error",
|
||||||
|
Message: fmt.Sprintf("%v", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeUsage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||||
|
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClaudeServerToolUse struct {
|
||||||
|
WebSearchRequests int `json:"web_search_requests"`
|
||||||
}
|
}
|
||||||
|
|||||||
28
dto/dalle.go
28
dto/dalle.go
@@ -1,28 +0,0 @@
|
|||||||
package dto
|
|
||||||
|
|
||||||
import "encoding/json"
|
|
||||||
|
|
||||||
type ImageRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Prompt string `json:"prompt" binding:"required"`
|
|
||||||
N int `json:"n,omitempty"`
|
|
||||||
Size string `json:"size,omitempty"`
|
|
||||||
Quality string `json:"quality,omitempty"`
|
|
||||||
ResponseFormat string `json:"response_format,omitempty"`
|
|
||||||
Style string `json:"style,omitempty"`
|
|
||||||
User string `json:"user,omitempty"`
|
|
||||||
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
|
|
||||||
Background string `json:"background,omitempty"`
|
|
||||||
Moderation string `json:"moderation,omitempty"`
|
|
||||||
OutputFormat string `json:"output_format,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ImageResponse struct {
|
|
||||||
Data []ImageData `json:"data"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
}
|
|
||||||
type ImageData struct {
|
|
||||||
Url string `json:"url"`
|
|
||||||
B64Json string `json:"b64_json"`
|
|
||||||
RevisedPrompt string `json:"revised_prompt"`
|
|
||||||
}
|
|
||||||
@@ -1,5 +1,12 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
type EmbeddingOptions struct {
|
type EmbeddingOptions struct {
|
||||||
Seed int `json:"seed,omitempty"`
|
Seed int `json:"seed,omitempty"`
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
@@ -24,9 +31,32 @@ type EmbeddingRequest struct {
|
|||||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r EmbeddingRequest) ParseInput() []string {
|
func (r *EmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var texts = make([]string, 0)
|
||||||
|
|
||||||
|
inputs := r.ParseInput()
|
||||||
|
for _, input := range inputs {
|
||||||
|
texts = append(texts, input)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: strings.Join(texts, "\n"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EmbeddingRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EmbeddingRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *EmbeddingRequest) ParseInput() []string {
|
||||||
if r.Input == nil {
|
if r.Input == nil {
|
||||||
return nil
|
return make([]string, 0)
|
||||||
}
|
}
|
||||||
var input []string
|
var input []string
|
||||||
switch r.Input.(type) {
|
switch r.Input.(type) {
|
||||||
|
|||||||
12
dto/error.go
12
dto/error.go
@@ -1,5 +1,7 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
|
import "one-api/types"
|
||||||
|
|
||||||
type OpenAIError struct {
|
type OpenAIError struct {
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
@@ -14,11 +16,11 @@ type OpenAIErrorWithStatusCode struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GeneralErrorResponse struct {
|
type GeneralErrorResponse struct {
|
||||||
Error OpenAIError `json:"error"`
|
Error types.OpenAIError `json:"error"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Msg string `json:"msg"`
|
Msg string `json:"msg"`
|
||||||
Err string `json:"err"`
|
Err string `json:"err"`
|
||||||
ErrorMsg string `json:"error_msg"`
|
ErrorMsg string `json:"error_msg"`
|
||||||
Header struct {
|
Header struct {
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
} `json:"header"`
|
} `json:"header"`
|
||||||
|
|||||||
384
dto/gemini.go
Normal file
384
dto/gemini.go
Normal file
@@ -0,0 +1,384 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/logger"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GeminiChatRequest struct {
|
||||||
|
Contents []GeminiChatContent `json:"contents"`
|
||||||
|
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
||||||
|
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
||||||
|
Tools json.RawMessage `json:"tools,omitempty"`
|
||||||
|
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var files []*types.FileMeta = make([]*types.FileMeta, 0)
|
||||||
|
|
||||||
|
var maxTokens int
|
||||||
|
|
||||||
|
if r.GenerationConfig.MaxOutputTokens > 0 {
|
||||||
|
maxTokens = int(r.GenerationConfig.MaxOutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
var inputTexts []string
|
||||||
|
for _, content := range r.Contents {
|
||||||
|
for _, part := range content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
inputTexts = append(inputTexts, part.Text)
|
||||||
|
}
|
||||||
|
if part.InlineData != nil && part.InlineData.Data != "" {
|
||||||
|
if strings.HasPrefix(part.InlineData.MimeType, "image/") {
|
||||||
|
files = append(files, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeImage,
|
||||||
|
OriginData: part.InlineData.Data,
|
||||||
|
})
|
||||||
|
} else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
|
||||||
|
files = append(files, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeAudio,
|
||||||
|
OriginData: part.InlineData.Data,
|
||||||
|
})
|
||||||
|
} else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
|
||||||
|
files = append(files, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeVideo,
|
||||||
|
OriginData: part.InlineData.Data,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
files = append(files, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeFile,
|
||||||
|
OriginData: part.InlineData.Data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputText := strings.Join(inputTexts, "\n")
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: inputText,
|
||||||
|
Files: files,
|
||||||
|
MaxTokens: maxTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
|
||||||
|
if c.Query("alt") == "sse" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) SetModelName(modelName string) {
|
||||||
|
// GeminiChatRequest does not have a model field, so this method does nothing.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
|
||||||
|
var tools []GeminiChatTool
|
||||||
|
if strings.HasSuffix(string(r.Tools), "[") {
|
||||||
|
// is array
|
||||||
|
if err := common.Unmarshal(r.Tools, &tools); err != nil {
|
||||||
|
logger.LogError(nil, "error_unmarshalling_tools: "+err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(string(r.Tools), "{") {
|
||||||
|
// is object
|
||||||
|
singleTool := GeminiChatTool{}
|
||||||
|
if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
|
||||||
|
logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tools = []GeminiChatTool{singleTool}
|
||||||
|
}
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
|
||||||
|
if len(tools) == 0 {
|
||||||
|
r.Tools = json.RawMessage("[]")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal the tools to JSON
|
||||||
|
data, err := common.Marshal(tools)
|
||||||
|
if err != nil {
|
||||||
|
logger.LogError(nil, "error_marshalling_tools: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.Tools = data
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiThinkingConfig struct {
|
||||||
|
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||||
|
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
|
||||||
|
c.ThinkingBudget = &budget
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiInlineData struct {
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
|
||||||
|
func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
|
||||||
|
type Alias GeminiInlineData // Use type alias to avoid recursion
|
||||||
|
var aux struct {
|
||||||
|
Alias
|
||||||
|
MimeTypeSnake string `json:"mime_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := common.Unmarshal(data, &aux); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*g = GeminiInlineData(aux.Alias) // Copy other fields if any in future
|
||||||
|
|
||||||
|
// Prioritize snake_case if present
|
||||||
|
if aux.MimeTypeSnake != "" {
|
||||||
|
g.MimeType = aux.MimeTypeSnake
|
||||||
|
} else if aux.MimeType != "" { // Fallback to camelCase from Alias
|
||||||
|
g.MimeType = aux.MimeType
|
||||||
|
}
|
||||||
|
// g.Data would be populated by aux.Alias.Data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionCall struct {
|
||||||
|
FunctionName string `json:"name"`
|
||||||
|
Arguments any `json:"args"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiFunctionResponse struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Response map[string]interface{} `json:"response"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiPartExecutableCode struct {
|
||||||
|
Language string `json:"language,omitempty"`
|
||||||
|
Code string `json:"code,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiPartCodeExecutionResult struct {
|
||||||
|
Outcome string `json:"outcome,omitempty"`
|
||||||
|
Output string `json:"output,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiFileData struct {
|
||||||
|
MimeType string `json:"mimeType,omitempty"`
|
||||||
|
FileUri string `json:"fileUri,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiPart struct {
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Thought bool `json:"thought,omitempty"`
|
||||||
|
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||||
|
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||||
|
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||||
|
FileData *GeminiFileData `json:"fileData,omitempty"`
|
||||||
|
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
|
||||||
|
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON custom unmarshaler for GeminiPart to support snake_case and camelCase for InlineData
|
||||||
|
func (p *GeminiPart) UnmarshalJSON(data []byte) error {
|
||||||
|
// Alias to avoid recursion during unmarshalling
|
||||||
|
type Alias GeminiPart
|
||||||
|
var aux struct {
|
||||||
|
Alias
|
||||||
|
InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := common.Unmarshal(data, &aux); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign fields from alias
|
||||||
|
*p = GeminiPart(aux.Alias)
|
||||||
|
|
||||||
|
// Prioritize snake_case for InlineData if present
|
||||||
|
if aux.InlineDataSnake != nil {
|
||||||
|
p.InlineData = aux.InlineDataSnake
|
||||||
|
} else if aux.InlineData != nil { // Fallback to camelCase from Alias
|
||||||
|
p.InlineData = aux.InlineData
|
||||||
|
}
|
||||||
|
// Other fields like Text, FunctionCall etc. are already populated via aux.Alias
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatContent struct {
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
Parts []GeminiPart `json:"parts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatSafetySettings struct {
|
||||||
|
Category string `json:"category"`
|
||||||
|
Threshold string `json:"threshold"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatTool struct {
|
||||||
|
GoogleSearch any `json:"googleSearch,omitempty"`
|
||||||
|
GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
|
||||||
|
CodeExecution any `json:"codeExecution,omitempty"`
|
||||||
|
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatGenerationConfig struct {
|
||||||
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"topP,omitempty"`
|
||||||
|
TopK float64 `json:"topK,omitempty"`
|
||||||
|
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||||
|
CandidateCount int `json:"candidateCount,omitempty"`
|
||||||
|
StopSequences []string `json:"stopSequences,omitempty"`
|
||||||
|
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||||
|
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||||
|
Seed int64 `json:"seed,omitempty"`
|
||||||
|
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||||
|
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||||
|
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatCandidate struct {
|
||||||
|
Content GeminiChatContent `json:"content"`
|
||||||
|
FinishReason *string `json:"finishReason"`
|
||||||
|
Index int64 `json:"index"`
|
||||||
|
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatSafetyRating struct {
|
||||||
|
Category string `json:"category"`
|
||||||
|
Probability string `json:"probability"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatPromptFeedback struct {
|
||||||
|
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiChatResponse struct {
|
||||||
|
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||||
|
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||||
|
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiUsageMetadata struct {
|
||||||
|
PromptTokenCount int `json:"promptTokenCount"`
|
||||||
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||||
|
TotalTokenCount int `json:"totalTokenCount"`
|
||||||
|
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
|
||||||
|
PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiPromptTokensDetails struct {
|
||||||
|
Modality string `json:"modality"`
|
||||||
|
TokenCount int `json:"tokenCount"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Imagen related structs
|
||||||
|
type GeminiImageRequest struct {
|
||||||
|
Instances []GeminiImageInstance `json:"instances"`
|
||||||
|
Parameters GeminiImageParameters `json:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiImageInstance struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiImageParameters struct {
|
||||||
|
SampleCount int `json:"sampleCount,omitempty"`
|
||||||
|
AspectRatio string `json:"aspectRatio,omitempty"`
|
||||||
|
PersonGeneration string `json:"personGeneration,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiImageResponse struct {
|
||||||
|
Predictions []GeminiImagePrediction `json:"predictions"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiImagePrediction struct {
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||||
|
RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
|
||||||
|
SafetyAttributes any `json:"safetyAttributes,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Embedding related structs
|
||||||
|
type GeminiEmbeddingRequest struct {
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
Content GeminiChatContent `json:"content"`
|
||||||
|
TaskType string `json:"taskType,omitempty"`
|
||||||
|
Title string `json:"title,omitempty"`
|
||||||
|
OutputDimensionality int `json:"outputDimensionality,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiEmbeddingRequest) IsStream(c *gin.Context) bool {
|
||||||
|
// Gemini embedding requests are not streamed
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var inputTexts []string
|
||||||
|
for _, part := range r.Content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
inputTexts = append(inputTexts, part.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inputText := strings.Join(inputTexts, "\n")
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: inputText,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiEmbeddingRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiBatchEmbeddingRequest struct {
|
||||||
|
Requests []*GeminiEmbeddingRequest `json:"requests"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiBatchEmbeddingRequest) IsStream(c *gin.Context) bool {
|
||||||
|
// Gemini batch embedding requests are not streamed
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiBatchEmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var inputTexts []string
|
||||||
|
for _, request := range r.Requests {
|
||||||
|
meta := request.GetTokenCountMeta()
|
||||||
|
if meta != nil && meta.CombineText != "" {
|
||||||
|
inputTexts = append(inputTexts, meta.CombineText)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inputText := strings.Join(inputTexts, "\n")
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: inputText,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeminiBatchEmbeddingRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
for _, req := range r.Requests {
|
||||||
|
req.SetModelName(modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiEmbeddingResponse struct {
|
||||||
|
Embedding ContentEmbedding `json:"embedding"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiBatchEmbeddingResponse struct {
|
||||||
|
Embeddings []*ContentEmbedding `json:"embeddings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContentEmbedding struct {
|
||||||
|
Values []float64 `json:"values"`
|
||||||
|
}
|
||||||
@@ -57,6 +57,8 @@ type MidjourneyDto struct {
|
|||||||
StartTime int64 `json:"startTime"`
|
StartTime int64 `json:"startTime"`
|
||||||
FinishTime int64 `json:"finishTime"`
|
FinishTime int64 `json:"finishTime"`
|
||||||
ImageUrl string `json:"imageUrl"`
|
ImageUrl string `json:"imageUrl"`
|
||||||
|
VideoUrl string `json:"videoUrl"`
|
||||||
|
VideoUrls []ImgUrls `json:"videoUrls"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Progress string `json:"progress"`
|
Progress string `json:"progress"`
|
||||||
FailReason string `json:"failReason"`
|
FailReason string `json:"failReason"`
|
||||||
@@ -65,6 +67,10 @@ type MidjourneyDto struct {
|
|||||||
Properties *Properties `json:"properties"`
|
Properties *Properties `json:"properties"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ImgUrls struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
type MidjourneyStatus struct {
|
type MidjourneyStatus struct {
|
||||||
Status int `json:"status"`
|
Status int `json:"status"`
|
||||||
}
|
}
|
||||||
|
|||||||
80
dto/openai_image.go
Normal file
80
dto/openai_image.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ImageRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Prompt string `json:"prompt" binding:"required"`
|
||||||
|
N uint `json:"n,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
Quality string `json:"quality,omitempty"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
Style json.RawMessage `json:"style,omitempty"`
|
||||||
|
User json.RawMessage `json:"user,omitempty"`
|
||||||
|
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
|
||||||
|
Background json.RawMessage `json:"background,omitempty"`
|
||||||
|
Moderation json.RawMessage `json:"moderation,omitempty"`
|
||||||
|
OutputFormat json.RawMessage `json:"output_format,omitempty"`
|
||||||
|
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
|
||||||
|
PartialImages json.RawMessage `json:"partial_images,omitempty"`
|
||||||
|
// Stream bool `json:"stream,omitempty"`
|
||||||
|
Watermark *bool `json:"watermark,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var sizeRatio = 1.0
|
||||||
|
var qualityRatio = 1.0
|
||||||
|
|
||||||
|
if strings.HasPrefix(i.Model, "dall-e") {
|
||||||
|
// Size
|
||||||
|
if i.Size == "256x256" {
|
||||||
|
sizeRatio = 0.4
|
||||||
|
} else if i.Size == "512x512" {
|
||||||
|
sizeRatio = 0.45
|
||||||
|
} else if i.Size == "1024x1024" {
|
||||||
|
sizeRatio = 1
|
||||||
|
} else if i.Size == "1024x1792" || i.Size == "1792x1024" {
|
||||||
|
sizeRatio = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if i.Model == "dall-e-3" && i.Quality == "hd" {
|
||||||
|
qualityRatio = 2.0
|
||||||
|
if i.Size == "1024x1792" || i.Size == "1792x1024" {
|
||||||
|
qualityRatio = 1.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// not support token count for dalle
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: i.Prompt,
|
||||||
|
MaxTokens: 1584,
|
||||||
|
ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ImageRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ImageRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
i.Model = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageResponse struct {
|
||||||
|
Data []ImageData `json:"data"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
}
|
||||||
|
type ImageData struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
B64Json string `json:"b64_json"`
|
||||||
|
RevisedPrompt string `json:"revised_prompt"`
|
||||||
|
}
|
||||||
@@ -2,70 +2,211 @@ package dto
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ResponseFormat struct {
|
type ResponseFormat struct {
|
||||||
Type string `json:"type,omitempty"`
|
Type string `json:"type,omitempty"`
|
||||||
JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
|
JsonSchema json.RawMessage `json:"json_schema,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type FormatJsonSchema struct {
|
type FormatJsonSchema struct {
|
||||||
Description string `json:"description,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Schema any `json:"schema,omitempty"`
|
Schema any `json:"schema,omitempty"`
|
||||||
Strict any `json:"strict,omitempty"`
|
Strict json.RawMessage `json:"strict,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeneralOpenAIRequest struct {
|
type GeneralOpenAIRequest struct {
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
Prompt any `json:"prompt,omitempty"`
|
Prompt any `json:"prompt,omitempty"`
|
||||||
Prefix any `json:"prefix,omitempty"`
|
Prefix any `json:"prefix,omitempty"`
|
||||||
Suffix any `json:"suffix,omitempty"`
|
Suffix any `json:"suffix,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
Verbosity json.RawMessage `json:"verbosity,omitempty"` // gpt-5
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
Stop any `json:"stop,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
N int `json:"n,omitempty"`
|
Stop any `json:"stop,omitempty"`
|
||||||
Input any `json:"input,omitempty"`
|
N int `json:"n,omitempty"`
|
||||||
Instruction string `json:"instruction,omitempty"`
|
Input any `json:"input,omitempty"`
|
||||||
Size string `json:"size,omitempty"`
|
Instruction string `json:"instruction,omitempty"`
|
||||||
Functions any `json:"functions,omitempty"`
|
Size string `json:"size,omitempty"`
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
Functions json.RawMessage `json:"functions,omitempty"`
|
||||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
EncodingFormat any `json:"encoding_format,omitempty"`
|
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||||
Seed float64 `json:"seed,omitempty"`
|
EncodingFormat json.RawMessage `json:"encoding_format,omitempty"`
|
||||||
ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
|
Seed float64 `json:"seed,omitempty"`
|
||||||
Tools []ToolCallRequest `json:"tools,omitempty"`
|
ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||||
ToolChoice any `json:"tool_choice,omitempty"`
|
Tools []ToolCallRequest `json:"tools,omitempty"`
|
||||||
User string `json:"user,omitempty"`
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
LogProbs bool `json:"logprobs,omitempty"`
|
User string `json:"user,omitempty"`
|
||||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
LogProbs bool `json:"logprobs,omitempty"`
|
||||||
Dimensions int `json:"dimensions,omitempty"`
|
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||||
Modalities any `json:"modalities,omitempty"`
|
Dimensions int `json:"dimensions,omitempty"`
|
||||||
Audio any `json:"audio,omitempty"`
|
Modalities json.RawMessage `json:"modalities,omitempty"`
|
||||||
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
Audio json.RawMessage `json:"audio,omitempty"`
|
||||||
ExtraBody any `json:"extra_body,omitempty"`
|
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
||||||
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao,zhipu_v4
|
||||||
// OpenRouter Params
|
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||||
|
SearchParameters any `json:"search_parameters,omitempty"` //xai
|
||||||
|
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||||
|
// OpenRouter Params
|
||||||
|
Usage json.RawMessage `json:"usage,omitempty"`
|
||||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||||
|
// Ali Qwen Params
|
||||||
|
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||||
|
// 用匿名参数接收额外参数,例如ollama的think参数在此接收
|
||||||
|
Extra map[string]json.RawMessage `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var tokenCountMeta types.TokenCountMeta
|
||||||
|
var texts = make([]string, 0)
|
||||||
|
var fileMeta = make([]*types.FileMeta, 0)
|
||||||
|
|
||||||
|
if r.Prompt != nil {
|
||||||
|
switch v := r.Prompt.(type) {
|
||||||
|
case string:
|
||||||
|
texts = append(texts, v)
|
||||||
|
case []any:
|
||||||
|
for _, item := range v {
|
||||||
|
if str, ok := item.(string); ok {
|
||||||
|
texts = append(texts, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
texts = append(texts, fmt.Sprintf("%v", r.Prompt))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Input != nil {
|
||||||
|
inputs := r.ParseInput()
|
||||||
|
texts = append(texts, inputs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.MaxCompletionTokens > r.MaxTokens {
|
||||||
|
tokenCountMeta.MaxTokens = int(r.MaxCompletionTokens)
|
||||||
|
} else {
|
||||||
|
tokenCountMeta.MaxTokens = int(r.MaxTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, message := range r.Messages {
|
||||||
|
tokenCountMeta.MessagesCount++
|
||||||
|
texts = append(texts, message.Role)
|
||||||
|
if message.Content != nil {
|
||||||
|
if message.Name != nil {
|
||||||
|
tokenCountMeta.NameCount++
|
||||||
|
texts = append(texts, *message.Name)
|
||||||
|
}
|
||||||
|
arrayContent := message.ParseContent()
|
||||||
|
for _, m := range arrayContent {
|
||||||
|
if m.Type == ContentTypeImageURL {
|
||||||
|
imageUrl := m.GetImageMedia()
|
||||||
|
if imageUrl != nil {
|
||||||
|
if imageUrl.Url != "" {
|
||||||
|
meta := &types.FileMeta{
|
||||||
|
FileType: types.FileTypeImage,
|
||||||
|
}
|
||||||
|
meta.OriginData = imageUrl.Url
|
||||||
|
meta.Detail = imageUrl.Detail
|
||||||
|
fileMeta = append(fileMeta, meta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if m.Type == ContentTypeInputAudio {
|
||||||
|
inputAudio := m.GetInputAudio()
|
||||||
|
if inputAudio != nil {
|
||||||
|
meta := &types.FileMeta{
|
||||||
|
FileType: types.FileTypeAudio,
|
||||||
|
}
|
||||||
|
meta.OriginData = inputAudio.Data
|
||||||
|
fileMeta = append(fileMeta, meta)
|
||||||
|
}
|
||||||
|
} else if m.Type == ContentTypeFile {
|
||||||
|
file := m.GetFile()
|
||||||
|
if file != nil {
|
||||||
|
meta := &types.FileMeta{
|
||||||
|
FileType: types.FileTypeFile,
|
||||||
|
}
|
||||||
|
meta.OriginData = file.FileData
|
||||||
|
fileMeta = append(fileMeta, meta)
|
||||||
|
}
|
||||||
|
} else if m.Type == ContentTypeVideoUrl {
|
||||||
|
videoUrl := m.GetVideoUrl()
|
||||||
|
if videoUrl != nil && videoUrl.Url != "" {
|
||||||
|
meta := &types.FileMeta{
|
||||||
|
FileType: types.FileTypeVideo,
|
||||||
|
}
|
||||||
|
meta.OriginData = videoUrl.Url
|
||||||
|
fileMeta = append(fileMeta, meta)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
texts = append(texts, m.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Tools != nil {
|
||||||
|
openaiTools := r.Tools
|
||||||
|
for _, tool := range openaiTools {
|
||||||
|
tokenCountMeta.ToolsCount++
|
||||||
|
texts = append(texts, tool.Function.Name)
|
||||||
|
if tool.Function.Description != "" {
|
||||||
|
texts = append(texts, tool.Function.Description)
|
||||||
|
}
|
||||||
|
if tool.Function.Parameters != nil {
|
||||||
|
texts = append(texts, fmt.Sprintf("%v", tool.Function.Parameters))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//toolTokens := CountTokenInput(countStr, request.Model)
|
||||||
|
//tkm += 8
|
||||||
|
//tkm += toolTokens
|
||||||
|
}
|
||||||
|
tokenCountMeta.CombineText = strings.Join(texts, "\n")
|
||||||
|
tokenCountMeta.Files = fileMeta
|
||||||
|
return &tokenCountMeta
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return r.Stream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GeneralOpenAIRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||||
result := make(map[string]any)
|
result := make(map[string]any)
|
||||||
data, _ := common.EncodeJson(r)
|
data, _ := common.Marshal(r)
|
||||||
_ = common.DecodeJson(data, &result)
|
_ = common.Unmarshal(data, &result)
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
|
||||||
|
if strings.HasPrefix(r.Model, "o") {
|
||||||
|
if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
|
||||||
|
return "developer"
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(r.Model, "gpt-5") {
|
||||||
|
return "developer"
|
||||||
|
}
|
||||||
|
return "system"
|
||||||
|
}
|
||||||
|
|
||||||
type ToolCallRequest struct {
|
type ToolCallRequest struct {
|
||||||
ID string `json:"id,omitempty"`
|
ID string `json:"id,omitempty"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
@@ -83,8 +224,11 @@ type StreamOptions struct {
|
|||||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GeneralOpenAIRequest) GetMaxTokens() int {
|
func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
|
||||||
return int(r.MaxTokens)
|
if r.MaxCompletionTokens != 0 {
|
||||||
|
return r.MaxCompletionTokens
|
||||||
|
}
|
||||||
|
return r.MaxTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GeneralOpenAIRequest) ParseInput() []string {
|
func (r *GeneralOpenAIRequest) ParseInput() []string {
|
||||||
@@ -107,16 +251,16 @@ func (r *GeneralOpenAIRequest) ParseInput() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content json.RawMessage `json:"content"`
|
Content any `json:"content"`
|
||||||
Name *string `json:"name,omitempty"`
|
Name *string `json:"name,omitempty"`
|
||||||
Prefix *bool `json:"prefix,omitempty"`
|
Prefix *bool `json:"prefix,omitempty"`
|
||||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
Reasoning string `json:"reasoning,omitempty"`
|
Reasoning string `json:"reasoning,omitempty"`
|
||||||
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
|
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
|
||||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||||
parsedContent []MediaContent
|
parsedContent []MediaContent
|
||||||
parsedStringContent *string
|
//parsedStringContent *string
|
||||||
}
|
}
|
||||||
|
|
||||||
type MediaContent struct {
|
type MediaContent struct {
|
||||||
@@ -132,21 +276,65 @@ type MediaContent struct {
|
|||||||
|
|
||||||
func (m *MediaContent) GetImageMedia() *MessageImageUrl {
|
func (m *MediaContent) GetImageMedia() *MessageImageUrl {
|
||||||
if m.ImageUrl != nil {
|
if m.ImageUrl != nil {
|
||||||
return m.ImageUrl.(*MessageImageUrl)
|
if _, ok := m.ImageUrl.(*MessageImageUrl); ok {
|
||||||
|
return m.ImageUrl.(*MessageImageUrl)
|
||||||
|
}
|
||||||
|
if itemMap, ok := m.ImageUrl.(map[string]any); ok {
|
||||||
|
out := &MessageImageUrl{
|
||||||
|
Url: common.Interface2String(itemMap["url"]),
|
||||||
|
Detail: common.Interface2String(itemMap["detail"]),
|
||||||
|
MimeType: common.Interface2String(itemMap["mime_type"]),
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MediaContent) GetInputAudio() *MessageInputAudio {
|
func (m *MediaContent) GetInputAudio() *MessageInputAudio {
|
||||||
if m.InputAudio != nil {
|
if m.InputAudio != nil {
|
||||||
return m.InputAudio.(*MessageInputAudio)
|
if _, ok := m.InputAudio.(*MessageInputAudio); ok {
|
||||||
|
return m.InputAudio.(*MessageInputAudio)
|
||||||
|
}
|
||||||
|
if itemMap, ok := m.InputAudio.(map[string]any); ok {
|
||||||
|
out := &MessageInputAudio{
|
||||||
|
Data: common.Interface2String(itemMap["data"]),
|
||||||
|
Format: common.Interface2String(itemMap["format"]),
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MediaContent) GetFile() *MessageFile {
|
func (m *MediaContent) GetFile() *MessageFile {
|
||||||
if m.File != nil {
|
if m.File != nil {
|
||||||
return m.File.(*MessageFile)
|
if _, ok := m.File.(*MessageFile); ok {
|
||||||
|
return m.File.(*MessageFile)
|
||||||
|
}
|
||||||
|
if itemMap, ok := m.File.(map[string]any); ok {
|
||||||
|
out := &MessageFile{
|
||||||
|
FileName: common.Interface2String(itemMap["file_name"]),
|
||||||
|
FileData: common.Interface2String(itemMap["file_data"]),
|
||||||
|
FileId: common.Interface2String(itemMap["file_id"]),
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MediaContent) GetVideoUrl() *MessageVideoUrl {
|
||||||
|
if m.VideoUrl != nil {
|
||||||
|
if _, ok := m.VideoUrl.(*MessageVideoUrl); ok {
|
||||||
|
return m.VideoUrl.(*MessageVideoUrl)
|
||||||
|
}
|
||||||
|
if itemMap, ok := m.VideoUrl.(map[string]any); ok {
|
||||||
|
out := &MessageVideoUrl{
|
||||||
|
Url: common.Interface2String(itemMap["url"]),
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -182,6 +370,7 @@ const (
|
|||||||
ContentTypeInputAudio = "input_audio"
|
ContentTypeInputAudio = "input_audio"
|
||||||
ContentTypeFile = "file"
|
ContentTypeFile = "file"
|
||||||
ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别
|
ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别
|
||||||
|
//ContentTypeAudioUrl = "audio_url"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (m *Message) GetPrefix() bool {
|
func (m *Message) GetPrefix() bool {
|
||||||
@@ -212,6 +401,186 @@ func (m *Message) SetToolCalls(toolCalls any) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Message) StringContent() string {
|
func (m *Message) StringContent() string {
|
||||||
|
switch m.Content.(type) {
|
||||||
|
case string:
|
||||||
|
return m.Content.(string)
|
||||||
|
case []any:
|
||||||
|
var contentStr string
|
||||||
|
for _, contentItem := range m.Content.([]any) {
|
||||||
|
contentMap, ok := contentItem.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if contentMap["type"] == ContentTypeText {
|
||||||
|
if subStr, ok := contentMap["text"].(string); ok {
|
||||||
|
contentStr += subStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return contentStr
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) SetNullContent() {
|
||||||
|
m.Content = nil
|
||||||
|
m.parsedContent = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) SetStringContent(content string) {
|
||||||
|
m.Content = content
|
||||||
|
m.parsedContent = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) SetMediaContent(content []MediaContent) {
|
||||||
|
m.Content = content
|
||||||
|
m.parsedContent = content
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) IsStringContent() bool {
|
||||||
|
_, ok := m.Content.(string)
|
||||||
|
if ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) ParseContent() []MediaContent {
|
||||||
|
if m.Content == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(m.parsedContent) > 0 {
|
||||||
|
return m.parsedContent
|
||||||
|
}
|
||||||
|
|
||||||
|
var contentList []MediaContent
|
||||||
|
// 先尝试解析为字符串
|
||||||
|
content, ok := m.Content.(string)
|
||||||
|
if ok {
|
||||||
|
contentList = []MediaContent{{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
Text: content,
|
||||||
|
}}
|
||||||
|
m.parsedContent = contentList
|
||||||
|
return contentList
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试解析为数组
|
||||||
|
//var arrayContent []map[string]interface{}
|
||||||
|
|
||||||
|
arrayContent, ok := m.Content.([]any)
|
||||||
|
if !ok {
|
||||||
|
return contentList
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, contentItemAny := range arrayContent {
|
||||||
|
mediaItem, ok := contentItemAny.(MediaContent)
|
||||||
|
if ok {
|
||||||
|
contentList = append(contentList, mediaItem)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
contentItem, ok := contentItemAny.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
contentType, ok := contentItem["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch contentType {
|
||||||
|
case ContentTypeText:
|
||||||
|
if text, ok := contentItem["text"].(string); ok {
|
||||||
|
contentList = append(contentList, MediaContent{
|
||||||
|
Type: ContentTypeText,
|
||||||
|
Text: text,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
case ContentTypeImageURL:
|
||||||
|
imageUrl := contentItem["image_url"]
|
||||||
|
temp := &MessageImageUrl{
|
||||||
|
Detail: "high",
|
||||||
|
}
|
||||||
|
switch v := imageUrl.(type) {
|
||||||
|
case string:
|
||||||
|
temp.Url = v
|
||||||
|
case map[string]interface{}:
|
||||||
|
url, ok1 := v["url"].(string)
|
||||||
|
detail, ok2 := v["detail"].(string)
|
||||||
|
if ok2 {
|
||||||
|
temp.Detail = detail
|
||||||
|
}
|
||||||
|
if ok1 {
|
||||||
|
temp.Url = url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contentList = append(contentList, MediaContent{
|
||||||
|
Type: ContentTypeImageURL,
|
||||||
|
ImageUrl: temp,
|
||||||
|
})
|
||||||
|
|
||||||
|
case ContentTypeInputAudio:
|
||||||
|
if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
|
||||||
|
data, ok1 := audioData["data"].(string)
|
||||||
|
format, ok2 := audioData["format"].(string)
|
||||||
|
if ok1 && ok2 {
|
||||||
|
temp := &MessageInputAudio{
|
||||||
|
Data: data,
|
||||||
|
Format: format,
|
||||||
|
}
|
||||||
|
contentList = append(contentList, MediaContent{
|
||||||
|
Type: ContentTypeInputAudio,
|
||||||
|
InputAudio: temp,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case ContentTypeFile:
|
||||||
|
if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
|
||||||
|
fileId, ok3 := fileData["file_id"].(string)
|
||||||
|
if ok3 {
|
||||||
|
contentList = append(contentList, MediaContent{
|
||||||
|
Type: ContentTypeFile,
|
||||||
|
File: &MessageFile{
|
||||||
|
FileId: fileId,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
fileName, ok1 := fileData["filename"].(string)
|
||||||
|
fileDataStr, ok2 := fileData["file_data"].(string)
|
||||||
|
if ok1 && ok2 {
|
||||||
|
contentList = append(contentList, MediaContent{
|
||||||
|
Type: ContentTypeFile,
|
||||||
|
File: &MessageFile{
|
||||||
|
FileName: fileName,
|
||||||
|
FileData: fileDataStr,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case ContentTypeVideoUrl:
|
||||||
|
if videoUrl, ok := contentItem["video_url"].(string); ok {
|
||||||
|
contentList = append(contentList, MediaContent{
|
||||||
|
Type: ContentTypeVideoUrl,
|
||||||
|
VideoUrl: &MessageVideoUrl{
|
||||||
|
Url: videoUrl,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(contentList) > 0 {
|
||||||
|
m.parsedContent = contentList
|
||||||
|
}
|
||||||
|
return contentList
|
||||||
|
}
|
||||||
|
|
||||||
|
// old code
|
||||||
|
/*func (m *Message) StringContent() string {
|
||||||
if m.parsedStringContent != nil {
|
if m.parsedStringContent != nil {
|
||||||
return *m.parsedStringContent
|
return *m.parsedStringContent
|
||||||
}
|
}
|
||||||
@@ -382,33 +751,106 @@ func (m *Message) ParseContent() []MediaContent {
|
|||||||
m.parsedContent = contentList
|
m.parsedContent = contentList
|
||||||
}
|
}
|
||||||
return contentList
|
return contentList
|
||||||
}
|
}*/
|
||||||
|
|
||||||
type WebSearchOptions struct {
|
type WebSearchOptions struct {
|
||||||
SearchContextSize string `json:"search_context_size,omitempty"`
|
SearchContextSize string `json:"search_context_size,omitempty"`
|
||||||
UserLocation json.RawMessage `json:"user_location,omitempty"`
|
UserLocation json.RawMessage `json:"user_location,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/responses/create
|
||||||
type OpenAIResponsesRequest struct {
|
type OpenAIResponsesRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Input json.RawMessage `json:"input,omitempty"`
|
Input any `json:"input,omitempty"`
|
||||||
Include json.RawMessage `json:"include,omitempty"`
|
Include json.RawMessage `json:"include,omitempty"`
|
||||||
Instructions json.RawMessage `json:"instructions,omitempty"`
|
Instructions json.RawMessage `json:"instructions,omitempty"`
|
||||||
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
||||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||||
ServiceTier string `json:"service_tier,omitempty"`
|
ServiceTier string `json:"service_tier,omitempty"`
|
||||||
Store bool `json:"store,omitempty"`
|
Store bool `json:"store,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
Text json.RawMessage `json:"text,omitempty"`
|
Text json.RawMessage `json:"text,omitempty"`
|
||||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||||
Tools []ResponsesToolsCall `json:"tools,omitempty"`
|
Tools []map[string]any `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
Truncation string `json:"truncation,omitempty"`
|
Truncation string `json:"truncation,omitempty"`
|
||||||
User string `json:"user,omitempty"`
|
User string `json:"user,omitempty"`
|
||||||
|
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
|
||||||
|
Prompt json.RawMessage `json:"prompt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||||
|
var fileMeta = make([]*types.FileMeta, 0)
|
||||||
|
var texts = make([]string, 0)
|
||||||
|
|
||||||
|
if r.Input != nil {
|
||||||
|
inputs := r.ParseInput()
|
||||||
|
for _, input := range inputs {
|
||||||
|
if input.Type == "input_image" {
|
||||||
|
if input.ImageUrl != "" {
|
||||||
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeImage,
|
||||||
|
OriginData: input.ImageUrl,
|
||||||
|
Detail: input.Detail,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else if input.Type == "input_file" {
|
||||||
|
if input.FileUrl != "" {
|
||||||
|
fileMeta = append(fileMeta, &types.FileMeta{
|
||||||
|
FileType: types.FileTypeFile,
|
||||||
|
OriginData: input.FileUrl,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
texts = append(texts, input.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Instructions) > 0 {
|
||||||
|
texts = append(texts, string(r.Instructions))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Metadata) > 0 {
|
||||||
|
texts = append(texts, string(r.Metadata))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Text) > 0 {
|
||||||
|
texts = append(texts, string(r.Text))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.ToolChoice) > 0 {
|
||||||
|
texts = append(texts, string(r.ToolChoice))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Prompt) > 0 {
|
||||||
|
texts = append(texts, string(r.Prompt))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r.Tools) > 0 {
|
||||||
|
toolStr, _ := common.Marshal(r.Tools)
|
||||||
|
texts = append(texts, string(toolStr))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &types.TokenCountMeta{
|
||||||
|
CombineText: strings.Join(texts, "\n"),
|
||||||
|
Files: fileMeta,
|
||||||
|
MaxTokens: int(r.MaxOutputTokens),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool {
|
||||||
|
return r.Stream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OpenAIResponsesRequest) SetModelName(modelName string) {
|
||||||
|
if modelName != "" {
|
||||||
|
r.Model = modelName
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Reasoning struct {
|
type Reasoning struct {
|
||||||
@@ -416,21 +858,80 @@ type Reasoning struct {
|
|||||||
Summary string `json:"summary,omitempty"`
|
Summary string `json:"summary,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponsesToolsCall struct {
|
type MediaInput struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
// Web Search
|
Text string `json:"text,omitempty"`
|
||||||
UserLocation json.RawMessage `json:"user_location,omitempty"`
|
FileUrl string `json:"file_url,omitempty"`
|
||||||
SearchContextSize string `json:"search_context_size,omitempty"`
|
ImageUrl string `json:"image_url,omitempty"`
|
||||||
// File Search
|
Detail string `json:"detail,omitempty"` // 仅 input_image 有效
|
||||||
VectorStoreIds []string `json:"vector_store_ids,omitempty"`
|
}
|
||||||
MaxNumResults uint `json:"max_num_results,omitempty"`
|
|
||||||
Filters json.RawMessage `json:"filters,omitempty"`
|
// ParseInput parses the Responses API `input` field into a normalized slice of MediaInput.
|
||||||
// Computer Use
|
// Reference implementation mirrors Message.ParseContent:
|
||||||
DisplayWidth uint `json:"display_width,omitempty"`
|
// - input can be a string, treated as an input_text item
|
||||||
DisplayHeight uint `json:"display_height,omitempty"`
|
// - input can be an array of objects with a `type` field
|
||||||
Environment string `json:"environment,omitempty"`
|
// supported types: input_text, input_image, input_file
|
||||||
// Function
|
func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
|
||||||
Name string `json:"name,omitempty"`
|
if r.Input == nil {
|
||||||
Description string `json:"description,omitempty"`
|
return nil
|
||||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
}
|
||||||
|
|
||||||
|
var inputs []MediaInput
|
||||||
|
|
||||||
|
// Try string first
|
||||||
|
if str, ok := r.Input.(string); ok {
|
||||||
|
inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
|
||||||
|
return inputs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try array of parts
|
||||||
|
if array, ok := r.Input.([]any); ok {
|
||||||
|
for _, itemAny := range array {
|
||||||
|
// Already parsed MediaInput
|
||||||
|
if media, ok := itemAny.(MediaInput); ok {
|
||||||
|
inputs = append(inputs, media)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Generic map
|
||||||
|
item, ok := itemAny.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
typeVal, ok := item["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch typeVal {
|
||||||
|
case "input_text":
|
||||||
|
text, _ := item["text"].(string)
|
||||||
|
inputs = append(inputs, MediaInput{Type: "input_text", Text: text})
|
||||||
|
case "input_image":
|
||||||
|
// image_url may be string or object with url field
|
||||||
|
var imageUrl string
|
||||||
|
switch v := item["image_url"].(type) {
|
||||||
|
case string:
|
||||||
|
imageUrl = v
|
||||||
|
case map[string]any:
|
||||||
|
if url, ok := v["url"].(string); ok {
|
||||||
|
imageUrl = url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
|
||||||
|
case "input_file":
|
||||||
|
// file_url may be string or object with url field
|
||||||
|
var fileUrl string
|
||||||
|
switch v := item["file_url"].(type) {
|
||||||
|
case string:
|
||||||
|
fileUrl = v
|
||||||
|
case map[string]any:
|
||||||
|
if url, ok := v["url"].(string); ok {
|
||||||
|
fileUrl = url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return inputs
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,19 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
import "encoding/json"
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
type SimpleResponse struct {
|
type SimpleResponse struct {
|
||||||
Usage `json:"usage"`
|
Usage `json:"usage"`
|
||||||
Error *OpenAIError `json:"error"`
|
Error any `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||||
|
func (s *SimpleResponse) GetOpenAIError() *types.OpenAIError {
|
||||||
|
return GetOpenAIError(s.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextResponse struct {
|
type TextResponse struct {
|
||||||
@@ -26,12 +35,17 @@ type OpenAITextResponse struct {
|
|||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Created int64 `json:"created"`
|
Created any `json:"created"`
|
||||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||||
Error *OpenAIError `json:"error,omitempty"`
|
Error any `json:"error,omitempty"`
|
||||||
Usage `json:"usage"`
|
Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||||
|
func (o *OpenAITextResponse) GetOpenAIError() *types.OpenAIError {
|
||||||
|
return GetOpenAIError(o.Error)
|
||||||
|
}
|
||||||
|
|
||||||
type OpenAIEmbeddingResponseItem struct {
|
type OpenAIEmbeddingResponseItem struct {
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
@@ -45,6 +59,19 @@ type OpenAIEmbeddingResponse struct {
|
|||||||
Usage `json:"usage"`
|
Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FlexibleEmbeddingResponseItem struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
Embedding any `json:"embedding"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FlexibleEmbeddingResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []FlexibleEmbeddingResponseItem `json:"data"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
type ChatCompletionsStreamResponseChoice struct {
|
type ChatCompletionsStreamResponseChoice struct {
|
||||||
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"`
|
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"`
|
||||||
Logprobs *any `json:"logprobs"`
|
Logprobs *any `json:"logprobs"`
|
||||||
@@ -83,7 +110,7 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string
|
|||||||
|
|
||||||
func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
|
func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
|
||||||
c.ReasoningContent = &s
|
c.ReasoningContent = &s
|
||||||
c.Reasoning = &s
|
//c.Reasoning = &s
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCallResponse struct {
|
type ToolCallResponse struct {
|
||||||
@@ -116,6 +143,13 @@ type ChatCompletionsStreamResponse struct {
|
|||||||
Usage *Usage `json:"usage"`
|
Usage *Usage `json:"usage"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ChatCompletionsStreamResponse) IsFinished() bool {
|
||||||
|
if len(c.Choices) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return c.Choices[0].FinishReason != nil && *c.Choices[0].FinishReason != ""
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
|
func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
|
||||||
if len(c.Choices) == 0 {
|
if len(c.Choices) == 0 {
|
||||||
return false
|
return false
|
||||||
@@ -130,6 +164,19 @@ func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ChatCompletionsStreamResponse) ClearToolCalls() {
|
||||||
|
if !c.IsToolCall() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for choiceIdx := range c.Choices {
|
||||||
|
for callIdx := range c.Choices[choiceIdx].Delta.ToolCalls {
|
||||||
|
c.Choices[choiceIdx].Delta.ToolCalls[callIdx].ID = ""
|
||||||
|
c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Type = nil
|
||||||
|
c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Function.Name = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
|
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
|
||||||
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
|
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
|
||||||
copy(choices, c.Choices)
|
copy(choices, c.Choices)
|
||||||
@@ -178,6 +225,8 @@ type Usage struct {
|
|||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
|
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
|
||||||
|
// OpenRouter Params
|
||||||
|
Cost any `json:"cost,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type InputTokenDetails struct {
|
type InputTokenDetails struct {
|
||||||
@@ -195,28 +244,33 @@ type OutputTokenDetails struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIResponsesResponse struct {
|
type OpenAIResponsesResponse struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
CreatedAt int `json:"created_at"`
|
CreatedAt int `json:"created_at"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Error *OpenAIError `json:"error,omitempty"`
|
Error any `json:"error,omitempty"`
|
||||||
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
||||||
Instructions string `json:"instructions"`
|
Instructions string `json:"instructions"`
|
||||||
MaxOutputTokens int `json:"max_output_tokens"`
|
MaxOutputTokens int `json:"max_output_tokens"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Output []ResponsesOutput `json:"output"`
|
Output []ResponsesOutput `json:"output"`
|
||||||
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
ParallelToolCalls bool `json:"parallel_tool_calls"`
|
||||||
PreviousResponseID string `json:"previous_response_id"`
|
PreviousResponseID string `json:"previous_response_id"`
|
||||||
Reasoning *Reasoning `json:"reasoning"`
|
Reasoning *Reasoning `json:"reasoning"`
|
||||||
Store bool `json:"store"`
|
Store bool `json:"store"`
|
||||||
Temperature float64 `json:"temperature"`
|
Temperature float64 `json:"temperature"`
|
||||||
ToolChoice string `json:"tool_choice"`
|
ToolChoice string `json:"tool_choice"`
|
||||||
Tools []ResponsesToolsCall `json:"tools"`
|
Tools []map[string]any `json:"tools"`
|
||||||
TopP float64 `json:"top_p"`
|
TopP float64 `json:"top_p"`
|
||||||
Truncation string `json:"truncation"`
|
Truncation string `json:"truncation"`
|
||||||
Usage *Usage `json:"usage"`
|
Usage *Usage `json:"usage"`
|
||||||
User json.RawMessage `json:"user"`
|
User json.RawMessage `json:"user"`
|
||||||
Metadata json.RawMessage `json:"metadata"`
|
Metadata json.RawMessage `json:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||||
|
func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
|
||||||
|
return GetOpenAIError(o.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type IncompleteDetails struct {
|
type IncompleteDetails struct {
|
||||||
@@ -258,3 +312,45 @@ type ResponsesStreamResponse struct {
|
|||||||
Delta string `json:"delta,omitempty"`
|
Delta string `json:"delta,omitempty"`
|
||||||
Item *ResponsesOutput `json:"item,omitempty"`
|
Item *ResponsesOutput `json:"item,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
|
||||||
|
func GetOpenAIError(errorField any) *types.OpenAIError {
|
||||||
|
if errorField == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch err := errorField.(type) {
|
||||||
|
case types.OpenAIError:
|
||||||
|
return &err
|
||||||
|
case *types.OpenAIError:
|
||||||
|
return err
|
||||||
|
case map[string]interface{}:
|
||||||
|
// 处理从JSON解析来的map结构
|
||||||
|
openaiErr := &types.OpenAIError{}
|
||||||
|
if errType, ok := err["type"].(string); ok {
|
||||||
|
openaiErr.Type = errType
|
||||||
|
}
|
||||||
|
if errMsg, ok := err["message"].(string); ok {
|
||||||
|
openaiErr.Message = errMsg
|
||||||
|
}
|
||||||
|
if errParam, ok := err["param"].(string); ok {
|
||||||
|
openaiErr.Param = errParam
|
||||||
|
}
|
||||||
|
if errCode, ok := err["code"]; ok {
|
||||||
|
openaiErr.Code = errCode
|
||||||
|
}
|
||||||
|
return openaiErr
|
||||||
|
case string:
|
||||||
|
// 处理简单字符串错误
|
||||||
|
return &types.OpenAIError{
|
||||||
|
Type: "error",
|
||||||
|
Message: err,
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// 未知类型,尝试转换为字符串
|
||||||
|
return &types.OpenAIError{
|
||||||
|
Type: "unknown_error",
|
||||||
|
Message: fmt.Sprintf("%v", err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,26 +1,35 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
type OpenAIModelPermission struct {
|
import "one-api/constant"
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
// 这里不好动就不动了,本来想独立出来的(
|
||||||
Created int `json:"created"`
|
type OpenAIModels struct {
|
||||||
AllowCreateEngine bool `json:"allow_create_engine"`
|
Id string `json:"id"`
|
||||||
AllowSampling bool `json:"allow_sampling"`
|
Object string `json:"object"`
|
||||||
AllowLogprobs bool `json:"allow_logprobs"`
|
Created int `json:"created"`
|
||||||
AllowSearchIndices bool `json:"allow_search_indices"`
|
OwnedBy string `json:"owned_by"`
|
||||||
AllowView bool `json:"allow_view"`
|
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||||
AllowFineTuning bool `json:"allow_fine_tuning"`
|
|
||||||
Organization string `json:"organization"`
|
|
||||||
Group *string `json:"group"`
|
|
||||||
IsBlocking bool `json:"is_blocking"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIModels struct {
|
type AnthropicModel struct {
|
||||||
Id string `json:"id"`
|
ID string `json:"id"`
|
||||||
Object string `json:"object"`
|
CreatedAt string `json:"created_at"`
|
||||||
Created int `json:"created"`
|
DisplayName string `json:"display_name"`
|
||||||
OwnedBy string `json:"owned_by"`
|
Type string `json:"type"`
|
||||||
Permission []OpenAIModelPermission `json:"permission"`
|
}
|
||||||
Root string `json:"root"`
|
|
||||||
Parent *string `json:"parent"`
|
type GeminiModel struct {
|
||||||
|
Name interface{} `json:"name"`
|
||||||
|
BaseModelId interface{} `json:"baseModelId"`
|
||||||
|
Version interface{} `json:"version"`
|
||||||
|
DisplayName interface{} `json:"displayName"`
|
||||||
|
Description interface{} `json:"description"`
|
||||||
|
InputTokenLimit interface{} `json:"inputTokenLimit"`
|
||||||
|
OutputTokenLimit interface{} `json:"outputTokenLimit"`
|
||||||
|
SupportedGenerationMethods []interface{} `json:"supportedGenerationMethods"`
|
||||||
|
Thinking interface{} `json:"thinking"`
|
||||||
|
Temperature interface{} `json:"temperature"`
|
||||||
|
MaxTemperature interface{} `json:"maxTemperature"`
|
||||||
|
TopP interface{} `json:"topP"`
|
||||||
|
TopK interface{} `json:"topK"`
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user