diff --git a/.claude/settings.local.json b/.claude/settings.local.json
new file mode 100644
index 00000000..010182e3
--- /dev/null
+++ b/.claude/settings.local.json
@@ -0,0 +1,15 @@
+{
+ "permissions": {
+ "allow": [
+ "Bash(git init:*)",
+ "Bash(touch:*)",
+ "Bash(git checkout:*)",
+ "Bash(git add:*)",
+ "Bash(git commit:*)",
+ "Bash(git remote add:*)",
+ "Bash(git push:*)"
+ ],
+ "deny": [],
+ "ask": []
+ }
+}
\ No newline at end of file
diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 00000000..e4e8e72e
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,7 @@
+.github
+.git
+*.md
+.vscode
+.gitignore
+Makefile
+docs
\ No newline at end of file
diff --git a/.env.example b/.env.example
new file mode 100644
index 00000000..ea246427
--- /dev/null
+++ b/.env.example
@@ -0,0 +1,75 @@
+# 端口号
+# PORT=3000
+# 前端基础URL
+# FRONTEND_BASE_URL=https://your-frontend-url.com
+
+
+# 调试相关配置
+# 启用pprof
+# ENABLE_PPROF=true
+# 启用调试模式
+# DEBUG=true
+
+# 数据库相关配置
+# 数据库连接字符串
+# SQL_DSN=user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
+# 日志数据库连接字符串
+# LOG_SQL_DSN=user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
+# SQLite数据库路径
+# SQLITE_PATH=/path/to/sqlite.db
+# 数据库最大空闲连接数
+# SQL_MAX_IDLE_CONNS=100
+# 数据库最大打开连接数
+# SQL_MAX_OPEN_CONNS=1000
+# 数据库连接最大生命周期(秒)
+# SQL_MAX_LIFETIME=60
+
+
+# 缓存相关配置
+# Redis连接字符串
+# REDIS_CONN_STRING=redis://user:password@localhost:6379/0
+# 同步频率(单位:秒)
+# SYNC_FREQUENCY=60
+# 内存缓存启用
+# MEMORY_CACHE_ENABLED=true
+# 渠道更新频率(单位:秒)
+# CHANNEL_UPDATE_FREQUENCY=30
+# 批量更新启用
+# BATCH_UPDATE_ENABLED=true
+# 批量更新间隔(单位:秒)
+# BATCH_UPDATE_INTERVAL=5
+
+# 任务和功能配置
+# 更新任务启用
+# UPDATE_TASK=true
+
+# 对话超时设置
+# 所有请求超时时间,单位秒,默认为0,表示不限制
+# RELAY_TIMEOUT=0
+# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
+# STREAMING_TIMEOUT=120
+
+# Gemini 识别图片 最大图片数量
+# GEMINI_VISION_MAX_IMAGE_NUM=16
+
+# 会话密钥
+# SESSION_SECRET=random_string
+
+# 其他配置
+# 渠道测试频率(单位:秒)
+# CHANNEL_TEST_FREQUENCY=10
+# 生成默认token
+# GENERATE_DEFAULT_TOKEN=false
+# Cohere 安全设置
+# COHERE_SAFETY_SETTING=NONE
+# 是否统计图片token
+# GET_MEDIA_TOKEN=true
+# 是否在非流(stream=false)情况下统计图片token
+# GET_MEDIA_TOKEN_NOT_STREAM=true
+# 设置 Dify 渠道是否输出工作流和节点信息到客户端
+# DIFY_DEBUG=true
+
+
+# 节点类型
+# 如果是主节点则为master
+# NODE_TYPE=master
diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml
new file mode 100644
index 00000000..87747788
--- /dev/null
+++ b/.github/FUNDING.yml
@@ -0,0 +1,12 @@
+# These are supported funding model platforms
+
+github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
+patreon: # Replace with a single Patreon username
+open_collective: # Replace with a single Open Collective username
+ko_fi: # Replace with a single Ko-fi username
+tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
+community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
+liberapay: # Replace with a single Liberapay username
+issuehunt: # Replace with a single IssueHunt username
+otechie: # Replace with a single Otechie username
+custom: ['https://afdian.com/a/new-api'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
new file mode 100644
index 00000000..dd688493
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -0,0 +1,26 @@
+---
+name: 报告问题
+about: 使用简练详细的语言描述你遇到的问题
+title: ''
+labels: bug
+assignees: ''
+
+---
+
+**例行检查**
+
+[//]: # (方框内删除已有的空格,填 x 号)
++ [ ] 我已确认目前没有类似 issue
++ [ ] 我已确认我已升级到最新版本
++ [ ] 我已完整查看过项目 README,尤其是常见问题部分
++ [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈
++ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭**
+
+**问题描述**
+
+**复现步骤**
+
+**预期结果**
+
+**相关截图**
+如果没有的话,请删除此节。
\ No newline at end of file
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 00000000..5b8ee14f
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1,5 @@
+blank_issues_enabled: false
+contact_links:
+ - name: 项目群聊
+ url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg
+ about: QQ 群:629454374
diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 00000000..049d89c8
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,21 @@
+---
+name: 功能请求
+about: 使用简练详细的语言描述希望加入的新功能
+title: ''
+labels: enhancement
+assignees: ''
+
+---
+
+**例行检查**
+
+[//]: # (方框内删除已有的空格,填 x 号)
++ [ ] 我已确认目前没有类似 issue
++ [ ] 我已确认我已升级到最新版本
++ [ ] 我已完整查看过项目 README,已确定现有版本无法满足需求
++ [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈
++ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭**
+
+**功能描述**
+
+**应用场景**
diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
new file mode 100644
index 00000000..4f6e41ac
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
@@ -0,0 +1,19 @@
+### PR 类型
+
+- [ ] Bug 修复
+- [ ] 新功能
+- [ ] 文档更新
+- [ ] 其他
+
+### PR 是否包含破坏性更新?
+
+- [ ] 是
+- [ ] 否
+
+### PR 描述
+
+**请在下方详细描述您的 PR,包括目的、实现细节等。**
+
+### **重要提示**
+
+**所有 PR 都必须提交到 `alpha` 分支。请确保您的 PR 目标分支是 `alpha`。**
diff --git a/.github/workflows/docker-image-alpha.yml b/.github/workflows/docker-image-alpha.yml
new file mode 100644
index 00000000..c02bd409
--- /dev/null
+++ b/.github/workflows/docker-image-alpha.yml
@@ -0,0 +1,62 @@
+name: Publish Docker image (alpha)
+
+on:
+ push:
+ branches:
+ - alpha
+ workflow_dispatch:
+ inputs:
+ name:
+ description: "reason"
+ required: false
+
+jobs:
+ push_to_registries:
+ name: Push Docker image to multiple registries
+ runs-on: ubuntu-latest
+ permissions:
+ packages: write
+ contents: read
+ steps:
+ - name: Check out the repo
+ uses: actions/checkout@v4
+
+ - name: Save version info
+ run: |
+ echo "alpha-$(date +'%Y%m%d')-$(git rev-parse --short HEAD)" > VERSION
+
+ - name: Log in to Docker Hub
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
+
+ - name: Log in to the Container registry
+ uses: docker/login-action@v3
+ with:
+ registry: ghcr.io
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ - name: Extract metadata (tags, labels) for Docker
+ id: meta
+ uses: docker/metadata-action@v5
+ with:
+ images: |
+ calciumion/new-api
+ ghcr.io/${{ github.repository }}
+ tags: |
+ type=raw,value=alpha
+ type=raw,value=alpha-{{date 'YYYYMMDD'}}-{{sha}}
+
+ - name: Build and push Docker images
+ uses: docker/build-push-action@v5
+ with:
+ context: .
+ platforms: linux/amd64,linux/arm64
+ push: true
+ tags: ${{ steps.meta.outputs.tags }}
+ labels: ${{ steps.meta.outputs.labels }}
diff --git a/.github/workflows/docker-image-arm64.yml b/.github/workflows/docker-image-arm64.yml
new file mode 100644
index 00000000..8e4656aa
--- /dev/null
+++ b/.github/workflows/docker-image-arm64.yml
@@ -0,0 +1,56 @@
+name: Publish Docker image (Multi Registries)
+
+on:
+ push:
+ tags:
+ - '*'
+jobs:
+ push_to_registries:
+ name: Push Docker image to multiple registries
+ runs-on: ubuntu-latest
+ permissions:
+ packages: write
+ contents: read
+ steps:
+ - name: Check out the repo
+ uses: actions/checkout@v4
+
+ - name: Save version info
+ run: |
+ git describe --tags > VERSION
+
+ - name: Set up QEMU
+ uses: docker/setup-qemu-action@v3
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ - name: Log in to Docker Hub
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
+
+ - name: Log in to the Container registry
+ uses: docker/login-action@v3
+ with:
+ registry: ghcr.io
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Extract metadata (tags, labels) for Docker
+ id: meta
+ uses: docker/metadata-action@v5
+ with:
+ images: |
+ calciumion/new-api
+ ghcr.io/${{ github.repository }}
+
+ - name: Build and push Docker images
+ uses: docker/build-push-action@v5
+ with:
+ context: .
+ platforms: linux/amd64,linux/arm64
+ push: true
+ tags: ${{ steps.meta.outputs.tags }}
+ labels: ${{ steps.meta.outputs.labels }}
\ No newline at end of file
diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml
new file mode 100644
index 00000000..c87fcfce
--- /dev/null
+++ b/.github/workflows/linux-release.yml
@@ -0,0 +1,59 @@
+name: Linux Release
+permissions:
+ contents: write
+
+on:
+ workflow_dispatch:
+ inputs:
+ name:
+ description: 'reason'
+ required: false
+ push:
+ tags:
+ - '*'
+ - '!*-alpha*'
+jobs:
+ release:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+ - uses: oven-sh/setup-bun@v2
+ with:
+ bun-version: latest
+ - name: Build Frontend
+ env:
+ CI: ""
+ run: |
+ cd web
+ bun install
+ DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
+ cd ..
+ - name: Set up Go
+ uses: actions/setup-go@v3
+ with:
+ go-version: '>=1.18.0'
+ - name: Build Backend (amd64)
+ run: |
+ go mod download
+ go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api
+
+ - name: Build Backend (arm64)
+ run: |
+ sudo apt-get update
+ DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu
+ CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-arm64
+
+ - name: Release
+ uses: softprops/action-gh-release@v1
+ if: startsWith(github.ref, 'refs/tags/')
+ with:
+ files: |
+ one-api
+ one-api-arm64
+ draft: true
+ generate_release_notes: true
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
\ No newline at end of file
diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml
new file mode 100644
index 00000000..1bc786ac
--- /dev/null
+++ b/.github/workflows/macos-release.yml
@@ -0,0 +1,51 @@
+name: macOS Release
+permissions:
+ contents: write
+
+on:
+ workflow_dispatch:
+ inputs:
+ name:
+ description: 'reason'
+ required: false
+ push:
+ tags:
+ - '*'
+ - '!*-alpha*'
+jobs:
+ release:
+ runs-on: macos-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+ - uses: oven-sh/setup-bun@v2
+ with:
+ bun-version: latest
+ - name: Build Frontend
+ env:
+ CI: ""
+ NODE_OPTIONS: "--max-old-space-size=4096"
+ run: |
+ cd web
+ bun install
+ DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
+ cd ..
+ - name: Set up Go
+ uses: actions/setup-go@v3
+ with:
+ go-version: '>=1.18.0'
+ - name: Build Backend
+ run: |
+ go mod download
+ go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos
+ - name: Release
+ uses: softprops/action-gh-release@v1
+ if: startsWith(github.ref, 'refs/tags/')
+ with:
+ files: one-api-macos
+ draft: true
+ generate_release_notes: true
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
diff --git a/.github/workflows/pr-target-branch-check.yml b/.github/workflows/pr-target-branch-check.yml
new file mode 100644
index 00000000..e7bd4c81
--- /dev/null
+++ b/.github/workflows/pr-target-branch-check.yml
@@ -0,0 +1,21 @@
+name: Check PR Branching Strategy
+on:
+ pull_request:
+ types: [opened, synchronize, reopened, edited]
+
+jobs:
+ check-branching-strategy:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Enforce branching strategy
+ run: |
+ if [[ "${{ github.base_ref }}" == "main" ]]; then
+ if [[ "${{ github.head_ref }}" != "alpha" ]]; then
+ echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
+ exit 1
+ fi
+ elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
+ echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
+ exit 1
+ fi
+ echo "Branching strategy check passed."
\ No newline at end of file
diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml
new file mode 100644
index 00000000..de3d83d5
--- /dev/null
+++ b/.github/workflows/windows-release.yml
@@ -0,0 +1,53 @@
+name: Windows Release
+permissions:
+ contents: write
+
+on:
+ workflow_dispatch:
+ inputs:
+ name:
+ description: 'reason'
+ required: false
+ push:
+ tags:
+ - '*'
+ - '!*-alpha*'
+jobs:
+ release:
+ runs-on: windows-latest
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+ - uses: oven-sh/setup-bun@v2
+ with:
+ bun-version: latest
+ - name: Build Frontend
+ env:
+ CI: ""
+ run: |
+ cd web
+ bun install
+ DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
+ cd ..
+ - name: Set up Go
+ uses: actions/setup-go@v3
+ with:
+ go-version: '>=1.18.0'
+ - name: Build Backend
+ run: |
+ go mod download
+ go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe
+ - name: Release
+ uses: softprops/action-gh-release@v1
+ if: startsWith(github.ref, 'refs/tags/')
+ with:
+ files: one-api.exe
+ draft: true
+ generate_release_notes: true
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 00000000..6a23f89e
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,13 @@
+.idea
+.vscode
+upload
+*.exe
+*.db
+build
+*.db-journal
+logs
+web/dist
+.env
+one-api
+.DS_Store
+tiktoken_cache
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 00000000..08cc86f7
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,35 @@
+FROM oven/bun:latest AS builder
+
+WORKDIR /build
+COPY web/package.json .
+COPY web/bun.lock .
+RUN bun install
+COPY ./web .
+COPY ./VERSION .
+RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
+
+FROM golang:alpine AS builder2
+
+ENV GO111MODULE=on \
+ CGO_ENABLED=0 \
+ GOOS=linux
+
+WORKDIR /build
+
+ADD go.mod go.sum ./
+RUN go mod download
+
+COPY . .
+COPY --from=builder /build/dist ./web/dist
+RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-api
+
+FROM alpine
+
+RUN apk upgrade --no-cache \
+ && apk add --no-cache ca-certificates tzdata ffmpeg \
+ && update-ca-certificates
+
+COPY --from=builder2 /build/one-api /
+EXPOSE 3000
+WORKDIR /data
+ENTRYPOINT ["/one-api"]
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 00000000..261eeb9e
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.en.md b/README.en.md
new file mode 100644
index 00000000..69fd32f8
--- /dev/null
+++ b/README.en.md
@@ -0,0 +1,216 @@
+
+ 中文 | English
+
+
+
+
+
+# New API
+
+🍥 Next-Generation Large Model Gateway and AI Asset Management System
+
+

+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+## 📝 Project Description
+
+> [!NOTE]
+> This is an open-source project developed based on [One API](https://github.com/songquanpeng/one-api)
+
+> [!IMPORTANT]
+> - This project is for personal learning purposes only, with no guarantee of stability or technical support.
+> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**, and must not use it for illegal purposes.
+> - According to the [《Interim Measures for the Management of Generative Artificial Intelligence Services》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), please do not provide any unregistered generative AI services to the public in China.
+
+🤝 Trusted Partners
+
+No particular order
+
+
+
+
+
+
+
+
+
+## 📚 Documentation
+
+For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/)
+
+You can also access the AI-generated DeepWiki:
+[](https://deepwiki.com/QuantumNous/new-api)
+
+## ✨ Key Features
+
+New API offers a wide range of features, please refer to [Features Introduction](https://docs.newapi.pro/wiki/features-introduction) for details:
+
+1. 🎨 Brand new UI interface
+2. 🌍 Multi-language support
+3. 💰 Online recharge functionality (YiPay)
+4. 🔍 Support for querying usage quotas with keys (works with [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool))
+5. 🔄 Compatible with the original One API database
+6. 💵 Support for pay-per-use model pricing
+7. ⚖️ Support for weighted random channel selection
+8. 📈 Data dashboard (console)
+9. 🔒 Token grouping and model restrictions
+10. 🤖 Support for more authorization login methods (LinuxDO, Telegram, OIDC)
+11. 🔄 Support for Rerank models (Cohere and Jina), [API Documentation](https://docs.newapi.pro/api/jinaai-rerank)
+12. ⚡ Support for OpenAI Realtime API (including Azure channels), [API Documentation](https://docs.newapi.pro/api/openai-realtime)
+13. ⚡ Support for Claude Messages format, [API Documentation](https://docs.newapi.pro/api/anthropic-chat)
+14. Support for entering chat interface via /chat2link route
+15. 🧠 Support for setting reasoning effort through model name suffixes:
+ 1. OpenAI o-series models
+ - Add `-high` suffix for high reasoning effort (e.g.: `o3-mini-high`)
+ - Add `-medium` suffix for medium reasoning effort (e.g.: `o3-mini-medium`)
+ - Add `-low` suffix for low reasoning effort (e.g.: `o3-mini-low`)
+ 2. Claude thinking models
+ - Add `-thinking` suffix to enable thinking mode (e.g.: `claude-3-7-sonnet-20250219-thinking`)
+16. 🔄 Thinking-to-content functionality
+17. 🔄 Model rate limiting for users
+18. 💰 Cache billing support, which allows billing at a set ratio when cache is hit:
+ 1. Set the `Prompt Cache Ratio` option in `System Settings-Operation Settings`
+ 2. Set `Prompt Cache Ratio` in the channel, range 0-1, e.g., setting to 0.5 means billing at 50% when cache is hit
+ 3. Supported channels:
+ - [x] OpenAI
+ - [x] Azure
+ - [x] DeepSeek
+ - [x] Claude
+
+## Model Support
+
+This version supports multiple models, please refer to [API Documentation-Relay Interface](https://docs.newapi.pro/api) for details:
+
+1. Third-party models **gpts** (gpt-4-gizmo-*)
+2. Third-party channel [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [API Documentation](https://docs.newapi.pro/api/midjourney-proxy-image)
+3. Third-party channel [Suno API](https://github.com/Suno-API/Suno-API) interface, [API Documentation](https://docs.newapi.pro/api/suno-music)
+4. Custom channels, supporting full call address input
+5. Rerank models ([Cohere](https://cohere.ai/) and [Jina](https://jina.ai/)), [API Documentation](https://docs.newapi.pro/api/jinaai-rerank)
+6. Claude Messages format, [API Documentation](https://docs.newapi.pro/api/anthropic-chat)
+7. Dify, currently only supports chatflow
+
+## Environment Variable Configuration
+
+For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
+
+- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
+- `STREAMING_TIMEOUT`: Streaming response timeout, default is 300 seconds
+- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
+- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
+- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`
+- `GET_MEDIA_TOKEN_NOT_STREAM`: Whether to count image tokens in non-streaming cases, default is `true`
+- `UPDATE_TASK`: Whether to update asynchronous tasks (Midjourney, Suno), default is `true`
+- `COHERE_SAFETY_SETTING`: Cohere model safety settings, options are `NONE`, `CONTEXTUAL`, `STRICT`, default is `NONE`
+- `GEMINI_VISION_MAX_IMAGE_NUM`: Maximum number of images for Gemini models, default is `16`
+- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default is `20`
+- `CRYPTO_SECRET`: Encryption key used for encrypting database content
+- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2025-04-01-preview`
+- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes
+- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2`
+- `ERROR_LOG_ENABLED=true`: Whether to record and display error logs, default is `false`
+
+## Deployment
+
+For detailed deployment guides, please refer to [Installation Guide-Deployment Methods](https://docs.newapi.pro/installation):
+
+> [!TIP]
+> Latest Docker image: `calciumion/new-api:latest`
+
+### Multi-machine Deployment Considerations
+- Environment variable `SESSION_SECRET` must be set, otherwise login status will be inconsistent across multiple machines
+- If sharing Redis, `CRYPTO_SECRET` must be set, otherwise Redis content cannot be accessed across multiple machines
+
+### Deployment Requirements
+- Local database (default): SQLite (Docker deployment must mount the `/data` directory)
+- Remote database: MySQL version >= 5.7.8, PgSQL version >= 9.6
+
+### Deployment Methods
+
+#### Using BaoTa Panel Docker Feature
+Install BaoTa Panel (version **9.2.0** or above), find **New-API** in the application store and install it.
+[Tutorial with images](./docs/BT.md)
+
+#### Using Docker Compose (Recommended)
+```shell
+# Download the project
+git clone https://github.com/Calcium-Ion/new-api.git
+cd new-api
+# Edit docker-compose.yml as needed
+# Start
+docker-compose up -d
+```
+
+#### Using Docker Image Directly
+```shell
+# Using SQLite
+docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
+
+# Using MySQL
+docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
+```
+
+## Channel Retry and Cache
+Channel retry functionality has been implemented, you can set the number of retries in `Settings->Operation Settings->General Settings`. It is **recommended to enable caching**.
+
+### Cache Configuration Method
+1. `REDIS_CONN_STRING`: Set Redis as cache
+2. `MEMORY_CACHE_ENABLED`: Enable memory cache (no need to set manually if Redis is set)
+
+## API Documentation
+
+For detailed API documentation, please refer to [API Documentation](https://docs.newapi.pro/api):
+
+- [Chat API](https://docs.newapi.pro/api/openai-chat)
+- [Image API](https://docs.newapi.pro/api/openai-image)
+- [Rerank API](https://docs.newapi.pro/api/jinaai-rerank)
+- [Realtime API](https://docs.newapi.pro/api/openai-realtime)
+- [Claude Chat API (messages)](https://docs.newapi.pro/api/anthropic-chat)
+
+## Related Projects
+- [One API](https://github.com/songquanpeng/one-api): Original project
+- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy): Midjourney interface support
+- [chatnio](https://github.com/Deeptrain-Community/chatnio): Next-generation AI one-stop B/C-end solution
+- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool): Query usage quota with key
+
+Other projects based on New API:
+- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon): High-performance optimized version of New API
+- [VoAPI](https://github.com/VoAPI/VoAPI): Frontend beautified version based on New API
+
+## Help and Support
+
+If you have any questions, please refer to [Help and Support](https://docs.newapi.pro/support):
+- [Community Interaction](https://docs.newapi.pro/support/community-interaction)
+- [Issue Feedback](https://docs.newapi.pro/support/feedback-issues)
+- [FAQ](https://docs.newapi.pro/support/faq)
+
+## 🌟 Star History
+
+[](https://star-history.com/#Calcium-Ion/new-api&Date)
diff --git a/VERSION b/VERSION
new file mode 100644
index 00000000..e69de29b
diff --git a/bin/migration_v0.2-v0.3.sql b/bin/migration_v0.2-v0.3.sql
new file mode 100644
index 00000000..6b08d7bf
--- /dev/null
+++ b/bin/migration_v0.2-v0.3.sql
@@ -0,0 +1,6 @@
+UPDATE users
+SET quota = quota + (
+ SELECT SUM(remain_quota)
+ FROM tokens
+ WHERE tokens.user_id = users.id
+)
diff --git a/bin/migration_v0.3-v0.4.sql b/bin/migration_v0.3-v0.4.sql
new file mode 100644
index 00000000..e6103c29
--- /dev/null
+++ b/bin/migration_v0.3-v0.4.sql
@@ -0,0 +1,17 @@
+INSERT INTO abilities (`group`, model, channel_id, enabled)
+SELECT c.`group`, m.model, c.id, 1
+FROM channels c
+CROSS JOIN (
+ SELECT 'gpt-3.5-turbo' AS model UNION ALL
+ SELECT 'gpt-3.5-turbo-0301' AS model UNION ALL
+ SELECT 'gpt-4' AS model UNION ALL
+ SELECT 'gpt-4-0314' AS model
+) AS m
+WHERE c.status = 1
+ AND NOT EXISTS (
+ SELECT 1
+ FROM abilities a
+ WHERE a.`group` = c.`group`
+ AND a.model = m.model
+ AND a.channel_id = c.id
+);
diff --git a/bin/time_test.sh b/bin/time_test.sh
new file mode 100644
index 00000000..2cde4a65
--- /dev/null
+++ b/bin/time_test.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+
+if [ $# -lt 3 ]; then
+ echo "Usage: time_test.sh []"
+ exit 1
+fi
+
+domain=$1
+key=$2
+count=$3
+model=${4:-"gpt-3.5-turbo"} # 设置默认模型为 gpt-3.5-turbo
+
+total_time=0
+times=()
+
+for ((i=1; i<=count; i++)); do
+ result=$(curl -o /dev/null -s -w "%{http_code} %{time_total}\\n" \
+ https://"$domain"/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -H "Authorization: Bearer $key" \
+ -d '{"messages": [{"content": "echo hi", "role": "user"}], "model": "'"$model"'", "stream": false, "max_tokens": 1}')
+ http_code=$(echo "$result" | awk '{print $1}')
+ time=$(echo "$result" | awk '{print $2}')
+ echo "HTTP status code: $http_code, Time taken: $time"
+ total_time=$(bc <<< "$total_time + $time")
+ times+=("$time")
+done
+
+average_time=$(echo "scale=4; $total_time / $count" | bc)
+
+sum_of_squares=0
+for time in "${times[@]}"; do
+ difference=$(echo "scale=4; $time - $average_time" | bc)
+ square=$(echo "scale=4; $difference * $difference" | bc)
+ sum_of_squares=$(echo "scale=4; $sum_of_squares + $square" | bc)
+done
+
+standard_deviation=$(echo "scale=4; sqrt($sum_of_squares / $count)" | bc)
+
+echo "Average time: $average_time±$standard_deviation"
diff --git a/common/api_type.go b/common/api_type.go
new file mode 100644
index 00000000..f045866a
--- /dev/null
+++ b/common/api_type.go
@@ -0,0 +1,73 @@
+package common
+
+import "one-api/constant"
+
+func ChannelType2APIType(channelType int) (int, bool) {
+ apiType := -1
+ switch channelType {
+ case constant.ChannelTypeOpenAI:
+ apiType = constant.APITypeOpenAI
+ case constant.ChannelTypeAnthropic:
+ apiType = constant.APITypeAnthropic
+ case constant.ChannelTypeBaidu:
+ apiType = constant.APITypeBaidu
+ case constant.ChannelTypePaLM:
+ apiType = constant.APITypePaLM
+ case constant.ChannelTypeZhipu:
+ apiType = constant.APITypeZhipu
+ case constant.ChannelTypeAli:
+ apiType = constant.APITypeAli
+ case constant.ChannelTypeXunfei:
+ apiType = constant.APITypeXunfei
+ case constant.ChannelTypeAIProxyLibrary:
+ apiType = constant.APITypeAIProxyLibrary
+ case constant.ChannelTypeTencent:
+ apiType = constant.APITypeTencent
+ case constant.ChannelTypeGemini:
+ apiType = constant.APITypeGemini
+ case constant.ChannelTypeZhipu_v4:
+ apiType = constant.APITypeZhipuV4
+ case constant.ChannelTypeOllama:
+ apiType = constant.APITypeOllama
+ case constant.ChannelTypePerplexity:
+ apiType = constant.APITypePerplexity
+ case constant.ChannelTypeAws:
+ apiType = constant.APITypeAws
+ case constant.ChannelTypeCohere:
+ apiType = constant.APITypeCohere
+ case constant.ChannelTypeDify:
+ apiType = constant.APITypeDify
+ case constant.ChannelTypeJina:
+ apiType = constant.APITypeJina
+ case constant.ChannelCloudflare:
+ apiType = constant.APITypeCloudflare
+ case constant.ChannelTypeSiliconFlow:
+ apiType = constant.APITypeSiliconFlow
+ case constant.ChannelTypeVertexAi:
+ apiType = constant.APITypeVertexAi
+ case constant.ChannelTypeMistral:
+ apiType = constant.APITypeMistral
+ case constant.ChannelTypeDeepSeek:
+ apiType = constant.APITypeDeepSeek
+ case constant.ChannelTypeMokaAI:
+ apiType = constant.APITypeMokaAI
+ case constant.ChannelTypeVolcEngine:
+ apiType = constant.APITypeVolcEngine
+ case constant.ChannelTypeBaiduV2:
+ apiType = constant.APITypeBaiduV2
+ case constant.ChannelTypeOpenRouter:
+ apiType = constant.APITypeOpenRouter
+ case constant.ChannelTypeXinference:
+ apiType = constant.APITypeXinference
+ case constant.ChannelTypeXai:
+ apiType = constant.APITypeXai
+ case constant.ChannelTypeCoze:
+ apiType = constant.APITypeCoze
+ case constant.ChannelTypeJimeng:
+ apiType = constant.APITypeJimeng
+ }
+ if apiType == -1 {
+ return constant.APITypeOpenAI, false
+ }
+ return apiType, true
+}
diff --git a/common/constants.go b/common/constants.go
new file mode 100644
index 00000000..30522411
--- /dev/null
+++ b/common/constants.go
@@ -0,0 +1,201 @@
+package common
+
+import (
+ //"os"
+ //"strconv"
+ "sync"
+ "time"
+
+ "github.com/google/uuid"
+)
+
+var StartTime = time.Now().Unix() // unit: second
+var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
+var SystemName = "New API"
+var Footer = ""
+var Logo = ""
+var TopUpLink = ""
+
+// var ChatLink = ""
+// var ChatLink2 = ""
+var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
+var DisplayInCurrencyEnabled = true
+var DisplayTokenStatEnabled = true
+var DrawingEnabled = true
+var TaskEnabled = true
+var DataExportEnabled = true
+var DataExportInterval = 5 // unit: minute
+var DataExportDefaultTime = "hour" // unit: minute
+var DefaultCollapseSidebar = false // default value of collapse sidebar
+
+// Any options with "Secret", "Token" in its key won't be return by GetOptions
+
+var SessionSecret = uuid.New().String()
+var CryptoSecret = uuid.New().String()
+
+var OptionMap map[string]string
+var OptionMapRWMutex sync.RWMutex
+
+var ItemsPerPage = 10
+var MaxRecentItems = 100
+
+var PasswordLoginEnabled = true
+var PasswordRegisterEnabled = true
+var EmailVerificationEnabled = false
+var GitHubOAuthEnabled = false
+var LinuxDOOAuthEnabled = false
+var WeChatAuthEnabled = false
+var TelegramOAuthEnabled = false
+var TurnstileCheckEnabled = false
+var RegisterEnabled = true
+
+var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
+var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
+var EmailDomainWhitelist = []string{
+ "gmail.com",
+ "163.com",
+ "126.com",
+ "qq.com",
+ "outlook.com",
+ "hotmail.com",
+ "icloud.com",
+ "yahoo.com",
+ "foxmail.com",
+}
+var EmailLoginAuthServerList = []string{
+ "smtp.sendcloud.net",
+ "smtp.azurecomm.net",
+}
+
+var DebugEnabled bool
+var MemoryCacheEnabled bool
+
+var LogConsumeEnabled = true
+
+var SMTPServer = ""
+var SMTPPort = 587
+var SMTPSSLEnabled = false
+var SMTPAccount = ""
+var SMTPFrom = ""
+var SMTPToken = ""
+
+var GitHubClientId = ""
+var GitHubClientSecret = ""
+var LinuxDOClientId = ""
+var LinuxDOClientSecret = ""
+
+var WeChatServerAddress = ""
+var WeChatServerToken = ""
+var WeChatAccountQRCodeImageURL = ""
+
+var TurnstileSiteKey = ""
+var TurnstileSecretKey = ""
+
+var TelegramBotToken = ""
+var TelegramBotName = ""
+
+var QuotaForNewUser = 0
+var QuotaForInviter = 0
+var QuotaForInvitee = 0
+var ChannelDisableThreshold = 5.0
+var AutomaticDisableChannelEnabled = false
+var AutomaticEnableChannelEnabled = false
+var QuotaRemindThreshold = 1000
+var PreConsumedQuota = 500
+
+var RetryTimes = 0
+
+//var RootUserEmail = ""
+
+var IsMasterNode bool
+
+var requestInterval int
+var RequestInterval time.Duration
+
+var SyncFrequency int // unit is second
+
+var BatchUpdateEnabled = false
+var BatchUpdateInterval int
+
+var RelayTimeout int // unit is second
+
+var GeminiSafetySetting string
+
+// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
+var CohereSafetySetting string
+
+const (
+ RequestIdKey = "X-Oneapi-Request-Id"
+)
+
+const (
+ RoleGuestUser = 0
+ RoleCommonUser = 1
+ RoleAdminUser = 10
+ RoleRootUser = 100
+)
+
+func IsValidateRole(role int) bool {
+ return role == RoleGuestUser || role == RoleCommonUser || role == RoleAdminUser || role == RoleRootUser
+}
+
+var (
+ FileUploadPermission = RoleGuestUser
+ FileDownloadPermission = RoleGuestUser
+ ImageUploadPermission = RoleGuestUser
+ ImageDownloadPermission = RoleGuestUser
+)
+
+// All duration's unit is seconds
+// Shouldn't larger then RateLimitKeyExpirationDuration
+var (
+ GlobalApiRateLimitEnable bool
+ GlobalApiRateLimitNum int
+ GlobalApiRateLimitDuration int64
+
+ GlobalWebRateLimitEnable bool
+ GlobalWebRateLimitNum int
+ GlobalWebRateLimitDuration int64
+
+ UploadRateLimitNum = 10
+ UploadRateLimitDuration int64 = 60
+
+ DownloadRateLimitNum = 10
+ DownloadRateLimitDuration int64 = 60
+
+ CriticalRateLimitNum = 20
+ CriticalRateLimitDuration int64 = 20 * 60
+)
+
+var RateLimitKeyExpirationDuration = 20 * time.Minute
+
+const (
+ UserStatusEnabled = 1 // don't use 0, 0 is the default value!
+ UserStatusDisabled = 2 // also don't use 0
+)
+
+const (
+ TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
+ TokenStatusDisabled = 2 // also don't use 0
+ TokenStatusExpired = 3
+ TokenStatusExhausted = 4
+)
+
+const (
+ RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
+ RedemptionCodeStatusDisabled = 2 // also don't use 0
+ RedemptionCodeStatusUsed = 3 // also don't use 0
+)
+
+const (
+ ChannelStatusUnknown = 0
+ ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
+ ChannelStatusManuallyDisabled = 2 // also don't use 0
+ ChannelStatusAutoDisabled = 3
+)
+
+const (
+ TopUpStatusPending = "pending"
+ TopUpStatusSuccess = "success"
+ TopUpStatusExpired = "expired"
+)
diff --git a/common/crypto.go b/common/crypto.go
new file mode 100644
index 00000000..c353188a
--- /dev/null
+++ b/common/crypto.go
@@ -0,0 +1,31 @@
+package common
+
+import (
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/hex"
+ "golang.org/x/crypto/bcrypt"
+)
+
+func GenerateHMACWithKey(key []byte, data string) string {
+ h := hmac.New(sha256.New, key)
+ h.Write([]byte(data))
+ return hex.EncodeToString(h.Sum(nil))
+}
+
+func GenerateHMAC(data string) string {
+ h := hmac.New(sha256.New, []byte(CryptoSecret))
+ h.Write([]byte(data))
+ return hex.EncodeToString(h.Sum(nil))
+}
+
+func Password2Hash(password string) (string, error) {
+ passwordBytes := []byte(password)
+ hashedPassword, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost)
+ return string(hashedPassword), err
+}
+
+func ValidatePasswordAndHash(password string, hash string) bool {
+ err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
+ return err == nil
+}
diff --git a/common/custom-event.go b/common/custom-event.go
new file mode 100644
index 00000000..d8f9ec9f
--- /dev/null
+++ b/common/custom-event.go
@@ -0,0 +1,82 @@
+// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
+// Use of this source code is governed by a MIT style
+// license that can be found in the LICENSE file.
+
+package common
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+)
+
+type stringWriter interface {
+ io.Writer
+ writeString(string) (int, error)
+}
+
+type stringWrapper struct {
+ io.Writer
+}
+
+func (w stringWrapper) writeString(str string) (int, error) {
+ return w.Writer.Write([]byte(str))
+}
+
+func checkWriter(writer io.Writer) stringWriter {
+ if w, ok := writer.(stringWriter); ok {
+ return w
+ } else {
+ return stringWrapper{writer}
+ }
+}
+
+// Server-Sent Events
+// W3C Working Draft 29 October 2009
+// http://www.w3.org/TR/2009/WD-eventsource-20091029/
+
+var contentType = []string{"text/event-stream"}
+var noCache = []string{"no-cache"}
+
+var fieldReplacer = strings.NewReplacer(
+ "\n", "\\n",
+ "\r", "\\r")
+
+var dataReplacer = strings.NewReplacer(
+ "\n", "\n",
+ "\r", "\\r")
+
+type CustomEvent struct {
+ Event string
+ Id string
+ Retry uint
+ Data interface{}
+}
+
+func encode(writer io.Writer, event CustomEvent) error {
+ w := checkWriter(writer)
+ return writeData(w, event.Data)
+}
+
+func writeData(w stringWriter, data interface{}) error {
+ dataReplacer.WriteString(w, fmt.Sprint(data))
+ if strings.HasPrefix(data.(string), "data") {
+ w.writeString("\n\n")
+ }
+ return nil
+}
+
+func (r CustomEvent) Render(w http.ResponseWriter) error {
+ r.WriteContentType(w)
+ return encode(w, r)
+}
+
+func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
+ header := w.Header()
+ header["Content-Type"] = contentType
+
+ if _, exist := header["Cache-Control"]; !exist {
+ header["Cache-Control"] = noCache
+ }
+}
diff --git a/common/database.go b/common/database.go
new file mode 100644
index 00000000..9cbaf46a
--- /dev/null
+++ b/common/database.go
@@ -0,0 +1,15 @@
+package common
+
+const (
+ DatabaseTypeMySQL = "mysql"
+ DatabaseTypeSQLite = "sqlite"
+ DatabaseTypePostgreSQL = "postgres"
+)
+
+var UsingSQLite = false
+var UsingPostgreSQL = false
+var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
+var UsingMySQL = false
+var UsingClickHouse = false
+
+var SQLitePath = "one-api.db?_busy_timeout=5000"
diff --git a/common/email-outlook-auth.go b/common/email-outlook-auth.go
new file mode 100644
index 00000000..f6a71b8e
--- /dev/null
+++ b/common/email-outlook-auth.go
@@ -0,0 +1,40 @@
+package common
+
+import (
+ "errors"
+ "net/smtp"
+ "strings"
+)
+
+type outlookAuth struct {
+ username, password string
+}
+
+func LoginAuth(username, password string) smtp.Auth {
+ return &outlookAuth{username, password}
+}
+
+func (a *outlookAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
+ return "LOGIN", []byte{}, nil
+}
+
+func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) {
+ if more {
+ switch string(fromServer) {
+ case "Username:":
+ return []byte(a.username), nil
+ case "Password:":
+ return []byte(a.password), nil
+ default:
+ return nil, errors.New("unknown fromServer")
+ }
+ }
+ return nil, nil
+}
+
+func isOutlookServer(server string) bool {
+ // 兼容多地区的outlook邮箱和ofb邮箱
+ // 其实应该加一个Option来区分是否用LOGIN的方式登录
+ // 先临时兼容一下
+ return strings.Contains(server, "outlook") || strings.Contains(server, "onmicrosoft")
+}
diff --git a/common/email.go b/common/email.go
new file mode 100644
index 00000000..18e6dbf7
--- /dev/null
+++ b/common/email.go
@@ -0,0 +1,90 @@
+package common
+
+import (
+ "crypto/tls"
+ "encoding/base64"
+ "fmt"
+ "net/smtp"
+ "slices"
+ "strings"
+ "time"
+)
+
+func generateMessageID() (string, error) {
+ split := strings.Split(SMTPFrom, "@")
+ if len(split) < 2 {
+ return "", fmt.Errorf("invalid SMTP account")
+ }
+ domain := strings.Split(SMTPFrom, "@")[1]
+ return fmt.Sprintf("<%d.%s@%s>", time.Now().UnixNano(), GetRandomString(12), domain), nil
+}
+
+func SendEmail(subject string, receiver string, content string) error {
+ if SMTPFrom == "" { // for compatibility
+ SMTPFrom = SMTPAccount
+ }
+ id, err2 := generateMessageID()
+ if err2 != nil {
+ return err2
+ }
+ if SMTPServer == "" && SMTPAccount == "" {
+ return fmt.Errorf("SMTP 服务器未配置")
+ }
+ encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
+ mail := []byte(fmt.Sprintf("To: %s\r\n"+
+ "From: %s<%s>\r\n"+
+ "Subject: %s\r\n"+
+ "Date: %s\r\n"+
+ "Message-ID: %s\r\n"+ // 添加 Message-ID 头
+ "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
+ receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), id, content))
+ auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
+ addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
+ to := strings.Split(receiver, ";")
+ var err error
+ if SMTPPort == 465 || SMTPSSLEnabled {
+ tlsConfig := &tls.Config{
+ InsecureSkipVerify: true,
+ ServerName: SMTPServer,
+ }
+ conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig)
+ if err != nil {
+ return err
+ }
+ client, err := smtp.NewClient(conn, SMTPServer)
+ if err != nil {
+ return err
+ }
+ defer client.Close()
+ if err = client.Auth(auth); err != nil {
+ return err
+ }
+ if err = client.Mail(SMTPFrom); err != nil {
+ return err
+ }
+ receiverEmails := strings.Split(receiver, ";")
+ for _, receiver := range receiverEmails {
+ if err = client.Rcpt(receiver); err != nil {
+ return err
+ }
+ }
+ w, err := client.Data()
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(mail)
+ if err != nil {
+ return err
+ }
+ err = w.Close()
+ if err != nil {
+ return err
+ }
+ } else if isOutlookServer(SMTPAccount) || slices.Contains(EmailLoginAuthServerList, SMTPServer) {
+ auth = LoginAuth(SMTPAccount, SMTPToken)
+ err = smtp.SendMail(addr, auth, SMTPFrom, to, mail)
+ } else {
+ err = smtp.SendMail(addr, auth, SMTPFrom, to, mail)
+ }
+ return err
+}
diff --git a/common/embed-file-system.go b/common/embed-file-system.go
new file mode 100644
index 00000000..3ea02cf8
--- /dev/null
+++ b/common/embed-file-system.go
@@ -0,0 +1,32 @@
+package common
+
+import (
+ "embed"
+ "github.com/gin-contrib/static"
+ "io/fs"
+ "net/http"
+)
+
+// Credit: https://github.com/gin-contrib/static/issues/19
+
+type embedFileSystem struct {
+ http.FileSystem
+}
+
+func (e embedFileSystem) Exists(prefix string, path string) bool {
+ _, err := e.Open(path)
+ if err != nil {
+ return false
+ }
+ return true
+}
+
+func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {
+ efs, err := fs.Sub(fsEmbed, targetPath)
+ if err != nil {
+ panic(err)
+ }
+ return embedFileSystem{
+ FileSystem: http.FS(efs),
+ }
+}
diff --git a/common/endpoint_type.go b/common/endpoint_type.go
new file mode 100644
index 00000000..a0ca73ea
--- /dev/null
+++ b/common/endpoint_type.go
@@ -0,0 +1,41 @@
+package common
+
+import "one-api/constant"
+
+// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
+func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
+ var endpointTypes []constant.EndpointType
+ switch channelType {
+ case constant.ChannelTypeJina:
+ endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
+ //case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus:
+ // endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney}
+ //case constant.ChannelTypeSunoAPI:
+ // endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno}
+ //case constant.ChannelTypeKling:
+ // endpointTypes = []constant.EndpointType{constant.EndpointTypeKling}
+ //case constant.ChannelTypeJimeng:
+ // endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng}
+ case constant.ChannelTypeAws:
+ fallthrough
+ case constant.ChannelTypeAnthropic:
+ endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
+ case constant.ChannelTypeVertexAi:
+ fallthrough
+ case constant.ChannelTypeGemini:
+ endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
+ case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
+ endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
+ default:
+ if IsOpenAIResponseOnlyModel(modelName) {
+ endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
+ } else {
+ endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
+ }
+ }
+ if IsImageGenerationModel(modelName) {
+ // add to first
+ endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...)
+ }
+ return endpointTypes
+}
diff --git a/common/env.go b/common/env.go
new file mode 100644
index 00000000..1aa340f8
--- /dev/null
+++ b/common/env.go
@@ -0,0 +1,38 @@
+package common
+
+import (
+ "fmt"
+ "os"
+ "strconv"
+)
+
+func GetEnvOrDefault(env string, defaultValue int) int {
+ if env == "" || os.Getenv(env) == "" {
+ return defaultValue
+ }
+ num, err := strconv.Atoi(os.Getenv(env))
+ if err != nil {
+ SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
+ return defaultValue
+ }
+ return num
+}
+
+func GetEnvOrDefaultString(env string, defaultValue string) string {
+ if env == "" || os.Getenv(env) == "" {
+ return defaultValue
+ }
+ return os.Getenv(env)
+}
+
+func GetEnvOrDefaultBool(env string, defaultValue bool) bool {
+ if env == "" || os.Getenv(env) == "" {
+ return defaultValue
+ }
+ b, err := strconv.ParseBool(os.Getenv(env))
+ if err != nil {
+ SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue))
+ return defaultValue
+ }
+ return b
+}
diff --git a/common/gin.go b/common/gin.go
new file mode 100644
index 00000000..8c67bb4d
--- /dev/null
+++ b/common/gin.go
@@ -0,0 +1,111 @@
+package common
+
+import (
+ "bytes"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+ "one-api/constant"
+ "strings"
+ "time"
+)
+
+const KeyRequestBody = "key_request_body"
+
+func GetRequestBody(c *gin.Context) ([]byte, error) {
+ requestBody, _ := c.Get(KeyRequestBody)
+ if requestBody != nil {
+ return requestBody.([]byte), nil
+ }
+ requestBody, err := io.ReadAll(c.Request.Body)
+ if err != nil {
+ return nil, err
+ }
+ _ = c.Request.Body.Close()
+ c.Set(KeyRequestBody, requestBody)
+ return requestBody.([]byte), nil
+}
+
+func UnmarshalBodyReusable(c *gin.Context, v any) error {
+ requestBody, err := GetRequestBody(c)
+ if err != nil {
+ return err
+ }
+ contentType := c.Request.Header.Get("Content-Type")
+ if strings.HasPrefix(contentType, "application/json") {
+ err = Unmarshal(requestBody, &v)
+ } else {
+ // skip for now
+ // TODO: someday non json request have variant model, we will need to implementation this
+ }
+ if err != nil {
+ return err
+ }
+ // Reset request body
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+ return nil
+}
+
+func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
+ c.Set(string(key), value)
+}
+
+func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
+ return c.Get(string(key))
+}
+
+func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
+ return c.GetString(string(key))
+}
+
+func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
+ return c.GetInt(string(key))
+}
+
+func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
+ return c.GetBool(string(key))
+}
+
+func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
+ return c.GetStringSlice(string(key))
+}
+
+func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
+ return c.GetStringMap(string(key))
+}
+
+func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
+ return c.GetTime(string(key))
+}
+
+func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
+ if value, ok := c.Get(string(key)); ok {
+ if v, ok := value.(T); ok {
+ return v, true
+ }
+ }
+ var t T
+ return t, false
+}
+
+func ApiError(c *gin.Context, err error) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+}
+
+func ApiErrorMsg(c *gin.Context, msg string) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": msg,
+ })
+}
+
+func ApiSuccess(c *gin.Context, data any) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": data,
+ })
+}
diff --git a/common/go-channel.go b/common/go-channel.go
new file mode 100644
index 00000000..f9168fc4
--- /dev/null
+++ b/common/go-channel.go
@@ -0,0 +1,53 @@
+package common
+
+import (
+ "time"
+)
+
+func SafeSendBool(ch chan bool, value bool) (closed bool) {
+ defer func() {
+ // Recover from panic if one occured. A panic would mean the channel was closed.
+ if recover() != nil {
+ closed = true
+ }
+ }()
+
+ // This will panic if the channel is closed.
+ ch <- value
+
+ // If the code reaches here, then the channel was not closed.
+ return false
+}
+
+func SafeSendString(ch chan string, value string) (closed bool) {
+ defer func() {
+ // Recover from panic if one occured. A panic would mean the channel was closed.
+ if recover() != nil {
+ closed = true
+ }
+ }()
+
+ // This will panic if the channel is closed.
+ ch <- value
+
+ // If the code reaches here, then the channel was not closed.
+ return false
+}
+
+// SafeSendStringTimeout send, return true, else return false
+func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) {
+ defer func() {
+ // Recover from panic if one occured. A panic would mean the channel was closed.
+ if recover() != nil {
+ closed = false
+ }
+ }()
+
+ // This will panic if the channel is closed.
+ select {
+ case ch <- value:
+ return true
+ case <-time.After(time.Duration(timeout) * time.Second):
+ return false
+ }
+}
diff --git a/common/gopool.go b/common/gopool.go
new file mode 100644
index 00000000..bf5df311
--- /dev/null
+++ b/common/gopool.go
@@ -0,0 +1,24 @@
+package common
+
+import (
+ "context"
+ "fmt"
+ "github.com/bytedance/gopkg/util/gopool"
+ "math"
+)
+
+var relayGoPool gopool.Pool
+
+func init() {
+ relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig())
+ relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) {
+ if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok {
+ SafeSendBool(stopChan, true)
+ }
+ SysError(fmt.Sprintf("panic in gopool.RelayPool: %v", i))
+ })
+}
+
+func RelayCtxGo(ctx context.Context, f func()) {
+ relayGoPool.CtxGo(ctx, f)
+}
diff --git a/common/hash.go b/common/hash.go
new file mode 100644
index 00000000..50191938
--- /dev/null
+++ b/common/hash.go
@@ -0,0 +1,34 @@
+package common
+
+import (
+ "crypto/hmac"
+ "crypto/sha1"
+ "crypto/sha256"
+ "encoding/hex"
+)
+
+func Sha256Raw(data []byte) []byte {
+ h := sha256.New()
+ h.Write(data)
+ return h.Sum(nil)
+}
+
+func Sha1Raw(data []byte) []byte {
+ h := sha1.New()
+ h.Write(data)
+ return h.Sum(nil)
+}
+
+func Sha1(data []byte) string {
+ return hex.EncodeToString(Sha1Raw(data))
+}
+
+func HmacSha256Raw(message, key []byte) []byte {
+ h := hmac.New(sha256.New, key)
+ h.Write(message)
+ return h.Sum(nil)
+}
+
+func HmacSha256(message, key string) string {
+ return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
+}
diff --git a/common/http.go b/common/http.go
new file mode 100644
index 00000000..d2e824ef
--- /dev/null
+++ b/common/http.go
@@ -0,0 +1,57 @@
+package common
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+
+ "github.com/gin-gonic/gin"
+)
+
+func CloseResponseBodyGracefully(httpResponse *http.Response) {
+ if httpResponse == nil || httpResponse.Body == nil {
+ return
+ }
+ err := httpResponse.Body.Close()
+ if err != nil {
+ SysError("failed to close response body: " + err.Error())
+ }
+}
+
+func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
+ if c.Writer == nil {
+ return
+ }
+
+ body := io.NopCloser(bytes.NewBuffer(data))
+
+ // We shouldn't set the header before we parse the response body, because the parse part may fail.
+ // And then we will have to send an error response, but in this case, the header has already been set.
+ // So the httpClient will be confused by the response.
+ // For example, Postman will report error, and we cannot check the response at all.
+ if src != nil {
+ for k, v := range src.Header {
+ // avoid setting Content-Length
+ if k == "Content-Length" {
+ continue
+ }
+ c.Writer.Header().Set(k, v[0])
+ }
+ }
+
+ // set Content-Length header manually BEFORE calling WriteHeader
+ c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
+
+ // Write header with status code (this sends the headers)
+ if src != nil {
+ c.Writer.WriteHeader(src.StatusCode)
+ } else {
+ c.Writer.WriteHeader(http.StatusOK)
+ }
+
+ _, err := io.Copy(c.Writer, body)
+ if err != nil {
+ LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
+ }
+}
diff --git a/common/init.go b/common/init.go
new file mode 100644
index 00000000..d70a09dd
--- /dev/null
+++ b/common/init.go
@@ -0,0 +1,120 @@
+package common
+
+import (
+ "flag"
+ "fmt"
+ "log"
+ "one-api/constant"
+ "os"
+ "path/filepath"
+ "strconv"
+ "time"
+)
+
+var (
+ Port = flag.Int("port", 3000, "the listening port")
+ PrintVersion = flag.Bool("version", false, "print version and exit")
+ PrintHelp = flag.Bool("help", false, "print help and exit")
+ LogDir = flag.String("log-dir", "./logs", "specify the log directory")
+)
+
+func printHelp() {
+ fmt.Println("New API " + Version + " - All in one API service for OpenAI API.")
+ fmt.Println("Copyright (C) 2023 JustSong. All rights reserved.")
+ fmt.Println("GitHub: https://github.com/songquanpeng/one-api")
+ fmt.Println("Usage: one-api [--port ] [--log-dir ] [--version] [--help]")
+}
+
+func InitEnv() {
+ flag.Parse()
+
+ if *PrintVersion {
+ fmt.Println(Version)
+ os.Exit(0)
+ }
+
+ if *PrintHelp {
+ printHelp()
+ os.Exit(0)
+ }
+
+ if os.Getenv("SESSION_SECRET") != "" {
+ ss := os.Getenv("SESSION_SECRET")
+ if ss == "random_string" {
+ log.Println("WARNING: SESSION_SECRET is set to the default value 'random_string', please change it to a random string.")
+ log.Println("警告:SESSION_SECRET被设置为默认值'random_string',请修改为随机字符串。")
+ log.Fatal("Please set SESSION_SECRET to a random string.")
+ } else {
+ SessionSecret = ss
+ }
+ }
+ if os.Getenv("CRYPTO_SECRET") != "" {
+ CryptoSecret = os.Getenv("CRYPTO_SECRET")
+ } else {
+ CryptoSecret = SessionSecret
+ }
+ if os.Getenv("SQLITE_PATH") != "" {
+ SQLitePath = os.Getenv("SQLITE_PATH")
+ }
+ if *LogDir != "" {
+ var err error
+ *LogDir, err = filepath.Abs(*LogDir)
+ if err != nil {
+ log.Fatal(err)
+ }
+ if _, err := os.Stat(*LogDir); os.IsNotExist(err) {
+ err = os.Mkdir(*LogDir, 0777)
+ if err != nil {
+ log.Fatal(err)
+ }
+ }
+ }
+
+ // Initialize variables from constants.go that were using environment variables
+ DebugEnabled = os.Getenv("DEBUG") == "true"
+ MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
+ IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
+
+ // Parse requestInterval and set RequestInterval
+ requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
+ RequestInterval = time.Duration(requestInterval) * time.Second
+
+ // Initialize variables with GetEnvOrDefault
+ SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
+ BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
+ RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
+
+ // Initialize string variables with GetEnvOrDefaultString
+ GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
+ CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
+
+ // Initialize rate limit variables
+ GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
+ GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
+ GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
+
+ GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
+ GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
+ GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
+
+ initConstantEnv()
+}
+
+func initConstantEnv() {
+ constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120)
+ constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
+ constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
+ // ForceStreamOption 覆盖请求参数,强制返回usage信息
+ constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
+ constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
+ constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
+ constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
+ constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
+ constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
+ constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
+ constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
+ // GenerateDefaultToken 是否生成初始令牌,默认关闭。
+ constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
+ // 是否启用错误日志
+ constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
+}
diff --git a/common/json.go b/common/json.go
new file mode 100644
index 00000000..69aa952e
--- /dev/null
+++ b/common/json.go
@@ -0,0 +1,22 @@
+package common
+
+import (
+ "bytes"
+ "encoding/json"
+)
+
+func Unmarshal(data []byte, v any) error {
+ return json.Unmarshal(data, v)
+}
+
+func UnmarshalJsonStr(data string, v any) error {
+ return json.Unmarshal(StringToByteSlice(data), v)
+}
+
+func DecodeJson(reader *bytes.Reader, v any) error {
+ return json.NewDecoder(reader).Decode(v)
+}
+
+func Marshal(v any) ([]byte, error) {
+ return json.Marshal(v)
+}
diff --git a/common/limiter/limiter.go b/common/limiter/limiter.go
new file mode 100644
index 00000000..ef5d1935
--- /dev/null
+++ b/common/limiter/limiter.go
@@ -0,0 +1,89 @@
+package limiter
+
+import (
+ "context"
+ _ "embed"
+ "fmt"
+ "github.com/go-redis/redis/v8"
+ "one-api/common"
+ "sync"
+)
+
+//go:embed lua/rate_limit.lua
+var rateLimitScript string
+
+type RedisLimiter struct {
+ client *redis.Client
+ limitScriptSHA string
+}
+
+var (
+ instance *RedisLimiter
+ once sync.Once
+)
+
+func New(ctx context.Context, r *redis.Client) *RedisLimiter {
+ once.Do(func() {
+ // 预加载脚本
+ limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
+ if err != nil {
+ common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
+ }
+ instance = &RedisLimiter{
+ client: r,
+ limitScriptSHA: limitSHA,
+ }
+ })
+
+ return instance
+}
+
+func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) {
+ // 默认配置
+ config := &Config{
+ Capacity: 10,
+ Rate: 1,
+ Requested: 1,
+ }
+
+ // 应用选项模式
+ for _, opt := range opts {
+ opt(config)
+ }
+
+ // 执行限流
+ result, err := rl.client.EvalSha(
+ ctx,
+ rl.limitScriptSHA,
+ []string{key},
+ config.Requested,
+ config.Rate,
+ config.Capacity,
+ ).Int()
+
+ if err != nil {
+ return false, fmt.Errorf("rate limit failed: %w", err)
+ }
+ return result == 1, nil
+}
+
+// Config 配置选项模式
+type Config struct {
+ Capacity int64
+ Rate int64
+ Requested int64
+}
+
+type Option func(*Config)
+
+func WithCapacity(c int64) Option {
+ return func(cfg *Config) { cfg.Capacity = c }
+}
+
+func WithRate(r int64) Option {
+ return func(cfg *Config) { cfg.Rate = r }
+}
+
+func WithRequested(n int64) Option {
+ return func(cfg *Config) { cfg.Requested = n }
+}
diff --git a/common/limiter/lua/rate_limit.lua b/common/limiter/lua/rate_limit.lua
new file mode 100644
index 00000000..c07fd3a8
--- /dev/null
+++ b/common/limiter/lua/rate_limit.lua
@@ -0,0 +1,44 @@
+-- 令牌桶限流器
+-- KEYS[1]: 限流器唯一标识
+-- ARGV[1]: 请求令牌数 (通常为1)
+-- ARGV[2]: 令牌生成速率 (每秒)
+-- ARGV[3]: 桶容量
+
+local key = KEYS[1]
+local requested = tonumber(ARGV[1])
+local rate = tonumber(ARGV[2])
+local capacity = tonumber(ARGV[3])
+
+-- 获取当前时间(Redis服务器时间)
+local now = redis.call('TIME')
+local nowInSeconds = tonumber(now[1])
+
+-- 获取桶状态
+local bucket = redis.call('HMGET', key, 'tokens', 'last_time')
+local tokens = tonumber(bucket[1])
+local last_time = tonumber(bucket[2])
+
+-- 初始化桶(首次请求或过期)
+if not tokens or not last_time then
+ tokens = capacity
+ last_time = nowInSeconds
+else
+ -- 计算新增令牌
+ local elapsed = nowInSeconds - last_time
+ local add_tokens = elapsed * rate
+ tokens = math.min(capacity, tokens + add_tokens)
+ last_time = nowInSeconds
+end
+
+-- 判断是否允许请求
+local allowed = false
+if tokens >= requested then
+ tokens = tokens - requested
+ allowed = true
+end
+
+---- 更新桶状态并设置过期时间
+redis.call('HMSET', key, 'tokens', tokens, 'last_time', last_time)
+--redis.call('EXPIRE', key, math.ceil(capacity / rate) + 60) -- 适当延长过期时间
+
+return allowed and 1 or 0
\ No newline at end of file
diff --git a/common/logger.go b/common/logger.go
new file mode 100644
index 00000000..0f6dc3c3
--- /dev/null
+++ b/common/logger.go
@@ -0,0 +1,123 @@
+package common
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "github.com/bytedance/gopkg/util/gopool"
+ "github.com/gin-gonic/gin"
+ "io"
+ "log"
+ "os"
+ "path/filepath"
+ "sync"
+ "time"
+)
+
+const (
+ loggerINFO = "INFO"
+ loggerWarn = "WARN"
+ loggerError = "ERR"
+)
+
+const maxLogCount = 1000000
+
+var logCount int
+var setupLogLock sync.Mutex
+var setupLogWorking bool
+
+func SetupLogger() {
+ if *LogDir != "" {
+ ok := setupLogLock.TryLock()
+ if !ok {
+ log.Println("setup log is already working")
+ return
+ }
+ defer func() {
+ setupLogLock.Unlock()
+ setupLogWorking = false
+ }()
+ logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
+ fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
+ if err != nil {
+ log.Fatal("failed to open log file")
+ }
+ gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
+ gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
+ }
+}
+
+func SysLog(s string) {
+ t := time.Now()
+ _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
+}
+
+func SysError(s string) {
+ t := time.Now()
+ _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
+}
+
+func LogInfo(ctx context.Context, msg string) {
+ logHelper(ctx, loggerINFO, msg)
+}
+
+func LogWarn(ctx context.Context, msg string) {
+ logHelper(ctx, loggerWarn, msg)
+}
+
+func LogError(ctx context.Context, msg string) {
+ logHelper(ctx, loggerError, msg)
+}
+
+func logHelper(ctx context.Context, level string, msg string) {
+ writer := gin.DefaultErrorWriter
+ if level == loggerINFO {
+ writer = gin.DefaultWriter
+ }
+ id := ctx.Value(RequestIdKey)
+ if id == nil {
+ id = "SYSTEM"
+ }
+ now := time.Now()
+ _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
+ logCount++ // we don't need accurate count, so no lock here
+ if logCount > maxLogCount && !setupLogWorking {
+ logCount = 0
+ setupLogWorking = true
+ gopool.Go(func() {
+ SetupLogger()
+ })
+ }
+}
+
+func FatalLog(v ...any) {
+ t := time.Now()
+ _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
+ os.Exit(1)
+}
+
+func LogQuota(quota int) string {
+ if DisplayInCurrencyEnabled {
+ return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
+ } else {
+ return fmt.Sprintf("%d 点额度", quota)
+ }
+}
+
+func FormatQuota(quota int) string {
+ if DisplayInCurrencyEnabled {
+ return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit)
+ } else {
+ return fmt.Sprintf("%d", quota)
+ }
+}
+
+// LogJson 仅供测试使用 only for test
+func LogJson(ctx context.Context, msg string, obj any) {
+ jsonStr, err := json.Marshal(obj)
+ if err != nil {
+ LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
+ return
+ }
+ LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
+}
diff --git a/common/model.go b/common/model.go
new file mode 100644
index 00000000..14ca1911
--- /dev/null
+++ b/common/model.go
@@ -0,0 +1,42 @@
+package common
+
+import "strings"
+
+var (
+ // OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
+ OpenAIResponseOnlyModels = []string{
+ "o3-pro",
+ "o3-deep-research",
+ "o4-mini-deep-research",
+ }
+ ImageGenerationModels = []string{
+ "dall-e-3",
+ "dall-e-2",
+ "gpt-image-1",
+ "prefix:imagen-",
+ "flux-",
+ "flux.1-",
+ }
+)
+
+func IsOpenAIResponseOnlyModel(modelName string) bool {
+ for _, m := range OpenAIResponseOnlyModels {
+ if strings.Contains(modelName, m) {
+ return true
+ }
+ }
+ return false
+}
+
+func IsImageGenerationModel(modelName string) bool {
+ modelName = strings.ToLower(modelName)
+ for _, m := range ImageGenerationModels {
+ if strings.Contains(modelName, m) {
+ return true
+ }
+ if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) {
+ return true
+ }
+ }
+ return false
+}
diff --git a/common/page_info.go b/common/page_info.go
new file mode 100644
index 00000000..5e4535e3
--- /dev/null
+++ b/common/page_info.go
@@ -0,0 +1,82 @@
+package common
+
+import (
+ "strconv"
+
+ "github.com/gin-gonic/gin"
+)
+
+type PageInfo struct {
+ Page int `json:"page"` // page num 页码
+ PageSize int `json:"page_size"` // page size 页大小
+
+ Total int `json:"total"` // 总条数,后设置
+ Items any `json:"items"` // 数据,后设置
+}
+
+func (p *PageInfo) GetStartIdx() int {
+ return (p.Page - 1) * p.PageSize
+}
+
+func (p *PageInfo) GetEndIdx() int {
+ return p.Page * p.PageSize
+}
+
+func (p *PageInfo) GetPageSize() int {
+ return p.PageSize
+}
+
+func (p *PageInfo) GetPage() int {
+ return p.Page
+}
+
+func (p *PageInfo) SetTotal(total int) {
+ p.Total = total
+}
+
+func (p *PageInfo) SetItems(items any) {
+ p.Items = items
+}
+
+func GetPageQuery(c *gin.Context) *PageInfo {
+ pageInfo := &PageInfo{}
+ // 手动获取并处理每个参数
+ if page, err := strconv.Atoi(c.Query("page")); err == nil {
+ pageInfo.Page = page
+ }
+ if pageSize, err := strconv.Atoi(c.Query("page_size")); err == nil {
+ pageInfo.PageSize = pageSize
+ }
+ if pageInfo.Page < 1 {
+ // 兼容
+ page, _ := strconv.Atoi(c.Query("p"))
+ if page != 0 {
+ pageInfo.Page = page
+ } else {
+ pageInfo.Page = 1
+ }
+ }
+
+ if pageInfo.PageSize == 0 {
+ // 兼容
+ pageSize, _ := strconv.Atoi(c.Query("ps"))
+ if pageSize != 0 {
+ pageInfo.PageSize = pageSize
+ }
+ if pageInfo.PageSize == 0 {
+ pageSize, _ = strconv.Atoi(c.Query("size")) // token page
+ if pageSize != 0 {
+ pageInfo.PageSize = pageSize
+ }
+ }
+ if pageInfo.PageSize == 0 {
+ pageInfo.PageSize = ItemsPerPage
+ }
+ }
+
+ if pageInfo.PageSize > 100 {
+ pageInfo.PageSize = 100
+ }
+
+ return pageInfo
+}
diff --git a/common/pprof.go b/common/pprof.go
new file mode 100644
index 00000000..4bec30f1
--- /dev/null
+++ b/common/pprof.go
@@ -0,0 +1,44 @@
+package common
+
+import (
+ "fmt"
+ "github.com/shirou/gopsutil/cpu"
+ "os"
+ "runtime/pprof"
+ "time"
+)
+
+// Monitor 定时监控cpu使用率,超过阈值输出pprof文件
+func Monitor() {
+ for {
+ percent, err := cpu.Percent(time.Second, false)
+ if err != nil {
+ panic(err)
+ }
+ if percent[0] > 80 {
+ fmt.Println("cpu usage too high")
+ // write pprof file
+ if _, err := os.Stat("./pprof"); os.IsNotExist(err) {
+ err := os.Mkdir("./pprof", os.ModePerm)
+ if err != nil {
+ SysLog("创建pprof文件夹失败 " + err.Error())
+ continue
+ }
+ }
+ f, err := os.Create("./pprof/" + fmt.Sprintf("cpu-%s.pprof", time.Now().Format("20060102150405")))
+ if err != nil {
+ SysLog("创建pprof文件失败 " + err.Error())
+ continue
+ }
+ err = pprof.StartCPUProfile(f)
+ if err != nil {
+ SysLog("启动pprof失败 " + err.Error())
+ continue
+ }
+ time.Sleep(10 * time.Second) // profile for 30 seconds
+ pprof.StopCPUProfile()
+ f.Close()
+ }
+ time.Sleep(30 * time.Second)
+ }
+}
diff --git a/common/rate-limit.go b/common/rate-limit.go
new file mode 100644
index 00000000..301c101c
--- /dev/null
+++ b/common/rate-limit.go
@@ -0,0 +1,70 @@
+package common
+
+import (
+ "sync"
+ "time"
+)
+
+type InMemoryRateLimiter struct {
+ store map[string]*[]int64
+ mutex sync.Mutex
+ expirationDuration time.Duration
+}
+
+func (l *InMemoryRateLimiter) Init(expirationDuration time.Duration) {
+ if l.store == nil {
+ l.mutex.Lock()
+ if l.store == nil {
+ l.store = make(map[string]*[]int64)
+ l.expirationDuration = expirationDuration
+ if expirationDuration > 0 {
+ go l.clearExpiredItems()
+ }
+ }
+ l.mutex.Unlock()
+ }
+}
+
+func (l *InMemoryRateLimiter) clearExpiredItems() {
+ for {
+ time.Sleep(l.expirationDuration)
+ l.mutex.Lock()
+ now := time.Now().Unix()
+ for key := range l.store {
+ queue := l.store[key]
+ size := len(*queue)
+ if size == 0 || now-(*queue)[size-1] > int64(l.expirationDuration.Seconds()) {
+ delete(l.store, key)
+ }
+ }
+ l.mutex.Unlock()
+ }
+}
+
+// Request parameter duration's unit is seconds
+func (l *InMemoryRateLimiter) Request(key string, maxRequestNum int, duration int64) bool {
+ l.mutex.Lock()
+ defer l.mutex.Unlock()
+ // [old <-- new]
+ queue, ok := l.store[key]
+ now := time.Now().Unix()
+ if ok {
+ if len(*queue) < maxRequestNum {
+ *queue = append(*queue, now)
+ return true
+ } else {
+ if now-(*queue)[0] >= duration {
+ *queue = (*queue)[1:]
+ *queue = append(*queue, now)
+ return true
+ } else {
+ return false
+ }
+ }
+ } else {
+ s := make([]int64, 0, maxRequestNum)
+ l.store[key] = &s
+ *(l.store[key]) = append(*(l.store[key]), now)
+ }
+ return true
+}
diff --git a/common/redis.go b/common/redis.go
new file mode 100644
index 00000000..c7287837
--- /dev/null
+++ b/common/redis.go
@@ -0,0 +1,327 @@
+package common
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "reflect"
+ "strconv"
+ "time"
+
+ "github.com/go-redis/redis/v8"
+ "gorm.io/gorm"
+)
+
+var RDB *redis.Client
+var RedisEnabled = true
+
+func RedisKeyCacheSeconds() int {
+ return SyncFrequency
+}
+
+// InitRedisClient This function is called after init()
+func InitRedisClient() (err error) {
+ if os.Getenv("REDIS_CONN_STRING") == "" {
+ RedisEnabled = false
+ SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
+ return nil
+ }
+ if os.Getenv("SYNC_FREQUENCY") == "" {
+ SysLog("SYNC_FREQUENCY not set, use default value 60")
+ SyncFrequency = 60
+ }
+ SysLog("Redis is enabled")
+ opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
+ if err != nil {
+ FatalLog("failed to parse Redis connection string: " + err.Error())
+ }
+ opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10)
+ RDB = redis.NewClient(opt)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ _, err = RDB.Ping(ctx).Result()
+ if err != nil {
+ FatalLog("Redis ping test failed: " + err.Error())
+ }
+ if DebugEnabled {
+ SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr))
+ SysLog(fmt.Sprintf("Redis database: %d", opt.DB))
+ }
+ return err
+}
+
+func ParseRedisOption() *redis.Options {
+ opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
+ if err != nil {
+ FatalLog("failed to parse Redis connection string: " + err.Error())
+ }
+ return opt
+}
+
+func RedisSet(key string, value string, expiration time.Duration) error {
+ if DebugEnabled {
+ SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration))
+ }
+ ctx := context.Background()
+ return RDB.Set(ctx, key, value, expiration).Err()
+}
+
+func RedisGet(key string) (string, error) {
+ if DebugEnabled {
+ SysLog(fmt.Sprintf("Redis GET: key=%s", key))
+ }
+ ctx := context.Background()
+ val, err := RDB.Get(ctx, key).Result()
+ return val, err
+}
+
+//func RedisExpire(key string, expiration time.Duration) error {
+// ctx := context.Background()
+// return RDB.Expire(ctx, key, expiration).Err()
+//}
+//
+//func RedisGetEx(key string, expiration time.Duration) (string, error) {
+// ctx := context.Background()
+// return RDB.GetSet(ctx, key, expiration).Result()
+//}
+
+func RedisDel(key string) error {
+ if DebugEnabled {
+ SysLog(fmt.Sprintf("Redis DEL: key=%s", key))
+ }
+ ctx := context.Background()
+ return RDB.Del(ctx, key).Err()
+}
+
+func RedisDelKey(key string) error {
+ if DebugEnabled {
+ SysLog(fmt.Sprintf("Redis DEL Key: key=%s", key))
+ }
+ ctx := context.Background()
+ return RDB.Del(ctx, key).Err()
+}
+
+func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
+ if DebugEnabled {
+ SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration))
+ }
+ ctx := context.Background()
+
+ data := make(map[string]interface{})
+
+ // 使用反射遍历结构体字段
+ v := reflect.ValueOf(obj).Elem()
+ t := v.Type()
+ for i := 0; i < v.NumField(); i++ {
+ field := t.Field(i)
+ value := v.Field(i)
+
+ // Skip DeletedAt field
+ if field.Type.String() == "gorm.DeletedAt" {
+ continue
+ }
+
+ // 处理指针类型
+ if value.Kind() == reflect.Ptr {
+ if value.IsNil() {
+ data[field.Name] = ""
+ continue
+ }
+ value = value.Elem()
+ }
+
+ // 处理布尔类型
+ if value.Kind() == reflect.Bool {
+ data[field.Name] = strconv.FormatBool(value.Bool())
+ continue
+ }
+
+ // 其他类型直接转换为字符串
+ data[field.Name] = fmt.Sprintf("%v", value.Interface())
+ }
+
+ txn := RDB.TxPipeline()
+ txn.HSet(ctx, key, data)
+
+ // 只有在 expiration 大于 0 时才设置过期时间
+ if expiration > 0 {
+ txn.Expire(ctx, key, expiration)
+ }
+
+ _, err := txn.Exec(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to execute transaction: %w", err)
+ }
+ return nil
+}
+
+func RedisHGetObj(key string, obj interface{}) error {
+ if DebugEnabled {
+ SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key))
+ }
+ ctx := context.Background()
+
+ result, err := RDB.HGetAll(ctx, key).Result()
+ if err != nil {
+ return fmt.Errorf("failed to load hash from Redis: %w", err)
+ }
+
+ if len(result) == 0 {
+ return fmt.Errorf("key %s not found in Redis", key)
+ }
+
+ // Handle both pointer and non-pointer values
+ val := reflect.ValueOf(obj)
+ if val.Kind() != reflect.Ptr {
+ return fmt.Errorf("obj must be a pointer to a struct, got %T", obj)
+ }
+
+ v := val.Elem()
+ if v.Kind() != reflect.Struct {
+ return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface())
+ }
+
+ t := v.Type()
+ for i := 0; i < v.NumField(); i++ {
+ field := t.Field(i)
+ fieldName := field.Name
+ if value, ok := result[fieldName]; ok {
+ fieldValue := v.Field(i)
+
+ // Handle pointer types
+ if fieldValue.Kind() == reflect.Ptr {
+ if value == "" {
+ continue
+ }
+ if fieldValue.IsNil() {
+ fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
+ }
+ fieldValue = fieldValue.Elem()
+ }
+
+ // Enhanced type handling for Token struct
+ switch fieldValue.Kind() {
+ case reflect.String:
+ fieldValue.SetString(value)
+ case reflect.Int, reflect.Int64:
+ intValue, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return fmt.Errorf("failed to parse int field %s: %w", fieldName, err)
+ }
+ fieldValue.SetInt(intValue)
+ case reflect.Bool:
+ boolValue, err := strconv.ParseBool(value)
+ if err != nil {
+ return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err)
+ }
+ fieldValue.SetBool(boolValue)
+ case reflect.Struct:
+ // Special handling for gorm.DeletedAt
+ if fieldValue.Type().String() == "gorm.DeletedAt" {
+ if value != "" {
+ timeValue, err := time.Parse(time.RFC3339, value)
+ if err != nil {
+ return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err)
+ }
+ fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true}))
+ }
+ }
+ default:
+ return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName)
+ }
+ }
+ }
+
+ return nil
+}
+
+// RedisIncr Add this function to handle atomic increments
+func RedisIncr(key string, delta int64) error {
+ if DebugEnabled {
+ SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta))
+ }
+ // 检查键的剩余生存时间
+ ttlCmd := RDB.TTL(context.Background(), key)
+ ttl, err := ttlCmd.Result()
+ if err != nil && !errors.Is(err, redis.Nil) {
+ return fmt.Errorf("failed to get TTL: %w", err)
+ }
+
+ // 只有在 key 存在且有 TTL 时才需要特殊处理
+ if ttl > 0 {
+ ctx := context.Background()
+ // 开始一个Redis事务
+ txn := RDB.TxPipeline()
+
+ // 减少余额
+ decrCmd := txn.IncrBy(ctx, key, delta)
+ if err := decrCmd.Err(); err != nil {
+ return err // 如果减少失败,则直接返回错误
+ }
+
+ // 重新设置过期时间,使用原来的过期时间
+ txn.Expire(ctx, key, ttl)
+
+ // 执行事务
+ _, err = txn.Exec(ctx)
+ return err
+ }
+ return nil
+}
+
+func RedisHIncrBy(key, field string, delta int64) error {
+ if DebugEnabled {
+ SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta))
+ }
+ ttlCmd := RDB.TTL(context.Background(), key)
+ ttl, err := ttlCmd.Result()
+ if err != nil && !errors.Is(err, redis.Nil) {
+ return fmt.Errorf("failed to get TTL: %w", err)
+ }
+
+ if ttl > 0 {
+ ctx := context.Background()
+ txn := RDB.TxPipeline()
+
+ incrCmd := txn.HIncrBy(ctx, key, field, delta)
+ if err := incrCmd.Err(); err != nil {
+ return err
+ }
+
+ txn.Expire(ctx, key, ttl)
+
+ _, err = txn.Exec(ctx)
+ return err
+ }
+ return nil
+}
+
+func RedisHSetField(key, field string, value interface{}) error {
+ if DebugEnabled {
+ SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value))
+ }
+ ttlCmd := RDB.TTL(context.Background(), key)
+ ttl, err := ttlCmd.Result()
+ if err != nil && !errors.Is(err, redis.Nil) {
+ return fmt.Errorf("failed to get TTL: %w", err)
+ }
+
+ if ttl > 0 {
+ ctx := context.Background()
+ txn := RDB.TxPipeline()
+
+ hsetCmd := txn.HSet(ctx, key, field, value)
+ if err := hsetCmd.Err(); err != nil {
+ return err
+ }
+
+ txn.Expire(ctx, key, ttl)
+
+ _, err = txn.Exec(ctx)
+ return err
+ }
+ return nil
+}
diff --git a/common/str.go b/common/str.go
new file mode 100644
index 00000000..88b58c72
--- /dev/null
+++ b/common/str.go
@@ -0,0 +1,97 @@
+package common
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "math/rand"
+ "strconv"
+ "unsafe"
+)
+
+func GetStringIfEmpty(str string, defaultValue string) string {
+ if str == "" {
+ return defaultValue
+ }
+ return str
+}
+
+func GetRandomString(length int) string {
+ //rand.Seed(time.Now().UnixNano())
+ key := make([]byte, length)
+ for i := 0; i < length; i++ {
+ key[i] = keyChars[rand.Intn(len(keyChars))]
+ }
+ return string(key)
+}
+
+func MapToJsonStr(m map[string]interface{}) string {
+ bytes, err := json.Marshal(m)
+ if err != nil {
+ return ""
+ }
+ return string(bytes)
+}
+
+func StrToMap(str string) (map[string]interface{}, error) {
+ m := make(map[string]interface{})
+ err := Unmarshal([]byte(str), &m)
+ if err != nil {
+ return nil, err
+ }
+ return m, nil
+}
+
+func StrToJsonArray(str string) ([]interface{}, error) {
+ var js []interface{}
+ err := json.Unmarshal([]byte(str), &js)
+ if err != nil {
+ return nil, err
+ }
+ return js, nil
+}
+
+func IsJsonArray(str string) bool {
+ var js []interface{}
+ return json.Unmarshal([]byte(str), &js) == nil
+}
+
+func IsJsonObject(str string) bool {
+ var js map[string]interface{}
+ return json.Unmarshal([]byte(str), &js) == nil
+}
+
+func String2Int(str string) int {
+ num, err := strconv.Atoi(str)
+ if err != nil {
+ return 0
+ }
+ return num
+}
+
+func StringsContains(strs []string, str string) bool {
+ for _, s := range strs {
+ if s == str {
+ return true
+ }
+ }
+ return false
+}
+
+// StringToByteSlice []byte only read, panic on append
+func StringToByteSlice(s string) []byte {
+ tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
+ tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
+ return *(*[]byte)(unsafe.Pointer(&tmp2))
+}
+
+func EncodeBase64(str string) string {
+ return base64.StdEncoding.EncodeToString([]byte(str))
+}
+
+func GetJsonString(data any) string {
+ if data == nil {
+ return ""
+ }
+ b, _ := json.Marshal(data)
+ return string(b)
+}
diff --git a/common/topup-ratio.go b/common/topup-ratio.go
new file mode 100644
index 00000000..8f03395d
--- /dev/null
+++ b/common/topup-ratio.go
@@ -0,0 +1,33 @@
+package common
+
+import (
+ "encoding/json"
+)
+
+var TopupGroupRatio = map[string]float64{
+ "default": 1,
+ "vip": 1,
+ "svip": 1,
+}
+
+func TopupGroupRatio2JSONString() string {
+ jsonBytes, err := json.Marshal(TopupGroupRatio)
+ if err != nil {
+ SysError("error marshalling model ratio: " + err.Error())
+ }
+ return string(jsonBytes)
+}
+
+func UpdateTopupGroupRatioByJSONString(jsonStr string) error {
+ TopupGroupRatio = make(map[string]float64)
+ return json.Unmarshal([]byte(jsonStr), &TopupGroupRatio)
+}
+
+func GetTopupGroupRatio(name string) float64 {
+ ratio, ok := TopupGroupRatio[name]
+ if !ok {
+ SysError("topup group ratio not found: " + name)
+ return 1
+ }
+ return ratio
+}
diff --git a/common/utils.go b/common/utils.go
new file mode 100644
index 00000000..17aecd95
--- /dev/null
+++ b/common/utils.go
@@ -0,0 +1,304 @@
+package common
+
+import (
+ "bytes"
+ "context"
+ crand "crypto/rand"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "html/template"
+ "io"
+ "log"
+ "math/big"
+ "math/rand"
+ "net"
+ "net/url"
+ "os"
+ "os/exec"
+ "runtime"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/pkg/errors"
+)
+
+func OpenBrowser(url string) {
+ var err error
+
+ switch runtime.GOOS {
+ case "linux":
+ err = exec.Command("xdg-open", url).Start()
+ case "windows":
+ err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
+ case "darwin":
+ err = exec.Command("open", url).Start()
+ }
+ if err != nil {
+ log.Println(err)
+ }
+}
+
+func GetIp() (ip string) {
+ ips, err := net.InterfaceAddrs()
+ if err != nil {
+ log.Println(err)
+ return ip
+ }
+
+ for _, a := range ips {
+ if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
+ if ipNet.IP.To4() != nil {
+ ip = ipNet.IP.String()
+ if strings.HasPrefix(ip, "10") {
+ return
+ }
+ if strings.HasPrefix(ip, "172") {
+ return
+ }
+ if strings.HasPrefix(ip, "192.168") {
+ return
+ }
+ ip = ""
+ }
+ }
+ }
+ return
+}
+
+var sizeKB = 1024
+var sizeMB = sizeKB * 1024
+var sizeGB = sizeMB * 1024
+
+func Bytes2Size(num int64) string {
+ numStr := ""
+ unit := "B"
+ if num/int64(sizeGB) > 1 {
+ numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
+ unit = "GB"
+ } else if num/int64(sizeMB) > 1 {
+ numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
+ unit = "MB"
+ } else if num/int64(sizeKB) > 1 {
+ numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
+ unit = "KB"
+ } else {
+ numStr = fmt.Sprintf("%d", num)
+ }
+ return numStr + " " + unit
+}
+
+func Seconds2Time(num int) (time string) {
+ if num/31104000 > 0 {
+ time += strconv.Itoa(num/31104000) + " 年 "
+ num %= 31104000
+ }
+ if num/2592000 > 0 {
+ time += strconv.Itoa(num/2592000) + " 个月 "
+ num %= 2592000
+ }
+ if num/86400 > 0 {
+ time += strconv.Itoa(num/86400) + " 天 "
+ num %= 86400
+ }
+ if num/3600 > 0 {
+ time += strconv.Itoa(num/3600) + " 小时 "
+ num %= 3600
+ }
+ if num/60 > 0 {
+ time += strconv.Itoa(num/60) + " 分钟 "
+ num %= 60
+ }
+ time += strconv.Itoa(num) + " 秒"
+ return
+}
+
+func Interface2String(inter interface{}) string {
+ switch inter.(type) {
+ case string:
+ return inter.(string)
+ case int:
+ return fmt.Sprintf("%d", inter.(int))
+ case float64:
+ return fmt.Sprintf("%f", inter.(float64))
+ }
+ return "Not Implemented"
+}
+
+func UnescapeHTML(x string) interface{} {
+ return template.HTML(x)
+}
+
+func IntMax(a int, b int) int {
+ if a >= b {
+ return a
+ } else {
+ return b
+ }
+}
+
+func IsIP(s string) bool {
+ ip := net.ParseIP(s)
+ return ip != nil
+}
+
+func GetUUID() string {
+ code := uuid.New().String()
+ code = strings.Replace(code, "-", "", -1)
+ return code
+}
+
+const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+
+func init() {
+ rand.New(rand.NewSource(time.Now().UnixNano()))
+}
+
+func GenerateRandomCharsKey(length int) (string, error) {
+ b := make([]byte, length)
+ maxI := big.NewInt(int64(len(keyChars)))
+
+ for i := range b {
+ n, err := crand.Int(crand.Reader, maxI)
+ if err != nil {
+ return "", err
+ }
+ b[i] = keyChars[n.Int64()]
+ }
+
+ return string(b), nil
+}
+
+func GenerateRandomKey(length int) (string, error) {
+ bytes := make([]byte, length*3/4) // 对于48位的输出,这里应该是36
+ if _, err := crand.Read(bytes); err != nil {
+ return "", err
+ }
+ return base64.StdEncoding.EncodeToString(bytes), nil
+}
+
+func GenerateKey() (string, error) {
+ //rand.Seed(time.Now().UnixNano())
+ return GenerateRandomCharsKey(48)
+}
+
+func GetRandomInt(max int) int {
+ //rand.Seed(time.Now().UnixNano())
+ return rand.Intn(max)
+}
+
+func GetTimestamp() int64 {
+ return time.Now().Unix()
+}
+
+func GetTimeString() string {
+ now := time.Now()
+ return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
+}
+
+func Max(a int, b int) int {
+ if a >= b {
+ return a
+ } else {
+ return b
+ }
+}
+
+func MessageWithRequestId(message string, id string) string {
+ return fmt.Sprintf("%s (request id: %s)", message, id)
+}
+
+func RandomSleep() {
+ // Sleep for 0-3000 ms
+ time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
+}
+
+func GetPointer[T any](v T) *T {
+ return &v
+}
+
+func Any2Type[T any](data any) (T, error) {
+ var zero T
+ bytes, err := json.Marshal(data)
+ if err != nil {
+ return zero, err
+ }
+ var res T
+ err = json.Unmarshal(bytes, &res)
+ if err != nil {
+ return zero, err
+ }
+ return res, nil
+}
+
+// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
+func SaveTmpFile(filename string, data io.Reader) (string, error) {
+ f, err := os.CreateTemp(os.TempDir(), filename)
+ if err != nil {
+ return "", errors.Wrapf(err, "failed to create temporary file %s", filename)
+ }
+ defer f.Close()
+
+ _, err = io.Copy(f, data)
+ if err != nil {
+ return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename)
+ }
+
+ return f.Name(), nil
+}
+
+// GetAudioDuration returns the duration of an audio file in seconds.
+func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
+ // ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
+ c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
+ output, err := c.Output()
+ if err != nil {
+ return 0, errors.Wrap(err, "failed to get audio duration")
+ }
+ durationStr := string(bytes.TrimSpace(output))
+ if durationStr == "N/A" {
+ // Create a temporary output file name
+ tmpFp, err := os.CreateTemp("", "audio-*"+ext)
+ if err != nil {
+ return 0, errors.Wrap(err, "failed to create temporary file")
+ }
+ tmpName := tmpFp.Name()
+ // Close immediately so ffmpeg can open the file on Windows.
+ _ = tmpFp.Close()
+ defer os.Remove(tmpName)
+
+ // ffmpeg -y -i filename -vcodec copy -acodec copy
+ ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
+ if err := ffmpegCmd.Run(); err != nil {
+ return 0, errors.Wrap(err, "failed to run ffmpeg")
+ }
+
+ // Recalculate the duration of the new file
+ c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
+ output, err := c.Output()
+ if err != nil {
+ return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
+ }
+ durationStr = string(bytes.TrimSpace(output))
+ }
+ return strconv.ParseFloat(durationStr, 64)
+}
+
+// BuildURL concatenates base and endpoint, returns the complete url string
+func BuildURL(base string, endpoint string) string {
+ u, err := url.Parse(base)
+ if err != nil {
+ return base + endpoint
+ }
+ end := endpoint
+ if end == "" {
+ end = "/"
+ }
+ ref, err := url.Parse(end)
+ if err != nil {
+ return base + endpoint
+ }
+ return u.ResolveReference(ref).String()
+}
diff --git a/common/validate.go b/common/validate.go
new file mode 100644
index 00000000..b3c78591
--- /dev/null
+++ b/common/validate.go
@@ -0,0 +1,9 @@
+package common
+
+import "github.com/go-playground/validator/v10"
+
+var Validate *validator.Validate
+
+func init() {
+ Validate = validator.New()
+}
diff --git a/common/verification.go b/common/verification.go
new file mode 100644
index 00000000..d8ccd6ea
--- /dev/null
+++ b/common/verification.go
@@ -0,0 +1,77 @@
+package common
+
+import (
+ "github.com/google/uuid"
+ "strings"
+ "sync"
+ "time"
+)
+
+type verificationValue struct {
+ code string
+ time time.Time
+}
+
+const (
+ EmailVerificationPurpose = "v"
+ PasswordResetPurpose = "r"
+)
+
+var verificationMutex sync.Mutex
+var verificationMap map[string]verificationValue
+var verificationMapMaxSize = 10
+var VerificationValidMinutes = 10
+
+func GenerateVerificationCode(length int) string {
+ code := uuid.New().String()
+ code = strings.Replace(code, "-", "", -1)
+ if length == 0 {
+ return code
+ }
+ return code[:length]
+}
+
+func RegisterVerificationCodeWithKey(key string, code string, purpose string) {
+ verificationMutex.Lock()
+ defer verificationMutex.Unlock()
+ verificationMap[purpose+key] = verificationValue{
+ code: code,
+ time: time.Now(),
+ }
+ if len(verificationMap) > verificationMapMaxSize {
+ removeExpiredPairs()
+ }
+}
+
+func VerifyCodeWithKey(key string, code string, purpose string) bool {
+ verificationMutex.Lock()
+ defer verificationMutex.Unlock()
+ value, okay := verificationMap[purpose+key]
+ now := time.Now()
+ if !okay || int(now.Sub(value.time).Seconds()) >= VerificationValidMinutes*60 {
+ return false
+ }
+ return code == value.code
+}
+
+func DeleteKey(key string, purpose string) {
+ verificationMutex.Lock()
+ defer verificationMutex.Unlock()
+ delete(verificationMap, purpose+key)
+}
+
+// no lock inside, so the caller must lock the verificationMap before calling!
+func removeExpiredPairs() {
+ now := time.Now()
+ for key := range verificationMap {
+ if int(now.Sub(verificationMap[key].time).Seconds()) >= VerificationValidMinutes*60 {
+ delete(verificationMap, key)
+ }
+ }
+}
+
+func init() {
+ verificationMutex.Lock()
+ defer verificationMutex.Unlock()
+ verificationMap = make(map[string]verificationValue)
+}
diff --git a/constant/README.md b/constant/README.md
new file mode 100644
index 00000000..12a9ffad
--- /dev/null
+++ b/constant/README.md
@@ -0,0 +1,26 @@
+# constant 包 (`/constant`)
+
+该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
+
+## 当前文件
+
+| 文件 | 说明 |
+|----------------------|---------------------------------------------------------------------|
+| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 |
+| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 |
+| `channel_setting.go` | Channel 级别的设置键,如 `proxy`、`force_format` 等。 |
+| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量(请求时间、Token/Channel/User 相关信息等)。 |
+| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 |
+| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 |
+| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 |
+| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 |
+| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 |
+| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 |
+
+## 使用约定
+
+1. `constant` 包**只能被其他包引用**(import),**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**。
+2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
+3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
+
+> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。
\ No newline at end of file
diff --git a/constant/api_type.go b/constant/api_type.go
new file mode 100644
index 00000000..6ba5f257
--- /dev/null
+++ b/constant/api_type.go
@@ -0,0 +1,35 @@
+package constant
+
+const (
+ APITypeOpenAI = iota
+ APITypeAnthropic
+ APITypePaLM
+ APITypeBaidu
+ APITypeZhipu
+ APITypeAli
+ APITypeXunfei
+ APITypeAIProxyLibrary
+ APITypeTencent
+ APITypeGemini
+ APITypeZhipuV4
+ APITypeOllama
+ APITypePerplexity
+ APITypeAws
+ APITypeCohere
+ APITypeDify
+ APITypeJina
+ APITypeCloudflare
+ APITypeSiliconFlow
+ APITypeVertexAi
+ APITypeMistral
+ APITypeDeepSeek
+ APITypeMokaAI
+ APITypeVolcEngine
+ APITypeBaiduV2
+ APITypeOpenRouter
+ APITypeXinference
+ APITypeXai
+ APITypeCoze
+ APITypeJimeng
+ APITypeDummy // this one is only for count, do not add any channel after this
+)
diff --git a/constant/azure.go b/constant/azure.go
new file mode 100644
index 00000000..d84040ce
--- /dev/null
+++ b/constant/azure.go
@@ -0,0 +1,5 @@
+package constant
+
+import "time"
+
+var AzureNoRemoveDotTime = time.Date(2025, time.May, 10, 0, 0, 0, 0, time.UTC).Unix()
diff --git a/constant/cache_key.go b/constant/cache_key.go
new file mode 100644
index 00000000..0601396a
--- /dev/null
+++ b/constant/cache_key.go
@@ -0,0 +1,14 @@
+package constant
+
+// Cache keys
+const (
+ UserGroupKeyFmt = "user_group:%d"
+ UserQuotaKeyFmt = "user_quota:%d"
+ UserEnabledKeyFmt = "user_enabled:%d"
+ UserUsernameKeyFmt = "user_name:%d"
+)
+
+const (
+ TokenFiledRemainQuota = "RemainQuota"
+ TokenFieldGroup = "Group"
+)
diff --git a/constant/channel.go b/constant/channel.go
new file mode 100644
index 00000000..224121e7
--- /dev/null
+++ b/constant/channel.go
@@ -0,0 +1,109 @@
+package constant
+
+const (
+ ChannelTypeUnknown = 0
+ ChannelTypeOpenAI = 1
+ ChannelTypeMidjourney = 2
+ ChannelTypeAzure = 3
+ ChannelTypeOllama = 4
+ ChannelTypeMidjourneyPlus = 5
+ ChannelTypeOpenAIMax = 6
+ ChannelTypeOhMyGPT = 7
+ ChannelTypeCustom = 8
+ ChannelTypeAILS = 9
+ ChannelTypeAIProxy = 10
+ ChannelTypePaLM = 11
+ ChannelTypeAPI2GPT = 12
+ ChannelTypeAIGC2D = 13
+ ChannelTypeAnthropic = 14
+ ChannelTypeBaidu = 15
+ ChannelTypeZhipu = 16
+ ChannelTypeAli = 17
+ ChannelTypeXunfei = 18
+ ChannelType360 = 19
+ ChannelTypeOpenRouter = 20
+ ChannelTypeAIProxyLibrary = 21
+ ChannelTypeFastGPT = 22
+ ChannelTypeTencent = 23
+ ChannelTypeGemini = 24
+ ChannelTypeMoonshot = 25
+ ChannelTypeZhipu_v4 = 26
+ ChannelTypePerplexity = 27
+ ChannelTypeLingYiWanWu = 31
+ ChannelTypeAws = 33
+ ChannelTypeCohere = 34
+ ChannelTypeMiniMax = 35
+ ChannelTypeSunoAPI = 36
+ ChannelTypeDify = 37
+ ChannelTypeJina = 38
+ ChannelCloudflare = 39
+ ChannelTypeSiliconFlow = 40
+ ChannelTypeVertexAi = 41
+ ChannelTypeMistral = 42
+ ChannelTypeDeepSeek = 43
+ ChannelTypeMokaAI = 44
+ ChannelTypeVolcEngine = 45
+ ChannelTypeBaiduV2 = 46
+ ChannelTypeXinference = 47
+ ChannelTypeXai = 48
+ ChannelTypeCoze = 49
+ ChannelTypeKling = 50
+ ChannelTypeJimeng = 51
+ ChannelTypeDummy // this one is only for count, do not add any channel after this
+
+)
+
+var ChannelBaseURLs = []string{
+ "", // 0
+ "https://api.openai.com", // 1
+ "https://oa.api2d.net", // 2
+ "", // 3
+ "http://localhost:11434", // 4
+ "https://api.openai-sb.com", // 5
+ "https://api.openaimax.com", // 6
+ "https://api.ohmygpt.com", // 7
+ "", // 8
+ "https://api.caipacity.com", // 9
+ "https://api.aiproxy.io", // 10
+ "", // 11
+ "https://api.api2gpt.com", // 12
+ "https://api.aigc2d.com", // 13
+ "https://api.anthropic.com", // 14
+ "https://aip.baidubce.com", // 15
+ "https://open.bigmodel.cn", // 16
+ "https://dashscope.aliyuncs.com", // 17
+ "", // 18
+ "https://api.360.cn", // 19
+ "https://openrouter.ai/api", // 20
+ "https://api.aiproxy.io", // 21
+ "https://fastgpt.run/api/openapi", // 22
+ "https://hunyuan.tencentcloudapi.com", //23
+ "https://generativelanguage.googleapis.com", //24
+ "https://api.moonshot.cn", //25
+ "https://open.bigmodel.cn", //26
+ "https://api.perplexity.ai", //27
+ "", //28
+ "", //29
+ "", //30
+ "https://api.lingyiwanwu.com", //31
+ "", //32
+ "", //33
+ "https://api.cohere.ai", //34
+ "https://api.minimax.chat", //35
+ "", //36
+ "https://api.dify.ai", //37
+ "https://api.jina.ai", //38
+ "https://api.cloudflare.com", //39
+ "https://api.siliconflow.cn", //40
+ "", //41
+ "https://api.mistral.ai", //42
+ "https://api.deepseek.com", //43
+ "https://api.moka.ai", //44
+ "https://ark.cn-beijing.volces.com", //45
+ "https://qianfan.baidubce.com", //46
+ "", //47
+ "https://api.x.ai", //48
+ "https://api.coze.cn", //49
+ "https://api.klingai.com", //50
+ "https://visual.volcengineapi.com", //51
+}
diff --git a/constant/context_key.go b/constant/context_key.go
new file mode 100644
index 00000000..4eaf3d00
--- /dev/null
+++ b/constant/context_key.go
@@ -0,0 +1,44 @@
+package constant
+
+type ContextKey string
+
+const (
+ ContextKeyOriginalModel ContextKey = "original_model"
+ ContextKeyRequestStartTime ContextKey = "request_start_time"
+
+ /* token related keys */
+ ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota"
+ ContextKeyTokenKey ContextKey = "token_key"
+ ContextKeyTokenId ContextKey = "token_id"
+ ContextKeyTokenGroup ContextKey = "token_group"
+ ContextKeyTokenAllowIps ContextKey = "allow_ips"
+ ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
+ ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
+ ContextKeyTokenModelLimit ContextKey = "token_model_limit"
+
+ /* channel related keys */
+ ContextKeyChannelId ContextKey = "channel_id"
+ ContextKeyChannelName ContextKey = "channel_name"
+ ContextKeyChannelCreateTime ContextKey = "channel_create_time"
+ ContextKeyChannelBaseUrl ContextKey = "base_url"
+ ContextKeyChannelType ContextKey = "channel_type"
+ ContextKeyChannelSetting ContextKey = "channel_setting"
+ ContextKeyChannelParamOverride ContextKey = "param_override"
+ ContextKeyChannelOrganization ContextKey = "channel_organization"
+ ContextKeyChannelAutoBan ContextKey = "auto_ban"
+ ContextKeyChannelModelMapping ContextKey = "model_mapping"
+ ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
+ ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
+ ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
+ ContextKeyChannelKey ContextKey = "channel_key"
+
+ /* user related keys */
+ ContextKeyUserId ContextKey = "id"
+ ContextKeyUserSetting ContextKey = "user_setting"
+ ContextKeyUserQuota ContextKey = "user_quota"
+ ContextKeyUserStatus ContextKey = "user_status"
+ ContextKeyUserEmail ContextKey = "user_email"
+ ContextKeyUserGroup ContextKey = "user_group"
+ ContextKeyUsingGroup ContextKey = "group"
+ ContextKeyUserName ContextKey = "username"
+)
diff --git a/constant/endpoint_type.go b/constant/endpoint_type.go
new file mode 100644
index 00000000..ef096b75
--- /dev/null
+++ b/constant/endpoint_type.go
@@ -0,0 +1,16 @@
+package constant
+
+type EndpointType string
+
+const (
+ EndpointTypeOpenAI EndpointType = "openai"
+ EndpointTypeOpenAIResponse EndpointType = "openai-response"
+ EndpointTypeAnthropic EndpointType = "anthropic"
+ EndpointTypeGemini EndpointType = "gemini"
+ EndpointTypeJinaRerank EndpointType = "jina-rerank"
+ EndpointTypeImageGeneration EndpointType = "image-generation"
+ //EndpointTypeMidjourney EndpointType = "midjourney-proxy"
+ //EndpointTypeSuno EndpointType = "suno-proxy"
+ //EndpointTypeKling EndpointType = "kling"
+ //EndpointTypeJimeng EndpointType = "jimeng"
+)
diff --git a/constant/env.go b/constant/env.go
new file mode 100644
index 00000000..8bc2f131
--- /dev/null
+++ b/constant/env.go
@@ -0,0 +1,15 @@
+package constant
+
+var StreamingTimeout int
+var DifyDebug bool
+var MaxFileDownloadMB int
+var ForceStreamOption bool
+var GetMediaToken bool
+var GetMediaTokenNotStream bool
+var UpdateTask bool
+var AzureDefaultAPIVersion string
+var GeminiVisionMaxImageNum int
+var NotifyLimitCount int
+var NotificationLimitDurationMinute int
+var GenerateDefaultToken bool
+var ErrorLogEnabled bool
diff --git a/constant/finish_reason.go b/constant/finish_reason.go
new file mode 100644
index 00000000..5a752a5f
--- /dev/null
+++ b/constant/finish_reason.go
@@ -0,0 +1,9 @@
+package constant
+
+var (
+ FinishReasonStop = "stop"
+ FinishReasonToolCalls = "tool_calls"
+ FinishReasonLength = "length"
+ FinishReasonFunctionCall = "function_call"
+ FinishReasonContentFilter = "content_filter"
+)
diff --git a/constant/midjourney.go b/constant/midjourney.go
new file mode 100644
index 00000000..5934be2f
--- /dev/null
+++ b/constant/midjourney.go
@@ -0,0 +1,48 @@
+package constant
+
+const (
+ MjErrorUnknown = 5
+ MjRequestError = 4
+)
+
+const (
+ MjActionImagine = "IMAGINE"
+ MjActionDescribe = "DESCRIBE"
+ MjActionBlend = "BLEND"
+ MjActionUpscale = "UPSCALE"
+ MjActionVariation = "VARIATION"
+ MjActionReRoll = "REROLL"
+ MjActionInPaint = "INPAINT"
+ MjActionModal = "MODAL"
+ MjActionZoom = "ZOOM"
+ MjActionCustomZoom = "CUSTOM_ZOOM"
+ MjActionShorten = "SHORTEN"
+ MjActionHighVariation = "HIGH_VARIATION"
+ MjActionLowVariation = "LOW_VARIATION"
+ MjActionPan = "PAN"
+ MjActionSwapFace = "SWAP_FACE"
+ MjActionUpload = "UPLOAD"
+ MjActionVideo = "VIDEO"
+ MjActionEdits = "EDITS"
+)
+
+var MidjourneyModel2Action = map[string]string{
+ "mj_imagine": MjActionImagine,
+ "mj_describe": MjActionDescribe,
+ "mj_blend": MjActionBlend,
+ "mj_upscale": MjActionUpscale,
+ "mj_variation": MjActionVariation,
+ "mj_reroll": MjActionReRoll,
+ "mj_modal": MjActionModal,
+ "mj_inpaint": MjActionInPaint,
+ "mj_zoom": MjActionZoom,
+ "mj_custom_zoom": MjActionCustomZoom,
+ "mj_shorten": MjActionShorten,
+ "mj_high_variation": MjActionHighVariation,
+ "mj_low_variation": MjActionLowVariation,
+ "mj_pan": MjActionPan,
+ "swap_face": MjActionSwapFace,
+ "mj_upload": MjActionUpload,
+ "mj_video": MjActionVideo,
+ "mj_edits": MjActionEdits,
+}
diff --git a/constant/multi_key_mode.go b/constant/multi_key_mode.go
new file mode 100644
index 00000000..cd0cdbff
--- /dev/null
+++ b/constant/multi_key_mode.go
@@ -0,0 +1,8 @@
+package constant
+
+type MultiKeyMode string
+
+const (
+ MultiKeyModeRandom MultiKeyMode = "random" // 随机
+ MultiKeyModePolling MultiKeyMode = "polling" // 轮询
+)
diff --git a/constant/setup.go b/constant/setup.go
new file mode 100644
index 00000000..26ecc883
--- /dev/null
+++ b/constant/setup.go
@@ -0,0 +1,3 @@
+package constant
+
+var Setup = false
diff --git a/constant/task.go b/constant/task.go
new file mode 100644
index 00000000..e7af39a6
--- /dev/null
+++ b/constant/task.go
@@ -0,0 +1,23 @@
+package constant
+
+type TaskPlatform string
+
+const (
+ TaskPlatformSuno TaskPlatform = "suno"
+ TaskPlatformMidjourney = "mj"
+ TaskPlatformKling TaskPlatform = "kling"
+ TaskPlatformJimeng TaskPlatform = "jimeng"
+)
+
+const (
+ SunoActionMusic = "MUSIC"
+ SunoActionLyrics = "LYRICS"
+
+ TaskActionGenerate = "generate"
+ TaskActionTextGenerate = "textGenerate"
+)
+
+var SunoModel2Action = map[string]string{
+ "suno_music": SunoActionMusic,
+ "suno_lyrics": SunoActionLyrics,
+}
diff --git a/controller/billing.go b/controller/billing.go
new file mode 100644
index 00000000..1fb83633
--- /dev/null
+++ b/controller/billing.go
@@ -0,0 +1,92 @@
+package controller
+
+import (
+ "github.com/gin-gonic/gin"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/model"
+)
+
+func GetSubscription(c *gin.Context) {
+ var remainQuota int
+ var usedQuota int
+ var err error
+ var token *model.Token
+ var expiredTime int64
+ if common.DisplayTokenStatEnabled {
+ tokenId := c.GetInt("token_id")
+ token, err = model.GetTokenById(tokenId)
+ expiredTime = token.ExpiredTime
+ remainQuota = token.RemainQuota
+ usedQuota = token.UsedQuota
+ } else {
+ userId := c.GetInt("id")
+ remainQuota, err = model.GetUserQuota(userId, false)
+ usedQuota, err = model.GetUserUsedQuota(userId)
+ }
+ if expiredTime <= 0 {
+ expiredTime = 0
+ }
+ if err != nil {
+ openAIError := dto.OpenAIError{
+ Message: err.Error(),
+ Type: "upstream_error",
+ }
+ c.JSON(200, gin.H{
+ "error": openAIError,
+ })
+ return
+ }
+ quota := remainQuota + usedQuota
+ amount := float64(quota)
+ if common.DisplayInCurrencyEnabled {
+ amount /= common.QuotaPerUnit
+ }
+ if token != nil && token.UnlimitedQuota {
+ amount = 100000000
+ }
+ subscription := OpenAISubscriptionResponse{
+ Object: "billing_subscription",
+ HasPaymentMethod: true,
+ SoftLimitUSD: amount,
+ HardLimitUSD: amount,
+ SystemHardLimitUSD: amount,
+ AccessUntil: expiredTime,
+ }
+ c.JSON(200, subscription)
+ return
+}
+
+func GetUsage(c *gin.Context) {
+ var quota int
+ var err error
+ var token *model.Token
+ if common.DisplayTokenStatEnabled {
+ tokenId := c.GetInt("token_id")
+ token, err = model.GetTokenById(tokenId)
+ quota = token.UsedQuota
+ } else {
+ userId := c.GetInt("id")
+ quota, err = model.GetUserUsedQuota(userId)
+ }
+ if err != nil {
+ openAIError := dto.OpenAIError{
+ Message: err.Error(),
+ Type: "new_api_error",
+ }
+ c.JSON(200, gin.H{
+ "error": openAIError,
+ })
+ return
+ }
+ amount := float64(quota)
+ if common.DisplayInCurrencyEnabled {
+ amount /= common.QuotaPerUnit
+ }
+ usage := OpenAIUsageResponse{
+ Object: "list",
+ TotalUsage: amount * 100,
+ }
+ c.JSON(200, usage)
+ return
+}
diff --git a/controller/channel-billing.go b/controller/channel-billing.go
new file mode 100644
index 00000000..5152e060
--- /dev/null
+++ b/controller/channel-billing.go
@@ -0,0 +1,492 @@
+package controller
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/model"
+ "one-api/service"
+ "one-api/setting"
+ "one-api/types"
+ "strconv"
+ "time"
+
+ "github.com/shopspring/decimal"
+
+ "github.com/gin-gonic/gin"
+)
+
+// https://github.com/songquanpeng/one-api/issues/79
+
+type OpenAISubscriptionResponse struct {
+ Object string `json:"object"`
+ HasPaymentMethod bool `json:"has_payment_method"`
+ SoftLimitUSD float64 `json:"soft_limit_usd"`
+ HardLimitUSD float64 `json:"hard_limit_usd"`
+ SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
+ AccessUntil int64 `json:"access_until"`
+}
+
+type OpenAIUsageDailyCost struct {
+ Timestamp float64 `json:"timestamp"`
+ LineItems []struct {
+ Name string `json:"name"`
+ Cost float64 `json:"cost"`
+ }
+}
+
+type OpenAICreditGrants struct {
+ Object string `json:"object"`
+ TotalGranted float64 `json:"total_granted"`
+ TotalUsed float64 `json:"total_used"`
+ TotalAvailable float64 `json:"total_available"`
+}
+
+type OpenAIUsageResponse struct {
+ Object string `json:"object"`
+ //DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
+ TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
+}
+
+type OpenAISBUsageResponse struct {
+ Msg string `json:"msg"`
+ Data *struct {
+ Credit string `json:"credit"`
+ } `json:"data"`
+}
+
+type AIProxyUserOverviewResponse struct {
+ Success bool `json:"success"`
+ Message string `json:"message"`
+ ErrorCode int `json:"error_code"`
+ Data struct {
+ TotalPoints float64 `json:"totalPoints"`
+ } `json:"data"`
+}
+
+type API2GPTUsageResponse struct {
+ Object string `json:"object"`
+ TotalGranted float64 `json:"total_granted"`
+ TotalUsed float64 `json:"total_used"`
+ TotalRemaining float64 `json:"total_remaining"`
+}
+
+type APGC2DGPTUsageResponse struct {
+ //Grants interface{} `json:"grants"`
+ Object string `json:"object"`
+ TotalAvailable float64 `json:"total_available"`
+ TotalGranted float64 `json:"total_granted"`
+ TotalUsed float64 `json:"total_used"`
+}
+
+type SiliconFlowUsageResponse struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Status bool `json:"status"`
+ Data struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Image string `json:"image"`
+ Email string `json:"email"`
+ IsAdmin bool `json:"isAdmin"`
+ Balance string `json:"balance"`
+ Status string `json:"status"`
+ Introduction string `json:"introduction"`
+ Role string `json:"role"`
+ ChargeBalance string `json:"chargeBalance"`
+ TotalBalance string `json:"totalBalance"`
+ Category string `json:"category"`
+ } `json:"data"`
+}
+
+type DeepSeekUsageResponse struct {
+ IsAvailable bool `json:"is_available"`
+ BalanceInfos []struct {
+ Currency string `json:"currency"`
+ TotalBalance string `json:"total_balance"`
+ GrantedBalance string `json:"granted_balance"`
+ ToppedUpBalance string `json:"topped_up_balance"`
+ } `json:"balance_infos"`
+}
+
+type OpenRouterCreditResponse struct {
+ Data struct {
+ TotalCredits float64 `json:"total_credits"`
+ TotalUsage float64 `json:"total_usage"`
+ } `json:"data"`
+}
+
+// GetAuthHeader get auth header
+func GetAuthHeader(token string) http.Header {
+ h := http.Header{}
+ h.Add("Authorization", fmt.Sprintf("Bearer %s", token))
+ return h
+}
+
+func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
+ req, err := http.NewRequest(method, url, nil)
+ if err != nil {
+ return nil, err
+ }
+ for k := range headers {
+ req.Header.Add(k, headers.Get(k))
+ }
+ res, err := service.GetHttpClient().Do(req)
+ if err != nil {
+ return nil, err
+ }
+ if res.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("status code: %d", res.StatusCode)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ return nil, err
+ }
+ err = res.Body.Close()
+ if err != nil {
+ return nil, err
+ }
+ return body, nil
+}
+
+func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
+ url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+
+ if err != nil {
+ return 0, err
+ }
+ response := OpenAICreditGrants{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ return 0, err
+ }
+ channel.UpdateBalance(response.TotalAvailable)
+ return response.TotalAvailable, nil
+}
+
+func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
+ url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+ if err != nil {
+ return 0, err
+ }
+ response := OpenAISBUsageResponse{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ return 0, err
+ }
+ if response.Data == nil {
+ return 0, errors.New(response.Msg)
+ }
+ balance, err := strconv.ParseFloat(response.Data.Credit, 64)
+ if err != nil {
+ return 0, err
+ }
+ channel.UpdateBalance(balance)
+ return balance, nil
+}
+
+func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) {
+ url := "https://aiproxy.io/api/report/getUserOverview"
+ headers := http.Header{}
+ headers.Add("Api-Key", channel.Key)
+ body, err := GetResponseBody("GET", url, channel, headers)
+ if err != nil {
+ return 0, err
+ }
+ response := AIProxyUserOverviewResponse{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ return 0, err
+ }
+ if !response.Success {
+ return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
+ }
+ channel.UpdateBalance(response.Data.TotalPoints)
+ return response.Data.TotalPoints, nil
+}
+
+func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) {
+ url := "https://api.api2gpt.com/dashboard/billing/credit_grants"
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+
+ if err != nil {
+ return 0, err
+ }
+ response := API2GPTUsageResponse{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ return 0, err
+ }
+ channel.UpdateBalance(response.TotalRemaining)
+ return response.TotalRemaining, nil
+}
+
+func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
+ url := "https://api.siliconflow.cn/v1/user/info"
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+ if err != nil {
+ return 0, err
+ }
+ response := SiliconFlowUsageResponse{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ return 0, err
+ }
+ if response.Code != 20000 {
+ return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
+ }
+ balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64)
+ if err != nil {
+ return 0, err
+ }
+ channel.UpdateBalance(balance)
+ return balance, nil
+}
+
+func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) {
+ url := "https://api.deepseek.com/user/balance"
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+ if err != nil {
+ return 0, err
+ }
+ response := DeepSeekUsageResponse{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ return 0, err
+ }
+ index := -1
+ for i, balanceInfo := range response.BalanceInfos {
+ if balanceInfo.Currency == "CNY" {
+ index = i
+ break
+ }
+ }
+ if index == -1 {
+ return 0, errors.New("currency CNY not found")
+ }
+ balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64)
+ if err != nil {
+ return 0, err
+ }
+ channel.UpdateBalance(balance)
+ return balance, nil
+}
+
+func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
+ url := "https://api.aigc2d.com/dashboard/billing/credit_grants"
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+ if err != nil {
+ return 0, err
+ }
+ response := APGC2DGPTUsageResponse{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ return 0, err
+ }
+ channel.UpdateBalance(response.TotalAvailable)
+ return response.TotalAvailable, nil
+}
+
+func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
+ url := "https://openrouter.ai/api/v1/credits"
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+ if err != nil {
+ return 0, err
+ }
+ response := OpenRouterCreditResponse{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ return 0, err
+ }
+ balance := response.Data.TotalCredits - response.Data.TotalUsage
+ channel.UpdateBalance(balance)
+ return balance, nil
+}
+
+func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
+ url := "https://api.moonshot.cn/v1/users/me/balance"
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+ if err != nil {
+ return 0, err
+ }
+
+ type MoonshotBalanceData struct {
+ AvailableBalance float64 `json:"available_balance"`
+ VoucherBalance float64 `json:"voucher_balance"`
+ CashBalance float64 `json:"cash_balance"`
+ }
+
+ type MoonshotBalanceResponse struct {
+ Code int `json:"code"`
+ Data MoonshotBalanceData `json:"data"`
+ Scode string `json:"scode"`
+ Status bool `json:"status"`
+ }
+
+ response := MoonshotBalanceResponse{}
+ err = json.Unmarshal(body, &response)
+ if err != nil {
+ return 0, err
+ }
+ if !response.Status || response.Code != 0 {
+ return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
+ }
+ availableBalanceCny := response.Data.AvailableBalance
+ availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
+ channel.UpdateBalance(availableBalanceUsd)
+ return availableBalanceUsd, nil
+}
+
+func updateChannelBalance(channel *model.Channel) (float64, error) {
+ baseURL := constant.ChannelBaseURLs[channel.Type]
+ if channel.GetBaseURL() == "" {
+ channel.BaseURL = &baseURL
+ }
+ switch channel.Type {
+ case constant.ChannelTypeOpenAI:
+ if channel.GetBaseURL() != "" {
+ baseURL = channel.GetBaseURL()
+ }
+ case constant.ChannelTypeAzure:
+ return 0, errors.New("尚未实现")
+ case constant.ChannelTypeCustom:
+ baseURL = channel.GetBaseURL()
+ //case common.ChannelTypeOpenAISB:
+ // return updateChannelOpenAISBBalance(channel)
+ case constant.ChannelTypeAIProxy:
+ return updateChannelAIProxyBalance(channel)
+ case constant.ChannelTypeAPI2GPT:
+ return updateChannelAPI2GPTBalance(channel)
+ case constant.ChannelTypeAIGC2D:
+ return updateChannelAIGC2DBalance(channel)
+ case constant.ChannelTypeSiliconFlow:
+ return updateChannelSiliconFlowBalance(channel)
+ case constant.ChannelTypeDeepSeek:
+ return updateChannelDeepSeekBalance(channel)
+ case constant.ChannelTypeOpenRouter:
+ return updateChannelOpenRouterBalance(channel)
+ case constant.ChannelTypeMoonshot:
+ return updateChannelMoonshotBalance(channel)
+ default:
+ return 0, errors.New("尚未实现")
+ }
+ url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
+
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+ if err != nil {
+ return 0, err
+ }
+ subscription := OpenAISubscriptionResponse{}
+ err = json.Unmarshal(body, &subscription)
+ if err != nil {
+ return 0, err
+ }
+ now := time.Now()
+ startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
+ endDate := now.Format("2006-01-02")
+ if !subscription.HasPaymentMethod {
+ startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
+ }
+ url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
+ body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+ if err != nil {
+ return 0, err
+ }
+ usage := OpenAIUsageResponse{}
+ err = json.Unmarshal(body, &usage)
+ if err != nil {
+ return 0, err
+ }
+ balance := subscription.HardLimitUSD - usage.TotalUsage/100
+ channel.UpdateBalance(balance)
+ return balance, nil
+}
+
+func UpdateChannelBalance(c *gin.Context) {
+ id, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ channel, err := model.CacheGetChannel(id)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if channel.ChannelInfo.IsMultiKey {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "多密钥渠道不支持余额查询",
+ })
+ return
+ }
+ balance, err := updateChannelBalance(channel)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "balance": balance,
+ })
+}
+
+func updateAllChannelsBalance() error {
+ channels, err := model.GetAllChannels(0, 0, true, false)
+ if err != nil {
+ return err
+ }
+ for _, channel := range channels {
+ if channel.Status != common.ChannelStatusEnabled {
+ continue
+ }
+ if channel.ChannelInfo.IsMultiKey {
+ continue // skip multi-key channels
+ }
+ // TODO: support Azure
+ //if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
+ // continue
+ //}
+ balance, err := updateChannelBalance(channel)
+ if err != nil {
+ continue
+ } else {
+ // err is nil & balance <= 0 means quota is used up
+ if balance <= 0 {
+ service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足")
+ }
+ }
+ time.Sleep(common.RequestInterval)
+ }
+ return nil
+}
+
+func UpdateAllChannelsBalance(c *gin.Context) {
+ // TODO: make it async
+ err := updateAllChannelsBalance()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func AutomaticallyUpdateChannels(frequency int) {
+ for {
+ time.Sleep(time.Duration(frequency) * time.Minute)
+ common.SysLog("updating all channels")
+ _ = updateAllChannelsBalance()
+ common.SysLog("channels update done")
+ }
+}
diff --git a/controller/channel-test.go b/controller/channel-test.go
new file mode 100644
index 00000000..8c4a26ae
--- /dev/null
+++ b/controller/channel-test.go
@@ -0,0 +1,464 @@
+package controller
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/middleware"
+ "one-api/model"
+ "one-api/relay"
+ relaycommon "one-api/relay/common"
+ relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/bytedance/gopkg/util/gopool"
+
+ "github.com/gin-gonic/gin"
+)
+
+type testResult struct {
+ context *gin.Context
+ localErr error
+ newAPIError *types.NewAPIError
+}
+
+func testChannel(channel *model.Channel, testModel string) testResult {
+ tik := time.Now()
+ if channel.Type == constant.ChannelTypeMidjourney {
+ return testResult{
+ localErr: errors.New("midjourney channel test is not supported"),
+ newAPIError: nil,
+ }
+ }
+ if channel.Type == constant.ChannelTypeMidjourneyPlus {
+ return testResult{
+ localErr: errors.New("midjourney plus channel test is not supported"),
+ newAPIError: nil,
+ }
+ }
+ if channel.Type == constant.ChannelTypeSunoAPI {
+ return testResult{
+ localErr: errors.New("suno channel test is not supported"),
+ newAPIError: nil,
+ }
+ }
+ if channel.Type == constant.ChannelTypeKling {
+ return testResult{
+ localErr: errors.New("kling channel test is not supported"),
+ newAPIError: nil,
+ }
+ }
+ if channel.Type == constant.ChannelTypeJimeng {
+ return testResult{
+ localErr: errors.New("jimeng channel test is not supported"),
+ newAPIError: nil,
+ }
+ }
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+
+ requestPath := "/v1/chat/completions"
+
+ // 先判断是否为 Embedding 模型
+ if strings.Contains(strings.ToLower(testModel), "embedding") ||
+ strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
+ strings.Contains(testModel, "bge-") || // bge 系列模型
+ strings.Contains(testModel, "embed") ||
+ channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
+ requestPath = "/v1/embeddings" // 修改请求路径
+ }
+
+ c.Request = &http.Request{
+ Method: "POST",
+ URL: &url.URL{Path: requestPath}, // 使用动态路径
+ Body: nil,
+ Header: make(http.Header),
+ }
+
+ if testModel == "" {
+ if channel.TestModel != nil && *channel.TestModel != "" {
+ testModel = *channel.TestModel
+ } else {
+ if len(channel.GetModels()) > 0 {
+ testModel = channel.GetModels()[0]
+ } else {
+ testModel = "gpt-4o-mini"
+ }
+ }
+ }
+
+ cache, err := model.GetUserCache(1)
+ if err != nil {
+ return testResult{
+ localErr: err,
+ newAPIError: nil,
+ }
+ }
+ cache.WriteContext(c)
+
+ //c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
+ c.Request.Header.Set("Content-Type", "application/json")
+ c.Set("channel", channel.Type)
+ c.Set("base_url", channel.GetBaseURL())
+ group, _ := model.GetUserGroup(1, false)
+ c.Set("group", group)
+
+ newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
+ if newAPIError != nil {
+ return testResult{
+ context: c,
+ localErr: newAPIError,
+ newAPIError: newAPIError,
+ }
+ }
+
+ info := relaycommon.GenRelayInfo(c)
+
+ err = helper.ModelMappedHelper(c, info, nil)
+ if err != nil {
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
+ }
+ }
+ testModel = info.UpstreamModelName
+
+ apiType, _ := common.ChannelType2APIType(channel.Type)
+ adaptor := relay.GetAdaptor(apiType)
+ if adaptor == nil {
+ return testResult{
+ context: c,
+ localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
+ newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
+ }
+ }
+
+ request := buildTestRequest(testModel)
+ // 创建一个用于日志的 info 副本,移除 ApiKey
+ logInfo := *info
+ logInfo.ApiKey = ""
+ common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
+
+ priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
+ if err != nil {
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
+ }
+ }
+
+ adaptor.Init(info)
+
+ var convertedRequest any
+ // 根据 RelayMode 选择正确的转换函数
+ if info.RelayMode == relayconstant.RelayModeEmbeddings {
+ // 创建一个 EmbeddingRequest
+ embeddingRequest := dto.EmbeddingRequest{
+ Input: request.Input,
+ Model: request.Model,
+ }
+ // 调用专门用于 Embedding 的转换函数
+ convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
+ } else {
+ // 对其他所有请求类型(如 Chat),保持原有逻辑
+ convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
+ }
+
+ if err != nil {
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
+ }
+ }
+ jsonData, err := json.Marshal(convertedRequest)
+ if err != nil {
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
+ }
+ }
+ requestBody := bytes.NewBuffer(jsonData)
+ c.Request.Body = io.NopCloser(requestBody)
+ resp, err := adaptor.DoRequest(c, info, requestBody)
+ if err != nil {
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed),
+ }
+ }
+ var httpResp *http.Response
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ if httpResp.StatusCode != http.StatusOK {
+ err := service.RelayErrorHandler(httpResp, true)
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewError(err, types.ErrorCodeBadResponse),
+ }
+ }
+ }
+ usageA, respErr := adaptor.DoResponse(c, httpResp, info)
+ if respErr != nil {
+ return testResult{
+ context: c,
+ localErr: respErr,
+ newAPIError: respErr,
+ }
+ }
+ if usageA == nil {
+ return testResult{
+ context: c,
+ localErr: errors.New("usage is nil"),
+ newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody),
+ }
+ }
+ usage := usageA.(*dto.Usage)
+ result := w.Result()
+ respBody, err := io.ReadAll(result.Body)
+ if err != nil {
+ return testResult{
+ context: c,
+ localErr: err,
+ newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed),
+ }
+ }
+ info.PromptTokens = usage.PromptTokens
+
+ quota := 0
+ if !priceData.UsePrice {
+ quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
+ quota = int(math.Round(float64(quota) * priceData.ModelRatio))
+ if priceData.ModelRatio != 0 && quota <= 0 {
+ quota = 1
+ }
+ } else {
+ quota = int(priceData.ModelPrice * common.QuotaPerUnit)
+ }
+ tok := time.Now()
+ milliseconds := tok.Sub(tik).Milliseconds()
+ consumedTime := float64(milliseconds) / 1000.0
+ other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
+ usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+ model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
+ ChannelId: channel.Id,
+ PromptTokens: usage.PromptTokens,
+ CompletionTokens: usage.CompletionTokens,
+ ModelName: info.OriginModelName,
+ TokenName: "模型测试",
+ Quota: quota,
+ Content: "模型测试",
+ UseTimeSeconds: int(consumedTime),
+ IsStream: false,
+ Group: info.UsingGroup,
+ Other: other,
+ })
+ common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
+ return testResult{
+ context: c,
+ localErr: nil,
+ newAPIError: nil,
+ }
+}
+
+func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
+ testRequest := &dto.GeneralOpenAIRequest{
+ Model: "", // this will be set later
+ Stream: false,
+ }
+
+ // 先判断是否为 Embedding 模型
+ if strings.Contains(strings.ToLower(model), "embedding") || // 其他 embedding 模型
+ strings.HasPrefix(model, "m3e") || // m3e 系列模型
+ strings.Contains(model, "bge-") {
+ testRequest.Model = model
+ // Embedding 请求
+ testRequest.Input = []any{"hello world"} // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
+ return testRequest
+ }
+ // 并非Embedding 模型
+ if strings.HasPrefix(model, "o") {
+ testRequest.MaxCompletionTokens = 10
+ } else if strings.Contains(model, "thinking") {
+ if !strings.Contains(model, "claude") {
+ testRequest.MaxTokens = 50
+ }
+ } else if strings.Contains(model, "gemini") {
+ testRequest.MaxTokens = 3000
+ } else {
+ testRequest.MaxTokens = 10
+ }
+
+ testMessage := dto.Message{
+ Role: "user",
+ Content: "hi",
+ }
+ testRequest.Model = model
+ testRequest.Messages = append(testRequest.Messages, testMessage)
+ return testRequest
+}
+
+func TestChannel(c *gin.Context) {
+ channelId, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ channel, err := model.CacheGetChannel(channelId)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ //defer func() {
+ // if channel.ChannelInfo.IsMultiKey {
+ // go func() { _ = channel.SaveChannelInfo() }()
+ // }
+ //}()
+ testModel := c.Query("model")
+ tik := time.Now()
+ result := testChannel(channel, testModel)
+ if result.localErr != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": result.localErr.Error(),
+ "time": 0.0,
+ })
+ return
+ }
+ tok := time.Now()
+ milliseconds := tok.Sub(tik).Milliseconds()
+ go channel.UpdateResponseTime(milliseconds)
+ consumedTime := float64(milliseconds) / 1000.0
+ if result.newAPIError != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": result.newAPIError.Error(),
+ "time": consumedTime,
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "time": consumedTime,
+ })
+ return
+}
+
+var testAllChannelsLock sync.Mutex
+var testAllChannelsRunning bool = false
+
+func testAllChannels(notify bool) error {
+
+ testAllChannelsLock.Lock()
+ if testAllChannelsRunning {
+ testAllChannelsLock.Unlock()
+ return errors.New("测试已在运行中")
+ }
+ testAllChannelsRunning = true
+ testAllChannelsLock.Unlock()
+ channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
+ if getChannelErr != nil {
+ return getChannelErr
+ }
+ var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
+ if disableThreshold == 0 {
+ disableThreshold = 10000000 // a impossible value
+ }
+ gopool.Go(func() {
+ // 使用 defer 确保无论如何都会重置运行状态,防止死锁
+ defer func() {
+ testAllChannelsLock.Lock()
+ testAllChannelsRunning = false
+ testAllChannelsLock.Unlock()
+ }()
+
+ for _, channel := range channels {
+ isChannelEnabled := channel.Status == common.ChannelStatusEnabled
+ tik := time.Now()
+ result := testChannel(channel, "")
+ tok := time.Now()
+ milliseconds := tok.Sub(tik).Milliseconds()
+
+ shouldBanChannel := false
+ newAPIError := result.newAPIError
+ // request error disables the channel
+ if newAPIError != nil {
+ shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
+ }
+
+ // 当错误检查通过,才检查响应时间
+ if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
+ if milliseconds > disableThreshold {
+ err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
+ newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded)
+ shouldBanChannel = true
+ }
+ }
+
+ // disable channel
+ if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
+ go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
+ }
+
+ // enable channel
+ if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
+ service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
+ }
+
+ channel.UpdateResponseTime(milliseconds)
+ time.Sleep(common.RequestInterval)
+ }
+
+ if notify {
+ service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
+ }
+ })
+ return nil
+}
+
+func TestAllChannels(c *gin.Context) {
+ err := testAllChannels(true)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func AutomaticallyTestChannels(frequency int) {
+ if frequency <= 0 {
+ common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
+ return
+ }
+ for {
+ time.Sleep(time.Duration(frequency) * time.Minute)
+ common.SysLog("testing all channels")
+ _ = testAllChannels(false)
+ common.SysLog("channel test finished")
+ }
+}
diff --git a/controller/channel.go b/controller/channel.go
new file mode 100644
index 00000000..d3bfa202
--- /dev/null
+++ b/controller/channel.go
@@ -0,0 +1,916 @@
+package controller
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/model"
+ "strconv"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+type OpenAIModel struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ OwnedBy string `json:"owned_by"`
+ Permission []struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ AllowCreateEngine bool `json:"allow_create_engine"`
+ AllowSampling bool `json:"allow_sampling"`
+ AllowLogprobs bool `json:"allow_logprobs"`
+ AllowSearchIndices bool `json:"allow_search_indices"`
+ AllowView bool `json:"allow_view"`
+ AllowFineTuning bool `json:"allow_fine_tuning"`
+ Organization string `json:"organization"`
+ Group string `json:"group"`
+ IsBlocking bool `json:"is_blocking"`
+ } `json:"permission"`
+ Root string `json:"root"`
+ Parent string `json:"parent"`
+}
+
+type OpenAIModelsResponse struct {
+ Data []OpenAIModel `json:"data"`
+ Success bool `json:"success"`
+}
+
+func parseStatusFilter(statusParam string) int {
+ switch strings.ToLower(statusParam) {
+ case "enabled", "1":
+ return common.ChannelStatusEnabled
+ case "disabled", "0":
+ return 0
+ default:
+ return -1
+ }
+}
+
+func GetAllChannels(c *gin.Context) {
+ pageInfo := common.GetPageQuery(c)
+ channelData := make([]*model.Channel, 0)
+ idSort, _ := strconv.ParseBool(c.Query("id_sort"))
+ enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
+ statusParam := c.Query("status")
+ // statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
+ statusFilter := parseStatusFilter(statusParam)
+ // type filter
+ typeStr := c.Query("type")
+ typeFilter := -1
+ if typeStr != "" {
+ if t, err := strconv.Atoi(typeStr); err == nil {
+ typeFilter = t
+ }
+ }
+
+ var total int64
+
+ if enableTagMode {
+ tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ for _, tag := range tags {
+ if tag == nil || *tag == "" {
+ continue
+ }
+ tagChannels, err := model.GetChannelsByTag(*tag, idSort)
+ if err != nil {
+ continue
+ }
+ filtered := make([]*model.Channel, 0)
+ for _, ch := range tagChannels {
+ if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
+ continue
+ }
+ if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
+ continue
+ }
+ if typeFilter >= 0 && ch.Type != typeFilter {
+ continue
+ }
+ filtered = append(filtered, ch)
+ }
+ channelData = append(channelData, filtered...)
+ }
+ total, _ = model.CountAllTags()
+ } else {
+ baseQuery := model.DB.Model(&model.Channel{})
+ if typeFilter >= 0 {
+ baseQuery = baseQuery.Where("type = ?", typeFilter)
+ }
+ if statusFilter == common.ChannelStatusEnabled {
+ baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
+ } else if statusFilter == 0 {
+ baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
+ }
+
+ baseQuery.Count(&total)
+
+ order := "priority desc"
+ if idSort {
+ order = "id desc"
+ }
+
+ err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ }
+
+ countQuery := model.DB.Model(&model.Channel{})
+ if statusFilter == common.ChannelStatusEnabled {
+ countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
+ } else if statusFilter == 0 {
+ countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
+ }
+ var results []struct {
+ Type int64
+ Count int64
+ }
+ _ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
+ typeCounts := make(map[int64]int64)
+ for _, r := range results {
+ typeCounts[r.Type] = r.Count
+ }
+ common.ApiSuccess(c, gin.H{
+ "items": channelData,
+ "total": total,
+ "page": pageInfo.GetPage(),
+ "page_size": pageInfo.GetPageSize(),
+ "type_counts": typeCounts,
+ })
+ return
+}
+
+func FetchUpstreamModels(c *gin.Context) {
+ id, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ channel, err := model.GetChannelById(id, true)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ baseURL := constant.ChannelBaseURLs[channel.Type]
+ if channel.GetBaseURL() != "" {
+ baseURL = channel.GetBaseURL()
+ }
+ url := fmt.Sprintf("%s/v1/models", baseURL)
+ switch channel.Type {
+ case constant.ChannelTypeGemini:
+ url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
+ case constant.ChannelTypeAli:
+ url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
+ }
+ body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ var result OpenAIModelsResponse
+ if err = json.Unmarshal(body, &result); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": fmt.Sprintf("解析响应失败: %s", err.Error()),
+ })
+ return
+ }
+
+ var ids []string
+ for _, model := range result.Data {
+ id := model.ID
+ if channel.Type == constant.ChannelTypeGemini {
+ id = strings.TrimPrefix(id, "models/")
+ }
+ ids = append(ids, id)
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": ids,
+ })
+}
+
+func FixChannelsAbilities(c *gin.Context) {
+ success, fails, err := model.FixAbility()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": gin.H{
+ "success": success,
+ "fails": fails,
+ },
+ })
+}
+
+func SearchChannels(c *gin.Context) {
+ keyword := c.Query("keyword")
+ group := c.Query("group")
+ modelKeyword := c.Query("model")
+ statusParam := c.Query("status")
+ statusFilter := parseStatusFilter(statusParam)
+ idSort, _ := strconv.ParseBool(c.Query("id_sort"))
+ enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
+ channelData := make([]*model.Channel, 0)
+ if enableTagMode {
+ tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ for _, tag := range tags {
+ if tag != nil && *tag != "" {
+ tagChannel, err := model.GetChannelsByTag(*tag, idSort)
+ if err == nil {
+ channelData = append(channelData, tagChannel...)
+ }
+ }
+ }
+ } else {
+ channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ channelData = channels
+ }
+
+ if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
+ filtered := make([]*model.Channel, 0, len(channelData))
+ for _, ch := range channelData {
+ if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
+ continue
+ }
+ if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
+ continue
+ }
+ filtered = append(filtered, ch)
+ }
+ channelData = filtered
+ }
+
+ // calculate type counts for search results
+ typeCounts := make(map[int64]int64)
+ for _, channel := range channelData {
+ typeCounts[int64(channel.Type)]++
+ }
+
+ typeParam := c.Query("type")
+ typeFilter := -1
+ if typeParam != "" {
+ if tp, err := strconv.Atoi(typeParam); err == nil {
+ typeFilter = tp
+ }
+ }
+
+ if typeFilter >= 0 {
+ filtered := make([]*model.Channel, 0, len(channelData))
+ for _, ch := range channelData {
+ if ch.Type == typeFilter {
+ filtered = append(filtered, ch)
+ }
+ }
+ channelData = filtered
+ }
+
+ page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
+ pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
+ if page < 1 {
+ page = 1
+ }
+ if pageSize <= 0 {
+ pageSize = 20
+ }
+
+ total := len(channelData)
+ startIdx := (page - 1) * pageSize
+ if startIdx > total {
+ startIdx = total
+ }
+ endIdx := startIdx + pageSize
+ if endIdx > total {
+ endIdx = total
+ }
+
+ pagedData := channelData[startIdx:endIdx]
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": gin.H{
+ "items": pagedData,
+ "total": total,
+ "type_counts": typeCounts,
+ },
+ })
+ return
+}
+
+func GetChannel(c *gin.Context) {
+ id, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ channel, err := model.GetChannelById(id, false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": channel,
+ })
+ return
+}
+
+// validateChannel 通用的渠道校验函数
+func validateChannel(channel *model.Channel, isAdd bool) error {
+ // 校验 channel settings
+ if err := channel.ValidateSettings(); err != nil {
+ return fmt.Errorf("渠道额外设置[channel setting] 格式错误:%s", err.Error())
+ }
+
+ // 如果是添加操作,检查 channel 和 key 是否为空
+ if isAdd {
+ if channel == nil || channel.Key == "" {
+ return fmt.Errorf("channel cannot be empty")
+ }
+
+ // 检查模型名称长度是否超过 255
+ for _, m := range channel.GetModels() {
+ if len(m) > 255 {
+ return fmt.Errorf("模型名称过长: %s", m)
+ }
+ }
+ }
+
+ // VertexAI 特殊校验
+ if channel.Type == constant.ChannelTypeVertexAi {
+ if channel.Other == "" {
+ return fmt.Errorf("部署地区不能为空")
+ }
+
+ regionMap, err := common.StrToMap(channel.Other)
+ if err != nil {
+ return fmt.Errorf("部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}")
+ }
+
+ if regionMap["default"] == nil {
+ return fmt.Errorf("部署地区必须包含default字段")
+ }
+ }
+
+ return nil
+}
+
+type AddChannelRequest struct {
+ Mode string `json:"mode"`
+ MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
+ Channel *model.Channel `json:"channel"`
+}
+
+func getVertexArrayKeys(keys string) ([]string, error) {
+ if keys == "" {
+ return nil, nil
+ }
+ var keyArray []interface{}
+ err := common.Unmarshal([]byte(keys), &keyArray)
+ if err != nil {
+ return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
+ }
+ cleanKeys := make([]string, 0, len(keyArray))
+ for _, key := range keyArray {
+ var keyStr string
+ switch v := key.(type) {
+ case string:
+ keyStr = strings.TrimSpace(v)
+ default:
+ bytes, err := json.Marshal(v)
+ if err != nil {
+ return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
+ }
+ keyStr = string(bytes)
+ }
+ if keyStr != "" {
+ cleanKeys = append(cleanKeys, keyStr)
+ }
+ }
+ if len(cleanKeys) == 0 {
+ return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
+ }
+ return cleanKeys, nil
+}
+
+func AddChannel(c *gin.Context) {
+ addChannelRequest := AddChannelRequest{}
+ err := c.ShouldBindJSON(&addChannelRequest)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ // 使用统一的校验函数
+ if err := validateChannel(addChannelRequest.Channel, true); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
+ keys := make([]string, 0)
+ switch addChannelRequest.Mode {
+ case "multi_to_single":
+ addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
+ addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
+ if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
+ array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
+ addChannelRequest.Channel.Key = strings.Join(array, "\n")
+ } else {
+ cleanKeys := make([]string, 0)
+ for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
+ if key == "" {
+ continue
+ }
+ key = strings.TrimSpace(key)
+ cleanKeys = append(cleanKeys, key)
+ }
+ addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
+ addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
+ }
+ keys = []string{addChannelRequest.Channel.Key}
+ case "batch":
+ if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
+ // multi json
+ keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ } else {
+ keys = strings.Split(addChannelRequest.Channel.Key, "\n")
+ }
+ case "single":
+ keys = []string{addChannelRequest.Channel.Key}
+ default:
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "不支持的添加模式",
+ })
+ return
+ }
+
+ channels := make([]model.Channel, 0, len(keys))
+ for _, key := range keys {
+ if key == "" {
+ continue
+ }
+ localChannel := addChannelRequest.Channel
+ localChannel.Key = key
+ channels = append(channels, *localChannel)
+ }
+ err = model.BatchInsertChannels(channels)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func DeleteChannel(c *gin.Context) {
+ id, _ := strconv.Atoi(c.Param("id"))
+ channel := model.Channel{Id: id}
+ err := channel.Delete()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ model.InitChannelCache()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func DeleteDisabledChannel(c *gin.Context) {
+ rows, err := model.DeleteDisabledChannel()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ model.InitChannelCache()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": rows,
+ })
+ return
+}
+
+type ChannelTag struct {
+ Tag string `json:"tag"`
+ NewTag *string `json:"new_tag"`
+ Priority *int64 `json:"priority"`
+ Weight *uint `json:"weight"`
+ ModelMapping *string `json:"model_mapping"`
+ Models *string `json:"models"`
+ Groups *string `json:"groups"`
+}
+
+func DisableTagChannels(c *gin.Context) {
+ channelTag := ChannelTag{}
+ err := c.ShouldBindJSON(&channelTag)
+ if err != nil || channelTag.Tag == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "参数错误",
+ })
+ return
+ }
+ err = model.DisableChannelByTag(channelTag.Tag)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ model.InitChannelCache()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func EnableTagChannels(c *gin.Context) {
+ channelTag := ChannelTag{}
+ err := c.ShouldBindJSON(&channelTag)
+ if err != nil || channelTag.Tag == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "参数错误",
+ })
+ return
+ }
+ err = model.EnableChannelByTag(channelTag.Tag)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ model.InitChannelCache()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func EditTagChannels(c *gin.Context) {
+ channelTag := ChannelTag{}
+ err := c.ShouldBindJSON(&channelTag)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "参数错误",
+ })
+ return
+ }
+ if channelTag.Tag == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "tag不能为空",
+ })
+ return
+ }
+ err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ model.InitChannelCache()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+type ChannelBatch struct {
+ Ids []int `json:"ids"`
+ Tag *string `json:"tag"`
+}
+
+func DeleteChannelBatch(c *gin.Context) {
+ channelBatch := ChannelBatch{}
+ err := c.ShouldBindJSON(&channelBatch)
+ if err != nil || len(channelBatch.Ids) == 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "参数错误",
+ })
+ return
+ }
+ err = model.BatchDeleteChannels(channelBatch.Ids)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ model.InitChannelCache()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": len(channelBatch.Ids),
+ })
+ return
+}
+
+type PatchChannel struct {
+ model.Channel
+ MultiKeyMode *string `json:"multi_key_mode"`
+}
+
+func UpdateChannel(c *gin.Context) {
+ channel := PatchChannel{}
+ err := c.ShouldBindJSON(&channel)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ // 使用统一的校验函数
+ if err := validateChannel(&channel.Channel, false); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ // Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
+ originChannel, err := model.GetChannelById(channel.Id, false)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ // Always copy the original ChannelInfo so that fields like IsMultiKey and MultiKeySize are retained.
+ channel.ChannelInfo = originChannel.ChannelInfo
+
+ // If the request explicitly specifies a new MultiKeyMode, apply it on top of the original info.
+ if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
+ channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
+ }
+ err = channel.Update()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ model.InitChannelCache()
+ channel.Key = ""
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": channel,
+ })
+ return
+}
+
+func FetchModels(c *gin.Context) {
+ var req struct {
+ BaseURL string `json:"base_url"`
+ Type int `json:"type"`
+ Key string `json:"key"`
+ }
+
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "success": false,
+ "message": "Invalid request",
+ })
+ return
+ }
+
+ baseURL := req.BaseURL
+ if baseURL == "" {
+ baseURL = constant.ChannelBaseURLs[req.Type]
+ }
+
+ client := &http.Client{}
+ url := fmt.Sprintf("%s/v1/models", baseURL)
+
+ request, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ // remove line breaks and extra spaces.
+ key := strings.TrimSpace(req.Key)
+ // If the key contains a line break, only take the first part.
+ key = strings.Split(key, "\n")[0]
+ request.Header.Set("Authorization", "Bearer "+key)
+
+ response, err := client.Do(request)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ //check status code
+ if response.StatusCode != http.StatusOK {
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "success": false,
+ "message": "Failed to fetch models",
+ })
+ return
+ }
+ defer response.Body.Close()
+
+ var result struct {
+ Data []struct {
+ ID string `json:"id"`
+ } `json:"data"`
+ }
+
+ if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ var models []string
+ for _, model := range result.Data {
+ models = append(models, model.ID)
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "data": models,
+ })
+}
+
+func BatchSetChannelTag(c *gin.Context) {
+ channelBatch := ChannelBatch{}
+ err := c.ShouldBindJSON(&channelBatch)
+ if err != nil || len(channelBatch.Ids) == 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "参数错误",
+ })
+ return
+ }
+ err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ model.InitChannelCache()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": len(channelBatch.Ids),
+ })
+ return
+}
+
+func GetTagModels(c *gin.Context) {
+ tag := c.Query("tag")
+ if tag == "" {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "success": false,
+ "message": "tag不能为空",
+ })
+ return
+ }
+
+ channels, err := model.GetChannelsByTag(tag, false) // Assuming false for idSort is fine here
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ var longestModels string
+ maxLength := 0
+
+ // Find the longest models string among all channels with the given tag
+ for _, channel := range channels {
+ if channel.Models != "" {
+ currentModels := strings.Split(channel.Models, ",")
+ if len(currentModels) > maxLength {
+ maxLength = len(currentModels)
+ longestModels = channel.Models
+ }
+ }
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": longestModels,
+ })
+ return
+}
+
+// CopyChannel handles cloning an existing channel with its key.
+// POST /api/channel/copy/:id
+// Optional query params:
+//
+// suffix - string appended to the original name (default "_复制")
+// reset_balance - bool, when true will reset balance & used_quota to 0 (default true)
+func CopyChannel(c *gin.Context) {
+ id, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid id"})
+ return
+ }
+
+ suffix := c.DefaultQuery("suffix", "_复制")
+ resetBalance := true
+ if rbStr := c.DefaultQuery("reset_balance", "true"); rbStr != "" {
+ if v, err := strconv.ParseBool(rbStr); err == nil {
+ resetBalance = v
+ }
+ }
+
+ // fetch original channel with key
+ origin, err := model.GetChannelById(id, true)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+
+ // clone channel
+ clone := *origin // shallow copy is sufficient as we will overwrite primitives
+ clone.Id = 0 // let DB auto-generate
+ clone.CreatedTime = common.GetTimestamp()
+ clone.Name = origin.Name + suffix
+ clone.TestTime = 0
+ clone.ResponseTime = 0
+ if resetBalance {
+ clone.Balance = 0
+ clone.UsedQuota = 0
+ }
+
+ // insert
+ if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ model.InitChannelCache()
+ // success
+ c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
+}
diff --git a/controller/console_migrate.go b/controller/console_migrate.go
new file mode 100644
index 00000000..d25f199b
--- /dev/null
+++ b/controller/console_migrate.go
@@ -0,0 +1,103 @@
+// 用于迁移检测的旧键,该文件下个版本会删除
+
+package controller
+
+import (
+ "encoding/json"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "github.com/gin-gonic/gin"
+)
+
+// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
+func MigrateConsoleSetting(c *gin.Context) {
+ // 读取全部 option
+ opts, err := model.AllOption()
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ // 建立 map
+ valMap := map[string]string{}
+ for _, o := range opts {
+ valMap[o.Key] = o.Value
+ }
+
+ // 处理 APIInfo
+ if v := valMap["ApiInfo"]; v != "" {
+ var arr []map[string]interface{}
+ if err := json.Unmarshal([]byte(v), &arr); err == nil {
+ if len(arr) > 50 {
+ arr = arr[:50]
+ }
+ bytes, _ := json.Marshal(arr)
+ model.UpdateOption("console_setting.api_info", string(bytes))
+ }
+ model.UpdateOption("ApiInfo", "")
+ }
+ // Announcements 直接搬
+ if v := valMap["Announcements"]; v != "" {
+ model.UpdateOption("console_setting.announcements", v)
+ model.UpdateOption("Announcements", "")
+ }
+ // FAQ 转换
+ if v := valMap["FAQ"]; v != "" {
+ var arr []map[string]interface{}
+ if err := json.Unmarshal([]byte(v), &arr); err == nil {
+ out := []map[string]interface{}{}
+ for _, item := range arr {
+ q, _ := item["question"].(string)
+ if q == "" {
+ q, _ = item["title"].(string)
+ }
+ a, _ := item["answer"].(string)
+ if a == "" {
+ a, _ = item["content"].(string)
+ }
+ if q != "" && a != "" {
+ out = append(out, map[string]interface{}{"question": q, "answer": a})
+ }
+ }
+ if len(out) > 50 {
+ out = out[:50]
+ }
+ bytes, _ := json.Marshal(out)
+ model.UpdateOption("console_setting.faq", string(bytes))
+ }
+ model.UpdateOption("FAQ", "")
+ }
+ // Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
+ url := valMap["UptimeKumaUrl"]
+ slug := valMap["UptimeKumaSlug"]
+ if url != "" && slug != "" {
+ // 仅当同时存在 URL 与 Slug 时才进行迁移
+ groups := []map[string]interface{}{
+ {
+ "id": 1,
+ "categoryName": "old",
+ "url": url,
+ "slug": slug,
+ "description": "",
+ },
+ }
+ bytes, _ := json.Marshal(groups)
+ model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
+ }
+ // 清空旧键内容
+ if url != "" {
+ model.UpdateOption("UptimeKumaUrl", "")
+ }
+ if slug != "" {
+ model.UpdateOption("UptimeKumaSlug", "")
+ }
+
+ // 删除旧键记录
+ oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
+ model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
+
+ // 重新加载 OptionMap
+ model.InitOptionMap()
+ common.SysLog("console setting migrated")
+ c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
+}
\ No newline at end of file
diff --git a/controller/github.go b/controller/github.go
new file mode 100644
index 00000000..881d6dc1
--- /dev/null
+++ b/controller/github.go
@@ -0,0 +1,239 @@
+package controller
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "strconv"
+ "time"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+)
+
+type GitHubOAuthResponse struct {
+ AccessToken string `json:"access_token"`
+ Scope string `json:"scope"`
+ TokenType string `json:"token_type"`
+}
+
+type GitHubUser struct {
+ Login string `json:"login"`
+ Name string `json:"name"`
+ Email string `json:"email"`
+}
+
+func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
+ if code == "" {
+ return nil, errors.New("无效的参数")
+ }
+ values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
+ jsonData, err := json.Marshal(values)
+ if err != nil {
+ return nil, err
+ }
+ req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ client := http.Client{
+ Timeout: 5 * time.Second,
+ }
+ res, err := client.Do(req)
+ if err != nil {
+ common.SysLog(err.Error())
+ return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
+ }
+ defer res.Body.Close()
+ var oAuthResponse GitHubOAuthResponse
+ err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
+ if err != nil {
+ return nil, err
+ }
+ req, err = http.NewRequest("GET", "https://api.github.com/user", nil)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
+ res2, err := client.Do(req)
+ if err != nil {
+ common.SysLog(err.Error())
+ return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
+ }
+ defer res2.Body.Close()
+ var githubUser GitHubUser
+ err = json.NewDecoder(res2.Body).Decode(&githubUser)
+ if err != nil {
+ return nil, err
+ }
+ if githubUser.Login == "" {
+ return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
+ }
+ return &githubUser, nil
+}
+
+func GitHubOAuth(c *gin.Context) {
+ session := sessions.Default(c)
+ state := c.Query("state")
+ if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
+ c.JSON(http.StatusForbidden, gin.H{
+ "success": false,
+ "message": "state is empty or not same",
+ })
+ return
+ }
+ username := session.Get("username")
+ if username != nil {
+ GitHubBind(c)
+ return
+ }
+
+ if !common.GitHubOAuthEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员未开启通过 GitHub 登录以及注册",
+ })
+ return
+ }
+ code := c.Query("code")
+ githubUser, err := getGitHubUserInfoByCode(code)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ user := model.User{
+ GitHubId: githubUser.Login,
+ }
+ // IsGitHubIdAlreadyTaken is unscoped
+ if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
+ // FillUserByGitHubId is scoped
+ err := user.FillUserByGitHubId()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ // if user.Id == 0 , user has been deleted
+ if user.Id == 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "用户已注销",
+ })
+ return
+ }
+ } else {
+ if common.RegisterEnabled {
+ user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
+ if githubUser.Name != "" {
+ user.DisplayName = githubUser.Name
+ } else {
+ user.DisplayName = "GitHub User"
+ }
+ user.Email = githubUser.Email
+ user.Role = common.RoleCommonUser
+ user.Status = common.UserStatusEnabled
+ affCode := session.Get("aff")
+ inviterId := 0
+ if affCode != nil {
+ inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
+ }
+
+ if err := user.Insert(inviterId); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ } else {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员关闭了新用户注册",
+ })
+ return
+ }
+ }
+
+ if user.Status != common.UserStatusEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "用户已被封禁",
+ "success": false,
+ })
+ return
+ }
+ setupLogin(&user, c)
+}
+
+func GitHubBind(c *gin.Context) {
+ if !common.GitHubOAuthEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员未开启通过 GitHub 登录以及注册",
+ })
+ return
+ }
+ code := c.Query("code")
+ githubUser, err := getGitHubUserInfoByCode(code)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ user := model.User{
+ GitHubId: githubUser.Login,
+ }
+ if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "该 GitHub 账户已被绑定",
+ })
+ return
+ }
+ session := sessions.Default(c)
+ id := session.Get("id")
+ // id := c.GetInt("id") // critical bug!
+ user.Id = id.(int)
+ err = user.FillUserById()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ user.GitHubId = githubUser.Login
+ err = user.Update(false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "bind",
+ })
+ return
+}
+
+func GenerateOAuthCode(c *gin.Context) {
+ session := sessions.Default(c)
+ state := common.GetRandomString(12)
+ affCode := c.Query("aff")
+ if affCode != "" {
+ session.Set("aff", affCode)
+ }
+ session.Set("oauth_state", state)
+ err := session.Save()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": state,
+ })
+}
diff --git a/controller/group.go b/controller/group.go
new file mode 100644
index 00000000..2565b6ea
--- /dev/null
+++ b/controller/group.go
@@ -0,0 +1,50 @@
+package controller
+
+import (
+ "net/http"
+ "one-api/model"
+ "one-api/setting"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GetGroups(c *gin.Context) {
+ groupNames := make([]string, 0)
+ for groupName := range ratio_setting.GetGroupRatioCopy() {
+ groupNames = append(groupNames, groupName)
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": groupNames,
+ })
+}
+
+func GetUserGroups(c *gin.Context) {
+ usableGroups := make(map[string]map[string]interface{})
+ userGroup := ""
+ userId := c.GetInt("id")
+ userGroup, _ = model.GetUserGroup(userId, false)
+ for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
+ // UserUsableGroups contains the groups that the user can use
+ userUsableGroups := setting.GetUserUsableGroups(userGroup)
+ if desc, ok := userUsableGroups[groupName]; ok {
+ usableGroups[groupName] = map[string]interface{}{
+ "ratio": ratio,
+ "desc": desc,
+ }
+ }
+ }
+ if setting.GroupInUserUsableGroups("auto") {
+ usableGroups["auto"] = map[string]interface{}{
+ "ratio": "自动",
+ "desc": setting.GetUsableGroupDescription("auto"),
+ }
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": usableGroups,
+ })
+}
diff --git a/controller/image.go b/controller/image.go
new file mode 100644
index 00000000..d6e8806a
--- /dev/null
+++ b/controller/image.go
@@ -0,0 +1,9 @@
+package controller
+
+import (
+ "github.com/gin-gonic/gin"
+)
+
+func GetImage(c *gin.Context) {
+
+}
diff --git a/controller/linuxdo.go b/controller/linuxdo.go
new file mode 100644
index 00000000..65380b65
--- /dev/null
+++ b/controller/linuxdo.go
@@ -0,0 +1,259 @@
+package controller
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "net/url"
+ "one-api/common"
+ "one-api/model"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+)
+
+type LinuxdoUser struct {
+ Id int `json:"id"`
+ Username string `json:"username"`
+ Name string `json:"name"`
+ Active bool `json:"active"`
+ TrustLevel int `json:"trust_level"`
+ Silenced bool `json:"silenced"`
+}
+
+func LinuxDoBind(c *gin.Context) {
+ if !common.LinuxDOOAuthEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员未开启通过 Linux DO 登录以及注册",
+ })
+ return
+ }
+
+ code := c.Query("code")
+ linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ user := model.User{
+ LinuxDOId: strconv.Itoa(linuxdoUser.Id),
+ }
+
+ if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "该 Linux DO 账户已被绑定",
+ })
+ return
+ }
+
+ session := sessions.Default(c)
+ id := session.Get("id")
+ user.Id = id.(int)
+
+ err = user.FillUserById()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
+ err = user.Update(false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "bind",
+ })
+}
+
+func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) {
+ if code == "" {
+ return nil, errors.New("invalid code")
+ }
+
+ // Get access token using Basic auth
+ tokenEndpoint := "https://connect.linux.do/oauth2/token"
+ credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
+ basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
+
+ // Get redirect URI from request
+ scheme := "http"
+ if c.Request.TLS != nil {
+ scheme = "https"
+ }
+ redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
+
+ data := url.Values{}
+ data.Set("grant_type", "authorization_code")
+ data.Set("code", code)
+ data.Set("redirect_uri", redirectURI)
+
+ req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode()))
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Authorization", basicAuth)
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Set("Accept", "application/json")
+
+ client := http.Client{Timeout: 5 * time.Second}
+ res, err := client.Do(req)
+ if err != nil {
+ return nil, errors.New("failed to connect to Linux DO server")
+ }
+ defer res.Body.Close()
+
+ var tokenRes struct {
+ AccessToken string `json:"access_token"`
+ Message string `json:"message"`
+ }
+ if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
+ return nil, err
+ }
+
+ if tokenRes.AccessToken == "" {
+ return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message)
+ }
+
+ // Get user info
+ userEndpoint := "https://connect.linux.do/api/user"
+ req, err = http.NewRequest("GET", userEndpoint, nil)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken)
+ req.Header.Set("Accept", "application/json")
+
+ res2, err := client.Do(req)
+ if err != nil {
+ return nil, errors.New("failed to get user info from Linux DO")
+ }
+ defer res2.Body.Close()
+
+ var linuxdoUser LinuxdoUser
+ if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil {
+ return nil, err
+ }
+
+ if linuxdoUser.Id == 0 {
+ return nil, errors.New("invalid user info returned")
+ }
+
+ return &linuxdoUser, nil
+}
+
+func LinuxdoOAuth(c *gin.Context) {
+ session := sessions.Default(c)
+
+ errorCode := c.Query("error")
+ if errorCode != "" {
+ errorDescription := c.Query("error_description")
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": errorDescription,
+ })
+ return
+ }
+
+ state := c.Query("state")
+ if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
+ c.JSON(http.StatusForbidden, gin.H{
+ "success": false,
+ "message": "state is empty or not same",
+ })
+ return
+ }
+
+ username := session.Get("username")
+ if username != nil {
+ LinuxDoBind(c)
+ return
+ }
+
+ if !common.LinuxDOOAuthEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员未开启通过 Linux DO 登录以及注册",
+ })
+ return
+ }
+
+ code := c.Query("code")
+ linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ user := model.User{
+ LinuxDOId: strconv.Itoa(linuxdoUser.Id),
+ }
+
+ // Check if user exists
+ if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
+ err := user.FillUserByLinuxDOId()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ if user.Id == 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "用户已注销",
+ })
+ return
+ }
+ } else {
+ if common.RegisterEnabled {
+ user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
+ user.DisplayName = linuxdoUser.Name
+ user.Role = common.RoleCommonUser
+ user.Status = common.UserStatusEnabled
+
+ affCode := session.Get("aff")
+ inviterId := 0
+ if affCode != nil {
+ inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
+ }
+
+ if err := user.Insert(inviterId); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ } else {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员关闭了新用户注册",
+ })
+ return
+ }
+ }
+
+ if user.Status != common.UserStatusEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "用户已被封禁",
+ "success": false,
+ })
+ return
+ }
+
+ setupLogin(&user, c)
+}
diff --git a/controller/log.go b/controller/log.go
new file mode 100644
index 00000000..042fa725
--- /dev/null
+++ b/controller/log.go
@@ -0,0 +1,168 @@
+package controller
+
+import (
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "strconv"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GetAllLogs(c *gin.Context) {
+ pageInfo := common.GetPageQuery(c)
+ logType, _ := strconv.Atoi(c.Query("type"))
+ startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+ endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+ username := c.Query("username")
+ tokenName := c.Query("token_name")
+ modelName := c.Query("model_name")
+ channel, _ := strconv.Atoi(c.Query("channel"))
+ group := c.Query("group")
+ logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), channel, group)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(logs)
+ common.ApiSuccess(c, pageInfo)
+ return
+}
+
+func GetUserLogs(c *gin.Context) {
+ pageInfo := common.GetPageQuery(c)
+ userId := c.GetInt("id")
+ logType, _ := strconv.Atoi(c.Query("type"))
+ startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+ endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+ tokenName := c.Query("token_name")
+ modelName := c.Query("model_name")
+ group := c.Query("group")
+ logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), group)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(logs)
+ common.ApiSuccess(c, pageInfo)
+ return
+}
+
+func SearchAllLogs(c *gin.Context) {
+ keyword := c.Query("keyword")
+ logs, err := model.SearchAllLogs(keyword)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": logs,
+ })
+ return
+}
+
+func SearchUserLogs(c *gin.Context) {
+ keyword := c.Query("keyword")
+ userId := c.GetInt("id")
+ logs, err := model.SearchUserLogs(userId, keyword)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": logs,
+ })
+ return
+}
+
+func GetLogByKey(c *gin.Context) {
+ key := c.Query("key")
+ logs, err := model.GetLogByKey(key)
+ if err != nil {
+ c.JSON(200, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(200, gin.H{
+ "success": true,
+ "message": "",
+ "data": logs,
+ })
+}
+
+func GetLogsStat(c *gin.Context) {
+ logType, _ := strconv.Atoi(c.Query("type"))
+ startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+ endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+ tokenName := c.Query("token_name")
+ username := c.Query("username")
+ modelName := c.Query("model_name")
+ channel, _ := strconv.Atoi(c.Query("channel"))
+ group := c.Query("group")
+ stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
+ //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": gin.H{
+ "quota": stat.Quota,
+ "rpm": stat.Rpm,
+ "tpm": stat.Tpm,
+ },
+ })
+ return
+}
+
+func GetLogsSelfStat(c *gin.Context) {
+ username := c.GetString("username")
+ logType, _ := strconv.Atoi(c.Query("type"))
+ startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+ endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+ tokenName := c.Query("token_name")
+ modelName := c.Query("model_name")
+ channel, _ := strconv.Atoi(c.Query("channel"))
+ group := c.Query("group")
+ quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
+ //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
+ c.JSON(200, gin.H{
+ "success": true,
+ "message": "",
+ "data": gin.H{
+ "quota": quotaNum.Quota,
+ "rpm": quotaNum.Rpm,
+ "tpm": quotaNum.Tpm,
+ //"token": tokenNum,
+ },
+ })
+ return
+}
+
+func DeleteHistoryLogs(c *gin.Context) {
+ targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
+ if targetTimestamp == 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "target timestamp is required",
+ })
+ return
+ }
+ count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": count,
+ })
+ return
+}
diff --git a/controller/midjourney.go b/controller/midjourney.go
new file mode 100644
index 00000000..02ad708f
--- /dev/null
+++ b/controller/midjourney.go
@@ -0,0 +1,263 @@
+package controller
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/model"
+ "one-api/service"
+ "one-api/setting"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+func UpdateMidjourneyTaskBulk() {
+ //imageModel := "midjourney"
+ ctx := context.TODO()
+ for {
+ time.Sleep(time.Duration(15) * time.Second)
+
+ tasks := model.GetAllUnFinishTasks()
+ if len(tasks) == 0 {
+ continue
+ }
+
+ common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
+ taskChannelM := make(map[int][]string)
+ taskM := make(map[string]*model.Midjourney)
+ nullTaskIds := make([]int, 0)
+ for _, task := range tasks {
+ if task.MjId == "" {
+ // 统计失败的未完成任务
+ nullTaskIds = append(nullTaskIds, task.Id)
+ continue
+ }
+ taskM[task.MjId] = task
+ taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
+ }
+ if len(nullTaskIds) > 0 {
+ err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{
+ "status": "FAILURE",
+ "progress": "100%",
+ })
+ if err != nil {
+ common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
+ } else {
+ common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
+ }
+ }
+ if len(taskChannelM) == 0 {
+ continue
+ }
+
+ for channelId, taskIds := range taskChannelM {
+ common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+ if len(taskIds) == 0 {
+ continue
+ }
+ midjourneyChannel, err := model.CacheGetChannel(channelId)
+ if err != nil {
+ common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
+ err := model.MjBulkUpdate(taskIds, map[string]any{
+ "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
+ "status": "FAILURE",
+ "progress": "100%",
+ })
+ if err != nil {
+ common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
+ }
+ continue
+ }
+ requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
+
+ body, _ := json.Marshal(map[string]any{
+ "ids": taskIds,
+ })
+ req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
+ if err != nil {
+ common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
+ continue
+ }
+ // 设置超时时间
+ timeout := time.Second * 15
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ // 使用带有超时的 context 创建新的请求
+ req = req.WithContext(ctx)
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("mj-api-secret", midjourneyChannel.Key)
+ resp, err := service.GetHttpClient().Do(req)
+ if err != nil {
+ common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
+ continue
+ }
+ if resp.StatusCode != http.StatusOK {
+ common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+ continue
+ }
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
+ continue
+ }
+ var responseItems []dto.MidjourneyDto
+ err = json.Unmarshal(responseBody, &responseItems)
+ if err != nil {
+ common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
+ continue
+ }
+ resp.Body.Close()
+ req.Body.Close()
+ cancel()
+
+ for _, responseItem := range responseItems {
+ task := taskM[responseItem.MjId]
+
+ useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
+ // 如果时间超过一小时,且进度不是100%,则认为任务失败
+ if useTime > 3600000 && task.Progress != "100%" {
+ responseItem.FailReason = "上游任务超时(超过1小时)"
+ responseItem.Status = "FAILURE"
+ }
+ if !checkMjTaskNeedUpdate(task, responseItem) {
+ continue
+ }
+ task.Code = 1
+ task.Progress = responseItem.Progress
+ task.PromptEn = responseItem.PromptEn
+ task.State = responseItem.State
+ task.SubmitTime = responseItem.SubmitTime
+ task.StartTime = responseItem.StartTime
+ task.FinishTime = responseItem.FinishTime
+ task.ImageUrl = responseItem.ImageUrl
+ task.Status = responseItem.Status
+ task.FailReason = responseItem.FailReason
+ if responseItem.Properties != nil {
+ propertiesStr, _ := json.Marshal(responseItem.Properties)
+ task.Properties = string(propertiesStr)
+ }
+ if responseItem.Buttons != nil {
+ buttonStr, _ := json.Marshal(responseItem.Buttons)
+ task.Buttons = string(buttonStr)
+ }
+ shouldReturnQuota := false
+ if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
+ common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
+ task.Progress = "100%"
+ if task.Quota != 0 {
+ shouldReturnQuota = true
+ }
+ }
+ err = task.Update()
+ if err != nil {
+ common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
+ } else {
+ if shouldReturnQuota {
+ err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
+ if err != nil {
+ common.LogError(ctx, "fail to increase user quota: "+err.Error())
+ }
+ logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota))
+ model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+ }
+ }
+ }
+ }
+ }
+}
+
+func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
+ if oldTask.Code != 1 {
+ return true
+ }
+ if oldTask.Progress != newTask.Progress {
+ return true
+ }
+ if oldTask.PromptEn != newTask.PromptEn {
+ return true
+ }
+ if oldTask.State != newTask.State {
+ return true
+ }
+ if oldTask.SubmitTime != newTask.SubmitTime {
+ return true
+ }
+ if oldTask.StartTime != newTask.StartTime {
+ return true
+ }
+ if oldTask.FinishTime != newTask.FinishTime {
+ return true
+ }
+ if oldTask.ImageUrl != newTask.ImageUrl {
+ return true
+ }
+ if oldTask.Status != newTask.Status {
+ return true
+ }
+ if oldTask.FailReason != newTask.FailReason {
+ return true
+ }
+ if oldTask.FinishTime != newTask.FinishTime {
+ return true
+ }
+ if oldTask.Progress != "100%" && newTask.FailReason != "" {
+ return true
+ }
+
+ return false
+}
+
+func GetAllMidjourney(c *gin.Context) {
+ pageInfo := common.GetPageQuery(c)
+
+ // 解析其他查询参数
+ queryParams := model.TaskQueryParams{
+ ChannelID: c.Query("channel_id"),
+ MjID: c.Query("mj_id"),
+ StartTimestamp: c.Query("start_timestamp"),
+ EndTimestamp: c.Query("end_timestamp"),
+ }
+
+ items := model.GetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
+ total := model.CountAllTasks(queryParams)
+
+ if setting.MjForwardUrlEnabled {
+ for i, midjourney := range items {
+ midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
+ items[i] = midjourney
+ }
+ }
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(items)
+ common.ApiSuccess(c, pageInfo)
+}
+
+func GetUserMidjourney(c *gin.Context) {
+ pageInfo := common.GetPageQuery(c)
+
+ userId := c.GetInt("id")
+
+ queryParams := model.TaskQueryParams{
+ MjID: c.Query("mj_id"),
+ StartTimestamp: c.Query("start_timestamp"),
+ EndTimestamp: c.Query("end_timestamp"),
+ }
+
+ items := model.GetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
+ total := model.CountAllUserTask(userId, queryParams)
+
+ if setting.MjForwardUrlEnabled {
+ for i, midjourney := range items {
+ midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
+ items[i] = midjourney
+ }
+ }
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(items)
+ common.ApiSuccess(c, pageInfo)
+}
diff --git a/controller/misc.go b/controller/misc.go
new file mode 100644
index 00000000..a3ed9be9
--- /dev/null
+++ b/controller/misc.go
@@ -0,0 +1,302 @@
+package controller
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/middleware"
+ "one-api/model"
+ "one-api/setting"
+ "one-api/setting/console_setting"
+ "one-api/setting/operation_setting"
+ "one-api/setting/system_setting"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func TestStatus(c *gin.Context) {
+ err := model.PingDB()
+ if err != nil {
+ c.JSON(http.StatusServiceUnavailable, gin.H{
+ "success": false,
+ "message": "数据库连接失败",
+ })
+ return
+ }
+ // 获取HTTP统计信息
+ httpStats := middleware.GetStats()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "Server is running",
+ "http_stats": httpStats,
+ })
+ return
+}
+
+func GetStatus(c *gin.Context) {
+
+ cs := console_setting.GetConsoleSetting()
+
+ data := gin.H{
+ "version": common.Version,
+ "start_time": common.StartTime,
+ "email_verification": common.EmailVerificationEnabled,
+ "github_oauth": common.GitHubOAuthEnabled,
+ "github_client_id": common.GitHubClientId,
+ "linuxdo_oauth": common.LinuxDOOAuthEnabled,
+ "linuxdo_client_id": common.LinuxDOClientId,
+ "telegram_oauth": common.TelegramOAuthEnabled,
+ "telegram_bot_name": common.TelegramBotName,
+ "system_name": common.SystemName,
+ "logo": common.Logo,
+ "footer_html": common.Footer,
+ "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
+ "wechat_login": common.WeChatAuthEnabled,
+ "server_address": setting.ServerAddress,
+ "price": setting.Price,
+ "stripe_unit_price": setting.StripeUnitPrice,
+ "min_topup": setting.MinTopUp,
+ "stripe_min_topup": setting.StripeMinTopUp,
+ "turnstile_check": common.TurnstileCheckEnabled,
+ "turnstile_site_key": common.TurnstileSiteKey,
+ "top_up_link": common.TopUpLink,
+ "docs_link": operation_setting.GetGeneralSetting().DocsLink,
+ "quota_per_unit": common.QuotaPerUnit,
+ "display_in_currency": common.DisplayInCurrencyEnabled,
+ "enable_batch_update": common.BatchUpdateEnabled,
+ "enable_drawing": common.DrawingEnabled,
+ "enable_task": common.TaskEnabled,
+ "enable_data_export": common.DataExportEnabled,
+ "data_export_default_time": common.DataExportDefaultTime,
+ "default_collapse_sidebar": common.DefaultCollapseSidebar,
+ "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
+ "enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
+ "mj_notify_enabled": setting.MjNotifyEnabled,
+ "chats": setting.Chats,
+ "demo_site_enabled": operation_setting.DemoSiteEnabled,
+ "self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
+ "default_use_auto_group": setting.DefaultUseAutoGroup,
+ "pay_methods": setting.PayMethods,
+ "usd_exchange_rate": setting.USDExchangeRate,
+
+ // 面板启用开关
+ "api_info_enabled": cs.ApiInfoEnabled,
+ "uptime_kuma_enabled": cs.UptimeKumaEnabled,
+ "announcements_enabled": cs.AnnouncementsEnabled,
+ "faq_enabled": cs.FAQEnabled,
+
+ "oidc_enabled": system_setting.GetOIDCSettings().Enabled,
+ "oidc_client_id": system_setting.GetOIDCSettings().ClientId,
+ "oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
+ "setup": constant.Setup,
+ }
+
+ // 根据启用状态注入可选内容
+ if cs.ApiInfoEnabled {
+ data["api_info"] = console_setting.GetApiInfo()
+ }
+ if cs.AnnouncementsEnabled {
+ data["announcements"] = console_setting.GetAnnouncements()
+ }
+ if cs.FAQEnabled {
+ data["faq"] = console_setting.GetFAQ()
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": data,
+ })
+ return
+}
+
+func GetNotice(c *gin.Context) {
+ common.OptionMapRWMutex.RLock()
+ defer common.OptionMapRWMutex.RUnlock()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": common.OptionMap["Notice"],
+ })
+ return
+}
+
+func GetAbout(c *gin.Context) {
+ common.OptionMapRWMutex.RLock()
+ defer common.OptionMapRWMutex.RUnlock()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": common.OptionMap["About"],
+ })
+ return
+}
+
+func GetMidjourney(c *gin.Context) {
+ common.OptionMapRWMutex.RLock()
+ defer common.OptionMapRWMutex.RUnlock()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": common.OptionMap["Midjourney"],
+ })
+ return
+}
+
+func GetHomePageContent(c *gin.Context) {
+ common.OptionMapRWMutex.RLock()
+ defer common.OptionMapRWMutex.RUnlock()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": common.OptionMap["HomePageContent"],
+ })
+ return
+}
+
+func SendEmailVerification(c *gin.Context) {
+ email := c.Query("email")
+ if err := common.Validate.Var(email, "required,email"); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的参数",
+ })
+ return
+ }
+ parts := strings.Split(email, "@")
+ if len(parts) != 2 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的邮箱地址",
+ })
+ return
+ }
+ localPart := parts[0]
+ domainPart := parts[1]
+ if common.EmailDomainRestrictionEnabled {
+ allowed := false
+ for _, domain := range common.EmailDomainWhitelist {
+ if domainPart == domain {
+ allowed = true
+ break
+ }
+ }
+ if !allowed {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "The administrator has enabled the email domain name whitelist, and your email address is not allowed due to special symbols or it's not in the whitelist.",
+ })
+ return
+ }
+ }
+ if common.EmailAliasRestrictionEnabled {
+ containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Contains(localPart, ".")
+ if containsSpecialSymbols {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员已启用邮箱地址别名限制,您的邮箱地址由于包含特殊符号而被拒绝。",
+ })
+ return
+ }
+ }
+
+ if model.IsEmailAlreadyTaken(email) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "邮箱地址已被占用",
+ })
+ return
+ }
+ code := common.GenerateVerificationCode(6)
+ common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose)
+ subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName)
+ content := fmt.Sprintf("您好,你正在进行%s邮箱验证。
"+
+ "您的验证码为: %s
"+
+ "验证码 %d 分钟内有效,如果不是本人操作,请忽略。
", common.SystemName, code, common.VerificationValidMinutes)
+ err := common.SendEmail(subject, email, content)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func SendPasswordResetEmail(c *gin.Context) {
+ email := c.Query("email")
+ if err := common.Validate.Var(email, "required,email"); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的参数",
+ })
+ return
+ }
+ if !model.IsEmailAlreadyTaken(email) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "该邮箱地址未注册",
+ })
+ return
+ }
+ code := common.GenerateVerificationCode(0)
+ common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
+ link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code)
+ subject := fmt.Sprintf("%s密码重置", common.SystemName)
+ content := fmt.Sprintf("您好,你正在进行%s密码重置。
"+
+ "点击 此处 进行密码重置。
"+
+ "如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:
%s
"+
+ "重置链接 %d 分钟内有效,如果不是本人操作,请忽略。
", common.SystemName, link, link, common.VerificationValidMinutes)
+ err := common.SendEmail(subject, email, content)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+type PasswordResetRequest struct {
+ Email string `json:"email"`
+ Token string `json:"token"`
+}
+
+func ResetPassword(c *gin.Context) {
+ var req PasswordResetRequest
+ err := json.NewDecoder(c.Request.Body).Decode(&req)
+ if req.Email == "" || req.Token == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的参数",
+ })
+ return
+ }
+ if !common.VerifyCodeWithKey(req.Email, req.Token, common.PasswordResetPurpose) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "重置链接非法或已过期",
+ })
+ return
+ }
+ password := common.GenerateVerificationCode(12)
+ err = model.ResetUserPasswordByEmail(req.Email, password)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ common.DeleteKey(req.Email, common.PasswordResetPurpose)
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": password,
+ })
+ return
+}
diff --git a/controller/model.go b/controller/model.go
new file mode 100644
index 00000000..31a66b29
--- /dev/null
+++ b/controller/model.go
@@ -0,0 +1,216 @@
+package controller
+
+import (
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "github.com/samber/lo"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/model"
+ "one-api/relay"
+ "one-api/relay/channel/ai360"
+ "one-api/relay/channel/lingyiwanwu"
+ "one-api/relay/channel/minimax"
+ "one-api/relay/channel/moonshot"
+ relaycommon "one-api/relay/common"
+ "one-api/setting"
+)
+
+// https://platform.openai.com/docs/api-reference/models/list
+
+var openAIModels []dto.OpenAIModels
+var openAIModelsMap map[string]dto.OpenAIModels
+var channelId2Models map[int][]string
+
+func init() {
+ // https://platform.openai.com/docs/models/model-endpoint-compatibility
+ for i := 0; i < constant.APITypeDummy; i++ {
+ if i == constant.APITypeAIProxyLibrary {
+ continue
+ }
+ adaptor := relay.GetAdaptor(i)
+ channelName := adaptor.GetChannelName()
+ modelNames := adaptor.GetModelList()
+ for _, modelName := range modelNames {
+ openAIModels = append(openAIModels, dto.OpenAIModels{
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: channelName,
+ })
+ }
+ }
+ for _, modelName := range ai360.ModelList {
+ openAIModels = append(openAIModels, dto.OpenAIModels{
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: ai360.ChannelName,
+ })
+ }
+ for _, modelName := range moonshot.ModelList {
+ openAIModels = append(openAIModels, dto.OpenAIModels{
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: moonshot.ChannelName,
+ })
+ }
+ for _, modelName := range lingyiwanwu.ModelList {
+ openAIModels = append(openAIModels, dto.OpenAIModels{
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: lingyiwanwu.ChannelName,
+ })
+ }
+ for _, modelName := range minimax.ModelList {
+ openAIModels = append(openAIModels, dto.OpenAIModels{
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: minimax.ChannelName,
+ })
+ }
+ for modelName, _ := range constant.MidjourneyModel2Action {
+ openAIModels = append(openAIModels, dto.OpenAIModels{
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: "midjourney",
+ })
+ }
+ openAIModelsMap = make(map[string]dto.OpenAIModels)
+ for _, aiModel := range openAIModels {
+ openAIModelsMap[aiModel.Id] = aiModel
+ }
+ channelId2Models = make(map[int][]string)
+ for i := 1; i <= constant.ChannelTypeDummy; i++ {
+ apiType, success := common.ChannelType2APIType(i)
+ if !success || apiType == constant.APITypeAIProxyLibrary {
+ continue
+ }
+ meta := &relaycommon.RelayInfo{ChannelType: i}
+ adaptor := relay.GetAdaptor(apiType)
+ adaptor.Init(meta)
+ channelId2Models[i] = adaptor.GetModelList()
+ }
+ openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
+ return m.Id
+ })
+}
+
+func ListModels(c *gin.Context) {
+ userOpenAiModels := make([]dto.OpenAIModels, 0)
+
+ modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
+ if modelLimitEnable {
+ s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
+ var tokenModelLimit map[string]bool
+ if ok {
+ tokenModelLimit = s.(map[string]bool)
+ } else {
+ tokenModelLimit = map[string]bool{}
+ }
+ for allowModel, _ := range tokenModelLimit {
+ if oaiModel, ok := openAIModelsMap[allowModel]; ok {
+ oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
+ userOpenAiModels = append(userOpenAiModels, oaiModel)
+ } else {
+ userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
+ Id: allowModel,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: "custom",
+ SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
+ })
+ }
+ }
+ } else {
+ userId := c.GetInt("id")
+ userGroup, err := model.GetUserGroup(userId, false)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "get user group failed",
+ })
+ return
+ }
+ group := userGroup
+ tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
+ if tokenGroup != "" {
+ group = tokenGroup
+ }
+ var models []string
+ if tokenGroup == "auto" {
+ for _, autoGroup := range setting.AutoGroups {
+ groupModels := model.GetGroupEnabledModels(autoGroup)
+ for _, g := range groupModels {
+ if !common.StringsContains(models, g) {
+ models = append(models, g)
+ }
+ }
+ }
+ } else {
+ models = model.GetGroupEnabledModels(group)
+ }
+ for _, modelName := range models {
+ if oaiModel, ok := openAIModelsMap[modelName]; ok {
+ oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
+ userOpenAiModels = append(userOpenAiModels, oaiModel)
+ } else {
+ userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: "custom",
+ SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
+ })
+ }
+ }
+ }
+ c.JSON(200, gin.H{
+ "success": true,
+ "data": userOpenAiModels,
+ })
+}
+
+func ChannelListModels(c *gin.Context) {
+ c.JSON(200, gin.H{
+ "success": true,
+ "data": openAIModels,
+ })
+}
+
+func DashboardListModels(c *gin.Context) {
+ c.JSON(200, gin.H{
+ "success": true,
+ "data": channelId2Models,
+ })
+}
+
+func EnabledListModels(c *gin.Context) {
+ c.JSON(200, gin.H{
+ "success": true,
+ "data": model.GetEnabledModels(),
+ })
+}
+
+func RetrieveModel(c *gin.Context) {
+ modelId := c.Param("model")
+ if aiModel, ok := openAIModelsMap[modelId]; ok {
+ c.JSON(200, aiModel)
+ } else {
+ openAIError := dto.OpenAIError{
+ Message: fmt.Sprintf("The model '%s' does not exist", modelId),
+ Type: "invalid_request_error",
+ Param: "model",
+ Code: "model_not_found",
+ }
+ c.JSON(200, gin.H{
+ "error": openAIError,
+ })
+ }
+}
diff --git a/controller/oidc.go b/controller/oidc.go
new file mode 100644
index 00000000..df8ea1c4
--- /dev/null
+++ b/controller/oidc.go
@@ -0,0 +1,228 @@
+package controller
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "net/url"
+ "one-api/common"
+ "one-api/model"
+ "one-api/setting"
+ "one-api/setting/system_setting"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+)
+
+type OidcResponse struct {
+ AccessToken string `json:"access_token"`
+ IDToken string `json:"id_token"`
+ RefreshToken string `json:"refresh_token"`
+ TokenType string `json:"token_type"`
+ ExpiresIn int `json:"expires_in"`
+ Scope string `json:"scope"`
+}
+
+type OidcUser struct {
+ OpenID string `json:"sub"`
+ Email string `json:"email"`
+ Name string `json:"name"`
+ PreferredUsername string `json:"preferred_username"`
+ Picture string `json:"picture"`
+}
+
+func getOidcUserInfoByCode(code string) (*OidcUser, error) {
+ if code == "" {
+ return nil, errors.New("无效的参数")
+ }
+
+ values := url.Values{}
+ values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
+ values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
+ values.Set("code", code)
+ values.Set("grant_type", "authorization_code")
+ values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress))
+ formData := values.Encode()
+ req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Set("Accept", "application/json")
+ client := http.Client{
+ Timeout: 5 * time.Second,
+ }
+ res, err := client.Do(req)
+ if err != nil {
+ common.SysLog(err.Error())
+ return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
+ }
+ defer res.Body.Close()
+ var oidcResponse OidcResponse
+ err = json.NewDecoder(res.Body).Decode(&oidcResponse)
+ if err != nil {
+ return nil, err
+ }
+
+ if oidcResponse.AccessToken == "" {
+ common.SysError("OIDC 获取 Token 失败,请检查设置!")
+ return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
+ }
+
+ req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
+ res2, err := client.Do(req)
+ if err != nil {
+ common.SysLog(err.Error())
+ return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
+ }
+ defer res2.Body.Close()
+ if res2.StatusCode != http.StatusOK {
+ common.SysError("OIDC 获取用户信息失败!请检查设置!")
+ return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
+ }
+
+ var oidcUser OidcUser
+ err = json.NewDecoder(res2.Body).Decode(&oidcUser)
+ if err != nil {
+ return nil, err
+ }
+ if oidcUser.OpenID == "" || oidcUser.Email == "" {
+ common.SysError("OIDC 获取用户信息为空!请检查设置!")
+ return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
+ }
+ return &oidcUser, nil
+}
+
+func OidcAuth(c *gin.Context) {
+ session := sessions.Default(c)
+ state := c.Query("state")
+ if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
+ c.JSON(http.StatusForbidden, gin.H{
+ "success": false,
+ "message": "state is empty or not same",
+ })
+ return
+ }
+ username := session.Get("username")
+ if username != nil {
+ OidcBind(c)
+ return
+ }
+ if !system_setting.GetOIDCSettings().Enabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员未开启通过 OIDC 登录以及注册",
+ })
+ return
+ }
+ code := c.Query("code")
+ oidcUser, err := getOidcUserInfoByCode(code)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ user := model.User{
+ OidcId: oidcUser.OpenID,
+ }
+ if model.IsOidcIdAlreadyTaken(user.OidcId) {
+ err := user.FillUserByOidcId()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ } else {
+ if common.RegisterEnabled {
+ user.Email = oidcUser.Email
+ if oidcUser.PreferredUsername != "" {
+ user.Username = oidcUser.PreferredUsername
+ } else {
+ user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
+ }
+ if oidcUser.Name != "" {
+ user.DisplayName = oidcUser.Name
+ } else {
+ user.DisplayName = "OIDC User"
+ }
+ err := user.Insert(0)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ } else {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员关闭了新用户注册",
+ })
+ return
+ }
+ }
+
+ if user.Status != common.UserStatusEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "用户已被封禁",
+ "success": false,
+ })
+ return
+ }
+ setupLogin(&user, c)
+}
+
+func OidcBind(c *gin.Context) {
+ if !system_setting.GetOIDCSettings().Enabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员未开启通过 OIDC 登录以及注册",
+ })
+ return
+ }
+ code := c.Query("code")
+ oidcUser, err := getOidcUserInfoByCode(code)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ user := model.User{
+ OidcId: oidcUser.OpenID,
+ }
+ if model.IsOidcIdAlreadyTaken(user.OidcId) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "该 OIDC 账户已被绑定",
+ })
+ return
+ }
+ session := sessions.Default(c)
+ id := session.Get("id")
+ // id := c.GetInt("id") // critical bug!
+ user.Id = id.(int)
+ err = user.FillUserById()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ user.OidcId = oidcUser.OpenID
+ err = user.Update(false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "bind",
+ })
+ return
+}
diff --git a/controller/option.go b/controller/option.go
new file mode 100644
index 00000000..decdb0d4
--- /dev/null
+++ b/controller/option.go
@@ -0,0 +1,171 @@
+package controller
+
+import (
+ "encoding/json"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "one-api/setting"
+ "one-api/setting/console_setting"
+ "one-api/setting/ratio_setting"
+ "one-api/setting/system_setting"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GetOptions(c *gin.Context) {
+ var options []*model.Option
+ common.OptionMapRWMutex.Lock()
+ for k, v := range common.OptionMap {
+ if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "Key") {
+ continue
+ }
+ options = append(options, &model.Option{
+ Key: k,
+ Value: common.Interface2String(v),
+ })
+ }
+ common.OptionMapRWMutex.Unlock()
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": options,
+ })
+ return
+}
+
+func UpdateOption(c *gin.Context) {
+ var option model.Option
+ err := json.NewDecoder(c.Request.Body).Decode(&option)
+ if err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "success": false,
+ "message": "无效的参数",
+ })
+ return
+ }
+ switch option.Key {
+ case "GitHubOAuthEnabled":
+ if option.Value == "true" && common.GitHubClientId == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!",
+ })
+ return
+ }
+ case "oidc.enabled":
+ if option.Value == "true" && system_setting.GetOIDCSettings().ClientId == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无法启用 OIDC 登录,请先填入 OIDC Client Id 以及 OIDC Client Secret!",
+ })
+ return
+ }
+ case "LinuxDOOAuthEnabled":
+ if option.Value == "true" && common.LinuxDOClientId == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无法启用 LinuxDO OAuth,请先填入 LinuxDO Client Id 以及 LinuxDO Client Secret!",
+ })
+ return
+ }
+ case "EmailDomainRestrictionEnabled":
+ if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
+ })
+ return
+ }
+ case "WeChatAuthEnabled":
+ if option.Value == "true" && common.WeChatServerAddress == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无法启用微信登录,请先填入微信登录相关配置信息!",
+ })
+ return
+ }
+ case "TurnstileCheckEnabled":
+ if option.Value == "true" && common.TurnstileSiteKey == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",
+ })
+
+ return
+ }
+ case "TelegramOAuthEnabled":
+ if option.Value == "true" && common.TelegramBotToken == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无法启用 Telegram OAuth,请先填入 Telegram Bot Token!",
+ })
+ return
+ }
+ case "GroupRatio":
+ err = ratio_setting.CheckGroupRatio(option.Value)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ case "ModelRequestRateLimitGroup":
+ err = setting.CheckModelRequestRateLimitGroup(option.Value)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ case "console_setting.api_info":
+ err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ case "console_setting.announcements":
+ err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ case "console_setting.faq":
+ err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ case "console_setting.uptime_kuma_groups":
+ err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ }
+ err = model.UpdateOption(option.Key, option.Value)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
diff --git a/controller/playground.go b/controller/playground.go
new file mode 100644
index 00000000..0073cf06
--- /dev/null
+++ b/controller/playground.go
@@ -0,0 +1,84 @@
+package controller
+
+import (
+ "errors"
+ "fmt"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/middleware"
+ "one-api/model"
+ "one-api/setting"
+ "one-api/types"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+func Playground(c *gin.Context) {
+ var newAPIError *types.NewAPIError
+
+ defer func() {
+ if newAPIError != nil {
+ c.JSON(newAPIError.StatusCode, gin.H{
+ "error": newAPIError.ToOpenAIError(),
+ })
+ }
+ }()
+
+ useAccessToken := c.GetBool("use_access_token")
+ if useAccessToken {
+ newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
+ return
+ }
+
+ playgroundRequest := &dto.PlayGroundRequest{}
+ err := common.UnmarshalBodyReusable(c, playgroundRequest)
+ if err != nil {
+ newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
+ return
+ }
+
+ if playgroundRequest.Model == "" {
+ newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
+ return
+ }
+ c.Set("original_model", playgroundRequest.Model)
+ group := playgroundRequest.Group
+ userGroup := c.GetString("group")
+
+ if group == "" {
+ group = userGroup
+ } else {
+ if !setting.GroupInUserUsableGroups(group) && group != userGroup {
+ newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
+ return
+ }
+ c.Set("group", group)
+ }
+
+ userId := c.GetInt("id")
+
+ // Write user context to ensure acceptUnsetRatio is available
+ userCache, err := model.GetUserCache(userId)
+ if err != nil {
+ newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
+ return
+ }
+ userCache.WriteContext(c)
+
+ tempToken := &model.Token{
+ UserId: userId,
+ Name: fmt.Sprintf("playground-%s", group),
+ Group: group,
+ }
+ _ = middleware.SetupContextForToken(c, tempToken)
+ _, newAPIError = getChannel(c, group, playgroundRequest.Model, 0)
+ if newAPIError != nil {
+ return
+ }
+ //middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
+ common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
+
+ Relay(c)
+}
diff --git a/controller/pricing.go b/controller/pricing.go
new file mode 100644
index 00000000..f27336b7
--- /dev/null
+++ b/controller/pricing.go
@@ -0,0 +1,71 @@
+package controller
+
+import (
+ "one-api/model"
+ "one-api/setting"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GetPricing(c *gin.Context) {
+ pricing := model.GetPricing()
+ userId, exists := c.Get("id")
+ usableGroup := map[string]string{}
+ groupRatio := map[string]float64{}
+ for s, f := range ratio_setting.GetGroupRatioCopy() {
+ groupRatio[s] = f
+ }
+ var group string
+ if exists {
+ user, err := model.GetUserCache(userId.(int))
+ if err == nil {
+ group = user.Group
+ for g := range groupRatio {
+ ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
+ if ok {
+ groupRatio[g] = ratio
+ }
+ }
+ }
+ }
+
+ usableGroup = setting.GetUserUsableGroups(group)
+ // check groupRatio contains usableGroup
+ for group := range ratio_setting.GetGroupRatioCopy() {
+ if _, ok := usableGroup[group]; !ok {
+ delete(groupRatio, group)
+ }
+ }
+
+ c.JSON(200, gin.H{
+ "success": true,
+ "data": pricing,
+ "group_ratio": groupRatio,
+ "usable_group": usableGroup,
+ })
+}
+
+func ResetModelRatio(c *gin.Context) {
+ defaultStr := ratio_setting.DefaultModelRatio2JSONString()
+ err := model.UpdateOption("ModelRatio", defaultStr)
+ if err != nil {
+ c.JSON(200, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
+ if err != nil {
+ c.JSON(200, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(200, gin.H{
+ "success": true,
+ "message": "重置模型倍率成功",
+ })
+}
diff --git a/controller/ratio_config.go b/controller/ratio_config.go
new file mode 100644
index 00000000..6ddc3d9e
--- /dev/null
+++ b/controller/ratio_config.go
@@ -0,0 +1,24 @@
+package controller
+
+import (
+ "net/http"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GetRatioConfig(c *gin.Context) {
+ if !ratio_setting.IsExposeRatioEnabled() {
+ c.JSON(http.StatusForbidden, gin.H{
+ "success": false,
+ "message": "倍率配置接口未启用",
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": ratio_setting.GetExposedData(),
+ })
+}
\ No newline at end of file
diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go
new file mode 100644
index 00000000..0453870d
--- /dev/null
+++ b/controller/ratio_sync.go
@@ -0,0 +1,474 @@
+package controller
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+ "sync"
+ "time"
+
+ "one-api/common"
+ "one-api/dto"
+ "one-api/model"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ defaultTimeoutSeconds = 10
+ defaultEndpoint = "/api/ratio_config"
+ maxConcurrentFetches = 8
+)
+
+var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
+
+type upstreamResult struct {
+ Name string `json:"name"`
+ Data map[string]any `json:"data,omitempty"`
+ Err string `json:"err,omitempty"`
+}
+
+func FetchUpstreamRatios(c *gin.Context) {
+ var req dto.UpstreamRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+
+ if req.Timeout <= 0 {
+ req.Timeout = defaultTimeoutSeconds
+ }
+
+ var upstreams []dto.UpstreamDTO
+
+ if len(req.Upstreams) > 0 {
+ for _, u := range req.Upstreams {
+ if strings.HasPrefix(u.BaseURL, "http") {
+ if u.Endpoint == "" {
+ u.Endpoint = defaultEndpoint
+ }
+ u.BaseURL = strings.TrimRight(u.BaseURL, "/")
+ upstreams = append(upstreams, u)
+ }
+ }
+ } else if len(req.ChannelIDs) > 0 {
+ intIds := make([]int, 0, len(req.ChannelIDs))
+ for _, id64 := range req.ChannelIDs {
+ intIds = append(intIds, int(id64))
+ }
+ dbChannels, err := model.GetChannelsByIds(intIds)
+ if err != nil {
+ common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
+ c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
+ return
+ }
+ for _, ch := range dbChannels {
+ if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
+ upstreams = append(upstreams, dto.UpstreamDTO{
+ ID: ch.Id,
+ Name: ch.Name,
+ BaseURL: strings.TrimRight(base, "/"),
+ Endpoint: "",
+ })
+ }
+ }
+ }
+
+ if len(upstreams) == 0 {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
+ return
+ }
+
+ var wg sync.WaitGroup
+ ch := make(chan upstreamResult, len(upstreams))
+
+ sem := make(chan struct{}, maxConcurrentFetches)
+
+ client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
+
+ for _, chn := range upstreams {
+ wg.Add(1)
+ go func(chItem dto.UpstreamDTO) {
+ defer wg.Done()
+
+ sem <- struct{}{}
+ defer func() { <-sem }()
+
+ endpoint := chItem.Endpoint
+ if endpoint == "" {
+ endpoint = defaultEndpoint
+ } else if !strings.HasPrefix(endpoint, "/") {
+ endpoint = "/" + endpoint
+ }
+ fullURL := chItem.BaseURL + endpoint
+
+ uniqueName := chItem.Name
+ if chItem.ID != 0 {
+ uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
+ }
+
+ ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
+ defer cancel()
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
+ if err != nil {
+ common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
+ ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+ return
+ }
+
+ resp, err := client.Do(httpReq)
+ if err != nil {
+ common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
+ ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+ return
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
+ ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
+ return
+ }
+ // 兼容两种上游接口格式:
+ // type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
+ // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
+ var body struct {
+ Success bool `json:"success"`
+ Data json.RawMessage `json:"data"`
+ Message string `json:"message"`
+ }
+
+ if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
+ common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
+ ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
+ return
+ }
+
+ if !body.Success {
+ ch <- upstreamResult{Name: uniqueName, Err: body.Message}
+ return
+ }
+
+ // 尝试按 type1 解析
+ var type1Data map[string]any
+ if err := json.Unmarshal(body.Data, &type1Data); err == nil {
+ // 如果包含至少一个 ratioTypes 字段,则认为是 type1
+ isType1 := false
+ for _, rt := range ratioTypes {
+ if _, ok := type1Data[rt]; ok {
+ isType1 = true
+ break
+ }
+ }
+ if isType1 {
+ ch <- upstreamResult{Name: uniqueName, Data: type1Data}
+ return
+ }
+ }
+
+ // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
+ var pricingItems []struct {
+ ModelName string `json:"model_name"`
+ QuotaType int `json:"quota_type"`
+ ModelRatio float64 `json:"model_ratio"`
+ ModelPrice float64 `json:"model_price"`
+ CompletionRatio float64 `json:"completion_ratio"`
+ }
+ if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
+ common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
+ ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
+ return
+ }
+
+ modelRatioMap := make(map[string]float64)
+ completionRatioMap := make(map[string]float64)
+ modelPriceMap := make(map[string]float64)
+
+ for _, item := range pricingItems {
+ if item.QuotaType == 1 {
+ modelPriceMap[item.ModelName] = item.ModelPrice
+ } else {
+ modelRatioMap[item.ModelName] = item.ModelRatio
+ // completionRatio 可能为 0,此时也直接赋值,保持与上游一致
+ completionRatioMap[item.ModelName] = item.CompletionRatio
+ }
+ }
+
+ converted := make(map[string]any)
+
+ if len(modelRatioMap) > 0 {
+ ratioAny := make(map[string]any, len(modelRatioMap))
+ for k, v := range modelRatioMap {
+ ratioAny[k] = v
+ }
+ converted["model_ratio"] = ratioAny
+ }
+
+ if len(completionRatioMap) > 0 {
+ compAny := make(map[string]any, len(completionRatioMap))
+ for k, v := range completionRatioMap {
+ compAny[k] = v
+ }
+ converted["completion_ratio"] = compAny
+ }
+
+ if len(modelPriceMap) > 0 {
+ priceAny := make(map[string]any, len(modelPriceMap))
+ for k, v := range modelPriceMap {
+ priceAny[k] = v
+ }
+ converted["model_price"] = priceAny
+ }
+
+ ch <- upstreamResult{Name: uniqueName, Data: converted}
+ }(chn)
+ }
+
+ wg.Wait()
+ close(ch)
+
+ localData := ratio_setting.GetExposedData()
+
+ var testResults []dto.TestResult
+ var successfulChannels []struct {
+ name string
+ data map[string]any
+ }
+
+ for r := range ch {
+ if r.Err != "" {
+ testResults = append(testResults, dto.TestResult{
+ Name: r.Name,
+ Status: "error",
+ Error: r.Err,
+ })
+ } else {
+ testResults = append(testResults, dto.TestResult{
+ Name: r.Name,
+ Status: "success",
+ })
+ successfulChannels = append(successfulChannels, struct {
+ name string
+ data map[string]any
+ }{name: r.Name, data: r.Data})
+ }
+ }
+
+ differences := buildDifferences(localData, successfulChannels)
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "data": gin.H{
+ "differences": differences,
+ "test_results": testResults,
+ },
+ })
+}
+
+func buildDifferences(localData map[string]any, successfulChannels []struct {
+ name string
+ data map[string]any
+}) map[string]map[string]dto.DifferenceItem {
+ differences := make(map[string]map[string]dto.DifferenceItem)
+
+ allModels := make(map[string]struct{})
+
+ for _, ratioType := range ratioTypes {
+ if localRatioAny, ok := localData[ratioType]; ok {
+ if localRatio, ok := localRatioAny.(map[string]float64); ok {
+ for modelName := range localRatio {
+ allModels[modelName] = struct{}{}
+ }
+ }
+ }
+ }
+
+ for _, channel := range successfulChannels {
+ for _, ratioType := range ratioTypes {
+ if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
+ for modelName := range upstreamRatio {
+ allModels[modelName] = struct{}{}
+ }
+ }
+ }
+ }
+
+ confidenceMap := make(map[string]map[string]bool)
+
+ // 预处理阶段:检查pricing接口的可信度
+ for _, channel := range successfulChannels {
+ confidenceMap[channel.name] = make(map[string]bool)
+
+ modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
+ completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
+
+ if hasModelRatio && hasCompletionRatio {
+ // 遍历所有模型,检查是否满足不可信条件
+ for modelName := range allModels {
+ // 默认为可信
+ confidenceMap[channel.name][modelName] = true
+
+ // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
+ if modelRatioVal, ok := modelRatios[modelName]; ok {
+ if completionRatioVal, ok := completionRatios[modelName]; ok {
+ // 转换为float64进行比较
+ if modelRatioFloat, ok := modelRatioVal.(float64); ok {
+ if completionRatioFloat, ok := completionRatioVal.(float64); ok {
+ if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
+ confidenceMap[channel.name][modelName] = false
+ }
+ }
+ }
+ }
+ }
+ }
+ } else {
+ // 如果不是从pricing接口获取的数据,则全部标记为可信
+ for modelName := range allModels {
+ confidenceMap[channel.name][modelName] = true
+ }
+ }
+ }
+
+ for modelName := range allModels {
+ for _, ratioType := range ratioTypes {
+ var localValue interface{} = nil
+ if localRatioAny, ok := localData[ratioType]; ok {
+ if localRatio, ok := localRatioAny.(map[string]float64); ok {
+ if val, exists := localRatio[modelName]; exists {
+ localValue = val
+ }
+ }
+ }
+
+ upstreamValues := make(map[string]interface{})
+ confidenceValues := make(map[string]bool)
+ hasUpstreamValue := false
+ hasDifference := false
+
+ for _, channel := range successfulChannels {
+ var upstreamValue interface{} = nil
+
+ if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
+ if val, exists := upstreamRatio[modelName]; exists {
+ upstreamValue = val
+ hasUpstreamValue = true
+
+ if localValue != nil && localValue != val {
+ hasDifference = true
+ } else if localValue == val {
+ upstreamValue = "same"
+ }
+ }
+ }
+ if upstreamValue == nil && localValue == nil {
+ upstreamValue = "same"
+ }
+
+ if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
+ hasDifference = true
+ }
+
+ upstreamValues[channel.name] = upstreamValue
+
+ confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
+ }
+
+ shouldInclude := false
+
+ if localValue != nil {
+ if hasDifference {
+ shouldInclude = true
+ }
+ } else {
+ if hasUpstreamValue {
+ shouldInclude = true
+ }
+ }
+
+ if shouldInclude {
+ if differences[modelName] == nil {
+ differences[modelName] = make(map[string]dto.DifferenceItem)
+ }
+ differences[modelName][ratioType] = dto.DifferenceItem{
+ Current: localValue,
+ Upstreams: upstreamValues,
+ Confidence: confidenceValues,
+ }
+ }
+ }
+ }
+
+ channelHasDiff := make(map[string]bool)
+ for _, ratioMap := range differences {
+ for _, item := range ratioMap {
+ for chName, val := range item.Upstreams {
+ if val != nil && val != "same" {
+ channelHasDiff[chName] = true
+ }
+ }
+ }
+ }
+
+ for modelName, ratioMap := range differences {
+ for ratioType, item := range ratioMap {
+ for chName := range item.Upstreams {
+ if !channelHasDiff[chName] {
+ delete(item.Upstreams, chName)
+ delete(item.Confidence, chName)
+ }
+ }
+
+ allSame := true
+ for _, v := range item.Upstreams {
+ if v != "same" {
+ allSame = false
+ break
+ }
+ }
+ if len(item.Upstreams) == 0 || allSame {
+ delete(ratioMap, ratioType)
+ } else {
+ differences[modelName][ratioType] = item
+ }
+ }
+
+ if len(ratioMap) == 0 {
+ delete(differences, modelName)
+ }
+ }
+
+ return differences
+}
+
+func GetSyncableChannels(c *gin.Context) {
+ channels, err := model.GetAllChannels(0, 0, true, false)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ var syncableChannels []dto.SyncableChannel
+ for _, channel := range channels {
+ if channel.GetBaseURL() != "" {
+ syncableChannels = append(syncableChannels, dto.SyncableChannel{
+ ID: channel.Id,
+ Name: channel.Name,
+ BaseURL: channel.GetBaseURL(),
+ Status: channel.Status,
+ })
+ }
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": syncableChannels,
+ })
+}
\ No newline at end of file
diff --git a/controller/redemption.go b/controller/redemption.go
new file mode 100644
index 00000000..83ec19ad
--- /dev/null
+++ b/controller/redemption.go
@@ -0,0 +1,193 @@
+package controller
+
+import (
+ "errors"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "strconv"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GetAllRedemptions(c *gin.Context) {
+ pageInfo := common.GetPageQuery(c)
+ redemptions, total, err := model.GetAllRedemptions(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(redemptions)
+ common.ApiSuccess(c, pageInfo)
+ return
+}
+
+func SearchRedemptions(c *gin.Context) {
+ keyword := c.Query("keyword")
+ pageInfo := common.GetPageQuery(c)
+ redemptions, total, err := model.SearchRedemptions(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(redemptions)
+ common.ApiSuccess(c, pageInfo)
+ return
+}
+
+func GetRedemption(c *gin.Context) {
+ id, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ redemption, err := model.GetRedemptionById(id)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": redemption,
+ })
+ return
+}
+
+func AddRedemption(c *gin.Context) {
+ redemption := model.Redemption{}
+ err := c.ShouldBindJSON(&redemption)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if len(redemption.Name) == 0 || len(redemption.Name) > 20 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "兑换码名称长度必须在1-20之间",
+ })
+ return
+ }
+ if redemption.Count <= 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "兑换码个数必须大于0",
+ })
+ return
+ }
+ if redemption.Count > 100 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "一次兑换码批量生成的个数不能大于 100",
+ })
+ return
+ }
+ if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ var keys []string
+ for i := 0; i < redemption.Count; i++ {
+ key := common.GetUUID()
+ cleanRedemption := model.Redemption{
+ UserId: c.GetInt("id"),
+ Name: redemption.Name,
+ Key: key,
+ CreatedTime: common.GetTimestamp(),
+ Quota: redemption.Quota,
+ ExpiredTime: redemption.ExpiredTime,
+ }
+ err = cleanRedemption.Insert()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ "data": keys,
+ })
+ return
+ }
+ keys = append(keys, key)
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": keys,
+ })
+ return
+}
+
+func DeleteRedemption(c *gin.Context) {
+ id, _ := strconv.Atoi(c.Param("id"))
+ err := model.DeleteRedemptionById(id)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func UpdateRedemption(c *gin.Context) {
+ statusOnly := c.Query("status_only")
+ redemption := model.Redemption{}
+ err := c.ShouldBindJSON(&redemption)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ cleanRedemption, err := model.GetRedemptionById(redemption.Id)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if statusOnly == "" {
+ if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
+ c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
+ return
+ }
+ // If you add more fields, please also update redemption.Update()
+ cleanRedemption.Name = redemption.Name
+ cleanRedemption.Quota = redemption.Quota
+ cleanRedemption.ExpiredTime = redemption.ExpiredTime
+ }
+ if statusOnly != "" {
+ cleanRedemption.Status = redemption.Status
+ }
+ err = cleanRedemption.Update()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": cleanRedemption,
+ })
+ return
+}
+
+func DeleteInvalidRedemption(c *gin.Context) {
+ rows, err := model.DeleteInvalidRedemptions()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": rows,
+ })
+ return
+}
+
+func validateExpiredTime(expired int64) error {
+ if expired != 0 && expired < common.GetTimestamp() {
+ return errors.New("过期时间不能早于当前时间")
+ }
+ return nil
+}
diff --git a/controller/relay.go b/controller/relay.go
new file mode 100644
index 00000000..b224b42c
--- /dev/null
+++ b/controller/relay.go
@@ -0,0 +1,476 @@
+package controller
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ constant2 "one-api/constant"
+ "one-api/dto"
+ "one-api/middleware"
+ "one-api/model"
+ "one-api/relay"
+ relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+ "github.com/gorilla/websocket"
+)
+
+func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
+ var err *types.NewAPIError
+ switch relayMode {
+ case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
+ err = relay.ImageHelper(c)
+ case relayconstant.RelayModeAudioSpeech:
+ fallthrough
+ case relayconstant.RelayModeAudioTranslation:
+ fallthrough
+ case relayconstant.RelayModeAudioTranscription:
+ err = relay.AudioHelper(c)
+ case relayconstant.RelayModeRerank:
+ err = relay.RerankHelper(c, relayMode)
+ case relayconstant.RelayModeEmbeddings:
+ err = relay.EmbeddingHelper(c)
+ case relayconstant.RelayModeResponses:
+ err = relay.ResponsesHelper(c)
+ case relayconstant.RelayModeGemini:
+ err = relay.GeminiHelper(c)
+ default:
+ err = relay.TextHelper(c)
+ }
+
+ if constant2.ErrorLogEnabled && err != nil {
+ // 保存错误日志到mysql中
+ userId := c.GetInt("id")
+ tokenName := c.GetString("token_name")
+ modelName := c.GetString("original_model")
+ tokenId := c.GetInt("token_id")
+ userGroup := c.GetString("group")
+ channelId := c.GetInt("channel_id")
+ other := make(map[string]interface{})
+ other["error_type"] = err.ErrorType
+ other["error_code"] = err.GetErrorCode()
+ other["status_code"] = err.StatusCode
+ other["channel_id"] = channelId
+ other["channel_name"] = c.GetString("channel_name")
+ other["channel_type"] = c.GetInt("channel_type")
+
+ model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other)
+ }
+
+ return err
+}
+
+func Relay(c *gin.Context) {
+ relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
+ requestId := c.GetString(common.RequestIdKey)
+ group := c.GetString("group")
+ originalModel := c.GetString("original_model")
+ var newAPIError *types.NewAPIError
+
+ for i := 0; i <= common.RetryTimes; i++ {
+ channel, err := getChannel(c, group, originalModel, i)
+ if err != nil {
+ common.LogError(c, err.Error())
+ newAPIError = err
+ break
+ }
+
+ newAPIError = relayRequest(c, relayMode, channel)
+
+ if newAPIError == nil {
+ return // 成功处理请求,直接返回
+ }
+
+ go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
+
+ if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
+ break
+ }
+ }
+ useChannel := c.GetStringSlice("use_channel")
+ if len(useChannel) > 1 {
+ retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
+ common.LogInfo(c, retryLogStr)
+ }
+
+ if newAPIError != nil {
+ //if newAPIError.StatusCode == http.StatusTooManyRequests {
+ // common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
+ // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
+ //}
+ newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
+ c.JSON(newAPIError.StatusCode, gin.H{
+ "error": newAPIError.ToOpenAIError(),
+ })
+ }
+}
+
+var upgrader = websocket.Upgrader{
+ Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
+ CheckOrigin: func(r *http.Request) bool {
+ return true // 允许跨域
+ },
+}
+
+func WssRelay(c *gin.Context) {
+ // 将 HTTP 连接升级为 WebSocket 连接
+
+ ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
+ defer ws.Close()
+
+ if err != nil {
+ helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError())
+ return
+ }
+
+ relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
+ requestId := c.GetString(common.RequestIdKey)
+ group := c.GetString("group")
+ //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
+ originalModel := c.GetString("original_model")
+ var newAPIError *types.NewAPIError
+
+ for i := 0; i <= common.RetryTimes; i++ {
+ channel, err := getChannel(c, group, originalModel, i)
+ if err != nil {
+ common.LogError(c, err.Error())
+ newAPIError = err
+ break
+ }
+
+ newAPIError = wssRequest(c, ws, relayMode, channel)
+
+ if newAPIError == nil {
+ return // 成功处理请求,直接返回
+ }
+
+ go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
+
+ if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
+ break
+ }
+ }
+ useChannel := c.GetStringSlice("use_channel")
+ if len(useChannel) > 1 {
+ retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
+ common.LogInfo(c, retryLogStr)
+ }
+
+ if newAPIError != nil {
+ //if newAPIError.StatusCode == http.StatusTooManyRequests {
+ // newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
+ //}
+ newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
+ helper.WssError(c, ws, newAPIError.ToOpenAIError())
+ }
+}
+
+func RelayClaude(c *gin.Context) {
+ //relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+ requestId := c.GetString(common.RequestIdKey)
+ group := c.GetString("group")
+ originalModel := c.GetString("original_model")
+ var newAPIError *types.NewAPIError
+
+ for i := 0; i <= common.RetryTimes; i++ {
+ channel, err := getChannel(c, group, originalModel, i)
+ if err != nil {
+ common.LogError(c, err.Error())
+ newAPIError = err
+ break
+ }
+
+ newAPIError = claudeRequest(c, channel)
+
+ if newAPIError == nil {
+ return // 成功处理请求,直接返回
+ }
+
+ go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
+
+ if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
+ break
+ }
+ }
+ useChannel := c.GetStringSlice("use_channel")
+ if len(useChannel) > 1 {
+ retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
+ common.LogInfo(c, retryLogStr)
+ }
+
+ if newAPIError != nil {
+ newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
+ c.JSON(newAPIError.StatusCode, gin.H{
+ "type": "error",
+ "error": newAPIError.ToClaudeError(),
+ })
+ }
+}
+
+func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
+ addUsedChannel(c, channel.Id)
+ requestBody, _ := common.GetRequestBody(c)
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+ return relayHandler(c, relayMode)
+}
+
+func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
+ addUsedChannel(c, channel.Id)
+ requestBody, _ := common.GetRequestBody(c)
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+ return relay.WssHelper(c, ws)
+}
+
+func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
+ addUsedChannel(c, channel.Id)
+ requestBody, _ := common.GetRequestBody(c)
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+ return relay.ClaudeHelper(c)
+}
+
+func addUsedChannel(c *gin.Context, channelId int) {
+ useChannel := c.GetStringSlice("use_channel")
+ useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
+ c.Set("use_channel", useChannel)
+}
+
+func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
+ if retryCount == 0 {
+ autoBan := c.GetBool("auto_ban")
+ autoBanInt := 1
+ if !autoBan {
+ autoBanInt = 0
+ }
+ return &model.Channel{
+ Id: c.GetInt("channel_id"),
+ Type: c.GetInt("channel_type"),
+ Name: c.GetString("channel_name"),
+ AutoBan: &autoBanInt,
+ }, nil
+ }
+ channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
+ if err != nil {
+ if group == "auto" {
+ return nil, types.NewError(errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
+ }
+ return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
+ }
+ newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+ if newAPIError != nil {
+ return nil, newAPIError
+ }
+ return channel, nil
+}
+
+func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
+ if openaiErr == nil {
+ return false
+ }
+ if types.IsChannelError(openaiErr) {
+ return true
+ }
+ if types.IsLocalError(openaiErr) {
+ return false
+ }
+ if retryTimes <= 0 {
+ return false
+ }
+ if _, ok := c.Get("specific_channel_id"); ok {
+ return false
+ }
+ if openaiErr.StatusCode == http.StatusTooManyRequests {
+ return true
+ }
+ if openaiErr.StatusCode == 307 {
+ return true
+ }
+ if openaiErr.StatusCode/100 == 5 {
+ // 超时不重试
+ if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
+ return false
+ }
+ return true
+ }
+ if openaiErr.StatusCode == http.StatusBadRequest {
+ channelType := c.GetInt("channel_type")
+ if channelType == constant.ChannelTypeAnthropic {
+ return true
+ }
+ return false
+ }
+ if openaiErr.StatusCode == 408 {
+ // azure处理超时不重试
+ return false
+ }
+ if openaiErr.StatusCode/100 == 2 {
+ return false
+ }
+ return true
+}
+
+func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
+ // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
+ // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
+ common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
+ if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
+ service.DisableChannel(channelError, err.Error())
+ }
+}
+
+func RelayMidjourney(c *gin.Context) {
+ relayMode := c.GetInt("relay_mode")
+ var err *dto.MidjourneyResponse
+ switch relayMode {
+ case relayconstant.RelayModeMidjourneyNotify:
+ err = relay.RelayMidjourneyNotify(c)
+ case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
+ err = relay.RelayMidjourneyTask(c, relayMode)
+ case relayconstant.RelayModeMidjourneyTaskImageSeed:
+ err = relay.RelayMidjourneyTaskImageSeed(c)
+ case relayconstant.RelayModeSwapFace:
+ err = relay.RelaySwapFace(c)
+ default:
+ err = relay.RelayMidjourneySubmit(c, relayMode)
+ }
+ //err = relayMidjourneySubmit(c, relayMode)
+ log.Println(err)
+ if err != nil {
+ statusCode := http.StatusBadRequest
+ if err.Code == 30 {
+ err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
+ statusCode = http.StatusTooManyRequests
+ }
+ c.JSON(statusCode, gin.H{
+ "description": fmt.Sprintf("%s %s", err.Description, err.Result),
+ "type": "upstream_error",
+ "code": err.Code,
+ })
+ channelId := c.GetInt("channel_id")
+ common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
+ }
+}
+
+func RelayNotImplemented(c *gin.Context) {
+ err := dto.OpenAIError{
+ Message: "API not implemented",
+ Type: "new_api_error",
+ Param: "",
+ Code: "api_not_implemented",
+ }
+ c.JSON(http.StatusNotImplemented, gin.H{
+ "error": err,
+ })
+}
+
+func RelayNotFound(c *gin.Context) {
+ err := dto.OpenAIError{
+ Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
+ Type: "invalid_request_error",
+ Param: "",
+ Code: "",
+ }
+ c.JSON(http.StatusNotFound, gin.H{
+ "error": err,
+ })
+}
+
+func RelayTask(c *gin.Context) {
+ retryTimes := common.RetryTimes
+ channelId := c.GetInt("channel_id")
+ relayMode := c.GetInt("relay_mode")
+ group := c.GetString("group")
+ originalModel := c.GetString("original_model")
+ c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
+ taskErr := taskRelayHandler(c, relayMode)
+ if taskErr == nil {
+ retryTimes = 0
+ }
+ for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
+ channel, newAPIError := getChannel(c, group, originalModel, i)
+ if newAPIError != nil {
+ common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
+ taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
+ break
+ }
+ channelId = channel.Id
+ useChannel := c.GetStringSlice("use_channel")
+ useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
+ c.Set("use_channel", useChannel)
+ common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
+ //middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+
+ requestBody, _ := common.GetRequestBody(c)
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+ taskErr = taskRelayHandler(c, relayMode)
+ }
+ useChannel := c.GetStringSlice("use_channel")
+ if len(useChannel) > 1 {
+ retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
+ common.LogInfo(c, retryLogStr)
+ }
+ if taskErr != nil {
+ if taskErr.StatusCode == http.StatusTooManyRequests {
+ taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
+ }
+ c.JSON(taskErr.StatusCode, taskErr)
+ }
+}
+
+func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
+ var err *dto.TaskError
+ switch relayMode {
+ case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
+ err = relay.RelayTaskFetch(c, relayMode)
+ default:
+ err = relay.RelayTaskSubmit(c, relayMode)
+ }
+ return err
+}
+
+func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
+ if taskErr == nil {
+ return false
+ }
+ if retryTimes <= 0 {
+ return false
+ }
+ if _, ok := c.Get("specific_channel_id"); ok {
+ return false
+ }
+ if taskErr.StatusCode == http.StatusTooManyRequests {
+ return true
+ }
+ if taskErr.StatusCode == 307 {
+ return true
+ }
+ if taskErr.StatusCode/100 == 5 {
+ // 超时不重试
+ if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
+ return false
+ }
+ return true
+ }
+ if taskErr.StatusCode == http.StatusBadRequest {
+ return false
+ }
+ if taskErr.StatusCode == 408 {
+ // azure处理超时不重试
+ return false
+ }
+ if taskErr.LocalError {
+ return false
+ }
+ if taskErr.StatusCode/100 == 2 {
+ return false
+ }
+ return true
+}
diff --git a/controller/setup.go b/controller/setup.go
new file mode 100644
index 00000000..8943a1a0
--- /dev/null
+++ b/controller/setup.go
@@ -0,0 +1,181 @@
+package controller
+
+import (
+ "github.com/gin-gonic/gin"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/model"
+ "one-api/setting/operation_setting"
+ "time"
+)
+
+type Setup struct {
+ Status bool `json:"status"`
+ RootInit bool `json:"root_init"`
+ DatabaseType string `json:"database_type"`
+}
+
+type SetupRequest struct {
+ Username string `json:"username"`
+ Password string `json:"password"`
+ ConfirmPassword string `json:"confirmPassword"`
+ SelfUseModeEnabled bool `json:"SelfUseModeEnabled"`
+ DemoSiteEnabled bool `json:"DemoSiteEnabled"`
+}
+
+func GetSetup(c *gin.Context) {
+ setup := Setup{
+ Status: constant.Setup,
+ }
+ if constant.Setup {
+ c.JSON(200, gin.H{
+ "success": true,
+ "data": setup,
+ })
+ return
+ }
+ setup.RootInit = model.RootUserExists()
+ if common.UsingMySQL {
+ setup.DatabaseType = "mysql"
+ }
+ if common.UsingPostgreSQL {
+ setup.DatabaseType = "postgres"
+ }
+ if common.UsingSQLite {
+ setup.DatabaseType = "sqlite"
+ }
+ c.JSON(200, gin.H{
+ "success": true,
+ "data": setup,
+ })
+}
+
+func PostSetup(c *gin.Context) {
+ // Check if setup is already completed
+ if constant.Setup {
+ c.JSON(400, gin.H{
+ "success": false,
+ "message": "系统已经初始化完成",
+ })
+ return
+ }
+
+ // Check if root user already exists
+ rootExists := model.RootUserExists()
+
+ var req SetupRequest
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ c.JSON(400, gin.H{
+ "success": false,
+ "message": "请求参数有误",
+ })
+ return
+ }
+
+ // If root doesn't exist, validate and create admin account
+ if !rootExists {
+ // Validate username length: max 12 characters to align with model.User validation
+ if len(req.Username) > 12 {
+ c.JSON(400, gin.H{
+ "success": false,
+ "message": "用户名长度不能超过12个字符",
+ })
+ return
+ }
+ // Validate password
+ if req.Password != req.ConfirmPassword {
+ c.JSON(400, gin.H{
+ "success": false,
+ "message": "两次输入的密码不一致",
+ })
+ return
+ }
+
+ if len(req.Password) < 8 {
+ c.JSON(400, gin.H{
+ "success": false,
+ "message": "密码长度至少为8个字符",
+ })
+ return
+ }
+
+ // Create root user
+ hashedPassword, err := common.Password2Hash(req.Password)
+ if err != nil {
+ c.JSON(500, gin.H{
+ "success": false,
+ "message": "系统错误: " + err.Error(),
+ })
+ return
+ }
+ rootUser := model.User{
+ Username: req.Username,
+ Password: hashedPassword,
+ Role: common.RoleRootUser,
+ Status: common.UserStatusEnabled,
+ DisplayName: "Root User",
+ AccessToken: nil,
+ Quota: 100000000,
+ }
+ err = model.DB.Create(&rootUser).Error
+ if err != nil {
+ c.JSON(500, gin.H{
+ "success": false,
+ "message": "创建管理员账号失败: " + err.Error(),
+ })
+ return
+ }
+ }
+
+ // Set operation modes
+ operation_setting.SelfUseModeEnabled = req.SelfUseModeEnabled
+ operation_setting.DemoSiteEnabled = req.DemoSiteEnabled
+
+ // Save operation modes to database for persistence
+ err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
+ if err != nil {
+ c.JSON(500, gin.H{
+ "success": false,
+ "message": "保存自用模式设置失败: " + err.Error(),
+ })
+ return
+ }
+
+ err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
+ if err != nil {
+ c.JSON(500, gin.H{
+ "success": false,
+ "message": "保存演示站点模式设置失败: " + err.Error(),
+ })
+ return
+ }
+
+ // Update setup status
+ constant.Setup = true
+
+ setup := model.Setup{
+ Version: common.Version,
+ InitializedAt: time.Now().Unix(),
+ }
+ err = model.DB.Create(&setup).Error
+ if err != nil {
+ c.JSON(500, gin.H{
+ "success": false,
+ "message": "系统初始化失败: " + err.Error(),
+ })
+ return
+ }
+
+ c.JSON(200, gin.H{
+ "success": true,
+ "message": "系统初始化成功",
+ })
+}
+
+func boolToString(b bool) string {
+ if b {
+ return "true"
+ }
+ return "false"
+}
diff --git a/controller/swag_video.go b/controller/swag_video.go
new file mode 100644
index 00000000..185fd515
--- /dev/null
+++ b/controller/swag_video.go
@@ -0,0 +1,116 @@
+package controller
+
+import (
+ "github.com/gin-gonic/gin"
+)
+
+// VideoGenerations
+// @Summary 生成视频
+// @Description 调用视频生成接口生成视频
+// @Description 支持多种视频生成服务:
+// @Description - 可灵AI (Kling): https://app.klingai.com/cn/dev/document-api/apiReference/commonInfo
+// @Description - 即梦 (Jimeng): https://www.volcengine.com/docs/85621/1538636
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
+// @Param request body dto.VideoRequest true "视频生成请求参数"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /v1/video/generations [post]
+func VideoGenerations(c *gin.Context) {
+}
+
+// VideoGenerationsTaskId
+// @Summary 查询视频
+// @Description 根据任务ID查询视频生成任务的状态和结果
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Security BearerAuth
+// @Param task_id path string true "Task ID"
+// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /v1/video/generations/{task_id} [get]
+func VideoGenerationsTaskId(c *gin.Context) {
+}
+
+// KlingText2VideoGenerations
+// @Summary 可灵文生视频
+// @Description 调用可灵AI文生视频接口,生成视频内容
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
+// @Param request body KlingText2VideoRequest true "视频生成请求参数"
+// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /kling/v1/videos/text2video [post]
+func KlingText2VideoGenerations(c *gin.Context) {
+}
+
+type KlingText2VideoRequest struct {
+ ModelName string `json:"model_name,omitempty" example:"kling-v1"`
+ Prompt string `json:"prompt" binding:"required" example:"A cat playing piano in the garden"`
+ NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
+ CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
+ Mode string `json:"mode,omitempty" example:"std"`
+ CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
+ AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
+ Duration string `json:"duration,omitempty" example:"5"`
+ CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
+ ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-001"`
+}
+
+type KlingCameraControl struct {
+ Type string `json:"type,omitempty" example:"simple"`
+ Config *KlingCameraConfig `json:"config,omitempty"`
+}
+
+type KlingCameraConfig struct {
+ Horizontal float64 `json:"horizontal,omitempty" example:"2.5"`
+ Vertical float64 `json:"vertical,omitempty" example:"0"`
+ Pan float64 `json:"pan,omitempty" example:"0"`
+ Tilt float64 `json:"tilt,omitempty" example:"0"`
+ Roll float64 `json:"roll,omitempty" example:"0"`
+ Zoom float64 `json:"zoom,omitempty" example:"0"`
+}
+
+// KlingImage2VideoGenerations
+// @Summary 可灵官方-图生视频
+// @Description 调用可灵AI图生视频接口,生成视频内容
+// @Tags Video
+// @Accept json
+// @Produce json
+// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
+// @Param request body KlingImage2VideoRequest true "图生视频请求参数"
+// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
+// @Failure 400 {object} dto.OpenAIError "请求参数错误"
+// @Failure 401 {object} dto.OpenAIError "未授权"
+// @Failure 403 {object} dto.OpenAIError "无权限"
+// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
+// @Router /kling/v1/videos/image2video [post]
+func KlingImage2VideoGenerations(c *gin.Context) {
+}
+
+type KlingImage2VideoRequest struct {
+ ModelName string `json:"model_name,omitempty" example:"kling-v2-master"`
+ Image string `json:"image" binding:"required" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"`
+ Prompt string `json:"prompt,omitempty" example:"A cat playing piano in the garden"`
+ NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
+ CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
+ Mode string `json:"mode,omitempty" example:"std"`
+ CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
+ AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
+ Duration string `json:"duration,omitempty" example:"5"`
+ CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
+ ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-002"`
+}
diff --git a/controller/task.go b/controller/task.go
new file mode 100644
index 00000000..78674d8b
--- /dev/null
+++ b/controller/task.go
@@ -0,0 +1,273 @@
+package controller
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/model"
+ "one-api/relay"
+ "sort"
+ "strconv"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/samber/lo"
+)
+
+func UpdateTaskBulk() {
+ //revocer
+ //imageModel := "midjourney"
+ for {
+ time.Sleep(time.Duration(15) * time.Second)
+ common.SysLog("任务进度轮询开始")
+ ctx := context.TODO()
+ allTasks := model.GetAllUnFinishSyncTasks(500)
+ platformTask := make(map[constant.TaskPlatform][]*model.Task)
+ for _, t := range allTasks {
+ platformTask[t.Platform] = append(platformTask[t.Platform], t)
+ }
+ for platform, tasks := range platformTask {
+ if len(tasks) == 0 {
+ continue
+ }
+ taskChannelM := make(map[int][]string)
+ taskM := make(map[string]*model.Task)
+ nullTaskIds := make([]int64, 0)
+ for _, task := range tasks {
+ if task.TaskID == "" {
+ // 统计失败的未完成任务
+ nullTaskIds = append(nullTaskIds, task.ID)
+ continue
+ }
+ taskM[task.TaskID] = task
+ taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
+ }
+ if len(nullTaskIds) > 0 {
+ err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
+ "status": "FAILURE",
+ "progress": "100%",
+ })
+ if err != nil {
+ common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
+ } else {
+ common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
+ }
+ }
+ if len(taskChannelM) == 0 {
+ continue
+ }
+
+ UpdateTaskByPlatform(platform, taskChannelM, taskM)
+ }
+ common.SysLog("任务进度轮询完成")
+ }
+}
+
+func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
+ switch platform {
+ case constant.TaskPlatformMidjourney:
+ //_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
+ case constant.TaskPlatformSuno:
+ _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
+ case constant.TaskPlatformKling, constant.TaskPlatformJimeng:
+ _ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM)
+ default:
+ common.SysLog("未知平台")
+ }
+}
+
+func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+ for channelId, taskIds := range taskChannelM {
+ err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
+ if err != nil {
+ common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
+ }
+ }
+ return nil
+}
+
+func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+ common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+ if len(taskIds) == 0 {
+ return nil
+ }
+ channel, err := model.CacheGetChannel(channelId)
+ if err != nil {
+ common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
+ err = model.TaskBulkUpdate(taskIds, map[string]any{
+ "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
+ "status": "FAILURE",
+ "progress": "100%",
+ })
+ if err != nil {
+ common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
+ }
+ return err
+ }
+ adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
+ if adaptor == nil {
+ return errors.New("adaptor not found")
+ }
+ resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
+ "ids": taskIds,
+ })
+ if err != nil {
+ common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
+ return err
+ }
+ if resp.StatusCode != http.StatusOK {
+ common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+ return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+ }
+ defer resp.Body.Close()
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
+ return err
+ }
+ var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
+ err = json.Unmarshal(responseBody, &responseItems)
+ if err != nil {
+ common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
+ return err
+ }
+ if !responseItems.IsSuccess() {
+ common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
+ return err
+ }
+
+ for _, responseItem := range responseItems.Data {
+ task := taskM[responseItem.TaskID]
+ if !checkTaskNeedUpdate(task, responseItem) {
+ continue
+ }
+
+ task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
+ task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
+ task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
+ task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
+ task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
+ if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
+ common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
+ task.Progress = "100%"
+ //err = model.CacheUpdateUserQuota(task.UserId) ?
+ if err != nil {
+ common.LogError(ctx, "error update user quota cache: "+err.Error())
+ } else {
+ quota := task.Quota
+ if quota != 0 {
+ err = model.IncreaseUserQuota(task.UserId, quota, false)
+ if err != nil {
+ common.LogError(ctx, "fail to increase user quota: "+err.Error())
+ }
+ logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
+ model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+ }
+ }
+ }
+ if responseItem.Status == model.TaskStatusSuccess {
+ task.Progress = "100%"
+ }
+ task.Data = responseItem.Data
+
+ err = task.Update()
+ if err != nil {
+ common.SysError("UpdateMidjourneyTask task error: " + err.Error())
+ }
+ }
+ return nil
+}
+
+func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
+
+ if oldTask.SubmitTime != newTask.SubmitTime {
+ return true
+ }
+ if oldTask.StartTime != newTask.StartTime {
+ return true
+ }
+ if oldTask.FinishTime != newTask.FinishTime {
+ return true
+ }
+ if string(oldTask.Status) != newTask.Status {
+ return true
+ }
+ if oldTask.FailReason != newTask.FailReason {
+ return true
+ }
+ if oldTask.FinishTime != newTask.FinishTime {
+ return true
+ }
+
+ if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
+ return true
+ }
+
+ oldData, _ := json.Marshal(oldTask.Data)
+ newData, _ := json.Marshal(newTask.Data)
+
+ sort.Slice(oldData, func(i, j int) bool {
+ return oldData[i] < oldData[j]
+ })
+ sort.Slice(newData, func(i, j int) bool {
+ return newData[i] < newData[j]
+ })
+
+ if string(oldData) != string(newData) {
+ return true
+ }
+ return false
+}
+
+func GetAllTask(c *gin.Context) {
+ pageInfo := common.GetPageQuery(c)
+
+ startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+ endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+ // 解析其他查询参数
+ queryParams := model.SyncTaskQueryParams{
+ Platform: constant.TaskPlatform(c.Query("platform")),
+ TaskID: c.Query("task_id"),
+ Status: c.Query("status"),
+ Action: c.Query("action"),
+ StartTimestamp: startTimestamp,
+ EndTimestamp: endTimestamp,
+ ChannelID: c.Query("channel_id"),
+ }
+
+ items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
+ total := model.TaskCountAllTasks(queryParams)
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(items)
+ common.ApiSuccess(c, pageInfo)
+}
+
+func GetUserTask(c *gin.Context) {
+ pageInfo := common.GetPageQuery(c)
+
+ userId := c.GetInt("id")
+
+ startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+ endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+
+ queryParams := model.SyncTaskQueryParams{
+ Platform: constant.TaskPlatform(c.Query("platform")),
+ TaskID: c.Query("task_id"),
+ Status: c.Query("status"),
+ Action: c.Query("action"),
+ StartTimestamp: startTimestamp,
+ EndTimestamp: endTimestamp,
+ }
+
+ items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
+ total := model.TaskCountAllUserTask(userId, queryParams)
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(items)
+ common.ApiSuccess(c, pageInfo)
+}
diff --git a/controller/task_video.go b/controller/task_video.go
new file mode 100644
index 00000000..b62978a7
--- /dev/null
+++ b/controller/task_video.go
@@ -0,0 +1,138 @@
+package controller
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/model"
+ "one-api/relay"
+ "one-api/relay/channel"
+ "time"
+)
+
+func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+ for channelId, taskIds := range taskChannelM {
+ if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
+ common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
+ }
+ }
+ return nil
+}
+
+func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+ common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
+ if len(taskIds) == 0 {
+ return nil
+ }
+ cacheGetChannel, err := model.CacheGetChannel(channelId)
+ if err != nil {
+ errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
+ "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
+ "status": "FAILURE",
+ "progress": "100%",
+ })
+ if errUpdate != nil {
+ common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
+ }
+ return fmt.Errorf("CacheGetChannel failed: %w", err)
+ }
+ adaptor := relay.GetTaskAdaptor(platform)
+ if adaptor == nil {
+ return fmt.Errorf("video adaptor not found")
+ }
+ for _, taskId := range taskIds {
+ if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
+ common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
+ }
+ }
+ return nil
+}
+
+func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
+ baseURL := constant.ChannelBaseURLs[channel.Type]
+ if channel.GetBaseURL() != "" {
+ baseURL = channel.GetBaseURL()
+ }
+
+ task := taskM[taskId]
+ if task == nil {
+ common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
+ return fmt.Errorf("task %s not found", taskId)
+ }
+ resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
+ "task_id": taskId,
+ "action": task.Action,
+ })
+ if err != nil {
+ return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
+ }
+ //if resp.StatusCode != http.StatusOK {
+ //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
+ //}
+ defer resp.Body.Close()
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
+ }
+
+ taskResult, err := adaptor.ParseTaskResult(responseBody)
+ if err != nil {
+ return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
+ }
+ //if taskResult.Code != 0 {
+ // return fmt.Errorf("video task fetch failed for task %s", taskId)
+ //}
+
+ now := time.Now().Unix()
+ if taskResult.Status == "" {
+ return fmt.Errorf("task %s status is empty", taskId)
+ }
+ task.Status = model.TaskStatus(taskResult.Status)
+ switch taskResult.Status {
+ case model.TaskStatusSubmitted:
+ task.Progress = "10%"
+ case model.TaskStatusQueued:
+ task.Progress = "20%"
+ case model.TaskStatusInProgress:
+ task.Progress = "30%"
+ if task.StartTime == 0 {
+ task.StartTime = now
+ }
+ case model.TaskStatusSuccess:
+ task.Progress = "100%"
+ if task.FinishTime == 0 {
+ task.FinishTime = now
+ }
+ task.FailReason = taskResult.Url
+ case model.TaskStatusFailure:
+ task.Status = model.TaskStatusFailure
+ task.Progress = "100%"
+ if task.FinishTime == 0 {
+ task.FinishTime = now
+ }
+ task.FailReason = taskResult.Reason
+ common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
+ quota := task.Quota
+ if quota != 0 {
+ if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
+ common.LogError(ctx, "Failed to increase user quota: "+err.Error())
+ }
+ logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
+ model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+ }
+ default:
+ return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
+ }
+ if taskResult.Progress != "" {
+ task.Progress = taskResult.Progress
+ }
+
+ task.Data = responseBody
+ if err := task.Update(); err != nil {
+ common.SysError("UpdateVideoTask task error: " + err.Error())
+ }
+
+ return nil
+}
diff --git a/controller/telegram.go b/controller/telegram.go
new file mode 100644
index 00000000..8d07fc94
--- /dev/null
+++ b/controller/telegram.go
@@ -0,0 +1,124 @@
+package controller
+
+import (
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/hex"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "sort"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+)
+
+func TelegramBind(c *gin.Context) {
+ if !common.TelegramOAuthEnabled {
+ c.JSON(200, gin.H{
+ "message": "管理员未开启通过 Telegram 登录以及注册",
+ "success": false,
+ })
+ return
+ }
+ params := c.Request.URL.Query()
+ if !checkTelegramAuthorization(params, common.TelegramBotToken) {
+ c.JSON(200, gin.H{
+ "message": "无效的请求",
+ "success": false,
+ })
+ return
+ }
+ telegramId := params["id"][0]
+ if model.IsTelegramIdAlreadyTaken(telegramId) {
+ c.JSON(200, gin.H{
+ "message": "该 Telegram 账户已被绑定",
+ "success": false,
+ })
+ return
+ }
+
+ session := sessions.Default(c)
+ id := session.Get("id")
+ user := model.User{Id: id.(int)}
+ if err := user.FillUserById(); err != nil {
+ c.JSON(200, gin.H{
+ "message": err.Error(),
+ "success": false,
+ })
+ return
+ }
+ if user.Id == 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "用户已注销",
+ })
+ return
+ }
+ user.TelegramId = telegramId
+ if err := user.Update(false); err != nil {
+ c.JSON(200, gin.H{
+ "message": err.Error(),
+ "success": false,
+ })
+ return
+ }
+
+ c.Redirect(302, "/setting")
+}
+
+func TelegramLogin(c *gin.Context) {
+ if !common.TelegramOAuthEnabled {
+ c.JSON(200, gin.H{
+ "message": "管理员未开启通过 Telegram 登录以及注册",
+ "success": false,
+ })
+ return
+ }
+ params := c.Request.URL.Query()
+ if !checkTelegramAuthorization(params, common.TelegramBotToken) {
+ c.JSON(200, gin.H{
+ "message": "无效的请求",
+ "success": false,
+ })
+ return
+ }
+
+ telegramId := params["id"][0]
+ user := model.User{TelegramId: telegramId}
+ if err := user.FillUserByTelegramId(); err != nil {
+ c.JSON(200, gin.H{
+ "message": err.Error(),
+ "success": false,
+ })
+ return
+ }
+ setupLogin(&user, c)
+}
+
+func checkTelegramAuthorization(params map[string][]string, token string) bool {
+ strs := []string{}
+ var hash = ""
+ for k, v := range params {
+ if k == "hash" {
+ hash = v[0]
+ continue
+ }
+ strs = append(strs, k+"="+v[0])
+ }
+ sort.Strings(strs)
+ var imploded = ""
+ for _, s := range strs {
+ if imploded != "" {
+ imploded += "\n"
+ }
+ imploded += s
+ }
+ sha256hash := sha256.New()
+ io.WriteString(sha256hash, token)
+ hmachash := hmac.New(sha256.New, sha256hash.Sum(nil))
+ io.WriteString(hmachash, imploded)
+ ss := hex.EncodeToString(hmachash.Sum(nil))
+ return hash == ss
+}
diff --git a/controller/token.go b/controller/token.go
new file mode 100644
index 00000000..62eb5474
--- /dev/null
+++ b/controller/token.go
@@ -0,0 +1,236 @@
+package controller
+
+import (
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "strconv"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GetAllTokens(c *gin.Context) {
+ userId := c.GetInt("id")
+ pageInfo := common.GetPageQuery(c)
+ tokens, err := model.GetAllUserTokens(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ total, _ := model.CountUserTokens(userId)
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(tokens)
+ common.ApiSuccess(c, pageInfo)
+ return
+}
+
+func SearchTokens(c *gin.Context) {
+ userId := c.GetInt("id")
+ keyword := c.Query("keyword")
+ token := c.Query("token")
+ tokens, err := model.SearchUserTokens(userId, keyword, token)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": tokens,
+ })
+ return
+}
+
+func GetToken(c *gin.Context) {
+ id, err := strconv.Atoi(c.Param("id"))
+ userId := c.GetInt("id")
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ token, err := model.GetTokenByIds(id, userId)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": token,
+ })
+ return
+}
+
+func GetTokenStatus(c *gin.Context) {
+ tokenId := c.GetInt("token_id")
+ userId := c.GetInt("id")
+ token, err := model.GetTokenByIds(tokenId, userId)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ expiredAt := token.ExpiredTime
+ if expiredAt == -1 {
+ expiredAt = 0
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "object": "credit_summary",
+ "total_granted": token.RemainQuota,
+ "total_used": 0, // not supported currently
+ "total_available": token.RemainQuota,
+ "expires_at": expiredAt * 1000,
+ })
+}
+
+func AddToken(c *gin.Context) {
+ token := model.Token{}
+ err := c.ShouldBindJSON(&token)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if len(token.Name) > 30 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "令牌名称过长",
+ })
+ return
+ }
+ key, err := common.GenerateKey()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "生成令牌失败",
+ })
+ common.SysError("failed to generate token key: " + err.Error())
+ return
+ }
+ cleanToken := model.Token{
+ UserId: c.GetInt("id"),
+ Name: token.Name,
+ Key: key,
+ CreatedTime: common.GetTimestamp(),
+ AccessedTime: common.GetTimestamp(),
+ ExpiredTime: token.ExpiredTime,
+ RemainQuota: token.RemainQuota,
+ UnlimitedQuota: token.UnlimitedQuota,
+ ModelLimitsEnabled: token.ModelLimitsEnabled,
+ ModelLimits: token.ModelLimits,
+ AllowIps: token.AllowIps,
+ Group: token.Group,
+ }
+ err = cleanToken.Insert()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func DeleteToken(c *gin.Context) {
+ id, _ := strconv.Atoi(c.Param("id"))
+ userId := c.GetInt("id")
+ err := model.DeleteTokenById(id, userId)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func UpdateToken(c *gin.Context) {
+ userId := c.GetInt("id")
+ statusOnly := c.Query("status_only")
+ token := model.Token{}
+ err := c.ShouldBindJSON(&token)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if len(token.Name) > 30 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "令牌名称过长",
+ })
+ return
+ }
+ cleanToken, err := model.GetTokenByIds(token.Id, userId)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if token.Status == common.TokenStatusEnabled {
+ if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
+ })
+ return
+ }
+ if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",
+ })
+ return
+ }
+ }
+ if statusOnly != "" {
+ cleanToken.Status = token.Status
+ } else {
+ // If you add more fields, please also update token.Update()
+ cleanToken.Name = token.Name
+ cleanToken.ExpiredTime = token.ExpiredTime
+ cleanToken.RemainQuota = token.RemainQuota
+ cleanToken.UnlimitedQuota = token.UnlimitedQuota
+ cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
+ cleanToken.ModelLimits = token.ModelLimits
+ cleanToken.AllowIps = token.AllowIps
+ cleanToken.Group = token.Group
+ }
+ err = cleanToken.Update()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": cleanToken,
+ })
+ return
+}
+
+type TokenBatch struct {
+ Ids []int `json:"ids"`
+}
+
+func DeleteTokenBatch(c *gin.Context) {
+ tokenBatch := TokenBatch{}
+ if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "参数错误",
+ })
+ return
+ }
+ userId := c.GetInt("id")
+ count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": count,
+ })
+}
diff --git a/controller/topup.go b/controller/topup.go
new file mode 100644
index 00000000..827dda39
--- /dev/null
+++ b/controller/topup.go
@@ -0,0 +1,265 @@
+package controller
+
+import (
+ "fmt"
+ "log"
+ "net/url"
+ "one-api/common"
+ "one-api/model"
+ "one-api/service"
+ "one-api/setting"
+ "strconv"
+ "sync"
+ "time"
+
+ "github.com/Calcium-Ion/go-epay/epay"
+ "github.com/gin-gonic/gin"
+ "github.com/samber/lo"
+ "github.com/shopspring/decimal"
+)
+
+type EpayRequest struct {
+ Amount int64 `json:"amount"`
+ PaymentMethod string `json:"payment_method"`
+ TopUpCode string `json:"top_up_code"`
+}
+
+type AmountRequest struct {
+ Amount int64 `json:"amount"`
+ TopUpCode string `json:"top_up_code"`
+}
+
+func GetEpayClient() *epay.Client {
+ if setting.PayAddress == "" || setting.EpayId == "" || setting.EpayKey == "" {
+ return nil
+ }
+ withUrl, err := epay.NewClient(&epay.Config{
+ PartnerID: setting.EpayId,
+ Key: setting.EpayKey,
+ }, setting.PayAddress)
+ if err != nil {
+ return nil
+ }
+ return withUrl
+}
+
+func getPayMoney(amount int64, group string) float64 {
+ dAmount := decimal.NewFromInt(amount)
+
+ if !common.DisplayInCurrencyEnabled {
+ dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
+ dAmount = dAmount.Div(dQuotaPerUnit)
+ }
+
+ topupGroupRatio := common.GetTopupGroupRatio(group)
+ if topupGroupRatio == 0 {
+ topupGroupRatio = 1
+ }
+
+ dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio)
+ dPrice := decimal.NewFromFloat(setting.Price)
+
+ payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio)
+
+ return payMoney.InexactFloat64()
+}
+
+func getMinTopup() int64 {
+ minTopup := setting.MinTopUp
+ if !common.DisplayInCurrencyEnabled {
+ dMinTopup := decimal.NewFromInt(int64(minTopup))
+ dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
+ minTopup = int(dMinTopup.Mul(dQuotaPerUnit).IntPart())
+ }
+ return int64(minTopup)
+}
+
+func RequestEpay(c *gin.Context) {
+ var req EpayRequest
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
+ return
+ }
+ if req.Amount < getMinTopup() {
+ c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
+ return
+ }
+
+ id := c.GetInt("id")
+ group, err := model.GetUserGroup(id, true)
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
+ return
+ }
+ payMoney := getPayMoney(req.Amount, group)
+ if payMoney < 0.01 {
+ c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
+ return
+ }
+
+ if !setting.ContainsPayMethod(req.PaymentMethod) {
+ c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
+ return
+ }
+
+ callBackAddress := service.GetCallbackAddress()
+ returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
+ notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
+ tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
+ tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
+ client := GetEpayClient()
+ if client == nil {
+ c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
+ return
+ }
+ uri, params, err := client.Purchase(&epay.PurchaseArgs{
+ Type: req.PaymentMethod,
+ ServiceTradeNo: tradeNo,
+ Name: fmt.Sprintf("TUC%d", req.Amount),
+ Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
+ Device: epay.PC,
+ NotifyUrl: notifyUrl,
+ ReturnUrl: returnUrl,
+ })
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
+ return
+ }
+ amount := req.Amount
+ if !common.DisplayInCurrencyEnabled {
+ dAmount := decimal.NewFromInt(int64(amount))
+ dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
+ amount = dAmount.Div(dQuotaPerUnit).IntPart()
+ }
+ topUp := &model.TopUp{
+ UserId: id,
+ Amount: amount,
+ Money: payMoney,
+ TradeNo: tradeNo,
+ CreateTime: time.Now().Unix(),
+ Status: "pending",
+ }
+ err = topUp.Insert()
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
+ return
+ }
+ c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
+}
+
+// tradeNo lock
+var orderLocks sync.Map
+var createLock sync.Mutex
+
+// LockOrder 尝试对给定订单号加锁
+func LockOrder(tradeNo string) {
+ lock, ok := orderLocks.Load(tradeNo)
+ if !ok {
+ createLock.Lock()
+ defer createLock.Unlock()
+ lock, ok = orderLocks.Load(tradeNo)
+ if !ok {
+ lock = new(sync.Mutex)
+ orderLocks.Store(tradeNo, lock)
+ }
+ }
+ lock.(*sync.Mutex).Lock()
+}
+
+// UnlockOrder 释放给定订单号的锁
+func UnlockOrder(tradeNo string) {
+ lock, ok := orderLocks.Load(tradeNo)
+ if ok {
+ lock.(*sync.Mutex).Unlock()
+ }
+}
+
+func EpayNotify(c *gin.Context) {
+ params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
+ r[t] = c.Request.URL.Query().Get(t)
+ return r
+ }, map[string]string{})
+ client := GetEpayClient()
+ if client == nil {
+ log.Println("易支付回调失败 未找到配置信息")
+ _, err := c.Writer.Write([]byte("fail"))
+ if err != nil {
+ log.Println("易支付回调写入失败")
+ return
+ }
+ }
+ verifyInfo, err := client.Verify(params)
+ if err == nil && verifyInfo.VerifyStatus {
+ _, err := c.Writer.Write([]byte("success"))
+ if err != nil {
+ log.Println("易支付回调写入失败")
+ }
+ } else {
+ _, err := c.Writer.Write([]byte("fail"))
+ if err != nil {
+ log.Println("易支付回调写入失败")
+ }
+ log.Println("易支付回调签名验证失败")
+ return
+ }
+
+ if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
+ log.Println(verifyInfo)
+ LockOrder(verifyInfo.ServiceTradeNo)
+ defer UnlockOrder(verifyInfo.ServiceTradeNo)
+ topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
+ if topUp == nil {
+ log.Printf("易支付回调未找到订单: %v", verifyInfo)
+ return
+ }
+ if topUp.Status == "pending" {
+ topUp.Status = "success"
+ err := topUp.Update()
+ if err != nil {
+ log.Printf("易支付回调更新订单失败: %v", topUp)
+ return
+ }
+ //user, _ := model.GetUserById(topUp.UserId, false)
+ //user.Quota += topUp.Amount * 500000
+ dAmount := decimal.NewFromInt(int64(topUp.Amount))
+ dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
+ quotaToAdd := int(dAmount.Mul(dQuotaPerUnit).IntPart())
+ err = model.IncreaseUserQuota(topUp.UserId, quotaToAdd, true)
+ if err != nil {
+ log.Printf("易支付回调更新用户失败: %v", topUp)
+ return
+ }
+ log.Printf("易支付回调更新用户成功 %v", topUp)
+ model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money))
+ }
+ } else {
+ log.Printf("易支付异常回调: %v", verifyInfo)
+ }
+}
+
+func RequestAmount(c *gin.Context) {
+ var req AmountRequest
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
+ return
+ }
+
+ if req.Amount < getMinTopup() {
+ c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
+ return
+ }
+ id := c.GetInt("id")
+ group, err := model.GetUserGroup(id, true)
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
+ return
+ }
+ payMoney := getPayMoney(req.Amount, group)
+ if payMoney <= 0.01 {
+ c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
+ return
+ }
+ c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
+}
diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go
new file mode 100644
index 00000000..eb320809
--- /dev/null
+++ b/controller/topup_stripe.go
@@ -0,0 +1,275 @@
+package controller
+
+import (
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "one-api/setting"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stripe/stripe-go/v81"
+ "github.com/stripe/stripe-go/v81/checkout/session"
+ "github.com/stripe/stripe-go/v81/webhook"
+ "github.com/thanhpk/randstr"
+)
+
+const (
+ PaymentMethodStripe = "stripe"
+)
+
+var stripeAdaptor = &StripeAdaptor{}
+
+type StripePayRequest struct {
+ Amount int64 `json:"amount"`
+ PaymentMethod string `json:"payment_method"`
+}
+
+type StripeAdaptor struct {
+}
+
+func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
+ if req.Amount < getStripeMinTopup() {
+ c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
+ return
+ }
+ id := c.GetInt("id")
+ group, err := model.GetUserGroup(id, true)
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
+ return
+ }
+ payMoney := getStripePayMoney(float64(req.Amount), group)
+ if payMoney <= 0.01 {
+ c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
+ return
+ }
+ c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
+}
+
+func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
+ if req.PaymentMethod != PaymentMethodStripe {
+ c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
+ return
+ }
+ if req.Amount < getStripeMinTopup() {
+ c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
+ return
+ }
+ if req.Amount > 10000 {
+ c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
+ return
+ }
+
+ id := c.GetInt("id")
+ user, _ := model.GetUserById(id, false)
+ chargedMoney := GetChargedAmount(float64(req.Amount), *user)
+
+ reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), randstr.String(4))
+ referenceId := "ref_" + common.Sha1([]byte(reference))
+
+ payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount)
+ if err != nil {
+ log.Println("获取Stripe Checkout支付链接失败", err)
+ c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
+ return
+ }
+
+ topUp := &model.TopUp{
+ UserId: id,
+ Amount: req.Amount,
+ Money: chargedMoney,
+ TradeNo: referenceId,
+ CreateTime: time.Now().Unix(),
+ Status: common.TopUpStatusPending,
+ }
+ err = topUp.Insert()
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
+ return
+ }
+ c.JSON(200, gin.H{
+ "message": "success",
+ "data": gin.H{
+ "pay_link": payLink,
+ },
+ })
+}
+
+func RequestStripeAmount(c *gin.Context) {
+ var req StripePayRequest
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
+ return
+ }
+ stripeAdaptor.RequestAmount(c, &req)
+}
+
+func RequestStripePay(c *gin.Context) {
+ var req StripePayRequest
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
+ return
+ }
+ stripeAdaptor.RequestPay(c, &req)
+}
+
+func StripeWebhook(c *gin.Context) {
+ payload, err := io.ReadAll(c.Request.Body)
+ if err != nil {
+ log.Printf("解析Stripe Webhook参数失败: %v\n", err)
+ c.AbortWithStatus(http.StatusServiceUnavailable)
+ return
+ }
+
+ signature := c.GetHeader("Stripe-Signature")
+ endpointSecret := setting.StripeWebhookSecret
+ event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
+ IgnoreAPIVersionMismatch: true,
+ })
+
+ if err != nil {
+ log.Printf("Stripe Webhook验签失败: %v\n", err)
+ c.AbortWithStatus(http.StatusBadRequest)
+ return
+ }
+
+ switch event.Type {
+ case stripe.EventTypeCheckoutSessionCompleted:
+ sessionCompleted(event)
+ case stripe.EventTypeCheckoutSessionExpired:
+ sessionExpired(event)
+ default:
+ log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
+ }
+
+ c.Status(http.StatusOK)
+}
+
+func sessionCompleted(event stripe.Event) {
+ customerId := event.GetObjectValue("customer")
+ referenceId := event.GetObjectValue("client_reference_id")
+ status := event.GetObjectValue("status")
+ if "complete" != status {
+ log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
+ return
+ }
+
+ err := model.Recharge(referenceId, customerId)
+ if err != nil {
+ log.Println(err.Error(), referenceId)
+ return
+ }
+
+ total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
+ currency := strings.ToUpper(event.GetObjectValue("currency"))
+ log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
+}
+
+func sessionExpired(event stripe.Event) {
+ referenceId := event.GetObjectValue("client_reference_id")
+ status := event.GetObjectValue("status")
+ if "expired" != status {
+ log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
+ return
+ }
+
+ if len(referenceId) == 0 {
+ log.Println("未提供支付单号")
+ return
+ }
+
+ topUp := model.GetTopUpByTradeNo(referenceId)
+ if topUp == nil {
+ log.Println("充值订单不存在", referenceId)
+ return
+ }
+
+ if topUp.Status != common.TopUpStatusPending {
+ log.Println("充值订单状态错误", referenceId)
+ }
+
+ topUp.Status = common.TopUpStatusExpired
+ err := topUp.Update()
+ if err != nil {
+ log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
+ return
+ }
+
+ log.Println("充值订单已过期", referenceId)
+}
+
+func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
+ if !strings.HasPrefix(setting.StripeApiSecret, "sk_") && !strings.HasPrefix(setting.StripeApiSecret, "rk_") {
+ return "", fmt.Errorf("无效的Stripe API密钥")
+ }
+
+ stripe.Key = setting.StripeApiSecret
+
+ params := &stripe.CheckoutSessionParams{
+ ClientReferenceID: stripe.String(referenceId),
+ SuccessURL: stripe.String(setting.ServerAddress + "/log"),
+ CancelURL: stripe.String(setting.ServerAddress + "/topup"),
+ LineItems: []*stripe.CheckoutSessionLineItemParams{
+ {
+ Price: stripe.String(setting.StripePriceId),
+ Quantity: stripe.Int64(amount),
+ },
+ },
+ Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
+ }
+
+ if "" == customerId {
+ if "" != email {
+ params.CustomerEmail = stripe.String(email)
+ }
+
+ params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
+ } else {
+ params.Customer = stripe.String(customerId)
+ }
+
+ result, err := session.New(params)
+ if err != nil {
+ return "", err
+ }
+
+ return result.URL, nil
+}
+
+func GetChargedAmount(count float64, user model.User) float64 {
+ topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
+ if topUpGroupRatio == 0 {
+ topUpGroupRatio = 1
+ }
+
+ return count * topUpGroupRatio
+}
+
+func getStripePayMoney(amount float64, group string) float64 {
+ if !common.DisplayInCurrencyEnabled {
+ amount = amount / common.QuotaPerUnit
+ }
+ // Using float64 for monetary calculations is acceptable here due to the small amounts involved
+ topupGroupRatio := common.GetTopupGroupRatio(group)
+ if topupGroupRatio == 0 {
+ topupGroupRatio = 1
+ }
+ payMoney := amount * setting.StripeUnitPrice * topupGroupRatio
+ return payMoney
+}
+
+func getStripeMinTopup() int64 {
+ minTopup := setting.StripeMinTopUp
+ if !common.DisplayInCurrencyEnabled {
+ minTopup = minTopup * int(common.QuotaPerUnit)
+ }
+ return int64(minTopup)
+}
diff --git a/controller/uptime_kuma.go b/controller/uptime_kuma.go
new file mode 100644
index 00000000..05d6297e
--- /dev/null
+++ b/controller/uptime_kuma.go
@@ -0,0 +1,154 @@
+package controller
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "one-api/setting/console_setting"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "golang.org/x/sync/errgroup"
+)
+
+const (
+ requestTimeout = 30 * time.Second
+ httpTimeout = 10 * time.Second
+ uptimeKeySuffix = "_24"
+ apiStatusPath = "/api/status-page/"
+ apiHeartbeatPath = "/api/status-page/heartbeat/"
+)
+
+type Monitor struct {
+ Name string `json:"name"`
+ Uptime float64 `json:"uptime"`
+ Status int `json:"status"`
+ Group string `json:"group,omitempty"`
+}
+
+type UptimeGroupResult struct {
+ CategoryName string `json:"categoryName"`
+ Monitors []Monitor `json:"monitors"`
+}
+
+func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return err
+ }
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return errors.New("non-200 status")
+ }
+
+ return json.NewDecoder(resp.Body).Decode(dest)
+}
+
+func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[string]interface{}) UptimeGroupResult {
+ url, _ := groupConfig["url"].(string)
+ slug, _ := groupConfig["slug"].(string)
+ categoryName, _ := groupConfig["categoryName"].(string)
+
+ result := UptimeGroupResult{
+ CategoryName: categoryName,
+ Monitors: []Monitor{},
+ }
+
+ if url == "" || slug == "" {
+ return result
+ }
+
+ baseURL := strings.TrimSuffix(url, "/")
+
+ var statusData struct {
+ PublicGroupList []struct {
+ ID int `json:"id"`
+ Name string `json:"name"`
+ MonitorList []struct {
+ ID int `json:"id"`
+ Name string `json:"name"`
+ } `json:"monitorList"`
+ } `json:"publicGroupList"`
+ }
+
+ var heartbeatData struct {
+ HeartbeatList map[string][]struct {
+ Status int `json:"status"`
+ } `json:"heartbeatList"`
+ UptimeList map[string]float64 `json:"uptimeList"`
+ }
+
+ g, gCtx := errgroup.WithContext(ctx)
+ g.Go(func() error {
+ return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
+ })
+ g.Go(func() error {
+ return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
+ })
+
+ if g.Wait() != nil {
+ return result
+ }
+
+ for _, pg := range statusData.PublicGroupList {
+ if len(pg.MonitorList) == 0 {
+ continue
+ }
+
+ for _, m := range pg.MonitorList {
+ monitor := Monitor{
+ Name: m.Name,
+ Group: pg.Name,
+ }
+
+ monitorID := strconv.Itoa(m.ID)
+
+ if uptime, exists := heartbeatData.UptimeList[monitorID+uptimeKeySuffix]; exists {
+ monitor.Uptime = uptime
+ }
+
+ if heartbeats, exists := heartbeatData.HeartbeatList[monitorID]; exists && len(heartbeats) > 0 {
+ monitor.Status = heartbeats[0].Status
+ }
+
+ result.Monitors = append(result.Monitors, monitor)
+ }
+ }
+
+ return result
+}
+
+func GetUptimeKumaStatus(c *gin.Context) {
+ groups := console_setting.GetUptimeKumaGroups()
+ if len(groups) == 0 {
+ c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": []UptimeGroupResult{}})
+ return
+ }
+
+ ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
+ defer cancel()
+
+ client := &http.Client{Timeout: httpTimeout}
+ results := make([]UptimeGroupResult, len(groups))
+
+ g, gCtx := errgroup.WithContext(ctx)
+ for i, group := range groups {
+ i, group := i, group
+ g.Go(func() error {
+ results[i] = fetchGroupData(gCtx, client, group)
+ return nil
+ })
+ }
+
+ g.Wait()
+ c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
+}
\ No newline at end of file
diff --git a/controller/usedata.go b/controller/usedata.go
new file mode 100644
index 00000000..4adee50f
--- /dev/null
+++ b/controller/usedata.go
@@ -0,0 +1,52 @@
+package controller
+
+import (
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "strconv"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GetAllQuotaDates(c *gin.Context) {
+ startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+ endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+ username := c.Query("username")
+ dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": dates,
+ })
+ return
+}
+
+func GetUserQuotaDates(c *gin.Context) {
+ userId := c.GetInt("id")
+ startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+ endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+ // 判断时间跨度是否超过 1 个月
+ if endTimestamp-startTimestamp > 2592000 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "时间跨度不能超过 1 个月",
+ })
+ return
+ }
+ dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": dates,
+ })
+ return
+}
diff --git a/controller/user.go b/controller/user.go
new file mode 100644
index 00000000..292ed8c6
--- /dev/null
+++ b/controller/user.go
@@ -0,0 +1,956 @@
+package controller
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/model"
+ "one-api/setting"
+ "strconv"
+ "strings"
+ "sync"
+
+ "one-api/constant"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+)
+
+type LoginRequest struct {
+ Username string `json:"username"`
+ Password string `json:"password"`
+}
+
+func Login(c *gin.Context) {
+ if !common.PasswordLoginEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "管理员关闭了密码登录",
+ "success": false,
+ })
+ return
+ }
+ var loginRequest LoginRequest
+ err := json.NewDecoder(c.Request.Body).Decode(&loginRequest)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "无效的参数",
+ "success": false,
+ })
+ return
+ }
+ username := loginRequest.Username
+ password := loginRequest.Password
+ if username == "" || password == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "无效的参数",
+ "success": false,
+ })
+ return
+ }
+ user := model.User{
+ Username: username,
+ Password: password,
+ }
+ err = user.ValidateAndFill()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "message": err.Error(),
+ "success": false,
+ })
+ return
+ }
+ setupLogin(&user, c)
+}
+
+// setup session & cookies and then return user info
+func setupLogin(user *model.User, c *gin.Context) {
+ session := sessions.Default(c)
+ session.Set("id", user.Id)
+ session.Set("username", user.Username)
+ session.Set("role", user.Role)
+ session.Set("status", user.Status)
+ session.Set("group", user.Group)
+ err := session.Save()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "无法保存会话信息,请重试",
+ "success": false,
+ })
+ return
+ }
+ cleanUser := model.User{
+ Id: user.Id,
+ Username: user.Username,
+ DisplayName: user.DisplayName,
+ Role: user.Role,
+ Status: user.Status,
+ Group: user.Group,
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "message": "",
+ "success": true,
+ "data": cleanUser,
+ })
+}
+
+func Logout(c *gin.Context) {
+ session := sessions.Default(c)
+ session.Clear()
+ err := session.Save()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "message": err.Error(),
+ "success": false,
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "message": "",
+ "success": true,
+ })
+}
+
+func Register(c *gin.Context) {
+ if !common.RegisterEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "管理员关闭了新用户注册",
+ "success": false,
+ })
+ return
+ }
+ if !common.PasswordRegisterEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
+ "success": false,
+ })
+ return
+ }
+ var user model.User
+ err := json.NewDecoder(c.Request.Body).Decode(&user)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的参数",
+ })
+ return
+ }
+ if err := common.Validate.Struct(&user); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "输入不合法 " + err.Error(),
+ })
+ return
+ }
+ if common.EmailVerificationEnabled {
+ if user.Email == "" || user.VerificationCode == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员开启了邮箱验证,请输入邮箱地址和验证码",
+ })
+ return
+ }
+ if !common.VerifyCodeWithKey(user.Email, user.VerificationCode, common.EmailVerificationPurpose) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "验证码错误或已过期",
+ })
+ return
+ }
+ }
+ exist, err := model.CheckUserExistOrDeleted(user.Username, user.Email)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "数据库错误,请稍后重试",
+ })
+ common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
+ return
+ }
+ if exist {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "用户名已存在,或已注销",
+ })
+ return
+ }
+ affCode := user.AffCode // this code is the inviter's code, not the user's own code
+ inviterId, _ := model.GetUserIdByAffCode(affCode)
+ cleanUser := model.User{
+ Username: user.Username,
+ Password: user.Password,
+ DisplayName: user.Username,
+ InviterId: inviterId,
+ }
+ if common.EmailVerificationEnabled {
+ cleanUser.Email = user.Email
+ }
+ if err := cleanUser.Insert(inviterId); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ // 获取插入后的用户ID
+ var insertedUser model.User
+ if err := model.DB.Where("username = ?", cleanUser.Username).First(&insertedUser).Error; err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "用户注册失败或用户ID获取失败",
+ })
+ return
+ }
+ // 生成默认令牌
+ if constant.GenerateDefaultToken {
+ key, err := common.GenerateKey()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "生成默认令牌失败",
+ })
+ common.SysError("failed to generate token key: " + err.Error())
+ return
+ }
+ // 生成默认令牌
+ token := model.Token{
+ UserId: insertedUser.Id, // 使用插入后的用户ID
+ Name: cleanUser.Username + "的初始令牌",
+ Key: key,
+ CreatedTime: common.GetTimestamp(),
+ AccessedTime: common.GetTimestamp(),
+ ExpiredTime: -1, // 永不过期
+ RemainQuota: 500000, // 示例额度
+ UnlimitedQuota: true,
+ ModelLimitsEnabled: false,
+ }
+ if setting.DefaultUseAutoGroup {
+ token.Group = "auto"
+ }
+ if err := token.Insert(); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "创建默认令牌失败",
+ })
+ return
+ }
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func GetAllUsers(c *gin.Context) {
+ pageInfo := common.GetPageQuery(c)
+ users, total, err := model.GetAllUsers(pageInfo)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(users)
+
+ common.ApiSuccess(c, pageInfo)
+ return
+}
+
+func SearchUsers(c *gin.Context) {
+ keyword := c.Query("keyword")
+ group := c.Query("group")
+ pageInfo := common.GetPageQuery(c)
+ users, total, err := model.SearchUsers(keyword, group, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(users)
+ common.ApiSuccess(c, pageInfo)
+ return
+}
+
+func GetUser(c *gin.Context) {
+ id, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ user, err := model.GetUserById(id, false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ myRole := c.GetInt("role")
+ if myRole <= user.Role && myRole != common.RoleRootUser {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无权获取同级或更高等级用户的信息",
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": user,
+ })
+ return
+}
+
+func GenerateAccessToken(c *gin.Context) {
+ id := c.GetInt("id")
+ user, err := model.GetUserById(id, true)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ // get rand int 28-32
+ randI := common.GetRandomInt(4)
+ key, err := common.GenerateRandomKey(29 + randI)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "生成失败",
+ })
+ common.SysError("failed to generate key: " + err.Error())
+ return
+ }
+ user.SetAccessToken(key)
+
+ if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "请重试,系统生成的 UUID 竟然重复了!",
+ })
+ return
+ }
+
+ if err := user.Update(false); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": user.AccessToken,
+ })
+ return
+}
+
+type TransferAffQuotaRequest struct {
+ Quota int `json:"quota" binding:"required"`
+}
+
+func TransferAffQuota(c *gin.Context) {
+ id := c.GetInt("id")
+ user, err := model.GetUserById(id, true)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ tran := TransferAffQuotaRequest{}
+ if err := c.ShouldBindJSON(&tran); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ err = user.TransferAffQuotaToQuota(tran.Quota)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "划转失败 " + err.Error(),
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "划转成功",
+ })
+}
+
+func GetAffCode(c *gin.Context) {
+ id := c.GetInt("id")
+ user, err := model.GetUserById(id, true)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if user.AffCode == "" {
+ user.AffCode = common.GetRandomString(4)
+ if err := user.Update(false); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": user.AffCode,
+ })
+ return
+}
+
+func GetSelf(c *gin.Context) {
+ id := c.GetInt("id")
+ user, err := model.GetUserById(id, false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ // Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
+ user.Remark = ""
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": user,
+ })
+ return
+}
+
+func GetUserModels(c *gin.Context) {
+ id, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ id = c.GetInt("id")
+ }
+ user, err := model.GetUserCache(id)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ groups := setting.GetUserUsableGroups(user.Group)
+ var models []string
+ for group := range groups {
+ for _, g := range model.GetGroupEnabledModels(group) {
+ if !common.StringsContains(models, g) {
+ models = append(models, g)
+ }
+ }
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": models,
+ })
+ return
+}
+
+func UpdateUser(c *gin.Context) {
+ var updatedUser model.User
+ err := json.NewDecoder(c.Request.Body).Decode(&updatedUser)
+ if err != nil || updatedUser.Id == 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的参数",
+ })
+ return
+ }
+ if updatedUser.Password == "" {
+ updatedUser.Password = "$I_LOVE_U" // make Validator happy :)
+ }
+ if err := common.Validate.Struct(&updatedUser); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "输入不合法 " + err.Error(),
+ })
+ return
+ }
+ originUser, err := model.GetUserById(updatedUser.Id, false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ myRole := c.GetInt("role")
+ if myRole <= originUser.Role && myRole != common.RoleRootUser {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无权更新同权限等级或更高权限等级的用户信息",
+ })
+ return
+ }
+ if myRole <= updatedUser.Role && myRole != common.RoleRootUser {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
+ })
+ return
+ }
+ if updatedUser.Password == "$I_LOVE_U" {
+ updatedUser.Password = "" // rollback to what it should be
+ }
+ updatePassword := updatedUser.Password != ""
+ if err := updatedUser.Edit(updatePassword); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if originUser.Quota != updatedUser.Quota {
+ model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func UpdateSelf(c *gin.Context) {
+ var user model.User
+ err := json.NewDecoder(c.Request.Body).Decode(&user)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的参数",
+ })
+ return
+ }
+ if user.Password == "" {
+ user.Password = "$I_LOVE_U" // make Validator happy :)
+ }
+ if err := common.Validate.Struct(&user); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "输入不合法 " + err.Error(),
+ })
+ return
+ }
+
+ cleanUser := model.User{
+ Id: c.GetInt("id"),
+ Username: user.Username,
+ Password: user.Password,
+ DisplayName: user.DisplayName,
+ }
+ if user.Password == "$I_LOVE_U" {
+ user.Password = "" // rollback to what it should be
+ cleanUser.Password = ""
+ }
+ updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if err := cleanUser.Update(updatePassword); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func checkUpdatePassword(originalPassword string, newPassword string, userId int) (updatePassword bool, err error) {
+ var currentUser *model.User
+ currentUser, err = model.GetUserById(userId, true)
+ if err != nil {
+ return
+ }
+ if !common.ValidatePasswordAndHash(originalPassword, currentUser.Password) {
+ err = fmt.Errorf("原密码错误")
+ return
+ }
+ if newPassword == "" {
+ return
+ }
+ updatePassword = true
+ return
+}
+
+func DeleteUser(c *gin.Context) {
+ id, err := strconv.Atoi(c.Param("id"))
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ originUser, err := model.GetUserById(id, false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ myRole := c.GetInt("role")
+ if myRole <= originUser.Role {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无权删除同权限等级或更高权限等级的用户",
+ })
+ return
+ }
+ err = model.HardDeleteUserById(id)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+ }
+}
+
+func DeleteSelf(c *gin.Context) {
+ id := c.GetInt("id")
+ user, _ := model.GetUserById(id, false)
+
+ if user.Role == common.RoleRootUser {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "不能删除超级管理员账户",
+ })
+ return
+ }
+
+ err := model.DeleteUserById(id)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+func CreateUser(c *gin.Context) {
+ var user model.User
+ err := json.NewDecoder(c.Request.Body).Decode(&user)
+ user.Username = strings.TrimSpace(user.Username)
+ if err != nil || user.Username == "" || user.Password == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的参数",
+ })
+ return
+ }
+ if err := common.Validate.Struct(&user); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "输入不合法 " + err.Error(),
+ })
+ return
+ }
+ if user.DisplayName == "" {
+ user.DisplayName = user.Username
+ }
+ myRole := c.GetInt("role")
+ if user.Role >= myRole {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无法创建权限大于等于自己的用户",
+ })
+ return
+ }
+ // Even for admin users, we cannot fully trust them!
+ cleanUser := model.User{
+ Username: user.Username,
+ Password: user.Password,
+ DisplayName: user.DisplayName,
+ }
+ if err := cleanUser.Insert(0); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+type ManageRequest struct {
+ Id int `json:"id"`
+ Action string `json:"action"`
+}
+
+// ManageUser Only admin user can do this
+func ManageUser(c *gin.Context) {
+ var req ManageRequest
+ err := json.NewDecoder(c.Request.Body).Decode(&req)
+
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的参数",
+ })
+ return
+ }
+ user := model.User{
+ Id: req.Id,
+ }
+ // Fill attributes
+ model.DB.Unscoped().Where(&user).First(&user)
+ if user.Id == 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "用户不存在",
+ })
+ return
+ }
+ myRole := c.GetInt("role")
+ if myRole <= user.Role && myRole != common.RoleRootUser {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无权更新同权限等级或更高权限等级的用户信息",
+ })
+ return
+ }
+ switch req.Action {
+ case "disable":
+ user.Status = common.UserStatusDisabled
+ if user.Role == common.RoleRootUser {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无法禁用超级管理员用户",
+ })
+ return
+ }
+ case "enable":
+ user.Status = common.UserStatusEnabled
+ case "delete":
+ if user.Role == common.RoleRootUser {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无法删除超级管理员用户",
+ })
+ return
+ }
+ if err := user.Delete(); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ case "promote":
+ if myRole != common.RoleRootUser {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "普通管理员用户无法提升其他用户为管理员",
+ })
+ return
+ }
+ if user.Role >= common.RoleAdminUser {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "该用户已经是管理员",
+ })
+ return
+ }
+ user.Role = common.RoleAdminUser
+ case "demote":
+ if user.Role == common.RoleRootUser {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无法降级超级管理员用户",
+ })
+ return
+ }
+ if user.Role == common.RoleCommonUser {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "该用户已经是普通用户",
+ })
+ return
+ }
+ user.Role = common.RoleCommonUser
+ }
+
+ if err := user.Update(false); err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ clearUser := model.User{
+ Role: user.Role,
+ Status: user.Status,
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": clearUser,
+ })
+ return
+}
+
+func EmailBind(c *gin.Context) {
+ email := c.Query("email")
+ code := c.Query("code")
+ if !common.VerifyCodeWithKey(email, code, common.EmailVerificationPurpose) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "验证码错误或已过期",
+ })
+ return
+ }
+ session := sessions.Default(c)
+ id := session.Get("id")
+ user := model.User{
+ Id: id.(int),
+ }
+ err := user.FillUserById()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ user.Email = email
+ // no need to check if this email already taken, because we have used verification code to check it
+ err = user.Update(false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
+type topUpRequest struct {
+ Key string `json:"key"`
+}
+
+var topUpLock = sync.Mutex{}
+
+func TopUp(c *gin.Context) {
+ topUpLock.Lock()
+ defer topUpLock.Unlock()
+ req := topUpRequest{}
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ id := c.GetInt("id")
+ quota, err := model.Redeem(req.Key, id)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": quota,
+ })
+ return
+}
+
+type UpdateUserSettingRequest struct {
+ QuotaWarningType string `json:"notify_type"`
+ QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
+ WebhookUrl string `json:"webhook_url,omitempty"`
+ WebhookSecret string `json:"webhook_secret,omitempty"`
+ NotificationEmail string `json:"notification_email,omitempty"`
+ AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
+ RecordIpLog bool `json:"record_ip_log"`
+}
+
+func UpdateUserSetting(c *gin.Context) {
+ var req UpdateUserSettingRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的参数",
+ })
+ return
+ }
+
+ // 验证预警类型
+ if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的预警类型",
+ })
+ return
+ }
+
+ // 验证预警阈值
+ if req.QuotaWarningThreshold <= 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "预警阈值必须大于0",
+ })
+ return
+ }
+
+ // 如果是webhook类型,验证webhook地址
+ if req.QuotaWarningType == dto.NotifyTypeWebhook {
+ if req.WebhookUrl == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "Webhook地址不能为空",
+ })
+ return
+ }
+ // 验证URL格式
+ if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的Webhook地址",
+ })
+ return
+ }
+ }
+
+ // 如果是邮件类型,验证邮箱地址
+ if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
+ // 验证邮箱格式
+ if !strings.Contains(req.NotificationEmail, "@") {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的邮箱地址",
+ })
+ return
+ }
+ }
+
+ userId := c.GetInt("id")
+ user, err := model.GetUserById(userId, true)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+
+ // 构建设置
+ settings := dto.UserSetting{
+ NotifyType: req.QuotaWarningType,
+ QuotaWarningThreshold: req.QuotaWarningThreshold,
+ AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel,
+ RecordIpLog: req.RecordIpLog,
+ }
+
+ // 如果是webhook类型,添加webhook相关设置
+ if req.QuotaWarningType == dto.NotifyTypeWebhook {
+ settings.WebhookUrl = req.WebhookUrl
+ if req.WebhookSecret != "" {
+ settings.WebhookSecret = req.WebhookSecret
+ }
+ }
+
+ // 如果提供了通知邮箱,添加到设置中
+ if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
+ settings.NotificationEmail = req.NotificationEmail
+ }
+
+ // 更新用户设置
+ user.SetSetting(settings)
+ if err := user.Update(false); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "更新设置失败: " + err.Error(),
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "设置已更新",
+ })
+}
diff --git a/controller/wechat.go b/controller/wechat.go
new file mode 100644
index 00000000..9a4bdfed
--- /dev/null
+++ b/controller/wechat.go
@@ -0,0 +1,168 @@
+package controller
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "strconv"
+ "time"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+)
+
+type wechatLoginResponse struct {
+ Success bool `json:"success"`
+ Message string `json:"message"`
+ Data string `json:"data"`
+}
+
+func getWeChatIdByCode(code string) (string, error) {
+ if code == "" {
+ return "", errors.New("无效的参数")
+ }
+ req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil)
+ if err != nil {
+ return "", err
+ }
+ req.Header.Set("Authorization", common.WeChatServerToken)
+ client := http.Client{
+ Timeout: 5 * time.Second,
+ }
+ httpResponse, err := client.Do(req)
+ if err != nil {
+ return "", err
+ }
+ defer httpResponse.Body.Close()
+ var res wechatLoginResponse
+ err = json.NewDecoder(httpResponse.Body).Decode(&res)
+ if err != nil {
+ return "", err
+ }
+ if !res.Success {
+ return "", errors.New(res.Message)
+ }
+ if res.Data == "" {
+ return "", errors.New("验证码错误或已过期")
+ }
+ return res.Data, nil
+}
+
+func WeChatAuth(c *gin.Context) {
+ if !common.WeChatAuthEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "管理员未开启通过微信登录以及注册",
+ "success": false,
+ })
+ return
+ }
+ code := c.Query("code")
+ wechatId, err := getWeChatIdByCode(code)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "message": err.Error(),
+ "success": false,
+ })
+ return
+ }
+ user := model.User{
+ WeChatId: wechatId,
+ }
+ if model.IsWeChatIdAlreadyTaken(wechatId) {
+ err := user.FillUserByWeChatId()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ if user.Id == 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "用户已注销",
+ })
+ return
+ }
+ } else {
+ if common.RegisterEnabled {
+ user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
+ user.DisplayName = "WeChat User"
+ user.Role = common.RoleCommonUser
+ user.Status = common.UserStatusEnabled
+
+ if err := user.Insert(0); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ } else {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员关闭了新用户注册",
+ })
+ return
+ }
+ }
+
+ if user.Status != common.UserStatusEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "用户已被封禁",
+ "success": false,
+ })
+ return
+ }
+ setupLogin(&user, c)
+}
+
+func WeChatBind(c *gin.Context) {
+ if !common.WeChatAuthEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "管理员未开启通过微信登录以及注册",
+ "success": false,
+ })
+ return
+ }
+ code := c.Query("code")
+ wechatId, err := getWeChatIdByCode(code)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "message": err.Error(),
+ "success": false,
+ })
+ return
+ }
+ if model.IsWeChatIdAlreadyTaken(wechatId) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "该微信账号已被绑定",
+ })
+ return
+ }
+ session := sessions.Default(c)
+ id := session.Get("id")
+ user := model.User{
+ Id: id.(int),
+ }
+ err = user.FillUserById()
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ user.WeChatId = wechatId
+ err = user.Update(false)
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
diff --git a/docker-compose.yml b/docker-compose.yml
new file mode 100644
index 00000000..57ad0b30
--- /dev/null
+++ b/docker-compose.yml
@@ -0,0 +1,52 @@
+version: '3.4'
+
+services:
+ new-api:
+ image: calciumion/new-api:latest
+ container_name: new-api
+ restart: always
+ command: --log-dir /app/logs
+ ports:
+ - "3000:3000"
+ volumes:
+ - ./data:/data
+ - ./logs:/app/logs
+ environment:
+ - SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service
+ - REDIS_CONN_STRING=redis://redis
+ - TZ=Asia/Shanghai
+ - ERROR_LOG_ENABLED=true # 是否启用错误日志记录
+ # - STREAMING_TIMEOUT=120 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值
+ # - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
+ # - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
+ # - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
+ # - FRONTEND_BASE_URL=https://openai.justsong.cn # Uncomment for multi-node deployment with front-end URL
+
+ depends_on:
+ - redis
+ - mysql
+ healthcheck:
+ test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"]
+ interval: 30s
+ timeout: 10s
+ retries: 3
+
+ redis:
+ image: redis:latest
+ container_name: redis
+ restart: always
+
+ mysql:
+ image: mysql:8.2
+ container_name: mysql
+ restart: always
+ environment:
+ MYSQL_ROOT_PASSWORD: 123456 # Ensure this matches the password in SQL_DSN
+ MYSQL_DATABASE: new-api
+ volumes:
+ - mysql_data:/var/lib/mysql
+ # ports:
+ # - "3306:3306" # If you want to access MySQL from outside Docker, uncomment
+
+volumes:
+ mysql_data:
diff --git a/docs/api/api_auth.md b/docs/api/api_auth.md
new file mode 100644
index 00000000..798ca374
--- /dev/null
+++ b/docs/api/api_auth.md
@@ -0,0 +1,53 @@
+# API 鉴权文档
+
+## 认证方式
+
+### Access Token
+
+对于需要鉴权的 API 接口,必须同时提供以下两个请求头来进行 Access Token 认证:
+
+1. **请求头中的 `Authorization` 字段**
+
+ 将 Access Token 放置于 HTTP 请求头部的 `Authorization` 字段中,格式如下:
+
+ ```
+ Authorization:
+ ```
+
+ 其中 `` 需要替换为实际的 Access Token 值。
+
+2. **请求头中的 `New-Api-User` 字段**
+
+ 将用户 ID 放置于 HTTP 请求头部的 `New-Api-User` 字段中,格式如下:
+
+ ```
+ New-Api-User:
+ ```
+
+ 其中 `` 需要替换为实际的用户 ID。
+
+**注意:**
+
+* **必须同时提供 `Authorization` 和 `New-Api-User` 两个请求头才能通过鉴权。**
+* 如果只提供其中一个请求头,或者两个请求头都未提供,则会返回 `401 Unauthorized` 错误。
+* 如果 `Authorization` 中的 Access Token 无效,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,access token 无效”。
+* 如果 `New-Api-User` 中的用户 ID 与 Access Token 不匹配,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,与登录用户不匹配,请重新登录”。
+* 如果没有提供 `New-Api-User` 请求头,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,未提供 New-Api-User”。
+* 如果 `New-Api-User` 请求头格式错误,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,New-Api-User 格式错误”。
+* 如果用户已被禁用,则会返回 `403 Forbidden` 错误,并提示“用户已被封禁”。
+* 如果用户权限不足,则会返回 `403 Forbidden` 错误,并提示“无权进行此操作,权限不足”。
+* 如果用户信息无效,则会返回 `403 Forbidden` 错误,并提示“无权进行此操作,用户信息无效”。
+
+## Curl 示例
+
+假设您的 Access Token 为 `access_token`,用户 ID 为 `123`,要访问的 API 接口为 `/api/user/self`,则可以使用以下 curl 命令:
+
+```bash
+curl -X GET \
+ -H "Authorization: access_token" \
+ -H "New-Api-User: 123" \
+ https://your-domain.com/api/user/self
+```
+
+请将 `access_token`、`123` 和 `https://your-domain.com` 替换为实际的值。
+
diff --git a/docs/api/web_api.md b/docs/api/web_api.md
new file mode 100644
index 00000000..e64fd359
--- /dev/null
+++ b/docs/api/web_api.md
@@ -0,0 +1,197 @@
+# New API – Web 界面后端接口文档
+
+> 本文档汇总了 **New API** 后端提供给前端 Web 界面的全部 REST 接口(不含 *Relay* 相关接口)。
+>
+> 接口前缀统一为 `https://`,以下仅列出 **路径**、**HTTP 方法**、**鉴权要求** 与 **功能简介**。
+>
+> 鉴权级别说明:
+> * **公开** – 不需要登录即可调用
+> * **用户** – 需携带用户 Token(`middleware.UserAuth`)
+> * **管理员** – 需管理员 Token(`middleware.AdminAuth`)
+> * **Root** – 仅限最高权限 Root 用户(`middleware.RootAuth`)
+
+---
+
+## 1. 初始化 / 系统状态
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/setup | 公开 | 获取系统初始化状态 |
+| POST | /api/setup | 公开 | 完成首次安装向导 |
+| GET | /api/status | 公开 | 获取运行状态摘要 |
+| GET | /api/uptime/status | 公开 | Uptime-Kuma 兼容状态探针 |
+| GET | /api/status/test | 管理员 | 测试后端与依赖组件是否正常 |
+
+## 2. 公共信息
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/models | 用户 | 获取前端可用模型列表 |
+| GET | /api/notice | 公开 | 获取公告栏内容 |
+| GET | /api/about | 公开 | 关于页面信息 |
+| GET | /api/home_page_content | 公开 | 首页自定义内容 |
+| GET | /api/pricing | 可匿名/用户 | 价格与套餐信息 |
+| GET | /api/ratio_config | 公开 | 模型倍率配置(仅公开字段) |
+
+## 3. 邮件 / 身份验证
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/verification | 公开 (限流) | 发送邮箱验证邮件 |
+| GET | /api/reset_password | 公开 (限流) | 发送重置密码邮件 |
+| POST | /api/user/reset | 公开 | 提交重置密码请求 |
+
+## 4. OAuth / 第三方登录
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/oauth/github | 公开 | GitHub OAuth 跳转 |
+| GET | /api/oauth/oidc | 公开 | OIDC 通用 OAuth 跳转 |
+| GET | /api/oauth/linuxdo | 公开 | LinuxDo OAuth 跳转 |
+| GET | /api/oauth/wechat | 公开 | 微信扫码登录跳转 |
+| GET | /api/oauth/wechat/bind | 公开 | 微信账户绑定 |
+| GET | /api/oauth/email/bind | 公开 | 邮箱绑定 |
+| GET | /api/oauth/telegram/login | 公开 | Telegram 登录 |
+| GET | /api/oauth/telegram/bind | 公开 | Telegram 账户绑定 |
+| GET | /api/oauth/state | 公开 | 获取随机 state(防 CSRF) |
+
+## 5. 用户模块
+### 5.1 账号注册/登录
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| POST | /api/user/register | 公开 | 注册新账号 |
+| POST | /api/user/login | 公开 | 用户登录 |
+| GET | /api/user/logout | 用户 | 退出登录 |
+| GET | /api/user/epay/notify | 公开 | Epay 支付回调 |
+| GET | /api/user/groups | 公开 | 列出所有分组(无鉴权版) |
+
+### 5.2 用户自身操作 (需登录)
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/user/self/groups | 用户 | 获取自己所在分组 |
+| GET | /api/user/self | 用户 | 获取个人资料 |
+| GET | /api/user/models | 用户 | 获取模型可见性 |
+| PUT | /api/user/self | 用户 | 修改个人资料 |
+| DELETE | /api/user/self | 用户 | 注销账号 |
+| GET | /api/user/token | 用户 | 生成用户级别 Access Token |
+| GET | /api/user/aff | 用户 | 获取推广码信息 |
+| POST | /api/user/topup | 用户 | 余额直充 |
+| POST | /api/user/pay | 用户 | 提交支付订单 |
+| POST | /api/user/amount | 用户 | 余额支付 |
+| POST | /api/user/aff_transfer | 用户 | 推广额度转账 |
+| PUT | /api/user/setting | 用户 | 更新用户设置 |
+
+### 5.3 管理员用户管理
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/user/ | 管理员 | 获取全部用户列表 |
+| GET | /api/user/search | 管理员 | 搜索用户 |
+| GET | /api/user/:id | 管理员 | 获取单个用户信息 |
+| POST | /api/user/ | 管理员 | 创建用户 |
+| POST | /api/user/manage | 管理员 | 冻结/重置等管理操作 |
+| PUT | /api/user/ | 管理员 | 更新用户 |
+| DELETE | /api/user/:id | 管理员 | 删除用户 |
+
+## 6. 站点选项 (Root)
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/option/ | Root | 获取全局配置 |
+| PUT | /api/option/ | Root | 更新全局配置 |
+| POST | /api/option/rest_model_ratio | Root | 重置模型倍率 |
+| POST | /api/option/migrate_console_setting | Root | 迁移旧版控制台配置 |
+
+## 7. 模型倍率同步 (Root)
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/ratio_sync/channels | Root | 获取可同步渠道列表 |
+| POST | /api/ratio_sync/fetch | Root | 从上游拉取倍率 |
+
+## 8. 渠道管理 (管理员)
+| 方法 | 路径 | 说明 |
+|------|------|------|
+| GET | /api/channel/ | 获取渠道列表 |
+| GET | /api/channel/search | 搜索渠道 |
+| GET | /api/channel/models | 查询渠道模型能力 |
+| GET | /api/channel/models_enabled | 查询启用模型能力 |
+| GET | /api/channel/:id | 获取单个渠道 |
+| GET | /api/channel/test | 批量测试渠道连通性 |
+| GET | /api/channel/test/:id | 单个渠道测试 |
+| GET | /api/channel/update_balance | 批量刷新余额 |
+| GET | /api/channel/update_balance/:id | 单个刷新余额 |
+| POST | /api/channel/ | 新增渠道 |
+| PUT | /api/channel/ | 更新渠道 |
+| DELETE | /api/channel/disabled | 删除已禁用渠道 |
+| POST | /api/channel/tag/disabled | 批量禁用标签渠道 |
+| POST | /api/channel/tag/enabled | 批量启用标签渠道 |
+| PUT | /api/channel/tag | 编辑渠道标签 |
+| DELETE | /api/channel/:id | 删除渠道 |
+| POST | /api/channel/batch | 批量删除渠道 |
+| POST | /api/channel/fix | 修复渠道能力表 |
+| GET | /api/channel/fetch_models/:id | 拉取单渠道模型 |
+| POST | /api/channel/fetch_models | 拉取全部渠道模型 |
+| POST | /api/channel/batch/tag | 批量设置渠道标签 |
+| GET | /api/channel/tag/models | 根据标签获取模型 |
+| POST | /api/channel/copy/:id | 复制渠道 |
+
+## 9. Token 管理
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/token/ | 用户 | 获取全部 Token |
+| GET | /api/token/search | 用户 | 搜索 Token |
+| GET | /api/token/:id | 用户 | 获取单个 Token |
+| POST | /api/token/ | 用户 | 创建 Token |
+| PUT | /api/token/ | 用户 | 更新 Token |
+| DELETE | /api/token/:id | 用户 | 删除 Token |
+| POST | /api/token/batch | 用户 | 批量删除 Token |
+
+## 10. 兑换码管理 (管理员)
+| 方法 | 路径 | 说明 |
+|------|------|------|
+| GET | /api/redemption/ | 获取兑换码列表 |
+| GET | /api/redemption/search | 搜索兑换码 |
+| GET | /api/redemption/:id | 获取单个兑换码 |
+| POST | /api/redemption/ | 创建兑换码 |
+| PUT | /api/redemption/ | 更新兑换码 |
+| DELETE | /api/redemption/invalid | 删除无效兑换码 |
+| DELETE | /api/redemption/:id | 删除兑换码 |
+
+## 11. 日志
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/log/ | 管理员 | 获取全部日志 |
+| DELETE | /api/log/ | 管理员 | 删除历史日志 |
+| GET | /api/log/stat | 管理员 | 日志统计 |
+| GET | /api/log/self/stat | 用户 | 我的日志统计 |
+| GET | /api/log/search | 管理员 | 搜索全部日志 |
+| GET | /api/log/self | 用户 | 获取我的日志 |
+| GET | /api/log/self/search | 用户 | 搜索我的日志 |
+| GET | /api/log/token | 公开 | 根据 Token 查询日志(支持 CORS) |
+
+## 12. 数据统计
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/data/ | 管理员 | 全站用量按日期统计 |
+| GET | /api/data/self | 用户 | 我的用量按日期统计 |
+
+## 13. 分组
+| GET | /api/group/ | 管理员 | 获取全部分组列表 |
+
+## 14. Midjourney 任务
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/mj/self | 用户 | 获取自己的 MJ 任务 |
+| GET | /api/mj/ | 管理员 | 获取全部 MJ 任务 |
+
+## 15. 任务中心
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /api/task/self | 用户 | 获取我的任务 |
+| GET | /api/task/ | 管理员 | 获取全部任务 |
+
+## 16. 账户计费面板 (Dashboard)
+| 方法 | 路径 | 鉴权 | 说明 |
+|------|------|------|------|
+| GET | /dashboard/billing/subscription | 用户 Token | 获取订阅额度信息 |
+| GET | /v1/dashboard/billing/subscription | 同上 | 兼容 OpenAI SDK 路径 |
+| GET | /dashboard/billing/usage | 用户 Token | 获取使用量信息 |
+| GET | /v1/dashboard/billing/usage | 同上 | 兼容 OpenAI SDK 路径 |
+
+---
+
+> **更新日期**:2025.07.17
diff --git a/docs/channel/other_setting.md b/docs/channel/other_setting.md
new file mode 100644
index 00000000..43341660
--- /dev/null
+++ b/docs/channel/other_setting.md
@@ -0,0 +1,33 @@
+# 渠道而外设置说明
+
+该配置用于设置一些额外的渠道参数,可以通过 JSON 对象进行配置。主要包含以下两个设置项:
+
+1. force_format
+ - 用于标识是否对数据进行强制格式化为 OpenAI 格式
+ - 类型为布尔值,设置为 true 时启用强制格式化
+
+2. proxy
+ - 用于配置网络代理
+ - 类型为字符串,填写代理地址(例如 socks5 协议的代理地址)
+
+3. thinking_to_content
+ - 用于标识是否将思考内容`reasoning_content`转换为``标签拼接到内容中返回
+ - 类型为布尔值,设置为 true 时启用思考内容转换
+
+--------------------------------------------------------------
+
+## JSON 格式示例
+
+以下是一个示例配置,启用强制格式化并设置了代理地址:
+
+```json
+{
+ "force_format": true,
+ "thinking_to_content": true,
+ "proxy": "socks5://xxxxxxx"
+}
+```
+
+--------------------------------------------------------------
+
+通过调整上述 JSON 配置中的值,可以灵活控制渠道的额外行为,比如是否进行格式化以及使用特定的网络代理。
diff --git a/docs/images/aliyun.png b/docs/images/aliyun.png
new file mode 100644
index 00000000..6266bfbf
Binary files /dev/null and b/docs/images/aliyun.png differ
diff --git a/docs/images/cherry-studio.png b/docs/images/cherry-studio.png
new file mode 100644
index 00000000..a58a7713
Binary files /dev/null and b/docs/images/cherry-studio.png differ
diff --git a/docs/images/io-net.png b/docs/images/io-net.png
new file mode 100644
index 00000000..fb47534d
Binary files /dev/null and b/docs/images/io-net.png differ
diff --git a/docs/images/pku.png b/docs/images/pku.png
new file mode 100644
index 00000000..a058c3ce
Binary files /dev/null and b/docs/images/pku.png differ
diff --git a/docs/images/ucloud.png b/docs/images/ucloud.png
new file mode 100644
index 00000000..16cca764
Binary files /dev/null and b/docs/images/ucloud.png differ
diff --git a/docs/installation/BT.md b/docs/installation/BT.md
new file mode 100644
index 00000000..b4ea5b2f
--- /dev/null
+++ b/docs/installation/BT.md
@@ -0,0 +1,3 @@
+密钥为环境变量SESSION_SECRET
+
+
diff --git a/docs/models/Midjourney.md b/docs/models/Midjourney.md
new file mode 100644
index 00000000..478115b7
--- /dev/null
+++ b/docs/models/Midjourney.md
@@ -0,0 +1,82 @@
+# Midjourney Proxy API文档
+
+**简介**:Midjourney Proxy API文档
+
+## 接口列表
+支持的接口如下:
++ [x] /mj/submit/imagine
++ [x] /mj/submit/change
++ [x] /mj/submit/blend
++ [x] /mj/submit/describe
++ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**)
++ [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址)
++ [x] /task/list-by-condition
++ [x] /mj/submit/action (仅midjourney-proxy-plus支持,下同)
++ [x] /mj/submit/modal
++ [x] /mj/submit/shorten
++ [x] /mj/task/{id}/image-seed
++ [x] /mj/insight-face/swap (InsightFace)
+
+## 模型列表
+
+### midjourney-proxy支持
+
+- mj_imagine (绘图)
+- mj_variation (变换)
+- mj_reroll (重绘)
+- mj_blend (混合)
+- mj_upscale (放大)
+- mj_describe (图生文)
+
+### 仅midjourney-proxy-plus支持
+
+- mj_zoom (比例变焦)
+- mj_shorten (提示词缩短)
+- mj_modal (窗口提交,局部重绘和自定义比例变焦必须和mj_modal一同添加)
+- mj_inpaint (局部重绘提交,必须和mj_modal一同添加)
+- mj_custom_zoom (自定义比例变焦,必须和mj_modal一同添加)
+- mj_high_variation (强变换)
+- mj_low_variation (弱变换)
+- mj_pan (平移)
+- swap_face (换脸)
+
+## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
+```json
+{
+ "mj_imagine": 0.1,
+ "mj_variation": 0.1,
+ "mj_reroll": 0.1,
+ "mj_blend": 0.1,
+ "mj_modal": 0.1,
+ "mj_zoom": 0.1,
+ "mj_shorten": 0.1,
+ "mj_high_variation": 0.1,
+ "mj_low_variation": 0.1,
+ "mj_pan": 0.1,
+ "mj_inpaint": 0,
+ "mj_custom_zoom": 0,
+ "mj_describe": 0.05,
+ "mj_upscale": 0.05,
+ "swap_face": 0.05
+}
+```
+其中mj_inpaint和mj_custom_zoom的价格设置为0,是因为这两个模型需要搭配mj_modal使用,所以价格由mj_modal决定。
+
+## 渠道设置
+
+### 对接 midjourney-proxy(plus)
+
+1.
+
+部署Midjourney-Proxy,并配置好midjourney账号等(强烈建议设置密钥),[项目地址](https://github.com/novicezk/midjourney-proxy)
+
+2. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy**,如果是plus版本选择**Midjourney Proxy Plus**
+ ,模型请参考上方模型列表
+3. **代理**填写midjourney-proxy部署的地址,例如:http://localhost:8080
+4. 密钥填写midjourney-proxy的密钥,如果没有设置密钥,可以随便填
+
+### 对接上游new api
+
+1. 在渠道管理中添加渠道,渠道类型选择**Midjourney Proxy Plus**,模型请参考上方模型列表
+2. **代理**填写上游new api的地址,例如:http://localhost:3000
+3. 密钥填写上游new api的密钥
\ No newline at end of file
diff --git a/docs/models/Rerank.md b/docs/models/Rerank.md
new file mode 100644
index 00000000..dc57d99b
--- /dev/null
+++ b/docs/models/Rerank.md
@@ -0,0 +1,62 @@
+# Rerank API文档
+
+**简介**:Rerank API文档
+
+## 接入Dify
+模型供应商选择Jina,按要求填写模型信息即可接入Dify。
+
+## 请求方式
+
+Post: /v1/rerank
+
+Request:
+
+```json
+{
+ "model": "jina-reranker-v2-base-multilingual",
+ "query": "What is the capital of the United States?",
+ "top_n": 3,
+ "documents": [
+ "Carson City is the capital city of the American state of Nevada.",
+ "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
+ "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
+ "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
+ "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
+ ]
+}
+```
+
+Response:
+
+```json
+{
+ "results": [
+ {
+ "document": {
+ "text": "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district."
+ },
+ "index": 2,
+ "relevance_score": 0.9999702
+ },
+ {
+ "document": {
+ "text": "Carson City is the capital city of the American state of Nevada."
+ },
+ "index": 0,
+ "relevance_score": 0.67800725
+ },
+ {
+ "document": {
+ "text": "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages."
+ },
+ "index": 3,
+ "relevance_score": 0.02800752
+ }
+ ],
+ "usage": {
+ "prompt_tokens": 158,
+ "completion_tokens": 0,
+ "total_tokens": 158
+ }
+}
+```
\ No newline at end of file
diff --git a/docs/models/Suno.md b/docs/models/Suno.md
new file mode 100644
index 00000000..840ca8e4
--- /dev/null
+++ b/docs/models/Suno.md
@@ -0,0 +1,44 @@
+# Suno API文档
+
+**简介**:Suno API文档
+
+## 接口列表
+支持的接口如下:
++ [x] /suno/submit/music
++ [x] /suno/submit/lyrics
++ [x] /suno/fetch
++ [x] /suno/fetch/:id
+
+## 模型列表
+
+### Suno API支持
+
+- suno_music (自定义模式、灵感模式、续写)
+- suno_lyrics (生成歌词)
+
+
+## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
+```json
+{
+ "suno_music": 0.3,
+ "suno_lyrics": 0.01
+}
+```
+
+## 渠道设置
+
+### 对接 Suno API
+
+1.
+部署 Suno API,并配置好suno账号等(强烈建议设置密钥),[项目地址](https://github.com/Suno-API/Suno-API)
+
+2. 在渠道管理中添加渠道,渠道类型选择**Suno API**
+ ,模型请参考上方模型列表
+3. **代理**填写 Suno API 部署的地址,例如:http://localhost:8080
+4. 密钥填写 Suno API 的密钥,如果没有设置密钥,可以随便填
+
+### 对接上游new api
+
+1. 在渠道管理中添加渠道,渠道类型选择**Suno API**,或任意类型,只需模型包含上方模型列表的模型
+2. **代理**填写上游new api的地址,例如:http://localhost:3000
+3. 密钥填写上游new api的密钥
\ No newline at end of file
diff --git a/dto/audio.go b/dto/audio.go
new file mode 100644
index 00000000..c36b3da5
--- /dev/null
+++ b/dto/audio.go
@@ -0,0 +1,34 @@
+package dto
+
+type AudioRequest struct {
+ Model string `json:"model"`
+ Input string `json:"input"`
+ Voice string `json:"voice"`
+ Speed float64 `json:"speed,omitempty"`
+ ResponseFormat string `json:"response_format,omitempty"`
+}
+
+type AudioResponse struct {
+ Text string `json:"text"`
+}
+
+type WhisperVerboseJSONResponse struct {
+ Task string `json:"task,omitempty"`
+ Language string `json:"language,omitempty"`
+ Duration float64 `json:"duration,omitempty"`
+ Text string `json:"text,omitempty"`
+ Segments []Segment `json:"segments,omitempty"`
+}
+
+type Segment struct {
+ Id int `json:"id"`
+ Seek int `json:"seek"`
+ Start float64 `json:"start"`
+ End float64 `json:"end"`
+ Text string `json:"text"`
+ Tokens []int `json:"tokens"`
+ Temperature float64 `json:"temperature"`
+ AvgLogprob float64 `json:"avg_logprob"`
+ CompressionRatio float64 `json:"compression_ratio"`
+ NoSpeechProb float64 `json:"no_speech_prob"`
+}
diff --git a/dto/channel_settings.go b/dto/channel_settings.go
new file mode 100644
index 00000000..871d6716
--- /dev/null
+++ b/dto/channel_settings.go
@@ -0,0 +1,7 @@
+package dto
+
+type ChannelSettings struct {
+ ForceFormat bool `json:"force_format,omitempty"`
+ ThinkingToContent bool `json:"thinking_to_content,omitempty"`
+ Proxy string `json:"proxy"`
+}
diff --git a/dto/claude.go b/dto/claude.go
new file mode 100644
index 00000000..1a7eacb1
--- /dev/null
+++ b/dto/claude.go
@@ -0,0 +1,337 @@
+package dto
+
+import (
+ "encoding/json"
+ "one-api/common"
+ "one-api/types"
+)
+
+type ClaudeMetadata struct {
+ UserId string `json:"user_id"`
+}
+
+type ClaudeMediaMessage struct {
+ Type string `json:"type,omitempty"`
+ Text *string `json:"text,omitempty"`
+ Model string `json:"model,omitempty"`
+ Source *ClaudeMessageSource `json:"source,omitempty"`
+ Usage *ClaudeUsage `json:"usage,omitempty"`
+ StopReason *string `json:"stop_reason,omitempty"`
+ PartialJson *string `json:"partial_json,omitempty"`
+ Role string `json:"role,omitempty"`
+ Thinking string `json:"thinking,omitempty"`
+ Signature string `json:"signature,omitempty"`
+ Delta string `json:"delta,omitempty"`
+ CacheControl json.RawMessage `json:"cache_control,omitempty"`
+ // tool_calls
+ Id string `json:"id,omitempty"`
+ Name string `json:"name,omitempty"`
+ Input any `json:"input,omitempty"`
+ Content any `json:"content,omitempty"`
+ ToolUseId string `json:"tool_use_id,omitempty"`
+}
+
+func (c *ClaudeMediaMessage) SetText(s string) {
+ c.Text = &s
+}
+
+func (c *ClaudeMediaMessage) GetText() string {
+ if c.Text == nil {
+ return ""
+ }
+ return *c.Text
+}
+
+func (c *ClaudeMediaMessage) IsStringContent() bool {
+ if c.Content == nil {
+ return false
+ }
+ _, ok := c.Content.(string)
+ if ok {
+ return true
+ }
+ return false
+}
+
+func (c *ClaudeMediaMessage) GetStringContent() string {
+ if c.Content == nil {
+ return ""
+ }
+ switch c.Content.(type) {
+ case string:
+ return c.Content.(string)
+ case []any:
+ var contentStr string
+ for _, contentItem := range c.Content.([]any) {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ if contentMap["type"] == ContentTypeText {
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentStr += subStr
+ }
+ }
+ }
+ return contentStr
+ }
+
+ return ""
+}
+
+func (c *ClaudeMediaMessage) GetJsonRowString() string {
+ jsonContent, _ := json.Marshal(c)
+ return string(jsonContent)
+}
+
+func (c *ClaudeMediaMessage) SetContent(content any) {
+ c.Content = content
+}
+
+func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
+ mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.Content)
+ return mediaContent
+}
+
+type ClaudeMessageSource struct {
+ Type string `json:"type"`
+ MediaType string `json:"media_type,omitempty"`
+ Data any `json:"data,omitempty"`
+ Url string `json:"url,omitempty"`
+}
+
+type ClaudeMessage struct {
+ Role string `json:"role"`
+ Content any `json:"content"`
+}
+
+func (c *ClaudeMessage) IsStringContent() bool {
+ if c.Content == nil {
+ return false
+ }
+ _, ok := c.Content.(string)
+ return ok
+}
+
+func (c *ClaudeMessage) GetStringContent() string {
+ if c.Content == nil {
+ return ""
+ }
+ switch c.Content.(type) {
+ case string:
+ return c.Content.(string)
+ case []any:
+ var contentStr string
+ for _, contentItem := range c.Content.([]any) {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ if contentMap["type"] == ContentTypeText {
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentStr += subStr
+ }
+ }
+ }
+ return contentStr
+ }
+
+ return ""
+}
+
+func (c *ClaudeMessage) SetStringContent(content string) {
+ c.Content = content
+}
+
+func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
+ return common.Any2Type[[]ClaudeMediaMessage](c.Content)
+}
+
+type Tool struct {
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
+ InputSchema map[string]interface{} `json:"input_schema"`
+}
+
+type InputSchema struct {
+ Type string `json:"type"`
+ Properties any `json:"properties,omitempty"`
+ Required any `json:"required,omitempty"`
+}
+
+type ClaudeWebSearchTool struct {
+ Type string `json:"type"`
+ Name string `json:"name"`
+ MaxUses int `json:"max_uses,omitempty"`
+ UserLocation *ClaudeWebSearchUserLocation `json:"user_location,omitempty"`
+}
+
+type ClaudeWebSearchUserLocation struct {
+ Type string `json:"type"`
+ Timezone string `json:"timezone,omitempty"`
+ Country string `json:"country,omitempty"`
+ Region string `json:"region,omitempty"`
+ City string `json:"city,omitempty"`
+}
+
+type ClaudeToolChoice struct {
+ Type string `json:"type"`
+ Name string `json:"name,omitempty"`
+ DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
+}
+
+type ClaudeRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt,omitempty"`
+ System any `json:"system,omitempty"`
+ Messages []ClaudeMessage `json:"messages,omitempty"`
+ MaxTokens uint `json:"max_tokens,omitempty"`
+ MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
+ StopSequences []string `json:"stop_sequences,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ //ClaudeMetadata `json:"metadata,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Tools any `json:"tools,omitempty"`
+ ToolChoice any `json:"tool_choice,omitempty"`
+ Thinking *Thinking `json:"thinking,omitempty"`
+}
+
+// AddTool 添加工具到请求中
+func (c *ClaudeRequest) AddTool(tool any) {
+ if c.Tools == nil {
+ c.Tools = make([]any, 0)
+ }
+
+ switch tools := c.Tools.(type) {
+ case []any:
+ c.Tools = append(tools, tool)
+ default:
+ // 如果Tools不是[]any类型,重新初始化为[]any
+ c.Tools = []any{tool}
+ }
+}
+
+// GetTools 获取工具列表
+func (c *ClaudeRequest) GetTools() []any {
+ if c.Tools == nil {
+ return nil
+ }
+
+ switch tools := c.Tools.(type) {
+ case []any:
+ return tools
+ default:
+ return nil
+ }
+}
+
+// ProcessTools 处理工具列表,支持类型断言
+func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) {
+ var normalTools []*Tool
+ var webSearchTools []*ClaudeWebSearchTool
+
+ for _, tool := range tools {
+ switch t := tool.(type) {
+ case *Tool:
+ normalTools = append(normalTools, t)
+ case *ClaudeWebSearchTool:
+ webSearchTools = append(webSearchTools, t)
+ case Tool:
+ normalTools = append(normalTools, &t)
+ case ClaudeWebSearchTool:
+ webSearchTools = append(webSearchTools, &t)
+ default:
+ // 未知类型,跳过
+ continue
+ }
+ }
+
+ return normalTools, webSearchTools
+}
+
+type Thinking struct {
+ Type string `json:"type"`
+ BudgetTokens *int `json:"budget_tokens,omitempty"`
+}
+
+func (c *Thinking) GetBudgetTokens() int {
+ if c.BudgetTokens == nil {
+ return 0
+ }
+ return *c.BudgetTokens
+}
+
+func (c *ClaudeRequest) IsStringSystem() bool {
+ _, ok := c.System.(string)
+ return ok
+}
+
+func (c *ClaudeRequest) GetStringSystem() string {
+ if c.IsStringSystem() {
+ return c.System.(string)
+ }
+ return ""
+}
+
+func (c *ClaudeRequest) SetStringSystem(system string) {
+ c.System = system
+}
+
+func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
+ mediaContent, _ := common.Any2Type[[]ClaudeMediaMessage](c.System)
+ return mediaContent
+}
+
+type ClaudeError struct {
+ Type string `json:"type,omitempty"`
+ Message string `json:"message,omitempty"`
+}
+
+type ClaudeErrorWithStatusCode struct {
+ Error ClaudeError `json:"error"`
+ StatusCode int `json:"status_code"`
+ LocalError bool
+}
+
+type ClaudeResponse struct {
+ Id string `json:"id,omitempty"`
+ Type string `json:"type"`
+ Role string `json:"role,omitempty"`
+ Content []ClaudeMediaMessage `json:"content,omitempty"`
+ Completion string `json:"completion,omitempty"`
+ StopReason string `json:"stop_reason,omitempty"`
+ Model string `json:"model,omitempty"`
+ Error *types.ClaudeError `json:"error,omitempty"`
+ Usage *ClaudeUsage `json:"usage,omitempty"`
+ Index *int `json:"index,omitempty"`
+ ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
+ Delta *ClaudeMediaMessage `json:"delta,omitempty"`
+ Message *ClaudeMediaMessage `json:"message,omitempty"`
+}
+
+// set index
+func (c *ClaudeResponse) SetIndex(i int) {
+ c.Index = &i
+}
+
+// get index
+func (c *ClaudeResponse) GetIndex() int {
+ if c.Index == nil {
+ return 0
+ }
+ return *c.Index
+}
+
+type ClaudeUsage struct {
+ InputTokens int `json:"input_tokens"`
+ CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
+ CacheReadInputTokens int `json:"cache_read_input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+ ServerToolUse *ClaudeServerToolUse `json:"server_tool_use"`
+}
+
+type ClaudeServerToolUse struct {
+ WebSearchRequests int `json:"web_search_requests"`
+}
diff --git a/dto/dalle.go b/dto/dalle.go
new file mode 100644
index 00000000..ce2f6361
--- /dev/null
+++ b/dto/dalle.go
@@ -0,0 +1,29 @@
+package dto
+
+import "encoding/json"
+
+type ImageRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt" binding:"required"`
+ N int `json:"n,omitempty"`
+ Size string `json:"size,omitempty"`
+ Quality string `json:"quality,omitempty"`
+ ResponseFormat string `json:"response_format,omitempty"`
+ Style string `json:"style,omitempty"`
+ User string `json:"user,omitempty"`
+ ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
+ Background string `json:"background,omitempty"`
+ Moderation string `json:"moderation,omitempty"`
+ OutputFormat string `json:"output_format,omitempty"`
+ Watermark *bool `json:"watermark,omitempty"`
+}
+
+type ImageResponse struct {
+ Data []ImageData `json:"data"`
+ Created int64 `json:"created"`
+}
+type ImageData struct {
+ Url string `json:"url"`
+ B64Json string `json:"b64_json"`
+ RevisedPrompt string `json:"revised_prompt"`
+}
diff --git a/dto/embedding.go b/dto/embedding.go
new file mode 100644
index 00000000..9d722292
--- /dev/null
+++ b/dto/embedding.go
@@ -0,0 +1,57 @@
+package dto
+
+type EmbeddingOptions struct {
+ Seed int `json:"seed,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ TopP *float64 `json:"top_p,omitempty"`
+ FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty *float64 `json:"presence_penalty,omitempty"`
+ NumPredict int `json:"num_predict,omitempty"`
+ NumCtx int `json:"num_ctx,omitempty"`
+}
+
+type EmbeddingRequest struct {
+ Model string `json:"model"`
+ Input any `json:"input"`
+ EncodingFormat string `json:"encoding_format,omitempty"`
+ Dimensions int `json:"dimensions,omitempty"`
+ User string `json:"user,omitempty"`
+ Seed float64 `json:"seed,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty float64 `json:"presence_penalty,omitempty"`
+}
+
+func (r EmbeddingRequest) ParseInput() []string {
+ if r.Input == nil {
+ return nil
+ }
+ var input []string
+ switch r.Input.(type) {
+ case string:
+ input = []string{r.Input.(string)}
+ case []any:
+ input = make([]string, 0, len(r.Input.([]any)))
+ for _, item := range r.Input.([]any) {
+ if str, ok := item.(string); ok {
+ input = append(input, str)
+ }
+ }
+ }
+ return input
+}
+
+type EmbeddingResponseItem struct {
+ Object string `json:"object"`
+ Index int `json:"index"`
+ Embedding []float64 `json:"embedding"`
+}
+
+type EmbeddingResponse struct {
+ Object string `json:"object"`
+ Data []EmbeddingResponseItem `json:"data"`
+ Model string `json:"model"`
+ Usage `json:"usage"`
+}
diff --git a/dto/error.go b/dto/error.go
new file mode 100644
index 00000000..d7f6824d
--- /dev/null
+++ b/dto/error.go
@@ -0,0 +1,57 @@
+package dto
+
+import "one-api/types"
+
+type OpenAIError struct {
+ Message string `json:"message"`
+ Type string `json:"type"`
+ Param string `json:"param"`
+ Code any `json:"code"`
+}
+
+type OpenAIErrorWithStatusCode struct {
+ Error OpenAIError `json:"error"`
+ StatusCode int `json:"status_code"`
+ LocalError bool
+}
+
+type GeneralErrorResponse struct {
+ Error types.OpenAIError `json:"error"`
+ Message string `json:"message"`
+ Msg string `json:"msg"`
+ Err string `json:"err"`
+ ErrorMsg string `json:"error_msg"`
+ Header struct {
+ Message string `json:"message"`
+ } `json:"header"`
+ Response struct {
+ Error struct {
+ Message string `json:"message"`
+ } `json:"error"`
+ } `json:"response"`
+}
+
+func (e GeneralErrorResponse) ToMessage() string {
+ if e.Error.Message != "" {
+ return e.Error.Message
+ }
+ if e.Message != "" {
+ return e.Message
+ }
+ if e.Msg != "" {
+ return e.Msg
+ }
+ if e.Err != "" {
+ return e.Err
+ }
+ if e.ErrorMsg != "" {
+ return e.ErrorMsg
+ }
+ if e.Header.Message != "" {
+ return e.Header.Message
+ }
+ if e.Response.Error.Message != "" {
+ return e.Response.Error.Message
+ }
+ return ""
+}
diff --git a/dto/file_data.go b/dto/file_data.go
new file mode 100644
index 00000000..d5cf0f68
--- /dev/null
+++ b/dto/file_data.go
@@ -0,0 +1,8 @@
+package dto
+
+type LocalFileData struct {
+ MimeType string
+ Base64Data string
+ Url string
+ Size int64
+}
diff --git a/dto/midjourney.go b/dto/midjourney.go
new file mode 100644
index 00000000..6fbcb357
--- /dev/null
+++ b/dto/midjourney.go
@@ -0,0 +1,107 @@
+package dto
+
+//type SimpleMjRequest struct {
+// Prompt string `json:"prompt"`
+// CustomId string `json:"customId"`
+// Action string `json:"action"`
+// Content string `json:"content"`
+//}
+
+type SwapFaceRequest struct {
+ SourceBase64 string `json:"sourceBase64"`
+ TargetBase64 string `json:"targetBase64"`
+}
+
+type MidjourneyRequest struct {
+ Prompt string `json:"prompt"`
+ CustomId string `json:"customId"`
+ BotType string `json:"botType"`
+ NotifyHook string `json:"notifyHook"`
+ Action string `json:"action"`
+ Index int `json:"index"`
+ State string `json:"state"`
+ TaskId string `json:"taskId"`
+ Base64Array []string `json:"base64Array"`
+ Content string `json:"content"`
+ MaskBase64 string `json:"maskBase64"`
+}
+
+type MidjourneyResponse struct {
+ Code int `json:"code"`
+ Description string `json:"description"`
+ Properties interface{} `json:"properties"`
+ Result string `json:"result"`
+}
+
+type MidjourneyUploadResponse struct {
+ Code int `json:"code"`
+ Description string `json:"description"`
+ Result []string `json:"result"`
+}
+
+type MidjourneyResponseWithStatusCode struct {
+ StatusCode int `json:"statusCode"`
+ Response MidjourneyResponse
+}
+
+type MidjourneyDto struct {
+ MjId string `json:"id"`
+ Action string `json:"action"`
+ CustomId string `json:"customId"`
+ BotType string `json:"botType"`
+ Prompt string `json:"prompt"`
+ PromptEn string `json:"promptEn"`
+ Description string `json:"description"`
+ State string `json:"state"`
+ SubmitTime int64 `json:"submitTime"`
+ StartTime int64 `json:"startTime"`
+ FinishTime int64 `json:"finishTime"`
+ ImageUrl string `json:"imageUrl"`
+ VideoUrl string `json:"videoUrl"`
+ VideoUrls []ImgUrls `json:"videoUrls"`
+ Status string `json:"status"`
+ Progress string `json:"progress"`
+ FailReason string `json:"failReason"`
+ Buttons any `json:"buttons"`
+ MaskBase64 string `json:"maskBase64"`
+ Properties *Properties `json:"properties"`
+}
+
+type ImgUrls struct {
+ Url string `json:"url"`
+}
+
+type MidjourneyStatus struct {
+ Status int `json:"status"`
+}
+type MidjourneyWithoutStatus struct {
+ Id int `json:"id"`
+ Code int `json:"code"`
+ UserId int `json:"user_id" gorm:"index"`
+ Action string `json:"action"`
+ MjId string `json:"mj_id" gorm:"index"`
+ Prompt string `json:"prompt"`
+ PromptEn string `json:"prompt_en"`
+ Description string `json:"description"`
+ State string `json:"state"`
+ SubmitTime int64 `json:"submit_time"`
+ StartTime int64 `json:"start_time"`
+ FinishTime int64 `json:"finish_time"`
+ ImageUrl string `json:"image_url"`
+ Progress string `json:"progress"`
+ FailReason string `json:"fail_reason"`
+ ChannelId int `json:"channel_id"`
+}
+
+type ActionButton struct {
+ CustomId any `json:"customId"`
+ Emoji any `json:"emoji"`
+ Label any `json:"label"`
+ Type any `json:"type"`
+ Style any `json:"style"`
+}
+
+type Properties struct {
+ FinalPrompt string `json:"finalPrompt"`
+ FinalZhPrompt string `json:"finalZhPrompt"`
+}
diff --git a/dto/notify.go b/dto/notify.go
new file mode 100644
index 00000000..b75cec70
--- /dev/null
+++ b/dto/notify.go
@@ -0,0 +1,25 @@
+package dto
+
+type Notify struct {
+ Type string `json:"type"`
+ Title string `json:"title"`
+ Content string `json:"content"`
+ Values []interface{} `json:"values"`
+}
+
+const ContentValueParam = "{{value}}"
+
+const (
+ NotifyTypeQuotaExceed = "quota_exceed"
+ NotifyTypeChannelUpdate = "channel_update"
+ NotifyTypeChannelTest = "channel_test"
+)
+
+func NewNotify(t string, title string, content string, values []interface{}) Notify {
+ return Notify{
+ Type: t,
+ Title: title,
+ Content: content,
+ Values: values,
+ }
+}
diff --git a/dto/openai_request.go b/dto/openai_request.go
new file mode 100644
index 00000000..88d3bd6c
--- /dev/null
+++ b/dto/openai_request.go
@@ -0,0 +1,655 @@
+package dto
+
+import (
+ "encoding/json"
+ "one-api/common"
+ "strings"
+)
+
+type ResponseFormat struct {
+ Type string `json:"type,omitempty"`
+ JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
+}
+
+type FormatJsonSchema struct {
+ Description string `json:"description,omitempty"`
+ Name string `json:"name"`
+ Schema any `json:"schema,omitempty"`
+ Strict any `json:"strict,omitempty"`
+}
+
+type GeneralOpenAIRequest struct {
+ Model string `json:"model,omitempty"`
+ Messages []Message `json:"messages,omitempty"`
+ Prompt any `json:"prompt,omitempty"`
+ Prefix any `json:"prefix,omitempty"`
+ Suffix any `json:"suffix,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ StreamOptions *StreamOptions `json:"stream_options,omitempty"`
+ MaxTokens uint `json:"max_tokens,omitempty"`
+ MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
+ ReasoningEffort string `json:"reasoning_effort,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ Stop any `json:"stop,omitempty"`
+ N int `json:"n,omitempty"`
+ Input any `json:"input,omitempty"`
+ Instruction string `json:"instruction,omitempty"`
+ Size string `json:"size,omitempty"`
+ Functions json.RawMessage `json:"functions,omitempty"`
+ FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty float64 `json:"presence_penalty,omitempty"`
+ ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
+ EncodingFormat json.RawMessage `json:"encoding_format,omitempty"`
+ Seed float64 `json:"seed,omitempty"`
+ ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
+ Tools []ToolCallRequest `json:"tools,omitempty"`
+ ToolChoice any `json:"tool_choice,omitempty"`
+ User string `json:"user,omitempty"`
+ LogProbs bool `json:"logprobs,omitempty"`
+ TopLogProbs int `json:"top_logprobs,omitempty"`
+ Dimensions int `json:"dimensions,omitempty"`
+ Modalities json.RawMessage `json:"modalities,omitempty"`
+ Audio json.RawMessage `json:"audio,omitempty"`
+ EnableThinking any `json:"enable_thinking,omitempty"` // ali
+ THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
+ ExtraBody json.RawMessage `json:"extra_body,omitempty"`
+ SearchParameters any `json:"search_parameters,omitempty"` //xai
+ WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
+ // OpenRouter Params
+ Usage json.RawMessage `json:"usage,omitempty"`
+ Reasoning json.RawMessage `json:"reasoning,omitempty"`
+ // Ali Qwen Params
+ VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
+}
+
+func (r *GeneralOpenAIRequest) ToMap() map[string]any {
+ result := make(map[string]any)
+ data, _ := common.Marshal(r)
+ _ = common.Unmarshal(data, &result)
+ return result
+}
+
+type ToolCallRequest struct {
+ ID string `json:"id,omitempty"`
+ Type string `json:"type"`
+ Function FunctionRequest `json:"function"`
+}
+
+type FunctionRequest struct {
+ Description string `json:"description,omitempty"`
+ Name string `json:"name"`
+ Parameters any `json:"parameters,omitempty"`
+ Arguments string `json:"arguments,omitempty"`
+}
+
+type StreamOptions struct {
+ IncludeUsage bool `json:"include_usage,omitempty"`
+}
+
+func (r *GeneralOpenAIRequest) GetMaxTokens() int {
+ return int(r.MaxTokens)
+}
+
+func (r *GeneralOpenAIRequest) ParseInput() []string {
+ if r.Input == nil {
+ return nil
+ }
+ var input []string
+ switch r.Input.(type) {
+ case string:
+ input = []string{r.Input.(string)}
+ case []any:
+ input = make([]string, 0, len(r.Input.([]any)))
+ for _, item := range r.Input.([]any) {
+ if str, ok := item.(string); ok {
+ input = append(input, str)
+ }
+ }
+ }
+ return input
+}
+
+type Message struct {
+ Role string `json:"role"`
+ Content any `json:"content"`
+ Name *string `json:"name,omitempty"`
+ Prefix *bool `json:"prefix,omitempty"`
+ ReasoningContent string `json:"reasoning_content,omitempty"`
+ Reasoning string `json:"reasoning,omitempty"`
+ ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
+ ToolCallId string `json:"tool_call_id,omitempty"`
+ parsedContent []MediaContent
+ //parsedStringContent *string
+}
+
+type MediaContent struct {
+ Type string `json:"type"`
+ Text string `json:"text,omitempty"`
+ ImageUrl any `json:"image_url,omitempty"`
+ InputAudio any `json:"input_audio,omitempty"`
+ File any `json:"file,omitempty"`
+ VideoUrl any `json:"video_url,omitempty"`
+ // OpenRouter Params
+ CacheControl json.RawMessage `json:"cache_control,omitempty"`
+}
+
+func (m *MediaContent) GetImageMedia() *MessageImageUrl {
+ if m.ImageUrl != nil {
+ if _, ok := m.ImageUrl.(*MessageImageUrl); ok {
+ return m.ImageUrl.(*MessageImageUrl)
+ }
+ if itemMap, ok := m.ImageUrl.(map[string]any); ok {
+ out := &MessageImageUrl{
+ Url: common.Interface2String(itemMap["url"]),
+ Detail: common.Interface2String(itemMap["detail"]),
+ MimeType: common.Interface2String(itemMap["mime_type"]),
+ }
+ return out
+ }
+ }
+ return nil
+}
+
+func (m *MediaContent) GetInputAudio() *MessageInputAudio {
+ if m.InputAudio != nil {
+ if _, ok := m.InputAudio.(*MessageInputAudio); ok {
+ return m.InputAudio.(*MessageInputAudio)
+ }
+ if itemMap, ok := m.InputAudio.(map[string]any); ok {
+ out := &MessageInputAudio{
+ Data: common.Interface2String(itemMap["data"]),
+ Format: common.Interface2String(itemMap["format"]),
+ }
+ return out
+ }
+ }
+ return nil
+}
+
+func (m *MediaContent) GetFile() *MessageFile {
+ if m.File != nil {
+ if _, ok := m.File.(*MessageFile); ok {
+ return m.File.(*MessageFile)
+ }
+ if itemMap, ok := m.File.(map[string]any); ok {
+ out := &MessageFile{
+ FileName: common.Interface2String(itemMap["file_name"]),
+ FileData: common.Interface2String(itemMap["file_data"]),
+ FileId: common.Interface2String(itemMap["file_id"]),
+ }
+ return out
+ }
+ }
+ return nil
+}
+
+type MessageImageUrl struct {
+ Url string `json:"url"`
+ Detail string `json:"detail"`
+ MimeType string
+}
+
+func (m *MessageImageUrl) IsRemoteImage() bool {
+ return strings.HasPrefix(m.Url, "http")
+}
+
+type MessageInputAudio struct {
+ Data string `json:"data"` //base64
+ Format string `json:"format"`
+}
+
+type MessageFile struct {
+ FileName string `json:"filename,omitempty"`
+ FileData string `json:"file_data,omitempty"`
+ FileId string `json:"file_id,omitempty"`
+}
+
+type MessageVideoUrl struct {
+ Url string `json:"url"`
+}
+
+const (
+ ContentTypeText = "text"
+ ContentTypeImageURL = "image_url"
+ ContentTypeInputAudio = "input_audio"
+ ContentTypeFile = "file"
+ ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别
+)
+
+func (m *Message) GetPrefix() bool {
+ if m.Prefix == nil {
+ return false
+ }
+ return *m.Prefix
+}
+
+func (m *Message) SetPrefix(prefix bool) {
+ m.Prefix = &prefix
+}
+
+func (m *Message) ParseToolCalls() []ToolCallRequest {
+ if m.ToolCalls == nil {
+ return nil
+ }
+ var toolCalls []ToolCallRequest
+ if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
+ return toolCalls
+ }
+ return toolCalls
+}
+
+func (m *Message) SetToolCalls(toolCalls any) {
+ toolCallsJson, _ := json.Marshal(toolCalls)
+ m.ToolCalls = toolCallsJson
+}
+
+func (m *Message) StringContent() string {
+ switch m.Content.(type) {
+ case string:
+ return m.Content.(string)
+ case []any:
+ var contentStr string
+ for _, contentItem := range m.Content.([]any) {
+ contentMap, ok := contentItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ if contentMap["type"] == ContentTypeText {
+ if subStr, ok := contentMap["text"].(string); ok {
+ contentStr += subStr
+ }
+ }
+ }
+ return contentStr
+ }
+
+ return ""
+}
+
+func (m *Message) SetNullContent() {
+ m.Content = nil
+ m.parsedContent = nil
+}
+
+func (m *Message) SetStringContent(content string) {
+ m.Content = content
+ m.parsedContent = nil
+}
+
+func (m *Message) SetMediaContent(content []MediaContent) {
+ m.Content = content
+ m.parsedContent = content
+}
+
+func (m *Message) IsStringContent() bool {
+ _, ok := m.Content.(string)
+ if ok {
+ return true
+ }
+ return false
+}
+
+func (m *Message) ParseContent() []MediaContent {
+ if m.Content == nil {
+ return nil
+ }
+ if len(m.parsedContent) > 0 {
+ return m.parsedContent
+ }
+
+ var contentList []MediaContent
+ // 先尝试解析为字符串
+ content, ok := m.Content.(string)
+ if ok {
+ contentList = []MediaContent{{
+ Type: ContentTypeText,
+ Text: content,
+ }}
+ m.parsedContent = contentList
+ return contentList
+ }
+
+ // 尝试解析为数组
+ //var arrayContent []map[string]interface{}
+
+ arrayContent, ok := m.Content.([]any)
+ if !ok {
+ return contentList
+ }
+
+ for _, contentItemAny := range arrayContent {
+ mediaItem, ok := contentItemAny.(MediaContent)
+ if ok {
+ contentList = append(contentList, mediaItem)
+ continue
+ }
+
+ contentItem, ok := contentItemAny.(map[string]any)
+ if !ok {
+ continue
+ }
+ contentType, ok := contentItem["type"].(string)
+ if !ok {
+ continue
+ }
+
+ switch contentType {
+ case ContentTypeText:
+ if text, ok := contentItem["text"].(string); ok {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeText,
+ Text: text,
+ })
+ }
+
+ case ContentTypeImageURL:
+ imageUrl := contentItem["image_url"]
+ temp := &MessageImageUrl{
+ Detail: "high",
+ }
+ switch v := imageUrl.(type) {
+ case string:
+ temp.Url = v
+ case map[string]interface{}:
+ url, ok1 := v["url"].(string)
+ detail, ok2 := v["detail"].(string)
+ if ok2 {
+ temp.Detail = detail
+ }
+ if ok1 {
+ temp.Url = url
+ }
+ }
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeImageURL,
+ ImageUrl: temp,
+ })
+
+ case ContentTypeInputAudio:
+ if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
+ data, ok1 := audioData["data"].(string)
+ format, ok2 := audioData["format"].(string)
+ if ok1 && ok2 {
+ temp := &MessageInputAudio{
+ Data: data,
+ Format: format,
+ }
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeInputAudio,
+ InputAudio: temp,
+ })
+ }
+ }
+ case ContentTypeFile:
+ if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
+ fileId, ok3 := fileData["file_id"].(string)
+ if ok3 {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeFile,
+ File: &MessageFile{
+ FileId: fileId,
+ },
+ })
+ } else {
+ fileName, ok1 := fileData["filename"].(string)
+ fileDataStr, ok2 := fileData["file_data"].(string)
+ if ok1 && ok2 {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeFile,
+ File: &MessageFile{
+ FileName: fileName,
+ FileData: fileDataStr,
+ },
+ })
+ }
+ }
+ }
+ case ContentTypeVideoUrl:
+ if videoUrl, ok := contentItem["video_url"].(string); ok {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeVideoUrl,
+ VideoUrl: &MessageVideoUrl{
+ Url: videoUrl,
+ },
+ })
+ }
+ }
+ }
+
+ if len(contentList) > 0 {
+ m.parsedContent = contentList
+ }
+ return contentList
+}
+
+// old code
+/*func (m *Message) StringContent() string {
+ if m.parsedStringContent != nil {
+ return *m.parsedStringContent
+ }
+
+ var stringContent string
+ if err := json.Unmarshal(m.Content, &stringContent); err == nil {
+ m.parsedStringContent = &stringContent
+ return stringContent
+ }
+
+ contentStr := new(strings.Builder)
+ arrayContent := m.ParseContent()
+ for _, content := range arrayContent {
+ if content.Type == ContentTypeText {
+ contentStr.WriteString(content.Text)
+ }
+ }
+ stringContent = contentStr.String()
+ m.parsedStringContent = &stringContent
+
+ return stringContent
+}
+
+func (m *Message) SetNullContent() {
+ m.Content = nil
+ m.parsedStringContent = nil
+ m.parsedContent = nil
+}
+
+func (m *Message) SetStringContent(content string) {
+ jsonContent, _ := json.Marshal(content)
+ m.Content = jsonContent
+ m.parsedStringContent = &content
+ m.parsedContent = nil
+}
+
+func (m *Message) SetMediaContent(content []MediaContent) {
+ jsonContent, _ := json.Marshal(content)
+ m.Content = jsonContent
+ m.parsedContent = nil
+ m.parsedStringContent = nil
+}
+
+func (m *Message) IsStringContent() bool {
+ if m.parsedStringContent != nil {
+ return true
+ }
+ var stringContent string
+ if err := json.Unmarshal(m.Content, &stringContent); err == nil {
+ m.parsedStringContent = &stringContent
+ return true
+ }
+ return false
+}
+
+func (m *Message) ParseContent() []MediaContent {
+ if m.parsedContent != nil {
+ return m.parsedContent
+ }
+
+ var contentList []MediaContent
+
+ // 先尝试解析为字符串
+ var stringContent string
+ if err := json.Unmarshal(m.Content, &stringContent); err == nil {
+ contentList = []MediaContent{{
+ Type: ContentTypeText,
+ Text: stringContent,
+ }}
+ m.parsedContent = contentList
+ return contentList
+ }
+
+ // 尝试解析为数组
+ var arrayContent []map[string]interface{}
+ if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
+ for _, contentItem := range arrayContent {
+ contentType, ok := contentItem["type"].(string)
+ if !ok {
+ continue
+ }
+
+ switch contentType {
+ case ContentTypeText:
+ if text, ok := contentItem["text"].(string); ok {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeText,
+ Text: text,
+ })
+ }
+
+ case ContentTypeImageURL:
+ imageUrl := contentItem["image_url"]
+ temp := &MessageImageUrl{
+ Detail: "high",
+ }
+ switch v := imageUrl.(type) {
+ case string:
+ temp.Url = v
+ case map[string]interface{}:
+ url, ok1 := v["url"].(string)
+ detail, ok2 := v["detail"].(string)
+ if ok2 {
+ temp.Detail = detail
+ }
+ if ok1 {
+ temp.Url = url
+ }
+ }
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeImageURL,
+ ImageUrl: temp,
+ })
+
+ case ContentTypeInputAudio:
+ if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
+ data, ok1 := audioData["data"].(string)
+ format, ok2 := audioData["format"].(string)
+ if ok1 && ok2 {
+ temp := &MessageInputAudio{
+ Data: data,
+ Format: format,
+ }
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeInputAudio,
+ InputAudio: temp,
+ })
+ }
+ }
+ case ContentTypeFile:
+ if fileData, ok := contentItem["file"].(map[string]interface{}); ok {
+ fileId, ok3 := fileData["file_id"].(string)
+ if ok3 {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeFile,
+ File: &MessageFile{
+ FileId: fileId,
+ },
+ })
+ } else {
+ fileName, ok1 := fileData["filename"].(string)
+ fileDataStr, ok2 := fileData["file_data"].(string)
+ if ok1 && ok2 {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeFile,
+ File: &MessageFile{
+ FileName: fileName,
+ FileData: fileDataStr,
+ },
+ })
+ }
+ }
+ }
+ case ContentTypeVideoUrl:
+ if videoUrl, ok := contentItem["video_url"].(string); ok {
+ contentList = append(contentList, MediaContent{
+ Type: ContentTypeVideoUrl,
+ VideoUrl: &MessageVideoUrl{
+ Url: videoUrl,
+ },
+ })
+ }
+ }
+ }
+ }
+
+ if len(contentList) > 0 {
+ m.parsedContent = contentList
+ }
+ return contentList
+}*/
+
+type WebSearchOptions struct {
+ SearchContextSize string `json:"search_context_size,omitempty"`
+ UserLocation json.RawMessage `json:"user_location,omitempty"`
+}
+
+// https://platform.openai.com/docs/api-reference/responses/create
+type OpenAIResponsesRequest struct {
+ Model string `json:"model"`
+ Input json.RawMessage `json:"input,omitempty"`
+ Include json.RawMessage `json:"include,omitempty"`
+ Instructions json.RawMessage `json:"instructions,omitempty"`
+ MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
+ Metadata json.RawMessage `json:"metadata,omitempty"`
+ ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
+ PreviousResponseID string `json:"previous_response_id,omitempty"`
+ Reasoning *Reasoning `json:"reasoning,omitempty"`
+ ServiceTier string `json:"service_tier,omitempty"`
+ Store bool `json:"store,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ Text json.RawMessage `json:"text,omitempty"`
+ ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
+ Tools []map[string]any `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
+ TopP float64 `json:"top_p,omitempty"`
+ Truncation string `json:"truncation,omitempty"`
+ User string `json:"user,omitempty"`
+ MaxToolCalls uint `json:"max_tool_calls,omitempty"`
+ Prompt json.RawMessage `json:"prompt,omitempty"`
+}
+
+type Reasoning struct {
+ Effort string `json:"effort,omitempty"`
+ Summary string `json:"summary,omitempty"`
+}
+
+//type ResponsesToolsCall struct {
+// Type string `json:"type"`
+// // Web Search
+// UserLocation json.RawMessage `json:"user_location,omitempty"`
+// SearchContextSize string `json:"search_context_size,omitempty"`
+// // File Search
+// VectorStoreIds []string `json:"vector_store_ids,omitempty"`
+// MaxNumResults uint `json:"max_num_results,omitempty"`
+// Filters json.RawMessage `json:"filters,omitempty"`
+// // Computer Use
+// DisplayWidth uint `json:"display_width,omitempty"`
+// DisplayHeight uint `json:"display_height,omitempty"`
+// Environment string `json:"environment,omitempty"`
+// // Function
+// Name string `json:"name,omitempty"`
+// Description string `json:"description,omitempty"`
+// Parameters json.RawMessage `json:"parameters,omitempty"`
+// Function json.RawMessage `json:"function,omitempty"`
+// Container json.RawMessage `json:"container,omitempty"`
+//}
diff --git a/dto/openai_response.go b/dto/openai_response.go
new file mode 100644
index 00000000..4e534823
--- /dev/null
+++ b/dto/openai_response.go
@@ -0,0 +1,278 @@
+package dto
+
+import (
+ "encoding/json"
+ "one-api/types"
+)
+
+type SimpleResponse struct {
+ Usage `json:"usage"`
+ Error *OpenAIError `json:"error"`
+}
+
+type TextResponse struct {
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ Choices []OpenAITextResponseChoice `json:"choices"`
+ Usage `json:"usage"`
+}
+
+type OpenAITextResponseChoice struct {
+ Index int `json:"index"`
+ Message `json:"message"`
+ FinishReason string `json:"finish_reason"`
+}
+
+type OpenAITextResponse struct {
+ Id string `json:"id"`
+ Model string `json:"model"`
+ Object string `json:"object"`
+ Created any `json:"created"`
+ Choices []OpenAITextResponseChoice `json:"choices"`
+ Error *types.OpenAIError `json:"error,omitempty"`
+ Usage `json:"usage"`
+}
+
+type OpenAIEmbeddingResponseItem struct {
+ Object string `json:"object"`
+ Index int `json:"index"`
+ Embedding []float64 `json:"embedding"`
+}
+
+type OpenAIEmbeddingResponse struct {
+ Object string `json:"object"`
+ Data []OpenAIEmbeddingResponseItem `json:"data"`
+ Model string `json:"model"`
+ Usage `json:"usage"`
+}
+
+type FlexibleEmbeddingResponseItem struct {
+ Object string `json:"object"`
+ Index int `json:"index"`
+ Embedding any `json:"embedding"`
+}
+
+type FlexibleEmbeddingResponse struct {
+ Object string `json:"object"`
+ Data []FlexibleEmbeddingResponseItem `json:"data"`
+ Model string `json:"model"`
+ Usage `json:"usage"`
+}
+
+type ChatCompletionsStreamResponseChoice struct {
+ Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta,omitempty"`
+ Logprobs *any `json:"logprobs"`
+ FinishReason *string `json:"finish_reason"`
+ Index int `json:"index"`
+}
+
+type ChatCompletionsStreamResponseChoiceDelta struct {
+ Content *string `json:"content,omitempty"`
+ ReasoningContent *string `json:"reasoning_content,omitempty"`
+ Reasoning *string `json:"reasoning,omitempty"`
+ Role string `json:"role,omitempty"`
+ ToolCalls []ToolCallResponse `json:"tool_calls,omitempty"`
+}
+
+func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
+ c.Content = &s
+}
+
+func (c *ChatCompletionsStreamResponseChoiceDelta) GetContentString() string {
+ if c.Content == nil {
+ return ""
+ }
+ return *c.Content
+}
+
+func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string {
+ if c.ReasoningContent == nil && c.Reasoning == nil {
+ return ""
+ }
+ if c.ReasoningContent != nil {
+ return *c.ReasoningContent
+ }
+ return *c.Reasoning
+}
+
+func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
+ c.ReasoningContent = &s
+ c.Reasoning = &s
+}
+
+type ToolCallResponse struct {
+ // Index is not nil only in chat completion chunk object
+ Index *int `json:"index,omitempty"`
+ ID string `json:"id,omitempty"`
+ Type any `json:"type"`
+ Function FunctionResponse `json:"function"`
+}
+
+func (c *ToolCallResponse) SetIndex(i int) {
+ c.Index = &i
+}
+
+type FunctionResponse struct {
+ Description string `json:"description,omitempty"`
+ Name string `json:"name,omitempty"`
+ // call function with arguments in JSON format
+ Parameters any `json:"parameters,omitempty"` // request
+ Arguments string `json:"arguments"` // response
+}
+
+type ChatCompletionsStreamResponse struct {
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ SystemFingerprint *string `json:"system_fingerprint"`
+ Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
+ Usage *Usage `json:"usage"`
+}
+
+func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
+ if len(c.Choices) == 0 {
+ return false
+ }
+ return len(c.Choices[0].Delta.ToolCalls) > 0
+}
+
+func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse {
+ if c.IsToolCall() {
+ return &c.Choices[0].Delta.ToolCalls[0]
+ }
+ return nil
+}
+
+func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
+ choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
+ copy(choices, c.Choices)
+ return &ChatCompletionsStreamResponse{
+ Id: c.Id,
+ Object: c.Object,
+ Created: c.Created,
+ Model: c.Model,
+ SystemFingerprint: c.SystemFingerprint,
+ Choices: choices,
+ Usage: c.Usage,
+ }
+}
+
+func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
+ if c.SystemFingerprint == nil {
+ return ""
+ }
+ return *c.SystemFingerprint
+}
+
+func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) {
+ c.SystemFingerprint = &s
+}
+
+type ChatCompletionsStreamResponseSimple struct {
+ Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
+ Usage *Usage `json:"usage"`
+}
+
+type CompletionsStreamResponse struct {
+ Choices []struct {
+ Text string `json:"text"`
+ FinishReason string `json:"finish_reason"`
+ } `json:"choices"`
+}
+
+type Usage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+ PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"`
+
+ PromptTokensDetails InputTokenDetails `json:"prompt_tokens_details"`
+ CompletionTokenDetails OutputTokenDetails `json:"completion_tokens_details"`
+ InputTokens int `json:"input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+ InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
+ // OpenRouter Params
+ Cost any `json:"cost,omitempty"`
+}
+
+type InputTokenDetails struct {
+ CachedTokens int `json:"cached_tokens"`
+ CachedCreationTokens int `json:"-"`
+ TextTokens int `json:"text_tokens"`
+ AudioTokens int `json:"audio_tokens"`
+ ImageTokens int `json:"image_tokens"`
+}
+
+type OutputTokenDetails struct {
+ TextTokens int `json:"text_tokens"`
+ AudioTokens int `json:"audio_tokens"`
+ ReasoningTokens int `json:"reasoning_tokens"`
+}
+
+type OpenAIResponsesResponse struct {
+ ID string `json:"id"`
+ Object string `json:"object"`
+ CreatedAt int `json:"created_at"`
+ Status string `json:"status"`
+ Error *types.OpenAIError `json:"error,omitempty"`
+ IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
+ Instructions string `json:"instructions"`
+ MaxOutputTokens int `json:"max_output_tokens"`
+ Model string `json:"model"`
+ Output []ResponsesOutput `json:"output"`
+ ParallelToolCalls bool `json:"parallel_tool_calls"`
+ PreviousResponseID string `json:"previous_response_id"`
+ Reasoning *Reasoning `json:"reasoning"`
+ Store bool `json:"store"`
+ Temperature float64 `json:"temperature"`
+ ToolChoice string `json:"tool_choice"`
+ Tools []map[string]any `json:"tools"`
+ TopP float64 `json:"top_p"`
+ Truncation string `json:"truncation"`
+ Usage *Usage `json:"usage"`
+ User json.RawMessage `json:"user"`
+ Metadata json.RawMessage `json:"metadata"`
+}
+
+type IncompleteDetails struct {
+ Reasoning string `json:"reasoning"`
+}
+
+type ResponsesOutput struct {
+ Type string `json:"type"`
+ ID string `json:"id"`
+ Status string `json:"status"`
+ Role string `json:"role"`
+ Content []ResponsesOutputContent `json:"content"`
+}
+
+type ResponsesOutputContent struct {
+ Type string `json:"type"`
+ Text string `json:"text"`
+ Annotations []interface{} `json:"annotations"`
+}
+
+const (
+ BuildInToolWebSearchPreview = "web_search_preview"
+ BuildInToolFileSearch = "file_search"
+)
+
+const (
+ BuildInCallWebSearchCall = "web_search_call"
+)
+
+const (
+ ResponsesOutputTypeItemAdded = "response.output_item.added"
+ ResponsesOutputTypeItemDone = "response.output_item.done"
+)
+
+// ResponsesStreamResponse 用于处理 /v1/responses 流式响应
+type ResponsesStreamResponse struct {
+ Type string `json:"type"`
+ Response *OpenAIResponsesResponse `json:"response,omitempty"`
+ Delta string `json:"delta,omitempty"`
+ Item *ResponsesOutput `json:"item,omitempty"`
+}
diff --git a/dto/playground.go b/dto/playground.go
new file mode 100644
index 00000000..47eddaec
--- /dev/null
+++ b/dto/playground.go
@@ -0,0 +1,6 @@
+package dto
+
+type PlayGroundRequest struct {
+ Model string `json:"model,omitempty"`
+ Group string `json:"group,omitempty"`
+}
diff --git a/dto/pricing.go b/dto/pricing.go
new file mode 100644
index 00000000..0f317d9d
--- /dev/null
+++ b/dto/pricing.go
@@ -0,0 +1,11 @@
+package dto
+
+import "one-api/constant"
+
+type OpenAIModels struct {
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int `json:"created"`
+ OwnedBy string `json:"owned_by"`
+ SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
+}
diff --git a/dto/ratio_sync.go b/dto/ratio_sync.go
new file mode 100644
index 00000000..6315f31a
--- /dev/null
+++ b/dto/ratio_sync.go
@@ -0,0 +1,38 @@
+package dto
+
+type UpstreamDTO struct {
+ ID int `json:"id,omitempty"`
+ Name string `json:"name" binding:"required"`
+ BaseURL string `json:"base_url" binding:"required"`
+ Endpoint string `json:"endpoint"`
+}
+
+type UpstreamRequest struct {
+ ChannelIDs []int64 `json:"channel_ids"`
+ Upstreams []UpstreamDTO `json:"upstreams"`
+ Timeout int `json:"timeout"`
+}
+
+// TestResult 上游测试连通性结果
+type TestResult struct {
+ Name string `json:"name"`
+ Status string `json:"status"`
+ Error string `json:"error,omitempty"`
+}
+
+// DifferenceItem 差异项
+// Current 为本地值,可能为 nil
+// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
+
+type DifferenceItem struct {
+ Current interface{} `json:"current"`
+ Upstreams map[string]interface{} `json:"upstreams"`
+ Confidence map[string]bool `json:"confidence"`
+}
+
+type SyncableChannel struct {
+ ID int `json:"id"`
+ Name string `json:"name"`
+ BaseURL string `json:"base_url"`
+ Status int `json:"status"`
+}
\ No newline at end of file
diff --git a/dto/realtime.go b/dto/realtime.go
new file mode 100644
index 00000000..32a69056
--- /dev/null
+++ b/dto/realtime.go
@@ -0,0 +1,88 @@
+package dto
+
+import "one-api/types"
+
+const (
+ RealtimeEventTypeError = "error"
+ RealtimeEventTypeSessionUpdate = "session.update"
+ RealtimeEventTypeConversationCreate = "conversation.item.create"
+ RealtimeEventTypeResponseCreate = "response.create"
+ RealtimeEventInputAudioBufferAppend = "input_audio_buffer.append"
+)
+
+const (
+ RealtimeEventTypeResponseDone = "response.done"
+ RealtimeEventTypeSessionUpdated = "session.updated"
+ RealtimeEventTypeSessionCreated = "session.created"
+ RealtimeEventResponseAudioDelta = "response.audio.delta"
+ RealtimeEventResponseAudioTranscriptionDelta = "response.audio_transcript.delta"
+ RealtimeEventResponseFunctionCallArgumentsDelta = "response.function_call_arguments.delta"
+ RealtimeEventResponseFunctionCallArgumentsDone = "response.function_call_arguments.done"
+ RealtimeEventConversationItemCreated = "conversation.item.created"
+)
+
+type RealtimeEvent struct {
+ EventId string `json:"event_id"`
+ Type string `json:"type"`
+ //PreviousItemId string `json:"previous_item_id"`
+ Session *RealtimeSession `json:"session,omitempty"`
+ Item *RealtimeItem `json:"item,omitempty"`
+ Error *types.OpenAIError `json:"error,omitempty"`
+ Response *RealtimeResponse `json:"response,omitempty"`
+ Delta string `json:"delta,omitempty"`
+ Audio string `json:"audio,omitempty"`
+}
+
+type RealtimeResponse struct {
+ Usage *RealtimeUsage `json:"usage"`
+}
+
+type RealtimeUsage struct {
+ TotalTokens int `json:"total_tokens"`
+ InputTokens int `json:"input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+ InputTokenDetails InputTokenDetails `json:"input_token_details"`
+ OutputTokenDetails OutputTokenDetails `json:"output_token_details"`
+}
+
+type RealtimeSession struct {
+ Modalities []string `json:"modalities"`
+ Instructions string `json:"instructions"`
+ Voice string `json:"voice"`
+ InputAudioFormat string `json:"input_audio_format"`
+ OutputAudioFormat string `json:"output_audio_format"`
+ InputAudioTranscription InputAudioTranscription `json:"input_audio_transcription"`
+ TurnDetection interface{} `json:"turn_detection"`
+ Tools []RealTimeTool `json:"tools"`
+ ToolChoice string `json:"tool_choice"`
+ Temperature float64 `json:"temperature"`
+ //MaxResponseOutputTokens int `json:"max_response_output_tokens"`
+}
+
+type InputAudioTranscription struct {
+ Model string `json:"model"`
+}
+
+type RealTimeTool struct {
+ Type string `json:"type"`
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Parameters any `json:"parameters"`
+}
+
+type RealtimeItem struct {
+ Id string `json:"id"`
+ Type string `json:"type"`
+ Status string `json:"status"`
+ Role string `json:"role"`
+ Content []RealtimeContent `json:"content"`
+ Name *string `json:"name,omitempty"`
+ ToolCalls any `json:"tool_calls,omitempty"`
+ CallId string `json:"call_id,omitempty"`
+}
+type RealtimeContent struct {
+ Type string `json:"type"`
+ Text string `json:"text,omitempty"`
+ Audio string `json:"audio,omitempty"` // Base64-encoded audio bytes.
+ Transcript string `json:"transcript,omitempty"`
+}
diff --git a/dto/rerank.go b/dto/rerank.go
new file mode 100644
index 00000000..5ea68cba
--- /dev/null
+++ b/dto/rerank.go
@@ -0,0 +1,33 @@
+package dto
+
+type RerankRequest struct {
+ Documents []any `json:"documents"`
+ Query string `json:"query"`
+ Model string `json:"model"`
+ TopN int `json:"top_n,omitempty"`
+ ReturnDocuments *bool `json:"return_documents,omitempty"`
+ MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
+ OverLapTokens int `json:"overlap_tokens,omitempty"`
+}
+
+func (r *RerankRequest) GetReturnDocuments() bool {
+ if r.ReturnDocuments == nil {
+ return false
+ }
+ return *r.ReturnDocuments
+}
+
+type RerankResponseResult struct {
+ Document any `json:"document,omitempty"`
+ Index int `json:"index"`
+ RelevanceScore float64 `json:"relevance_score"`
+}
+
+type RerankDocument struct {
+ Text any `json:"text"`
+}
+
+type RerankResponse struct {
+ Results []RerankResponseResult `json:"results"`
+ Usage Usage `json:"usage"`
+}
diff --git a/dto/sensitive.go b/dto/sensitive.go
new file mode 100644
index 00000000..0bfbc6fb
--- /dev/null
+++ b/dto/sensitive.go
@@ -0,0 +1,6 @@
+package dto
+
+type SensitiveResponse struct {
+ SensitiveWords []string `json:"sensitive_words"`
+ Content string `json:"content"`
+}
diff --git a/dto/suno.go b/dto/suno.go
new file mode 100644
index 00000000..a6bb3eba
--- /dev/null
+++ b/dto/suno.go
@@ -0,0 +1,129 @@
+package dto
+
+import (
+ "encoding/json"
+)
+
+type TaskData interface {
+ SunoDataResponse | []SunoDataResponse | string | any
+}
+
+type SunoSubmitReq struct {
+ GptDescriptionPrompt string `json:"gpt_description_prompt,omitempty"`
+ Prompt string `json:"prompt,omitempty"`
+ Mv string `json:"mv,omitempty"`
+ Title string `json:"title,omitempty"`
+ Tags string `json:"tags,omitempty"`
+ ContinueAt float64 `json:"continue_at,omitempty"`
+ TaskID string `json:"task_id,omitempty"`
+ ContinueClipId string `json:"continue_clip_id,omitempty"`
+ MakeInstrumental bool `json:"make_instrumental"`
+}
+
+type FetchReq struct {
+ IDs []string `json:"ids"`
+}
+
+type SunoDataResponse struct {
+ TaskID string `json:"task_id" gorm:"type:varchar(50);index"`
+ Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
+ Status string `json:"status" gorm:"type:varchar(20);index"` // 任务状态, submitted, queueing, processing, success, failed
+ FailReason string `json:"fail_reason"`
+ SubmitTime int64 `json:"submit_time" gorm:"index"`
+ StartTime int64 `json:"start_time" gorm:"index"`
+ FinishTime int64 `json:"finish_time" gorm:"index"`
+ Data json.RawMessage `json:"data" gorm:"type:json"`
+}
+
+type SunoSong struct {
+ ID string `json:"id"`
+ VideoURL string `json:"video_url"`
+ AudioURL string `json:"audio_url"`
+ ImageURL string `json:"image_url"`
+ ImageLargeURL string `json:"image_large_url"`
+ MajorModelVersion string `json:"major_model_version"`
+ ModelName string `json:"model_name"`
+ Status string `json:"status"`
+ Title string `json:"title"`
+ Text string `json:"text"`
+ Metadata SunoMetadata `json:"metadata"`
+}
+
+type SunoMetadata struct {
+ Tags string `json:"tags"`
+ Prompt string `json:"prompt"`
+ GPTDescriptionPrompt interface{} `json:"gpt_description_prompt"`
+ AudioPromptID interface{} `json:"audio_prompt_id"`
+ Duration interface{} `json:"duration"`
+ ErrorType interface{} `json:"error_type"`
+ ErrorMessage interface{} `json:"error_message"`
+}
+
+type SunoLyrics struct {
+ ID string `json:"id"`
+ Status string `json:"status"`
+ Title string `json:"title"`
+ Text string `json:"text"`
+}
+
+const TaskSuccessCode = "success"
+
+type TaskResponse[T TaskData] struct {
+ Code string `json:"code"`
+ Message string `json:"message"`
+ Data T `json:"data"`
+}
+
+func (t *TaskResponse[T]) IsSuccess() bool {
+ return t.Code == TaskSuccessCode
+}
+
+type TaskDto struct {
+ TaskID string `json:"task_id"` // 第三方id,不一定有/ song id\ Task id
+ Action string `json:"action"` // 任务类型, song, lyrics, description-mode
+ Status string `json:"status"` // 任务状态, submitted, queueing, processing, success, failed
+ FailReason string `json:"fail_reason"`
+ SubmitTime int64 `json:"submit_time"`
+ StartTime int64 `json:"start_time"`
+ FinishTime int64 `json:"finish_time"`
+ Progress string `json:"progress"`
+ Data json.RawMessage `json:"data"`
+}
+
+type SunoGoAPISubmitReq struct {
+ CustomMode bool `json:"custom_mode"`
+
+ Input SunoGoAPISubmitReqInput `json:"input"`
+
+ NotifyHook string `json:"notify_hook,omitempty"`
+}
+
+type SunoGoAPISubmitReqInput struct {
+ GptDescriptionPrompt string `json:"gpt_description_prompt"`
+ Prompt string `json:"prompt"`
+ Mv string `json:"mv"`
+ Title string `json:"title"`
+ Tags string `json:"tags"`
+ ContinueAt float64 `json:"continue_at"`
+ TaskID string `json:"task_id"`
+ ContinueClipId string `json:"continue_clip_id"`
+ MakeInstrumental bool `json:"make_instrumental"`
+}
+
+type GoAPITaskResponse[T any] struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data T `json:"data"`
+ ErrorMessage string `json:"error_message,omitempty"`
+}
+
+type GoAPITaskResponseData struct {
+ TaskID string `json:"task_id"`
+}
+
+type GoAPIFetchResponseData struct {
+ TaskID string `json:"task_id"`
+ Status string `json:"status"`
+ Input string `json:"input"`
+ Clips map[string]SunoSong `json:"clips"`
+}
diff --git a/dto/task.go b/dto/task.go
new file mode 100644
index 00000000..afc186b4
--- /dev/null
+++ b/dto/task.go
@@ -0,0 +1,10 @@
+package dto
+
+type TaskError struct {
+ Code string `json:"code"`
+ Message string `json:"message"`
+ Data any `json:"data"`
+ StatusCode int `json:"-"`
+ LocalError bool `json:"-"`
+ Error error `json:"-"`
+}
diff --git a/dto/user_settings.go b/dto/user_settings.go
new file mode 100644
index 00000000..2e1a1541
--- /dev/null
+++ b/dto/user_settings.go
@@ -0,0 +1,16 @@
+package dto
+
+type UserSetting struct {
+ NotifyType string `json:"notify_type,omitempty"` // QuotaWarningType 额度预警类型
+ QuotaWarningThreshold float64 `json:"quota_warning_threshold,omitempty"` // QuotaWarningThreshold 额度预警阈值
+ WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
+ WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
+ NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
+ AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
+ RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
+}
+
+var (
+ NotifyTypeEmail = "email" // Email 邮件
+ NotifyTypeWebhook = "webhook" // Webhook
+)
diff --git a/dto/video.go b/dto/video.go
new file mode 100644
index 00000000..5b48146a
--- /dev/null
+++ b/dto/video.go
@@ -0,0 +1,47 @@
+package dto
+
+type VideoRequest struct {
+ Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
+ Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
+ Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
+ Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
+ Width int `json:"width" example:"512"` // Video width
+ Height int `json:"height" example:"512"` // Video height
+ Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
+ Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
+ N int `json:"n,omitempty" example:"1"` // Number of videos to generate
+ ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
+ User string `json:"user,omitempty" example:"user-1234"` // User identifier
+ Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
+}
+
+// VideoResponse 视频生成提交任务后的响应
+type VideoResponse struct {
+ TaskId string `json:"task_id"`
+ Status string `json:"status"`
+}
+
+// VideoTaskResponse 查询视频生成任务状态的响应
+type VideoTaskResponse struct {
+ TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
+ Status string `json:"status" example:"succeeded"` // 任务状态
+ Url string `json:"url,omitempty"` // 视频资源URL(成功时)
+ Format string `json:"format,omitempty" example:"mp4"` // 视频格式
+ Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
+ Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
+}
+
+// VideoTaskMetadata 视频任务元数据
+type VideoTaskMetadata struct {
+ Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
+ Fps int `json:"fps" example:"30"` // 实际帧率
+ Width int `json:"width" example:"512"` // 实际宽度
+ Height int `json:"height" example:"512"` // 实际高度
+ Seed int `json:"seed" example:"20231234"` // 使用的随机种子
+}
+
+// VideoTaskError 视频任务错误信息
+type VideoTaskError struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+}
diff --git a/go.mod b/go.mod
new file mode 100644
index 00000000..94873c88
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,98 @@
+module one-api
+
+// +heroku goVersion go1.18
+go 1.23.4
+
+require (
+ github.com/Calcium-Ion/go-epay v0.0.4
+ github.com/andybalholm/brotli v1.1.1
+ github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
+ github.com/aws/aws-sdk-go-v2 v1.26.1
+ github.com/aws/aws-sdk-go-v2/credentials v1.17.11
+ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
+ github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
+ github.com/gin-contrib/cors v1.7.2
+ github.com/gin-contrib/gzip v0.0.6
+ github.com/gin-contrib/sessions v0.0.5
+ github.com/gin-contrib/static v0.0.1
+ github.com/gin-gonic/gin v1.9.1
+ github.com/glebarez/sqlite v1.9.0
+ github.com/go-playground/validator/v10 v10.20.0
+ github.com/go-redis/redis/v8 v8.11.5
+ github.com/golang-jwt/jwt v3.2.2+incompatible
+ github.com/google/uuid v1.6.0
+ github.com/gorilla/websocket v1.5.0
+ github.com/joho/godotenv v1.5.1
+ github.com/pkg/errors v0.9.1
+ github.com/samber/lo v1.39.0
+ github.com/shirou/gopsutil v3.21.11+incompatible
+ github.com/shopspring/decimal v1.4.0
+ github.com/stripe/stripe-go/v81 v81.4.0
+ github.com/thanhpk/randstr v1.0.6
+ github.com/tiktoken-go/tokenizer v0.6.2
+ golang.org/x/crypto v0.35.0
+ golang.org/x/image v0.23.0
+ golang.org/x/net v0.35.0
+ golang.org/x/sync v0.11.0
+ gorm.io/driver/mysql v1.4.3
+ gorm.io/driver/postgres v1.5.2
+ gorm.io/gorm v1.25.2
+)
+
+require (
+ github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
+ github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
+ github.com/aws/smithy-go v1.20.2 // indirect
+ github.com/bytedance/sonic v1.11.6 // indirect
+ github.com/bytedance/sonic/loader v0.1.1 // indirect
+ github.com/cespare/xxhash/v2 v2.3.0 // indirect
+ github.com/cloudwego/base64x v0.1.4 // indirect
+ github.com/cloudwego/iasm v0.2.0 // indirect
+ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
+ github.com/dlclark/regexp2 v1.11.5 // indirect
+ github.com/dustin/go-humanize v1.0.1 // indirect
+ github.com/gabriel-vasile/mimetype v1.4.3 // indirect
+ github.com/gin-contrib/sse v0.1.0 // indirect
+ github.com/glebarez/go-sqlite v1.21.2 // indirect
+ github.com/go-ole/go-ole v1.2.6 // indirect
+ github.com/go-playground/locales v0.14.1 // indirect
+ github.com/go-playground/universal-translator v0.18.1 // indirect
+ github.com/go-sql-driver/mysql v1.7.0 // indirect
+ github.com/goccy/go-json v0.10.2 // indirect
+ github.com/google/go-cmp v0.6.0 // indirect
+ github.com/gorilla/context v1.1.1 // indirect
+ github.com/gorilla/securecookie v1.1.1 // indirect
+ github.com/gorilla/sessions v1.2.1 // indirect
+ github.com/jackc/pgpassfile v1.0.0 // indirect
+ github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
+ github.com/jackc/pgx/v5 v5.7.1 // indirect
+ github.com/jackc/puddle/v2 v2.2.2 // indirect
+ github.com/jinzhu/inflection v1.0.0 // indirect
+ github.com/jinzhu/now v1.1.5 // indirect
+ github.com/json-iterator/go v1.1.12 // indirect
+ github.com/klauspost/cpuid/v2 v2.2.9 // indirect
+ github.com/leodido/go-urn v1.4.0 // indirect
+ github.com/mattn/go-isatty v0.0.20 // indirect
+ github.com/mitchellh/mapstructure v1.5.0 // indirect
+ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
+ github.com/modern-go/reflect2 v1.0.2 // indirect
+ github.com/pelletier/go-toml/v2 v2.2.1 // indirect
+ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
+ github.com/tklauser/go-sysconf v0.3.12 // indirect
+ github.com/tklauser/numcpus v0.6.1 // indirect
+ github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
+ github.com/ugorji/go/codec v1.2.12 // indirect
+ github.com/yusufpapurcu/wmi v1.2.3 // indirect
+ golang.org/x/arch v0.12.0 // indirect
+ golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
+ golang.org/x/sys v0.30.0 // indirect
+ golang.org/x/text v0.22.0 // indirect
+ google.golang.org/protobuf v1.34.2 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
+ modernc.org/libc v1.22.5 // indirect
+ modernc.org/mathutil v1.5.0 // indirect
+ modernc.org/memory v1.5.0 // indirect
+ modernc.org/sqlite v1.23.1 // indirect
+)
diff --git a/go.sum b/go.sum
new file mode 100644
index 00000000..74eecd4c
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,293 @@
+github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A=
+github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U=
+github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
+github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
+github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs=
+github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
+github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
+github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
+github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
+github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
+github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
+github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
+github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
+github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
+github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
+github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
+github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
+github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
+github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
+github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
+github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
+github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
+github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
+github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
+github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
+github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
+github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
+github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
+github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
+github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
+github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
+github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
+github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
+github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
+github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
+github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
+github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
+github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
+github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
+github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
+github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw=
+github.com/gin-contrib/cors v1.7.2/go.mod h1:SUJVARKgQ40dmrzgXEVxj2m7Ig1v1qIboQkPDTQ9t2E=
+github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4=
+github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk=
+github.com/gin-contrib/sessions v0.0.5 h1:CATtfHmLMQrMNpJRgzjWXD7worTh7g7ritsQfmF+0jE=
+github.com/gin-contrib/sessions v0.0.5/go.mod h1:vYAuaUPqie3WUSsft6HUlCjlwwoJQs97miaG2+7neKY=
+github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
+github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
+github.com/gin-contrib/static v0.0.1 h1:JVxuvHPuUfkoul12N7dtQw7KRn/pSMq7Ue1Va9Swm1U=
+github.com/gin-contrib/static v0.0.1/go.mod h1:CSxeF+wep05e0kCOsqWdAWbSszmc31zTIbD8TvWl7Hs=
+github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
+github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk=
+github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
+github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
+github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
+github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
+github.com/glebarez/sqlite v1.9.0 h1:Aj6bPA12ZEx5GbSF6XADmCkYXlljPNUY+Zf1EQxynXs=
+github.com/glebarez/sqlite v1.9.0/go.mod h1:YBYCoyupOao60lzp1MVBLEjZfgkq0tdB1voAQ09K9zw=
+github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
+github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
+github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
+github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
+github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
+github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
+github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs=
+github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
+github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
+github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
+github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA=
+github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
+github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
+github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
+github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
+github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
+github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
+github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
+github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
+github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
+github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
+github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
+github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
+github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
+github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
+github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
+github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
+github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
+github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
+github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
+github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
+github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
+github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
+github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
+github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
+github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
+github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
+github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
+github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
+github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
+github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
+github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
+github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
+github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
+github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
+github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
+github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
+github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs=
+github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
+github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
+github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
+github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
+github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
+github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
+github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
+github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
+github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
+github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
+github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
+github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
+github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
+github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
+github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY=
+github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8=
+github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
+github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
+github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
+github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
+github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
+github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
+github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
+github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
+github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
+github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
+github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
+github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
+github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
+github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
+github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
+github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
+github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
+github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
+github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
+github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
+github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
+github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
+github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
+github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
+github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
+github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
+github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
+github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
+github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
+github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
+github.com/pelletier/go-toml/v2 v2.2.1 h1:9TA9+T8+8CUCO2+WYnDLCgrYi9+omqKXyjDtosvtEhg=
+github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
+github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
+github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
+github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
+github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
+github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
+github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
+github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
+github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
+github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
+github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
+github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
+github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
+github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
+github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
+github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
+github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
+github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJUzCLbw=
+github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
+github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
+github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
+github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
+github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
+github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
+github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
+github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
+github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
+github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
+github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
+github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
+github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M=
+github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
+github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
+github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
+github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
+github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
+github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
+github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw=
+github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
+golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
+golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg=
+golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
+golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
+golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
+golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
+golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
+golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
+golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
+golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
+golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
+golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
+golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
+golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
+golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
+golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
+google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
+google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
+google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
+gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
+gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
+gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
+gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
+gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k=
+gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c=
+gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
+gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
+gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
+gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
+gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
+modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE=
+modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY=
+modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
+modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
+modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
+modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
+modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
+modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
+nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
+rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
diff --git a/i18n/zh-cn.json b/i18n/zh-cn.json
new file mode 100644
index 00000000..7b57b51a
--- /dev/null
+++ b/i18n/zh-cn.json
@@ -0,0 +1,1041 @@
+{
+ "未登录或登录已过期,请重新登录": "未登录或登录已过期,请重新登录",
+ "登 录": "登 录",
+ "使用 微信 继续": "使用 微信 继续",
+ "使用 GitHub 继续": "使用 GitHub 继续",
+ "使用 LinuxDO 继续": "使用 LinuxDO 继续",
+ "使用 邮箱或用户名 登录": "使用 邮箱或用户名 登录",
+ "没有账户?": "没有账户?",
+ "用户名或邮箱": "用户名或邮箱",
+ "请输入您的用户名或邮箱地址": "请输入您的用户名或邮箱地址",
+ "请输入您的密码": "请输入您的密码",
+ "继续": "继续",
+ "忘记密码?": "忘记密码?",
+ "其他登录选项": "其他登录选项",
+ "微信扫码登录": "微信扫码登录",
+ "登录": "登录",
+ "微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)": "微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)",
+ "验证码": "验证码",
+ "处理中...": "处理中...",
+ "绑定成功!": "绑定成功!",
+ "登录成功!": "登录成功!",
+ "操作失败,重定向至登录界面中...": "操作失败,重定向至登录界面中...",
+ "出现错误,第 ${count} 次重试中...": "出现错误,第 ${count} 次重试中...",
+ "无效的重置链接,请重新发起密码重置请求": "无效的重置链接,请重新发起密码重置请求",
+ "密码已重置并已复制到剪贴板:": "密码已重置并已复制到剪贴板:",
+ "密码重置确认": "密码重置确认",
+ "等待获取邮箱信息...": "等待获取邮箱信息...",
+ "新密码": "新密码",
+ "密码已复制到剪贴板:": "密码已复制到剪贴板:",
+ "密码重置完成": "密码重置完成",
+ "确认重置密码": "确认重置密码",
+ "返回登录": "返回登录",
+ "请输入邮箱地址": "请输入邮箱地址",
+ "请稍后几秒重试,Turnstile 正在检查用户环境!": "请稍后几秒重试,Turnstile 正在检查用户环境!",
+ "重置邮件发送成功,请检查邮箱!": "重置邮件发送成功,请检查邮箱!",
+ "密码重置": "密码重置",
+ "请输入您的邮箱地址": "请输入您的邮箱地址",
+ "重试": "重试",
+ "想起来了?": "想起来了?",
+ "注 册": "注 册",
+ "使用 用户名 注册": "使用 用户名 注册",
+ "已有账户?": "已有账户?",
+ "用户名": "用户名",
+ "请输入用户名": "请输入用户名",
+ "输入密码,最短 8 位,最长 20 位": "输入密码,最短 8 位,最长 20 位",
+ "确认密码": "确认密码",
+ "输入邮箱地址": "输入邮箱地址",
+ "获取验证码": "获取验证码",
+ "输入验证码": "输入验证码",
+ "或": "或",
+ "其他注册选项": "其他注册选项",
+ "加载中...": "加载中...",
+ "复制代码": "复制代码",
+ "代码已复制到剪贴板": "代码已复制到剪贴板",
+ "复制失败,请手动复制": "复制失败,请手动复制",
+ "显示更多": "显示更多",
+ "关于我们": "关于我们",
+ "关于项目": "关于项目",
+ "联系我们": "联系我们",
+ "功能特性": "功能特性",
+ "快速开始": "快速开始",
+ "安装指南": "安装指南",
+ "API 文档": "API 文档",
+ "基于New API的项目": "基于New API的项目",
+ "版权所有": "版权所有",
+ "设计与开发由": "设计与开发由",
+ "首页": "首页",
+ "控制台": "控制台",
+ "文档": "文档",
+ "关于": "关于",
+ "注销成功!": "注销成功!",
+ "个人设置": "个人设置",
+ "API令牌": "API令牌",
+ "退出": "退出",
+ "关闭侧边栏": "关闭侧边栏",
+ "打开侧边栏": "打开侧边栏",
+ "关闭菜单": "关闭菜单",
+ "打开菜单": "打开菜单",
+ "演示站点": "演示站点",
+ "自用模式": "自用模式",
+ "系统公告": "系统公告",
+ "切换主题": "切换主题",
+ "切换语言": "切换语言",
+ "暂无公告": "暂无公告",
+ "暂无系统公告": "暂无系统公告",
+ "今日关闭": "今日关闭",
+ "关闭公告": "关闭公告",
+ "数据看板": "数据看板",
+ "绘图日志": "绘图日志",
+ "任务日志": "任务日志",
+ "渠道": "渠道",
+ "兑换码": "兑换码",
+ "用户管理": "用户管理",
+ "操练场": "操练场",
+ "聊天": "聊天",
+ "管理员": "管理员",
+ "个人中心": "个人中心",
+ "展开侧边栏": "展开侧边栏",
+ "AI 对话": "AI 对话",
+ "选择模型开始对话": "选择模型开始对话",
+ "显示调试": "显示调试",
+ "请输入您的问题...": "请输入您的问题...",
+ "已复制到剪贴板": "已复制到剪贴板",
+ "复制失败": "复制失败",
+ "正在构造请求体预览...": "正在构造请求体预览...",
+ "暂无请求数据": "暂无请求数据",
+ "暂无响应数据": "暂无响应数据",
+ "内容较大,已启用性能优化模式": "内容较大,已启用性能优化模式",
+ "内容较大,部分功能可能受限": "内容较大,部分功能可能受限",
+ "已复制": "已复制",
+ "正在处理大内容...": "正在处理大内容...",
+ "显示完整内容": "显示完整内容",
+ "收起": "收起",
+ "配置已导出到下载文件夹": "配置已导出到下载文件夹",
+ "导出配置失败: ": "导出配置失败: ",
+ "确认导入配置": "确认导入配置",
+ "导入的配置将覆盖当前设置,是否继续?": "导入的配置将覆盖当前设置,是否继续?",
+ "取消": "取消",
+ "配置导入成功": "配置导入成功",
+ "导入配置失败: ": "导入配置失败: ",
+ "重置配置": "重置配置",
+ "将清除所有保存的配置并恢复默认设置,此操作不可撤销。是否继续?": "将清除所有保存的配置并恢复默认设置,此操作不可撤销。是否继续?",
+ "重置选项": "重置选项",
+ "是否同时重置对话消息?选择\"是\"将清空所有对话记录并恢复默认示例;选择\"否\"将保留当前对话记录。": "是否同时重置对话消息?选择\"是\"将清空所有对话记录并恢复默认示例;选择\"否\"将保留当前对话记录。",
+ "同时重置消息": "同时重置消息",
+ "仅重置配置": "仅重置配置",
+ "配置和消息已全部重置": "配置和消息已全部重置",
+ "配置已重置,对话消息已保留": "配置已重置,对话消息已保留",
+ "已有保存的配置": "已有保存的配置",
+ "暂无保存的配置": "暂无保存的配置",
+ "导出配置": "导出配置",
+ "导入配置": "导入配置",
+ "导出": "导出",
+ "导入": "导入",
+ "调试信息": "调试信息",
+ "预览请求体": "预览请求体",
+ "实际请求体": "实际请求体",
+ "预览更新": "预览更新",
+ "最后请求": "最后请求",
+ "操作暂时被禁用": "操作暂时被禁用",
+ "复制": "复制",
+ "编辑": "编辑",
+ "切换为System角色": "切换为System角色",
+ "切换为Assistant角色": "切换为Assistant角色",
+ "删除": "删除",
+ "请求发生错误": "请求发生错误",
+ "系统消息": "系统消息",
+ "请输入消息内容...": "请输入消息内容...",
+ "保存": "保存",
+ "模型配置": "模型配置",
+ "分组": "分组",
+ "请选择分组": "请选择分组",
+ "请选择模型": "请选择模型",
+ "思考中...": "思考中...",
+ "思考过程": "思考过程",
+ "选择同步渠道": "选择同步渠道",
+ "搜索渠道名称或地址": "搜索渠道名称或地址",
+ "暂无渠道": "暂无渠道",
+ "暂无选择": "暂无选择",
+ "无搜索结果": "无搜索结果",
+ "公告已更新": "公告已更新",
+ "公告更新失败": "公告更新失败",
+ "系统名称已更新": "系统名称已更新",
+ "系统名称更新失败": "系统名称更新失败",
+ "系统信息": "系统信息",
+ "当前版本": "当前版本",
+ "检查更新": "检查更新",
+ "启动时间": "启动时间",
+ "通用设置": "通用设置",
+ "设置公告": "设置公告",
+ "个性化设置": "个性化设置",
+ "系统名称": "系统名称",
+ "在此输入系统名称": "在此输入系统名称",
+ "设置系统名称": "设置系统名称",
+ "Logo 图片地址": "Logo 图片地址",
+ "在此输入 Logo 图片地址": "在此输入 Logo 图片地址",
+ "首页内容": "首页内容",
+ "设置首页内容": "设置首页内容",
+ "设置关于": "设置关于",
+ "页脚": "页脚",
+ "设置页脚": "设置页脚",
+ "详情": "详情",
+ "刷新失败": "刷新失败",
+ "令牌已重置并已复制到剪贴板": "令牌已重置并已复制到剪贴板",
+ "加载模型列表失败": "加载模型列表失败",
+ "系统令牌已复制到剪切板": "系统令牌已复制到剪切板",
+ "请输入你的账户名以确认删除!": "请输入你的账户名以确认删除!",
+ "账户已删除!": "账户已删除!",
+ "微信账户绑定成功!": "微信账户绑定成功!",
+ "请输入原密码!": "请输入原密码!",
+ "请输入新密码!": "请输入新密码!",
+ "新密码需要和原密码不一致!": "新密码需要和原密码不一致!",
+ "两次输入的密码不一致!": "两次输入的密码不一致!",
+ "密码修改成功!": "密码修改成功!",
+ "验证码发送成功,请检查邮箱!": "验证码发送成功,请检查邮箱!",
+ "请输入邮箱验证码!": "请输入邮箱验证码!",
+ "邮箱账户绑定成功!": "邮箱账户绑定成功!",
+ "无法复制到剪贴板,请手动复制": "无法复制到剪贴板,请手动复制",
+ "设置保存成功": "设置保存成功",
+ "设置保存失败": "设置保存失败",
+ "超级管理员": "超级管理员",
+ "普通用户": "普通用户",
+ "当前余额": "当前余额",
+ "历史消耗": "历史消耗",
+ "请求次数": "请求次数",
+ "默认": "默认",
+ "可用模型": "可用模型",
+ "模型列表": "模型列表",
+ "点击模型名称可复制": "点击模型名称可复制",
+ "没有可用模型": "没有可用模型",
+ "该分类下没有可用模型": "该分类下没有可用模型",
+ "更多": "更多",
+ "个模型": "个模型",
+ "账户绑定": "账户绑定",
+ "未绑定": "未绑定",
+ "修改绑定": "修改绑定",
+ "微信": "微信",
+ "已绑定": "已绑定",
+ "未启用": "未启用",
+ "绑定": "绑定",
+ "安全设置": "安全设置",
+ "系统访问令牌": "系统访问令牌",
+ "用于API调用的身份验证令牌,请妥善保管": "用于API调用的身份验证令牌,请妥善保管",
+ "生成令牌": "生成令牌",
+ "密码管理": "密码管理",
+ "定期更改密码可以提高账户安全性": "定期更改密码可以提高账户安全性",
+ "修改密码": "修改密码",
+ "此操作不可逆,所有数据将被永久删除": "此操作不可逆,所有数据将被永久删除",
+ "删除账户": "删除账户",
+ "其他设置": "其他设置",
+ "通知设置": "通知设置",
+ "邮件通知": "邮件通知",
+ "通过邮件接收通知": "通过邮件接收通知",
+ "Webhook通知": "Webhook通知",
+ "通过HTTP请求接收通知": "通过HTTP请求接收通知",
+ "请输入Webhook地址,例如: https://example.com/webhook": "请输入Webhook地址,例如: https://example.com/webhook",
+ "只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求": "只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求",
+ "接口凭证(可选)": "接口凭证(可选)",
+ "请输入密钥": "请输入密钥",
+ "密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性": "密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性",
+ "通知邮箱": "通知邮箱",
+ "留空则使用账号绑定的邮箱": "留空则使用账号绑定的邮箱",
+ "设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱": "设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱",
+ "额度预警阈值": "额度预警阈值",
+ "请输入预警额度": "请输入预警额度",
+ "当剩余额度低于此数值时,系统将通过选择的方式发送通知": "当剩余额度低于此数值时,系统将通过选择的方式发送通知",
+ "接受未设置价格模型": "接受未设置价格模型",
+ "当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用": "当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用",
+ "IP记录": "IP记录",
+ "记录请求与错误日志 IP": "记录请求与错误日志 IP",
+ "开启后,仅“消费”和“错误”日志将记录您的客户端 IP 地址": "开启后,仅“消费”和“错误”日志将记录您的客户端 IP 地址",
+ "绑定邮箱地址": "绑定邮箱地址",
+ "重新发送": "重新发送",
+ "绑定微信账户": "绑定微信账户",
+ "删除账户确认": "删除账户确认",
+ "您正在删除自己的帐户,将清空所有数据且不可恢复": "您正在删除自己的帐户,将清空所有数据且不可恢复",
+ "请输入您的用户名以确认删除": "请输入您的用户名以确认删除",
+ "输入你的账户名{{username}}以确认删除": "输入你的账户名{{username}}以确认删除",
+ "原密码": "原密码",
+ "请输入原密码": "请输入原密码",
+ "请输入新密码": "请输入新密码",
+ "确认新密码": "确认新密码",
+ "请再次输入新密码": "请再次输入新密码",
+ "模型倍率设置": "模型倍率设置",
+ "可视化倍率设置": "可视化倍率设置",
+ "未设置倍率模型": "未设置倍率模型",
+ "上游倍率同步": "上游倍率同步",
+ "未知类型": "未知类型",
+ "标签聚合": "标签聚合",
+ "已启用": "已启用",
+ "自动禁用": "自动禁用",
+ "未知状态": "未知状态",
+ "未测试": "未测试",
+ "名称": "名称",
+ "类型": "类型",
+ "状态": "状态",
+ ",时间:": ",时间:",
+ "响应时间": "响应时间",
+ "已用/剩余": "已用/剩余",
+ "剩余额度$": "剩余额度$",
+ ",点击更新": ",点击更新",
+ "已用额度": "已用额度",
+ "修改子渠道优先级": "修改子渠道优先级",
+ "确定要修改所有子渠道优先级为 ": "确定要修改所有子渠道优先级为 ",
+ "权重": "权重",
+ "修改子渠道权重": "修改子渠道权重",
+ "确定要修改所有子渠道权重为 ": "确定要修改所有子渠道权重为 ",
+ "确定是否要删除此渠道?": "确定是否要删除此渠道?",
+ "此修改将不可逆": "此修改将不可逆",
+ "确定是否要复制此渠道?": "确定是否要复制此渠道?",
+ "复制渠道的所有信息": "复制渠道的所有信息",
+ "测试单个渠道操作项目组": "测试单个渠道操作项目组",
+ "禁用": "禁用",
+ "启用": "启用",
+ "启用全部": "启用全部",
+ "禁用全部": "禁用全部",
+ "重置": "重置",
+ "全选": "全选",
+ "_复制": "_复制",
+ "渠道未找到,请刷新页面后重试。": "渠道未找到,请刷新页面后重试。",
+ "渠道复制成功": "渠道复制成功",
+ "渠道复制失败: ": "渠道复制失败: ",
+ "操作成功完成!": "操作成功完成!",
+ "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。",
+ "已停止测试": "已停止测试",
+ "全部": "全部",
+ "请先选择要设置标签的渠道!": "请先选择要设置标签的渠道!",
+ "标签不能为空!": "标签不能为空!",
+ "已为 ${count} 个渠道设置标签!": "已为 ${count} 个渠道设置标签!",
+ "已成功开始测试所有已启用通道,请刷新页面查看结果。": "已成功开始测试所有已启用通道,请刷新页面查看结果。",
+ "已删除所有禁用渠道,共计 ${data} 个": "已删除所有禁用渠道,共计 ${data} 个",
+ "已更新完毕所有已启用通道余额!": "已更新完毕所有已启用通道余额!",
+ "通道 ${name} 余额更新成功!": "通道 ${name} 余额更新成功!",
+ "已删除 ${data} 个通道!": "已删除 ${data} 个通道!",
+ "已修复 ${data} 个通道!": "已修复 ${data} 个通道!",
+ "确定是否要删除所选通道?": "确定是否要删除所选通道?",
+ "删除所选通道": "删除所选通道",
+ "批量设置标签": "批量设置标签",
+ "确定要测试所有通道吗?": "确定要测试所有通道吗?",
+ "测试所有通道": "测试所有通道",
+ "确定要更新所有已启用通道余额吗?": "确定要更新所有已启用通道余额吗?",
+ "更新所有已启用通道余额": "更新所有已启用通道余额",
+ "确定是否要删除禁用通道?": "确定是否要删除禁用通道?",
+ "删除禁用通道": "删除禁用通道",
+ "确定是否要修复数据库一致性?": "确定是否要修复数据库一致性?",
+ "进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用",
+ "批量操作": "批量操作",
+ "使用ID排序": "使用ID排序",
+ "开启批量操作": "开启批量操作",
+ "标签聚合模式": "标签聚合模式",
+ "刷新": "刷新",
+ "列设置": "列设置",
+ "搜索渠道的 ID,名称,密钥和API地址 ...": "搜索渠道的 ID,名称,密钥和API地址 ...",
+ "模型关键字": "模型关键字",
+ "选择分组": "选择分组",
+ "查询": "查询",
+ "第 {{start}} - {{end}} 条,共 {{total}} 条": "第 {{start}} - {{end}} 条,共 {{total}} 条",
+ "搜索无结果": "搜索无结果",
+ "请输入要设置的标签名称": "请输入要设置的标签名称",
+ "请输入标签名称": "请输入标签名称",
+ "已选择 ${count} 个渠道": "已选择 ${count} 个渠道",
+ "共": "共",
+ "停止测试": "停止测试",
+ "测试中...": "测试中...",
+ "批量测试${count}个模型": "批量测试${count}个模型",
+ "搜索模型...": "搜索模型...",
+ "模型名称": "模型名称",
+ "测试中": "测试中",
+ "未开始": "未开始",
+ "失败": "失败",
+ "请求时长: ${time}s": "请求时长: ${time}s",
+ "充值": "充值",
+ "消费": "消费",
+ "系统": "系统",
+ "错误": "错误",
+ "流": "流",
+ "非流": "非流",
+ "请求并计费模型": "请求并计费模型",
+ "实际模型": "实际模型",
+ "用户": "用户",
+ "用时/首字": "用时/首字",
+ "提示": "提示",
+ "花费": "花费",
+ "只有当用户设置开启IP记录时,才会进行请求和错误类型日志的IP记录": "只有当用户设置开启IP记录时,才会进行请求和错误类型日志的IP记录",
+ "确定": "确定",
+ "用户信息": "用户信息",
+ "渠道信息": "渠道信息",
+ "语音输入": "语音输入",
+ "文字输入": "文字输入",
+ "文字输出": "文字输出",
+ "缓存创建 Tokens": "缓存创建 Tokens",
+ "日志详情": "日志详情",
+ "消耗额度": "消耗额度",
+ "开始时间": "开始时间",
+ "结束时间": "结束时间",
+ "用户名称": "用户名称",
+ "日志类型": "日志类型",
+ "绘图": "绘图",
+ "放大": "放大",
+ "变换": "变换",
+ "强变换": "强变换",
+ "平移": "平移",
+ "图生文": "图生文",
+ "图混合": "图混合",
+ "重绘": "重绘",
+ "局部重绘-提交": "局部重绘-提交",
+ "自定义变焦-提交": "自定义变焦-提交",
+ "窗口处理": "窗口处理",
+ "未知": "未知",
+ "已提交": "已提交",
+ "等待中": "等待中",
+ "重复提交": "重复提交",
+ "成功": "成功",
+ "未启动": "未启动",
+ "执行中": "执行中",
+ "窗口等待": "窗口等待",
+ "秒": "秒",
+ "提交时间": "提交时间",
+ "花费时间": "花费时间",
+ "任务ID": "任务ID",
+ "提交结果": "提交结果",
+ "任务状态": "任务状态",
+ "结果图片": "结果图片",
+ "查看图片": "查看图片",
+ "无": "无",
+ "失败原因": "失败原因",
+ "已复制:": "已复制:",
+ "当前未开启Midjourney回调,部分项目可能无法获得绘图结果,可在运营设置中开启。": "当前未开启Midjourney回调,部分项目可能无法获得绘图结果,可在运营设置中开启。",
+ "Midjourney 任务记录": "Midjourney 任务记录",
+ "任务 ID": "任务 ID",
+ "按次计费": "按次计费",
+ "按量计费": "按量计费",
+ "您的分组可以使用该模型": "您的分组可以使用该模型",
+ "可用性": "可用性",
+ "计费类型": "计费类型",
+ "当前查看的分组为:{{group}},倍率为:{{ratio}}": "当前查看的分组为:{{group}},倍率为:{{ratio}}",
+ "倍率": "倍率",
+ "倍率是为了方便换算不同价格的模型": "倍率是为了方便换算不同价格的模型",
+ "模型倍率": "模型倍率",
+ "补全倍率": "补全倍率",
+ "分组倍率": "分组倍率",
+ "模型价格": "模型价格",
+ "补全": "补全",
+ "模糊搜索模型名称": "模糊搜索模型名称",
+ "复制选中模型": "复制选中模型",
+ "模型定价": "模型定价",
+ "当前分组": "当前分组",
+ "未登录,使用默认分组倍率": "未登录,使用默认分组倍率",
+ "按量计费费用 = 分组倍率 × 模型倍率 × (提示token数 + 补全token数 × 补全倍率)/ 500000 (单位:美元)": "按量计费费用 = 分组倍率 × 模型倍率 × (提示token数 + 补全token数 × 补全倍率)/ 500000 (单位:美元)",
+ "已过期": "已过期",
+ "未使用": "未使用",
+ "已禁用": "已禁用",
+ "创建时间": "创建时间",
+ "过期时间": "过期时间",
+ "永不过期": "永不过期",
+ "确定是否要删除此兑换码?": "确定是否要删除此兑换码?",
+ "查看": "查看",
+ "已复制到剪贴板!": "已复制到剪贴板!",
+ "兑换码可以批量生成和分发,适合用于推广活动或批量充值。": "兑换码可以批量生成和分发,适合用于推广活动或批量充值。",
+ "添加兑换码": "添加兑换码",
+ "请至少选择一个兑换码!": "请至少选择一个兑换码!",
+ "复制所选兑换码到剪贴板": "复制所选兑换码到剪贴板",
+ "确定清除所有失效兑换码?": "确定清除所有失效兑换码?",
+ "将删除已使用、已禁用及过期的兑换码,此操作不可撤销。": "将删除已使用、已禁用及过期的兑换码,此操作不可撤销。",
+ "已删除 {{count}} 条失效兑换码": "已删除 {{count}} 条失效兑换码",
+ "关键字(id或者名称)": "关键字(id或者名称)",
+ "生成音乐": "生成音乐",
+ "生成歌词": "生成歌词",
+ "生成视频": "生成视频",
+ "排队中": "排队中",
+ "正在提交": "正在提交",
+ "平台": "平台",
+ "点击预览视频": "点击预览视频",
+ "任务记录": "任务记录",
+ "渠道 ID": "渠道 ID",
+ "已启用:限制模型": "已启用:限制模型",
+ "已耗尽": "已耗尽",
+ "剩余额度": "剩余额度",
+ "聊天链接配置错误,请联系管理员": "聊天链接配置错误,请联系管理员",
+ "令牌详情": "令牌详情",
+ "确定是否要删除此令牌?": "确定是否要删除此令牌?",
+ "项目操作按钮组": "项目操作按钮组",
+ "请联系管理员配置聊天链接": "请联系管理员配置聊天链接",
+ "令牌用于API访问认证,可以设置额度限制和模型权限。": "令牌用于API访问认证,可以设置额度限制和模型权限。",
+ "添加令牌": "添加令牌",
+ "请至少选择一个令牌!": "请至少选择一个令牌!",
+ "复制所选令牌到剪贴板": "复制所选令牌到剪贴板",
+ "搜索关键字": "搜索关键字",
+ "未知身份": "未知身份",
+ "已封禁": "已封禁",
+ "统计信息": "统计信息",
+ "剩余": "剩余",
+ "调用": "调用",
+ "邀请信息": "邀请信息",
+ "收益": "收益",
+ "无邀请人": "无邀请人",
+ "已注销": "已注销",
+ "确定要提升此用户吗?": "确定要提升此用户吗?",
+ "此操作将提升用户的权限级别": "此操作将提升用户的权限级别",
+ "确定要降级此用户吗?": "确定要降级此用户吗?",
+ "此操作将降低用户的权限级别": "此操作将降低用户的权限级别",
+ "确定是否要注销此用户?": "确定是否要注销此用户?",
+ "相当于删除用户,此修改将不可逆": "相当于删除用户,此修改将不可逆",
+ "用户管理页面,可以查看和管理所有注册用户的信息、权限和状态。": "用户管理页面,可以查看和管理所有注册用户的信息、权限和状态。",
+ "添加用户": "添加用户",
+ "支持搜索用户的 ID、用户名、显示名称和邮箱地址": "支持搜索用户的 ID、用户名、显示名称和邮箱地址",
+ "全部模型": "全部模型",
+ "智谱": "智谱",
+ "通义千问": "通义千问",
+ "文心一言": "文心一言",
+ "腾讯混元": "腾讯混元",
+ "360智脑": "360智脑",
+ "豆包": "豆包",
+ "用户分组": "用户分组",
+ "专属倍率": "专属倍率",
+ "输入价格:${{price}} / 1M tokens{{audioPrice}}": "输入价格:${{price}} / 1M tokens{{audioPrice}}",
+ "Web搜索价格:${{price}} / 1K 次": "Web搜索价格:${{price}} / 1K 次",
+ "文件搜索价格:${{price}} / 1K 次": "文件搜索价格:${{price}} / 1K 次",
+ "仅供参考,以实际扣费为准": "仅供参考,以实际扣费为准",
+ "价格:${{price}} * {{ratioType}}:{{ratio}}": "价格:${{price}} * {{ratioType}}:{{ratio}}",
+ "模型: {{ratio}} * {{ratioType}}:{{groupRatio}}": "模型: {{ratio}} * {{ratioType}}:{{groupRatio}}",
+ "提示价格:${{price}} / 1M tokens": "提示价格:${{price}} / 1M tokens",
+ "模型价格 ${{price}},{{ratioType}} {{ratio}}": "模型价格 ${{price}},{{ratioType}} {{ratio}}",
+ "模型: {{ratio}} * {{ratioType}}: {{groupRatio}}": "模型: {{ratio}} * {{ratioType}}: {{groupRatio}}",
+ "不是合法的 JSON 字符串": "不是合法的 JSON 字符串",
+ "请求发生错误: ": "请求发生错误: ",
+ "解析响应数据时发生错误": "解析响应数据时发生错误",
+ "连接已断开": "连接已断开",
+ "建立连接时发生错误": "建立连接时发生错误",
+ "加载模型失败": "加载模型失败",
+ "加载分组失败": "加载分组失败",
+ "消息已复制到剪贴板": "消息已复制到剪贴板",
+ "确认删除": "确认删除",
+ "确定要删除这条消息吗?": "确定要删除这条消息吗?",
+ "已删除消息及其回复": "已删除消息及其回复",
+ "消息已删除": "消息已删除",
+ "消息已编辑": "消息已编辑",
+ "检测到该消息后有AI回复,是否删除后续回复并重新生成?": "检测到该消息后有AI回复,是否删除后续回复并重新生成?",
+ "重新生成": "重新生成",
+ "消息已更新": "消息已更新",
+ "加载关于内容失败...": "加载关于内容失败...",
+ "可在设置页面设置关于内容,支持 HTML & Markdown": "可在设置页面设置关于内容,支持 HTML & Markdown",
+ "New API项目仓库地址:": "New API项目仓库地址:",
+ "| 基于": "| 基于",
+ "本项目根据": "本项目根据",
+ "MIT许可证": "MIT许可证",
+ "授权,需在遵守": "授权,需在遵守",
+ "Apache-2.0协议": "Apache-2.0协议",
+ "管理员暂时未设置任何关于内容": "管理员暂时未设置任何关于内容",
+ "仅支持 OpenAI 接口格式": "仅支持 OpenAI 接口格式",
+ "请填写密钥": "请填写密钥",
+ "获取模型列表成功": "获取模型列表成功",
+ "获取模型列表失败": "获取模型列表失败",
+ "请填写渠道名称和渠道密钥!": "请填写渠道名称和渠道密钥!",
+ "请至少选择一个模型!": "请至少选择一个模型!",
+ "模型映射必须是合法的 JSON 格式!": "模型映射必须是合法的 JSON 格式!",
+ "提交失败,请勿重复提交!": "提交失败,请勿重复提交!",
+ "渠道创建成功!": "渠道创建成功!",
+ "已新增 {{count}} 个模型:{{list}}": "已新增 {{count}} 个模型:{{list}}",
+ "未发现新增模型": "未发现新增模型",
+ "新建": "新建",
+ "更新渠道信息": "更新渠道信息",
+ "创建新的渠道": "创建新的渠道",
+ "基本信息": "基本信息",
+ "渠道的基本配置信息": "渠道的基本配置信息",
+ "请选择渠道类型": "请选择渠道类型",
+ "请为渠道命名": "请为渠道命名",
+ "请输入密钥,一行一个": "请输入密钥,一行一个",
+ "批量创建": "批量创建",
+ "API 配置": "API 配置",
+ "API 地址和相关配置": "API 地址和相关配置",
+ "2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的\".\"": "2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的\".\"",
+ "请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com": "请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com",
+ "请输入默认 API 版本,例如:2025-04-01-preview": "请输入默认 API 版本,例如:2025-04-01-preview",
+ "如果你对接的是上游One API或者New API等转发项目,请使用OpenAI类型,不要使用此类型,除非你知道你在做什么。": "如果你对接的是上游One API或者New API等转发项目,请使用OpenAI类型,不要使用此类型,除非你知道你在做什么。",
+ "完整的 Base URL,支持变量{model}": "完整的 Base URL,支持变量{model}",
+ "请输入完整的URL,例如:https://api.openai.com/v1/chat/completions": "请输入完整的URL,例如:https://api.openai.com/v1/chat/completions",
+ "Dify渠道只适配chatflow和agent,并且agent不支持图片!": "Dify渠道只适配chatflow和agent,并且agent不支持图片!",
+ "此项可选,用于通过自定义API地址来进行 API 调用,末尾不要带/v1和/": "此项可选,用于通过自定义API地址来进行 API 调用,末尾不要带/v1和/",
+ "对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写": "对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写",
+ "私有部署地址": "私有部署地址",
+ "请输入私有部署地址,格式为:https://fastgpt.run/api/openapi": "请输入私有部署地址,格式为:https://fastgpt.run/api/openapi",
+ "注意非Chat API,请务必填写正确的API地址,否则可能导致无法使用": "注意非Chat API,请务必填写正确的API地址,否则可能导致无法使用",
+ "请输入到 /suno 前的路径,通常就是域名,例如:https://api.example.com": "请输入到 /suno 前的路径,通常就是域名,例如:https://api.example.com",
+ "模型选择和映射设置": "模型选择和映射设置",
+ "模型": "模型",
+ "请选择该渠道所支持的模型": "请选择该渠道所支持的模型",
+ "填入相关模型": "填入相关模型",
+ "填入所有模型": "填入所有模型",
+ "获取模型列表": "获取模型列表",
+ "清除所有模型": "清除所有模型",
+ "输入自定义模型名称": "输入自定义模型名称",
+ "模型重定向": "模型重定向",
+ "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:",
+ "填入模板": "填入模板",
+ "默认测试模型": "默认测试模型",
+ "不填则为模型列表第一个": "不填则为模型列表第一个",
+ "渠道的高级配置选项": "渠道的高级配置选项",
+ "请选择可以使用该渠道的分组": "请选择可以使用该渠道的分组",
+ "请在系统设置页面编辑分组倍率以添加新的分组:": "请在系统设置页面编辑分组倍率以添加新的分组:",
+ "部署地区": "部署地区",
+ "知识库 ID": "知识库 ID",
+ "渠道标签": "渠道标签",
+ "渠道优先级": "渠道优先级",
+ "渠道权重": "渠道权重",
+ "渠道额外设置": "渠道额外设置",
+ "此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:": "此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:",
+ "参数覆盖": "参数覆盖",
+ "此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:": "此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:",
+ "请输入组织org-xxx": "请输入组织org-xxx",
+ "组织,可选,不填则为默认组织": "组织,可选,不填则为默认组织",
+ "是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道": "是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道",
+ "状态码复写(仅影响本地判断,不修改返回到上游的状态码)": "状态码复写(仅影响本地判断,不修改返回到上游的状态码)",
+ "此项可选,用于复写返回的状态码,比如将claude渠道的400错误复写为500(用于重试),请勿滥用该功能,例如:": "此项可选,用于复写返回的状态码,比如将claude渠道的400错误复写为500(用于重试),请勿滥用该功能,例如:",
+ "编辑标签": "编辑标签",
+ "标签信息": "标签信息",
+ "标签的基本配置": "标签的基本配置",
+ "所有编辑均为覆盖操作,留空则不更改": "所有编辑均为覆盖操作,留空则不更改",
+ "标签名称": "标签名称",
+ "请输入新标签,留空则解散标签": "请输入新标签,留空则解散标签",
+ "当前模型列表为该标签下所有渠道模型列表最长的一个,并非所有渠道的并集,请注意可能导致某些渠道模型丢失。": "当前模型列表为该标签下所有渠道模型列表最长的一个,并非所有渠道的并集,请注意可能导致某些渠道模型丢失。",
+ "请选择该渠道所支持的模型,留空则不更改": "请选择该渠道所支持的模型,留空则不更改",
+ "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,留空则不更改": "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,留空则不更改",
+ "清空重定向": "清空重定向",
+ "分组设置": "分组设置",
+ "用户分组配置": "用户分组配置",
+ "请选择可以使用该渠道的分组,留空则不更改": "请选择可以使用该渠道的分组,留空则不更改",
+ "正在跳转...": "正在跳转...",
+ "小时": "小时",
+ "周": "周",
+ "模型调用次数占比": "模型调用次数占比",
+ "模型消耗分布": "模型消耗分布",
+ "总计": "总计",
+ "早上好": "早上好",
+ "中午好": "中午好",
+ "下午好": "下午好",
+ "账户数据": "账户数据",
+ "使用统计": "使用统计",
+ "统计次数": "统计次数",
+ "资源消耗": "资源消耗",
+ "统计额度": "统计额度",
+ "性能指标": "性能指标",
+ "平均RPM": "平均RPM",
+ "复制成功": "复制成功",
+ "进行中": "进行中",
+ "异常": "异常",
+ "正常": "正常",
+ "可用率": "可用率",
+ "有异常": "有异常",
+ "高延迟": "高延迟",
+ "维护中": "维护中",
+ "暂无监控数据": "暂无监控数据",
+ "搜索条件": "搜索条件",
+ "时间粒度": "时间粒度",
+ "模型数据分析": "模型数据分析",
+ "消耗分布": "消耗分布",
+ "调用次数分布": "调用次数分布",
+ "API信息": "API信息",
+ "暂无API信息": "暂无API信息",
+ "请联系管理员在系统设置中配置API信息": "请联系管理员在系统设置中配置API信息",
+ "显示最新20条": "显示最新20条",
+ "请联系管理员在系统设置中配置公告信息": "请联系管理员在系统设置中配置公告信息",
+ "暂无常见问答": "暂无常见问答",
+ "请联系管理员在系统设置中配置常见问答": "请联系管理员在系统设置中配置常见问答",
+ "服务可用性": "服务可用性",
+ "请联系管理员在系统设置中配置Uptime": "请联系管理员在系统设置中配置Uptime",
+ "加载首页内容失败...": "加载首页内容失败...",
+ "统一的大模型接口网关": "统一的大模型接口网关",
+ "更好的价格,更好的稳定性,无需订阅": "更好的价格,更好的稳定性,无需订阅",
+ "开始使用": "开始使用",
+ "支持众多的大模型供应商": "支持众多的大模型供应商",
+ "页面未找到,请检查您的浏览器地址是否正确": "页面未找到,请检查您的浏览器地址是否正确",
+ "登录过期,请重新登录!": "登录过期,请重新登录!",
+ "兑换码更新成功!": "兑换码更新成功!",
+ "兑换码创建成功!": "兑换码创建成功!",
+ "兑换码创建成功": "兑换码创建成功",
+ "兑换码创建成功,是否下载兑换码?": "兑换码创建成功,是否下载兑换码?",
+ "兑换码将以文本文件的形式下载,文件名为兑换码的名称。": "兑换码将以文本文件的形式下载,文件名为兑换码的名称。",
+ "更新兑换码信息": "更新兑换码信息",
+ "创建新的兑换码": "创建新的兑换码",
+ "设置兑换码的基本信息": "设置兑换码的基本信息",
+ "请输入名称": "请输入名称",
+ "选择过期时间(可选,留空为永久)": "选择过期时间(可选,留空为永久)",
+ "额度设置": "额度设置",
+ "设置兑换码的额度和数量": "设置兑换码的额度和数量",
+ "请输入额度": "请输入额度",
+ "生成数量": "生成数量",
+ "请输入生成数量": "请输入生成数量",
+ "你似乎并没有修改什么": "你似乎并没有修改什么",
+ "部分保存失败,请重试": "部分保存失败,请重试",
+ "保存成功": "保存成功",
+ "保存失败,请重试": "保存失败,请重试",
+ "请检查输入": "请检查输入",
+ "聊天配置": "聊天配置",
+ "为一个 JSON 文本": "为一个 JSON 文本",
+ "保存聊天设置": "保存聊天设置",
+ "设置已保存": "设置已保存",
+ "API地址": "API地址",
+ "说明": "说明",
+ "颜色": "颜色",
+ "API信息管理,可以配置多个API地址用于状态展示和负载均衡(最多50个)": "API信息管理,可以配置多个API地址用于状态展示和负载均衡(最多50个)",
+ "批量删除": "批量删除",
+ "保存设置": "保存设置",
+ "添加API": "添加API",
+ "请输入API地址": "请输入API地址",
+ "如:香港线路": "如:香港线路",
+ "请输入线路描述": "请输入线路描述",
+ "如:大带宽批量分析图片推荐": "如:大带宽批量分析图片推荐",
+ "请输入说明": "请输入说明",
+ "标识颜色": "标识颜色",
+ "确定要删除此API信息吗?": "确定要删除此API信息吗?",
+ "警告": "警告",
+ "发布时间": "发布时间",
+ "操作": "操作",
+ "系统公告管理,可以发布系统通知和重要消息(最多100个,前端显示最新20条)": "系统公告管理,可以发布系统通知和重要消息(最多100个,前端显示最新20条)",
+ "添加公告": "添加公告",
+ "编辑公告": "编辑公告",
+ "公告内容": "公告内容",
+ "请输入公告内容": "请输入公告内容",
+ "请选择发布日期": "请选择发布日期",
+ "公告类型": "公告类型",
+ "说明信息": "说明信息",
+ "可选,公告的补充说明": "可选,公告的补充说明",
+ "确定要删除此公告吗?": "确定要删除此公告吗?",
+ "数据看板设置": "数据看板设置",
+ "启用数据看板(实验性)": "启用数据看板(实验性)",
+ "数据看板更新间隔": "数据看板更新间隔",
+ "设置过短会影响数据库性能": "设置过短会影响数据库性能",
+ "数据看板默认时间粒度": "数据看板默认时间粒度",
+ "仅修改展示粒度,统计精确到小时": "仅修改展示粒度,统计精确到小时",
+ "保存数据看板设置": "保存数据看板设置",
+ "问题标题": "问题标题",
+ "回答内容": "回答内容",
+ "常见问答管理,为用户提供常见问题的答案(最多50个,前端显示最新20条)": "常见问答管理,为用户提供常见问题的答案(最多50个,前端显示最新20条)",
+ "添加问答": "添加问答",
+ "编辑问答": "编辑问答",
+ "请输入问题标题": "请输入问题标题",
+ "请输入回答内容": "请输入回答内容",
+ "确定要删除此问答吗?": "确定要删除此问答吗?",
+ "分类名称": "分类名称",
+ "Uptime Kuma地址": "Uptime Kuma地址",
+ "Uptime Kuma监控分类管理,可以配置多个监控分类用于服务状态展示(最多20个)": "Uptime Kuma监控分类管理,可以配置多个监控分类用于服务状态展示(最多20个)",
+ "编辑分类": "编辑分类",
+ "添加分类": "添加分类",
+ "请输入分类名称,如:OpenAI、Claude等": "请输入分类名称,如:OpenAI、Claude等",
+ "请输入分类名称": "请输入分类名称",
+ "请输入Uptime Kuma服务地址,如:https://status.example.com": "请输入Uptime Kuma服务地址,如:https://status.example.com",
+ "请输入Uptime Kuma地址": "请输入Uptime Kuma地址",
+ "请输入状态页面的Slug,如:my-status": "请输入状态页面的Slug,如:my-status",
+ "请输入状态页面Slug": "请输入状态页面Slug",
+ "确定要删除此分类吗?": "确定要删除此分类吗?",
+ "绘图设置": "绘图设置",
+ "启用绘图功能": "启用绘图功能",
+ "允许回调(会泄露服务器 IP 地址)": "允许回调(会泄露服务器 IP 地址)",
+ "允许 AccountFilter 参数": "允许 AccountFilter 参数",
+ "开启之后会清除用户提示词中的": "开启之后会清除用户提示词中的",
+ "以及": "以及",
+ "检测必须等待绘图成功才能进行放大等操作": "检测必须等待绘图成功才能进行放大等操作",
+ "保存绘图设置": "保存绘图设置",
+ "Claude设置": "Claude设置",
+ "Claude请求头覆盖": "Claude请求头覆盖",
+ "为一个 JSON 文本,例如:": "为一个 JSON 文本,例如:",
+ "缺省 MaxTokens": "缺省 MaxTokens",
+ "启用Claude思考适配(-thinking后缀)": "启用Claude思考适配(-thinking后缀)",
+ "思考适配 BudgetTokens 百分比": "思考适配 BudgetTokens 百分比",
+ "0.1-1之间的小数": "0.1-1之间的小数",
+ "Gemini设置": "Gemini设置",
+ "Gemini安全设置": "Gemini安全设置",
+ "default为默认设置,可单独设置每个模型的版本": "default为默认设置,可单独设置每个模型的版本",
+ "例如:": "例如:",
+ "Gemini思考适配设置": "Gemini思考适配设置",
+ "启用Gemini思考后缀适配": "启用Gemini思考后缀适配",
+ "适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "适配 -thinking、-thinking-预算数字 和 -nothinking 后缀",
+ "0.002-1之间的小数": "0.002-1之间的小数",
+ "全局设置": "全局设置",
+ "启用请求透传": "启用请求透传",
+ "连接保活设置": "连接保活设置",
+ "启用Ping间隔": "启用Ping间隔",
+ "Ping间隔(秒)": "Ping间隔(秒)",
+ "新用户初始额度": "新用户初始额度",
+ "请求预扣费额度": "请求预扣费额度",
+ "请求结束后多退少补": "请求结束后多退少补",
+ "邀请新用户奖励额度": "邀请新用户奖励额度",
+ "新用户使用邀请码奖励额度": "新用户使用邀请码奖励额度",
+ "例如:1000": "例如:1000",
+ "保存额度设置": "保存额度设置",
+ "例如发卡网站的购买链接": "例如发卡网站的购买链接",
+ "文档地址": "文档地址",
+ "单位美元额度": "单位美元额度",
+ "一单位货币能兑换的额度": "一单位货币能兑换的额度",
+ "失败重试次数": "失败重试次数",
+ "以货币形式显示额度": "以货币形式显示额度",
+ "额度查询接口返回令牌额度而非用户额度": "额度查询接口返回令牌额度而非用户额度",
+ "默认折叠侧边栏": "默认折叠侧边栏",
+ "开启后不限制:必须设置模型倍率": "开启后不限制:必须设置模型倍率",
+ "保存通用设置": "保存通用设置",
+ "请选择日志记录时间": "请选择日志记录时间",
+ "条日志已清理!": "条日志已清理!",
+ "日志清理失败:": "日志清理失败:",
+ "启用额度消费日志记录": "启用额度消费日志记录",
+ "日志记录时间": "日志记录时间",
+ "清除历史日志": "清除历史日志",
+ "保存日志设置": "保存日志设置",
+ "监控设置": "监控设置",
+ "测试所有渠道的最长响应时间": "测试所有渠道的最长响应时间",
+ "额度提醒阈值": "额度提醒阈值",
+ "低于此额度时将发送邮件提醒用户": "低于此额度时将发送邮件提醒用户",
+ "失败时自动禁用通道": "失败时自动禁用通道",
+ "成功时自动启用通道": "成功时自动启用通道",
+ "自动禁用关键词": "自动禁用关键词",
+ "一行一个,不区分大小写": "一行一个,不区分大小写",
+ "屏蔽词过滤设置": "屏蔽词过滤设置",
+ "启用屏蔽词过滤功能": "启用屏蔽词过滤功能",
+ "启用 Prompt 检查": "启用 Prompt 检查",
+ "一行一个屏蔽词,不需要符号分割": "一行一个屏蔽词,不需要符号分割",
+ "保存屏蔽词过滤设置": "保存屏蔽词过滤设置",
+ "更新成功": "更新成功",
+ "更新失败": "更新失败",
+ "服务器地址": "服务器地址",
+ "更新服务器地址": "更新服务器地址",
+ "请先填写服务器地址": "请先填写服务器地址",
+ "充值分组倍率不是合法的 JSON 字符串": "充值分组倍率不是合法的 JSON 字符串",
+ "充值方式设置不是合法的 JSON 字符串": "充值方式设置不是合法的 JSON 字符串",
+ "支付设置": "支付设置",
+ "(当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!)": "(当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!)",
+ "例如:https://yourdomain.com": "例如:https://yourdomain.com",
+ "易支付商户ID": "易支付商户ID",
+ "易支付商户密钥": "易支付商户密钥",
+ "敏感信息不会发送到前端显示": "敏感信息不会发送到前端显示",
+ "回调地址": "回调地址",
+ "充值价格(x元/美金)": "充值价格(x元/美金)",
+ "例如:7,就是7元/美金": "例如:7,就是7元/美金",
+ "最低充值美元数量": "最低充值美元数量",
+ "例如:2,就是最低充值2$": "例如:2,就是最低充值2$",
+ "为一个 JSON 文本,键为组名称,值为倍率": "为一个 JSON 文本,键为组名称,值为倍率",
+ "充值方式设置": "充值方式设置",
+ "更新支付设置": "更新支付设置",
+ "模型请求速率限制": "模型请求速率限制",
+ "启用用户模型请求速率限制(可能会影响高并发性能)": "启用用户模型请求速率限制(可能会影响高并发性能)",
+ "分钟": "分钟",
+ "频率限制的周期(分钟)": "频率限制的周期(分钟)",
+ "用户每周期最多请求次数": "用户每周期最多请求次数",
+ "包括失败请求的次数,0代表不限制": "包括失败请求的次数,0代表不限制",
+ "用户每周期最多请求完成次数": "用户每周期最多请求完成次数",
+ "只包括请求成功的次数": "只包括请求成功的次数",
+ "分组速率限制": "分组速率限制",
+ "使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}": "使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}",
+ "示例:{\"default\": [200, 100], \"vip\": [0, 1000]}。": "示例:{\"default\": [200, 100], \"vip\": [0, 1000]}。",
+ "[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。": "[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。",
+ "分组速率配置优先级高于全局速率限制。": "分组速率配置优先级高于全局速率限制。",
+ "限制周期统一使用上方配置的“限制周期”值。": "限制周期统一使用上方配置的“限制周期”值。",
+ "保存模型速率限制": "保存模型速率限制",
+ "保存失败": "保存失败",
+ "为一个 JSON 文本,键为分组名称,值为倍率": "为一个 JSON 文本,键为分组名称,值为倍率",
+ "用户可选分组": "用户可选分组",
+ "为一个 JSON 文本,键为分组名称,值为分组描述": "为一个 JSON 文本,键为分组名称,值为分组描述",
+ "自动分组auto,从第一个开始选择": "自动分组auto,从第一个开始选择",
+ "必须是有效的 JSON 字符串数组,例如:[\"g1\",\"g2\"]": "必须是有效的 JSON 字符串数组,例如:[\"g1\",\"g2\"]",
+ "模型固定价格": "模型固定价格",
+ "一次调用消耗多少刀,优先级大于模型倍率": "一次调用消耗多少刀,优先级大于模型倍率",
+ "为一个 JSON 文本,键为模型名称,值为倍率": "为一个 JSON 文本,键为模型名称,值为倍率",
+ "模型补全倍率(仅对自定义模型有效)": "模型补全倍率(仅对自定义模型有效)",
+ "仅对自定义模型有效": "仅对自定义模型有效",
+ "保存模型倍率设置": "保存模型倍率设置",
+ "确定重置模型倍率吗?": "确定重置模型倍率吗?",
+ "重置模型倍率": "重置模型倍率",
+ "获取启用模型失败:": "获取启用模型失败:",
+ "获取启用模型失败": "获取启用模型失败",
+ "JSON解析错误:": "JSON解析错误:",
+ "保存失败:": "保存失败:",
+ "输入模型倍率": "输入模型倍率",
+ "输入补全倍率": "输入补全倍率",
+ "请输入数字": "请输入数字",
+ "模型名称已存在": "模型名称已存在",
+ "请先选择需要批量设置的模型": "请先选择需要批量设置的模型",
+ "请输入模型倍率和补全倍率": "请输入模型倍率和补全倍率",
+ "请输入有效的数字": "请输入有效的数字",
+ "请输入填充值": "请输入填充值",
+ "批量设置成功": "批量设置成功",
+ "已为 {{count}} 个模型设置{{type}}": "已为 {{count}} 个模型设置{{type}}",
+ "模型倍率和补全倍率": "模型倍率和补全倍率",
+ "添加模型": "添加模型",
+ "批量设置": "批量设置",
+ "应用更改": "应用更改",
+ "搜索模型名称": "搜索模型名称",
+ "此页面仅显示未设置价格或倍率的模型,设置后将自动从列表中移除": "此页面仅显示未设置价格或倍率的模型,设置后将自动从列表中移除",
+ "定价模式": "定价模式",
+ "固定价格": "固定价格",
+ "固定价格(每次)": "固定价格(每次)",
+ "输入每次价格": "输入每次价格",
+ "输入补全价格": "输入补全价格",
+ "批量设置模型参数": "批量设置模型参数",
+ "设置类型": "设置类型",
+ "模型倍率和补全倍率同时设置": "模型倍率和补全倍率同时设置",
+ "模型倍率值": "模型倍率值",
+ "请输入模型倍率": "请输入模型倍率",
+ "补全倍率值": "补全倍率值",
+ "请输入补全倍率": "请输入补全倍率",
+ "请输入数值": "请输入数值",
+ "将为选中的 ": "将为选中的 ",
+ " 个模型设置相同的值": " 个模型设置相同的值",
+ "当前设置类型: ": "当前设置类型: ",
+ "默认补全倍率": "默认补全倍率",
+ "添加成功": "添加成功",
+ "价格设置方式": "价格设置方式",
+ "按倍率设置": "按倍率设置",
+ "按价格设置": "按价格设置",
+ "输入价格": "输入价格",
+ "输出价格": "输出价格",
+ "获取渠道失败:": "获取渠道失败:",
+ "请至少选择一个渠道": "请至少选择一个渠道",
+ "后端请求失败": "后端请求失败",
+ "部分渠道测试失败:": "部分渠道测试失败:",
+ "未找到差异化倍率,无需同步": "未找到差异化倍率,无需同步",
+ "请求后端接口失败:": "请求后端接口失败:",
+ "同步成功": "同步成功",
+ "部分保存失败": "部分保存失败",
+ "未找到匹配的模型": "未找到匹配的模型",
+ "暂无差异化倍率显示": "暂无差异化倍率显示",
+ "请先选择同步渠道": "请先选择同步渠道",
+ "倍率类型": "倍率类型",
+ "缓存倍率": "缓存倍率",
+ "当前值": "当前值",
+ "未设置": "未设置",
+ "与本地相同": "与本地相同",
+ "运营设置": "运营设置",
+ "聊天设置": "聊天设置",
+ "速率限制设置": "速率限制设置",
+ "模型相关设置": "模型相关设置",
+ "系统设置": "系统设置",
+ "仪表盘设置": "仪表盘设置",
+ "获取初始化状态失败": "获取初始化状态失败",
+ "表单引用错误,请刷新页面重试": "表单引用错误,请刷新页面重试",
+ "请输入管理员用户名": "请输入管理员用户名",
+ "密码长度至少为8个字符": "密码长度至少为8个字符",
+ "两次输入的密码不一致": "两次输入的密码不一致",
+ "系统初始化成功,正在跳转...": "系统初始化成功,正在跳转...",
+ "初始化失败,请重试": "初始化失败,请重试",
+ "系统初始化失败,请重试": "系统初始化失败,请重试",
+ "系统初始化": "系统初始化",
+ "欢迎使用,请完成以下设置以开始使用系统": "欢迎使用,请完成以下设置以开始使用系统",
+ "数据库信息": "数据库信息",
+ "管理员账号": "管理员账号",
+ "设置系统管理员的登录信息": "设置系统管理员的登录信息",
+ "管理员账号已经初始化过,请继续设置其他参数": "管理员账号已经初始化过,请继续设置其他参数",
+ "密码": "密码",
+ "请输入管理员密码": "请输入管理员密码",
+ "请确认管理员密码": "请确认管理员密码",
+ "选择适合您使用场景的模式": "选择适合您使用场景的模式",
+ "对外运营模式": "对外运营模式",
+ "适用于为多个用户提供服务的场景": "适用于为多个用户提供服务的场景",
+ "默认模式": "默认模式",
+ "适用于个人使用的场景,不需要设置模型价格": "适用于个人使用的场景,不需要设置模型价格",
+ "无需计费": "无需计费",
+ "演示站点模式": "演示站点模式",
+ "适用于展示系统功能的场景,提供基础功能演示": "适用于展示系统功能的场景,提供基础功能演示",
+ "初始化系统": "初始化系统",
+ "使用模式说明": "使用模式说明",
+ "我已了解": "我已了解",
+ "默认模式,适用于为多个用户提供服务的场景。": "默认模式,适用于为多个用户提供服务的场景。",
+ "此模式下,系统将计算每次调用的用量,您需要对每个模型都设置价格,如果没有设置价格,用户将无法使用该模型。": "此模式下,系统将计算每次调用的用量,您需要对每个模型都设置价格,如果没有设置价格,用户将无法使用该模型。",
+ "多用户支持": "多用户支持",
+ "适用于个人使用的场景。": "适用于个人使用的场景。",
+ "不需要设置模型价格,系统将弱化用量计算,您可专注于使用模型。": "不需要设置模型价格,系统将弱化用量计算,您可专注于使用模型。",
+ "个人使用": "个人使用",
+ "适用于展示系统功能的场景。": "适用于展示系统功能的场景。",
+ "提供基础功能演示,方便用户了解系统特性。": "提供基础功能演示,方便用户了解系统特性。",
+ "体验试用": "体验试用",
+ "自动选择": "自动选择",
+ "过期时间格式错误!": "过期时间格式错误!",
+ "令牌更新成功!": "令牌更新成功!",
+ "令牌创建成功,请在列表页面点击复制获取令牌!": "令牌创建成功,请在列表页面点击复制获取令牌!",
+ "更新令牌信息": "更新令牌信息",
+ "创建新的令牌": "创建新的令牌",
+ "设置令牌的基本信息": "设置令牌的基本信息",
+ "请选择过期时间": "请选择过期时间",
+ "一天": "一天",
+ "一个月": "一个月",
+ "设置令牌可用额度和数量": "设置令牌可用额度和数量",
+ "新建数量": "新建数量",
+ "请选择或输入创建令牌的数量": "请选择或输入创建令牌的数量",
+ "20个": "20个",
+ "100个": "100个",
+ "取消无限额度": "取消无限额度",
+ "设为无限额度": "设为无限额度",
+ "设置令牌的访问限制": "设置令牌的访问限制",
+ "IP白名单": "IP白名单",
+ "允许的IP,一行一个,不填写则不限制": "允许的IP,一行一个,不填写则不限制",
+ "请勿过度信任此功能,IP可能被伪造": "请勿过度信任此功能,IP可能被伪造",
+ "勾选启用模型限制后可选择": "勾选启用模型限制后可选择",
+ "非必要,不建议启用模型限制": "非必要,不建议启用模型限制",
+ "分组信息": "分组信息",
+ "设置令牌的分组": "设置令牌的分组",
+ "令牌分组,默认为用户的分组": "令牌分组,默认为用户的分组",
+ "管理员未设置用户可选分组": "管理员未设置用户可选分组",
+ "请输入兑换码!": "请输入兑换码!",
+ "兑换成功!": "兑换成功!",
+ "成功兑换额度:": "成功兑换额度:",
+ "请求失败": "请求失败",
+ "超级管理员未设置充值链接!": "超级管理员未设置充值链接!",
+ "管理员未开启在线充值!": "管理员未开启在线充值!",
+ "充值数量不能小于": "充值数量不能小于",
+ "支付请求失败": "支付请求失败",
+ "划转金额最低为": "划转金额最低为",
+ "邀请链接已复制到剪切板": "邀请链接已复制到剪切板",
+ "支付方式配置错误, 请联系管理员": "支付方式配置错误, 请联系管理员",
+ "划转邀请额度": "划转邀请额度",
+ "可用邀请额度": "可用邀请额度",
+ "划转额度": "划转额度",
+ "充值确认": "充值确认",
+ "充值数量": "充值数量",
+ "实付金额": "实付金额",
+ "支付方式": "支付方式",
+ "在线充值": "在线充值",
+ "快速方便的充值方式": "快速方便的充值方式",
+ "选择充值额度": "选择充值额度",
+ "实付": "实付",
+ "或输入自定义金额": "或输入自定义金额",
+ "充值数量,最低 ": "充值数量,最低 ",
+ "选择支付方式": "选择支付方式",
+ "处理中": "处理中",
+ "兑换码充值": "兑换码充值",
+ "使用兑换码快速充值": "使用兑换码快速充值",
+ "请输入兑换码": "请输入兑换码",
+ "兑换中...": "兑换中...",
+ "兑换": "兑换",
+ "邀请奖励": "邀请奖励",
+ "邀请好友获得额外奖励": "邀请好友获得额外奖励",
+ "待使用收益": "待使用收益",
+ "总收益": "总收益",
+ "邀请人数": "邀请人数",
+ "邀请链接": "邀请链接",
+ "邀请好友注册,好友充值后您可获得相应奖励": "邀请好友注册,好友充值后您可获得相应奖励",
+ "通过划转功能将奖励额度转入到您的账户余额中": "通过划转功能将奖励额度转入到您的账户余额中",
+ "邀请的好友越多,获得的奖励越多": "邀请的好友越多,获得的奖励越多",
+ "用户名和密码不能为空!": "用户名和密码不能为空!",
+ "用户账户创建成功!": "用户账户创建成功!",
+ "提交": "提交",
+ "创建新用户账户": "创建新用户账户",
+ "请输入显示名称": "请输入显示名称",
+ "请输入密码": "请输入密码",
+ "请输入备注(仅管理员可见)": "请输入备注(仅管理员可见)",
+ "编辑用户": "编辑用户",
+ "用户的基本账户信息": "用户的基本账户信息",
+ "请输入新的用户名": "请输入新的用户名",
+ "请输入新的密码,最短 8 位": "请输入新的密码,最短 8 位",
+ "显示名称": "显示名称",
+ "请输入新的显示名称": "请输入新的显示名称",
+ "权限设置": "权限设置",
+ "用户分组和额度管理": "用户分组和额度管理",
+ "请输入新的剩余额度": "请输入新的剩余额度",
+ "添加额度": "添加额度",
+ "第三方账户绑定状态(只读)": "第三方账户绑定状态(只读)",
+ "已绑定的 GitHub 账户": "已绑定的 GitHub 账户",
+ "已绑定的 OIDC 账户": "已绑定的 OIDC 账户",
+ "已绑定的微信账户": "已绑定的微信账户",
+ "已绑定的邮箱账户": "已绑定的邮箱账户",
+ "已绑定的 Telegram 账户": "已绑定的 Telegram 账户",
+ "新额度": "新额度",
+ "需要添加的额度(支持负数)": "需要添加的额度(支持负数)"
+}
\ No newline at end of file
diff --git a/main.go b/main.go
new file mode 100644
index 00000000..ca3da601
--- /dev/null
+++ b/main.go
@@ -0,0 +1,210 @@
+package main
+
+import (
+ "embed"
+ "fmt"
+ "log"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/controller"
+ "one-api/middleware"
+ "one-api/model"
+ "one-api/router"
+ "one-api/service"
+ "one-api/setting/ratio_setting"
+ "os"
+ "strconv"
+
+ "github.com/bytedance/gopkg/util/gopool"
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-contrib/sessions/cookie"
+ "github.com/gin-gonic/gin"
+ "github.com/joho/godotenv"
+
+ _ "net/http/pprof"
+)
+
+//go:embed web/dist
+var buildFS embed.FS
+
+//go:embed web/dist/index.html
+var indexPage []byte
+
+func main() {
+
+ err := InitResources()
+ if err != nil {
+ common.FatalLog("failed to initialize resources: " + err.Error())
+ return
+ }
+
+ common.SysLog("New API " + common.Version + " started")
+ if os.Getenv("GIN_MODE") != "debug" {
+ gin.SetMode(gin.ReleaseMode)
+ }
+ if common.DebugEnabled {
+ common.SysLog("running in debug mode")
+ }
+
+ defer func() {
+ err := model.CloseDB()
+ if err != nil {
+ common.FatalLog("failed to close database: " + err.Error())
+ }
+ }()
+
+ if common.RedisEnabled {
+ // for compatibility with old versions
+ common.MemoryCacheEnabled = true
+ }
+ if common.MemoryCacheEnabled {
+ common.SysLog("memory cache enabled")
+ common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
+
+ // Add panic recovery and retry for InitChannelCache
+ func() {
+ defer func() {
+ if r := recover(); r != nil {
+ common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
+ // Retry once
+ _, _, fixErr := model.FixAbility()
+ if fixErr != nil {
+ common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
+ }
+ }
+ }()
+ model.InitChannelCache()
+ }()
+
+ go model.SyncChannelCache(common.SyncFrequency)
+ }
+
+ // 热更新配置
+ go model.SyncOptions(common.SyncFrequency)
+
+ // 数据看板
+ go model.UpdateQuotaData()
+
+ if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
+ frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
+ if err != nil {
+ common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
+ }
+ go controller.AutomaticallyUpdateChannels(frequency)
+ }
+ if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
+ frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
+ if err != nil {
+ common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
+ }
+ go controller.AutomaticallyTestChannels(frequency)
+ }
+ if common.IsMasterNode && constant.UpdateTask {
+ gopool.Go(func() {
+ controller.UpdateMidjourneyTaskBulk()
+ })
+ gopool.Go(func() {
+ controller.UpdateTaskBulk()
+ })
+ }
+ if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
+ common.BatchUpdateEnabled = true
+ common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
+ model.InitBatchUpdater()
+ }
+
+ if os.Getenv("ENABLE_PPROF") == "true" {
+ gopool.Go(func() {
+ log.Println(http.ListenAndServe("0.0.0.0:8005", nil))
+ })
+ go common.Monitor()
+ common.SysLog("pprof enabled")
+ }
+
+ // Initialize HTTP server
+ server := gin.New()
+ server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
+ common.SysError(fmt.Sprintf("panic detected: %v", err))
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "error": gin.H{
+ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
+ "type": "new_api_panic",
+ },
+ })
+ }))
+ // This will cause SSE not to work!!!
+ //server.Use(gzip.Gzip(gzip.DefaultCompression))
+ server.Use(middleware.RequestId())
+ middleware.SetUpLogger(server)
+ // Initialize session store
+ store := cookie.NewStore([]byte(common.SessionSecret))
+ store.Options(sessions.Options{
+ Path: "/",
+ MaxAge: 2592000, // 30 days
+ HttpOnly: true,
+ Secure: false,
+ SameSite: http.SameSiteStrictMode,
+ })
+ server.Use(sessions.Sessions("session", store))
+
+ router.SetRouter(server, buildFS, indexPage)
+ var port = os.Getenv("PORT")
+ if port == "" {
+ port = strconv.Itoa(*common.Port)
+ }
+ err = server.Run(":" + port)
+ if err != nil {
+ common.FatalLog("failed to start HTTP server: " + err.Error())
+ }
+}
+
+func InitResources() error {
+ // Initialize resources here if needed
+ // This is a placeholder function for future resource initialization
+ err := godotenv.Load(".env")
+ if err != nil {
+ common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
+ common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
+ }
+
+ // 加载环境变量
+ common.InitEnv()
+
+ common.SetupLogger()
+
+ // Initialize model settings
+ ratio_setting.InitRatioSettings()
+
+ service.InitHttpClient()
+
+ service.InitTokenEncoders()
+
+ // Initialize SQL Database
+ err = model.InitDB()
+ if err != nil {
+ common.FatalLog("failed to initialize database: " + err.Error())
+ return err
+ }
+
+ model.CheckSetup()
+
+ // Initialize options, should after model.InitDB()
+ model.InitOptionMap()
+
+ // 初始化模型
+ model.GetPricing()
+
+ // Initialize SQL Database
+ err = model.InitLogDB()
+ if err != nil {
+ return err
+ }
+
+ // Initialize Redis
+ err = common.InitRedisClient()
+ if err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/makefile b/makefile
new file mode 100644
index 00000000..cbc4ea6a
--- /dev/null
+++ b/makefile
@@ -0,0 +1,14 @@
+FRONTEND_DIR = ./web
+BACKEND_DIR = .
+
+.PHONY: all build-frontend start-backend
+
+all: build-frontend start-backend
+
+build-frontend:
+ @echo "Building frontend..."
+ @cd $(FRONTEND_DIR) && bun install && DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
+
+start-backend:
+ @echo "Starting backend dev server..."
+ @cd $(BACKEND_DIR) && go run main.go &
diff --git a/middleware/auth.go b/middleware/auth.go
new file mode 100644
index 00000000..a158318c
--- /dev/null
+++ b/middleware/auth.go
@@ -0,0 +1,286 @@
+package middleware
+
+import (
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "strconv"
+ "strings"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+)
+
+func validUserInfo(username string, role int) bool {
+ // check username is empty
+ if strings.TrimSpace(username) == "" {
+ return false
+ }
+ if !common.IsValidateRole(role) {
+ return false
+ }
+ return true
+}
+
+func authHelper(c *gin.Context, minRole int) {
+ session := sessions.Default(c)
+ username := session.Get("username")
+ role := session.Get("role")
+ id := session.Get("id")
+ status := session.Get("status")
+ useAccessToken := false
+ if username == nil {
+ // Check access token
+ accessToken := c.Request.Header.Get("Authorization")
+ if accessToken == "" {
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "success": false,
+ "message": "无权进行此操作,未登录且未提供 access token",
+ })
+ c.Abort()
+ return
+ }
+ user := model.ValidateAccessToken(accessToken)
+ if user != nil && user.Username != "" {
+ if !validUserInfo(user.Username, user.Role) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无权进行此操作,用户信息无效",
+ })
+ c.Abort()
+ return
+ }
+ // Token is valid
+ username = user.Username
+ role = user.Role
+ id = user.Id
+ status = user.Status
+ useAccessToken = true
+ } else {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无权进行此操作,access token 无效",
+ })
+ c.Abort()
+ return
+ }
+ }
+ // get header New-Api-User
+ apiUserIdStr := c.Request.Header.Get("New-Api-User")
+ if apiUserIdStr == "" {
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "success": false,
+ "message": "无权进行此操作,未提供 New-Api-User",
+ })
+ c.Abort()
+ return
+ }
+ apiUserId, err := strconv.Atoi(apiUserIdStr)
+ if err != nil {
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "success": false,
+ "message": "无权进行此操作,New-Api-User 格式错误",
+ })
+ c.Abort()
+ return
+
+ }
+ if id != apiUserId {
+ c.JSON(http.StatusUnauthorized, gin.H{
+ "success": false,
+ "message": "无权进行此操作,New-Api-User 与登录用户不匹配",
+ })
+ c.Abort()
+ return
+ }
+ if status.(int) == common.UserStatusDisabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "用户已被封禁",
+ })
+ c.Abort()
+ return
+ }
+ if role.(int) < minRole {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无权进行此操作,权限不足",
+ })
+ c.Abort()
+ return
+ }
+ if !validUserInfo(username.(string), role.(int)) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无权进行此操作,用户信息无效",
+ })
+ c.Abort()
+ return
+ }
+ c.Set("username", username)
+ c.Set("role", role)
+ c.Set("id", id)
+ c.Set("group", session.Get("group"))
+ c.Set("use_access_token", useAccessToken)
+
+ //userCache, err := model.GetUserCache(id.(int))
+ //if err != nil {
+ // c.JSON(http.StatusOK, gin.H{
+ // "success": false,
+ // "message": err.Error(),
+ // })
+ // c.Abort()
+ // return
+ //}
+ //userCache.WriteContext(c)
+
+ c.Next()
+}
+
+func TryUserAuth() func(c *gin.Context) {
+ return func(c *gin.Context) {
+ session := sessions.Default(c)
+ id := session.Get("id")
+ if id != nil {
+ c.Set("id", id)
+ }
+ c.Next()
+ }
+}
+
+func UserAuth() func(c *gin.Context) {
+ return func(c *gin.Context) {
+ authHelper(c, common.RoleCommonUser)
+ }
+}
+
+func AdminAuth() func(c *gin.Context) {
+ return func(c *gin.Context) {
+ authHelper(c, common.RoleAdminUser)
+ }
+}
+
+func RootAuth() func(c *gin.Context) {
+ return func(c *gin.Context) {
+ authHelper(c, common.RoleRootUser)
+ }
+}
+
+func WssAuth(c *gin.Context) {
+
+}
+
+func TokenAuth() func(c *gin.Context) {
+ return func(c *gin.Context) {
+ // 先检测是否为ws
+ if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" {
+ // Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1
+ // read sk from Sec-WebSocket-Protocol
+ key := c.Request.Header.Get("Sec-WebSocket-Protocol")
+ parts := strings.Split(key, ",")
+ for _, part := range parts {
+ part = strings.TrimSpace(part)
+ if strings.HasPrefix(part, "openai-insecure-api-key") {
+ key = strings.TrimPrefix(part, "openai-insecure-api-key.")
+ break
+ }
+ }
+ c.Request.Header.Set("Authorization", "Bearer "+key)
+ }
+ // 检查path包含/v1/messages
+ if strings.Contains(c.Request.URL.Path, "/v1/messages") {
+ // 从x-api-key中获取key
+ key := c.Request.Header.Get("x-api-key")
+ if key != "" {
+ c.Request.Header.Set("Authorization", "Bearer "+key)
+ }
+ }
+ // gemini api 从query中获取key
+ if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
+ skKey := c.Query("key")
+ if skKey != "" {
+ c.Request.Header.Set("Authorization", "Bearer "+skKey)
+ }
+ // 从x-goog-api-key header中获取key
+ xGoogKey := c.Request.Header.Get("x-goog-api-key")
+ if xGoogKey != "" {
+ c.Request.Header.Set("Authorization", "Bearer "+xGoogKey)
+ }
+ }
+ key := c.Request.Header.Get("Authorization")
+ parts := make([]string, 0)
+ key = strings.TrimPrefix(key, "Bearer ")
+ if key == "" || key == "midjourney-proxy" {
+ key = c.Request.Header.Get("mj-api-secret")
+ key = strings.TrimPrefix(key, "Bearer ")
+ key = strings.TrimPrefix(key, "sk-")
+ parts = strings.Split(key, "-")
+ key = parts[0]
+ } else {
+ key = strings.TrimPrefix(key, "sk-")
+ parts = strings.Split(key, "-")
+ key = parts[0]
+ }
+ token, err := model.ValidateUserToken(key)
+ if token != nil {
+ id := c.GetInt("id")
+ if id == 0 {
+ c.Set("id", token.UserId)
+ }
+ }
+ if err != nil {
+ abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
+ return
+ }
+ userCache, err := model.GetUserCache(token.UserId)
+ if err != nil {
+ abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
+ return
+ }
+ userEnabled := userCache.Status == common.UserStatusEnabled
+ if !userEnabled {
+ abortWithOpenAiMessage(c, http.StatusForbidden, "用户已被封禁")
+ return
+ }
+
+ userCache.WriteContext(c)
+
+ err = SetupContextForToken(c, token, parts...)
+ if err != nil {
+ return
+ }
+ c.Next()
+ }
+}
+
+func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
+ if token == nil {
+ return fmt.Errorf("token is nil")
+ }
+ c.Set("id", token.UserId)
+ c.Set("token_id", token.Id)
+ c.Set("token_key", token.Key)
+ c.Set("token_name", token.Name)
+ c.Set("token_unlimited_quota", token.UnlimitedQuota)
+ if !token.UnlimitedQuota {
+ c.Set("token_quota", token.RemainQuota)
+ }
+ if token.ModelLimitsEnabled {
+ c.Set("token_model_limit_enabled", true)
+ c.Set("token_model_limit", token.GetModelLimitsMap())
+ } else {
+ c.Set("token_model_limit_enabled", false)
+ }
+ c.Set("allow_ips", token.GetIpLimitsMap())
+ c.Set("token_group", token.Group)
+ if len(parts) > 1 {
+ if model.IsAdmin(token.UserId) {
+ c.Set("specific_channel_id", parts[1])
+ } else {
+ abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
+ return fmt.Errorf("普通用户不支持指定渠道")
+ }
+ }
+ return nil
+}
diff --git a/middleware/cache.go b/middleware/cache.go
new file mode 100644
index 00000000..979734ab
--- /dev/null
+++ b/middleware/cache.go
@@ -0,0 +1,16 @@
+package middleware
+
+import (
+ "github.com/gin-gonic/gin"
+)
+
+func Cache() func(c *gin.Context) {
+ return func(c *gin.Context) {
+ if c.Request.RequestURI == "/" {
+ c.Header("Cache-Control", "no-cache")
+ } else {
+ c.Header("Cache-Control", "max-age=604800") // one week
+ }
+ c.Next()
+ }
+}
diff --git a/middleware/cors.go b/middleware/cors.go
new file mode 100644
index 00000000..d2a109ab
--- /dev/null
+++ b/middleware/cors.go
@@ -0,0 +1,15 @@
+package middleware
+
+import (
+ "github.com/gin-contrib/cors"
+ "github.com/gin-gonic/gin"
+)
+
+func CORS() gin.HandlerFunc {
+ config := cors.DefaultConfig()
+ config.AllowAllOrigins = true
+ config.AllowCredentials = true
+ config.AllowMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
+ config.AllowHeaders = []string{"*"}
+ return cors.New(config)
+}
diff --git a/middleware/distributor.go b/middleware/distributor.go
new file mode 100644
index 00000000..a6889e39
--- /dev/null
+++ b/middleware/distributor.go
@@ -0,0 +1,331 @@
+package middleware
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/model"
+ relayconstant "one-api/relay/constant"
+ "one-api/service"
+ "one-api/setting"
+ "one-api/setting/ratio_setting"
+ "one-api/types"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+type ModelRequest struct {
+ Model string `json:"model"`
+ Group string `json:"group,omitempty"`
+}
+
+func Distribute() func(c *gin.Context) {
+ return func(c *gin.Context) {
+ allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
+ if len(allowIpsMap) != 0 {
+ clientIp := c.ClientIP()
+ if _, ok := allowIpsMap[clientIp]; !ok {
+ abortWithOpenAiMessage(c, http.StatusForbidden, "您的 IP 不在令牌允许访问的列表中")
+ return
+ }
+ }
+ var channel *model.Channel
+ channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
+ modelRequest, shouldSelectChannel, err := getModelRequest(c)
+ if err != nil {
+ abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
+ return
+ }
+ userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
+ tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
+ if tokenGroup != "" {
+ // check common.UserUsableGroups[userGroup]
+ if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
+ abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
+ return
+ }
+ // check group in common.GroupRatio
+ if !ratio_setting.ContainsGroupRatio(tokenGroup) {
+ if tokenGroup != "auto" {
+ abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
+ return
+ }
+ }
+ userGroup = tokenGroup
+ }
+ common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
+ if ok {
+ id, err := strconv.Atoi(channelId.(string))
+ if err != nil {
+ abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
+ return
+ }
+ channel, err = model.GetChannelById(id, true)
+ if err != nil {
+ abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的渠道 Id")
+ return
+ }
+ if channel.Status != common.ChannelStatusEnabled {
+ abortWithOpenAiMessage(c, http.StatusForbidden, "该渠道已被禁用")
+ return
+ }
+ } else {
+ // Select a channel for the user
+ // check token model mapping
+ modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
+ if modelLimitEnable {
+ s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
+ var tokenModelLimit map[string]bool
+ if ok {
+ tokenModelLimit = s.(map[string]bool)
+ } else {
+ tokenModelLimit = map[string]bool{}
+ }
+ if tokenModelLimit != nil {
+ if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
+ abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
+ return
+ }
+ } else {
+ // token model limit is empty, all models are not allowed
+ abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
+ return
+ }
+ }
+
+ if shouldSelectChannel {
+ var selectGroup string
+ channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
+ if err != nil {
+ showGroup := userGroup
+ if userGroup == "auto" {
+ showGroup = fmt.Sprintf("auto(%s)", selectGroup)
+ }
+ message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
+ // 如果错误,但是渠道不为空,说明是数据库一致性问题
+ if channel != nil {
+ common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
+ message = "数据库一致性已被破坏,请联系管理员"
+ }
+ // 如果错误,而且渠道为空,说明是没有可用渠道
+ abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
+ return
+ }
+ if channel == nil {
+ abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
+ return
+ }
+ }
+ }
+ common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
+ SetupContextForSelectedChannel(c, channel, modelRequest.Model)
+ c.Next()
+ }
+}
+
+func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
+ var modelRequest ModelRequest
+ shouldSelectChannel := true
+ var err error
+ if strings.Contains(c.Request.URL.Path, "/mj/") {
+ relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
+ if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
+ relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
+ relayMode == relayconstant.RelayModeMidjourneyNotify ||
+ relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
+ shouldSelectChannel = false
+ } else {
+ midjourneyRequest := dto.MidjourneyRequest{}
+ err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
+ if err != nil {
+ return nil, false, err
+ }
+ midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
+ if mjErr != nil {
+ return nil, false, fmt.Errorf(mjErr.Description)
+ }
+ if midjourneyModel == "" {
+ if !success {
+ return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
+ } else {
+ // task fetch, task fetch by condition, notify
+ shouldSelectChannel = false
+ }
+ }
+ modelRequest.Model = midjourneyModel
+ }
+ c.Set("relay_mode", relayMode)
+ } else if strings.Contains(c.Request.URL.Path, "/suno/") {
+ relayMode := relayconstant.Path2RelaySuno(c.Request.Method, c.Request.URL.Path)
+ if relayMode == relayconstant.RelayModeSunoFetch ||
+ relayMode == relayconstant.RelayModeSunoFetchByID {
+ shouldSelectChannel = false
+ } else {
+ modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action"))
+ modelRequest.Model = modelName
+ }
+ c.Set("platform", string(constant.TaskPlatformSuno))
+ c.Set("relay_mode", relayMode)
+ } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
+ err = common.UnmarshalBodyReusable(c, &modelRequest)
+ var platform string
+ var relayMode int
+ if strings.HasPrefix(modelRequest.Model, "jimeng") {
+ platform = string(constant.TaskPlatformJimeng)
+ relayMode = relayconstant.Path2RelayJimeng(c.Request.Method, c.Request.URL.Path)
+ if relayMode == relayconstant.RelayModeJimengFetchByID {
+ shouldSelectChannel = false
+ }
+ } else {
+ platform = string(constant.TaskPlatformKling)
+ relayMode = relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
+ if relayMode == relayconstant.RelayModeKlingFetchByID {
+ shouldSelectChannel = false
+ }
+ }
+ c.Set("platform", platform)
+ c.Set("relay_mode", relayMode)
+ } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
+ // Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
+ relayMode := relayconstant.RelayModeGemini
+ modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
+ if modelName != "" {
+ modelRequest.Model = modelName
+ }
+ c.Set("relay_mode", relayMode)
+ } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
+ err = common.UnmarshalBodyReusable(c, &modelRequest)
+ }
+ if err != nil {
+ return nil, false, errors.New("无效的请求, " + err.Error())
+ }
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
+ //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
+ modelRequest.Model = c.Query("model")
+ }
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
+ if modelRequest.Model == "" {
+ modelRequest.Model = "text-moderation-stable"
+ }
+ }
+ if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
+ if modelRequest.Model == "" {
+ modelRequest.Model = c.Param("model")
+ }
+ }
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
+ modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e")
+ } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
+ modelRequest.Model = common.GetStringIfEmpty(c.PostForm("model"), "gpt-image-1")
+ }
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
+ relayMode := relayconstant.RelayModeAudioSpeech
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
+ modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1")
+ } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
+ modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
+ modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
+ relayMode = relayconstant.RelayModeAudioTranslation
+ } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
+ modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model"))
+ modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1")
+ relayMode = relayconstant.RelayModeAudioTranscription
+ }
+ c.Set("relay_mode", relayMode)
+ }
+ if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
+ // playground chat completions
+ err = common.UnmarshalBodyReusable(c, &modelRequest)
+ if err != nil {
+ return nil, false, errors.New("无效的请求, " + err.Error())
+ }
+ common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
+ }
+ return &modelRequest, shouldSelectChannel, nil
+}
+
+func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
+ c.Set("original_model", modelName) // for retry
+ if channel == nil {
+ return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed)
+ }
+ common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
+ common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
+ common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
+ common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
+ common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
+ common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
+ if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
+ common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
+ }
+ common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
+ common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
+ common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
+
+ key, index, newAPIError := channel.GetNextEnabledKey()
+ if newAPIError != nil {
+ return newAPIError
+ }
+ if channel.ChannelInfo.IsMultiKey {
+ common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
+ common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
+ }
+ // c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
+ common.SetContextKey(c, constant.ContextKeyChannelKey, key)
+ common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
+
+ // TODO: api_version统一
+ switch channel.Type {
+ case constant.ChannelTypeAzure:
+ c.Set("api_version", channel.Other)
+ case constant.ChannelTypeVertexAi:
+ c.Set("region", channel.Other)
+ case constant.ChannelTypeXunfei:
+ c.Set("api_version", channel.Other)
+ case constant.ChannelTypeGemini:
+ c.Set("api_version", channel.Other)
+ case constant.ChannelTypeAli:
+ c.Set("plugin", channel.Other)
+ case constant.ChannelCloudflare:
+ c.Set("api_version", channel.Other)
+ case constant.ChannelTypeMokaAI:
+ c.Set("api_version", channel.Other)
+ case constant.ChannelTypeCoze:
+ c.Set("bot_id", channel.Other)
+ }
+ return nil
+}
+
+// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
+// 输入格式: /v1beta/models/gemini-2.0-flash:generateContent
+// 输出: gemini-2.0-flash
+func extractModelNameFromGeminiPath(path string) string {
+ // 查找 "/models/" 的位置
+ modelsPrefix := "/models/"
+ modelsIndex := strings.Index(path, modelsPrefix)
+ if modelsIndex == -1 {
+ return ""
+ }
+
+ // 从 "/models/" 之后开始提取
+ startIndex := modelsIndex + len(modelsPrefix)
+ if startIndex >= len(path) {
+ return ""
+ }
+
+ // 查找 ":" 的位置,模型名在 ":" 之前
+ colonIndex := strings.Index(path[startIndex:], ":")
+ if colonIndex == -1 {
+ // 如果没有找到 ":",返回从 "/models/" 到路径结尾的部分
+ return path[startIndex:]
+ }
+
+ // 返回模型名部分
+ return path[startIndex : startIndex+colonIndex]
+}
diff --git a/middleware/gzip.go b/middleware/gzip.go
new file mode 100644
index 00000000..5b9d566a
--- /dev/null
+++ b/middleware/gzip.go
@@ -0,0 +1,38 @@
+package middleware
+
+import (
+ "compress/gzip"
+ "github.com/andybalholm/brotli"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+)
+
+func DecompressRequestMiddleware() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ if c.Request.Body == nil || c.Request.Method == http.MethodGet {
+ c.Next()
+ return
+ }
+ switch c.GetHeader("Content-Encoding") {
+ case "gzip":
+ gzipReader, err := gzip.NewReader(c.Request.Body)
+ if err != nil {
+ c.AbortWithStatus(http.StatusBadRequest)
+ return
+ }
+ defer gzipReader.Close()
+
+ // Replace the request body with the decompressed data
+ c.Request.Body = io.NopCloser(gzipReader)
+ c.Request.Header.Del("Content-Encoding")
+ case "br":
+ reader := brotli.NewReader(c.Request.Body)
+ c.Request.Body = io.NopCloser(reader)
+ c.Request.Header.Del("Content-Encoding")
+ }
+
+ // Continue processing the request
+ c.Next()
+ }
+}
diff --git a/middleware/kling_adapter.go b/middleware/kling_adapter.go
new file mode 100644
index 00000000..3d4943d2
--- /dev/null
+++ b/middleware/kling_adapter.go
@@ -0,0 +1,47 @@
+package middleware
+
+import (
+ "bytes"
+ "encoding/json"
+ "io"
+ "one-api/common"
+ "one-api/constant"
+
+ "github.com/gin-gonic/gin"
+)
+
+func KlingRequestConvert() func(c *gin.Context) {
+ return func(c *gin.Context) {
+ var originalReq map[string]interface{}
+ if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
+ c.Next()
+ return
+ }
+
+ model, _ := originalReq["model_name"].(string)
+ prompt, _ := originalReq["prompt"].(string)
+
+ unifiedReq := map[string]interface{}{
+ "model": model,
+ "prompt": prompt,
+ "metadata": originalReq,
+ }
+
+ jsonData, err := json.Marshal(unifiedReq)
+ if err != nil {
+ c.Next()
+ return
+ }
+
+ // Rewrite request body and path
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
+ c.Request.URL.Path = "/v1/video/generations"
+ if image, ok := originalReq["image"]; !ok || image == "" {
+ c.Set("action", constant.TaskActionTextGenerate)
+ }
+
+ // We have to reset the request body for the next handlers
+ c.Set(common.KeyRequestBody, jsonData)
+ c.Next()
+ }
+}
diff --git a/middleware/logger.go b/middleware/logger.go
new file mode 100644
index 00000000..02f2e0a9
--- /dev/null
+++ b/middleware/logger.go
@@ -0,0 +1,25 @@
+package middleware
+
+import (
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "one-api/common"
+)
+
+func SetUpLogger(server *gin.Engine) {
+ server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
+ var requestID string
+ if param.Keys != nil {
+ requestID = param.Keys[common.RequestIdKey].(string)
+ }
+ return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
+ param.TimeStamp.Format("2006/01/02 - 15:04:05"),
+ requestID,
+ param.StatusCode,
+ param.Latency,
+ param.ClientIP,
+ param.Method,
+ param.Path,
+ )
+ }))
+}
diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go
new file mode 100644
index 00000000..14d9a737
--- /dev/null
+++ b/middleware/model-rate-limit.go
@@ -0,0 +1,199 @@
+package middleware
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/common/limiter"
+ "one-api/constant"
+ "one-api/setting"
+ "strconv"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/go-redis/redis/v8"
+)
+
+const (
+ ModelRequestRateLimitCountMark = "MRRL"
+ ModelRequestRateLimitSuccessCountMark = "MRRLS"
+)
+
+// 检查Redis中的请求限制
+func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
+ // 如果maxCount为0,表示不限制
+ if maxCount == 0 {
+ return true, nil
+ }
+
+ // 获取当前计数
+ length, err := rdb.LLen(ctx, key).Result()
+ if err != nil {
+ return false, err
+ }
+
+ // 如果未达到限制,允许请求
+ if length < int64(maxCount) {
+ return true, nil
+ }
+
+ // 检查时间窗口
+ oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
+ oldTime, err := time.Parse(timeFormat, oldTimeStr)
+ if err != nil {
+ return false, err
+ }
+
+ nowTimeStr := time.Now().Format(timeFormat)
+ nowTime, err := time.Parse(timeFormat, nowTimeStr)
+ if err != nil {
+ return false, err
+ }
+ // 如果在时间窗口内已达到限制,拒绝请求
+ subTime := nowTime.Sub(oldTime).Seconds()
+ if int64(subTime) < duration {
+ rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
+ return false, nil
+ }
+
+ return true, nil
+}
+
+// 记录Redis请求
+func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
+ // 如果maxCount为0,不记录请求
+ if maxCount == 0 {
+ return
+ }
+
+ now := time.Now().Format(timeFormat)
+ rdb.LPush(ctx, key, now)
+ rdb.LTrim(ctx, key, 0, int64(maxCount-1))
+ rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
+}
+
+// Redis限流处理器
+func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ userId := strconv.Itoa(c.GetInt("id"))
+ ctx := context.Background()
+ rdb := common.RDB
+
+ // 1. 检查成功请求数限制
+ successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
+ allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
+ if err != nil {
+ fmt.Println("检查成功请求数限制失败:", err.Error())
+ abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
+ return
+ }
+ if !allowed {
+ abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
+ return
+ }
+
+ //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
+ if totalMaxCount > 0 {
+ totalKey := fmt.Sprintf("rateLimit:%s", userId)
+ // 初始化
+ tb := limiter.New(ctx, rdb)
+ allowed, err = tb.Allow(
+ ctx,
+ totalKey,
+ limiter.WithCapacity(int64(totalMaxCount)*duration),
+ limiter.WithRate(int64(totalMaxCount)),
+ limiter.WithRequested(duration),
+ )
+
+ if err != nil {
+ fmt.Println("检查总请求数限制失败:", err.Error())
+ abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
+ return
+ }
+
+ if !allowed {
+ abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
+ }
+ }
+
+ // 4. 处理请求
+ c.Next()
+
+ // 5. 如果请求成功,记录成功请求
+ if c.Writer.Status() < 400 {
+ recordRedisRequest(ctx, rdb, successKey, successMaxCount)
+ }
+ }
+}
+
+// 内存限流处理器
+func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
+ inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
+
+ return func(c *gin.Context) {
+ userId := strconv.Itoa(c.GetInt("id"))
+ totalKey := ModelRequestRateLimitCountMark + userId
+ successKey := ModelRequestRateLimitSuccessCountMark + userId
+
+ // 1. 检查总请求数限制(当totalMaxCount为0时跳过)
+ if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
+ c.Status(http.StatusTooManyRequests)
+ c.Abort()
+ return
+ }
+
+ // 2. 检查成功请求数限制
+ // 使用一个临时key来检查限制,这样可以避免实际记录
+ checkKey := successKey + "_check"
+ if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
+ c.Status(http.StatusTooManyRequests)
+ c.Abort()
+ return
+ }
+
+ // 3. 处理请求
+ c.Next()
+
+ // 4. 如果请求成功,记录到实际的成功请求计数中
+ if c.Writer.Status() < 400 {
+ inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
+ }
+ }
+}
+
+// ModelRequestRateLimit 模型请求限流中间件
+func ModelRequestRateLimit() func(c *gin.Context) {
+ return func(c *gin.Context) {
+ // 在每个请求时检查是否启用限流
+ if !setting.ModelRequestRateLimitEnabled {
+ c.Next()
+ return
+ }
+
+ // 计算限流参数
+ duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
+ totalMaxCount := setting.ModelRequestRateLimitCount
+ successMaxCount := setting.ModelRequestRateLimitSuccessCount
+
+ // 获取分组
+ group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
+ if group == "" {
+ group = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
+ }
+
+ //获取分组的限流配置
+ groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
+ if found {
+ totalMaxCount = groupTotalCount
+ successMaxCount = groupSuccessCount
+ }
+
+ // 根据存储类型选择并执行限流处理器
+ if common.RedisEnabled {
+ redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
+ } else {
+ memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
+ }
+ }
+}
diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go
new file mode 100644
index 00000000..e38fb8f6
--- /dev/null
+++ b/middleware/rate-limit.go
@@ -0,0 +1,113 @@
+package middleware
+
+import (
+ "context"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/common"
+ "time"
+)
+
+var timeFormat = "2006-01-02T15:04:05.000Z"
+
+var inMemoryRateLimiter common.InMemoryRateLimiter
+
+var defNext = func(c *gin.Context) {
+ c.Next()
+}
+
+func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) {
+ ctx := context.Background()
+ rdb := common.RDB
+ key := "rateLimit:" + mark + c.ClientIP()
+ listLength, err := rdb.LLen(ctx, key).Result()
+ if err != nil {
+ fmt.Println(err.Error())
+ c.Status(http.StatusInternalServerError)
+ c.Abort()
+ return
+ }
+ if listLength < int64(maxRequestNum) {
+ rdb.LPush(ctx, key, time.Now().Format(timeFormat))
+ rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ } else {
+ oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
+ oldTime, err := time.Parse(timeFormat, oldTimeStr)
+ if err != nil {
+ fmt.Println(err)
+ c.Status(http.StatusInternalServerError)
+ c.Abort()
+ return
+ }
+ nowTimeStr := time.Now().Format(timeFormat)
+ nowTime, err := time.Parse(timeFormat, nowTimeStr)
+ if err != nil {
+ fmt.Println(err)
+ c.Status(http.StatusInternalServerError)
+ c.Abort()
+ return
+ }
+ // time.Since will return negative number!
+ // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
+ if int64(nowTime.Sub(oldTime).Seconds()) < duration {
+ rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ c.Status(http.StatusTooManyRequests)
+ c.Abort()
+ return
+ } else {
+ rdb.LPush(ctx, key, time.Now().Format(timeFormat))
+ rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
+ rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ }
+ }
+}
+
+func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) {
+ key := mark + c.ClientIP()
+ if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
+ c.Status(http.StatusTooManyRequests)
+ c.Abort()
+ return
+ }
+}
+
+func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
+ if common.RedisEnabled {
+ return func(c *gin.Context) {
+ redisRateLimiter(c, maxRequestNum, duration, mark)
+ }
+ } else {
+ // It's safe to call multi times.
+ inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
+ return func(c *gin.Context) {
+ memoryRateLimiter(c, maxRequestNum, duration, mark)
+ }
+ }
+}
+
+func GlobalWebRateLimit() func(c *gin.Context) {
+ if common.GlobalWebRateLimitEnable {
+ return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW")
+ }
+ return defNext
+}
+
+func GlobalAPIRateLimit() func(c *gin.Context) {
+ if common.GlobalApiRateLimitEnable {
+ return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA")
+ }
+ return defNext
+}
+
+func CriticalRateLimit() func(c *gin.Context) {
+ return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT")
+}
+
+func DownloadRateLimit() func(c *gin.Context) {
+ return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW")
+}
+
+func UploadRateLimit() func(c *gin.Context) {
+ return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
+}
diff --git a/middleware/recover.go b/middleware/recover.go
new file mode 100644
index 00000000..51fc7190
--- /dev/null
+++ b/middleware/recover.go
@@ -0,0 +1,28 @@
+package middleware
+
+import (
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/common"
+ "runtime/debug"
+)
+
+func RelayPanicRecover() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ defer func() {
+ if err := recover(); err != nil {
+ common.SysError(fmt.Sprintf("panic detected: %v", err))
+ common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "error": gin.H{
+ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
+ "type": "new_api_panic",
+ },
+ })
+ c.Abort()
+ }
+ }()
+ c.Next()
+ }
+}
diff --git a/middleware/request-id.go b/middleware/request-id.go
new file mode 100644
index 00000000..e623be7a
--- /dev/null
+++ b/middleware/request-id.go
@@ -0,0 +1,18 @@
+package middleware
+
+import (
+ "context"
+ "github.com/gin-gonic/gin"
+ "one-api/common"
+)
+
+func RequestId() func(c *gin.Context) {
+ return func(c *gin.Context) {
+ id := common.GetTimeString() + common.GetRandomString(8)
+ c.Set(common.RequestIdKey, id)
+ ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
+ c.Request = c.Request.WithContext(ctx)
+ c.Header(common.RequestIdKey, id)
+ c.Next()
+ }
+}
diff --git a/middleware/stats.go b/middleware/stats.go
new file mode 100644
index 00000000..1c97983f
--- /dev/null
+++ b/middleware/stats.go
@@ -0,0 +1,41 @@
+package middleware
+
+import (
+ "sync/atomic"
+
+ "github.com/gin-gonic/gin"
+)
+
+// HTTPStats 存储HTTP统计信息
+type HTTPStats struct {
+ activeConnections int64
+}
+
+var globalStats = &HTTPStats{}
+
+// StatsMiddleware 统计中间件
+func StatsMiddleware() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ // 增加活跃连接数
+ atomic.AddInt64(&globalStats.activeConnections, 1)
+
+ // 确保在请求结束时减少连接数
+ defer func() {
+ atomic.AddInt64(&globalStats.activeConnections, -1)
+ }()
+
+ c.Next()
+ }
+}
+
+// StatsInfo 统计信息结构
+type StatsInfo struct {
+ ActiveConnections int64 `json:"active_connections"`
+}
+
+// GetStats 获取统计信息
+func GetStats() StatsInfo {
+ return StatsInfo{
+ ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
+ }
+}
\ No newline at end of file
diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go
new file mode 100644
index 00000000..26688810
--- /dev/null
+++ b/middleware/turnstile-check.go
@@ -0,0 +1,80 @@
+package middleware
+
+import (
+ "encoding/json"
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "net/url"
+ "one-api/common"
+)
+
+type turnstileCheckResponse struct {
+ Success bool `json:"success"`
+}
+
+func TurnstileCheck() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ if common.TurnstileCheckEnabled {
+ session := sessions.Default(c)
+ turnstileChecked := session.Get("turnstile")
+ if turnstileChecked != nil {
+ c.Next()
+ return
+ }
+ response := c.Query("turnstile")
+ if response == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "Turnstile token 为空",
+ })
+ c.Abort()
+ return
+ }
+ rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{
+ "secret": {common.TurnstileSecretKey},
+ "response": {response},
+ "remoteip": {c.ClientIP()},
+ })
+ if err != nil {
+ common.SysError(err.Error())
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ c.Abort()
+ return
+ }
+ defer rawRes.Body.Close()
+ var res turnstileCheckResponse
+ err = json.NewDecoder(rawRes.Body).Decode(&res)
+ if err != nil {
+ common.SysError(err.Error())
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ c.Abort()
+ return
+ }
+ if !res.Success {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "Turnstile 校验失败,请刷新重试!",
+ })
+ c.Abort()
+ return
+ }
+ session.Set("turnstile", true)
+ err = session.Save()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "无法保存会话信息,请重试",
+ "success": false,
+ })
+ return
+ }
+ }
+ c.Next()
+ }
+}
diff --git a/middleware/utils.go b/middleware/utils.go
new file mode 100644
index 00000000..082f5657
--- /dev/null
+++ b/middleware/utils.go
@@ -0,0 +1,29 @@
+package middleware
+
+import (
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "one-api/common"
+)
+
+func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
+ userId := c.GetInt("id")
+ c.JSON(statusCode, gin.H{
+ "error": gin.H{
+ "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
+ "type": "new_api_error",
+ },
+ })
+ c.Abort()
+ common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
+}
+
+func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
+ c.JSON(statusCode, gin.H{
+ "description": description,
+ "type": "new_api_error",
+ "code": code,
+ })
+ c.Abort()
+ common.LogError(c.Request.Context(), description)
+}
diff --git a/model/ability.go b/model/ability.go
new file mode 100644
index 00000000..f36ff764
--- /dev/null
+++ b/model/ability.go
@@ -0,0 +1,320 @@
+package model
+
+import (
+ "errors"
+ "fmt"
+ "one-api/common"
+ "strings"
+ "sync"
+
+ "github.com/samber/lo"
+ "gorm.io/gorm"
+ "gorm.io/gorm/clause"
+)
+
+type Ability struct {
+ Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
+ Model string `json:"model" gorm:"type:varchar(255);primaryKey;autoIncrement:false"`
+ ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
+ Enabled bool `json:"enabled"`
+ Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
+ Weight uint `json:"weight" gorm:"default:0;index"`
+ Tag *string `json:"tag" gorm:"index"`
+}
+
+type AbilityWithChannel struct {
+ Ability
+ ChannelType int `json:"channel_type"`
+}
+
+func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
+ var abilities []AbilityWithChannel
+ err := DB.Table("abilities").
+ Select("abilities.*, channels.type as channel_type").
+ Joins("left join channels on abilities.channel_id = channels.id").
+ Where("abilities.enabled = ?", true).
+ Scan(&abilities).Error
+ return abilities, err
+}
+
+func GetGroupEnabledModels(group string) []string {
+ var models []string
+ // Find distinct models
+ DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
+ return models
+}
+
+func GetEnabledModels() []string {
+ var models []string
+ // Find distinct models
+ DB.Table("abilities").Where("enabled = ?", true).Distinct("model").Pluck("model", &models)
+ return models
+}
+
+func GetAllEnableAbilities() []Ability {
+ var abilities []Ability
+ DB.Find(&abilities, "enabled = ?", true)
+ return abilities
+}
+
+func getPriority(group string, model string, retry int) (int, error) {
+
+ var priorities []int
+ err := DB.Model(&Ability{}).
+ Select("DISTINCT(priority)").
+ Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
+ Order("priority DESC"). // 按优先级降序排序
+ Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
+
+ if err != nil {
+ // 处理错误
+ return 0, err
+ }
+
+ if len(priorities) == 0 {
+ // 如果没有查询到优先级,则返回错误
+ return 0, errors.New("数据库一致性被破坏")
+ }
+
+ // 确定要使用的优先级
+ var priorityToUse int
+ if retry >= len(priorities) {
+ // 如果重试次数大于优先级数,则使用最小的优先级
+ priorityToUse = priorities[len(priorities)-1]
+ } else {
+ priorityToUse = priorities[retry]
+ }
+ return priorityToUse, nil
+}
+
+func getChannelQuery(group string, model string, retry int) (*gorm.DB, error) {
+ maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
+ channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
+ if retry != 0 {
+ priority, err := getPriority(group, model, retry)
+ if err != nil {
+ return nil, err
+ } else {
+ channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
+ }
+ }
+
+ return channelQuery, nil
+}
+
+func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
+ var abilities []Ability
+
+ var err error = nil
+ channelQuery, err := getChannelQuery(group, model, retry)
+ if err != nil {
+ return nil, err
+ }
+ if common.UsingSQLite || common.UsingPostgreSQL {
+ err = channelQuery.Order("weight DESC").Find(&abilities).Error
+ } else {
+ err = channelQuery.Order("weight DESC").Find(&abilities).Error
+ }
+ if err != nil {
+ return nil, err
+ }
+ channel := Channel{}
+ if len(abilities) > 0 {
+ // Randomly choose one
+ weightSum := uint(0)
+ for _, ability_ := range abilities {
+ weightSum += ability_.Weight + 10
+ }
+ // Randomly choose one
+ weight := common.GetRandomInt(int(weightSum))
+ for _, ability_ := range abilities {
+ weight -= int(ability_.Weight) + 10
+ //log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
+ if weight <= 0 {
+ channel.Id = ability_.ChannelId
+ break
+ }
+ }
+ } else {
+ return nil, errors.New("channel not found")
+ }
+ err = DB.First(&channel, "id = ?", channel.Id).Error
+ return &channel, err
+}
+
+func (channel *Channel) AddAbilities() error {
+ models_ := strings.Split(channel.Models, ",")
+ groups_ := strings.Split(channel.Group, ",")
+ abilitySet := make(map[string]struct{})
+ abilities := make([]Ability, 0, len(models_))
+ for _, model := range models_ {
+ for _, group := range groups_ {
+ key := group + "|" + model
+ if _, exists := abilitySet[key]; exists {
+ continue
+ }
+ abilitySet[key] = struct{}{}
+ ability := Ability{
+ Group: group,
+ Model: model,
+ ChannelId: channel.Id,
+ Enabled: channel.Status == common.ChannelStatusEnabled,
+ Priority: channel.Priority,
+ Weight: uint(channel.GetWeight()),
+ Tag: channel.Tag,
+ }
+ abilities = append(abilities, ability)
+ }
+ }
+ if len(abilities) == 0 {
+ return nil
+ }
+ for _, chunk := range lo.Chunk(abilities, 50) {
+ err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (channel *Channel) DeleteAbilities() error {
+ return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
+}
+
+// UpdateAbilities updates abilities of this channel.
+// Make sure the channel is completed before calling this function.
+func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
+ isNewTx := false
+ // 如果没有传入事务,创建新的事务
+ if tx == nil {
+ tx = DB.Begin()
+ if tx.Error != nil {
+ return tx.Error
+ }
+ isNewTx = true
+ defer func() {
+ if r := recover(); r != nil {
+ tx.Rollback()
+ }
+ }()
+ }
+
+ // First delete all abilities of this channel
+ err := tx.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
+ if err != nil {
+ if isNewTx {
+ tx.Rollback()
+ }
+ return err
+ }
+
+ // Then add new abilities
+ models_ := strings.Split(channel.Models, ",")
+ groups_ := strings.Split(channel.Group, ",")
+ abilitySet := make(map[string]struct{})
+ abilities := make([]Ability, 0, len(models_))
+ for _, model := range models_ {
+ for _, group := range groups_ {
+ key := group + "|" + model
+ if _, exists := abilitySet[key]; exists {
+ continue
+ }
+ abilitySet[key] = struct{}{}
+ ability := Ability{
+ Group: group,
+ Model: model,
+ ChannelId: channel.Id,
+ Enabled: channel.Status == common.ChannelStatusEnabled,
+ Priority: channel.Priority,
+ Weight: uint(channel.GetWeight()),
+ Tag: channel.Tag,
+ }
+ abilities = append(abilities, ability)
+ }
+ }
+
+ if len(abilities) > 0 {
+ for _, chunk := range lo.Chunk(abilities, 50) {
+ err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
+ if err != nil {
+ if isNewTx {
+ tx.Rollback()
+ }
+ return err
+ }
+ }
+ }
+
+ // 如果是新创建的事务,需要提交
+ if isNewTx {
+ return tx.Commit().Error
+ }
+
+ return nil
+}
+
+func UpdateAbilityStatus(channelId int, status bool) error {
+ return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
+}
+
+func UpdateAbilityStatusByTag(tag string, status bool) error {
+ return DB.Model(&Ability{}).Where("tag = ?", tag).Select("enabled").Update("enabled", status).Error
+}
+
+func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uint) error {
+ ability := Ability{}
+ if newTag != nil {
+ ability.Tag = newTag
+ }
+ if priority != nil {
+ ability.Priority = priority
+ }
+ if weight != nil {
+ ability.Weight = *weight
+ }
+ return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
+}
+
+var fixLock = sync.Mutex{}
+
+func FixAbility() (int, int, error) {
+ lock := fixLock.TryLock()
+ if !lock {
+ return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
+ }
+ defer fixLock.Unlock()
+ var channels []*Channel
+ // Find all channels
+ err := DB.Model(&Channel{}).Find(&channels).Error
+ if err != nil {
+ return 0, 0, err
+ }
+ if len(channels) == 0 {
+ return 0, 0, nil
+ }
+ successCount := 0
+ failCount := 0
+ for _, chunk := range lo.Chunk(channels, 50) {
+ ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
+ // Delete all abilities of this channel
+ err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
+ if err != nil {
+ common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
+ failCount += len(chunk)
+ continue
+ }
+ // Then add new abilities
+ for _, channel := range chunk {
+ err = channel.AddAbilities()
+ if err != nil {
+ common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
+ failCount++
+ } else {
+ successCount++
+ }
+ }
+ }
+ InitChannelCache()
+ return successCount, failCount, nil
+}
diff --git a/model/channel.go b/model/channel.go
new file mode 100644
index 00000000..6277fcda
--- /dev/null
+++ b/model/channel.go
@@ -0,0 +1,909 @@
+package model
+
+import (
+ "database/sql/driver"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "math/rand"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/types"
+ "strings"
+ "sync"
+
+ "gorm.io/gorm"
+)
+
+type Channel struct {
+ Id int `json:"id"`
+ Type int `json:"type" gorm:"default:0"`
+ Key string `json:"key" gorm:"not null"`
+ OpenAIOrganization *string `json:"openai_organization"`
+ TestModel *string `json:"test_model"`
+ Status int `json:"status" gorm:"default:1"`
+ Name string `json:"name" gorm:"index"`
+ Weight *uint `json:"weight" gorm:"default:0"`
+ CreatedTime int64 `json:"created_time" gorm:"bigint"`
+ TestTime int64 `json:"test_time" gorm:"bigint"`
+ ResponseTime int `json:"response_time"` // in milliseconds
+ BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
+ Other string `json:"other"`
+ Balance float64 `json:"balance"` // in USD
+ BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
+ Models string `json:"models"`
+ Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
+ UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
+ ModelMapping *string `json:"model_mapping" gorm:"type:text"`
+ //MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"`
+ StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
+ Priority *int64 `json:"priority" gorm:"bigint;default:0"`
+ AutoBan *int `json:"auto_ban" gorm:"default:1"`
+ OtherInfo string `json:"other_info"`
+ Tag *string `json:"tag" gorm:"index"`
+ Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
+ ParamOverride *string `json:"param_override" gorm:"type:text"`
+ // add after v0.8.5
+ ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
+}
+
+type ChannelInfo struct {
+ IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
+ MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
+ MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
+ MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
+ MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
+}
+
+// Value implements driver.Valuer interface
+func (c ChannelInfo) Value() (driver.Value, error) {
+ return common.Marshal(&c)
+}
+
+// Scan implements sql.Scanner interface
+func (c *ChannelInfo) Scan(value interface{}) error {
+ bytesValue, _ := value.([]byte)
+ return common.Unmarshal(bytesValue, c)
+}
+
+func (channel *Channel) getKeys() []string {
+ if channel.Key == "" {
+ return []string{}
+ }
+ trimmed := strings.TrimSpace(channel.Key)
+ // If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
+ if strings.HasPrefix(trimmed, "[") {
+ var arr []json.RawMessage
+ if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
+ res := make([]string, len(arr))
+ for i, v := range arr {
+ res[i] = string(v)
+ }
+ return res
+ }
+ }
+ // Otherwise, fall back to splitting by newline
+ keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
+ return keys
+}
+
+func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
+ // If not in multi-key mode, return the original key string directly.
+ if !channel.ChannelInfo.IsMultiKey {
+ return channel.Key, 0, nil
+ }
+
+ // Obtain all keys (split by \n)
+ keys := channel.getKeys()
+ if len(keys) == 0 {
+ // No keys available, return error, should disable the channel
+ return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
+ }
+
+ statusList := channel.ChannelInfo.MultiKeyStatusList
+ // helper to get key status, default to enabled when missing
+ getStatus := func(idx int) int {
+ if statusList == nil {
+ return common.ChannelStatusEnabled
+ }
+ if status, ok := statusList[idx]; ok {
+ return status
+ }
+ return common.ChannelStatusEnabled
+ }
+
+ // Collect indexes of enabled keys
+ enabledIdx := make([]int, 0, len(keys))
+ for i := range keys {
+ if getStatus(i) == common.ChannelStatusEnabled {
+ enabledIdx = append(enabledIdx, i)
+ }
+ }
+ // If no specific status list or none enabled, fall back to first key
+ if len(enabledIdx) == 0 {
+ return keys[0], 0, nil
+ }
+
+ switch channel.ChannelInfo.MultiKeyMode {
+ case constant.MultiKeyModeRandom:
+ // Randomly pick one enabled key
+ selectedIdx := enabledIdx[rand.Intn(len(enabledIdx))]
+ return keys[selectedIdx], selectedIdx, nil
+ case constant.MultiKeyModePolling:
+ // Use channel-specific lock to ensure thread-safe polling
+ lock := getChannelPollingLock(channel.Id)
+ lock.Lock()
+ defer lock.Unlock()
+
+ channelInfo, err := CacheGetChannelInfo(channel.Id)
+ if err != nil {
+ return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed)
+ }
+ //println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
+ defer func() {
+ if common.DebugEnabled {
+ println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex))
+ }
+ if !common.MemoryCacheEnabled {
+ _ = channel.SaveChannelInfo()
+ } else {
+ // CacheUpdateChannel(channel)
+ }
+ }()
+ // Start from the saved polling index and look for the next enabled key
+ start := channelInfo.MultiKeyPollingIndex
+ if start < 0 || start >= len(keys) {
+ start = 0
+ }
+ for i := 0; i < len(keys); i++ {
+ idx := (start + i) % len(keys)
+ if getStatus(idx) == common.ChannelStatusEnabled {
+ // update polling index for next call (point to the next position)
+ channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
+ return keys[idx], idx, nil
+ }
+ }
+ // Fallback – should not happen, but return first enabled key
+ return keys[enabledIdx[0]], enabledIdx[0], nil
+ default:
+ // Unknown mode, default to first enabled key (or original key string)
+ return keys[enabledIdx[0]], enabledIdx[0], nil
+ }
+}
+
+func (channel *Channel) SaveChannelInfo() error {
+ return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error
+}
+
+func (channel *Channel) GetModels() []string {
+ if channel.Models == "" {
+ return []string{}
+ }
+ return strings.Split(strings.Trim(channel.Models, ","), ",")
+}
+
+func (channel *Channel) GetGroups() []string {
+ if channel.Group == "" {
+ return []string{}
+ }
+ groups := strings.Split(strings.Trim(channel.Group, ","), ",")
+ for i, group := range groups {
+ groups[i] = strings.TrimSpace(group)
+ }
+ return groups
+}
+
+func (channel *Channel) GetOtherInfo() map[string]interface{} {
+ otherInfo := make(map[string]interface{})
+ if channel.OtherInfo != "" {
+ err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
+ if err != nil {
+ common.SysError("failed to unmarshal other info: " + err.Error())
+ }
+ }
+ return otherInfo
+}
+
+func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
+ otherInfoBytes, err := json.Marshal(otherInfo)
+ if err != nil {
+ common.SysError("failed to marshal other info: " + err.Error())
+ return
+ }
+ channel.OtherInfo = string(otherInfoBytes)
+}
+
+func (channel *Channel) GetTag() string {
+ if channel.Tag == nil {
+ return ""
+ }
+ return *channel.Tag
+}
+
+func (channel *Channel) SetTag(tag string) {
+ channel.Tag = &tag
+}
+
+func (channel *Channel) GetAutoBan() bool {
+ if channel.AutoBan == nil {
+ return false
+ }
+ return *channel.AutoBan == 1
+}
+
+func (channel *Channel) Save() error {
+ return DB.Save(channel).Error
+}
+
+func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
+ var channels []*Channel
+ var err error
+ order := "priority desc"
+ if idSort {
+ order = "id desc"
+ }
+ if selectAll {
+ err = DB.Order(order).Find(&channels).Error
+ } else {
+ err = DB.Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
+ }
+ return channels, err
+}
+
+func GetChannelsByTag(tag string, idSort bool) ([]*Channel, error) {
+ var channels []*Channel
+ order := "priority desc"
+ if idSort {
+ order = "id desc"
+ }
+ err := DB.Where("tag = ?", tag).Order(order).Find(&channels).Error
+ return channels, err
+}
+
+func SearchChannels(keyword string, group string, model string, idSort bool) ([]*Channel, error) {
+ var channels []*Channel
+ modelsCol := "`models`"
+
+ // 如果是 PostgreSQL,使用双引号
+ if common.UsingPostgreSQL {
+ modelsCol = `"models"`
+ }
+
+ baseURLCol := "`base_url`"
+ // 如果是 PostgreSQL,使用双引号
+ if common.UsingPostgreSQL {
+ baseURLCol = `"base_url"`
+ }
+
+ order := "priority desc"
+ if idSort {
+ order = "id desc"
+ }
+
+ // 构造基础查询
+ baseQuery := DB.Model(&Channel{}).Omit("key")
+
+ // 构造WHERE子句
+ var whereClause string
+ var args []interface{}
+ if group != "" && group != "null" {
+ var groupCondition string
+ if common.UsingMySQL {
+ groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
+ } else {
+ // sqlite, PostgreSQL
+ groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
+ }
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
+ args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
+ } else {
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
+ args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
+ }
+
+ // 执行查询
+ err := baseQuery.Where(whereClause, args...).Order(order).Find(&channels).Error
+ if err != nil {
+ return nil, err
+ }
+ return channels, nil
+}
+
+func GetChannelById(id int, selectAll bool) (*Channel, error) {
+ channel := &Channel{Id: id}
+ var err error = nil
+ if selectAll {
+ err = DB.First(channel, "id = ?", id).Error
+ } else {
+ err = DB.Omit("key").First(channel, "id = ?", id).Error
+ }
+ if err != nil {
+ return nil, err
+ }
+ if channel == nil {
+ return nil, errors.New("channel not found")
+ }
+ return channel, nil
+}
+
+func BatchInsertChannels(channels []Channel) error {
+ var err error
+ err = DB.Create(&channels).Error
+ if err != nil {
+ return err
+ }
+ for _, channel_ := range channels {
+ err = channel_.AddAbilities()
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func BatchDeleteChannels(ids []int) error {
+ //使用事务 删除channel表和channel_ability表
+ tx := DB.Begin()
+ err := tx.Where("id in (?)", ids).Delete(&Channel{}).Error
+ if err != nil {
+ // 回滚事务
+ tx.Rollback()
+ return err
+ }
+ err = tx.Where("channel_id in (?)", ids).Delete(&Ability{}).Error
+ if err != nil {
+ // 回滚事务
+ tx.Rollback()
+ return err
+ }
+ // 提交事务
+ tx.Commit()
+ return err
+}
+
+func (channel *Channel) GetPriority() int64 {
+ if channel.Priority == nil {
+ return 0
+ }
+ return *channel.Priority
+}
+
+func (channel *Channel) GetWeight() int {
+ if channel.Weight == nil {
+ return 0
+ }
+ return int(*channel.Weight)
+}
+
+func (channel *Channel) GetBaseURL() string {
+ if channel.BaseURL == nil {
+ return ""
+ }
+ return *channel.BaseURL
+}
+
+func (channel *Channel) GetModelMapping() string {
+ if channel.ModelMapping == nil {
+ return ""
+ }
+ return *channel.ModelMapping
+}
+
+func (channel *Channel) GetStatusCodeMapping() string {
+ if channel.StatusCodeMapping == nil {
+ return ""
+ }
+ return *channel.StatusCodeMapping
+}
+
+func (channel *Channel) Insert() error {
+ var err error
+ err = DB.Create(channel).Error
+ if err != nil {
+ return err
+ }
+ err = channel.AddAbilities()
+ return err
+}
+
+func (channel *Channel) Update() error {
+ // If this is a multi-key channel, recalculate MultiKeySize based on the current key list to avoid inconsistency after editing keys
+ if channel.ChannelInfo.IsMultiKey {
+ var keyStr string
+ if channel.Key != "" {
+ keyStr = channel.Key
+ } else {
+ // If key is not provided, read the existing key from the database
+ if existing, err := GetChannelById(channel.Id, true); err == nil {
+ keyStr = existing.Key
+ }
+ }
+ // Parse the key list (supports newline separation or JSON array)
+ keys := []string{}
+ if keyStr != "" {
+ trimmed := strings.TrimSpace(keyStr)
+ if strings.HasPrefix(trimmed, "[") {
+ var arr []json.RawMessage
+ if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
+ keys = make([]string, len(arr))
+ for i, v := range arr {
+ keys[i] = string(v)
+ }
+ }
+ }
+ if len(keys) == 0 { // fallback to newline split
+ keys = strings.Split(strings.Trim(keyStr, "\n"), "\n")
+ }
+ }
+ channel.ChannelInfo.MultiKeySize = len(keys)
+ // Clean up status data that exceeds the new key count to prevent index out of range
+ if channel.ChannelInfo.MultiKeyStatusList != nil {
+ for idx := range channel.ChannelInfo.MultiKeyStatusList {
+ if idx >= channel.ChannelInfo.MultiKeySize {
+ delete(channel.ChannelInfo.MultiKeyStatusList, idx)
+ }
+ }
+ }
+ }
+ var err error
+ err = DB.Model(channel).Updates(channel).Error
+ if err != nil {
+ return err
+ }
+ DB.Model(channel).First(channel, "id = ?", channel.Id)
+ err = channel.UpdateAbilities(nil)
+ return err
+}
+
+func (channel *Channel) UpdateResponseTime(responseTime int64) {
+ err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
+ TestTime: common.GetTimestamp(),
+ ResponseTime: int(responseTime),
+ }).Error
+ if err != nil {
+ common.SysError("failed to update response time: " + err.Error())
+ }
+}
+
+func (channel *Channel) UpdateBalance(balance float64) {
+ err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
+ BalanceUpdatedTime: common.GetTimestamp(),
+ Balance: balance,
+ }).Error
+ if err != nil {
+ common.SysError("failed to update balance: " + err.Error())
+ }
+}
+
+func (channel *Channel) Delete() error {
+ var err error
+ err = DB.Delete(channel).Error
+ if err != nil {
+ return err
+ }
+ err = channel.DeleteAbilities()
+ return err
+}
+
+var channelStatusLock sync.Mutex
+
+// channelPollingLocks stores locks for each channel.id to ensure thread-safe polling
+var channelPollingLocks sync.Map
+
+// getChannelPollingLock returns or creates a mutex for the given channel ID
+func getChannelPollingLock(channelId int) *sync.Mutex {
+ if lock, exists := channelPollingLocks.Load(channelId); exists {
+ return lock.(*sync.Mutex)
+ }
+ // Create new lock for this channel
+ newLock := &sync.Mutex{}
+ actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock)
+ return actual.(*sync.Mutex)
+}
+
+// CleanupChannelPollingLocks removes locks for channels that no longer exist
+// This is optional and can be called periodically to prevent memory leaks
+func CleanupChannelPollingLocks() {
+ var activeChannelIds []int
+ DB.Model(&Channel{}).Pluck("id", &activeChannelIds)
+
+ activeChannelSet := make(map[int]bool)
+ for _, id := range activeChannelIds {
+ activeChannelSet[id] = true
+ }
+
+ channelPollingLocks.Range(func(key, value interface{}) bool {
+ channelId := key.(int)
+ if !activeChannelSet[channelId] {
+ channelPollingLocks.Delete(channelId)
+ }
+ return true
+ })
+}
+
+func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
+ keys := channel.getKeys()
+ if len(keys) == 0 {
+ channel.Status = status
+ } else {
+ var keyIndex int
+ for i, key := range keys {
+ if key == usingKey {
+ keyIndex = i
+ break
+ }
+ }
+ if channel.ChannelInfo.MultiKeyStatusList == nil {
+ channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
+ }
+ if status == common.ChannelStatusEnabled {
+ delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
+ } else {
+ channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
+ }
+ if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
+ channel.Status = common.ChannelStatusAutoDisabled
+ info := channel.GetOtherInfo()
+ info["status_reason"] = "All keys are disabled"
+ info["status_time"] = common.GetTimestamp()
+ channel.SetOtherInfo(info)
+ }
+ }
+}
+
+func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool {
+ if common.MemoryCacheEnabled {
+ channelStatusLock.Lock()
+ defer channelStatusLock.Unlock()
+
+ channelCache, _ := CacheGetChannel(channelId)
+ if channelCache == nil {
+ return false
+ }
+ if channelCache.ChannelInfo.IsMultiKey {
+ // 如果是多Key模式,更新缓存中的状态
+ handlerMultiKeyUpdate(channelCache, usingKey, status)
+ //CacheUpdateChannel(channelCache)
+ //return true
+ } else {
+ // 如果缓存渠道存在,且状态已是目标状态,直接返回
+ if channelCache.Status == status {
+ return false
+ }
+ // 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
+ if status != common.ChannelStatusEnabled {
+ return false
+ }
+ CacheUpdateChannelStatus(channelId, status)
+ }
+ }
+
+ shouldUpdateAbilities := false
+ defer func() {
+ if shouldUpdateAbilities {
+ err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
+ if err != nil {
+ common.SysError("failed to update ability status: " + err.Error())
+ }
+ }
+ }()
+ channel, err := GetChannelById(channelId, true)
+ if err != nil {
+ return false
+ } else {
+ if channel.Status == status {
+ return false
+ }
+
+ if channel.ChannelInfo.IsMultiKey {
+ beforeStatus := channel.Status
+ handlerMultiKeyUpdate(channel, usingKey, status)
+ if beforeStatus != channel.Status {
+ shouldUpdateAbilities = true
+ }
+ } else {
+ info := channel.GetOtherInfo()
+ info["status_reason"] = reason
+ info["status_time"] = common.GetTimestamp()
+ channel.SetOtherInfo(info)
+ channel.Status = status
+ shouldUpdateAbilities = true
+ }
+ err = channel.Save()
+ if err != nil {
+ common.SysError("failed to update channel status: " + err.Error())
+ return false
+ }
+ }
+ return true
+}
+
+func EnableChannelByTag(tag string) error {
+ err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusEnabled).Error
+ if err != nil {
+ return err
+ }
+ err = UpdateAbilityStatusByTag(tag, true)
+ return err
+}
+
+func DisableChannelByTag(tag string) error {
+ err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusManuallyDisabled).Error
+ if err != nil {
+ return err
+ }
+ err = UpdateAbilityStatusByTag(tag, false)
+ return err
+}
+
+func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *string, group *string, priority *int64, weight *uint) error {
+ updateData := Channel{}
+ shouldReCreateAbilities := false
+ updatedTag := tag
+ // 如果 newTag 不为空且不等于 tag,则更新 tag
+ if newTag != nil && *newTag != tag {
+ updateData.Tag = newTag
+ updatedTag = *newTag
+ }
+ if modelMapping != nil && *modelMapping != "" {
+ updateData.ModelMapping = modelMapping
+ }
+ if models != nil && *models != "" {
+ shouldReCreateAbilities = true
+ updateData.Models = *models
+ }
+ if group != nil && *group != "" {
+ shouldReCreateAbilities = true
+ updateData.Group = *group
+ }
+ if priority != nil {
+ updateData.Priority = priority
+ }
+ if weight != nil {
+ updateData.Weight = weight
+ }
+
+ err := DB.Model(&Channel{}).Where("tag = ?", tag).Updates(updateData).Error
+ if err != nil {
+ return err
+ }
+ if shouldReCreateAbilities {
+ channels, err := GetChannelsByTag(updatedTag, false)
+ if err == nil {
+ for _, channel := range channels {
+ err = channel.UpdateAbilities(nil)
+ if err != nil {
+ common.SysError("failed to update abilities: " + err.Error())
+ }
+ }
+ }
+ } else {
+ err := UpdateAbilityByTag(tag, newTag, priority, weight)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func UpdateChannelUsedQuota(id int, quota int) {
+ if common.BatchUpdateEnabled {
+ addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
+ return
+ }
+ updateChannelUsedQuota(id, quota)
+}
+
+func updateChannelUsedQuota(id int, quota int) {
+ err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
+ if err != nil {
+ common.SysError("failed to update channel used quota: " + err.Error())
+ }
+}
+
+func DeleteChannelByStatus(status int64) (int64, error) {
+ result := DB.Where("status = ?", status).Delete(&Channel{})
+ return result.RowsAffected, result.Error
+}
+
+func DeleteDisabledChannel() (int64, error) {
+ result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
+ return result.RowsAffected, result.Error
+}
+
+func GetPaginatedTags(offset int, limit int) ([]*string, error) {
+ var tags []*string
+ err := DB.Model(&Channel{}).Select("DISTINCT tag").Where("tag != ''").Offset(offset).Limit(limit).Find(&tags).Error
+ return tags, err
+}
+
+func SearchTags(keyword string, group string, model string, idSort bool) ([]*string, error) {
+ var tags []*string
+ modelsCol := "`models`"
+
+ // 如果是 PostgreSQL,使用双引号
+ if common.UsingPostgreSQL {
+ modelsCol = `"models"`
+ }
+
+ baseURLCol := "`base_url`"
+ // 如果是 PostgreSQL,使用双引号
+ if common.UsingPostgreSQL {
+ baseURLCol = `"base_url"`
+ }
+
+ order := "priority desc"
+ if idSort {
+ order = "id desc"
+ }
+
+ // 构造基础查询
+ baseQuery := DB.Model(&Channel{}).Omit("key")
+
+ // 构造WHERE子句
+ var whereClause string
+ var args []interface{}
+ if group != "" && group != "null" {
+ var groupCondition string
+ if common.UsingMySQL {
+ groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
+ } else {
+ // sqlite, PostgreSQL
+ groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
+ }
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
+ args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
+ } else {
+ whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
+ args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
+ }
+
+ subQuery := baseQuery.Where(whereClause, args...).
+ Select("tag").
+ Where("tag != ''").
+ Order(order)
+
+ err := DB.Table("(?) as sub", subQuery).
+ Select("DISTINCT tag").
+ Find(&tags).Error
+
+ if err != nil {
+ return nil, err
+ }
+
+ return tags, nil
+}
+
+func (channel *Channel) ValidateSettings() error {
+ channelParams := &dto.ChannelSettings{}
+ if channel.Setting != nil && *channel.Setting != "" {
+ err := json.Unmarshal([]byte(*channel.Setting), channelParams)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (channel *Channel) GetSetting() dto.ChannelSettings {
+ setting := dto.ChannelSettings{}
+ if channel.Setting != nil && *channel.Setting != "" {
+ err := json.Unmarshal([]byte(*channel.Setting), &setting)
+ if err != nil {
+ common.SysError("failed to unmarshal setting: " + err.Error())
+ channel.Setting = nil // 清空设置以避免后续错误
+ _ = channel.Save() // 保存修改
+ }
+ }
+ return setting
+}
+
+func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
+ settingBytes, err := json.Marshal(setting)
+ if err != nil {
+ common.SysError("failed to marshal setting: " + err.Error())
+ return
+ }
+ channel.Setting = common.GetPointer[string](string(settingBytes))
+}
+
+func (channel *Channel) GetParamOverride() map[string]interface{} {
+ paramOverride := make(map[string]interface{})
+ if channel.ParamOverride != nil && *channel.ParamOverride != "" {
+ err := json.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride)
+ if err != nil {
+ common.SysError("failed to unmarshal param override: " + err.Error())
+ }
+ }
+ return paramOverride
+}
+
+func GetChannelsByIds(ids []int) ([]*Channel, error) {
+ var channels []*Channel
+ err := DB.Where("id in (?)", ids).Find(&channels).Error
+ return channels, err
+}
+
+func BatchSetChannelTag(ids []int, tag *string) error {
+ // 开启事务
+ tx := DB.Begin()
+ if tx.Error != nil {
+ return tx.Error
+ }
+
+ // 更新标签
+ err := tx.Model(&Channel{}).Where("id in (?)", ids).Update("tag", tag).Error
+ if err != nil {
+ tx.Rollback()
+ return err
+ }
+
+ // update ability status
+ channels, err := GetChannelsByIds(ids)
+ if err != nil {
+ tx.Rollback()
+ return err
+ }
+
+ for _, channel := range channels {
+ err = channel.UpdateAbilities(tx)
+ if err != nil {
+ tx.Rollback()
+ return err
+ }
+ }
+
+ // 提交事务
+ return tx.Commit().Error
+}
+
+// CountAllChannels returns total channels in DB
+func CountAllChannels() (int64, error) {
+ var total int64
+ err := DB.Model(&Channel{}).Count(&total).Error
+ return total, err
+}
+
+// CountAllTags returns number of non-empty distinct tags
+func CountAllTags() (int64, error) {
+ var total int64
+ err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
+ return total, err
+}
+
+// Get channels of specified type with pagination
+func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
+ var channels []*Channel
+ order := "priority desc"
+ if idSort {
+ order = "id desc"
+ }
+ err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
+ return channels, err
+}
+
+// Count channels of specific type
+func CountChannelsByType(channelType int) (int64, error) {
+ var count int64
+ err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
+ return count, err
+}
+
+// Return map[type]count for all channels
+func CountChannelsGroupByType() (map[int64]int64, error) {
+ type result struct {
+ Type int64 `gorm:"column:type"`
+ Count int64 `gorm:"column:count"`
+ }
+ var results []result
+ err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
+ if err != nil {
+ return nil, err
+ }
+ counts := make(map[int64]int64)
+ for _, r := range results {
+ counts[r.Type] = r.Count
+ }
+ return counts, nil
+}
diff --git a/model/channel_cache.go b/model/channel_cache.go
new file mode 100644
index 00000000..b2451248
--- /dev/null
+++ b/model/channel_cache.go
@@ -0,0 +1,262 @@
+package model
+
+import (
+ "errors"
+ "fmt"
+ "math/rand"
+ "one-api/common"
+ "one-api/setting"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+var group2model2channels map[string]map[string][]int // enabled channel
+var channelsIDM map[int]*Channel // all channels include disabled
+var channelSyncLock sync.RWMutex
+
+func InitChannelCache() {
+ if !common.MemoryCacheEnabled {
+ return
+ }
+ newChannelId2channel := make(map[int]*Channel)
+ var channels []*Channel
+ DB.Find(&channels)
+ for _, channel := range channels {
+ newChannelId2channel[channel.Id] = channel
+ }
+ var abilities []*Ability
+ DB.Find(&abilities)
+ groups := make(map[string]bool)
+ for _, ability := range abilities {
+ groups[ability.Group] = true
+ }
+ newGroup2model2channels := make(map[string]map[string][]int)
+ for group := range groups {
+ newGroup2model2channels[group] = make(map[string][]int)
+ }
+ for _, channel := range channels {
+ if channel.Status != common.ChannelStatusEnabled {
+ continue // skip disabled channels
+ }
+ groups := strings.Split(channel.Group, ",")
+ for _, group := range groups {
+ models := strings.Split(channel.Models, ",")
+ for _, model := range models {
+ if _, ok := newGroup2model2channels[group][model]; !ok {
+ newGroup2model2channels[group][model] = make([]int, 0)
+ }
+ newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id)
+ }
+ }
+ }
+
+ // sort by priority
+ for group, model2channels := range newGroup2model2channels {
+ for model, channels := range model2channels {
+ sort.Slice(channels, func(i, j int) bool {
+ return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
+ })
+ newGroup2model2channels[group][model] = channels
+ }
+ }
+
+ channelSyncLock.Lock()
+ group2model2channels = newGroup2model2channels
+ channelsIDM = newChannelId2channel
+ channelSyncLock.Unlock()
+ common.SysLog("channels synced from database")
+}
+
+func SyncChannelCache(frequency int) {
+ for {
+ time.Sleep(time.Duration(frequency) * time.Second)
+ common.SysLog("syncing channels from database")
+ InitChannelCache()
+ }
+}
+
+func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
+ var channel *Channel
+ var err error
+ selectGroup := group
+ if group == "auto" {
+ if len(setting.AutoGroups) == 0 {
+ return nil, selectGroup, errors.New("auto groups is not enabled")
+ }
+ for _, autoGroup := range setting.AutoGroups {
+ if common.DebugEnabled {
+ println("autoGroup:", autoGroup)
+ }
+ channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
+ if channel == nil {
+ continue
+ } else {
+ c.Set("auto_group", autoGroup)
+ selectGroup = autoGroup
+ if common.DebugEnabled {
+ println("selectGroup:", selectGroup)
+ }
+ break
+ }
+ }
+ } else {
+ channel, err = getRandomSatisfiedChannel(group, model, retry)
+ if err != nil {
+ return nil, group, err
+ }
+ }
+ if channel == nil {
+ return nil, group, errors.New("channel not found")
+ }
+ return channel, selectGroup, nil
+}
+
+func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
+ if strings.HasPrefix(model, "gpt-4-gizmo") {
+ model = "gpt-4-gizmo-*"
+ }
+ if strings.HasPrefix(model, "gpt-4o-gizmo") {
+ model = "gpt-4o-gizmo-*"
+ }
+
+ // if memory cache is disabled, get channel directly from database
+ if !common.MemoryCacheEnabled {
+ return GetRandomSatisfiedChannel(group, model, retry)
+ }
+
+ channelSyncLock.RLock()
+ defer channelSyncLock.RUnlock()
+ channels := group2model2channels[group][model]
+
+ if len(channels) == 0 {
+ return nil, errors.New("channel not found")
+ }
+
+ if len(channels) == 1 {
+ if channel, ok := channelsIDM[channels[0]]; ok {
+ return channel, nil
+ }
+ return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
+ }
+
+ uniquePriorities := make(map[int]bool)
+ for _, channelId := range channels {
+ if channel, ok := channelsIDM[channelId]; ok {
+ uniquePriorities[int(channel.GetPriority())] = true
+ } else {
+ return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
+ }
+ }
+ var sortedUniquePriorities []int
+ for priority := range uniquePriorities {
+ sortedUniquePriorities = append(sortedUniquePriorities, priority)
+ }
+ sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
+
+ if retry >= len(uniquePriorities) {
+ retry = len(uniquePriorities) - 1
+ }
+ targetPriority := int64(sortedUniquePriorities[retry])
+
+ // get the priority for the given retry number
+ var targetChannels []*Channel
+ for _, channelId := range channels {
+ if channel, ok := channelsIDM[channelId]; ok {
+ if channel.GetPriority() == targetPriority {
+ targetChannels = append(targetChannels, channel)
+ }
+ } else {
+ return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
+ }
+ }
+
+ // 平滑系数
+ smoothingFactor := 10
+ // Calculate the total weight of all channels up to endIdx
+ totalWeight := 0
+ for _, channel := range targetChannels {
+ totalWeight += channel.GetWeight() + smoothingFactor
+ }
+ // Generate a random value in the range [0, totalWeight)
+ randomWeight := rand.Intn(totalWeight)
+
+ // Find a channel based on its weight
+ for _, channel := range targetChannels {
+ randomWeight -= channel.GetWeight() + smoothingFactor
+ if randomWeight < 0 {
+ return channel, nil
+ }
+ }
+ // return null if no channel is not found
+ return nil, errors.New("channel not found")
+}
+
+func CacheGetChannel(id int) (*Channel, error) {
+ if !common.MemoryCacheEnabled {
+ return GetChannelById(id, true)
+ }
+ channelSyncLock.RLock()
+ defer channelSyncLock.RUnlock()
+
+ c, ok := channelsIDM[id]
+ if !ok {
+ return nil, fmt.Errorf("渠道# %d,已不存在", id)
+ }
+ if c.Status != common.ChannelStatusEnabled {
+ return nil, fmt.Errorf("渠道# %d,已被禁用", id)
+ }
+ return c, nil
+}
+
+func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
+ if !common.MemoryCacheEnabled {
+ channel, err := GetChannelById(id, true)
+ if err != nil {
+ return nil, err
+ }
+ return &channel.ChannelInfo, nil
+ }
+ channelSyncLock.RLock()
+ defer channelSyncLock.RUnlock()
+
+ c, ok := channelsIDM[id]
+ if !ok {
+ return nil, fmt.Errorf("渠道# %d,已不存在", id)
+ }
+ if c.Status != common.ChannelStatusEnabled {
+ return nil, fmt.Errorf("渠道# %d,已被禁用", id)
+ }
+ return &c.ChannelInfo, nil
+}
+
+func CacheUpdateChannelStatus(id int, status int) {
+ if !common.MemoryCacheEnabled {
+ return
+ }
+ channelSyncLock.Lock()
+ defer channelSyncLock.Unlock()
+ if channel, ok := channelsIDM[id]; ok {
+ channel.Status = status
+ }
+}
+
+func CacheUpdateChannel(channel *Channel) {
+ if !common.MemoryCacheEnabled {
+ return
+ }
+ channelSyncLock.Lock()
+ defer channelSyncLock.Unlock()
+ if channel == nil {
+ return
+ }
+
+ println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
+
+ println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
+ channelsIDM[channel.Id] = channel
+ println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
+}
diff --git a/model/log.go b/model/log.go
new file mode 100644
index 00000000..2070cd6f
--- /dev/null
+++ b/model/log.go
@@ -0,0 +1,411 @@
+package model
+
+import (
+ "context"
+ "fmt"
+ "one-api/common"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+
+ "github.com/bytedance/gopkg/util/gopool"
+ "gorm.io/gorm"
+)
+
+type Log struct {
+ Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"`
+ UserId int `json:"user_id" gorm:"index"`
+ CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
+ Type int `json:"type" gorm:"index:idx_created_at_type"`
+ Content string `json:"content"`
+ Username string `json:"username" gorm:"index;index:index_username_model_name,priority:2;default:''"`
+ TokenName string `json:"token_name" gorm:"index;default:''"`
+ ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
+ Quota int `json:"quota" gorm:"default:0"`
+ PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
+ CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
+ UseTime int `json:"use_time" gorm:"default:0"`
+ IsStream bool `json:"is_stream"`
+ ChannelId int `json:"channel" gorm:"index"`
+ ChannelName string `json:"channel_name" gorm:"->"`
+ TokenId int `json:"token_id" gorm:"default:0;index"`
+ Group string `json:"group" gorm:"index"`
+ Ip string `json:"ip" gorm:"index;default:''"`
+ Other string `json:"other"`
+}
+
+const (
+ LogTypeUnknown = iota
+ LogTypeTopup
+ LogTypeConsume
+ LogTypeManage
+ LogTypeSystem
+ LogTypeError
+)
+
+func formatUserLogs(logs []*Log) {
+ for i := range logs {
+ logs[i].ChannelName = ""
+ var otherMap map[string]interface{}
+ otherMap, _ = common.StrToMap(logs[i].Other)
+ if otherMap != nil {
+ // delete admin
+ delete(otherMap, "admin_info")
+ }
+ logs[i].Other = common.MapToJsonStr(otherMap)
+ logs[i].Id = logs[i].Id % 1024
+ }
+}
+
+func GetLogByKey(key string) (logs []*Log, err error) {
+ if os.Getenv("LOG_SQL_DSN") != "" {
+ var tk Token
+ if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
+ return nil, err
+ }
+ err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
+ } else {
+ err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
+ }
+ formatUserLogs(logs)
+ return logs, err
+}
+
+func RecordLog(userId int, logType int, content string) {
+ if logType == LogTypeConsume && !common.LogConsumeEnabled {
+ return
+ }
+ username, _ := GetUsernameById(userId, false)
+ log := &Log{
+ UserId: userId,
+ Username: username,
+ CreatedAt: common.GetTimestamp(),
+ Type: logType,
+ Content: content,
+ }
+ err := LOG_DB.Create(log).Error
+ if err != nil {
+ common.SysError("failed to record log: " + err.Error())
+ }
+}
+
+func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
+ isStream bool, group string, other map[string]interface{}) {
+ common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
+ username := c.GetString("username")
+ otherStr := common.MapToJsonStr(other)
+ // 判断是否需要记录 IP
+ needRecordIp := false
+ if settingMap, err := GetUserSetting(userId, false); err == nil {
+ if settingMap.RecordIpLog {
+ needRecordIp = true
+ }
+ }
+ log := &Log{
+ UserId: userId,
+ Username: username,
+ CreatedAt: common.GetTimestamp(),
+ Type: LogTypeError,
+ Content: content,
+ PromptTokens: 0,
+ CompletionTokens: 0,
+ TokenName: tokenName,
+ ModelName: modelName,
+ Quota: 0,
+ ChannelId: channelId,
+ TokenId: tokenId,
+ UseTime: useTimeSeconds,
+ IsStream: isStream,
+ Group: group,
+ Ip: func() string {
+ if needRecordIp {
+ return c.ClientIP()
+ }
+ return ""
+ }(),
+ Other: otherStr,
+ }
+ err := LOG_DB.Create(log).Error
+ if err != nil {
+ common.LogError(c, "failed to record log: "+err.Error())
+ }
+}
+
+type RecordConsumeLogParams struct {
+ ChannelId int `json:"channel_id"`
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ ModelName string `json:"model_name"`
+ TokenName string `json:"token_name"`
+ Quota int `json:"quota"`
+ Content string `json:"content"`
+ TokenId int `json:"token_id"`
+ UserQuota int `json:"user_quota"`
+ UseTimeSeconds int `json:"use_time_seconds"`
+ IsStream bool `json:"is_stream"`
+ Group string `json:"group"`
+ Other map[string]interface{} `json:"other"`
+}
+
+func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
+ common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
+ if !common.LogConsumeEnabled {
+ return
+ }
+ username := c.GetString("username")
+ otherStr := common.MapToJsonStr(params.Other)
+ // 判断是否需要记录 IP
+ needRecordIp := false
+ if settingMap, err := GetUserSetting(userId, false); err == nil {
+ if settingMap.RecordIpLog {
+ needRecordIp = true
+ }
+ }
+ log := &Log{
+ UserId: userId,
+ Username: username,
+ CreatedAt: common.GetTimestamp(),
+ Type: LogTypeConsume,
+ Content: params.Content,
+ PromptTokens: params.PromptTokens,
+ CompletionTokens: params.CompletionTokens,
+ TokenName: params.TokenName,
+ ModelName: params.ModelName,
+ Quota: params.Quota,
+ ChannelId: params.ChannelId,
+ TokenId: params.TokenId,
+ UseTime: params.UseTimeSeconds,
+ IsStream: params.IsStream,
+ Group: params.Group,
+ Ip: func() string {
+ if needRecordIp {
+ return c.ClientIP()
+ }
+ return ""
+ }(),
+ Other: otherStr,
+ }
+ err := LOG_DB.Create(log).Error
+ if err != nil {
+ common.LogError(c, "failed to record log: "+err.Error())
+ }
+ if common.DataExportEnabled {
+ gopool.Go(func() {
+ LogQuotaData(userId, username, params.ModelName, params.Quota, common.GetTimestamp(), params.PromptTokens+params.CompletionTokens)
+ })
+ }
+}
+
+func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string) (logs []*Log, total int64, err error) {
+ var tx *gorm.DB
+ if logType == LogTypeUnknown {
+ tx = LOG_DB
+ } else {
+ tx = LOG_DB.Where("logs.type = ?", logType)
+ }
+
+ if modelName != "" {
+ tx = tx.Where("logs.model_name like ?", modelName)
+ }
+ if username != "" {
+ tx = tx.Where("logs.username = ?", username)
+ }
+ if tokenName != "" {
+ tx = tx.Where("logs.token_name = ?", tokenName)
+ }
+ if startTimestamp != 0 {
+ tx = tx.Where("logs.created_at >= ?", startTimestamp)
+ }
+ if endTimestamp != 0 {
+ tx = tx.Where("logs.created_at <= ?", endTimestamp)
+ }
+ if channel != 0 {
+ tx = tx.Where("logs.channel_id = ?", channel)
+ }
+ if group != "" {
+ tx = tx.Where("logs."+logGroupCol+" = ?", group)
+ }
+ err = tx.Model(&Log{}).Count(&total).Error
+ if err != nil {
+ return nil, 0, err
+ }
+ err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
+ if err != nil {
+ return nil, 0, err
+ }
+
+ channelIdsMap := make(map[int]struct{})
+ channelMap := make(map[int]string)
+ for _, log := range logs {
+ if log.ChannelId != 0 {
+ channelIdsMap[log.ChannelId] = struct{}{}
+ }
+ }
+
+ channelIds := make([]int, 0, len(channelIdsMap))
+ for channelId := range channelIdsMap {
+ channelIds = append(channelIds, channelId)
+ }
+ if len(channelIds) > 0 {
+ var channels []struct {
+ Id int `gorm:"column:id"`
+ Name string `gorm:"column:name"`
+ }
+ if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds).Find(&channels).Error; err != nil {
+ return logs, total, err
+ }
+ for _, channel := range channels {
+ channelMap[channel.Id] = channel.Name
+ }
+ for i := range logs {
+ logs[i].ChannelName = channelMap[logs[i].ChannelId]
+ }
+ }
+
+ return logs, total, err
+}
+
+func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string) (logs []*Log, total int64, err error) {
+ var tx *gorm.DB
+ if logType == LogTypeUnknown {
+ tx = LOG_DB.Where("logs.user_id = ?", userId)
+ } else {
+ tx = LOG_DB.Where("logs.user_id = ? and logs.type = ?", userId, logType)
+ }
+
+ if modelName != "" {
+ tx = tx.Where("logs.model_name like ?", modelName)
+ }
+ if tokenName != "" {
+ tx = tx.Where("logs.token_name = ?", tokenName)
+ }
+ if startTimestamp != 0 {
+ tx = tx.Where("logs.created_at >= ?", startTimestamp)
+ }
+ if endTimestamp != 0 {
+ tx = tx.Where("logs.created_at <= ?", endTimestamp)
+ }
+ if group != "" {
+ tx = tx.Where("logs."+logGroupCol+" = ?", group)
+ }
+ err = tx.Model(&Log{}).Count(&total).Error
+ if err != nil {
+ return nil, 0, err
+ }
+ err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
+ if err != nil {
+ return nil, 0, err
+ }
+
+ formatUserLogs(logs)
+ return logs, total, err
+}
+
+func SearchAllLogs(keyword string) (logs []*Log, err error) {
+ err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
+ return logs, err
+}
+
+func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
+ err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
+ formatUserLogs(logs)
+ return logs, err
+}
+
+type Stat struct {
+ Quota int `json:"quota"`
+ Rpm int `json:"rpm"`
+ Tpm int `json:"tpm"`
+}
+
+func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat) {
+ tx := LOG_DB.Table("logs").Select("sum(quota) quota")
+
+ // 为rpm和tpm创建单独的查询
+ rpmTpmQuery := LOG_DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
+
+ if username != "" {
+ tx = tx.Where("username = ?", username)
+ rpmTpmQuery = rpmTpmQuery.Where("username = ?", username)
+ }
+ if tokenName != "" {
+ tx = tx.Where("token_name = ?", tokenName)
+ rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
+ }
+ if startTimestamp != 0 {
+ tx = tx.Where("created_at >= ?", startTimestamp)
+ }
+ if endTimestamp != 0 {
+ tx = tx.Where("created_at <= ?", endTimestamp)
+ }
+ if modelName != "" {
+ tx = tx.Where("model_name like ?", modelName)
+ rpmTpmQuery = rpmTpmQuery.Where("model_name like ?", modelName)
+ }
+ if channel != 0 {
+ tx = tx.Where("channel_id = ?", channel)
+ rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
+ }
+ if group != "" {
+ tx = tx.Where(logGroupCol+" = ?", group)
+ rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group)
+ }
+
+ tx = tx.Where("type = ?", LogTypeConsume)
+ rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
+
+ // 只统计最近60秒的rpm和tpm
+ rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
+
+ // 执行查询
+ tx.Scan(&stat)
+ rpmTpmQuery.Scan(&stat)
+
+ return stat
+}
+
+func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
+ tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
+ if username != "" {
+ tx = tx.Where("username = ?", username)
+ }
+ if tokenName != "" {
+ tx = tx.Where("token_name = ?", tokenName)
+ }
+ if startTimestamp != 0 {
+ tx = tx.Where("created_at >= ?", startTimestamp)
+ }
+ if endTimestamp != 0 {
+ tx = tx.Where("created_at <= ?", endTimestamp)
+ }
+ if modelName != "" {
+ tx = tx.Where("model_name = ?", modelName)
+ }
+ tx.Where("type = ?", LogTypeConsume).Scan(&token)
+ return token
+}
+
+func DeleteOldLog(ctx context.Context, targetTimestamp int64, limit int) (int64, error) {
+ var total int64 = 0
+
+ for {
+ if nil != ctx.Err() {
+ return total, ctx.Err()
+ }
+
+ result := LOG_DB.Where("created_at < ?", targetTimestamp).Limit(limit).Delete(&Log{})
+ if nil != result.Error {
+ return total, result.Error
+ }
+
+ total += result.RowsAffected
+
+ if result.RowsAffected < int64(limit) {
+ break
+ }
+ }
+
+ return total, nil
+}
diff --git a/model/main.go b/model/main.go
new file mode 100644
index 00000000..013beacd
--- /dev/null
+++ b/model/main.go
@@ -0,0 +1,363 @@
+package model
+
+import (
+ "fmt"
+ "log"
+ "one-api/common"
+ "one-api/constant"
+ "os"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/glebarez/sqlite"
+ "gorm.io/driver/mysql"
+ "gorm.io/driver/postgres"
+ "gorm.io/gorm"
+)
+
+var commonGroupCol string
+var commonKeyCol string
+var commonTrueVal string
+var commonFalseVal string
+
+var logKeyCol string
+var logGroupCol string
+
+func initCol() {
+ // init common column names
+ if common.UsingPostgreSQL {
+ commonGroupCol = `"group"`
+ commonKeyCol = `"key"`
+ commonTrueVal = "true"
+ commonFalseVal = "false"
+ } else {
+ commonGroupCol = "`group`"
+ commonKeyCol = "`key`"
+ commonTrueVal = "1"
+ commonFalseVal = "0"
+ }
+ if os.Getenv("LOG_SQL_DSN") != "" {
+ switch common.LogSqlType {
+ case common.DatabaseTypePostgreSQL:
+ logGroupCol = `"group"`
+ logKeyCol = `"key"`
+ default:
+ logGroupCol = commonGroupCol
+ logKeyCol = commonKeyCol
+ }
+ } else {
+ // LOG_SQL_DSN 为空时,日志数据库与主数据库相同
+ if common.UsingPostgreSQL {
+ logGroupCol = `"group"`
+ logKeyCol = `"key"`
+ } else {
+ logGroupCol = commonGroupCol
+ logKeyCol = commonKeyCol
+ }
+ }
+ // log sql type and database type
+ //common.SysLog("Using Log SQL Type: " + common.LogSqlType)
+}
+
+var DB *gorm.DB
+
+var LOG_DB *gorm.DB
+
+func createRootAccountIfNeed() error {
+ var user User
+ //if user.Status != common.UserStatusEnabled {
+ if err := DB.First(&user).Error; err != nil {
+ common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
+ hashedPassword, err := common.Password2Hash("123456")
+ if err != nil {
+ return err
+ }
+ rootUser := User{
+ Username: "root",
+ Password: hashedPassword,
+ Role: common.RoleRootUser,
+ Status: common.UserStatusEnabled,
+ DisplayName: "Root User",
+ AccessToken: nil,
+ Quota: 100000000,
+ }
+ DB.Create(&rootUser)
+ }
+ return nil
+}
+
+func CheckSetup() {
+ setup := GetSetup()
+ if setup == nil {
+ // No setup record exists, check if we have a root user
+ if RootUserExists() {
+ common.SysLog("system is not initialized, but root user exists")
+ // Create setup record
+ newSetup := Setup{
+ Version: common.Version,
+ InitializedAt: time.Now().Unix(),
+ }
+ err := DB.Create(&newSetup).Error
+ if err != nil {
+ common.SysLog("failed to create setup record: " + err.Error())
+ }
+ constant.Setup = true
+ } else {
+ common.SysLog("system is not initialized and no root user exists")
+ constant.Setup = false
+ }
+ } else {
+ // Setup record exists, system is initialized
+ common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
+ constant.Setup = true
+ }
+}
+
+func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
+ defer func() {
+ initCol()
+ }()
+ dsn := os.Getenv(envName)
+ if dsn != "" {
+ if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
+ // Use PostgreSQL
+ common.SysLog("using PostgreSQL as database")
+ if !isLog {
+ common.UsingPostgreSQL = true
+ } else {
+ common.LogSqlType = common.DatabaseTypePostgreSQL
+ }
+ return gorm.Open(postgres.New(postgres.Config{
+ DSN: dsn,
+ PreferSimpleProtocol: true, // disables implicit prepared statement usage
+ }), &gorm.Config{
+ PrepareStmt: true, // precompile SQL
+ })
+ }
+ if strings.HasPrefix(dsn, "local") {
+ common.SysLog("SQL_DSN not set, using SQLite as database")
+ if !isLog {
+ common.UsingSQLite = true
+ } else {
+ common.LogSqlType = common.DatabaseTypeSQLite
+ }
+ return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
+ PrepareStmt: true, // precompile SQL
+ })
+ }
+ // Use MySQL
+ common.SysLog("using MySQL as database")
+ // check parseTime
+ if !strings.Contains(dsn, "parseTime") {
+ if strings.Contains(dsn, "?") {
+ dsn += "&parseTime=true"
+ } else {
+ dsn += "?parseTime=true"
+ }
+ }
+ if !isLog {
+ common.UsingMySQL = true
+ } else {
+ common.LogSqlType = common.DatabaseTypeMySQL
+ }
+ return gorm.Open(mysql.Open(dsn), &gorm.Config{
+ PrepareStmt: true, // precompile SQL
+ })
+ }
+ // Use SQLite
+ common.SysLog("SQL_DSN not set, using SQLite as database")
+ common.UsingSQLite = true
+ return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
+ PrepareStmt: true, // precompile SQL
+ })
+}
+
+func InitDB() (err error) {
+ db, err := chooseDB("SQL_DSN", false)
+ if err == nil {
+ if common.DebugEnabled {
+ db = db.Debug()
+ }
+ DB = db
+ sqlDB, err := DB.DB()
+ if err != nil {
+ return err
+ }
+ sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
+ sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
+ sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
+
+ if !common.IsMasterNode {
+ return nil
+ }
+ if common.UsingMySQL {
+ //_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
+ }
+ common.SysLog("database migration started")
+ err = migrateDB()
+ return err
+ } else {
+ common.FatalLog(err)
+ }
+ return err
+}
+
+func InitLogDB() (err error) {
+ if os.Getenv("LOG_SQL_DSN") == "" {
+ LOG_DB = DB
+ return
+ }
+ db, err := chooseDB("LOG_SQL_DSN", true)
+ if err == nil {
+ if common.DebugEnabled {
+ db = db.Debug()
+ }
+ LOG_DB = db
+ sqlDB, err := LOG_DB.DB()
+ if err != nil {
+ return err
+ }
+ sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
+ sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
+ sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
+
+ if !common.IsMasterNode {
+ return nil
+ }
+ common.SysLog("database migration started")
+ err = migrateLOGDB()
+ return err
+ } else {
+ common.FatalLog(err)
+ }
+ return err
+}
+
+func migrateDB() error {
+ if !common.UsingPostgreSQL {
+ return migrateDBFast()
+ }
+ err := DB.AutoMigrate(
+ &Channel{},
+ &Token{},
+ &User{},
+ &Option{},
+ &Redemption{},
+ &Ability{},
+ &Log{},
+ &Midjourney{},
+ &TopUp{},
+ &QuotaData{},
+ &Task{},
+ &Setup{},
+ )
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func migrateDBFast() error {
+ var wg sync.WaitGroup
+
+ migrations := []struct {
+ model interface{}
+ name string
+ }{
+ {&Channel{}, "Channel"},
+ {&Token{}, "Token"},
+ {&User{}, "User"},
+ {&Option{}, "Option"},
+ {&Redemption{}, "Redemption"},
+ {&Ability{}, "Ability"},
+ {&Log{}, "Log"},
+ {&Midjourney{}, "Midjourney"},
+ {&TopUp{}, "TopUp"},
+ {&QuotaData{}, "QuotaData"},
+ {&Task{}, "Task"},
+ {&Setup{}, "Setup"},
+ }
+ // 动态计算migration数量,确保errChan缓冲区足够大
+ errChan := make(chan error, len(migrations))
+
+ for _, m := range migrations {
+ wg.Add(1)
+ go func(model interface{}, name string) {
+ defer wg.Done()
+ if err := DB.AutoMigrate(model); err != nil {
+ errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
+ }
+ }(m.model, m.name)
+ }
+
+ // Wait for all migrations to complete
+ wg.Wait()
+ close(errChan)
+
+ // Check for any errors
+ for err := range errChan {
+ if err != nil {
+ return err
+ }
+ }
+ common.SysLog("database migrated")
+ return nil
+}
+
+func migrateLOGDB() error {
+ var err error
+ if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
+ return err
+ }
+ return nil
+}
+
+func closeDB(db *gorm.DB) error {
+ sqlDB, err := db.DB()
+ if err != nil {
+ return err
+ }
+ err = sqlDB.Close()
+ return err
+}
+
+func CloseDB() error {
+ if LOG_DB != DB {
+ err := closeDB(LOG_DB)
+ if err != nil {
+ return err
+ }
+ }
+ return closeDB(DB)
+}
+
+var (
+ lastPingTime time.Time
+ pingMutex sync.Mutex
+)
+
+func PingDB() error {
+ pingMutex.Lock()
+ defer pingMutex.Unlock()
+
+ if time.Since(lastPingTime) < time.Second*10 {
+ return nil
+ }
+
+ sqlDB, err := DB.DB()
+ if err != nil {
+ log.Printf("Error getting sql.DB from GORM: %v", err)
+ return err
+ }
+
+ err = sqlDB.Ping()
+ if err != nil {
+ log.Printf("Error pinging DB: %v", err)
+ return err
+ }
+
+ lastPingTime = time.Now()
+ common.SysLog("Database pinged successfully")
+ return nil
+}
diff --git a/model/midjourney.go b/model/midjourney.go
new file mode 100644
index 00000000..c6ef5de5
--- /dev/null
+++ b/model/midjourney.go
@@ -0,0 +1,207 @@
+package model
+
+type Midjourney struct {
+ Id int `json:"id"`
+ Code int `json:"code"`
+ UserId int `json:"user_id" gorm:"index"`
+ Action string `json:"action" gorm:"type:varchar(40);index"`
+ MjId string `json:"mj_id" gorm:"index"`
+ Prompt string `json:"prompt"`
+ PromptEn string `json:"prompt_en"`
+ Description string `json:"description"`
+ State string `json:"state"`
+ SubmitTime int64 `json:"submit_time" gorm:"index"`
+ StartTime int64 `json:"start_time" gorm:"index"`
+ FinishTime int64 `json:"finish_time" gorm:"index"`
+ ImageUrl string `json:"image_url"`
+ VideoUrl string `json:"video_url"`
+ VideoUrls string `json:"video_urls"`
+ Status string `json:"status" gorm:"type:varchar(20);index"`
+ Progress string `json:"progress" gorm:"type:varchar(30);index"`
+ FailReason string `json:"fail_reason"`
+ ChannelId int `json:"channel_id"`
+ Quota int `json:"quota"`
+ Buttons string `json:"buttons"`
+ Properties string `json:"properties"`
+}
+
+// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
+type TaskQueryParams struct {
+ ChannelID string
+ MjID string
+ StartTimestamp string
+ EndTimestamp string
+}
+
+func GetAllUserTask(userId int, startIdx int, num int, queryParams TaskQueryParams) []*Midjourney {
+ var tasks []*Midjourney
+ var err error
+
+ // 初始化查询构建器
+ query := DB.Where("user_id = ?", userId)
+
+ if queryParams.MjID != "" {
+ query = query.Where("mj_id = ?", queryParams.MjID)
+ }
+ if queryParams.StartTimestamp != "" {
+ // 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != "" {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+
+ // 获取数据
+ err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
+ if err != nil {
+ return nil
+ }
+
+ return tasks
+}
+
+func GetAllTasks(startIdx int, num int, queryParams TaskQueryParams) []*Midjourney {
+ var tasks []*Midjourney
+ var err error
+
+ // 初始化查询构建器
+ query := DB
+
+ // 添加过滤条件
+ if queryParams.ChannelID != "" {
+ query = query.Where("channel_id = ?", queryParams.ChannelID)
+ }
+ if queryParams.MjID != "" {
+ query = query.Where("mj_id = ?", queryParams.MjID)
+ }
+ if queryParams.StartTimestamp != "" {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != "" {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+
+ // 获取数据
+ err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
+ if err != nil {
+ return nil
+ }
+
+ return tasks
+}
+
+func GetAllUnFinishTasks() []*Midjourney {
+ var tasks []*Midjourney
+ var err error
+ // get all tasks progress is not 100%
+ err = DB.Where("progress != ?", "100%").Find(&tasks).Error
+ if err != nil {
+ return nil
+ }
+ return tasks
+}
+
+func GetByOnlyMJId(mjId string) *Midjourney {
+ var mj *Midjourney
+ var err error
+ err = DB.Where("mj_id = ?", mjId).First(&mj).Error
+ if err != nil {
+ return nil
+ }
+ return mj
+}
+
+func GetByMJId(userId int, mjId string) *Midjourney {
+ var mj *Midjourney
+ var err error
+ err = DB.Where("user_id = ? and mj_id = ?", userId, mjId).First(&mj).Error
+ if err != nil {
+ return nil
+ }
+ return mj
+}
+
+func GetByMJIds(userId int, mjIds []string) []*Midjourney {
+ var mj []*Midjourney
+ var err error
+ err = DB.Where("user_id = ? and mj_id in (?)", userId, mjIds).Find(&mj).Error
+ if err != nil {
+ return nil
+ }
+ return mj
+}
+
+func GetMjByuId(id int) *Midjourney {
+ var mj *Midjourney
+ var err error
+ err = DB.Where("id = ?", id).First(&mj).Error
+ if err != nil {
+ return nil
+ }
+ return mj
+}
+
+func UpdateProgress(id int, progress string) error {
+ return DB.Model(&Midjourney{}).Where("id = ?", id).Update("progress", progress).Error
+}
+
+func (midjourney *Midjourney) Insert() error {
+ var err error
+ err = DB.Create(midjourney).Error
+ return err
+}
+
+func (midjourney *Midjourney) Update() error {
+ var err error
+ err = DB.Save(midjourney).Error
+ return err
+}
+
+func MjBulkUpdate(mjIds []string, params map[string]any) error {
+ return DB.Model(&Midjourney{}).
+ Where("mj_id in (?)", mjIds).
+ Updates(params).Error
+}
+
+func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
+ return DB.Model(&Midjourney{}).
+ Where("id in (?)", taskIDs).
+ Updates(params).Error
+}
+
+// CountAllTasks returns total midjourney tasks for admin query
+func CountAllTasks(queryParams TaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Midjourney{})
+ if queryParams.ChannelID != "" {
+ query = query.Where("channel_id = ?", queryParams.ChannelID)
+ }
+ if queryParams.MjID != "" {
+ query = query.Where("mj_id = ?", queryParams.MjID)
+ }
+ if queryParams.StartTimestamp != "" {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != "" {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
+
+// CountAllUserTask returns total midjourney tasks for user
+func CountAllUserTask(userId int, queryParams TaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Midjourney{}).Where("user_id = ?", userId)
+ if queryParams.MjID != "" {
+ query = query.Where("mj_id = ?", queryParams.MjID)
+ }
+ if queryParams.StartTimestamp != "" {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != "" {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
diff --git a/model/option.go b/model/option.go
new file mode 100644
index 00000000..05b99b41
--- /dev/null
+++ b/model/option.go
@@ -0,0 +1,442 @@
+package model
+
+import (
+ "one-api/common"
+ "one-api/setting"
+ "one-api/setting/config"
+ "one-api/setting/operation_setting"
+ "one-api/setting/ratio_setting"
+ "strconv"
+ "strings"
+ "time"
+)
+
+type Option struct {
+ Key string `json:"key" gorm:"primaryKey"`
+ Value string `json:"value"`
+}
+
+func AllOption() ([]*Option, error) {
+ var options []*Option
+ var err error
+ err = DB.Find(&options).Error
+ return options, err
+}
+
+func InitOptionMap() {
+ common.OptionMapRWMutex.Lock()
+ common.OptionMap = make(map[string]string)
+
+ // 添加原有的系统配置
+ common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
+ common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
+ common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
+ common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission)
+ common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled)
+ common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
+ common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
+ common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
+ common.OptionMap["LinuxDOOAuthEnabled"] = strconv.FormatBool(common.LinuxDOOAuthEnabled)
+ common.OptionMap["TelegramOAuthEnabled"] = strconv.FormatBool(common.TelegramOAuthEnabled)
+ common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
+ common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
+ common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
+ common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
+ common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
+ common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
+ common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
+ common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
+ common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
+ common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled)
+ common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
+ common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
+ common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
+ common.OptionMap["EmailAliasRestrictionEnabled"] = strconv.FormatBool(common.EmailAliasRestrictionEnabled)
+ common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
+ common.OptionMap["SMTPServer"] = ""
+ common.OptionMap["SMTPFrom"] = ""
+ common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
+ common.OptionMap["SMTPAccount"] = ""
+ common.OptionMap["SMTPToken"] = ""
+ common.OptionMap["SMTPSSLEnabled"] = strconv.FormatBool(common.SMTPSSLEnabled)
+ common.OptionMap["Notice"] = ""
+ common.OptionMap["About"] = ""
+ common.OptionMap["HomePageContent"] = ""
+ common.OptionMap["Footer"] = common.Footer
+ common.OptionMap["SystemName"] = common.SystemName
+ common.OptionMap["Logo"] = common.Logo
+ common.OptionMap["ServerAddress"] = ""
+ common.OptionMap["WorkerUrl"] = setting.WorkerUrl
+ common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey
+ common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(setting.WorkerAllowHttpImageRequestEnabled)
+ common.OptionMap["PayAddress"] = ""
+ common.OptionMap["CustomCallbackAddress"] = ""
+ common.OptionMap["EpayId"] = ""
+ common.OptionMap["EpayKey"] = ""
+ common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64)
+ common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(setting.USDExchangeRate, 'f', -1, 64)
+ common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
+ common.OptionMap["StripeMinTopUp"] = strconv.Itoa(setting.StripeMinTopUp)
+ common.OptionMap["StripeApiSecret"] = setting.StripeApiSecret
+ common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret
+ common.OptionMap["StripePriceId"] = setting.StripePriceId
+ common.OptionMap["StripeUnitPrice"] = strconv.FormatFloat(setting.StripeUnitPrice, 'f', -1, 64)
+ common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
+ common.OptionMap["Chats"] = setting.Chats2JsonString()
+ common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
+ common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
+ common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
+ common.OptionMap["GitHubClientId"] = ""
+ common.OptionMap["GitHubClientSecret"] = ""
+ common.OptionMap["TelegramBotToken"] = ""
+ common.OptionMap["TelegramBotName"] = ""
+ common.OptionMap["WeChatServerAddress"] = ""
+ common.OptionMap["WeChatServerToken"] = ""
+ common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
+ common.OptionMap["TurnstileSiteKey"] = ""
+ common.OptionMap["TurnstileSecretKey"] = ""
+ common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
+ common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
+ common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
+ common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
+ common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
+ common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
+ common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
+ common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
+ common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
+ common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
+ common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
+ common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
+ common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
+ common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
+ common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
+ common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
+ common.OptionMap["TopUpLink"] = common.TopUpLink
+ //common.OptionMap["ChatLink"] = common.ChatLink
+ //common.OptionMap["ChatLink2"] = common.ChatLink2
+ common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
+ common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
+ common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
+ common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
+ common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
+ common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(setting.MjNotifyEnabled)
+ common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(setting.MjAccountFilterEnabled)
+ common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(setting.MjModeClearEnabled)
+ common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
+ common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
+ common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
+ common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(operation_setting.DemoSiteEnabled)
+ common.OptionMap["SelfUseModeEnabled"] = strconv.FormatBool(operation_setting.SelfUseModeEnabled)
+ common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
+ common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
+ common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
+ common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
+ common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
+ common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
+ common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
+
+ // 自动添加所有注册的模型配置
+ modelConfigs := config.GlobalConfig.ExportAllConfigs()
+ for k, v := range modelConfigs {
+ common.OptionMap[k] = v
+ }
+
+ common.OptionMapRWMutex.Unlock()
+ loadOptionsFromDatabase()
+}
+
+func loadOptionsFromDatabase() {
+ options, _ := AllOption()
+ for _, option := range options {
+ err := updateOptionMap(option.Key, option.Value)
+ if err != nil {
+ common.SysError("failed to update option map: " + err.Error())
+ }
+ }
+}
+
+func SyncOptions(frequency int) {
+ for {
+ time.Sleep(time.Duration(frequency) * time.Second)
+ common.SysLog("syncing options from database")
+ loadOptionsFromDatabase()
+ }
+}
+
+func UpdateOption(key string, value string) error {
+ // Save to database first
+ option := Option{
+ Key: key,
+ }
+ // https://gorm.io/docs/update.html#Save-All-Fields
+ DB.FirstOrCreate(&option, Option{Key: key})
+ option.Value = value
+ // Save is a combination function.
+ // If save value does not contain primary key, it will execute Create,
+ // otherwise it will execute Update (with all fields).
+ DB.Save(&option)
+ // Update OptionMap
+ return updateOptionMap(key, value)
+}
+
+func updateOptionMap(key string, value string) (err error) {
+ common.OptionMapRWMutex.Lock()
+ defer common.OptionMapRWMutex.Unlock()
+ common.OptionMap[key] = value
+
+ // 检查是否是模型配置 - 使用更规范的方式处理
+ if handleConfigUpdate(key, value) {
+ return nil // 已由配置系统处理
+ }
+
+ // 处理传统配置项...
+ if strings.HasSuffix(key, "Permission") {
+ intValue, _ := strconv.Atoi(value)
+ switch key {
+ case "FileUploadPermission":
+ common.FileUploadPermission = intValue
+ case "FileDownloadPermission":
+ common.FileDownloadPermission = intValue
+ case "ImageUploadPermission":
+ common.ImageUploadPermission = intValue
+ case "ImageDownloadPermission":
+ common.ImageDownloadPermission = intValue
+ }
+ }
+ if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
+ boolValue := value == "true"
+ switch key {
+ case "PasswordRegisterEnabled":
+ common.PasswordRegisterEnabled = boolValue
+ case "PasswordLoginEnabled":
+ common.PasswordLoginEnabled = boolValue
+ case "EmailVerificationEnabled":
+ common.EmailVerificationEnabled = boolValue
+ case "GitHubOAuthEnabled":
+ common.GitHubOAuthEnabled = boolValue
+ case "LinuxDOOAuthEnabled":
+ common.LinuxDOOAuthEnabled = boolValue
+ case "WeChatAuthEnabled":
+ common.WeChatAuthEnabled = boolValue
+ case "TelegramOAuthEnabled":
+ common.TelegramOAuthEnabled = boolValue
+ case "TurnstileCheckEnabled":
+ common.TurnstileCheckEnabled = boolValue
+ case "RegisterEnabled":
+ common.RegisterEnabled = boolValue
+ case "EmailDomainRestrictionEnabled":
+ common.EmailDomainRestrictionEnabled = boolValue
+ case "EmailAliasRestrictionEnabled":
+ common.EmailAliasRestrictionEnabled = boolValue
+ case "AutomaticDisableChannelEnabled":
+ common.AutomaticDisableChannelEnabled = boolValue
+ case "AutomaticEnableChannelEnabled":
+ common.AutomaticEnableChannelEnabled = boolValue
+ case "LogConsumeEnabled":
+ common.LogConsumeEnabled = boolValue
+ case "DisplayInCurrencyEnabled":
+ common.DisplayInCurrencyEnabled = boolValue
+ case "DisplayTokenStatEnabled":
+ common.DisplayTokenStatEnabled = boolValue
+ case "DrawingEnabled":
+ common.DrawingEnabled = boolValue
+ case "TaskEnabled":
+ common.TaskEnabled = boolValue
+ case "DataExportEnabled":
+ common.DataExportEnabled = boolValue
+ case "DefaultCollapseSidebar":
+ common.DefaultCollapseSidebar = boolValue
+ case "MjNotifyEnabled":
+ setting.MjNotifyEnabled = boolValue
+ case "MjAccountFilterEnabled":
+ setting.MjAccountFilterEnabled = boolValue
+ case "MjModeClearEnabled":
+ setting.MjModeClearEnabled = boolValue
+ case "MjForwardUrlEnabled":
+ setting.MjForwardUrlEnabled = boolValue
+ case "MjActionCheckSuccessEnabled":
+ setting.MjActionCheckSuccessEnabled = boolValue
+ case "CheckSensitiveEnabled":
+ setting.CheckSensitiveEnabled = boolValue
+ case "DemoSiteEnabled":
+ operation_setting.DemoSiteEnabled = boolValue
+ case "SelfUseModeEnabled":
+ operation_setting.SelfUseModeEnabled = boolValue
+ case "CheckSensitiveOnPromptEnabled":
+ setting.CheckSensitiveOnPromptEnabled = boolValue
+ case "ModelRequestRateLimitEnabled":
+ setting.ModelRequestRateLimitEnabled = boolValue
+ case "StopOnSensitiveEnabled":
+ setting.StopOnSensitiveEnabled = boolValue
+ case "SMTPSSLEnabled":
+ common.SMTPSSLEnabled = boolValue
+ case "WorkerAllowHttpImageRequestEnabled":
+ setting.WorkerAllowHttpImageRequestEnabled = boolValue
+ case "DefaultUseAutoGroup":
+ setting.DefaultUseAutoGroup = boolValue
+ case "ExposeRatioEnabled":
+ ratio_setting.SetExposeRatioEnabled(boolValue)
+ }
+ }
+ switch key {
+ case "EmailDomainWhitelist":
+ common.EmailDomainWhitelist = strings.Split(value, ",")
+ case "SMTPServer":
+ common.SMTPServer = value
+ case "SMTPPort":
+ intValue, _ := strconv.Atoi(value)
+ common.SMTPPort = intValue
+ case "SMTPAccount":
+ common.SMTPAccount = value
+ case "SMTPFrom":
+ common.SMTPFrom = value
+ case "SMTPToken":
+ common.SMTPToken = value
+ case "ServerAddress":
+ setting.ServerAddress = value
+ case "WorkerUrl":
+ setting.WorkerUrl = value
+ case "WorkerValidKey":
+ setting.WorkerValidKey = value
+ case "PayAddress":
+ setting.PayAddress = value
+ case "Chats":
+ err = setting.UpdateChatsByJsonString(value)
+ case "AutoGroups":
+ err = setting.UpdateAutoGroupsByJsonString(value)
+ case "CustomCallbackAddress":
+ setting.CustomCallbackAddress = value
+ case "EpayId":
+ setting.EpayId = value
+ case "EpayKey":
+ setting.EpayKey = value
+ case "Price":
+ setting.Price, _ = strconv.ParseFloat(value, 64)
+ case "USDExchangeRate":
+ setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
+ case "MinTopUp":
+ setting.MinTopUp, _ = strconv.Atoi(value)
+ case "StripeApiSecret":
+ setting.StripeApiSecret = value
+ case "StripeWebhookSecret":
+ setting.StripeWebhookSecret = value
+ case "StripePriceId":
+ setting.StripePriceId = value
+ case "StripeUnitPrice":
+ setting.StripeUnitPrice, _ = strconv.ParseFloat(value, 64)
+ case "StripeMinTopUp":
+ setting.StripeMinTopUp, _ = strconv.Atoi(value)
+ case "TopupGroupRatio":
+ err = common.UpdateTopupGroupRatioByJSONString(value)
+ case "GitHubClientId":
+ common.GitHubClientId = value
+ case "GitHubClientSecret":
+ common.GitHubClientSecret = value
+ case "LinuxDOClientId":
+ common.LinuxDOClientId = value
+ case "LinuxDOClientSecret":
+ common.LinuxDOClientSecret = value
+ case "Footer":
+ common.Footer = value
+ case "SystemName":
+ common.SystemName = value
+ case "Logo":
+ common.Logo = value
+ case "WeChatServerAddress":
+ common.WeChatServerAddress = value
+ case "WeChatServerToken":
+ common.WeChatServerToken = value
+ case "WeChatAccountQRCodeImageURL":
+ common.WeChatAccountQRCodeImageURL = value
+ case "TelegramBotToken":
+ common.TelegramBotToken = value
+ case "TelegramBotName":
+ common.TelegramBotName = value
+ case "TurnstileSiteKey":
+ common.TurnstileSiteKey = value
+ case "TurnstileSecretKey":
+ common.TurnstileSecretKey = value
+ case "QuotaForNewUser":
+ common.QuotaForNewUser, _ = strconv.Atoi(value)
+ case "QuotaForInviter":
+ common.QuotaForInviter, _ = strconv.Atoi(value)
+ case "QuotaForInvitee":
+ common.QuotaForInvitee, _ = strconv.Atoi(value)
+ case "QuotaRemindThreshold":
+ common.QuotaRemindThreshold, _ = strconv.Atoi(value)
+ case "PreConsumedQuota":
+ common.PreConsumedQuota, _ = strconv.Atoi(value)
+ case "ModelRequestRateLimitCount":
+ setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
+ case "ModelRequestRateLimitDurationMinutes":
+ setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
+ case "ModelRequestRateLimitSuccessCount":
+ setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
+ case "ModelRequestRateLimitGroup":
+ err = setting.UpdateModelRequestRateLimitGroupByJSONString(value)
+ case "RetryTimes":
+ common.RetryTimes, _ = strconv.Atoi(value)
+ case "DataExportInterval":
+ common.DataExportInterval, _ = strconv.Atoi(value)
+ case "DataExportDefaultTime":
+ common.DataExportDefaultTime = value
+ case "ModelRatio":
+ err = ratio_setting.UpdateModelRatioByJSONString(value)
+ case "GroupRatio":
+ err = ratio_setting.UpdateGroupRatioByJSONString(value)
+ case "GroupGroupRatio":
+ err = ratio_setting.UpdateGroupGroupRatioByJSONString(value)
+ case "UserUsableGroups":
+ err = setting.UpdateUserUsableGroupsByJSONString(value)
+ case "CompletionRatio":
+ err = ratio_setting.UpdateCompletionRatioByJSONString(value)
+ case "ModelPrice":
+ err = ratio_setting.UpdateModelPriceByJSONString(value)
+ case "CacheRatio":
+ err = ratio_setting.UpdateCacheRatioByJSONString(value)
+ case "TopUpLink":
+ common.TopUpLink = value
+ //case "ChatLink":
+ // common.ChatLink = value
+ //case "ChatLink2":
+ // common.ChatLink2 = value
+ case "ChannelDisableThreshold":
+ common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
+ case "QuotaPerUnit":
+ common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
+ case "SensitiveWords":
+ setting.SensitiveWordsFromString(value)
+ case "AutomaticDisableKeywords":
+ operation_setting.AutomaticDisableKeywordsFromString(value)
+ case "StreamCacheQueueLength":
+ setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
+ case "PayMethods":
+ err = setting.UpdatePayMethodsByJsonString(value)
+ }
+ return err
+}
+
+// handleConfigUpdate 处理分层配置更新,返回是否已处理
+func handleConfigUpdate(key, value string) bool {
+ parts := strings.SplitN(key, ".", 2)
+ if len(parts) != 2 {
+ return false // 不是分层配置
+ }
+
+ configName := parts[0]
+ configKey := parts[1]
+
+ // 获取配置对象
+ cfg := config.GlobalConfig.Get(configName)
+ if cfg == nil {
+ return false // 未注册的配置
+ }
+
+ // 更新配置
+ configMap := map[string]string{
+ configKey: value,
+ }
+ config.UpdateConfigFromMap(cfg, configMap)
+
+ return true // 已处理
+}
diff --git a/model/pricing.go b/model/pricing.go
new file mode 100644
index 00000000..a280b524
--- /dev/null
+++ b/model/pricing.go
@@ -0,0 +1,127 @@
+package model
+
+import (
+ "fmt"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/setting/ratio_setting"
+ "one-api/types"
+ "sync"
+ "time"
+)
+
+type Pricing struct {
+ ModelName string `json:"model_name"`
+ QuotaType int `json:"quota_type"`
+ ModelRatio float64 `json:"model_ratio"`
+ ModelPrice float64 `json:"model_price"`
+ OwnerBy string `json:"owner_by"`
+ CompletionRatio float64 `json:"completion_ratio"`
+ EnableGroup []string `json:"enable_groups"`
+ SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
+}
+
+var (
+ pricingMap []Pricing
+ lastGetPricingTime time.Time
+ updatePricingLock sync.Mutex
+)
+
+var (
+ modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
+ modelSupportEndpointsLock = sync.RWMutex{}
+)
+
+func GetPricing() []Pricing {
+ if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
+ updatePricingLock.Lock()
+ defer updatePricingLock.Unlock()
+ // Double check after acquiring the lock
+ if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
+ modelSupportEndpointsLock.Lock()
+ defer modelSupportEndpointsLock.Unlock()
+ updatePricing()
+ }
+ }
+ return pricingMap
+}
+
+func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
+ if model == "" {
+ return make([]constant.EndpointType, 0)
+ }
+ modelSupportEndpointsLock.RLock()
+ defer modelSupportEndpointsLock.RUnlock()
+ if endpoints, ok := modelSupportEndpointTypes[model]; ok {
+ return endpoints
+ }
+ return make([]constant.EndpointType, 0)
+}
+
+func updatePricing() {
+ //modelRatios := common.GetModelRatios()
+ enableAbilities, err := GetAllEnableAbilityWithChannels()
+ if err != nil {
+ common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
+ return
+ }
+ modelGroupsMap := make(map[string]*types.Set[string])
+
+ for _, ability := range enableAbilities {
+ groups, ok := modelGroupsMap[ability.Model]
+ if !ok {
+ groups = types.NewSet[string]()
+ modelGroupsMap[ability.Model] = groups
+ }
+ groups.Add(ability.Group)
+ }
+
+ //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
+ modelSupportEndpointsStr := make(map[string][]string)
+
+ for _, ability := range enableAbilities {
+ endpoints, ok := modelSupportEndpointsStr[ability.Model]
+ if !ok {
+ endpoints = make([]string, 0)
+ modelSupportEndpointsStr[ability.Model] = endpoints
+ }
+ channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
+ for _, channelType := range channelTypes {
+ if !common.StringsContains(endpoints, string(channelType)) {
+ endpoints = append(endpoints, string(channelType))
+ }
+ }
+ modelSupportEndpointsStr[ability.Model] = endpoints
+ }
+
+ modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
+ for model, endpoints := range modelSupportEndpointsStr {
+ supportedEndpoints := make([]constant.EndpointType, 0)
+ for _, endpointStr := range endpoints {
+ endpointType := constant.EndpointType(endpointStr)
+ supportedEndpoints = append(supportedEndpoints, endpointType)
+ }
+ modelSupportEndpointTypes[model] = supportedEndpoints
+ }
+
+ pricingMap = make([]Pricing, 0)
+ for model, groups := range modelGroupsMap {
+ pricing := Pricing{
+ ModelName: model,
+ EnableGroup: groups.Items(),
+ SupportedEndpointTypes: modelSupportEndpointTypes[model],
+ }
+ modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
+ if findPrice {
+ pricing.ModelPrice = modelPrice
+ pricing.QuotaType = 1
+ } else {
+ modelRatio, _, _ := ratio_setting.GetModelRatio(model)
+ pricing.ModelRatio = modelRatio
+ pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
+ pricing.QuotaType = 0
+ }
+ pricingMap = append(pricingMap, pricing)
+ }
+ lastGetPricingTime = time.Now()
+}
diff --git a/model/redemption.go b/model/redemption.go
new file mode 100644
index 00000000..bf237668
--- /dev/null
+++ b/model/redemption.go
@@ -0,0 +1,195 @@
+package model
+
+import (
+ "errors"
+ "fmt"
+ "one-api/common"
+ "strconv"
+
+ "gorm.io/gorm"
+)
+
+type Redemption struct {
+ Id int `json:"id"`
+ UserId int `json:"user_id"`
+ Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
+ Status int `json:"status" gorm:"default:1"`
+ Name string `json:"name" gorm:"index"`
+ Quota int `json:"quota" gorm:"default:100"`
+ CreatedTime int64 `json:"created_time" gorm:"bigint"`
+ RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
+ Count int `json:"count" gorm:"-:all"` // only for api request
+ UsedUserId int `json:"used_user_id"`
+ DeletedAt gorm.DeletedAt `gorm:"index"`
+ ExpiredTime int64 `json:"expired_time" gorm:"bigint"` // 过期时间,0 表示不过期
+}
+
+func GetAllRedemptions(startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
+ // 开始事务
+ tx := DB.Begin()
+ if tx.Error != nil {
+ return nil, 0, tx.Error
+ }
+ defer func() {
+ if r := recover(); r != nil {
+ tx.Rollback()
+ }
+ }()
+
+ // 获取总数
+ err = tx.Model(&Redemption{}).Count(&total).Error
+ if err != nil {
+ tx.Rollback()
+ return nil, 0, err
+ }
+
+ // 获取分页数据
+ err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&redemptions).Error
+ if err != nil {
+ tx.Rollback()
+ return nil, 0, err
+ }
+
+ // 提交事务
+ if err = tx.Commit().Error; err != nil {
+ return nil, 0, err
+ }
+
+ return redemptions, total, nil
+}
+
+func SearchRedemptions(keyword string, startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
+ tx := DB.Begin()
+ if tx.Error != nil {
+ return nil, 0, tx.Error
+ }
+ defer func() {
+ if r := recover(); r != nil {
+ tx.Rollback()
+ }
+ }()
+
+ // Build query based on keyword type
+ query := tx.Model(&Redemption{})
+
+ // Only try to convert to ID if the string represents a valid integer
+ if id, err := strconv.Atoi(keyword); err == nil {
+ query = query.Where("id = ? OR name LIKE ?", id, keyword+"%")
+ } else {
+ query = query.Where("name LIKE ?", keyword+"%")
+ }
+
+ // Get total count
+ err = query.Count(&total).Error
+ if err != nil {
+ tx.Rollback()
+ return nil, 0, err
+ }
+
+ // Get paginated data
+ err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&redemptions).Error
+ if err != nil {
+ tx.Rollback()
+ return nil, 0, err
+ }
+
+ if err = tx.Commit().Error; err != nil {
+ return nil, 0, err
+ }
+
+ return redemptions, total, nil
+}
+
+func GetRedemptionById(id int) (*Redemption, error) {
+ if id == 0 {
+ return nil, errors.New("id 为空!")
+ }
+ redemption := Redemption{Id: id}
+ var err error = nil
+ err = DB.First(&redemption, "id = ?", id).Error
+ return &redemption, err
+}
+
+func Redeem(key string, userId int) (quota int, err error) {
+ if key == "" {
+ return 0, errors.New("未提供兑换码")
+ }
+ if userId == 0 {
+ return 0, errors.New("无效的 user id")
+ }
+ redemption := &Redemption{}
+
+ keyCol := "`key`"
+ if common.UsingPostgreSQL {
+ keyCol = `"key"`
+ }
+ common.RandomSleep()
+ err = DB.Transaction(func(tx *gorm.DB) error {
+ err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
+ if err != nil {
+ return errors.New("无效的兑换码")
+ }
+ if redemption.Status != common.RedemptionCodeStatusEnabled {
+ return errors.New("该兑换码已被使用")
+ }
+ if redemption.ExpiredTime != 0 && redemption.ExpiredTime < common.GetTimestamp() {
+ return errors.New("该兑换码已过期")
+ }
+ err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
+ if err != nil {
+ return err
+ }
+ redemption.RedeemedTime = common.GetTimestamp()
+ redemption.Status = common.RedemptionCodeStatusUsed
+ redemption.UsedUserId = userId
+ err = tx.Save(redemption).Error
+ return err
+ })
+ if err != nil {
+ return 0, errors.New("兑换失败," + err.Error())
+ }
+ RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s,兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id))
+ return redemption.Quota, nil
+}
+
+func (redemption *Redemption) Insert() error {
+ var err error
+ err = DB.Create(redemption).Error
+ return err
+}
+
+func (redemption *Redemption) SelectUpdate() error {
+ // This can update zero values
+ return DB.Model(redemption).Select("redeemed_time", "status").Updates(redemption).Error
+}
+
+// Update Make sure your token's fields is completed, because this will update non-zero values
+func (redemption *Redemption) Update() error {
+ var err error
+ err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time", "expired_time").Updates(redemption).Error
+ return err
+}
+
+func (redemption *Redemption) Delete() error {
+ var err error
+ err = DB.Delete(redemption).Error
+ return err
+}
+
+func DeleteRedemptionById(id int) (err error) {
+ if id == 0 {
+ return errors.New("id 为空!")
+ }
+ redemption := Redemption{Id: id}
+ err = DB.Where(redemption).First(&redemption).Error
+ if err != nil {
+ return err
+ }
+ return redemption.Delete()
+}
+
+func DeleteInvalidRedemptions() (int64, error) {
+ now := common.GetTimestamp()
+ result := DB.Where("status IN ? OR (status = ? AND expired_time != 0 AND expired_time < ?)", []int{common.RedemptionCodeStatusUsed, common.RedemptionCodeStatusDisabled}, common.RedemptionCodeStatusEnabled, now).Delete(&Redemption{})
+ return result.RowsAffected, result.Error
+}
diff --git a/model/setup.go b/model/setup.go
new file mode 100644
index 00000000..c4d7997f
--- /dev/null
+++ b/model/setup.go
@@ -0,0 +1,16 @@
+package model
+
+type Setup struct {
+ ID uint `json:"id" gorm:"primaryKey"`
+ Version string `json:"version" gorm:"type:varchar(50);not null"`
+ InitializedAt int64 `json:"initialized_at" gorm:"type:bigint;not null"`
+}
+
+func GetSetup() *Setup {
+ var setup Setup
+ err := DB.First(&setup).Error
+ if err != nil {
+ return nil
+ }
+ return &setup
+}
diff --git a/model/task.go b/model/task.go
new file mode 100644
index 00000000..9e4177ba
--- /dev/null
+++ b/model/task.go
@@ -0,0 +1,365 @@
+package model
+
+import (
+ "database/sql/driver"
+ "encoding/json"
+ "one-api/constant"
+ commonRelay "one-api/relay/common"
+ "time"
+)
+
+type TaskStatus string
+
+const (
+ TaskStatusNotStart TaskStatus = "NOT_START"
+ TaskStatusSubmitted = "SUBMITTED"
+ TaskStatusQueued = "QUEUED"
+ TaskStatusInProgress = "IN_PROGRESS"
+ TaskStatusFailure = "FAILURE"
+ TaskStatusSuccess = "SUCCESS"
+ TaskStatusUnknown = "UNKNOWN"
+)
+
+type Task struct {
+ ID int64 `json:"id" gorm:"primary_key;AUTO_INCREMENT"`
+ CreatedAt int64 `json:"created_at" gorm:"index"`
+ UpdatedAt int64 `json:"updated_at"`
+ TaskID string `json:"task_id" gorm:"type:varchar(50);index"` // 第三方id,不一定有/ song id\ Task id
+ Platform constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台
+ UserId int `json:"user_id" gorm:"index"`
+ ChannelId int `json:"channel_id" gorm:"index"`
+ Quota int `json:"quota"`
+ Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
+ Status TaskStatus `json:"status" gorm:"type:varchar(20);index"` // 任务状态
+ FailReason string `json:"fail_reason"`
+ SubmitTime int64 `json:"submit_time" gorm:"index"`
+ StartTime int64 `json:"start_time" gorm:"index"`
+ FinishTime int64 `json:"finish_time" gorm:"index"`
+ Progress string `json:"progress" gorm:"type:varchar(20);index"`
+ Properties Properties `json:"properties" gorm:"type:json"`
+
+ Data json.RawMessage `json:"data" gorm:"type:json"`
+}
+
+func (t *Task) SetData(data any) {
+ b, _ := json.Marshal(data)
+ t.Data = json.RawMessage(b)
+}
+
+func (t *Task) GetData(v any) error {
+ err := json.Unmarshal(t.Data, &v)
+ return err
+}
+
+type Properties struct {
+ Input string `json:"input"`
+}
+
+func (m *Properties) Scan(val interface{}) error {
+ bytesValue, _ := val.([]byte)
+ return json.Unmarshal(bytesValue, m)
+}
+
+func (m Properties) Value() (driver.Value, error) {
+ return json.Marshal(m)
+}
+
+// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
+type SyncTaskQueryParams struct {
+ Platform constant.TaskPlatform
+ ChannelID string
+ TaskID string
+ UserID string
+ Action string
+ Status string
+ StartTimestamp int64
+ EndTimestamp int64
+ UserIDs []int
+}
+
+func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task {
+ t := &Task{
+ UserId: relayInfo.UserId,
+ SubmitTime: time.Now().Unix(),
+ Status: TaskStatusNotStart,
+ Progress: "0%",
+ ChannelId: relayInfo.ChannelId,
+ Platform: platform,
+ }
+ return t
+}
+
+func TaskGetAllUserTask(userId int, startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
+ var tasks []*Task
+ var err error
+
+ // 初始化查询构建器
+ query := DB.Where("user_id = ?", userId)
+
+ if queryParams.TaskID != "" {
+ query = query.Where("task_id = ?", queryParams.TaskID)
+ }
+ if queryParams.Action != "" {
+ query = query.Where("action = ?", queryParams.Action)
+ }
+ if queryParams.Status != "" {
+ query = query.Where("status = ?", queryParams.Status)
+ }
+ if queryParams.Platform != "" {
+ query = query.Where("platform = ?", queryParams.Platform)
+ }
+ if queryParams.StartTimestamp != 0 {
+ // 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != 0 {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+
+ // 获取数据
+ err = query.Omit("channel_id").Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
+ if err != nil {
+ return nil
+ }
+
+ return tasks
+}
+
+func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
+ var tasks []*Task
+ var err error
+
+ // 初始化查询构建器
+ query := DB
+
+ // 添加过滤条件
+ if queryParams.ChannelID != "" {
+ query = query.Where("channel_id = ?", queryParams.ChannelID)
+ }
+ if queryParams.Platform != "" {
+ query = query.Where("platform = ?", queryParams.Platform)
+ }
+ if queryParams.UserID != "" {
+ query = query.Where("user_id = ?", queryParams.UserID)
+ }
+ if len(queryParams.UserIDs) != 0 {
+ query = query.Where("user_id in (?)", queryParams.UserIDs)
+ }
+ if queryParams.TaskID != "" {
+ query = query.Where("task_id = ?", queryParams.TaskID)
+ }
+ if queryParams.Action != "" {
+ query = query.Where("action = ?", queryParams.Action)
+ }
+ if queryParams.Status != "" {
+ query = query.Where("status = ?", queryParams.Status)
+ }
+ if queryParams.StartTimestamp != 0 {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != 0 {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+
+ // 获取数据
+ err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
+ if err != nil {
+ return nil
+ }
+
+ return tasks
+}
+
+func GetAllUnFinishSyncTasks(limit int) []*Task {
+ var tasks []*Task
+ var err error
+ // get all tasks progress is not 100%
+ err = DB.Where("progress != ?", "100%").Limit(limit).Order("id").Find(&tasks).Error
+ if err != nil {
+ return nil
+ }
+ return tasks
+}
+
+func GetByOnlyTaskId(taskId string) (*Task, bool, error) {
+ if taskId == "" {
+ return nil, false, nil
+ }
+ var task *Task
+ var err error
+ err = DB.Where("task_id = ?", taskId).First(&task).Error
+ exist, err := RecordExist(err)
+ if err != nil {
+ return nil, false, err
+ }
+ return task, exist, err
+}
+
+func GetByTaskId(userId int, taskId string) (*Task, bool, error) {
+ if taskId == "" {
+ return nil, false, nil
+ }
+ var task *Task
+ var err error
+ err = DB.Where("user_id = ? and task_id = ?", userId, taskId).
+ First(&task).Error
+ exist, err := RecordExist(err)
+ if err != nil {
+ return nil, false, err
+ }
+ return task, exist, err
+}
+
+func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
+ if len(taskIds) == 0 {
+ return nil, nil
+ }
+ var task []*Task
+ var err error
+ err = DB.Where("user_id = ? and task_id in (?)", userId, taskIds).
+ Find(&task).Error
+ if err != nil {
+ return nil, err
+ }
+ return task, nil
+}
+
+func TaskUpdateProgress(id int64, progress string) error {
+ return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
+}
+
+func (Task *Task) Insert() error {
+ var err error
+ err = DB.Create(Task).Error
+ return err
+}
+
+func (Task *Task) Update() error {
+ var err error
+ err = DB.Save(Task).Error
+ return err
+}
+
+func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
+ if len(TaskIds) == 0 {
+ return nil
+ }
+ return DB.Model(&Task{}).
+ Where("task_id in (?)", TaskIds).
+ Updates(params).Error
+}
+
+func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
+ if len(taskIDs) == 0 {
+ return nil
+ }
+ return DB.Model(&Task{}).
+ Where("id in (?)", taskIDs).
+ Updates(params).Error
+}
+
+func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
+ if len(ids) == 0 {
+ return nil
+ }
+ return DB.Model(&Task{}).
+ Where("id in (?)", ids).
+ Updates(params).Error
+}
+
+type TaskQuotaUsage struct {
+ Mode string `json:"mode"`
+ Count float64 `json:"count"`
+}
+
+func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
+ query := DB.Model(Task{})
+ // 添加过滤条件
+ if queryParams.ChannelID != "" {
+ query = query.Where("channel_id = ?", queryParams.ChannelID)
+ }
+ if queryParams.UserID != "" {
+ query = query.Where("user_id = ?", queryParams.UserID)
+ }
+ if len(queryParams.UserIDs) != 0 {
+ query = query.Where("user_id in (?)", queryParams.UserIDs)
+ }
+ if queryParams.TaskID != "" {
+ query = query.Where("task_id = ?", queryParams.TaskID)
+ }
+ if queryParams.Action != "" {
+ query = query.Where("action = ?", queryParams.Action)
+ }
+ if queryParams.Status != "" {
+ query = query.Where("status = ?", queryParams.Status)
+ }
+ if queryParams.StartTimestamp != 0 {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != 0 {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
+ return stat, err
+}
+
+// TaskCountAllTasks returns total tasks that match the given query params (admin usage)
+func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Task{})
+ if queryParams.ChannelID != "" {
+ query = query.Where("channel_id = ?", queryParams.ChannelID)
+ }
+ if queryParams.Platform != "" {
+ query = query.Where("platform = ?", queryParams.Platform)
+ }
+ if queryParams.UserID != "" {
+ query = query.Where("user_id = ?", queryParams.UserID)
+ }
+ if len(queryParams.UserIDs) != 0 {
+ query = query.Where("user_id in (?)", queryParams.UserIDs)
+ }
+ if queryParams.TaskID != "" {
+ query = query.Where("task_id = ?", queryParams.TaskID)
+ }
+ if queryParams.Action != "" {
+ query = query.Where("action = ?", queryParams.Action)
+ }
+ if queryParams.Status != "" {
+ query = query.Where("status = ?", queryParams.Status)
+ }
+ if queryParams.StartTimestamp != 0 {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != 0 {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
+
+// TaskCountAllUserTask returns total tasks for given user
+func TaskCountAllUserTask(userId int, queryParams SyncTaskQueryParams) int64 {
+ var total int64
+ query := DB.Model(&Task{}).Where("user_id = ?", userId)
+ if queryParams.TaskID != "" {
+ query = query.Where("task_id = ?", queryParams.TaskID)
+ }
+ if queryParams.Action != "" {
+ query = query.Where("action = ?", queryParams.Action)
+ }
+ if queryParams.Status != "" {
+ query = query.Where("status = ?", queryParams.Status)
+ }
+ if queryParams.Platform != "" {
+ query = query.Where("platform = ?", queryParams.Platform)
+ }
+ if queryParams.StartTimestamp != 0 {
+ query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+ }
+ if queryParams.EndTimestamp != 0 {
+ query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+ }
+ _ = query.Count(&total).Error
+ return total
+}
diff --git a/model/token.go b/model/token.go
new file mode 100644
index 00000000..e85a445e
--- /dev/null
+++ b/model/token.go
@@ -0,0 +1,363 @@
+package model
+
+import (
+ "errors"
+ "fmt"
+ "one-api/common"
+ "strings"
+
+ "github.com/bytedance/gopkg/util/gopool"
+ "gorm.io/gorm"
+)
+
+type Token struct {
+ Id int `json:"id"`
+ UserId int `json:"user_id" gorm:"index"`
+ Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
+ Status int `json:"status" gorm:"default:1"`
+ Name string `json:"name" gorm:"index" `
+ CreatedTime int64 `json:"created_time" gorm:"bigint"`
+ AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
+ ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
+ RemainQuota int `json:"remain_quota" gorm:"default:0"`
+ UnlimitedQuota bool `json:"unlimited_quota"`
+ ModelLimitsEnabled bool `json:"model_limits_enabled"`
+ ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
+ AllowIps *string `json:"allow_ips" gorm:"default:''"`
+ UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
+ Group string `json:"group" gorm:"default:''"`
+ DeletedAt gorm.DeletedAt `gorm:"index"`
+}
+
+func (token *Token) Clean() {
+ token.Key = ""
+}
+
+func (token *Token) GetIpLimitsMap() map[string]any {
+ // delete empty spaces
+ //split with \n
+ ipLimitsMap := make(map[string]any)
+ if token.AllowIps == nil {
+ return ipLimitsMap
+ }
+ cleanIps := strings.ReplaceAll(*token.AllowIps, " ", "")
+ if cleanIps == "" {
+ return ipLimitsMap
+ }
+ ips := strings.Split(cleanIps, "\n")
+ for _, ip := range ips {
+ ip = strings.TrimSpace(ip)
+ ip = strings.ReplaceAll(ip, ",", "")
+ if common.IsIP(ip) {
+ ipLimitsMap[ip] = true
+ }
+ }
+ return ipLimitsMap
+}
+
+func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
+ var tokens []*Token
+ var err error
+ err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error
+ return tokens, err
+}
+
+func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) {
+ if token != "" {
+ token = strings.Trim(token, "sk-")
+ }
+ err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
+ return tokens, err
+}
+
+func ValidateUserToken(key string) (token *Token, err error) {
+ if key == "" {
+ return nil, errors.New("未提供令牌")
+ }
+ token, err = GetTokenByKey(key, false)
+ if err == nil {
+ if token.Status == common.TokenStatusExhausted {
+ keyPrefix := key[:3]
+ keySuffix := key[len(key)-3:]
+ return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
+ } else if token.Status == common.TokenStatusExpired {
+ return token, errors.New("该令牌已过期")
+ }
+ if token.Status != common.TokenStatusEnabled {
+ return token, errors.New("该令牌状态不可用")
+ }
+ if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
+ if !common.RedisEnabled {
+ token.Status = common.TokenStatusExpired
+ err := token.SelectUpdate()
+ if err != nil {
+ common.SysError("failed to update token status" + err.Error())
+ }
+ }
+ return token, errors.New("该令牌已过期")
+ }
+ if !token.UnlimitedQuota && token.RemainQuota <= 0 {
+ if !common.RedisEnabled {
+ // in this case, we can make sure the token is exhausted
+ token.Status = common.TokenStatusExhausted
+ err := token.SelectUpdate()
+ if err != nil {
+ common.SysError("failed to update token status" + err.Error())
+ }
+ }
+ keyPrefix := key[:3]
+ keySuffix := key[len(key)-3:]
+ return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
+ }
+ return token, nil
+ }
+ return nil, errors.New("无效的令牌")
+}
+
+func GetTokenByIds(id int, userId int) (*Token, error) {
+ if id == 0 || userId == 0 {
+ return nil, errors.New("id 或 userId 为空!")
+ }
+ token := Token{Id: id, UserId: userId}
+ var err error = nil
+ err = DB.First(&token, "id = ? and user_id = ?", id, userId).Error
+ return &token, err
+}
+
+func GetTokenById(id int) (*Token, error) {
+ if id == 0 {
+ return nil, errors.New("id 为空!")
+ }
+ token := Token{Id: id}
+ var err error = nil
+ err = DB.First(&token, "id = ?", id).Error
+ if shouldUpdateRedis(true, err) {
+ gopool.Go(func() {
+ if err := cacheSetToken(token); err != nil {
+ common.SysError("failed to update user status cache: " + err.Error())
+ }
+ })
+ }
+ return &token, err
+}
+
+func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
+ defer func() {
+ // Update Redis cache asynchronously on successful DB read
+ if shouldUpdateRedis(fromDB, err) && token != nil {
+ gopool.Go(func() {
+ if err := cacheSetToken(*token); err != nil {
+ common.SysError("failed to update user status cache: " + err.Error())
+ }
+ })
+ }
+ }()
+ if !fromDB && common.RedisEnabled {
+ // Try Redis first
+ token, err := cacheGetTokenByKey(key)
+ if err == nil {
+ return token, nil
+ }
+ // Don't return error - fall through to DB
+ }
+ fromDB = true
+ err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error
+ return token, err
+}
+
+func (token *Token) Insert() error {
+ var err error
+ err = DB.Create(token).Error
+ return err
+}
+
+// Update Make sure your token's fields is completed, because this will update non-zero values
+func (token *Token) Update() (err error) {
+ defer func() {
+ if shouldUpdateRedis(true, err) {
+ gopool.Go(func() {
+ err := cacheSetToken(*token)
+ if err != nil {
+ common.SysError("failed to update token cache: " + err.Error())
+ }
+ })
+ }
+ }()
+ err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
+ "model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
+ return err
+}
+
+func (token *Token) SelectUpdate() (err error) {
+ defer func() {
+ if shouldUpdateRedis(true, err) {
+ gopool.Go(func() {
+ err := cacheSetToken(*token)
+ if err != nil {
+ common.SysError("failed to update token cache: " + err.Error())
+ }
+ })
+ }
+ }()
+ // This can update zero values
+ return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
+}
+
+func (token *Token) Delete() (err error) {
+ defer func() {
+ if shouldUpdateRedis(true, err) {
+ gopool.Go(func() {
+ err := cacheDeleteToken(token.Key)
+ if err != nil {
+ common.SysError("failed to delete token cache: " + err.Error())
+ }
+ })
+ }
+ }()
+ err = DB.Delete(token).Error
+ return err
+}
+
+func (token *Token) IsModelLimitsEnabled() bool {
+ return token.ModelLimitsEnabled
+}
+
+func (token *Token) GetModelLimits() []string {
+ if token.ModelLimits == "" {
+ return []string{}
+ }
+ return strings.Split(token.ModelLimits, ",")
+}
+
+func (token *Token) GetModelLimitsMap() map[string]bool {
+ limits := token.GetModelLimits()
+ limitsMap := make(map[string]bool)
+ for _, limit := range limits {
+ limitsMap[limit] = true
+ }
+ return limitsMap
+}
+
+func DisableModelLimits(tokenId int) error {
+ token, err := GetTokenById(tokenId)
+ if err != nil {
+ return err
+ }
+ token.ModelLimitsEnabled = false
+ token.ModelLimits = ""
+ return token.Update()
+}
+
+func DeleteTokenById(id int, userId int) (err error) {
+ // Why we need userId here? In case user want to delete other's token.
+ if id == 0 || userId == 0 {
+ return errors.New("id 或 userId 为空!")
+ }
+ token := Token{Id: id, UserId: userId}
+ err = DB.Where(token).First(&token).Error
+ if err != nil {
+ return err
+ }
+ return token.Delete()
+}
+
+func IncreaseTokenQuota(id int, key string, quota int) (err error) {
+ if quota < 0 {
+ return errors.New("quota 不能为负数!")
+ }
+ if common.RedisEnabled {
+ gopool.Go(func() {
+ err := cacheIncrTokenQuota(key, int64(quota))
+ if err != nil {
+ common.SysError("failed to increase token quota: " + err.Error())
+ }
+ })
+ }
+ if common.BatchUpdateEnabled {
+ addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
+ return nil
+ }
+ return increaseTokenQuota(id, quota)
+}
+
+func increaseTokenQuota(id int, quota int) (err error) {
+ err = DB.Model(&Token{}).Where("id = ?", id).Updates(
+ map[string]interface{}{
+ "remain_quota": gorm.Expr("remain_quota + ?", quota),
+ "used_quota": gorm.Expr("used_quota - ?", quota),
+ "accessed_time": common.GetTimestamp(),
+ },
+ ).Error
+ return err
+}
+
+func DecreaseTokenQuota(id int, key string, quota int) (err error) {
+ if quota < 0 {
+ return errors.New("quota 不能为负数!")
+ }
+ if common.RedisEnabled {
+ gopool.Go(func() {
+ err := cacheDecrTokenQuota(key, int64(quota))
+ if err != nil {
+ common.SysError("failed to decrease token quota: " + err.Error())
+ }
+ })
+ }
+ if common.BatchUpdateEnabled {
+ addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
+ return nil
+ }
+ return decreaseTokenQuota(id, quota)
+}
+
+func decreaseTokenQuota(id int, quota int) (err error) {
+ err = DB.Model(&Token{}).Where("id = ?", id).Updates(
+ map[string]interface{}{
+ "remain_quota": gorm.Expr("remain_quota - ?", quota),
+ "used_quota": gorm.Expr("used_quota + ?", quota),
+ "accessed_time": common.GetTimestamp(),
+ },
+ ).Error
+ return err
+}
+
+// CountUserTokens returns total number of tokens for the given user, used for pagination
+func CountUserTokens(userId int) (int64, error) {
+ var total int64
+ err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
+ return total, err
+}
+
+// BatchDeleteTokens 删除指定用户的一组令牌,返回成功删除数量
+func BatchDeleteTokens(ids []int, userId int) (int, error) {
+ if len(ids) == 0 {
+ return 0, errors.New("ids 不能为空!")
+ }
+
+ tx := DB.Begin()
+
+ var tokens []Token
+ if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Find(&tokens).Error; err != nil {
+ tx.Rollback()
+ return 0, err
+ }
+
+ if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Delete(&Token{}).Error; err != nil {
+ tx.Rollback()
+ return 0, err
+ }
+
+ if err := tx.Commit().Error; err != nil {
+ return 0, err
+ }
+
+ if common.RedisEnabled {
+ gopool.Go(func() {
+ for _, t := range tokens {
+ _ = cacheDeleteToken(t.Key)
+ }
+ })
+ }
+
+ return len(tokens), nil
+}
diff --git a/model/token_cache.go b/model/token_cache.go
new file mode 100644
index 00000000..5399dbc8
--- /dev/null
+++ b/model/token_cache.go
@@ -0,0 +1,64 @@
+package model
+
+import (
+ "fmt"
+ "one-api/common"
+ "one-api/constant"
+ "time"
+)
+
+func cacheSetToken(token Token) error {
+ key := common.GenerateHMAC(token.Key)
+ token.Clean()
+ err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(common.RedisKeyCacheSeconds())*time.Second)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func cacheDeleteToken(key string) error {
+ key = common.GenerateHMAC(key)
+ err := common.RedisDelKey(fmt.Sprintf("token:%s", key))
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func cacheIncrTokenQuota(key string, increment int64) error {
+ key = common.GenerateHMAC(key)
+ err := common.RedisHIncrBy(fmt.Sprintf("token:%s", key), constant.TokenFiledRemainQuota, increment)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func cacheDecrTokenQuota(key string, decrement int64) error {
+ return cacheIncrTokenQuota(key, -decrement)
+}
+
+func cacheSetTokenField(key string, field string, value string) error {
+ key = common.GenerateHMAC(key)
+ err := common.RedisHSetField(fmt.Sprintf("token:%s", key), field, value)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+// CacheGetTokenByKey 从缓存中获取 token,如果缓存中不存在,则从数据库中获取
+func cacheGetTokenByKey(key string) (*Token, error) {
+ hmacKey := common.GenerateHMAC(key)
+ if !common.RedisEnabled {
+ return nil, fmt.Errorf("redis is not enabled")
+ }
+ var token Token
+ err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
+ if err != nil {
+ return nil, err
+ }
+ token.Key = key
+ return &token, nil
+}
diff --git a/model/topup.go b/model/topup.go
new file mode 100644
index 00000000..c34c0ce6
--- /dev/null
+++ b/model/topup.go
@@ -0,0 +1,100 @@
+package model
+
+import (
+ "errors"
+ "fmt"
+ "one-api/common"
+
+ "gorm.io/gorm"
+)
+
+type TopUp struct {
+ Id int `json:"id"`
+ UserId int `json:"user_id" gorm:"index"`
+ Amount int64 `json:"amount"`
+ Money float64 `json:"money"`
+ TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
+ CreateTime int64 `json:"create_time"`
+ CompleteTime int64 `json:"complete_time"`
+ Status string `json:"status"`
+}
+
+func (topUp *TopUp) Insert() error {
+ var err error
+ err = DB.Create(topUp).Error
+ return err
+}
+
+func (topUp *TopUp) Update() error {
+ var err error
+ err = DB.Save(topUp).Error
+ return err
+}
+
+func GetTopUpById(id int) *TopUp {
+ var topUp *TopUp
+ var err error
+ err = DB.Where("id = ?", id).First(&topUp).Error
+ if err != nil {
+ return nil
+ }
+ return topUp
+}
+
+func GetTopUpByTradeNo(tradeNo string) *TopUp {
+ var topUp *TopUp
+ var err error
+ err = DB.Where("trade_no = ?", tradeNo).First(&topUp).Error
+ if err != nil {
+ return nil
+ }
+ return topUp
+}
+
+func Recharge(referenceId string, customerId string) (err error) {
+ if referenceId == "" {
+ return errors.New("未提供支付单号")
+ }
+
+ var quota float64
+ topUp := &TopUp{}
+
+ refCol := "`trade_no`"
+ if common.UsingPostgreSQL {
+ refCol = `"trade_no"`
+ }
+
+ err = DB.Transaction(func(tx *gorm.DB) error {
+ err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", referenceId).First(topUp).Error
+ if err != nil {
+ return errors.New("充值订单不存在")
+ }
+
+ if topUp.Status != common.TopUpStatusPending {
+ return errors.New("充值订单状态错误")
+ }
+
+ topUp.CompleteTime = common.GetTimestamp()
+ topUp.Status = common.TopUpStatusSuccess
+ err = tx.Save(topUp).Error
+ if err != nil {
+ return err
+ }
+
+ quota = topUp.Money * common.QuotaPerUnit
+ err = tx.Model(&User{}).Where("id = ?", topUp.UserId).Updates(map[string]interface{}{"stripe_customer": customerId, "quota": gorm.Expr("quota + ?", quota)}).Error
+ if err != nil {
+ return err
+ }
+
+ return nil
+ })
+
+ if err != nil {
+ return errors.New("充值失败," + err.Error())
+ }
+
+ RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%d", common.FormatQuota(int(quota)), topUp.Amount))
+
+ return nil
+}
diff --git a/model/usedata.go b/model/usedata.go
new file mode 100644
index 00000000..1255b0be
--- /dev/null
+++ b/model/usedata.go
@@ -0,0 +1,133 @@
+package model
+
+import (
+ "fmt"
+ "gorm.io/gorm"
+ "one-api/common"
+ "sync"
+ "time"
+)
+
+// QuotaData 柱状图数据
+type QuotaData struct {
+ Id int `json:"id"`
+ UserID int `json:"user_id" gorm:"index"`
+ Username string `json:"username" gorm:"index:idx_qdt_model_user_name,priority:2;size:64;default:''"`
+ ModelName string `json:"model_name" gorm:"index:idx_qdt_model_user_name,priority:1;size:64;default:''"`
+ CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_qdt_created_at,priority:2"`
+ TokenUsed int `json:"token_used" gorm:"default:0"`
+ Count int `json:"count" gorm:"default:0"`
+ Quota int `json:"quota" gorm:"default:0"`
+}
+
+func UpdateQuotaData() {
+ // recover
+ defer func() {
+ if r := recover(); r != nil {
+ common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
+ }
+ }()
+ for {
+ if common.DataExportEnabled {
+ common.SysLog("正在更新数据看板数据...")
+ SaveQuotaDataCache()
+ }
+ time.Sleep(time.Duration(common.DataExportInterval) * time.Minute)
+ }
+}
+
+var CacheQuotaData = make(map[string]*QuotaData)
+var CacheQuotaDataLock = sync.Mutex{}
+
+func logQuotaDataCache(userId int, username string, modelName string, quota int, createdAt int64, tokenUsed int) {
+ key := fmt.Sprintf("%d-%s-%s-%d", userId, username, modelName, createdAt)
+ quotaData, ok := CacheQuotaData[key]
+ if ok {
+ quotaData.Count += 1
+ quotaData.Quota += quota
+ quotaData.TokenUsed += tokenUsed
+ } else {
+ quotaData = &QuotaData{
+ UserID: userId,
+ Username: username,
+ ModelName: modelName,
+ CreatedAt: createdAt,
+ Count: 1,
+ Quota: quota,
+ TokenUsed: tokenUsed,
+ }
+ }
+ CacheQuotaData[key] = quotaData
+}
+
+func LogQuotaData(userId int, username string, modelName string, quota int, createdAt int64, tokenUsed int) {
+ // 只精确到小时
+ createdAt = createdAt - (createdAt % 3600)
+
+ CacheQuotaDataLock.Lock()
+ defer CacheQuotaDataLock.Unlock()
+ logQuotaDataCache(userId, username, modelName, quota, createdAt, tokenUsed)
+}
+
+func SaveQuotaDataCache() {
+ CacheQuotaDataLock.Lock()
+ defer CacheQuotaDataLock.Unlock()
+ size := len(CacheQuotaData)
+ // 如果缓存中有数据,就保存到数据库中
+ // 1. 先查询数据库中是否有数据
+ // 2. 如果有数据,就更新数据
+ // 3. 如果没有数据,就插入数据
+ for _, quotaData := range CacheQuotaData {
+ quotaDataDB := &QuotaData{}
+ DB.Table("quota_data").Where("user_id = ? and username = ? and model_name = ? and created_at = ?",
+ quotaData.UserID, quotaData.Username, quotaData.ModelName, quotaData.CreatedAt).First(quotaDataDB)
+ if quotaDataDB.Id > 0 {
+ //quotaDataDB.Count += quotaData.Count
+ //quotaDataDB.Quota += quotaData.Quota
+ //DB.Table("quota_data").Save(quotaDataDB)
+ increaseQuotaData(quotaData.UserID, quotaData.Username, quotaData.ModelName, quotaData.Count, quotaData.Quota, quotaData.CreatedAt, quotaData.TokenUsed)
+ } else {
+ DB.Table("quota_data").Create(quotaData)
+ }
+ }
+ CacheQuotaData = make(map[string]*QuotaData)
+ common.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
+}
+
+func increaseQuotaData(userId int, username string, modelName string, count int, quota int, createdAt int64, tokenUsed int) {
+ err := DB.Table("quota_data").Where("user_id = ? and username = ? and model_name = ? and created_at = ?",
+ userId, username, modelName, createdAt).Updates(map[string]interface{}{
+ "count": gorm.Expr("count + ?", count),
+ "quota": gorm.Expr("quota + ?", quota),
+ "token_used": gorm.Expr("token_used + ?", tokenUsed),
+ }).Error
+ if err != nil {
+ common.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err))
+ }
+}
+
+func GetQuotaDataByUsername(username string, startTime int64, endTime int64) (quotaData []*QuotaData, err error) {
+ var quotaDatas []*QuotaData
+ // 从quota_data表中查询数据
+ err = DB.Table("quota_data").Where("username = ? and created_at >= ? and created_at <= ?", username, startTime, endTime).Find("aDatas).Error
+ return quotaDatas, err
+}
+
+func GetQuotaDataByUserId(userId int, startTime int64, endTime int64) (quotaData []*QuotaData, err error) {
+ var quotaDatas []*QuotaData
+ // 从quota_data表中查询数据
+ err = DB.Table("quota_data").Where("user_id = ? and created_at >= ? and created_at <= ?", userId, startTime, endTime).Find("aDatas).Error
+ return quotaDatas, err
+}
+
+func GetAllQuotaDates(startTime int64, endTime int64, username string) (quotaData []*QuotaData, err error) {
+ if username != "" {
+ return GetQuotaDataByUsername(username, startTime, endTime)
+ }
+ var quotaDatas []*QuotaData
+ // 从quota_data表中查询数据
+ // only select model_name, sum(count) as count, sum(quota) as quota, model_name, created_at from quota_data group by model_name, created_at;
+ //err = DB.Table("quota_data").Where("created_at >= ? and created_at <= ?", startTime, endTime).Find("aDatas).Error
+ err = DB.Table("quota_data").Select("model_name, sum(count) as count, sum(quota) as quota, sum(token_used) as token_used, created_at").Where("created_at >= ? and created_at <= ?", startTime, endTime).Group("model_name, created_at").Find("aDatas).Error
+ return quotaDatas, err
+}
diff --git a/model/user.go b/model/user.go
new file mode 100644
index 00000000..6021f495
--- /dev/null
+++ b/model/user.go
@@ -0,0 +1,830 @@
+package model
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "one-api/common"
+ "one-api/dto"
+ "strconv"
+ "strings"
+
+ "github.com/bytedance/gopkg/util/gopool"
+ "gorm.io/gorm"
+)
+
+// User if you add sensitive fields, don't forget to clean them in setupLogin function.
+// Otherwise, the sensitive information will be saved on local storage in plain text!
+type User struct {
+ Id int `json:"id"`
+ Username string `json:"username" gorm:"unique;index" validate:"max=12"`
+ Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
+ OriginalPassword string `json:"original_password" gorm:"-:all"` // this field is only for Password change verification, don't save it to database!
+ DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
+ Role int `json:"role" gorm:"type:int;default:1"` // admin, common
+ Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
+ Email string `json:"email" gorm:"index" validate:"max=50"`
+ GitHubId string `json:"github_id" gorm:"column:github_id;index"`
+ OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
+ WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
+ TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
+ VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
+ AccessToken *string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
+ Quota int `json:"quota" gorm:"type:int;default:0"`
+ UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
+ RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
+ Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
+ AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
+ AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"`
+ AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
+ AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
+ InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
+ DeletedAt gorm.DeletedAt `gorm:"index"`
+ LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
+ Setting string `json:"setting" gorm:"type:text;column:setting"`
+ Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
+ StripeCustomer string `json:"stripe_customer" gorm:"type:varchar(64);column:stripe_customer;index"`
+}
+
+func (user *User) ToBaseUser() *UserBase {
+ cache := &UserBase{
+ Id: user.Id,
+ Group: user.Group,
+ Quota: user.Quota,
+ Status: user.Status,
+ Username: user.Username,
+ Setting: user.Setting,
+ Email: user.Email,
+ }
+ return cache
+}
+
+func (user *User) GetAccessToken() string {
+ if user.AccessToken == nil {
+ return ""
+ }
+ return *user.AccessToken
+}
+
+func (user *User) SetAccessToken(token string) {
+ user.AccessToken = &token
+}
+
+func (user *User) GetSetting() dto.UserSetting {
+ setting := dto.UserSetting{}
+ if user.Setting != "" {
+ err := json.Unmarshal([]byte(user.Setting), &setting)
+ if err != nil {
+ common.SysError("failed to unmarshal setting: " + err.Error())
+ }
+ }
+ return setting
+}
+
+func (user *User) SetSetting(setting dto.UserSetting) {
+ settingBytes, err := json.Marshal(setting)
+ if err != nil {
+ common.SysError("failed to marshal setting: " + err.Error())
+ return
+ }
+ user.Setting = string(settingBytes)
+}
+
+// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
+func CheckUserExistOrDeleted(username string, email string) (bool, error) {
+ var user User
+
+ // err := DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
+ // check email if empty
+ var err error
+ if email == "" {
+ err = DB.Unscoped().First(&user, "username = ?", username).Error
+ } else {
+ err = DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
+ }
+ if err != nil {
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ // not exist, return false, nil
+ return false, nil
+ }
+ // other error, return false, err
+ return false, err
+ }
+ // exist, return true, nil
+ return true, nil
+}
+
+func GetMaxUserId() int {
+ var user User
+ DB.Unscoped().Last(&user)
+ return user.Id
+}
+
+func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err error) {
+ // Start transaction
+ tx := DB.Begin()
+ if tx.Error != nil {
+ return nil, 0, tx.Error
+ }
+ defer func() {
+ if r := recover(); r != nil {
+ tx.Rollback()
+ }
+ }()
+
+ // Get total count within transaction
+ err = tx.Unscoped().Model(&User{}).Count(&total).Error
+ if err != nil {
+ tx.Rollback()
+ return nil, 0, err
+ }
+
+ // Get paginated users within same transaction
+ err = tx.Unscoped().Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("password").Find(&users).Error
+ if err != nil {
+ tx.Rollback()
+ return nil, 0, err
+ }
+
+ // Commit transaction
+ if err = tx.Commit().Error; err != nil {
+ return nil, 0, err
+ }
+
+ return users, total, nil
+}
+
+func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User, int64, error) {
+ var users []*User
+ var total int64
+ var err error
+
+ // 开始事务
+ tx := DB.Begin()
+ if tx.Error != nil {
+ return nil, 0, tx.Error
+ }
+ defer func() {
+ if r := recover(); r != nil {
+ tx.Rollback()
+ }
+ }()
+
+ // 构建基础查询
+ query := tx.Unscoped().Model(&User{})
+
+ // 构建搜索条件
+ likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
+
+ // 尝试将关键字转换为整数ID
+ keywordInt, err := strconv.Atoi(keyword)
+ if err == nil {
+ // 如果是数字,同时搜索ID和其他字段
+ likeCondition = "id = ? OR " + likeCondition
+ if group != "" {
+ query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
+ keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
+ } else {
+ query = query.Where(likeCondition,
+ keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
+ }
+ } else {
+ // 非数字关键字,只搜索字符串字段
+ if group != "" {
+ query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
+ "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
+ } else {
+ query = query.Where(likeCondition,
+ "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
+ }
+ }
+
+ // 获取总数
+ err = query.Count(&total).Error
+ if err != nil {
+ tx.Rollback()
+ return nil, 0, err
+ }
+
+ // 获取分页数据
+ err = query.Omit("password").Order("id desc").Limit(num).Offset(startIdx).Find(&users).Error
+ if err != nil {
+ tx.Rollback()
+ return nil, 0, err
+ }
+
+ // 提交事务
+ if err = tx.Commit().Error; err != nil {
+ return nil, 0, err
+ }
+
+ return users, total, nil
+}
+
+func GetUserById(id int, selectAll bool) (*User, error) {
+ if id == 0 {
+ return nil, errors.New("id 为空!")
+ }
+ user := User{Id: id}
+ var err error = nil
+ if selectAll {
+ err = DB.First(&user, "id = ?", id).Error
+ } else {
+ err = DB.Omit("password").First(&user, "id = ?", id).Error
+ }
+ return &user, err
+}
+
+func GetUserIdByAffCode(affCode string) (int, error) {
+ if affCode == "" {
+ return 0, errors.New("affCode 为空!")
+ }
+ var user User
+ err := DB.Select("id").First(&user, "aff_code = ?", affCode).Error
+ return user.Id, err
+}
+
+func DeleteUserById(id int) (err error) {
+ if id == 0 {
+ return errors.New("id 为空!")
+ }
+ user := User{Id: id}
+ return user.Delete()
+}
+
+func HardDeleteUserById(id int) error {
+ if id == 0 {
+ return errors.New("id 为空!")
+ }
+ err := DB.Unscoped().Delete(&User{}, "id = ?", id).Error
+ return err
+}
+
+func inviteUser(inviterId int) (err error) {
+ user, err := GetUserById(inviterId, true)
+ if err != nil {
+ return err
+ }
+ user.AffCount++
+ user.AffQuota += common.QuotaForInviter
+ user.AffHistoryQuota += common.QuotaForInviter
+ return DB.Save(user).Error
+}
+
+func (user *User) TransferAffQuotaToQuota(quota int) error {
+ // 检查quota是否小于最小额度
+ if float64(quota) < common.QuotaPerUnit {
+ return fmt.Errorf("转移额度最小为%s!", common.LogQuota(int(common.QuotaPerUnit)))
+ }
+
+ // 开始数据库事务
+ tx := DB.Begin()
+ if tx.Error != nil {
+ return tx.Error
+ }
+ defer tx.Rollback() // 确保在函数退出时事务能回滚
+
+ // 加锁查询用户以确保数据一致性
+ err := tx.Set("gorm:query_option", "FOR UPDATE").First(&user, user.Id).Error
+ if err != nil {
+ return err
+ }
+
+ // 再次检查用户的AffQuota是否足够
+ if user.AffQuota < quota {
+ return errors.New("邀请额度不足!")
+ }
+
+ // 更新用户额度
+ user.AffQuota -= quota
+ user.Quota += quota
+
+ // 保存用户状态
+ if err := tx.Save(user).Error; err != nil {
+ return err
+ }
+
+ // 提交事务
+ return tx.Commit().Error
+}
+
+func (user *User) Insert(inviterId int) error {
+ var err error
+ if user.Password != "" {
+ user.Password, err = common.Password2Hash(user.Password)
+ if err != nil {
+ return err
+ }
+ }
+ user.Quota = common.QuotaForNewUser
+ //user.SetAccessToken(common.GetUUID())
+ user.AffCode = common.GetRandomString(4)
+ result := DB.Create(user)
+ if result.Error != nil {
+ return result.Error
+ }
+ if common.QuotaForNewUser > 0 {
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
+ }
+ if inviterId != 0 {
+ if common.QuotaForInvitee > 0 {
+ _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
+ }
+ if common.QuotaForInviter > 0 {
+ //_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
+ RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
+ _ = inviteUser(inviterId)
+ }
+ }
+ return nil
+}
+
+func (user *User) Update(updatePassword bool) error {
+ var err error
+ if updatePassword {
+ user.Password, err = common.Password2Hash(user.Password)
+ if err != nil {
+ return err
+ }
+ }
+ newUser := *user
+ DB.First(&user, user.Id)
+ if err = DB.Model(user).Updates(newUser).Error; err != nil {
+ return err
+ }
+
+ // Update cache
+ return updateUserCache(*user)
+}
+
+func (user *User) Edit(updatePassword bool) error {
+ var err error
+ if updatePassword {
+ user.Password, err = common.Password2Hash(user.Password)
+ if err != nil {
+ return err
+ }
+ }
+
+ newUser := *user
+ updates := map[string]interface{}{
+ "username": newUser.Username,
+ "display_name": newUser.DisplayName,
+ "group": newUser.Group,
+ "quota": newUser.Quota,
+ "remark": newUser.Remark,
+ }
+ if updatePassword {
+ updates["password"] = newUser.Password
+ }
+
+ DB.First(&user, user.Id)
+ if err = DB.Model(user).Updates(updates).Error; err != nil {
+ return err
+ }
+
+ // Update cache
+ return updateUserCache(*user)
+}
+
+func (user *User) Delete() error {
+ if user.Id == 0 {
+ return errors.New("id 为空!")
+ }
+ if err := DB.Delete(user).Error; err != nil {
+ return err
+ }
+
+ // 清除缓存
+ return invalidateUserCache(user.Id)
+}
+
+func (user *User) HardDelete() error {
+ if user.Id == 0 {
+ return errors.New("id 为空!")
+ }
+ err := DB.Unscoped().Delete(user).Error
+ return err
+}
+
+// ValidateAndFill check password & user status
+func (user *User) ValidateAndFill() (err error) {
+ // When querying with struct, GORM will only query with non-zero fields,
+ // that means if your field's value is 0, '', false or other zero values,
+ // it won't be used to build query conditions
+ password := user.Password
+ username := strings.TrimSpace(user.Username)
+ if username == "" || password == "" {
+ return errors.New("用户名或密码为空")
+ }
+ // find buy username or email
+ DB.Where("username = ? OR email = ?", username, username).First(user)
+ okay := common.ValidatePasswordAndHash(password, user.Password)
+ if !okay || user.Status != common.UserStatusEnabled {
+ return errors.New("用户名或密码错误,或用户已被封禁")
+ }
+ return nil
+}
+
+func (user *User) FillUserById() error {
+ if user.Id == 0 {
+ return errors.New("id 为空!")
+ }
+ DB.Where(User{Id: user.Id}).First(user)
+ return nil
+}
+
+func (user *User) FillUserByEmail() error {
+ if user.Email == "" {
+ return errors.New("email 为空!")
+ }
+ DB.Where(User{Email: user.Email}).First(user)
+ return nil
+}
+
+func (user *User) FillUserByGitHubId() error {
+ if user.GitHubId == "" {
+ return errors.New("GitHub id 为空!")
+ }
+ DB.Where(User{GitHubId: user.GitHubId}).First(user)
+ return nil
+}
+
+func (user *User) FillUserByOidcId() error {
+ if user.OidcId == "" {
+ return errors.New("oidc id 为空!")
+ }
+ DB.Where(User{OidcId: user.OidcId}).First(user)
+ return nil
+}
+
+func (user *User) FillUserByWeChatId() error {
+ if user.WeChatId == "" {
+ return errors.New("WeChat id 为空!")
+ }
+ DB.Where(User{WeChatId: user.WeChatId}).First(user)
+ return nil
+}
+
+func (user *User) FillUserByTelegramId() error {
+ if user.TelegramId == "" {
+ return errors.New("Telegram id 为空!")
+ }
+ err := DB.Where(User{TelegramId: user.TelegramId}).First(user).Error
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ return errors.New("该 Telegram 账户未绑定")
+ }
+ return nil
+}
+
+func IsEmailAlreadyTaken(email string) bool {
+ return DB.Unscoped().Where("email = ?", email).Find(&User{}).RowsAffected == 1
+}
+
+func IsWeChatIdAlreadyTaken(wechatId string) bool {
+ return DB.Unscoped().Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
+}
+
+func IsGitHubIdAlreadyTaken(githubId string) bool {
+ return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
+}
+
+func IsOidcIdAlreadyTaken(oidcId string) bool {
+ return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
+}
+
+func IsTelegramIdAlreadyTaken(telegramId string) bool {
+ return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
+}
+
+func ResetUserPasswordByEmail(email string, password string) error {
+ if email == "" || password == "" {
+ return errors.New("邮箱地址或密码为空!")
+ }
+ hashedPassword, err := common.Password2Hash(password)
+ if err != nil {
+ return err
+ }
+ err = DB.Model(&User{}).Where("email = ?", email).Update("password", hashedPassword).Error
+ return err
+}
+
+func IsAdmin(userId int) bool {
+ if userId == 0 {
+ return false
+ }
+ var user User
+ err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
+ if err != nil {
+ common.SysError("no such user " + err.Error())
+ return false
+ }
+ return user.Role >= common.RoleAdminUser
+}
+
+//// IsUserEnabled checks user status from Redis first, falls back to DB if needed
+//func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
+// defer func() {
+// // Update Redis cache asynchronously on successful DB read
+// if shouldUpdateRedis(fromDB, err) {
+// gopool.Go(func() {
+// if err := updateUserStatusCache(id, status); err != nil {
+// common.SysError("failed to update user status cache: " + err.Error())
+// }
+// })
+// }
+// }()
+// if !fromDB && common.RedisEnabled {
+// // Try Redis first
+// status, err := getUserStatusCache(id)
+// if err == nil {
+// return status == common.UserStatusEnabled, nil
+// }
+// // Don't return error - fall through to DB
+// }
+// fromDB = true
+// var user User
+// err = DB.Where("id = ?", id).Select("status").Find(&user).Error
+// if err != nil {
+// return false, err
+// }
+//
+// return user.Status == common.UserStatusEnabled, nil
+//}
+
+func ValidateAccessToken(token string) (user *User) {
+ if token == "" {
+ return nil
+ }
+ token = strings.Replace(token, "Bearer ", "", 1)
+ user = &User{}
+ if DB.Where("access_token = ?", token).First(user).RowsAffected == 1 {
+ return user
+ }
+ return nil
+}
+
+// GetUserQuota gets quota from Redis first, falls back to DB if needed
+func GetUserQuota(id int, fromDB bool) (quota int, err error) {
+ defer func() {
+ // Update Redis cache asynchronously on successful DB read
+ if shouldUpdateRedis(fromDB, err) {
+ gopool.Go(func() {
+ if err := updateUserQuotaCache(id, quota); err != nil {
+ common.SysError("failed to update user quota cache: " + err.Error())
+ }
+ })
+ }
+ }()
+ if !fromDB && common.RedisEnabled {
+ quota, err := getUserQuotaCache(id)
+ if err == nil {
+ return quota, nil
+ }
+ // Don't return error - fall through to DB
+ }
+ fromDB = true
+ err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
+ if err != nil {
+ return 0, err
+ }
+
+ return quota, nil
+}
+
+func GetUserUsedQuota(id int) (quota int, err error) {
+ err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error
+ return quota, err
+}
+
+func GetUserEmail(id int) (email string, err error) {
+ err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error
+ return email, err
+}
+
+// GetUserGroup gets group from Redis first, falls back to DB if needed
+func GetUserGroup(id int, fromDB bool) (group string, err error) {
+ defer func() {
+ // Update Redis cache asynchronously on successful DB read
+ if shouldUpdateRedis(fromDB, err) {
+ gopool.Go(func() {
+ if err := updateUserGroupCache(id, group); err != nil {
+ common.SysError("failed to update user group cache: " + err.Error())
+ }
+ })
+ }
+ }()
+ if !fromDB && common.RedisEnabled {
+ group, err := getUserGroupCache(id)
+ if err == nil {
+ return group, nil
+ }
+ // Don't return error - fall through to DB
+ }
+ fromDB = true
+ err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error
+ if err != nil {
+ return "", err
+ }
+
+ return group, nil
+}
+
+// GetUserSetting gets setting from Redis first, falls back to DB if needed
+func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) {
+ var setting string
+ defer func() {
+ // Update Redis cache asynchronously on successful DB read
+ if shouldUpdateRedis(fromDB, err) {
+ gopool.Go(func() {
+ if err := updateUserSettingCache(id, setting); err != nil {
+ common.SysError("failed to update user setting cache: " + err.Error())
+ }
+ })
+ }
+ }()
+ if !fromDB && common.RedisEnabled {
+ setting, err := getUserSettingCache(id)
+ if err == nil {
+ return setting, nil
+ }
+ // Don't return error - fall through to DB
+ }
+ fromDB = true
+ err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
+ if err != nil {
+ return settingMap, err
+ }
+ userBase := &UserBase{
+ Setting: setting,
+ }
+ return userBase.GetSetting(), nil
+}
+
+func IncreaseUserQuota(id int, quota int, db bool) (err error) {
+ if quota < 0 {
+ return errors.New("quota 不能为负数!")
+ }
+ gopool.Go(func() {
+ err := cacheIncrUserQuota(id, int64(quota))
+ if err != nil {
+ common.SysError("failed to increase user quota: " + err.Error())
+ }
+ })
+ if !db && common.BatchUpdateEnabled {
+ addNewRecord(BatchUpdateTypeUserQuota, id, quota)
+ return nil
+ }
+ return increaseUserQuota(id, quota)
+}
+
+func increaseUserQuota(id int, quota int) (err error) {
+ err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
+ if err != nil {
+ return err
+ }
+ return err
+}
+
+func DecreaseUserQuota(id int, quota int) (err error) {
+ if quota < 0 {
+ return errors.New("quota 不能为负数!")
+ }
+ gopool.Go(func() {
+ err := cacheDecrUserQuota(id, int64(quota))
+ if err != nil {
+ common.SysError("failed to decrease user quota: " + err.Error())
+ }
+ })
+ if common.BatchUpdateEnabled {
+ addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
+ return nil
+ }
+ return decreaseUserQuota(id, quota)
+}
+
+func decreaseUserQuota(id int, quota int) (err error) {
+ err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
+ if err != nil {
+ return err
+ }
+ return err
+}
+
+func DeltaUpdateUserQuota(id int, delta int) (err error) {
+ if delta == 0 {
+ return nil
+ }
+ if delta > 0 {
+ return IncreaseUserQuota(id, delta, false)
+ } else {
+ return DecreaseUserQuota(id, -delta)
+ }
+}
+
+//func GetRootUserEmail() (email string) {
+// DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
+// return email
+//}
+
+func GetRootUser() (user *User) {
+ DB.Where("role = ?", common.RoleRootUser).First(&user)
+ return user
+}
+
+func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
+ if common.BatchUpdateEnabled {
+ addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
+ addNewRecord(BatchUpdateTypeRequestCount, id, 1)
+ return
+ }
+ updateUserUsedQuotaAndRequestCount(id, quota, 1)
+}
+
+func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
+ err := DB.Model(&User{}).Where("id = ?", id).Updates(
+ map[string]interface{}{
+ "used_quota": gorm.Expr("used_quota + ?", quota),
+ "request_count": gorm.Expr("request_count + ?", count),
+ },
+ ).Error
+ if err != nil {
+ common.SysError("failed to update user used quota and request count: " + err.Error())
+ return
+ }
+
+ //// 更新缓存
+ //if err := invalidateUserCache(id); err != nil {
+ // common.SysError("failed to invalidate user cache: " + err.Error())
+ //}
+}
+
+func updateUserUsedQuota(id int, quota int) {
+ err := DB.Model(&User{}).Where("id = ?", id).Updates(
+ map[string]interface{}{
+ "used_quota": gorm.Expr("used_quota + ?", quota),
+ },
+ ).Error
+ if err != nil {
+ common.SysError("failed to update user used quota: " + err.Error())
+ }
+}
+
+func updateUserRequestCount(id int, count int) {
+ err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
+ if err != nil {
+ common.SysError("failed to update user request count: " + err.Error())
+ }
+}
+
+// GetUsernameById gets username from Redis first, falls back to DB if needed
+func GetUsernameById(id int, fromDB bool) (username string, err error) {
+ defer func() {
+ // Update Redis cache asynchronously on successful DB read
+ if shouldUpdateRedis(fromDB, err) {
+ gopool.Go(func() {
+ if err := updateUserNameCache(id, username); err != nil {
+ common.SysError("failed to update user name cache: " + err.Error())
+ }
+ })
+ }
+ }()
+ if !fromDB && common.RedisEnabled {
+ username, err := getUserNameCache(id)
+ if err == nil {
+ return username, nil
+ }
+ // Don't return error - fall through to DB
+ }
+ fromDB = true
+ err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
+ if err != nil {
+ return "", err
+ }
+
+ return username, nil
+}
+
+func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
+ var user User
+ err := DB.Unscoped().Where("linux_do_id = ?", linuxDOId).First(&user).Error
+ return !errors.Is(err, gorm.ErrRecordNotFound)
+}
+
+func (user *User) FillUserByLinuxDOId() error {
+ if user.LinuxDOId == "" {
+ return errors.New("linux do id is empty")
+ }
+ err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
+ return err
+}
+
+func RootUserExists() bool {
+ var user User
+ err := DB.Where("role = ?", common.RoleRootUser).First(&user).Error
+ if err != nil {
+ return false
+ }
+ return true
+}
diff --git a/model/user_cache.go b/model/user_cache.go
new file mode 100644
index 00000000..a631457c
--- /dev/null
+++ b/model/user_cache.go
@@ -0,0 +1,218 @@
+package model
+
+import (
+ "fmt"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "time"
+
+ "github.com/gin-gonic/gin"
+
+ "github.com/bytedance/gopkg/util/gopool"
+)
+
+// UserBase struct remains the same as it represents the cached data structure
+type UserBase struct {
+ Id int `json:"id"`
+ Group string `json:"group"`
+ Email string `json:"email"`
+ Quota int `json:"quota"`
+ Status int `json:"status"`
+ Username string `json:"username"`
+ Setting string `json:"setting"`
+}
+
+func (user *UserBase) WriteContext(c *gin.Context) {
+ common.SetContextKey(c, constant.ContextKeyUserGroup, user.Group)
+ common.SetContextKey(c, constant.ContextKeyUserQuota, user.Quota)
+ common.SetContextKey(c, constant.ContextKeyUserStatus, user.Status)
+ common.SetContextKey(c, constant.ContextKeyUserEmail, user.Email)
+ common.SetContextKey(c, constant.ContextKeyUserName, user.Username)
+ common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
+}
+
+func (user *UserBase) GetSetting() dto.UserSetting {
+ setting := dto.UserSetting{}
+ if user.Setting != "" {
+ err := common.Unmarshal([]byte(user.Setting), &setting)
+ if err != nil {
+ common.SysError("failed to unmarshal setting: " + err.Error())
+ }
+ }
+ return setting
+}
+
+// getUserCacheKey returns the key for user cache
+func getUserCacheKey(userId int) string {
+ return fmt.Sprintf("user:%d", userId)
+}
+
+// invalidateUserCache clears user cache
+func invalidateUserCache(userId int) error {
+ if !common.RedisEnabled {
+ return nil
+ }
+ return common.RedisDelKey(getUserCacheKey(userId))
+}
+
+// updateUserCache updates all user cache fields using hash
+func updateUserCache(user User) error {
+ if !common.RedisEnabled {
+ return nil
+ }
+
+ return common.RedisHSetObj(
+ getUserCacheKey(user.Id),
+ user.ToBaseUser(),
+ time.Duration(common.RedisKeyCacheSeconds())*time.Second,
+ )
+}
+
+// GetUserCache gets complete user cache from hash
+func GetUserCache(userId int) (userCache *UserBase, err error) {
+ var user *User
+ var fromDB bool
+ defer func() {
+ // Update Redis cache asynchronously on successful DB read
+ if shouldUpdateRedis(fromDB, err) && user != nil {
+ gopool.Go(func() {
+ if err := updateUserCache(*user); err != nil {
+ common.SysError("failed to update user status cache: " + err.Error())
+ }
+ })
+ }
+ }()
+
+ // Try getting from Redis first
+ userCache, err = cacheGetUserBase(userId)
+ if err == nil {
+ return userCache, nil
+ }
+
+ // If Redis fails, get from DB
+ fromDB = true
+ user, err = GetUserById(userId, false)
+ if err != nil {
+ return nil, err // Return nil and error if DB lookup fails
+ }
+
+ // Create cache object from user data
+ userCache = &UserBase{
+ Id: user.Id,
+ Group: user.Group,
+ Quota: user.Quota,
+ Status: user.Status,
+ Username: user.Username,
+ Setting: user.Setting,
+ Email: user.Email,
+ }
+
+ return userCache, nil
+}
+
+func cacheGetUserBase(userId int) (*UserBase, error) {
+ if !common.RedisEnabled {
+ return nil, fmt.Errorf("redis is not enabled")
+ }
+ var userCache UserBase
+ // Try getting from Redis first
+ err := common.RedisHGetObj(getUserCacheKey(userId), &userCache)
+ if err != nil {
+ return nil, err
+ }
+ return &userCache, nil
+}
+
+// Add atomic quota operations using hash fields
+func cacheIncrUserQuota(userId int, delta int64) error {
+ if !common.RedisEnabled {
+ return nil
+ }
+ return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta)
+}
+
+func cacheDecrUserQuota(userId int, delta int64) error {
+ return cacheIncrUserQuota(userId, -delta)
+}
+
+// Helper functions to get individual fields if needed
+func getUserGroupCache(userId int) (string, error) {
+ cache, err := GetUserCache(userId)
+ if err != nil {
+ return "", err
+ }
+ return cache.Group, nil
+}
+
+func getUserQuotaCache(userId int) (int, error) {
+ cache, err := GetUserCache(userId)
+ if err != nil {
+ return 0, err
+ }
+ return cache.Quota, nil
+}
+
+func getUserStatusCache(userId int) (int, error) {
+ cache, err := GetUserCache(userId)
+ if err != nil {
+ return 0, err
+ }
+ return cache.Status, nil
+}
+
+func getUserNameCache(userId int) (string, error) {
+ cache, err := GetUserCache(userId)
+ if err != nil {
+ return "", err
+ }
+ return cache.Username, nil
+}
+
+func getUserSettingCache(userId int) (dto.UserSetting, error) {
+ cache, err := GetUserCache(userId)
+ if err != nil {
+ return dto.UserSetting{}, err
+ }
+ return cache.GetSetting(), nil
+}
+
+// New functions for individual field updates
+func updateUserStatusCache(userId int, status bool) error {
+ if !common.RedisEnabled {
+ return nil
+ }
+ statusInt := common.UserStatusEnabled
+ if !status {
+ statusInt = common.UserStatusDisabled
+ }
+ return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt))
+}
+
+func updateUserQuotaCache(userId int, quota int) error {
+ if !common.RedisEnabled {
+ return nil
+ }
+ return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota))
+}
+
+func updateUserGroupCache(userId int, group string) error {
+ if !common.RedisEnabled {
+ return nil
+ }
+ return common.RedisHSetField(getUserCacheKey(userId), "Group", group)
+}
+
+func updateUserNameCache(userId int, username string) error {
+ if !common.RedisEnabled {
+ return nil
+ }
+ return common.RedisHSetField(getUserCacheKey(userId), "Username", username)
+}
+
+func updateUserSettingCache(userId int, setting string) error {
+ if !common.RedisEnabled {
+ return nil
+ }
+ return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting)
+}
diff --git a/model/utils.go b/model/utils.go
new file mode 100644
index 00000000..1f8a0963
--- /dev/null
+++ b/model/utils.go
@@ -0,0 +1,111 @@
+package model
+
+import (
+ "errors"
+ "one-api/common"
+ "sync"
+ "time"
+
+ "github.com/bytedance/gopkg/util/gopool"
+ "gorm.io/gorm"
+)
+
+const (
+ BatchUpdateTypeUserQuota = iota
+ BatchUpdateTypeTokenQuota
+ BatchUpdateTypeUsedQuota
+ BatchUpdateTypeChannelUsedQuota
+ BatchUpdateTypeRequestCount
+ BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
+)
+
+var batchUpdateStores []map[int]int
+var batchUpdateLocks []sync.Mutex
+
+func init() {
+ for i := 0; i < BatchUpdateTypeCount; i++ {
+ batchUpdateStores = append(batchUpdateStores, make(map[int]int))
+ batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
+ }
+}
+
+func InitBatchUpdater() {
+ gopool.Go(func() {
+ for {
+ time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
+ batchUpdate()
+ }
+ })
+}
+
+func addNewRecord(type_ int, id int, value int) {
+ batchUpdateLocks[type_].Lock()
+ defer batchUpdateLocks[type_].Unlock()
+ if _, ok := batchUpdateStores[type_][id]; !ok {
+ batchUpdateStores[type_][id] = value
+ } else {
+ batchUpdateStores[type_][id] += value
+ }
+}
+
+func batchUpdate() {
+ // check if there's any data to update
+ hasData := false
+ for i := 0; i < BatchUpdateTypeCount; i++ {
+ batchUpdateLocks[i].Lock()
+ if len(batchUpdateStores[i]) > 0 {
+ hasData = true
+ batchUpdateLocks[i].Unlock()
+ break
+ }
+ batchUpdateLocks[i].Unlock()
+ }
+
+ if !hasData {
+ return
+ }
+
+ common.SysLog("batch update started")
+ for i := 0; i < BatchUpdateTypeCount; i++ {
+ batchUpdateLocks[i].Lock()
+ store := batchUpdateStores[i]
+ batchUpdateStores[i] = make(map[int]int)
+ batchUpdateLocks[i].Unlock()
+ // TODO: maybe we can combine updates with same key?
+ for key, value := range store {
+ switch i {
+ case BatchUpdateTypeUserQuota:
+ err := increaseUserQuota(key, value)
+ if err != nil {
+ common.SysError("failed to batch update user quota: " + err.Error())
+ }
+ case BatchUpdateTypeTokenQuota:
+ err := increaseTokenQuota(key, value)
+ if err != nil {
+ common.SysError("failed to batch update token quota: " + err.Error())
+ }
+ case BatchUpdateTypeUsedQuota:
+ updateUserUsedQuota(key, value)
+ case BatchUpdateTypeRequestCount:
+ updateUserRequestCount(key, value)
+ case BatchUpdateTypeChannelUsedQuota:
+ updateChannelUsedQuota(key, value)
+ }
+ }
+ }
+ common.SysLog("batch update finished")
+}
+
+func RecordExist(err error) (bool, error) {
+ if err == nil {
+ return true, nil
+ }
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ return false, nil
+ }
+ return false, err
+}
+
+func shouldUpdateRedis(fromDB bool, err error) bool {
+ return common.RedisEnabled && fromDB && err == nil
+}
diff --git a/one-api.service b/one-api.service
new file mode 100644
index 00000000..17e236bc
--- /dev/null
+++ b/one-api.service
@@ -0,0 +1,18 @@
+# File path: /etc/systemd/system/one-api.service
+# sudo systemctl daemon-reload
+# sudo systemctl start one-api
+# sudo systemctl enable one-api
+# sudo systemctl status one-api
+[Unit]
+Description=One API Service
+After=network.target
+
+[Service]
+User=ubuntu # 注意修改用户名
+WorkingDirectory=/path/to/one-api # 注意修改路径
+ExecStart=/path/to/one-api/one-api --port 3000 --log-dir /path/to/one-api/logs # 注意修改路径和端口号
+Restart=always
+RestartSec=5
+
+[Install]
+WantedBy=multi-user.target
diff --git a/relay/audio_handler.go b/relay/audio_handler.go
new file mode 100644
index 00000000..f39dbd82
--- /dev/null
+++ b/relay/audio_handler.go
@@ -0,0 +1,134 @@
+package relay
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
+ audioRequest := &dto.AudioRequest{}
+ err := common.UnmarshalBodyReusable(c, audioRequest)
+ if err != nil {
+ return nil, err
+ }
+ switch info.RelayMode {
+ case relayconstant.RelayModeAudioSpeech:
+ if audioRequest.Model == "" {
+ return nil, errors.New("model is required")
+ }
+ if setting.ShouldCheckPromptSensitive() {
+ words, err := service.CheckSensitiveInput(audioRequest.Input)
+ if err != nil {
+ common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
+ return nil, err
+ }
+ }
+ default:
+ err = c.Request.ParseForm()
+ if err != nil {
+ return nil, err
+ }
+ formData := c.Request.PostForm
+ if audioRequest.Model == "" {
+ audioRequest.Model = formData.Get("model")
+ }
+
+ if audioRequest.Model == "" {
+ return nil, errors.New("model is required")
+ }
+ audioRequest.ResponseFormat = formData.Get("response_format")
+ if audioRequest.ResponseFormat == "" {
+ audioRequest.ResponseFormat = "json"
+ }
+ }
+ return audioRequest, nil
+}
+
+func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
+ relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
+ audioRequest, err := getAndValidAudioRequest(c, relayInfo)
+
+ if err != nil {
+ common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
+ return types.NewError(err, types.ErrorCodeInvalidRequest)
+ }
+
+ promptTokens := 0
+ preConsumedTokens := common.PreConsumedQuota
+ if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
+ promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
+ preConsumedTokens = promptTokens
+ relayInfo.PromptTokens = promptTokens
+ }
+
+ priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeModelPriceError)
+ }
+
+ preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+ if openaiErr != nil {
+ return openaiErr
+ }
+ defer func() {
+ if openaiErr != nil {
+ returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+ }
+ }()
+
+ err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+ }
+
+ adaptor := GetAdaptor(relayInfo.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+ }
+ adaptor.Init(relayInfo)
+
+ ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+
+ resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeDoRequestFailed)
+ }
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+
+ var httpResp *http.Response
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ if httpResp.StatusCode != http.StatusOK {
+ newAPIError = service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ }
+
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+ if newAPIError != nil {
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+
+ return nil
+}
diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go
new file mode 100644
index 00000000..ab8836ba
--- /dev/null
+++ b/relay/channel/adapter.go
@@ -0,0 +1,50 @@
+package channel
+
+import (
+ "io"
+ "net/http"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor interface {
+ // Init IsStream bool
+ Init(info *relaycommon.RelayInfo)
+ GetRequestURL(info *relaycommon.RelayInfo) (string, error)
+ SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
+ ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
+ ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
+ ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
+ ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
+ ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
+ ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error)
+ DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
+ DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError)
+ GetModelList() []string
+ GetChannelName() string
+ ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
+}
+
+type TaskAdaptor interface {
+ Init(info *relaycommon.TaskRelayInfo)
+
+ ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError
+
+ BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error)
+ BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error
+ BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error)
+
+ DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error)
+ DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
+
+ GetModelList() []string
+ GetChannelName() string
+
+ // FetchTask
+ FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
+
+ ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
+}
diff --git a/relay/channel/ai360/constants.go b/relay/channel/ai360/constants.go
new file mode 100644
index 00000000..4b09dd56
--- /dev/null
+++ b/relay/channel/ai360/constants.go
@@ -0,0 +1,14 @@
+package ai360
+
+var ModelList = []string{
+ "360gpt-turbo",
+ "360gpt-turbo-responsibility-8k",
+ "360gpt-pro",
+ "360gpt2-pro",
+ "360GPT_S2_V9",
+ "embedding-bert-512-v1",
+ "embedding_s1_v1",
+ "semantic_similarity_s1_v1",
+}
+
+var ChannelName = "ai360"
diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go
new file mode 100644
index 00000000..d941a1bc
--- /dev/null
+++ b/relay/channel/ali/adaptor.go
@@ -0,0 +1,127 @@
+package ali
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ var fullRequestURL string
+ switch info.RelayMode {
+ case constant.RelayModeEmbeddings:
+ fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
+ case constant.RelayModeRerank:
+ fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
+ case constant.RelayModeImagesGenerations:
+ fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
+ case constant.RelayModeCompletions:
+ fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
+ default:
+ fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
+ }
+ return fullRequestURL, nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", "Bearer "+info.ApiKey)
+ if info.IsStream {
+ req.Set("X-DashScope-SSE", "enable")
+ }
+ if c.GetString("plugin") != "" {
+ req.Set("X-DashScope-Plugin", c.GetString("plugin"))
+ }
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+
+ // fix: ali parameter.enable_thinking must be set to false for non-streaming calls
+ if !info.IsStream {
+ request.EnableThinking = false
+ }
+
+ switch info.RelayMode {
+ default:
+ aliReq := requestOpenAI2Ali(*request)
+ return aliReq, nil
+ }
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ aliRequest := oaiImage2Ali(request)
+ return aliRequest, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return ConvertRerankRequest(request), nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ switch info.RelayMode {
+ case constant.RelayModeImagesGenerations:
+ err, usage = aliImageHandler(c, resp, info)
+ case constant.RelayModeEmbeddings:
+ err, usage = aliEmbeddingHandler(c, resp)
+ case constant.RelayModeRerank:
+ err, usage = RerankHandler(c, resp, info)
+ default:
+ if info.IsStream {
+ usage, err = openai.OaiStreamHandler(c, info, resp)
+ } else {
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/ali/constants.go b/relay/channel/ali/constants.go
new file mode 100644
index 00000000..df64439b
--- /dev/null
+++ b/relay/channel/ali/constants.go
@@ -0,0 +1,14 @@
+package ali
+
+var ModelList = []string{
+ "qwen-turbo",
+ "qwen-plus",
+ "qwen-max",
+ "qwen-max-longcontext",
+ "qwq-32b",
+ "qwen3-235b-a22b",
+ "text-embedding-v1",
+ "gte-rerank-v2",
+}
+
+var ChannelName = "ali"
diff --git a/relay/channel/ali/dto.go b/relay/channel/ali/dto.go
new file mode 100644
index 00000000..dbd18968
--- /dev/null
+++ b/relay/channel/ali/dto.go
@@ -0,0 +1,126 @@
+package ali
+
+import "one-api/dto"
+
+type AliMessage struct {
+ Content string `json:"content"`
+ Role string `json:"role"`
+}
+
+type AliInput struct {
+ Prompt string `json:"prompt,omitempty"`
+ //History []AliMessage `json:"history,omitempty"`
+ Messages []AliMessage `json:"messages"`
+}
+
+type AliParameters struct {
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ Seed uint64 `json:"seed,omitempty"`
+ EnableSearch bool `json:"enable_search,omitempty"`
+ IncrementalOutput bool `json:"incremental_output,omitempty"`
+}
+
+type AliChatRequest struct {
+ Model string `json:"model"`
+ Input AliInput `json:"input,omitempty"`
+ Parameters AliParameters `json:"parameters,omitempty"`
+}
+
+type AliEmbeddingRequest struct {
+ Model string `json:"model"`
+ Input struct {
+ Texts []string `json:"texts"`
+ } `json:"input"`
+ Parameters *struct {
+ TextType string `json:"text_type,omitempty"`
+ } `json:"parameters,omitempty"`
+}
+
+type AliEmbedding struct {
+ Embedding []float64 `json:"embedding"`
+ TextIndex int `json:"text_index"`
+}
+
+type AliEmbeddingResponse struct {
+ Output struct {
+ Embeddings []AliEmbedding `json:"embeddings"`
+ } `json:"output"`
+ Usage AliUsage `json:"usage"`
+ AliError
+}
+
+type AliError struct {
+ Code string `json:"code"`
+ Message string `json:"message"`
+ RequestId string `json:"request_id"`
+}
+
+type AliUsage struct {
+ InputTokens int `json:"input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+ TotalTokens int `json:"total_tokens"`
+}
+
+type TaskResult struct {
+ B64Image string `json:"b64_image,omitempty"`
+ Url string `json:"url,omitempty"`
+ Code string `json:"code,omitempty"`
+ Message string `json:"message,omitempty"`
+}
+
+type AliOutput struct {
+ TaskId string `json:"task_id,omitempty"`
+ TaskStatus string `json:"task_status,omitempty"`
+ Text string `json:"text"`
+ FinishReason string `json:"finish_reason"`
+ Message string `json:"message,omitempty"`
+ Code string `json:"code,omitempty"`
+ Results []TaskResult `json:"results,omitempty"`
+}
+
+type AliResponse struct {
+ Output AliOutput `json:"output"`
+ Usage AliUsage `json:"usage"`
+ AliError
+}
+
+type AliImageRequest struct {
+ Model string `json:"model"`
+ Input struct {
+ Prompt string `json:"prompt"`
+ NegativePrompt string `json:"negative_prompt,omitempty"`
+ } `json:"input"`
+ Parameters struct {
+ Size string `json:"size,omitempty"`
+ N int `json:"n,omitempty"`
+ Steps string `json:"steps,omitempty"`
+ Scale string `json:"scale,omitempty"`
+ } `json:"parameters,omitempty"`
+ ResponseFormat string `json:"response_format,omitempty"`
+}
+
+type AliRerankParameters struct {
+ TopN *int `json:"top_n,omitempty"`
+ ReturnDocuments *bool `json:"return_documents,omitempty"`
+}
+
+type AliRerankInput struct {
+ Query string `json:"query"`
+ Documents []any `json:"documents"`
+}
+
+type AliRerankRequest struct {
+ Model string `json:"model"`
+ Input AliRerankInput `json:"input"`
+ Parameters AliRerankParameters `json:"parameters,omitempty"`
+}
+
+type AliRerankResponse struct {
+ Output struct {
+ Results []dto.RerankResponseResult `json:"results"`
+ } `json:"output"`
+ Usage AliUsage `json:"usage"`
+ RequestId string `json:"request_id"`
+ AliError
+}
diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go
new file mode 100644
index 00000000..0d430c62
--- /dev/null
+++ b/relay/channel/ali/image.go
@@ -0,0 +1,171 @@
+package ali
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+ "one-api/types"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
+ var imageRequest AliImageRequest
+ imageRequest.Input.Prompt = request.Prompt
+ imageRequest.Model = request.Model
+ imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
+ imageRequest.Parameters.N = request.N
+ imageRequest.ResponseFormat = request.ResponseFormat
+
+ return &imageRequest
+}
+
+func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
+ url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID)
+
+ var aliResponse AliResponse
+
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ return &aliResponse, err, nil
+ }
+
+ req.Header.Set("Authorization", "Bearer "+info.ApiKey)
+
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ common.SysError("updateTask client.Do err: " + err.Error())
+ return &aliResponse, err, nil
+ }
+ defer resp.Body.Close()
+
+ responseBody, err := io.ReadAll(resp.Body)
+
+ var response AliResponse
+ err = json.Unmarshal(responseBody, &response)
+ if err != nil {
+ common.SysError("updateTask NewDecoder err: " + err.Error())
+ return &aliResponse, err, nil
+ }
+
+ return &response, nil, responseBody
+}
+
+func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
+ waitSeconds := 3
+ step := 0
+ maxStep := 20
+
+ var taskResponse AliResponse
+ var responseBody []byte
+
+ for {
+ step++
+ rsp, err, body := updateTask(info, taskID)
+ responseBody = body
+ if err != nil {
+ return &taskResponse, responseBody, err
+ }
+
+ if rsp.Output.TaskStatus == "" {
+ return &taskResponse, responseBody, nil
+ }
+
+ switch rsp.Output.TaskStatus {
+ case "FAILED":
+ fallthrough
+ case "CANCELED":
+ fallthrough
+ case "SUCCEEDED":
+ fallthrough
+ case "UNKNOWN":
+ return rsp, responseBody, nil
+ }
+ if step >= maxStep {
+ break
+ }
+ time.Sleep(time.Duration(waitSeconds) * time.Second)
+ }
+
+ return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
+}
+
+func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
+ imageResponse := dto.ImageResponse{
+ Created: info.StartTime.Unix(),
+ }
+
+ for _, data := range response.Output.Results {
+ var b64Json string
+ if responseFormat == "b64_json" {
+ _, b64, err := service.GetImageFromUrl(data.Url)
+ if err != nil {
+ common.LogError(c, "get_image_data_failed: "+err.Error())
+ continue
+ }
+ b64Json = b64
+ } else {
+ b64Json = data.B64Image
+ }
+
+ imageResponse.Data = append(imageResponse.Data, dto.ImageData{
+ Url: data.Url,
+ B64Json: b64Json,
+ RevisedPrompt: "",
+ })
+ }
+ return &imageResponse
+}
+
+func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
+ responseFormat := c.GetString("response_format")
+
+ var aliTaskResponse AliResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
+ }
+ common.CloseResponseBodyGracefully(resp)
+ err = json.Unmarshal(responseBody, &aliTaskResponse)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+
+ if aliTaskResponse.Message != "" {
+ common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
+ return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
+ }
+
+ aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponse), nil
+ }
+
+ if aliResponse.Output.TaskStatus != "SUCCEEDED" {
+ return types.WithOpenAIError(types.OpenAIError{
+ Message: aliResponse.Output.Message,
+ Type: "ali_error",
+ Param: "",
+ Code: aliResponse.Output.Code,
+ }, resp.StatusCode), nil
+ }
+
+ fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ c.Writer.Write(jsonResponse)
+ return nil, &dto.Usage{}
+}
diff --git a/relay/channel/ali/rerank.go b/relay/channel/ali/rerank.go
new file mode 100644
index 00000000..59cb0a11
--- /dev/null
+++ b/relay/channel/ali/rerank.go
@@ -0,0 +1,74 @@
+package ali
+
+import (
+ "encoding/json"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
+ returnDocuments := request.ReturnDocuments
+ if returnDocuments == nil {
+ t := true
+ returnDocuments = &t
+ }
+ return &AliRerankRequest{
+ Model: request.Model,
+ Input: AliRerankInput{
+ Query: request.Query,
+ Documents: request.Documents,
+ },
+ Parameters: AliRerankParameters{
+ TopN: &request.TopN,
+ ReturnDocuments: returnDocuments,
+ },
+ }
+}
+
+func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
+ }
+ common.CloseResponseBodyGracefully(resp)
+
+ var aliResponse AliRerankResponse
+ err = json.Unmarshal(responseBody, &aliResponse)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+
+ if aliResponse.Code != "" {
+ return types.WithOpenAIError(types.OpenAIError{
+ Message: aliResponse.Message,
+ Type: aliResponse.Code,
+ Param: aliResponse.RequestId,
+ Code: aliResponse.Code,
+ }, resp.StatusCode), nil
+ }
+
+ usage := dto.Usage{
+ PromptTokens: aliResponse.Usage.TotalTokens,
+ CompletionTokens: 0,
+ TotalTokens: aliResponse.Usage.TotalTokens,
+ }
+ rerankResponse := dto.RerankResponse{
+ Results: aliResponse.Output.Results,
+ Usage: usage,
+ }
+
+ jsonResponse, err := json.Marshal(rerankResponse)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ c.Writer.Write(jsonResponse)
+ return nil, &usage
+}
diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go
new file mode 100644
index 00000000..6d90fa71
--- /dev/null
+++ b/relay/channel/ali/text.go
@@ -0,0 +1,206 @@
+package ali
+
+import (
+ "bufio"
+ "encoding/json"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/relay/helper"
+ "strings"
+
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
+
+const EnableSearchModelSuffix = "-internet"
+
+func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
+ if request.TopP >= 1 {
+ request.TopP = 0.999
+ } else if request.TopP <= 0 {
+ request.TopP = 0.001
+ }
+ return &request
+}
+
+func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
+ return &AliEmbeddingRequest{
+ Model: request.Model,
+ Input: struct {
+ Texts []string `json:"texts"`
+ }{
+ Texts: request.ParseInput(),
+ },
+ }
+}
+
+func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
+ var fullTextResponse dto.FlexibleEmbeddingResponse
+ err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+
+ common.CloseResponseBodyGracefully(resp)
+
+ model := c.GetString("model")
+ if model == "" {
+ model = "text-embedding-v4"
+ }
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ c.Writer.Write(jsonResponse)
+ return nil, &fullTextResponse.Usage
+}
+
+func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse, model string) *dto.OpenAIEmbeddingResponse {
+ openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
+ Object: "list",
+ Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
+ Model: model,
+ Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
+ }
+
+ for _, item := range response.Output.Embeddings {
+ openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
+ Object: `embedding`,
+ Index: item.TextIndex,
+ Embedding: item.Embedding,
+ })
+ }
+ return &openAIEmbeddingResponse
+}
+
+func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
+ choice := dto.OpenAITextResponseChoice{
+ Index: 0,
+ Message: dto.Message{
+ Role: "assistant",
+ Content: response.Output.Text,
+ },
+ FinishReason: response.Output.FinishReason,
+ }
+ fullTextResponse := dto.OpenAITextResponse{
+ Id: response.RequestId,
+ Object: "chat.completion",
+ Created: common.GetTimestamp(),
+ Choices: []dto.OpenAITextResponseChoice{choice},
+ Usage: dto.Usage{
+ PromptTokens: response.Usage.InputTokens,
+ CompletionTokens: response.Usage.OutputTokens,
+ TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
+ },
+ }
+ return &fullTextResponse
+}
+
+func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
+ var choice dto.ChatCompletionsStreamResponseChoice
+ choice.Delta.SetContentString(aliResponse.Output.Text)
+ if aliResponse.Output.FinishReason != "null" {
+ finishReason := aliResponse.Output.FinishReason
+ choice.FinishReason = &finishReason
+ }
+ response := dto.ChatCompletionsStreamResponse{
+ Id: aliResponse.RequestId,
+ Object: "chat.completion.chunk",
+ Created: common.GetTimestamp(),
+ Model: "ernie-bot",
+ Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
+ }
+ return &response
+}
+
+func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
+ var usage dto.Usage
+ scanner := bufio.NewScanner(resp.Body)
+ scanner.Split(bufio.ScanLines)
+ dataChan := make(chan string)
+ stopChan := make(chan bool)
+ go func() {
+ for scanner.Scan() {
+ data := scanner.Text()
+ if len(data) < 5 { // ignore blank line or wrong format
+ continue
+ }
+ if data[:5] != "data:" {
+ continue
+ }
+ data = data[5:]
+ dataChan <- data
+ }
+ stopChan <- true
+ }()
+ helper.SetEventStreamHeaders(c)
+ lastResponseText := ""
+ c.Stream(func(w io.Writer) bool {
+ select {
+ case data := <-dataChan:
+ var aliResponse AliResponse
+ err := json.Unmarshal([]byte(data), &aliResponse)
+ if err != nil {
+ common.SysError("error unmarshalling stream response: " + err.Error())
+ return true
+ }
+ if aliResponse.Usage.OutputTokens != 0 {
+ usage.PromptTokens = aliResponse.Usage.InputTokens
+ usage.CompletionTokens = aliResponse.Usage.OutputTokens
+ usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
+ }
+ response := streamResponseAli2OpenAI(&aliResponse)
+ response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
+ lastResponseText = aliResponse.Output.Text
+ jsonResponse, err := json.Marshal(response)
+ if err != nil {
+ common.SysError("error marshalling stream response: " + err.Error())
+ return true
+ }
+ c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+ return true
+ case <-stopChan:
+ c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+ return false
+ }
+ })
+ common.CloseResponseBodyGracefully(resp)
+ return nil, &usage
+}
+
+func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
+ var aliResponse AliResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
+ }
+ common.CloseResponseBodyGracefully(resp)
+ err = json.Unmarshal(responseBody, &aliResponse)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ if aliResponse.Code != "" {
+ return types.WithOpenAIError(types.OpenAIError{
+ Message: aliResponse.Message,
+ Type: "ali_error",
+ Param: aliResponse.RequestId,
+ Code: aliResponse.Code,
+ }, resp.StatusCode), nil
+ }
+ fullTextResponse := responseAli2OpenAI(&aliResponse)
+ jsonResponse, err := common.Marshal(fullTextResponse)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = c.Writer.Write(jsonResponse)
+ return nil, &fullTextResponse.Usage
+}
diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go
new file mode 100644
index 00000000..ff7c63fa
--- /dev/null
+++ b/relay/channel/api_request.go
@@ -0,0 +1,277 @@
+package channel
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ common2 "one-api/common"
+ "one-api/relay/common"
+ "one-api/relay/constant"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting/operation_setting"
+ "sync"
+ "time"
+
+ "github.com/bytedance/gopkg/util/gopool"
+ "github.com/gin-gonic/gin"
+ "github.com/gorilla/websocket"
+)
+
+func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
+ if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
+ // multipart/form-data
+ } else if info.RelayMode == constant.RelayModeRealtime {
+ // websocket
+ } else {
+ req.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+ req.Set("Accept", c.Request.Header.Get("Accept"))
+ if info.IsStream && c.Request.Header.Get("Accept") == "" {
+ req.Set("Accept", "text/event-stream")
+ }
+ }
+}
+
+func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+ fullRequestURL, err := a.GetRequestURL(info)
+ if err != nil {
+ return nil, fmt.Errorf("get request url failed: %w", err)
+ }
+ if common2.DebugEnabled {
+ println("fullRequestURL:", fullRequestURL)
+ }
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ if err != nil {
+ return nil, fmt.Errorf("new request failed: %w", err)
+ }
+ err = a.SetupRequestHeader(c, &req.Header, info)
+ if err != nil {
+ return nil, fmt.Errorf("setup request header failed: %w", err)
+ }
+ resp, err := doRequest(c, req, info)
+ if err != nil {
+ return nil, fmt.Errorf("do request failed: %w", err)
+ }
+ return resp, nil
+}
+
+func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+ fullRequestURL, err := a.GetRequestURL(info)
+ if err != nil {
+ return nil, fmt.Errorf("get request url failed: %w", err)
+ }
+ if common2.DebugEnabled {
+ println("fullRequestURL:", fullRequestURL)
+ }
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ if err != nil {
+ return nil, fmt.Errorf("new request failed: %w", err)
+ }
+ // set form data
+ req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+
+ err = a.SetupRequestHeader(c, &req.Header, info)
+ if err != nil {
+ return nil, fmt.Errorf("setup request header failed: %w", err)
+ }
+ resp, err := doRequest(c, req, info)
+ if err != nil {
+ return nil, fmt.Errorf("do request failed: %w", err)
+ }
+ return resp, nil
+}
+
+func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) {
+ fullRequestURL, err := a.GetRequestURL(info)
+ if err != nil {
+ return nil, fmt.Errorf("get request url failed: %w", err)
+ }
+ targetHeader := http.Header{}
+ err = a.SetupRequestHeader(c, &targetHeader, info)
+ if err != nil {
+ return nil, fmt.Errorf("setup request header failed: %w", err)
+ }
+ targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+ targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
+ if err != nil {
+ return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err)
+ }
+ // send request body
+ //all, err := io.ReadAll(requestBody)
+ //err = service.WssString(c, targetConn, string(all))
+ return targetConn, nil
+}
+
+func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc {
+ pingerCtx, stopPinger := context.WithCancel(context.Background())
+
+ gopool.Go(func() {
+ defer func() {
+ // 增加panic恢复处理
+ if r := recover(); r != nil {
+ if common2.DebugEnabled {
+ println("SSE ping goroutine panic recovered:", fmt.Sprintf("%v", r))
+ }
+ }
+ if common2.DebugEnabled {
+ println("SSE ping goroutine stopped.")
+ }
+ }()
+
+ if pingInterval <= 0 {
+ pingInterval = helper.DefaultPingInterval
+ }
+
+ ticker := time.NewTicker(pingInterval)
+ // 确保在任何情况下都清理ticker
+ defer func() {
+ ticker.Stop()
+ if common2.DebugEnabled {
+ println("SSE ping ticker stopped")
+ }
+ }()
+
+ var pingMutex sync.Mutex
+ if common2.DebugEnabled {
+ println("SSE ping goroutine started")
+ }
+
+ // 增加超时控制,防止goroutine长时间运行
+ maxPingDuration := 120 * time.Minute // 最大ping持续时间
+ pingTimeout := time.NewTimer(maxPingDuration)
+ defer pingTimeout.Stop()
+
+ for {
+ select {
+ // 发送 ping 数据
+ case <-ticker.C:
+ if err := sendPingData(c, &pingMutex); err != nil {
+ if common2.DebugEnabled {
+ println("SSE ping error, stopping goroutine:", err.Error())
+ }
+ return
+ }
+ // 收到退出信号
+ case <-pingerCtx.Done():
+ return
+ // request 结束
+ case <-c.Request.Context().Done():
+ return
+ // 超时保护,防止goroutine无限运行
+ case <-pingTimeout.C:
+ if common2.DebugEnabled {
+ println("SSE ping goroutine timeout, stopping")
+ }
+ return
+ }
+ }
+ })
+
+ return stopPinger
+}
+
+func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
+ // 增加超时控制,防止锁死等待
+ done := make(chan error, 1)
+ go func() {
+ mutex.Lock()
+ defer mutex.Unlock()
+
+ err := helper.PingData(c)
+ if err != nil {
+ common2.LogError(c, "SSE ping error: "+err.Error())
+ done <- err
+ return
+ }
+
+ if common2.DebugEnabled {
+ println("SSE ping data sent.")
+ }
+ done <- nil
+ }()
+
+ // 设置发送ping数据的超时时间
+ select {
+ case err := <-done:
+ return err
+ case <-time.After(10 * time.Second):
+ return errors.New("SSE ping data send timeout")
+ case <-c.Request.Context().Done():
+ return errors.New("request context cancelled during ping")
+ }
+}
+
+func DoRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
+ return doRequest(c, req, info)
+}
+func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
+ var client *http.Client
+ var err error
+ if info.ChannelSetting.Proxy != "" {
+ client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
+ if err != nil {
+ return nil, fmt.Errorf("new proxy http client failed: %w", err)
+ }
+ } else {
+ client = service.GetHttpClient()
+ }
+
+ var stopPinger context.CancelFunc
+ if info.IsStream {
+ helper.SetEventStreamHeaders(c)
+ // 处理流式请求的 ping 保活
+ generalSettings := operation_setting.GetGeneralSetting()
+ if generalSettings.PingIntervalEnabled {
+ pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
+ stopPinger = startPingKeepAlive(c, pingInterval)
+ // 使用defer确保在任何情况下都能停止ping goroutine
+ defer func() {
+ if stopPinger != nil {
+ stopPinger()
+ if common2.DebugEnabled {
+ println("SSE ping goroutine stopped by defer")
+ }
+ }
+ }()
+ }
+ }
+
+ resp, err := client.Do(req)
+
+ if err != nil {
+ return nil, err
+ }
+ if resp == nil {
+ return nil, errors.New("resp is nil")
+ }
+
+ _ = req.Body.Close()
+ _ = c.Request.Body.Close()
+ return resp, nil
+}
+
+func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+ fullRequestURL, err := a.BuildRequestURL(info)
+ if err != nil {
+ return nil, err
+ }
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ if err != nil {
+ return nil, fmt.Errorf("new request failed: %w", err)
+ }
+ req.GetBody = func() (io.ReadCloser, error) {
+ return io.NopCloser(requestBody), nil
+ }
+
+ err = a.BuildRequestHeader(c, req, info)
+ if err != nil {
+ return nil, fmt.Errorf("setup request header failed: %w", err)
+ }
+ resp, err := doRequest(c, req, info.RelayInfo)
+ if err != nil {
+ return nil, fmt.Errorf("do request failed: %w", err)
+ }
+ return resp, nil
+}
diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go
new file mode 100644
index 00000000..d3354f00
--- /dev/null
+++ b/relay/channel/aws/adaptor.go
@@ -0,0 +1,107 @@
+package aws
+
+import (
+ "errors"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel/claude"
+ relaycommon "one-api/relay/common"
+ "one-api/setting/model_setting"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ RequestModeCompletion = 1
+ RequestModeMessage = 2
+)
+
+type Adaptor struct {
+ RequestMode int
+}
+
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
+ c.Set("request_model", request.Model)
+ c.Set("converted_request", request)
+ return request, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+ a.RequestMode = RequestModeMessage
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ return "", nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ model_setting.GetClaudeSettings().WriteHeaders(info.OriginModelName, req)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+
+ var claudeReq *dto.ClaudeRequest
+ var err error
+ claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
+ if err != nil {
+ return nil, err
+ }
+ c.Set("request_model", claudeReq.Model)
+ c.Set("converted_request", claudeReq)
+ return claudeReq, err
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
+ } else {
+ err, usage = awsHandler(c, info, a.RequestMode)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() (models []string) {
+ for n := range awsModelIDMap {
+ models = append(models, n)
+ }
+
+ return
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go
new file mode 100644
index 00000000..64c7b747
--- /dev/null
+++ b/relay/channel/aws/constants.go
@@ -0,0 +1,65 @@
+package aws
+
+var awsModelIDMap = map[string]string{
+ "claude-instant-1.2": "anthropic.claude-instant-v1",
+ "claude-2.0": "anthropic.claude-v2",
+ "claude-2.1": "anthropic.claude-v2:1",
+ "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
+ "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
+ "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
+ "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
+ "claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
+ "claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
+ "claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
+ "claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
+ "claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
+}
+
+var awsModelCanCrossRegionMap = map[string]map[string]bool{
+ "anthropic.claude-3-sonnet-20240229-v1:0": {
+ "us": true,
+ "eu": true,
+ "ap": true,
+ },
+ "anthropic.claude-3-opus-20240229-v1:0": {
+ "us": true,
+ },
+ "anthropic.claude-3-haiku-20240307-v1:0": {
+ "us": true,
+ "eu": true,
+ "ap": true,
+ },
+ "anthropic.claude-3-5-sonnet-20240620-v1:0": {
+ "us": true,
+ "eu": true,
+ "ap": true,
+ },
+ "anthropic.claude-3-5-sonnet-20241022-v2:0": {
+ "us": true,
+ "ap": true,
+ },
+ "anthropic.claude-3-5-haiku-20241022-v1:0": {
+ "us": true,
+ },
+ "anthropic.claude-3-7-sonnet-20250219-v1:0": {
+ "us": true,
+ "ap": true,
+ "eu": true,
+ },
+ "anthropic.claude-sonnet-4-20250514-v1:0": {
+ "us": true,
+ "ap": true,
+ "eu": true,
+ },
+ "anthropic.claude-opus-4-20250514-v1:0": {
+ "us": true,
+ },
+}
+
+var awsRegionCrossModelPrefixMap = map[string]string{
+ "us": "us",
+ "eu": "eu",
+ "ap": "apac",
+}
+
+var ChannelName = "aws"
diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go
new file mode 100644
index 00000000..0188c30a
--- /dev/null
+++ b/relay/channel/aws/dto.go
@@ -0,0 +1,36 @@
+package aws
+
+import (
+ "one-api/dto"
+)
+
+type AwsClaudeRequest struct {
+ // AnthropicVersion should be "bedrock-2023-05-31"
+ AnthropicVersion string `json:"anthropic_version"`
+ System any `json:"system,omitempty"`
+ Messages []dto.ClaudeMessage `json:"messages"`
+ MaxTokens uint `json:"max_tokens,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ StopSequences []string `json:"stop_sequences,omitempty"`
+ Tools any `json:"tools,omitempty"`
+ ToolChoice any `json:"tool_choice,omitempty"`
+ Thinking *dto.Thinking `json:"thinking,omitempty"`
+}
+
+func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
+ return &AwsClaudeRequest{
+ AnthropicVersion: "bedrock-2023-05-31",
+ System: req.System,
+ Messages: req.Messages,
+ MaxTokens: req.MaxTokens,
+ Temperature: req.Temperature,
+ TopP: req.TopP,
+ TopK: req.TopK,
+ StopSequences: req.StopSequences,
+ Tools: req.Tools,
+ ToolChoice: req.ToolChoice,
+ Thinking: req.Thinking,
+ }
+}
diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go
new file mode 100644
index 00000000..0df19e07
--- /dev/null
+++ b/relay/channel/aws/relay-aws.go
@@ -0,0 +1,196 @@
+package aws
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/relay/channel/claude"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+ "github.com/pkg/errors"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/credentials"
+ "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
+ bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
+)
+
+func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
+ awsSecret := strings.Split(info.ApiKey, "|")
+ if len(awsSecret) != 3 {
+ return nil, errors.New("invalid aws secret key")
+ }
+ ak := awsSecret[0]
+ sk := awsSecret[1]
+ region := awsSecret[2]
+ client := bedrockruntime.New(bedrockruntime.Options{
+ Region: region,
+ Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
+ })
+
+ return client, nil
+}
+
+func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
+ return &dto.OpenAIErrorWithStatusCode{
+ StatusCode: http.StatusInternalServerError,
+ Error: dto.OpenAIError{
+ Message: fmt.Sprintf("%s", err.Error()),
+ },
+ }
+}
+
+func awsRegionPrefix(awsRegionId string) string {
+ parts := strings.Split(awsRegionId, "-")
+ regionPrefix := ""
+ if len(parts) > 0 {
+ regionPrefix = parts[0]
+ }
+ return regionPrefix
+}
+
+func awsModelCanCrossRegion(awsModelId, awsRegionPrefix string) bool {
+ regionSet, exists := awsModelCanCrossRegionMap[awsModelId]
+ return exists && regionSet[awsRegionPrefix]
+}
+
+func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
+ modelPrefix, find := awsRegionCrossModelPrefixMap[awsRegionPrefix]
+ if !find {
+ return awsModelId
+ }
+ return modelPrefix + "." + awsModelId
+}
+
+func awsModelID(requestModel string) string {
+ if awsModelID, ok := awsModelIDMap[requestModel]; ok {
+ return awsModelID
+ }
+
+ return requestModel
+}
+
+func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
+ awsCli, err := newAwsClient(c, info)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
+ }
+
+ awsModelId := awsModelID(c.GetString("request_model"))
+
+ awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
+ canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
+ if canCrossRegion {
+ awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
+ }
+
+ awsReq := &bedrockruntime.InvokeModelInput{
+ ModelId: aws.String(awsModelId),
+ Accept: aws.String("application/json"),
+ ContentType: aws.String("application/json"),
+ }
+
+ claudeReq_, ok := c.Get("converted_request")
+ if !ok {
+ return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
+ }
+ claudeReq := claudeReq_.(*dto.ClaudeRequest)
+ awsClaudeReq := copyRequest(claudeReq)
+ awsReq.Body, err = json.Marshal(awsClaudeReq)
+ if err != nil {
+ return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
+ }
+
+ awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
+ if err != nil {
+ return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
+ }
+
+ claudeInfo := &claude.ClaudeResponseInfo{
+ ResponseId: helper.GetResponseID(c),
+ Created: common.GetTimestamp(),
+ Model: info.UpstreamModelName,
+ ResponseText: strings.Builder{},
+ Usage: &dto.Usage{},
+ }
+
+ handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
+ if handlerErr != nil {
+ return handlerErr, nil
+ }
+ return nil, claudeInfo.Usage
+}
+
+func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
+ awsCli, err := newAwsClient(c, info)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
+ }
+
+ awsModelId := awsModelID(c.GetString("request_model"))
+
+ awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
+ canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
+ if canCrossRegion {
+ awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
+ }
+
+ awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
+ ModelId: aws.String(awsModelId),
+ Accept: aws.String("application/json"),
+ ContentType: aws.String("application/json"),
+ }
+
+ claudeReq_, ok := c.Get("converted_request")
+ if !ok {
+ return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
+ }
+ claudeReq := claudeReq_.(*dto.ClaudeRequest)
+
+ awsClaudeReq := copyRequest(claudeReq)
+ awsReq.Body, err = json.Marshal(awsClaudeReq)
+ if err != nil {
+ return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
+ }
+
+ awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
+ if err != nil {
+ return types.NewError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeChannelAwsClientError), nil
+ }
+ stream := awsResp.GetStream()
+ defer stream.Close()
+
+ claudeInfo := &claude.ClaudeResponseInfo{
+ ResponseId: helper.GetResponseID(c),
+ Created: common.GetTimestamp(),
+ Model: info.UpstreamModelName,
+ ResponseText: strings.Builder{},
+ Usage: &dto.Usage{},
+ }
+
+ for event := range stream.Events() {
+ switch v := event.(type) {
+ case *bedrockruntimeTypes.ResponseStreamMemberChunk:
+ info.SetFirstResponseTime()
+ respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
+ if respErr != nil {
+ return respErr, nil
+ }
+ case *bedrockruntimeTypes.UnknownUnionMember:
+ fmt.Println("unknown tag:", v.Tag)
+ return types.NewError(errors.New("unknown response type"), types.ErrorCodeInvalidRequest), nil
+ default:
+ fmt.Println("union is nil or unknown type")
+ return types.NewError(errors.New("nil or unknown response type"), types.ErrorCodeInvalidRequest), nil
+ }
+ }
+
+ claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
+ return nil, claudeInfo.Usage
+}
diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go
new file mode 100644
index 00000000..22443354
--- /dev/null
+++ b/relay/channel/baidu/adaptor.go
@@ -0,0 +1,164 @@
+package baidu
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
+ suffix := "chat/"
+ if strings.HasPrefix(info.UpstreamModelName, "Embedding") {
+ suffix = "embeddings/"
+ }
+ if strings.HasPrefix(info.UpstreamModelName, "bge-large") {
+ suffix = "embeddings/"
+ }
+ if strings.HasPrefix(info.UpstreamModelName, "tao-8k") {
+ suffix = "embeddings/"
+ }
+ switch info.UpstreamModelName {
+ case "ERNIE-4.0":
+ suffix += "completions_pro"
+ case "ERNIE-Bot-4":
+ suffix += "completions_pro"
+ case "ERNIE-Bot":
+ suffix += "completions"
+ case "ERNIE-Bot-turbo":
+ suffix += "eb-instant"
+ case "ERNIE-Speed":
+ suffix += "ernie_speed"
+ case "ERNIE-4.0-8K":
+ suffix += "completions_pro"
+ case "ERNIE-3.5-8K":
+ suffix += "completions"
+ case "ERNIE-3.5-8K-0205":
+ suffix += "ernie-3.5-8k-0205"
+ case "ERNIE-3.5-8K-1222":
+ suffix += "ernie-3.5-8k-1222"
+ case "ERNIE-Bot-8K":
+ suffix += "ernie_bot_8k"
+ case "ERNIE-3.5-4K-0205":
+ suffix += "ernie-3.5-4k-0205"
+ case "ERNIE-Speed-8K":
+ suffix += "ernie_speed"
+ case "ERNIE-Speed-128K":
+ suffix += "ernie-speed-128k"
+ case "ERNIE-Lite-8K-0922":
+ suffix += "eb-instant"
+ case "ERNIE-Lite-8K-0308":
+ suffix += "ernie-lite-8k"
+ case "ERNIE-Tiny-8K":
+ suffix += "ernie-tiny-8k"
+ case "BLOOMZ-7B":
+ suffix += "bloomz_7b1"
+ case "Embedding-V1":
+ suffix += "embedding-v1"
+ case "bge-large-zh":
+ suffix += "bge_large_zh"
+ case "bge-large-en":
+ suffix += "bge_large_en"
+ case "tao-8k":
+ suffix += "tao_8k"
+ default:
+ suffix += strings.ToLower(info.UpstreamModelName)
+ }
+ fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix)
+ var accessToken string
+ var err error
+ if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
+ return "", err
+ }
+ fullRequestURL += "?access_token=" + accessToken
+ return fullRequestURL, nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", "Bearer "+info.ApiKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ switch info.RelayMode {
+ default:
+ baiduRequest := requestOpenAI2Baidu(*request)
+ return baiduRequest, nil
+ }
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(request)
+ return baiduEmbeddingRequest, nil
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ err, usage = baiduStreamHandler(c, info, resp)
+ } else {
+ switch info.RelayMode {
+ case constant.RelayModeEmbeddings:
+ err, usage = baiduEmbeddingHandler(c, info, resp)
+ default:
+ err, usage = baiduHandler(c, info, resp)
+ }
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go
new file mode 100644
index 00000000..46914330
--- /dev/null
+++ b/relay/channel/baidu/constants.go
@@ -0,0 +1,22 @@
+package baidu
+
+var ModelList = []string{
+ "ERNIE-4.0-8K",
+ "ERNIE-3.5-8K",
+ "ERNIE-3.5-8K-0205",
+ "ERNIE-3.5-8K-1222",
+ "ERNIE-Bot-8K",
+ "ERNIE-3.5-4K-0205",
+ "ERNIE-Speed-8K",
+ "ERNIE-Speed-128K",
+ "ERNIE-Lite-8K-0922",
+ "ERNIE-Lite-8K-0308",
+ "ERNIE-Tiny-8K",
+ "BLOOMZ-7B",
+ "Embedding-V1",
+ "bge-large-zh",
+ "bge-large-en",
+ "tao-8k",
+}
+
+var ChannelName = "baidu"
diff --git a/relay/channel/baidu/dto.go b/relay/channel/baidu/dto.go
new file mode 100644
index 00000000..a486de5a
--- /dev/null
+++ b/relay/channel/baidu/dto.go
@@ -0,0 +1,78 @@
+package baidu
+
+import (
+ "one-api/dto"
+ "time"
+)
+
+type BaiduMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+type BaiduChatRequest struct {
+ Messages []BaiduMessage `json:"messages"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ PenaltyScore float64 `json:"penalty_score,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ System string `json:"system,omitempty"`
+ DisableSearch bool `json:"disable_search,omitempty"`
+ EnableCitation bool `json:"enable_citation,omitempty"`
+ MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
+ UserId string `json:"user_id,omitempty"`
+}
+
+type Error struct {
+ ErrorCode int `json:"error_code"`
+ ErrorMsg string `json:"error_msg"`
+}
+
+type BaiduChatResponse struct {
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Result string `json:"result"`
+ IsTruncated bool `json:"is_truncated"`
+ NeedClearHistory bool `json:"need_clear_history"`
+ Usage dto.Usage `json:"usage"`
+ Error
+}
+
+type BaiduChatStreamResponse struct {
+ BaiduChatResponse
+ SentenceId int `json:"sentence_id"`
+ IsEnd bool `json:"is_end"`
+}
+
+type BaiduEmbeddingRequest struct {
+ Input []string `json:"input"`
+}
+
+type BaiduEmbeddingData struct {
+ Object string `json:"object"`
+ Embedding []float64 `json:"embedding"`
+ Index int `json:"index"`
+}
+
+type BaiduEmbeddingResponse struct {
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Data []BaiduEmbeddingData `json:"data"`
+ Usage dto.Usage `json:"usage"`
+ Error
+}
+
+type BaiduAccessToken struct {
+ AccessToken string `json:"access_token"`
+ Error string `json:"error,omitempty"`
+ ErrorDescription string `json:"error_description,omitempty"`
+ ExpiresIn int64 `json:"expires_in,omitempty"`
+ ExpiresAt time.Time `json:"-"`
+}
+
+type BaiduTokenResponse struct {
+ ExpiresIn int `json:"expires_in"`
+ AccessToken string `json:"access_token"`
+}
diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go
new file mode 100644
index 00000000..06b48c20
--- /dev/null
+++ b/relay/channel/baidu/relay-baidu.go
@@ -0,0 +1,245 @@
+package baidu
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
+
+var baiduTokenStore sync.Map
+
+func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
+ baiduRequest := BaiduChatRequest{
+ Temperature: request.Temperature,
+ TopP: request.TopP,
+ PenaltyScore: request.FrequencyPenalty,
+ Stream: request.Stream,
+ DisableSearch: false,
+ EnableCitation: false,
+ UserId: request.User,
+ }
+ if request.MaxTokens != 0 {
+ maxTokens := int(request.MaxTokens)
+ if request.MaxTokens == 1 {
+ maxTokens = 2
+ }
+ baiduRequest.MaxOutputTokens = &maxTokens
+ }
+ for _, message := range request.Messages {
+ if message.Role == "system" {
+ baiduRequest.System = message.StringContent()
+ } else {
+ baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
+ Role: message.Role,
+ Content: message.StringContent(),
+ })
+ }
+ }
+ return &baiduRequest
+}
+
+func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
+ choice := dto.OpenAITextResponseChoice{
+ Index: 0,
+ Message: dto.Message{
+ Role: "assistant",
+ Content: response.Result,
+ },
+ FinishReason: "stop",
+ }
+ fullTextResponse := dto.OpenAITextResponse{
+ Id: response.Id,
+ Object: "chat.completion",
+ Created: response.Created,
+ Choices: []dto.OpenAITextResponseChoice{choice},
+ Usage: response.Usage,
+ }
+ return &fullTextResponse
+}
+
+func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
+ var choice dto.ChatCompletionsStreamResponseChoice
+ choice.Delta.SetContentString(baiduResponse.Result)
+ if baiduResponse.IsEnd {
+ choice.FinishReason = &constant.FinishReasonStop
+ }
+ response := dto.ChatCompletionsStreamResponse{
+ Id: baiduResponse.Id,
+ Object: "chat.completion.chunk",
+ Created: baiduResponse.Created,
+ Model: "ernie-bot",
+ Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
+ }
+ return &response
+}
+
+func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest {
+ return &BaiduEmbeddingRequest{
+ Input: request.ParseInput(),
+ }
+}
+
+func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
+ openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
+ Object: "list",
+ Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
+ Model: "baidu-embedding",
+ Usage: response.Usage,
+ }
+ for _, item := range response.Data {
+ openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
+ Object: item.Object,
+ Index: item.Index,
+ Embedding: item.Embedding,
+ })
+ }
+ return &openAIEmbeddingResponse
+}
+
+func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
+ usage := &dto.Usage{}
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+ var baiduResponse BaiduChatStreamResponse
+ err := common.Unmarshal([]byte(data), &baiduResponse)
+ if err != nil {
+ common.SysError("error unmarshalling stream response: " + err.Error())
+ return true
+ }
+ if baiduResponse.Usage.TotalTokens != 0 {
+ usage.TotalTokens = baiduResponse.Usage.TotalTokens
+ usage.PromptTokens = baiduResponse.Usage.PromptTokens
+ usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
+ }
+ response := streamResponseBaidu2OpenAI(&baiduResponse)
+ err = helper.ObjectData(c, response)
+ if err != nil {
+ common.SysError("error sending stream response: " + err.Error())
+ }
+ return true
+ })
+ common.CloseResponseBodyGracefully(resp)
+ return nil, usage
+}
+
+func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
+ var baiduResponse BaiduChatResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ common.CloseResponseBodyGracefully(resp)
+ err = json.Unmarshal(responseBody, &baiduResponse)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ if baiduResponse.ErrorMsg != "" {
+ return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
+ }
+ fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = c.Writer.Write(jsonResponse)
+ return nil, &fullTextResponse.Usage
+}
+
+func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
+ var baiduResponse BaiduEmbeddingResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ common.CloseResponseBodyGracefully(resp)
+ err = json.Unmarshal(responseBody, &baiduResponse)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ if baiduResponse.ErrorMsg != "" {
+ return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
+ }
+ fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = c.Writer.Write(jsonResponse)
+ return nil, &fullTextResponse.Usage
+}
+
+func getBaiduAccessToken(apiKey string) (string, error) {
+ if val, ok := baiduTokenStore.Load(apiKey); ok {
+ var accessToken BaiduAccessToken
+ if accessToken, ok = val.(BaiduAccessToken); ok {
+ // soon this will expire
+ if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
+ go func() {
+ _, _ = getBaiduAccessTokenHelper(apiKey)
+ }()
+ }
+ return accessToken.AccessToken, nil
+ }
+ }
+ accessToken, err := getBaiduAccessTokenHelper(apiKey)
+ if err != nil {
+ return "", err
+ }
+ if accessToken == nil {
+ return "", errors.New("getBaiduAccessToken return a nil token")
+ }
+ return (*accessToken).AccessToken, nil
+}
+
+func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
+ parts := strings.Split(apiKey, "|")
+ if len(parts) != 2 {
+ return nil, errors.New("invalid baidu apikey")
+ }
+ req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
+ parts[0], parts[1]), nil)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Add("Content-Type", "application/json")
+ req.Header.Add("Accept", "application/json")
+ res, err := service.GetHttpClient().Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer res.Body.Close()
+
+ var accessToken BaiduAccessToken
+ err = json.NewDecoder(res.Body).Decode(&accessToken)
+ if err != nil {
+ return nil, err
+ }
+ if accessToken.Error != "" {
+ return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
+ }
+ if accessToken.AccessToken == "" {
+ return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
+ }
+ accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
+ baiduTokenStore.Store(apiKey, accessToken)
+ return &accessToken, nil
+}
diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go
new file mode 100644
index 00000000..375fd531
--- /dev/null
+++ b/relay/channel/baidu_v2/adaptor.go
@@ -0,0 +1,111 @@
+package baidu_v2
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ keyParts := strings.Split(info.ApiKey, "|")
+ if len(keyParts) == 0 || keyParts[0] == "" {
+ return errors.New("invalid API key: authorization token is required")
+ }
+ if len(keyParts) > 1 {
+ if keyParts[1] != "" {
+ req.Set("appid", keyParts[1])
+ }
+ }
+ req.Set("Authorization", "Bearer "+keyParts[0])
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ if strings.HasSuffix(info.UpstreamModelName, "-search") {
+ info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search")
+ request.Model = info.UpstreamModelName
+ toMap := request.ToMap()
+ toMap["web_search"] = map[string]any{
+ "enable": true,
+ "enable_citation": true,
+ "enable_trace": true,
+ "enable_status": false,
+ }
+ return toMap, nil
+ }
+ return request, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ usage, err = openai.OaiStreamHandler(c, info, resp)
+ } else {
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/baidu_v2/constants.go b/relay/channel/baidu_v2/constants.go
new file mode 100644
index 00000000..a7cee248
--- /dev/null
+++ b/relay/channel/baidu_v2/constants.go
@@ -0,0 +1,29 @@
+package baidu_v2
+
+var ModelList = []string{
+ "ernie-4.0-8k-latest",
+ "ernie-4.0-8k-preview",
+ "ernie-4.0-8k",
+ "ernie-4.0-turbo-8k-latest",
+ "ernie-4.0-turbo-8k-preview",
+ "ernie-4.0-turbo-8k",
+ "ernie-4.0-turbo-128k",
+ "ernie-3.5-8k-preview",
+ "ernie-3.5-8k",
+ "ernie-3.5-128k",
+ "ernie-speed-8k",
+ "ernie-speed-128k",
+ "ernie-speed-pro-128k",
+ "ernie-lite-8k",
+ "ernie-lite-pro-128k",
+ "ernie-tiny-8k",
+ "ernie-char-8k",
+ "ernie-char-fiction-8k",
+ "ernie-novel-8k",
+ "deepseek-v3",
+ "deepseek-r1",
+ "deepseek-r1-distill-qwen-32b",
+ "deepseek-r1-distill-qwen-14b",
+}
+
+var ChannelName = "volcengine"
diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go
new file mode 100644
index 00000000..540742d6
--- /dev/null
+++ b/relay/channel/claude/adaptor.go
@@ -0,0 +1,113 @@
+package claude
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/setting/model_setting"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ RequestModeCompletion = 1
+ RequestModeMessage = 2
+)
+
+type Adaptor struct {
+ RequestMode int
+}
+
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+ if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") {
+ a.RequestMode = RequestModeCompletion
+ } else {
+ a.RequestMode = RequestModeMessage
+ }
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ if a.RequestMode == RequestModeMessage {
+ return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
+ } else {
+ return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
+ }
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("x-api-key", info.ApiKey)
+ anthropicVersion := c.Request.Header.Get("anthropic-version")
+ if anthropicVersion == "" {
+ anthropicVersion = "2023-06-01"
+ }
+ req.Set("anthropic-version", anthropicVersion)
+ model_setting.GetClaudeSettings().WriteHeaders(info.OriginModelName, req)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ if a.RequestMode == RequestModeCompletion {
+ return RequestOpenAI2ClaudeComplete(*request), nil
+ } else {
+ return RequestOpenAI2ClaudeMessage(*request)
+ }
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
+ } else {
+ err, usage = ClaudeHandler(c, resp, a.RequestMode, info)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/claude/constants.go b/relay/channel/claude/constants.go
new file mode 100644
index 00000000..e0e3c421
--- /dev/null
+++ b/relay/channel/claude/constants.go
@@ -0,0 +1,22 @@
+package claude
+
+var ModelList = []string{
+ "claude-instant-1.2",
+ "claude-2",
+ "claude-2.0",
+ "claude-2.1",
+ "claude-3-sonnet-20240229",
+ "claude-3-opus-20240229",
+ "claude-3-haiku-20240307",
+ "claude-3-5-haiku-20241022",
+ "claude-3-5-sonnet-20240620",
+ "claude-3-5-sonnet-20241022",
+ "claude-3-7-sonnet-20250219",
+ "claude-3-7-sonnet-20250219-thinking",
+ "claude-sonnet-4-20250514",
+ "claude-sonnet-4-20250514-thinking",
+ "claude-opus-4-20250514",
+ "claude-opus-4-20250514-thinking",
+}
+
+var ChannelName = "claude"
diff --git a/relay/channel/claude/dto.go b/relay/channel/claude/dto.go
new file mode 100644
index 00000000..89415868
--- /dev/null
+++ b/relay/channel/claude/dto.go
@@ -0,0 +1,95 @@
+package claude
+
+//
+//type ClaudeMetadata struct {
+// UserId string `json:"user_id"`
+//}
+//
+//type ClaudeMediaMessage struct {
+// Type string `json:"type"`
+// Text string `json:"text,omitempty"`
+// Source *ClaudeMessageSource `json:"source,omitempty"`
+// Usage *ClaudeUsage `json:"usage,omitempty"`
+// StopReason *string `json:"stop_reason,omitempty"`
+// PartialJson string `json:"partial_json,omitempty"`
+// Thinking string `json:"thinking,omitempty"`
+// Signature string `json:"signature,omitempty"`
+// Delta string `json:"delta,omitempty"`
+// // tool_calls
+// Id string `json:"id,omitempty"`
+// Name string `json:"name,omitempty"`
+// Input any `json:"input,omitempty"`
+// Content string `json:"content,omitempty"`
+// ToolUseId string `json:"tool_use_id,omitempty"`
+//}
+//
+//type ClaudeMessageSource struct {
+// Type string `json:"type"`
+// MediaType string `json:"media_type"`
+// Data string `json:"data"`
+//}
+//
+//type ClaudeMessage struct {
+// Role string `json:"role"`
+// Content any `json:"content"`
+//}
+//
+//type Tool struct {
+// Name string `json:"name"`
+// Description string `json:"description,omitempty"`
+// InputSchema map[string]interface{} `json:"input_schema"`
+//}
+//
+//type InputSchema struct {
+// Type string `json:"type"`
+// Properties any `json:"properties,omitempty"`
+// Required any `json:"required,omitempty"`
+//}
+//
+//type ClaudeRequest struct {
+// Model string `json:"model"`
+// Prompt string `json:"prompt,omitempty"`
+// System string `json:"system,omitempty"`
+// Messages []ClaudeMessage `json:"messages,omitempty"`
+// MaxTokens uint `json:"max_tokens,omitempty"`
+// MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
+// StopSequences []string `json:"stop_sequences,omitempty"`
+// Temperature *float64 `json:"temperature,omitempty"`
+// TopP float64 `json:"top_p,omitempty"`
+// TopK int `json:"top_k,omitempty"`
+// //ClaudeMetadata `json:"metadata,omitempty"`
+// Stream bool `json:"stream,omitempty"`
+// Tools any `json:"tools,omitempty"`
+// ToolChoice any `json:"tool_choice,omitempty"`
+// Thinking *Thinking `json:"thinking,omitempty"`
+//}
+//
+//type Thinking struct {
+// Type string `json:"type"`
+// BudgetTokens int `json:"budget_tokens"`
+//}
+//
+//type ClaudeError struct {
+// Type string `json:"type"`
+// Message string `json:"message"`
+//}
+//
+//type ClaudeResponse struct {
+// Id string `json:"id"`
+// Type string `json:"type"`
+// Content []ClaudeMediaMessage `json:"content"`
+// Completion string `json:"completion"`
+// StopReason string `json:"stop_reason"`
+// Model string `json:"model"`
+// Error ClaudeError `json:"error"`
+// Usage ClaudeUsage `json:"usage"`
+// Index int `json:"index"` // stream only
+// ContentBlock *ClaudeMediaMessage `json:"content_block"`
+// Delta *ClaudeMediaMessage `json:"delta"` // stream only
+// Message *ClaudeResponse `json:"message"` // stream only: message_start
+//}
+//
+//type ClaudeUsage struct {
+// InputTokens int `json:"input_tokens"`
+// OutputTokens int `json:"output_tokens"`
+//}
diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go
new file mode 100644
index 00000000..f20b573d
--- /dev/null
+++ b/relay/channel/claude/relay-claude.go
@@ -0,0 +1,813 @@
+package claude
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/relay/channel/openrouter"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting/model_setting"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ WebSearchMaxUsesLow = 1
+ WebSearchMaxUsesMedium = 5
+ WebSearchMaxUsesHigh = 10
+)
+
+func stopReasonClaude2OpenAI(reason string) string {
+ switch reason {
+ case "stop_sequence":
+ return "stop"
+ case "end_turn":
+ return "stop"
+ case "max_tokens":
+ return "max_tokens"
+ case "tool_use":
+ return "tool_calls"
+ default:
+ return reason
+ }
+}
+
+func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest {
+
+ claudeRequest := dto.ClaudeRequest{
+ Model: textRequest.Model,
+ Prompt: "",
+ StopSequences: nil,
+ Temperature: textRequest.Temperature,
+ TopP: textRequest.TopP,
+ TopK: textRequest.TopK,
+ Stream: textRequest.Stream,
+ }
+ if claudeRequest.MaxTokensToSample == 0 {
+ claudeRequest.MaxTokensToSample = 4096
+ }
+ prompt := ""
+ for _, message := range textRequest.Messages {
+ if message.Role == "user" {
+ prompt += fmt.Sprintf("\n\nHuman: %s", message.StringContent())
+ } else if message.Role == "assistant" {
+ prompt += fmt.Sprintf("\n\nAssistant: %s", message.StringContent())
+ } else if message.Role == "system" {
+ if prompt == "" {
+ prompt = message.StringContent()
+ }
+ }
+ }
+ prompt += "\n\nAssistant:"
+ claudeRequest.Prompt = prompt
+ return &claudeRequest
+}
+
+func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
+ claudeTools := make([]any, 0, len(textRequest.Tools))
+
+ for _, tool := range textRequest.Tools {
+ if params, ok := tool.Function.Parameters.(map[string]any); ok {
+ claudeTool := dto.Tool{
+ Name: tool.Function.Name,
+ Description: tool.Function.Description,
+ }
+ claudeTool.InputSchema = make(map[string]interface{})
+ if params["type"] != nil {
+ claudeTool.InputSchema["type"] = params["type"].(string)
+ }
+ claudeTool.InputSchema["properties"] = params["properties"]
+ claudeTool.InputSchema["required"] = params["required"]
+ for s, a := range params {
+ if s == "type" || s == "properties" || s == "required" {
+ continue
+ }
+ claudeTool.InputSchema[s] = a
+ }
+ claudeTools = append(claudeTools, &claudeTool)
+ }
+ }
+
+ // Web search tool
+ // https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/web-search-tool
+ if textRequest.WebSearchOptions != nil {
+ webSearchTool := dto.ClaudeWebSearchTool{
+ Type: "web_search_20250305",
+ Name: "web_search",
+ }
+
+ // 处理 user_location
+ if textRequest.WebSearchOptions.UserLocation != nil {
+ anthropicUserLocation := &dto.ClaudeWebSearchUserLocation{
+ Type: "approximate", // 固定为 "approximate"
+ }
+
+ // 解析 UserLocation JSON
+ var userLocationMap map[string]interface{}
+ if err := json.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil {
+ // 检查是否有 approximate 字段
+ if approximateData, ok := userLocationMap["approximate"].(map[string]interface{}); ok {
+ if timezone, ok := approximateData["timezone"].(string); ok && timezone != "" {
+ anthropicUserLocation.Timezone = timezone
+ }
+ if country, ok := approximateData["country"].(string); ok && country != "" {
+ anthropicUserLocation.Country = country
+ }
+ if region, ok := approximateData["region"].(string); ok && region != "" {
+ anthropicUserLocation.Region = region
+ }
+ if city, ok := approximateData["city"].(string); ok && city != "" {
+ anthropicUserLocation.City = city
+ }
+ }
+ }
+
+ webSearchTool.UserLocation = anthropicUserLocation
+ }
+
+ // 处理 search_context_size 转换为 max_uses
+ if textRequest.WebSearchOptions.SearchContextSize != "" {
+ switch textRequest.WebSearchOptions.SearchContextSize {
+ case "low":
+ webSearchTool.MaxUses = WebSearchMaxUsesLow
+ case "medium":
+ webSearchTool.MaxUses = WebSearchMaxUsesMedium
+ case "high":
+ webSearchTool.MaxUses = WebSearchMaxUsesHigh
+ }
+ }
+
+ claudeTools = append(claudeTools, &webSearchTool)
+ }
+
+ claudeRequest := dto.ClaudeRequest{
+ Model: textRequest.Model,
+ MaxTokens: textRequest.MaxTokens,
+ StopSequences: nil,
+ Temperature: textRequest.Temperature,
+ TopP: textRequest.TopP,
+ TopK: textRequest.TopK,
+ Stream: textRequest.Stream,
+ Tools: claudeTools,
+ }
+
+ // 处理 tool_choice 和 parallel_tool_calls
+ if textRequest.ToolChoice != nil || textRequest.ParallelTooCalls != nil {
+ claudeToolChoice := mapToolChoice(textRequest.ToolChoice, textRequest.ParallelTooCalls)
+ if claudeToolChoice != nil {
+ claudeRequest.ToolChoice = claudeToolChoice
+ }
+ }
+
+ if claudeRequest.MaxTokens == 0 {
+ claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
+ }
+
+ if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
+ strings.HasSuffix(textRequest.Model, "-thinking") {
+
+ // 因为BudgetTokens 必须大于1024
+ if claudeRequest.MaxTokens < 1280 {
+ claudeRequest.MaxTokens = 1280
+ }
+
+ // BudgetTokens 为 max_tokens 的 80%
+ claudeRequest.Thinking = &dto.Thinking{
+ Type: "enabled",
+ BudgetTokens: common.GetPointer[int](int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
+ }
+ // TODO: 临时处理
+ // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
+ claudeRequest.TopP = 0
+ claudeRequest.Temperature = common.GetPointer[float64](1.0)
+ claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
+ }
+
+ if textRequest.ReasoningEffort != "" {
+ switch textRequest.ReasoningEffort {
+ case "low":
+ claudeRequest.Thinking = &dto.Thinking{
+ Type: "enabled",
+ BudgetTokens: common.GetPointer[int](1280),
+ }
+ case "medium":
+ claudeRequest.Thinking = &dto.Thinking{
+ Type: "enabled",
+ BudgetTokens: common.GetPointer[int](2048),
+ }
+ case "high":
+ claudeRequest.Thinking = &dto.Thinking{
+ Type: "enabled",
+ BudgetTokens: common.GetPointer[int](4096),
+ }
+ }
+ }
+
+ // 指定了 reasoning 参数,覆盖 budgetTokens
+ if textRequest.Reasoning != nil {
+ var reasoning openrouter.RequestReasoning
+ if err := common.Unmarshal(textRequest.Reasoning, &reasoning); err != nil {
+ return nil, err
+ }
+
+ budgetTokens := reasoning.MaxTokens
+ if budgetTokens > 0 {
+ claudeRequest.Thinking = &dto.Thinking{
+ Type: "enabled",
+ BudgetTokens: &budgetTokens,
+ }
+ }
+ }
+
+ if textRequest.Stop != nil {
+ // stop maybe string/array string, convert to array string
+ switch textRequest.Stop.(type) {
+ case string:
+ claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
+ case []interface{}:
+ stopSequences := make([]string, 0)
+ for _, stop := range textRequest.Stop.([]interface{}) {
+ stopSequences = append(stopSequences, stop.(string))
+ }
+ claudeRequest.StopSequences = stopSequences
+ }
+ }
+ formatMessages := make([]dto.Message, 0)
+ lastMessage := dto.Message{
+ Role: "tool",
+ }
+ for i, message := range textRequest.Messages {
+ if message.Role == "" {
+ textRequest.Messages[i].Role = "user"
+ }
+ fmtMessage := dto.Message{
+ Role: message.Role,
+ Content: message.Content,
+ }
+ if message.Role == "tool" {
+ fmtMessage.ToolCallId = message.ToolCallId
+ }
+ if message.Role == "assistant" && message.ToolCalls != nil {
+ fmtMessage.ToolCalls = message.ToolCalls
+ }
+ if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
+ if lastMessage.IsStringContent() && message.IsStringContent() {
+ fmtMessage.SetStringContent(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
+ // delete last message
+ formatMessages = formatMessages[:len(formatMessages)-1]
+ }
+ }
+ if fmtMessage.Content == nil {
+ fmtMessage.SetStringContent("...")
+ }
+ formatMessages = append(formatMessages, fmtMessage)
+ lastMessage = fmtMessage
+ }
+
+ claudeMessages := make([]dto.ClaudeMessage, 0)
+ isFirstMessage := true
+ for _, message := range formatMessages {
+ if message.Role == "system" {
+ if message.IsStringContent() {
+ claudeRequest.System = message.StringContent()
+ } else {
+ contents := message.ParseContent()
+ content := ""
+ for _, ctx := range contents {
+ if ctx.Type == "text" {
+ content += ctx.Text
+ }
+ }
+ claudeRequest.System = content
+ }
+ } else {
+ if isFirstMessage {
+ isFirstMessage = false
+ if message.Role != "user" {
+ // fix: first message is assistant, add user message
+ claudeMessage := dto.ClaudeMessage{
+ Role: "user",
+ Content: []dto.ClaudeMediaMessage{
+ {
+ Type: "text",
+ Text: common.GetPointer[string]("..."),
+ },
+ },
+ }
+ claudeMessages = append(claudeMessages, claudeMessage)
+ }
+ }
+ claudeMessage := dto.ClaudeMessage{
+ Role: message.Role,
+ }
+ if message.Role == "tool" {
+ if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
+ lastMessage := claudeMessages[len(claudeMessages)-1]
+ if content, ok := lastMessage.Content.(string); ok {
+ lastMessage.Content = []dto.ClaudeMediaMessage{
+ {
+ Type: "text",
+ Text: common.GetPointer[string](content),
+ },
+ }
+ }
+ lastMessage.Content = append(lastMessage.Content.([]dto.ClaudeMediaMessage), dto.ClaudeMediaMessage{
+ Type: "tool_result",
+ ToolUseId: message.ToolCallId,
+ Content: message.Content,
+ })
+ claudeMessages[len(claudeMessages)-1] = lastMessage
+ continue
+ } else {
+ claudeMessage.Role = "user"
+ claudeMessage.Content = []dto.ClaudeMediaMessage{
+ {
+ Type: "tool_result",
+ ToolUseId: message.ToolCallId,
+ Content: message.Content,
+ },
+ }
+ }
+ } else if message.IsStringContent() && message.ToolCalls == nil {
+ claudeMessage.Content = message.StringContent()
+ } else {
+ claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0)
+ for _, mediaMessage := range message.ParseContent() {
+ claudeMediaMessage := dto.ClaudeMediaMessage{
+ Type: mediaMessage.Type,
+ }
+ if mediaMessage.Type == "text" {
+ claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text)
+ } else {
+ imageUrl := mediaMessage.GetImageMedia()
+ claudeMediaMessage.Type = "image"
+ claudeMediaMessage.Source = &dto.ClaudeMessageSource{
+ Type: "base64",
+ }
+ // 判断是否是url
+ if strings.HasPrefix(imageUrl.Url, "http") {
+ // 是url,获取图片的类型和base64编码的数据
+ fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
+ if err != nil {
+ return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
+ }
+ claudeMediaMessage.Source.MediaType = fileData.MimeType
+ claudeMediaMessage.Source.Data = fileData.Base64Data
+ } else {
+ _, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
+ if err != nil {
+ return nil, err
+ }
+ claudeMediaMessage.Source.MediaType = "image/" + format
+ claudeMediaMessage.Source.Data = base64String
+ }
+ }
+ claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
+ }
+ if message.ToolCalls != nil {
+ for _, toolCall := range message.ParseToolCalls() {
+ inputObj := make(map[string]any)
+ if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
+ common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
+ continue
+ }
+ claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
+ Type: "tool_use",
+ Id: toolCall.ID,
+ Name: toolCall.Function.Name,
+ Input: inputObj,
+ })
+ }
+ }
+ claudeMessage.Content = claudeMediaMessages
+ }
+ claudeMessages = append(claudeMessages, claudeMessage)
+ }
+ }
+ claudeRequest.Prompt = ""
+ claudeRequest.Messages = claudeMessages
+ return &claudeRequest, nil
+}
+
+func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
+ var response dto.ChatCompletionsStreamResponse
+ response.Object = "chat.completion.chunk"
+ response.Model = claudeResponse.Model
+ response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
+ tools := make([]dto.ToolCallResponse, 0)
+ fcIdx := 0
+ if claudeResponse.Index != nil {
+ fcIdx = *claudeResponse.Index - 1
+ if fcIdx < 0 {
+ fcIdx = 0
+ }
+ }
+ var choice dto.ChatCompletionsStreamResponseChoice
+ if reqMode == RequestModeCompletion {
+ choice.Delta.SetContentString(claudeResponse.Completion)
+ finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
+ if finishReason != "null" {
+ choice.FinishReason = &finishReason
+ }
+ } else {
+ if claudeResponse.Type == "message_start" {
+ response.Id = claudeResponse.Message.Id
+ response.Model = claudeResponse.Message.Model
+ //claudeUsage = &claudeResponse.Message.Usage
+ choice.Delta.SetContentString("")
+ choice.Delta.Role = "assistant"
+ } else if claudeResponse.Type == "content_block_start" {
+ if claudeResponse.ContentBlock != nil {
+ //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
+ if claudeResponse.ContentBlock.Type == "tool_use" {
+ tools = append(tools, dto.ToolCallResponse{
+ Index: common.GetPointer(fcIdx),
+ ID: claudeResponse.ContentBlock.Id,
+ Type: "function",
+ Function: dto.FunctionResponse{
+ Name: claudeResponse.ContentBlock.Name,
+ Arguments: "",
+ },
+ })
+ }
+ } else {
+ return nil
+ }
+ } else if claudeResponse.Type == "content_block_delta" {
+ if claudeResponse.Delta != nil {
+ choice.Delta.Content = claudeResponse.Delta.Text
+ switch claudeResponse.Delta.Type {
+ case "input_json_delta":
+ tools = append(tools, dto.ToolCallResponse{
+ Type: "function",
+ Index: common.GetPointer(fcIdx),
+ Function: dto.FunctionResponse{
+ Arguments: *claudeResponse.Delta.PartialJson,
+ },
+ })
+ case "signature_delta":
+ // 加密的不处理
+ signatureContent := "\n"
+ choice.Delta.ReasoningContent = &signatureContent
+ case "thinking_delta":
+ thinkingContent := claudeResponse.Delta.Thinking
+ choice.Delta.ReasoningContent = &thinkingContent
+ }
+ }
+ } else if claudeResponse.Type == "message_delta" {
+ finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
+ if finishReason != "null" {
+ choice.FinishReason = &finishReason
+ }
+ //claudeUsage = &claudeResponse.Usage
+ } else if claudeResponse.Type == "message_stop" {
+ return nil
+ } else {
+ return nil
+ }
+ }
+ if len(tools) > 0 {
+ choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
+ choice.Delta.ToolCalls = tools
+ }
+ response.Choices = append(response.Choices, choice)
+
+ return &response
+}
+
+func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
+ choices := make([]dto.OpenAITextResponseChoice, 0)
+ fullTextResponse := dto.OpenAITextResponse{
+ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+ Object: "chat.completion",
+ Created: common.GetTimestamp(),
+ }
+ var responseText string
+ var responseThinking string
+ if len(claudeResponse.Content) > 0 {
+ responseText = claudeResponse.Content[0].GetText()
+ responseThinking = claudeResponse.Content[0].Thinking
+ }
+ tools := make([]dto.ToolCallResponse, 0)
+ thinkingContent := ""
+
+ if reqMode == RequestModeCompletion {
+ choice := dto.OpenAITextResponseChoice{
+ Index: 0,
+ Message: dto.Message{
+ Role: "assistant",
+ Content: strings.TrimPrefix(claudeResponse.Completion, " "),
+ Name: nil,
+ },
+ FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
+ }
+ choices = append(choices, choice)
+ } else {
+ fullTextResponse.Id = claudeResponse.Id
+ for _, message := range claudeResponse.Content {
+ switch message.Type {
+ case "tool_use":
+ args, _ := json.Marshal(message.Input)
+ tools = append(tools, dto.ToolCallResponse{
+ ID: message.Id,
+ Type: "function", // compatible with other OpenAI derivative applications
+ Function: dto.FunctionResponse{
+ Name: message.Name,
+ Arguments: string(args),
+ },
+ })
+ case "thinking":
+ // 加密的不管, 只输出明文的推理过程
+ thinkingContent = message.Thinking
+ case "text":
+ responseText = message.GetText()
+ }
+ }
+ }
+ choice := dto.OpenAITextResponseChoice{
+ Index: 0,
+ Message: dto.Message{
+ Role: "assistant",
+ },
+ FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
+ }
+ choice.SetStringContent(responseText)
+ if len(responseThinking) > 0 {
+ choice.ReasoningContent = responseThinking
+ }
+ if len(tools) > 0 {
+ choice.Message.SetToolCalls(tools)
+ }
+ choice.Message.ReasoningContent = thinkingContent
+ fullTextResponse.Model = claudeResponse.Model
+ choices = append(choices, choice)
+ fullTextResponse.Choices = choices
+ return &fullTextResponse
+}
+
+type ClaudeResponseInfo struct {
+ ResponseId string
+ Created int64
+ Model string
+ ResponseText strings.Builder
+ Usage *dto.Usage
+ Done bool
+}
+
+func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
+ if requestMode == RequestModeCompletion {
+ claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
+ } else {
+ if claudeResponse.Type == "message_start" {
+ claudeInfo.ResponseId = claudeResponse.Message.Id
+ claudeInfo.Model = claudeResponse.Message.Model
+
+ // message_start, 获取usage
+ claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
+ claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
+ claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
+ claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
+ } else if claudeResponse.Type == "content_block_delta" {
+ if claudeResponse.Delta.Text != nil {
+ claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
+ }
+ if claudeResponse.Delta.Thinking != "" {
+ claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
+ }
+ } else if claudeResponse.Type == "message_delta" {
+ // 最终的usage获取
+ if claudeResponse.Usage.InputTokens > 0 {
+ // 不叠加,只取最新的
+ claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
+ }
+ claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+ claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
+
+ // 判断是否完整
+ claudeInfo.Done = true
+ } else if claudeResponse.Type == "content_block_start" {
+ } else {
+ return false
+ }
+ }
+ if oaiResponse != nil {
+ oaiResponse.Id = claudeInfo.ResponseId
+ oaiResponse.Created = claudeInfo.Created
+ oaiResponse.Model = claudeInfo.Model
+ }
+ return true
+}
+
+func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *types.NewAPIError {
+ var claudeResponse dto.ClaudeResponse
+ err := common.UnmarshalJsonStr(data, &claudeResponse)
+ if err != nil {
+ common.SysError("error unmarshalling stream response: " + err.Error())
+ return types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
+ return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
+ }
+ if info.RelayFormat == relaycommon.RelayFormatClaude {
+ FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
+
+ if requestMode == RequestModeCompletion {
+ } else {
+ if claudeResponse.Type == "message_start" {
+ // message_start, 获取usage
+ info.UpstreamModelName = claudeResponse.Message.Model
+ } else if claudeResponse.Type == "content_block_delta" {
+ } else if claudeResponse.Type == "message_delta" {
+ }
+ }
+ helper.ClaudeChunkData(c, claudeResponse, data)
+ } else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
+ response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
+
+ if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
+ return nil
+ }
+
+ err = helper.ObjectData(c, response)
+ if err != nil {
+ common.LogError(c, "send_stream_response_failed: "+err.Error())
+ }
+ }
+ return nil
+}
+
+func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
+
+ if requestMode == RequestModeCompletion {
+ claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
+ } else {
+ if claudeInfo.Usage.PromptTokens == 0 {
+ //上游出错
+ }
+ if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
+ if common.DebugEnabled {
+ common.SysError("claude response usage is not complete, maybe upstream error")
+ }
+ claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
+ }
+ }
+
+ if info.RelayFormat == relaycommon.RelayFormatClaude {
+ //
+ } else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
+
+ if info.ShouldIncludeUsage {
+ response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
+ err := helper.ObjectData(c, response)
+ if err != nil {
+ common.SysError("send final response failed: " + err.Error())
+ }
+ }
+ helper.Done(c)
+ }
+}
+
+func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
+ claudeInfo := &ClaudeResponseInfo{
+ ResponseId: helper.GetResponseID(c),
+ Created: common.GetTimestamp(),
+ Model: info.UpstreamModelName,
+ ResponseText: strings.Builder{},
+ Usage: &dto.Usage{},
+ }
+ var err *types.NewAPIError
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+ err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
+ if err != nil {
+ return false
+ }
+ return true
+ })
+ if err != nil {
+ return err, nil
+ }
+
+ HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
+ return nil, claudeInfo.Usage
+}
+
+func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError {
+ var claudeResponse dto.ClaudeResponse
+ err := common.Unmarshal(data, &claudeResponse)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
+ return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
+ }
+ if requestMode == RequestModeCompletion {
+ completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
+ claudeInfo.Usage.PromptTokens = info.PromptTokens
+ claudeInfo.Usage.CompletionTokens = completionTokens
+ claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
+ } else {
+ claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
+ claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+ claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
+ claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
+ claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
+ }
+ var responseData []byte
+ switch info.RelayFormat {
+ case relaycommon.RelayFormatOpenAI:
+ openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
+ openaiResponse.Usage = *claudeInfo.Usage
+ responseData, err = json.Marshal(openaiResponse)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ case relaycommon.RelayFormatClaude:
+ responseData = data
+ }
+
+ if claudeResponse.Usage.ServerToolUse != nil && claudeResponse.Usage.ServerToolUse.WebSearchRequests > 0 {
+ c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
+ }
+
+ common.IOCopyBytesGracefully(c, nil, responseData)
+ return nil
+}
+
+func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
+ defer common.CloseResponseBodyGracefully(resp)
+
+ claudeInfo := &ClaudeResponseInfo{
+ ResponseId: helper.GetResponseID(c),
+ Created: common.GetTimestamp(),
+ Model: info.UpstreamModelName,
+ ResponseText: strings.Builder{},
+ Usage: &dto.Usage{},
+ }
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ if common.DebugEnabled {
+ println("responseBody: ", string(responseBody))
+ }
+ handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
+ if handleErr != nil {
+ return handleErr, nil
+ }
+ return nil, claudeInfo.Usage
+}
+
+func mapToolChoice(toolChoice any, parallelToolCalls *bool) *dto.ClaudeToolChoice {
+ var claudeToolChoice *dto.ClaudeToolChoice
+
+ // 处理 tool_choice 字符串值
+ if toolChoiceStr, ok := toolChoice.(string); ok {
+ switch toolChoiceStr {
+ case "auto":
+ claudeToolChoice = &dto.ClaudeToolChoice{
+ Type: "auto",
+ }
+ case "required":
+ claudeToolChoice = &dto.ClaudeToolChoice{
+ Type: "any",
+ }
+ case "none":
+ claudeToolChoice = &dto.ClaudeToolChoice{
+ Type: "none",
+ }
+ }
+ } else if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
+ // 处理 tool_choice 对象值
+ if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
+ if toolName, ok := function["name"].(string); ok {
+ claudeToolChoice = &dto.ClaudeToolChoice{
+ Type: "tool",
+ Name: toolName,
+ }
+ }
+ }
+ }
+
+ // 处理 parallel_tool_calls
+ if parallelToolCalls != nil {
+ if claudeToolChoice == nil {
+ // 如果没有 tool_choice,但有 parallel_tool_calls,创建默认的 auto 类型
+ claudeToolChoice = &dto.ClaudeToolChoice{
+ Type: "auto",
+ }
+ }
+
+ // 设置 disable_parallel_tool_use
+ // 如果 parallel_tool_calls 为 true,则 disable_parallel_tool_use 为 false
+ claudeToolChoice.DisableParallelToolUse = !*parallelToolCalls
+ }
+
+ return claudeToolChoice
+}
diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go
new file mode 100644
index 00000000..6e59ad71
--- /dev/null
+++ b/relay/channel/cloudflare/adaptor.go
@@ -0,0 +1,122 @@
+package cloudflare
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ switch info.RelayMode {
+ case constant.RelayModeChatCompletions:
+ return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
+ case constant.RelayModeEmbeddings:
+ return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
+ default:
+ return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
+ }
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ switch info.RelayMode {
+ case constant.RelayModeCompletions:
+ return convertCf2CompletionsRequest(*request), nil
+ default:
+ return request, nil
+ }
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ // 添加文件字段
+ file, _, err := c.Request.FormFile("file")
+ if err != nil {
+ return nil, errors.New("file is required")
+ }
+ defer file.Close()
+ // 打开临时文件用于保存上传的文件内容
+ requestBody := &bytes.Buffer{}
+
+ // 将上传的文件内容复制到临时文件
+ if _, err := io.Copy(requestBody, file); err != nil {
+ return nil, err
+ }
+ return requestBody, nil
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ switch info.RelayMode {
+ case constant.RelayModeEmbeddings:
+ fallthrough
+ case constant.RelayModeChatCompletions:
+ if info.IsStream {
+ err, usage = cfStreamHandler(c, info, resp)
+ } else {
+ err, usage = cfHandler(c, info, resp)
+ }
+ case constant.RelayModeAudioTranslation:
+ fallthrough
+ case constant.RelayModeAudioTranscription:
+ err, usage = cfSTTHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/cloudflare/constant.go b/relay/channel/cloudflare/constant.go
new file mode 100644
index 00000000..0e2aec2b
--- /dev/null
+++ b/relay/channel/cloudflare/constant.go
@@ -0,0 +1,39 @@
+package cloudflare
+
+var ModelList = []string{
+ "@cf/meta/llama-3.1-8b-instruct",
+ "@cf/meta/llama-2-7b-chat-fp16",
+ "@cf/meta/llama-2-7b-chat-int8",
+ "@cf/mistral/mistral-7b-instruct-v0.1",
+ "@hf/thebloke/deepseek-coder-6.7b-base-awq",
+ "@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
+ "@cf/deepseek-ai/deepseek-math-7b-base",
+ "@cf/deepseek-ai/deepseek-math-7b-instruct",
+ "@cf/thebloke/discolm-german-7b-v1-awq",
+ "@cf/tiiuae/falcon-7b-instruct",
+ "@cf/google/gemma-2b-it-lora",
+ "@hf/google/gemma-7b-it",
+ "@cf/google/gemma-7b-it-lora",
+ "@hf/nousresearch/hermes-2-pro-mistral-7b",
+ "@hf/thebloke/llama-2-13b-chat-awq",
+ "@cf/meta-llama/llama-2-7b-chat-hf-lora",
+ "@cf/meta/llama-3-8b-instruct",
+ "@hf/thebloke/llamaguard-7b-awq",
+ "@hf/thebloke/mistral-7b-instruct-v0.1-awq",
+ "@hf/mistralai/mistral-7b-instruct-v0.2",
+ "@cf/mistral/mistral-7b-instruct-v0.2-lora",
+ "@hf/thebloke/neural-chat-7b-v3-1-awq",
+ "@cf/openchat/openchat-3.5-0106",
+ "@hf/thebloke/openhermes-2.5-mistral-7b-awq",
+ "@cf/microsoft/phi-2",
+ "@cf/qwen/qwen1.5-0.5b-chat",
+ "@cf/qwen/qwen1.5-1.8b-chat",
+ "@cf/qwen/qwen1.5-14b-chat-awq",
+ "@cf/qwen/qwen1.5-7b-chat-awq",
+ "@cf/defog/sqlcoder-7b-2",
+ "@hf/nexusflow/starling-lm-7b-beta",
+ "@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
+ "@hf/thebloke/zephyr-7b-beta-awq",
+}
+
+var ChannelName = "cloudflare"
diff --git a/relay/channel/cloudflare/dto.go b/relay/channel/cloudflare/dto.go
new file mode 100644
index 00000000..62a45c40
--- /dev/null
+++ b/relay/channel/cloudflare/dto.go
@@ -0,0 +1,21 @@
+package cloudflare
+
+import "one-api/dto"
+
+type CfRequest struct {
+ Messages []dto.Message `json:"messages,omitempty"`
+ Lora string `json:"lora,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Prompt string `json:"prompt,omitempty"`
+ Raw bool `json:"raw,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+}
+
+type CfAudioResponse struct {
+ Result CfSTTResult `json:"result"`
+}
+
+type CfSTTResult struct {
+ Text string `json:"text"`
+}
diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go
new file mode 100644
index 00000000..5e8fe7f9
--- /dev/null
+++ b/relay/channel/cloudflare/relay_cloudflare.go
@@ -0,0 +1,150 @@
+package cloudflare
+
+import (
+ "bufio"
+ "encoding/json"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
+ p, _ := textRequest.Prompt.(string)
+ return &CfRequest{
+ Prompt: p,
+ MaxTokens: textRequest.GetMaxTokens(),
+ Stream: textRequest.Stream,
+ Temperature: textRequest.Temperature,
+ }
+}
+
+func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
+ scanner := bufio.NewScanner(resp.Body)
+ scanner.Split(bufio.ScanLines)
+
+ helper.SetEventStreamHeaders(c)
+ id := helper.GetResponseID(c)
+ var responseText string
+ isFirst := true
+
+ for scanner.Scan() {
+ data := scanner.Text()
+ if len(data) < len("data: ") {
+ continue
+ }
+ data = strings.TrimPrefix(data, "data: ")
+ data = strings.TrimSuffix(data, "\r")
+
+ if data == "[DONE]" {
+ break
+ }
+
+ var response dto.ChatCompletionsStreamResponse
+ err := json.Unmarshal([]byte(data), &response)
+ if err != nil {
+ common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
+ continue
+ }
+ for _, choice := range response.Choices {
+ choice.Delta.Role = "assistant"
+ responseText += choice.Delta.GetContentString()
+ }
+ response.Id = id
+ response.Model = info.UpstreamModelName
+ err = helper.ObjectData(c, response)
+ if isFirst {
+ isFirst = false
+ info.FirstResponseTime = time.Now()
+ }
+ if err != nil {
+ common.LogError(c, "error_rendering_stream_response: "+err.Error())
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ common.LogError(c, "error_scanning_stream_response: "+err.Error())
+ }
+ usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ if info.ShouldIncludeUsage {
+ response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
+ err := helper.ObjectData(c, response)
+ if err != nil {
+ common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
+ }
+ }
+ helper.Done(c)
+
+ common.CloseResponseBodyGracefully(resp)
+
+ return nil, usage
+}
+
+func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ common.CloseResponseBodyGracefully(resp)
+ var response dto.TextResponse
+ err = json.Unmarshal(responseBody, &response)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ response.Model = info.UpstreamModelName
+ var responseText string
+ for _, choice := range response.Choices {
+ responseText += choice.Message.StringContent()
+ }
+ usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ response.Usage = *usage
+ response.Id = helper.GetResponseID(c)
+ jsonResponse, err := json.Marshal(response)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, _ = c.Writer.Write(jsonResponse)
+ return nil, usage
+}
+
+func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
+ var cfResp CfAudioResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ common.CloseResponseBodyGracefully(resp)
+ err = json.Unmarshal(responseBody, &cfResp)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+
+ audioResp := &dto.AudioResponse{
+ Text: cfResp.Result.Text,
+ }
+
+ jsonResponse, err := json.Marshal(audioResp)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, _ = c.Writer.Write(jsonResponse)
+
+ usage := &dto.Usage{}
+ usage.PromptTokens = info.PromptTokens
+ usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+
+ return nil, usage
+}
diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go
new file mode 100644
index 00000000..4f3a96c3
--- /dev/null
+++ b/relay/channel/cohere/adaptor.go
@@ -0,0 +1,94 @@
+package cohere
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ if info.RelayMode == constant.RelayModeRerank {
+ return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
+ } else {
+ return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
+ }
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ return requestOpenAI2Cohere(*request), nil
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return requestConvertRerank2Cohere(request), nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.RelayMode == constant.RelayModeRerank {
+ usage, err = cohereRerankHandler(c, resp, info)
+ } else {
+ if info.IsStream {
+ usage, err = cohereStreamHandler(c, info, resp) // TODO: fix this
+ } else {
+ usage, err = cohereHandler(c, info, resp)
+ }
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/cohere/constant.go b/relay/channel/cohere/constant.go
new file mode 100644
index 00000000..f2d2e559
--- /dev/null
+++ b/relay/channel/cohere/constant.go
@@ -0,0 +1,12 @@
+package cohere
+
+var ModelList = []string{
+ "command-a-03-2025",
+ "command-r", "command-r-plus",
+ "command-r-08-2024", "command-r-plus-08-2024",
+ "c4ai-aya-23-35b", "c4ai-aya-23-8b",
+ "command-light", "command-light-nightly", "command", "command-nightly",
+ "rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
+}
+
+var ChannelName = "cohere"
diff --git a/relay/channel/cohere/dto.go b/relay/channel/cohere/dto.go
new file mode 100644
index 00000000..410540c0
--- /dev/null
+++ b/relay/channel/cohere/dto.go
@@ -0,0 +1,60 @@
+package cohere
+
+import "one-api/dto"
+
+type CohereRequest struct {
+ Model string `json:"model"`
+ ChatHistory []ChatHistory `json:"chat_history"`
+ Message string `json:"message"`
+ Stream bool `json:"stream"`
+ MaxTokens int `json:"max_tokens"`
+ SafetyMode string `json:"safety_mode,omitempty"`
+}
+
+type ChatHistory struct {
+ Role string `json:"role"`
+ Message string `json:"message"`
+}
+
+type CohereResponse struct {
+ IsFinished bool `json:"is_finished"`
+ EventType string `json:"event_type"`
+ Text string `json:"text,omitempty"`
+ FinishReason string `json:"finish_reason,omitempty"`
+ Response *CohereResponseResult `json:"response"`
+}
+
+type CohereResponseResult struct {
+ ResponseId string `json:"response_id"`
+ FinishReason string `json:"finish_reason,omitempty"`
+ Text string `json:"text"`
+ Meta CohereMeta `json:"meta"`
+}
+
+type CohereRerankRequest struct {
+ Documents []any `json:"documents"`
+ Query string `json:"query"`
+ Model string `json:"model"`
+ TopN int `json:"top_n"`
+ ReturnDocuments bool `json:"return_documents"`
+}
+
+type CohereRerankResponseResult struct {
+ Results []dto.RerankResponseResult `json:"results"`
+ Meta CohereMeta `json:"meta"`
+}
+
+type CohereMeta struct {
+ //Tokens CohereTokens `json:"tokens"`
+ BilledUnits CohereBilledUnits `json:"billed_units"`
+}
+
+type CohereBilledUnits struct {
+ InputTokens int `json:"input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+}
+
+type CohereTokens struct {
+ InputTokens int `json:"input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+}
diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go
new file mode 100644
index 00000000..fcfb12b7
--- /dev/null
+++ b/relay/channel/cohere/relay-cohere.go
@@ -0,0 +1,248 @@
+package cohere
+
+import (
+ "bufio"
+ "encoding/json"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
+ cohereReq := CohereRequest{
+ Model: textRequest.Model,
+ ChatHistory: []ChatHistory{},
+ Message: "",
+ Stream: textRequest.Stream,
+ MaxTokens: textRequest.GetMaxTokens(),
+ }
+ if common.CohereSafetySetting != "NONE" {
+ cohereReq.SafetyMode = common.CohereSafetySetting
+ }
+ if cohereReq.MaxTokens == 0 {
+ cohereReq.MaxTokens = 4000
+ }
+ for _, msg := range textRequest.Messages {
+ if msg.Role == "user" {
+ cohereReq.Message = msg.StringContent()
+ } else {
+ var role string
+ if msg.Role == "assistant" {
+ role = "CHATBOT"
+ } else if msg.Role == "system" {
+ role = "SYSTEM"
+ } else {
+ role = "USER"
+ }
+ cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
+ Role: role,
+ Message: msg.StringContent(),
+ })
+ }
+ }
+
+ return &cohereReq
+}
+
+func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
+ if rerankRequest.TopN == 0 {
+ rerankRequest.TopN = 1
+ }
+ cohereReq := CohereRerankRequest{
+ Query: rerankRequest.Query,
+ Documents: rerankRequest.Documents,
+ Model: rerankRequest.Model,
+ TopN: rerankRequest.TopN,
+ ReturnDocuments: true,
+ }
+ return &cohereReq
+}
+
+func stopReasonCohere2OpenAI(reason string) string {
+ switch reason {
+ case "COMPLETE":
+ return "stop"
+ case "MAX_TOKENS":
+ return "max_tokens"
+ default:
+ return reason
+ }
+}
+
+func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ responseId := helper.GetResponseID(c)
+ createdTime := common.GetTimestamp()
+ usage := &dto.Usage{}
+ responseText := ""
+ scanner := bufio.NewScanner(resp.Body)
+ scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
+ if atEOF && len(data) == 0 {
+ return 0, nil, nil
+ }
+ if i := strings.Index(string(data), "\n"); i >= 0 {
+ return i + 1, data[0:i], nil
+ }
+ if atEOF {
+ return len(data), data, nil
+ }
+ return 0, nil, nil
+ })
+ dataChan := make(chan string)
+ stopChan := make(chan bool)
+ go func() {
+ for scanner.Scan() {
+ data := scanner.Text()
+ dataChan <- data
+ }
+ stopChan <- true
+ }()
+ helper.SetEventStreamHeaders(c)
+ isFirst := true
+ c.Stream(func(w io.Writer) bool {
+ select {
+ case data := <-dataChan:
+ if isFirst {
+ isFirst = false
+ info.FirstResponseTime = time.Now()
+ }
+ data = strings.TrimSuffix(data, "\r")
+ var cohereResp CohereResponse
+ err := json.Unmarshal([]byte(data), &cohereResp)
+ if err != nil {
+ common.SysError("error unmarshalling stream response: " + err.Error())
+ return true
+ }
+ var openaiResp dto.ChatCompletionsStreamResponse
+ openaiResp.Id = responseId
+ openaiResp.Created = createdTime
+ openaiResp.Object = "chat.completion.chunk"
+ openaiResp.Model = info.UpstreamModelName
+ if cohereResp.IsFinished {
+ finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
+ openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
+ {
+ Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
+ Index: 0,
+ FinishReason: &finishReason,
+ },
+ }
+ if cohereResp.Response != nil {
+ usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
+ usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
+ }
+ } else {
+ openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
+ {
+ Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
+ Role: "assistant",
+ Content: &cohereResp.Text,
+ },
+ Index: 0,
+ },
+ }
+ responseText += cohereResp.Text
+ }
+ jsonStr, err := json.Marshal(openaiResp)
+ if err != nil {
+ common.SysError("error marshalling stream response: " + err.Error())
+ return true
+ }
+ c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
+ return true
+ case <-stopChan:
+ c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+ return false
+ }
+ })
+ if usage.PromptTokens == 0 {
+ usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ }
+ return usage, nil
+}
+
+func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ createdTime := common.GetTimestamp()
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ common.CloseResponseBodyGracefully(resp)
+ var cohereResp CohereResponseResult
+ err = json.Unmarshal(responseBody, &cohereResp)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ usage := dto.Usage{}
+ usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
+ usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
+ usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
+
+ var openaiResp dto.TextResponse
+ openaiResp.Id = cohereResp.ResponseId
+ openaiResp.Created = createdTime
+ openaiResp.Object = "chat.completion"
+ openaiResp.Model = info.UpstreamModelName
+ openaiResp.Usage = usage
+
+ openaiResp.Choices = []dto.OpenAITextResponseChoice{
+ {
+ Index: 0,
+ Message: dto.Message{Content: cohereResp.Text, Role: "assistant"},
+ FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
+ },
+ }
+
+ jsonResponse, err := json.Marshal(openaiResp)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, _ = c.Writer.Write(jsonResponse)
+ return &usage, nil
+}
+
+func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ common.CloseResponseBodyGracefully(resp)
+ var cohereResp CohereRerankResponseResult
+ err = json.Unmarshal(responseBody, &cohereResp)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ usage := dto.Usage{}
+ if cohereResp.Meta.BilledUnits.InputTokens == 0 {
+ usage.PromptTokens = info.PromptTokens
+ usage.CompletionTokens = 0
+ usage.TotalTokens = info.PromptTokens
+ } else {
+ usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
+ usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
+ usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
+ }
+
+ var rerankResp dto.RerankResponse
+ rerankResp.Results = cohereResp.Results
+ rerankResp.Usage = usage
+
+ jsonResponse, err := json.Marshal(rerankResp)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = c.Writer.Write(jsonResponse)
+ return &usage, nil
+}
diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go
new file mode 100644
index 00000000..fe5f5f00
--- /dev/null
+++ b/relay/channel/coze/adaptor.go
@@ -0,0 +1,133 @@
+package coze
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/common"
+ "one-api/types"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+// ConvertAudioRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ return nil, errors.New("not implemented")
+}
+
+// ConvertClaudeRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *common.RelayInfo, request *dto.ClaudeRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+// ConvertEmbeddingRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *common.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+// ConvertImageRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *common.RelayInfo, request dto.ImageRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+// ConvertOpenAIRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return convertCozeChatRequest(c, *request), nil
+}
+
+// ConvertOpenAIResponsesRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *common.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+// ConvertRerankRequest implements channel.Adaptor.
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+// DoRequest implements channel.Adaptor.
+func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) {
+ if info.IsStream {
+ return channel.DoApiRequest(a, c, info, requestBody)
+ }
+ // 首先发送创建消息请求,成功后再发送获取消息请求
+ // 发送创建消息请求
+ resp, err := channel.DoApiRequest(a, c, info, requestBody)
+ if err != nil {
+ return nil, err
+ }
+ // 解析 resp
+ var cozeResponse CozeChatResponse
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+ err = json.Unmarshal(respBody, &cozeResponse)
+ if cozeResponse.Code != 0 {
+ return nil, errors.New(cozeResponse.Msg)
+ }
+ c.Set("coze_conversation_id", cozeResponse.Data.ConversationId)
+ c.Set("coze_chat_id", cozeResponse.Data.Id)
+ // 轮询检查消息是否完成
+ for {
+ err, isComplete := checkIfChatComplete(a, c, info)
+ if err != nil {
+ return nil, err
+ } else {
+ if isComplete {
+ break
+ }
+ }
+ time.Sleep(time.Second * 1)
+ }
+ // 发送获取消息请求
+ return getChatDetail(a, c, info)
+}
+
+// DoResponse implements channel.Adaptor.
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ usage, err = cozeChatStreamHandler(c, info, resp)
+ } else {
+ usage, err = cozeChatHandler(c, info, resp)
+ }
+ return
+}
+
+// GetChannelName implements channel.Adaptor.
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
+
+// GetModelList implements channel.Adaptor.
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+// GetRequestURL implements channel.Adaptor.
+func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) {
+ return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil
+}
+
+// Init implements channel.Adaptor.
+func (a *Adaptor) Init(info *common.RelayInfo) {
+
+}
+
+// SetupRequestHeader implements channel.Adaptor.
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *common.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", "Bearer "+info.ApiKey)
+ return nil
+}
diff --git a/relay/channel/coze/constants.go b/relay/channel/coze/constants.go
new file mode 100644
index 00000000..873ffe24
--- /dev/null
+++ b/relay/channel/coze/constants.go
@@ -0,0 +1,30 @@
+package coze
+
+var ModelList = []string{
+ "moonshot-v1-8k",
+ "moonshot-v1-32k",
+ "moonshot-v1-128k",
+ "Baichuan4",
+ "abab6.5s-chat-pro",
+ "glm-4-0520",
+ "qwen-max",
+ "deepseek-r1",
+ "deepseek-v3",
+ "deepseek-r1-distill-qwen-32b",
+ "deepseek-r1-distill-qwen-7b",
+ "step-1v-8k",
+ "step-1.5v-mini",
+ "Doubao-pro-32k",
+ "Doubao-pro-256k",
+ "Doubao-lite-128k",
+ "Doubao-lite-32k",
+ "Doubao-vision-lite-32k",
+ "Doubao-vision-pro-32k",
+ "Doubao-1.5-pro-vision-32k",
+ "Doubao-1.5-lite-32k",
+ "Doubao-1.5-pro-32k",
+ "Doubao-1.5-thinking-pro",
+ "Doubao-1.5-pro-256k",
+}
+
+var ChannelName = "coze"
diff --git a/relay/channel/coze/dto.go b/relay/channel/coze/dto.go
new file mode 100644
index 00000000..d5dc9a81
--- /dev/null
+++ b/relay/channel/coze/dto.go
@@ -0,0 +1,78 @@
+package coze
+
+import "encoding/json"
+
+type CozeError struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+}
+
+type CozeEnterMessage struct {
+ Role string `json:"role"`
+ Type string `json:"type,omitempty"`
+ Content any `json:"content,omitempty"`
+ MetaData json.RawMessage `json:"meta_data,omitempty"`
+ ContentType string `json:"content_type,omitempty"`
+}
+
+type CozeChatRequest struct {
+ BotId string `json:"bot_id"`
+ UserId string `json:"user_id"`
+ AdditionalMessages []CozeEnterMessage `json:"additional_messages,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ CustomVariables json.RawMessage `json:"custom_variables,omitempty"`
+ AutoSaveHistory bool `json:"auto_save_history,omitempty"`
+ MetaData json.RawMessage `json:"meta_data,omitempty"`
+ ExtraParams json.RawMessage `json:"extra_params,omitempty"`
+ ShortcutCommand json.RawMessage `json:"shortcut_command,omitempty"`
+ Parameters json.RawMessage `json:"parameters,omitempty"`
+}
+
+type CozeChatResponse struct {
+ Code int `json:"code"`
+ Msg string `json:"msg"`
+ Data CozeChatResponseData `json:"data"`
+}
+
+type CozeChatResponseData struct {
+ Id string `json:"id"`
+ ConversationId string `json:"conversation_id"`
+ BotId string `json:"bot_id"`
+ CreatedAt int64 `json:"created_at"`
+ LastError CozeError `json:"last_error"`
+ Status string `json:"status"`
+ Usage CozeChatUsage `json:"usage"`
+}
+
+type CozeChatUsage struct {
+ TokenCount int `json:"token_count"`
+ OutputCount int `json:"output_count"`
+ InputCount int `json:"input_count"`
+}
+
+type CozeChatDetailResponse struct {
+ Data []CozeChatV3MessageDetail `json:"data"`
+ Code int `json:"code"`
+ Msg string `json:"msg"`
+ Detail CozeResponseDetail `json:"detail"`
+}
+
+type CozeChatV3MessageDetail struct {
+ Id string `json:"id"`
+ Role string `json:"role"`
+ Type string `json:"type"`
+ BotId string `json:"bot_id"`
+ ChatId string `json:"chat_id"`
+ Content json.RawMessage `json:"content"`
+ MetaData json.RawMessage `json:"meta_data"`
+ CreatedAt int64 `json:"created_at"`
+ SectionId string `json:"section_id"`
+ UpdatedAt int64 `json:"updated_at"`
+ ContentType string `json:"content_type"`
+ ConversationId string `json:"conversation_id"`
+ ReasoningContent string `json:"reasoning_content"`
+}
+
+type CozeResponseDetail struct {
+ Logid string `json:"logid"`
+}
diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go
new file mode 100644
index 00000000..32cc6937
--- /dev/null
+++ b/relay/channel/coze/relay-coze.go
@@ -0,0 +1,296 @@
+package coze
+
+import (
+ "bufio"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *CozeChatRequest {
+ var messages []CozeEnterMessage
+ // 将 request的messages的role为user的content转换为CozeMessage
+ for _, message := range request.Messages {
+ if message.Role == "user" {
+ messages = append(messages, CozeEnterMessage{
+ Role: "user",
+ Content: message.Content,
+ // TODO: support more content type
+ ContentType: "text",
+ })
+ }
+ }
+ user := request.User
+ if user == "" {
+ user = helper.GetResponseID(c)
+ }
+ cozeRequest := &CozeChatRequest{
+ BotId: c.GetString("bot_id"),
+ UserId: user,
+ AdditionalMessages: messages,
+ Stream: request.Stream,
+ }
+ return cozeRequest
+}
+
+func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ common.CloseResponseBodyGracefully(resp)
+ // convert coze response to openai response
+ var response dto.TextResponse
+ var cozeResponse CozeChatDetailResponse
+ response.Model = info.UpstreamModelName
+ err = json.Unmarshal(responseBody, &cozeResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ if cozeResponse.Code != 0 {
+ return nil, types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody)
+ }
+ // 从上下文获取 usage
+ var usage dto.Usage
+ usage.PromptTokens = c.GetInt("coze_input_count")
+ usage.CompletionTokens = c.GetInt("coze_output_count")
+ usage.TotalTokens = c.GetInt("coze_token_count")
+ response.Usage = usage
+ response.Id = helper.GetResponseID(c)
+
+ var responseContent json.RawMessage
+ for _, data := range cozeResponse.Data {
+ if data.Type == "answer" {
+ responseContent = data.Content
+ response.Created = data.CreatedAt
+ }
+ }
+ // 添加 response.Choices
+ response.Choices = []dto.OpenAITextResponseChoice{
+ {
+ Index: 0,
+ Message: dto.Message{Role: "assistant", Content: responseContent},
+ FinishReason: "stop",
+ },
+ }
+ jsonResponse, err := json.Marshal(response)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, _ = c.Writer.Write(jsonResponse)
+
+ return &usage, nil
+}
+
+func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ scanner := bufio.NewScanner(resp.Body)
+ scanner.Split(bufio.ScanLines)
+ helper.SetEventStreamHeaders(c)
+ id := helper.GetResponseID(c)
+ var responseText string
+
+ var currentEvent string
+ var currentData string
+ var usage = &dto.Usage{}
+
+ for scanner.Scan() {
+ line := scanner.Text()
+
+ if line == "" {
+ if currentEvent != "" && currentData != "" {
+ // handle last event
+ handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
+ currentEvent = ""
+ currentData = ""
+ }
+ continue
+ }
+
+ if strings.HasPrefix(line, "event:") {
+ currentEvent = strings.TrimSpace(line[6:])
+ continue
+ }
+
+ if strings.HasPrefix(line, "data:") {
+ currentData = strings.TrimSpace(line[5:])
+ continue
+ }
+ }
+
+ // Last event
+ if currentEvent != "" && currentData != "" {
+ handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
+ }
+
+ if err := scanner.Err(); err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ helper.Done(c)
+
+ if usage.TotalTokens == 0 {
+ usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
+ }
+
+ return usage, nil
+}
+
+func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
+ switch event {
+ case "conversation.chat.completed":
+ // 将 data 解析为 CozeChatResponseData
+ var chatData CozeChatResponseData
+ err := json.Unmarshal([]byte(data), &chatData)
+ if err != nil {
+ common.SysError("error_unmarshalling_stream_response: " + err.Error())
+ return
+ }
+
+ usage.PromptTokens = chatData.Usage.InputCount
+ usage.CompletionTokens = chatData.Usage.OutputCount
+ usage.TotalTokens = chatData.Usage.TokenCount
+
+ finishReason := "stop"
+ stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason)
+ helper.ObjectData(c, stopResponse)
+
+ case "conversation.message.delta":
+ // 将 data 解析为 CozeChatV3MessageDetail
+ var messageData CozeChatV3MessageDetail
+ err := json.Unmarshal([]byte(data), &messageData)
+ if err != nil {
+ common.SysError("error_unmarshalling_stream_response: " + err.Error())
+ return
+ }
+
+ var content string
+ err = json.Unmarshal(messageData.Content, &content)
+ if err != nil {
+ common.SysError("error_unmarshalling_stream_response: " + err.Error())
+ return
+ }
+
+ *responseText += content
+
+ openaiResponse := dto.ChatCompletionsStreamResponse{
+ Id: id,
+ Object: "chat.completion.chunk",
+ Created: common.GetTimestamp(),
+ Model: info.UpstreamModelName,
+ }
+
+ choice := dto.ChatCompletionsStreamResponseChoice{
+ Index: 0,
+ }
+ choice.Delta.SetContentString(content)
+ openaiResponse.Choices = append(openaiResponse.Choices, choice)
+
+ helper.ObjectData(c, openaiResponse)
+
+ case "error":
+ var errorData CozeError
+ err := json.Unmarshal([]byte(data), &errorData)
+ if err != nil {
+ common.SysError("error_unmarshalling_stream_response: " + err.Error())
+ return
+ }
+
+ common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
+ }
+}
+
+func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
+ requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
+
+ requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
+ // 将 conversationId和chatId作为参数发送get请求
+ req, err := http.NewRequest("GET", requestURL, nil)
+ if err != nil {
+ return err, false
+ }
+ err = a.SetupRequestHeader(c, &req.Header, info)
+ if err != nil {
+ return err, false
+ }
+
+ resp, err := doRequest(req, info) // 调用 doRequest
+ if err != nil {
+ return err, false
+ }
+ if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic
+ return fmt.Errorf("resp is nil"), false
+ }
+ defer resp.Body.Close() // 确保响应体被关闭
+
+ // 解析 resp 到 CozeChatResponse
+ var cozeResponse CozeChatResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return fmt.Errorf("read response body failed: %w", err), false
+ }
+ err = json.Unmarshal(responseBody, &cozeResponse)
+ if err != nil {
+ return fmt.Errorf("unmarshal response body failed: %w", err), false
+ }
+ if cozeResponse.Data.Status == "completed" {
+ // 在上下文设置 usage
+ c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount)
+ c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount)
+ c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount)
+ return nil, true
+ } else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" {
+ return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false
+ } else {
+ return nil, false
+ }
+}
+
+func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) {
+ requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl)
+
+ requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
+ req, err := http.NewRequest("GET", requestURL, nil)
+ if err != nil {
+ return nil, fmt.Errorf("new request failed: %w", err)
+ }
+ err = a.SetupRequestHeader(c, &req.Header, info)
+ if err != nil {
+ return nil, fmt.Errorf("setup request header failed: %w", err)
+ }
+ resp, err := doRequest(req, info)
+ if err != nil {
+ return nil, fmt.Errorf("do request failed: %w", err)
+ }
+ return resp, nil
+}
+
+func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
+ var client *http.Client
+ var err error // 声明 err 变量
+ if info.ChannelSetting.Proxy != "" {
+ client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
+ if err != nil {
+ return nil, fmt.Errorf("new proxy http client failed: %w", err)
+ }
+ } else {
+ client = service.GetHttpClient()
+ }
+ resp, err := client.Do(req)
+ if err != nil { // 增加对 client.Do(req) 返回错误的检查
+ return nil, fmt.Errorf("client.Do failed: %w", err)
+ }
+ // _ = resp.Body.Close()
+ return resp, nil
+}
diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go
new file mode 100644
index 00000000..edfc7fd3
--- /dev/null
+++ b/relay/channel/deepseek/adaptor.go
@@ -0,0 +1,100 @@
+package deepseek
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ fimBaseUrl := info.BaseUrl
+ if !strings.HasSuffix(info.BaseUrl, "/beta") {
+ fimBaseUrl += "/beta"
+ }
+ switch info.RelayMode {
+ case constant.RelayModeCompletions:
+ return fmt.Sprintf("%s/completions", fimBaseUrl), nil
+ default:
+ return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
+ }
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", "Bearer "+info.ApiKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ usage, err = openai.OaiStreamHandler(c, info, resp)
+ } else {
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/deepseek/constants.go b/relay/channel/deepseek/constants.go
new file mode 100644
index 00000000..1d7b1e32
--- /dev/null
+++ b/relay/channel/deepseek/constants.go
@@ -0,0 +1,7 @@
+package deepseek
+
+var ModelList = []string{
+ "deepseek-chat", "deepseek-reasoner",
+}
+
+var ChannelName = "deepseek"
diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go
new file mode 100644
index 00000000..4ad16766
--- /dev/null
+++ b/relay/channel/dify/adaptor.go
@@ -0,0 +1,115 @@
+package dify
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ BotTypeChatFlow = 1 // chatflow default
+ BotTypeAgent = 2
+ BotTypeWorkFlow = 3
+ BotTypeCompletion = 4
+)
+
+type Adaptor struct {
+ BotType int
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+ //if strings.HasPrefix(info.UpstreamModelName, "agent") {
+ // a.BotType = BotTypeAgent
+ //} else if strings.HasPrefix(info.UpstreamModelName, "workflow") {
+ // a.BotType = BotTypeWorkFlow
+ //} else if strings.HasPrefix(info.UpstreamModelName, "chat") {
+ // a.BotType = BotTypeCompletion
+ //} else {
+ //}
+ a.BotType = BotTypeChatFlow
+
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ switch a.BotType {
+ case BotTypeWorkFlow:
+ return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil
+ case BotTypeCompletion:
+ return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil
+ case BotTypeAgent:
+ fallthrough
+ default:
+ return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
+ }
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", "Bearer "+info.ApiKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return requestOpenAI2Dify(c, info, *request), nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ return difyStreamHandler(c, info, resp)
+ } else {
+ return difyHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/dify/constants.go b/relay/channel/dify/constants.go
new file mode 100644
index 00000000..db3e67c7
--- /dev/null
+++ b/relay/channel/dify/constants.go
@@ -0,0 +1,5 @@
+package dify
+
+var ModelList []string
+
+var ChannelName = "dify"
diff --git a/relay/channel/dify/dto.go b/relay/channel/dify/dto.go
new file mode 100644
index 00000000..7c6f39b6
--- /dev/null
+++ b/relay/channel/dify/dto.go
@@ -0,0 +1,45 @@
+package dify
+
+import "one-api/dto"
+
+type DifyChatRequest struct {
+ Inputs map[string]interface{} `json:"inputs"`
+ Query string `json:"query"`
+ ResponseMode string `json:"response_mode"`
+ User string `json:"user"`
+ AutoGenerateName bool `json:"auto_generate_name"`
+ Files []DifyFile `json:"files"`
+}
+
+type DifyFile struct {
+ Type string `json:"type"`
+ TransferMode string `json:"transfer_mode"`
+ URL string `json:"url,omitempty"`
+ UploadFileId string `json:"upload_file_id,omitempty"`
+}
+
+type DifyMetaData struct {
+ Usage dto.Usage `json:"usage"`
+}
+
+type DifyData struct {
+ WorkflowId string `json:"workflow_id"`
+ NodeId string `json:"node_id"`
+ NodeType string `json:"node_type"`
+ Status string `json:"status"`
+}
+
+type DifyChatCompletionResponse struct {
+ ConversationId string `json:"conversation_id"`
+ Answer string `json:"answer"`
+ CreateAt int64 `json:"create_at"`
+ MetaData DifyMetaData `json:"metadata"`
+}
+
+type DifyChunkChatCompletionResponse struct {
+ Event string `json:"event"`
+ ConversationId string `json:"conversation_id"`
+ Answer string `json:"answer"`
+ Data DifyData `json:"data"`
+ MetaData DifyMetaData `json:"metadata"`
+}
diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go
new file mode 100644
index 00000000..47337127
--- /dev/null
+++ b/relay/channel/dify/relay-dify.go
@@ -0,0 +1,289 @@
+package dify
+
+import (
+ "bytes"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "mime/multipart"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+ "os"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile {
+ uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl)
+ switch media.Type {
+ case dto.ContentTypeImageURL:
+ // Decode base64 data
+ imageMedia := media.GetImageMedia()
+ base64Data := imageMedia.Url
+ // Remove base64 prefix if exists (e.g., "data:image/jpeg;base64,")
+ if idx := strings.Index(base64Data, ","); idx != -1 {
+ base64Data = base64Data[idx+1:]
+ }
+
+ // Decode base64 string
+ decodedData, err := base64.StdEncoding.DecodeString(base64Data)
+ if err != nil {
+ common.SysError("failed to decode base64: " + err.Error())
+ return nil
+ }
+
+ // Create temporary file
+ tempFile, err := os.CreateTemp("", "dify-upload-*")
+ if err != nil {
+ common.SysError("failed to create temp file: " + err.Error())
+ return nil
+ }
+ defer tempFile.Close()
+ defer os.Remove(tempFile.Name())
+
+ // Write decoded data to temp file
+ if _, err := tempFile.Write(decodedData); err != nil {
+ common.SysError("failed to write to temp file: " + err.Error())
+ return nil
+ }
+
+ // Create multipart form
+ body := &bytes.Buffer{}
+ writer := multipart.NewWriter(body)
+
+ // Add user field
+ if err := writer.WriteField("user", user); err != nil {
+ common.SysError("failed to add user field: " + err.Error())
+ return nil
+ }
+
+ // Create form file with proper mime type
+ mimeType := imageMedia.MimeType
+ if mimeType == "" {
+ mimeType = "image/jpeg" // default mime type
+ }
+
+ // Create form file
+ part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
+ if err != nil {
+ common.SysError("failed to create form file: " + err.Error())
+ return nil
+ }
+
+ // Copy file content to form
+ if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
+ common.SysError("failed to copy file content: " + err.Error())
+ return nil
+ }
+ writer.Close()
+
+ // Create HTTP request
+ req, err := http.NewRequest("POST", uploadUrl, body)
+ if err != nil {
+ common.SysError("failed to create request: " + err.Error())
+ return nil
+ }
+
+ req.Header.Set("Content-Type", writer.FormDataContentType())
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+
+ // Send request
+ client := service.GetHttpClient()
+ resp, err := client.Do(req)
+ if err != nil {
+ common.SysError("failed to send request: " + err.Error())
+ return nil
+ }
+ defer resp.Body.Close()
+
+ // Parse response
+ var result struct {
+ Id string `json:"id"`
+ }
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ common.SysError("failed to decode response: " + err.Error())
+ return nil
+ }
+
+ return &DifyFile{
+ UploadFileId: result.Id,
+ Type: "image",
+ TransferMode: "local_file",
+ }
+ }
+ return nil
+}
+
+func requestOpenAI2Dify(c *gin.Context, info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) *DifyChatRequest {
+ difyReq := DifyChatRequest{
+ Inputs: make(map[string]interface{}),
+ AutoGenerateName: false,
+ }
+
+ user := request.User
+ if user == "" {
+ user = helper.GetResponseID(c)
+ }
+ difyReq.User = user
+
+ files := make([]DifyFile, 0)
+ var content strings.Builder
+ for _, message := range request.Messages {
+ if message.Role == "system" {
+ content.WriteString("SYSTEM: \n" + message.StringContent() + "\n")
+ } else if message.Role == "assistant" {
+ content.WriteString("ASSISTANT: \n" + message.StringContent() + "\n")
+ } else {
+ parseContent := message.ParseContent()
+ for _, mediaContent := range parseContent {
+ switch mediaContent.Type {
+ case dto.ContentTypeText:
+ content.WriteString("USER: \n" + mediaContent.Text + "\n")
+ case dto.ContentTypeImageURL:
+ media := mediaContent.GetImageMedia()
+ var file *DifyFile
+ if media.IsRemoteImage() {
+ file.Type = media.MimeType
+ file.TransferMode = "remote_url"
+ file.URL = media.Url
+ } else {
+ file = uploadDifyFile(c, info, difyReq.User, mediaContent)
+ }
+ if file != nil {
+ files = append(files, *file)
+ }
+ }
+ }
+ }
+ }
+ difyReq.Query = content.String()
+ difyReq.Files = files
+ mode := "blocking"
+ if request.Stream {
+ mode = "streaming"
+ }
+ difyReq.ResponseMode = mode
+ return &difyReq
+}
+
+func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
+ response := dto.ChatCompletionsStreamResponse{
+ Object: "chat.completion.chunk",
+ Created: common.GetTimestamp(),
+ Model: "dify",
+ }
+ var choice dto.ChatCompletionsStreamResponseChoice
+ if strings.HasPrefix(difyResponse.Event, "workflow_") {
+ if constant.DifyDebug {
+ text := "Workflow: " + difyResponse.Data.WorkflowId
+ if difyResponse.Event == "workflow_finished" {
+ text += " " + difyResponse.Data.Status
+ }
+ choice.Delta.SetReasoningContent(text + "\n")
+ }
+ } else if strings.HasPrefix(difyResponse.Event, "node_") {
+ if constant.DifyDebug {
+ text := "Node: " + difyResponse.Data.NodeType
+ if difyResponse.Event == "node_finished" {
+ text += " " + difyResponse.Data.Status
+ }
+ choice.Delta.SetReasoningContent(text + "\n")
+ }
+ } else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" {
+ if difyResponse.Answer == " Thinking...
\n" {
+ difyResponse.Answer = ""
+ } else if difyResponse.Answer == " " {
+ difyResponse.Answer = ""
+ }
+
+ choice.Delta.SetContentString(difyResponse.Answer)
+ }
+ response.Choices = append(response.Choices, choice)
+ return &response
+}
+
+func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ var responseText string
+ usage := &dto.Usage{}
+ var nodeToken int
+ helper.SetEventStreamHeaders(c)
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+ var difyResponse DifyChunkChatCompletionResponse
+ err := json.Unmarshal([]byte(data), &difyResponse)
+ if err != nil {
+ common.SysError("error unmarshalling stream response: " + err.Error())
+ return true
+ }
+ var openaiResponse dto.ChatCompletionsStreamResponse
+ if difyResponse.Event == "message_end" {
+ usage = &difyResponse.MetaData.Usage
+ return false
+ } else if difyResponse.Event == "error" {
+ return false
+ } else {
+ openaiResponse = *streamResponseDify2OpenAI(difyResponse)
+ if len(openaiResponse.Choices) != 0 {
+ responseText += openaiResponse.Choices[0].Delta.GetContentString()
+ if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
+ nodeToken += 1
+ }
+ }
+ }
+ err = helper.ObjectData(c, openaiResponse)
+ if err != nil {
+ common.SysError(err.Error())
+ }
+ return true
+ })
+ helper.Done(c)
+ if usage.TotalTokens == 0 {
+ usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ }
+ usage.CompletionTokens += nodeToken
+ return usage, nil
+}
+
+func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ var difyResponse DifyChatCompletionResponse
+ responseBody, err := io.ReadAll(resp.Body)
+
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ common.CloseResponseBodyGracefully(resp)
+ err = json.Unmarshal(responseBody, &difyResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ fullTextResponse := dto.OpenAITextResponse{
+ Id: difyResponse.ConversationId,
+ Object: "chat.completion",
+ Created: common.GetTimestamp(),
+ Usage: difyResponse.MetaData.Usage,
+ }
+ choice := dto.OpenAITextResponseChoice{
+ Index: 0,
+ Message: dto.Message{
+ Role: "assistant",
+ Content: difyResponse.Answer,
+ },
+ FinishReason: "stop",
+ }
+ fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ c.Writer.Write(jsonResponse)
+ return &difyResponse.MetaData.Usage, nil
+}
diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go
new file mode 100644
index 00000000..71eb9ba4
--- /dev/null
+++ b/relay/channel/gemini/adaptor.go
@@ -0,0 +1,271 @@
+package gemini
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
+ "one-api/setting/model_setting"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
+ return nil, errors.New("not supported model for image generation")
+ }
+
+ // convert size to aspect ratio
+ aspectRatio := "1:1" // default aspect ratio
+ switch request.Size {
+ case "1024x1024":
+ aspectRatio = "1:1"
+ case "1024x1792":
+ aspectRatio = "9:16"
+ case "1792x1024":
+ aspectRatio = "16:9"
+ }
+
+ // build gemini imagen request
+ geminiRequest := GeminiImageRequest{
+ Instances: []GeminiImageInstance{
+ {
+ Prompt: request.Prompt,
+ },
+ },
+ Parameters: GeminiImageParameters{
+ SampleCount: request.N,
+ AspectRatio: aspectRatio,
+ PersonGeneration: "allow_adult", // default allow adult
+ },
+ }
+
+ return geminiRequest, nil
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+
+ if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
+ // 新增逻辑:处理 -thinking- 格式
+ if strings.Contains(info.UpstreamModelName, "-thinking-") {
+ parts := strings.Split(info.UpstreamModelName, "-thinking-")
+ info.UpstreamModelName = parts[0]
+ } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
+ info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
+ } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
+ info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
+ }
+ }
+
+ version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
+
+ if strings.HasPrefix(info.UpstreamModelName, "imagen") {
+ return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
+ }
+
+ if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
+ strings.HasPrefix(info.UpstreamModelName, "embedding") ||
+ strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
+ return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil
+ }
+
+ action := "generateContent"
+ if info.IsStream {
+ action = "streamGenerateContent?alt=sse"
+ }
+ return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("x-goog-api-key", info.ApiKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+
+ geminiRequest, err := CovertGemini2OpenAI(*request, info)
+ if err != nil {
+ return nil, err
+ }
+
+ return geminiRequest, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ if request.Input == nil {
+ return nil, errors.New("input is required")
+ }
+
+ inputs := request.ParseInput()
+ if len(inputs) == 0 {
+ return nil, errors.New("input is empty")
+ }
+
+ // only process the first input
+ geminiRequest := GeminiEmbeddingRequest{
+ Content: GeminiChatContent{
+ Parts: []GeminiPart{
+ {
+ Text: inputs[0],
+ },
+ },
+ },
+ }
+
+ // set specific parameters for different models
+ // https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
+ switch info.UpstreamModelName {
+ case "text-embedding-004":
+ // except embedding-001 supports setting `OutputDimensionality`
+ if request.Dimensions > 0 {
+ geminiRequest.OutputDimensionality = request.Dimensions
+ }
+ }
+
+ return geminiRequest, nil
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.RelayMode == constant.RelayModeGemini {
+ if info.IsStream {
+ return GeminiTextGenerationStreamHandler(c, info, resp)
+ } else {
+ return GeminiTextGenerationHandler(c, info, resp)
+ }
+ }
+
+ if strings.HasPrefix(info.UpstreamModelName, "imagen") {
+ return GeminiImageHandler(c, info, resp)
+ }
+
+ // check if the model is an embedding model
+ if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
+ strings.HasPrefix(info.UpstreamModelName, "embedding") ||
+ strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
+ return GeminiEmbeddingHandler(c, info, resp)
+ }
+
+ if info.IsStream {
+ return GeminiChatStreamHandler(c, info, resp)
+ } else {
+ return GeminiChatHandler(c, info, resp)
+ }
+
+ //if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
+ // // 没有请求-thinking的情况下,产生思考token,则按照思考模型计费
+ // if !strings.HasSuffix(info.OriginModelName, "-thinking") &&
+ // !strings.HasSuffix(info.OriginModelName, "-nothinking") {
+ // thinkingModelName := info.OriginModelName + "-thinking"
+ // if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) {
+ // info.OriginModelName = thinkingModelName
+ // }
+ // }
+ //}
+
+ return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
+}
+
+func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ responseBody, readErr := io.ReadAll(resp.Body)
+ if readErr != nil {
+ return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
+ }
+ _ = resp.Body.Close()
+
+ var geminiResponse GeminiImageResponse
+ if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
+ return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
+ }
+
+ if len(geminiResponse.Predictions) == 0 {
+ return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody)
+ }
+
+ // convert to openai format response
+ openAIResponse := dto.ImageResponse{
+ Created: common.GetTimestamp(),
+ Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
+ }
+
+ for _, prediction := range geminiResponse.Predictions {
+ if prediction.RaiFilteredReason != "" {
+ continue // skip filtered image
+ }
+ openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
+ B64Json: prediction.BytesBase64Encoded,
+ })
+ }
+
+ jsonResponse, jsonErr := json.Marshal(openAIResponse)
+ if jsonErr != nil {
+ return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
+ }
+
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, _ = c.Writer.Write(jsonResponse)
+
+ // https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
+ // each image has fixed 258 tokens
+ const imageTokens = 258
+ generatedImages := len(openAIResponse.Data)
+
+ usage := &dto.Usage{
+ PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
+ CompletionTokens: 0, // image generation does not calculate completion tokens
+ TotalTokens: imageTokens * generatedImages,
+ }
+
+ return usage, nil
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/gemini/constant.go b/relay/channel/gemini/constant.go
new file mode 100644
index 00000000..2c972e37
--- /dev/null
+++ b/relay/channel/gemini/constant.go
@@ -0,0 +1,37 @@
+package gemini
+
+var ModelList = []string{
+ // stable version
+ "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b",
+ "gemini-2.0-flash",
+ // latest version
+ "gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
+ // preview version
+ "gemini-2.0-flash-lite-preview",
+ // gemini exp
+ "gemini-exp-1206",
+ // flash exp
+ "gemini-2.0-flash-exp",
+ // pro exp
+ "gemini-2.0-pro-exp",
+ // thinking exp
+ "gemini-2.0-flash-thinking-exp",
+ "gemini-2.5-pro-exp-03-25",
+ "gemini-2.5-pro-preview-03-25",
+ // imagen models
+ "imagen-3.0-generate-002",
+ // embedding models
+ "gemini-embedding-exp-03-07",
+ "text-embedding-004",
+ "embedding-001",
+}
+
+var SafetySettingList = []string{
+ "HARM_CATEGORY_HARASSMENT",
+ "HARM_CATEGORY_HATE_SPEECH",
+ "HARM_CATEGORY_SEXUALLY_EXPLICIT",
+ "HARM_CATEGORY_DANGEROUS_CONTENT",
+ "HARM_CATEGORY_CIVIC_INTEGRITY",
+}
+
+var ChannelName = "google gemini"
diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go
new file mode 100644
index 00000000..b22e092a
--- /dev/null
+++ b/relay/channel/gemini/dto.go
@@ -0,0 +1,222 @@
+package gemini
+
+import "encoding/json"
+
+type GeminiChatRequest struct {
+ Contents []GeminiChatContent `json:"contents"`
+ SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
+ GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
+ Tools []GeminiChatTool `json:"tools,omitempty"`
+ SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
+}
+
+type GeminiThinkingConfig struct {
+ IncludeThoughts bool `json:"includeThoughts,omitempty"`
+ ThinkingBudget *int `json:"thinkingBudget,omitempty"`
+}
+
+func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
+ c.ThinkingBudget = &budget
+}
+
+type GeminiInlineData struct {
+ MimeType string `json:"mimeType"`
+ Data string `json:"data"`
+}
+
+// UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
+func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
+ type Alias GeminiInlineData // Use type alias to avoid recursion
+ var aux struct {
+ Alias
+ MimeTypeSnake string `json:"mime_type"`
+ }
+
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ *g = GeminiInlineData(aux.Alias) // Copy other fields if any in future
+
+ // Prioritize snake_case if present
+ if aux.MimeTypeSnake != "" {
+ g.MimeType = aux.MimeTypeSnake
+ } else if aux.MimeType != "" { // Fallback to camelCase from Alias
+ g.MimeType = aux.MimeType
+ }
+ // g.Data would be populated by aux.Alias.Data
+ return nil
+}
+
+type FunctionCall struct {
+ FunctionName string `json:"name"`
+ Arguments any `json:"args"`
+}
+
+type FunctionResponse struct {
+ Name string `json:"name"`
+ Response map[string]interface{} `json:"response"`
+}
+
+type GeminiPartExecutableCode struct {
+ Language string `json:"language,omitempty"`
+ Code string `json:"code,omitempty"`
+}
+
+type GeminiPartCodeExecutionResult struct {
+ Outcome string `json:"outcome,omitempty"`
+ Output string `json:"output,omitempty"`
+}
+
+type GeminiFileData struct {
+ MimeType string `json:"mimeType,omitempty"`
+ FileUri string `json:"fileUri,omitempty"`
+}
+
+type GeminiPart struct {
+ Text string `json:"text,omitempty"`
+ Thought bool `json:"thought,omitempty"`
+ InlineData *GeminiInlineData `json:"inlineData,omitempty"`
+ FunctionCall *FunctionCall `json:"functionCall,omitempty"`
+ FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
+ FileData *GeminiFileData `json:"fileData,omitempty"`
+ ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
+ CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
+}
+
+// UnmarshalJSON custom unmarshaler for GeminiPart to support snake_case and camelCase for InlineData
+func (p *GeminiPart) UnmarshalJSON(data []byte) error {
+ // Alias to avoid recursion during unmarshalling
+ type Alias GeminiPart
+ var aux struct {
+ Alias
+ InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
+ }
+
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ // Assign fields from alias
+ *p = GeminiPart(aux.Alias)
+
+ // Prioritize snake_case for InlineData if present
+ if aux.InlineDataSnake != nil {
+ p.InlineData = aux.InlineDataSnake
+ } else if aux.InlineData != nil { // Fallback to camelCase from Alias
+ p.InlineData = aux.InlineData
+ }
+ // Other fields like Text, FunctionCall etc. are already populated via aux.Alias
+
+ return nil
+}
+
+type GeminiChatContent struct {
+ Role string `json:"role,omitempty"`
+ Parts []GeminiPart `json:"parts"`
+}
+
+type GeminiChatSafetySettings struct {
+ Category string `json:"category"`
+ Threshold string `json:"threshold"`
+}
+
+type GeminiChatTool struct {
+ GoogleSearch any `json:"googleSearch,omitempty"`
+ GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
+ CodeExecution any `json:"codeExecution,omitempty"`
+ FunctionDeclarations any `json:"functionDeclarations,omitempty"`
+}
+
+type GeminiChatGenerationConfig struct {
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"topP,omitempty"`
+ TopK float64 `json:"topK,omitempty"`
+ MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
+ CandidateCount int `json:"candidateCount,omitempty"`
+ StopSequences []string `json:"stopSequences,omitempty"`
+ ResponseMimeType string `json:"responseMimeType,omitempty"`
+ ResponseSchema any `json:"responseSchema,omitempty"`
+ Seed int64 `json:"seed,omitempty"`
+ ResponseModalities []string `json:"responseModalities,omitempty"`
+ ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
+ SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
+}
+
+type GeminiChatCandidate struct {
+ Content GeminiChatContent `json:"content"`
+ FinishReason *string `json:"finishReason"`
+ Index int64 `json:"index"`
+ SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
+}
+
+type GeminiChatSafetyRating struct {
+ Category string `json:"category"`
+ Probability string `json:"probability"`
+}
+
+type GeminiChatPromptFeedback struct {
+ SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
+}
+
+type GeminiChatResponse struct {
+ Candidates []GeminiChatCandidate `json:"candidates"`
+ PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
+ UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
+}
+
+type GeminiUsageMetadata struct {
+ PromptTokenCount int `json:"promptTokenCount"`
+ CandidatesTokenCount int `json:"candidatesTokenCount"`
+ TotalTokenCount int `json:"totalTokenCount"`
+ ThoughtsTokenCount int `json:"thoughtsTokenCount"`
+ PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
+}
+
+type GeminiPromptTokensDetails struct {
+ Modality string `json:"modality"`
+ TokenCount int `json:"tokenCount"`
+}
+
+// Imagen related structs
+type GeminiImageRequest struct {
+ Instances []GeminiImageInstance `json:"instances"`
+ Parameters GeminiImageParameters `json:"parameters"`
+}
+
+type GeminiImageInstance struct {
+ Prompt string `json:"prompt"`
+}
+
+type GeminiImageParameters struct {
+ SampleCount int `json:"sampleCount,omitempty"`
+ AspectRatio string `json:"aspectRatio,omitempty"`
+ PersonGeneration string `json:"personGeneration,omitempty"`
+}
+
+type GeminiImageResponse struct {
+ Predictions []GeminiImagePrediction `json:"predictions"`
+}
+
+type GeminiImagePrediction struct {
+ MimeType string `json:"mimeType"`
+ BytesBase64Encoded string `json:"bytesBase64Encoded"`
+ RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
+ SafetyAttributes any `json:"safetyAttributes,omitempty"`
+}
+
+// Embedding related structs
+type GeminiEmbeddingRequest struct {
+ Content GeminiChatContent `json:"content"`
+ TaskType string `json:"taskType,omitempty"`
+ Title string `json:"title,omitempty"`
+ OutputDimensionality int `json:"outputDimensionality,omitempty"`
+}
+
+type GeminiEmbeddingResponse struct {
+ Embedding ContentEmbedding `json:"embedding"`
+}
+
+type ContentEmbedding struct {
+ Values []float64 `json:"values"`
+}
diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go
new file mode 100644
index 00000000..0870e3fa
--- /dev/null
+++ b/relay/channel/gemini/relay-gemini-native.go
@@ -0,0 +1,138 @@
+package gemini
+
+import (
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ defer common.CloseResponseBodyGracefully(resp)
+
+ // 读取响应体
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+
+ if common.DebugEnabled {
+ println(string(responseBody))
+ }
+
+ // 解析为 Gemini 原生响应格式
+ var geminiResponse GeminiChatResponse
+ err = common.Unmarshal(responseBody, &geminiResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+
+ // 计算使用量(基于 UsageMetadata)
+ usage := dto.Usage{
+ PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
+ CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
+ TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
+ }
+
+ usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
+
+ for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+ if detail.Modality == "AUDIO" {
+ usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+ } else if detail.Modality == "TEXT" {
+ usage.PromptTokensDetails.TextTokens = detail.TokenCount
+ }
+ }
+
+ // 直接返回 Gemini 原生格式的 JSON 响应
+ jsonResponse, err := common.Marshal(geminiResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+
+ common.IOCopyBytesGracefully(c, resp, jsonResponse)
+
+ return &usage, nil
+}
+
+func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ var usage = &dto.Usage{}
+ var imageCount int
+
+ helper.SetEventStreamHeaders(c)
+
+ responseText := strings.Builder{}
+
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+ var geminiResponse GeminiChatResponse
+ err := common.UnmarshalJsonStr(data, &geminiResponse)
+ if err != nil {
+ common.LogError(c, "error unmarshalling stream response: "+err.Error())
+ return false
+ }
+
+ // 统计图片数量
+ for _, candidate := range geminiResponse.Candidates {
+ for _, part := range candidate.Content.Parts {
+ if part.InlineData != nil && part.InlineData.MimeType != "" {
+ imageCount++
+ }
+ if part.Text != "" {
+ responseText.WriteString(part.Text)
+ }
+ }
+ }
+
+ // 更新使用量统计
+ if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
+ usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
+ usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
+ usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
+ usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
+ for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+ if detail.Modality == "AUDIO" {
+ usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+ } else if detail.Modality == "TEXT" {
+ usage.PromptTokensDetails.TextTokens = detail.TokenCount
+ }
+ }
+ }
+
+ // 直接发送 GeminiChatResponse 响应
+ err = helper.StringData(c, data)
+ if err != nil {
+ common.LogError(c, err.Error())
+ }
+
+ return true
+ })
+
+ if imageCount != 0 {
+ if usage.CompletionTokens == 0 {
+ usage.CompletionTokens = imageCount * 258
+ }
+ }
+
+ // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
+ if usage.CompletionTokens == 0 {
+ str := responseText.String()
+ if len(str) > 0 {
+ usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
+ } else {
+ // 空补全,不需要使用量
+ usage = &dto.Usage{}
+ }
+ }
+
+ // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
+ //helper.Done(c)
+
+ return usage, nil
+}
diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go
new file mode 100644
index 00000000..6f3babeb
--- /dev/null
+++ b/relay/channel/gemini/relay-gemini.go
@@ -0,0 +1,958 @@
+package gemini
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting/model_setting"
+ "one-api/types"
+ "strconv"
+ "strings"
+ "unicode/utf8"
+
+ "github.com/gin-gonic/gin"
+)
+
+var geminiSupportedMimeTypes = map[string]bool{
+ "application/pdf": true,
+ "audio/mpeg": true,
+ "audio/mp3": true,
+ "audio/wav": true,
+ "image/png": true,
+ "image/jpeg": true,
+ "text/plain": true,
+ "video/mov": true,
+ "video/mpeg": true,
+ "video/mp4": true,
+ "video/mpg": true,
+ "video/avi": true,
+ "video/wmv": true,
+ "video/mpegps": true,
+ "video/flv": true,
+}
+
+// Gemini 允许的思考预算范围
+const (
+ pro25MinBudget = 128
+ pro25MaxBudget = 32768
+ flash25MaxBudget = 24576
+ flash25LiteMinBudget = 512
+ flash25LiteMaxBudget = 24576
+)
+
+// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
+func clampThinkingBudget(modelName string, budget int) int {
+ isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
+ is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
+
+ if is25FlashLite {
+ if budget < flash25LiteMinBudget {
+ return flash25LiteMinBudget
+ }
+ if budget > flash25LiteMaxBudget {
+ return flash25LiteMaxBudget
+ }
+ } else if isNew25Pro {
+ if budget < pro25MinBudget {
+ return pro25MinBudget
+ }
+ if budget > pro25MaxBudget {
+ return pro25MaxBudget
+ }
+ } else { // 其他模型
+ if budget < 0 {
+ return 0
+ }
+ if budget > flash25MaxBudget {
+ return flash25MaxBudget
+ }
+ }
+ return budget
+}
+
+func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) {
+ if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
+ modelName := info.UpstreamModelName
+ isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
+ !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
+
+ if strings.Contains(modelName, "-thinking-") {
+ parts := strings.SplitN(modelName, "-thinking-", 2)
+ if len(parts) == 2 && parts[1] != "" {
+ if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
+ clampedBudget := clampThinkingBudget(modelName, budgetTokens)
+ geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
+ ThinkingBudget: common.GetPointer(clampedBudget),
+ IncludeThoughts: true,
+ }
+ }
+ }
+ } else if strings.HasSuffix(modelName, "-thinking") {
+ unsupportedModels := []string{
+ "gemini-2.5-pro-preview-05-06",
+ "gemini-2.5-pro-preview-03-25",
+ }
+ isUnsupported := false
+ for _, unsupportedModel := range unsupportedModels {
+ if strings.HasPrefix(modelName, unsupportedModel) {
+ isUnsupported = true
+ break
+ }
+ }
+
+ if isUnsupported {
+ geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
+ IncludeThoughts: true,
+ }
+ } else {
+ geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
+ IncludeThoughts: true,
+ }
+ if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
+ budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
+ clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
+ geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
+ }
+ }
+ } else if strings.HasSuffix(modelName, "-nothinking") {
+ if !isNew25Pro {
+ geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
+ ThinkingBudget: common.GetPointer(0),
+ }
+ }
+ }
+ }
+}
+
+// Setting safety to the lowest possible values since Gemini is already powerless enough
+func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
+
+ geminiRequest := GeminiChatRequest{
+ Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
+ GenerationConfig: GeminiChatGenerationConfig{
+ Temperature: textRequest.Temperature,
+ TopP: textRequest.TopP,
+ MaxOutputTokens: textRequest.MaxTokens,
+ Seed: int64(textRequest.Seed),
+ },
+ }
+
+ if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
+ geminiRequest.GenerationConfig.ResponseModalities = []string{
+ "TEXT",
+ "IMAGE",
+ }
+ }
+
+ ThinkingAdaptor(&geminiRequest, info)
+
+ safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
+ for _, category := range SafetySettingList {
+ safetySettings = append(safetySettings, GeminiChatSafetySettings{
+ Category: category,
+ Threshold: model_setting.GetGeminiSafetySetting(category),
+ })
+ }
+ geminiRequest.SafetySettings = safetySettings
+
+ // openaiContent.FuncToToolCalls()
+ if textRequest.Tools != nil {
+ functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
+ googleSearch := false
+ codeExecution := false
+ for _, tool := range textRequest.Tools {
+ if tool.Function.Name == "googleSearch" {
+ googleSearch = true
+ continue
+ }
+ if tool.Function.Name == "codeExecution" {
+ codeExecution = true
+ continue
+ }
+ if tool.Function.Parameters != nil {
+
+ params, ok := tool.Function.Parameters.(map[string]interface{})
+ if ok {
+ if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
+ if len(props) == 0 {
+ tool.Function.Parameters = nil
+ }
+ }
+ }
+ }
+ // Clean the parameters before appending
+ cleanedParams := cleanFunctionParameters(tool.Function.Parameters)
+ tool.Function.Parameters = cleanedParams
+ functions = append(functions, tool.Function)
+ }
+ if codeExecution {
+ geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
+ CodeExecution: make(map[string]string),
+ })
+ }
+ if googleSearch {
+ geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
+ GoogleSearch: make(map[string]string),
+ })
+ }
+ if len(functions) > 0 {
+ geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
+ FunctionDeclarations: functions,
+ })
+ }
+ // common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
+ // json_data, _ := json.Marshal(geminiRequest.Tools)
+ // common.SysLog("tools_json: " + string(json_data))
+ }
+
+ if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
+ geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
+
+ if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil {
+ cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0)
+ geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
+ }
+ }
+ tool_call_ids := make(map[string]string)
+ var system_content []string
+ //shouldAddDummyModelMessage := false
+ for _, message := range textRequest.Messages {
+ if message.Role == "system" {
+ system_content = append(system_content, message.StringContent())
+ continue
+ } else if message.Role == "tool" || message.Role == "function" {
+ if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
+ geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
+ Role: "user",
+ })
+ }
+ var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
+ name := ""
+ if message.Name != nil {
+ name = *message.Name
+ } else if val, exists := tool_call_ids[message.ToolCallId]; exists {
+ name = val
+ }
+ var contentMap map[string]interface{}
+ contentStr := message.StringContent()
+
+ // 1. 尝试解析为 JSON 对象
+ if err := json.Unmarshal([]byte(contentStr), &contentMap); err != nil {
+ // 2. 如果失败,尝试解析为 JSON 数组
+ var contentSlice []interface{}
+ if err := json.Unmarshal([]byte(contentStr), &contentSlice); err == nil {
+ // 如果是数组,包装成对象
+ contentMap = map[string]interface{}{"result": contentSlice}
+ } else {
+ // 3. 如果再次失败,作为纯文本处理
+ contentMap = map[string]interface{}{"content": contentStr}
+ }
+ }
+
+ functionResp := &FunctionResponse{
+ Name: name,
+ Response: contentMap,
+ }
+
+ *parts = append(*parts, GeminiPart{
+ FunctionResponse: functionResp,
+ })
+ continue
+ }
+ var parts []GeminiPart
+ content := GeminiChatContent{
+ Role: message.Role,
+ }
+ // isToolCall := false
+ if message.ToolCalls != nil {
+ // message.Role = "model"
+ // isToolCall = true
+ for _, call := range message.ParseToolCalls() {
+ args := map[string]interface{}{}
+ if call.Function.Arguments != "" {
+ if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
+ return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
+ }
+ }
+ toolCall := GeminiPart{
+ FunctionCall: &FunctionCall{
+ FunctionName: call.Function.Name,
+ Arguments: args,
+ },
+ }
+ parts = append(parts, toolCall)
+ tool_call_ids[call.ID] = call.Function.Name
+ }
+ }
+
+ openaiContent := message.ParseContent()
+ imageNum := 0
+ for _, part := range openaiContent {
+ if part.Type == dto.ContentTypeText {
+ if part.Text == "" {
+ continue
+ }
+ parts = append(parts, GeminiPart{
+ Text: part.Text,
+ })
+ } else if part.Type == dto.ContentTypeImageURL {
+ imageNum += 1
+
+ if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
+ return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
+ }
+ // 判断是否是url
+ if strings.HasPrefix(part.GetImageMedia().Url, "http") {
+ // 是url,获取文件的类型和base64编码的数据
+ fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
+ if err != nil {
+ return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
+ }
+
+ // 校验 MimeType 是否在 Gemini 支持的白名单中
+ if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
+ url := part.GetImageMedia().Url
+ return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
+ }
+
+ parts = append(parts, GeminiPart{
+ InlineData: &GeminiInlineData{
+ MimeType: fileData.MimeType, // 使用原始的 MimeType,因为大小写可能对API有意义
+ Data: fileData.Base64Data,
+ },
+ })
+ } else {
+ format, base64String, err := service.DecodeBase64FileData(part.GetImageMedia().Url)
+ if err != nil {
+ return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
+ }
+ parts = append(parts, GeminiPart{
+ InlineData: &GeminiInlineData{
+ MimeType: format,
+ Data: base64String,
+ },
+ })
+ }
+ } else if part.Type == dto.ContentTypeFile {
+ if part.GetFile().FileId != "" {
+ return nil, fmt.Errorf("only base64 file is supported in gemini")
+ }
+ format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
+ if err != nil {
+ return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
+ }
+ parts = append(parts, GeminiPart{
+ InlineData: &GeminiInlineData{
+ MimeType: format,
+ Data: base64String,
+ },
+ })
+ } else if part.Type == dto.ContentTypeInputAudio {
+ if part.GetInputAudio().Data == "" {
+ return nil, fmt.Errorf("only base64 audio is supported in gemini")
+ }
+ base64String, err := service.DecodeBase64AudioData(part.GetInputAudio().Data)
+ if err != nil {
+ return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
+ }
+ parts = append(parts, GeminiPart{
+ InlineData: &GeminiInlineData{
+ MimeType: "audio/" + part.GetInputAudio().Format,
+ Data: base64String,
+ },
+ })
+ }
+ }
+
+ content.Parts = parts
+
+ // there's no assistant role in gemini and API shall vomit if Role is not user or model
+ if content.Role == "assistant" {
+ content.Role = "model"
+ }
+ if len(content.Parts) > 0 {
+ geminiRequest.Contents = append(geminiRequest.Contents, content)
+ }
+ }
+
+ if len(system_content) > 0 {
+ geminiRequest.SystemInstructions = &GeminiChatContent{
+ Parts: []GeminiPart{
+ {
+ Text: strings.Join(system_content, "\n"),
+ },
+ },
+ }
+ }
+
+ return &geminiRequest, nil
+}
+
+// Helper function to get a list of supported MIME types for error messages
+func getSupportedMimeTypesList() []string {
+ keys := make([]string, 0, len(geminiSupportedMimeTypes))
+ for k := range geminiSupportedMimeTypes {
+ keys = append(keys, k)
+ }
+ return keys
+}
+
+// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
+func cleanFunctionParameters(params interface{}) interface{} {
+ if params == nil {
+ return nil
+ }
+
+ switch v := params.(type) {
+ case map[string]interface{}:
+ // Create a copy to avoid modifying the original
+ cleanedMap := make(map[string]interface{})
+ for k, val := range v {
+ cleanedMap[k] = val
+ }
+
+ // Remove unsupported root-level fields
+ delete(cleanedMap, "default")
+ delete(cleanedMap, "exclusiveMaximum")
+ delete(cleanedMap, "exclusiveMinimum")
+ delete(cleanedMap, "$schema")
+ delete(cleanedMap, "additionalProperties")
+
+ // Check and clean 'format' for string types
+ if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" {
+ if formatValue, formatExists := cleanedMap["format"].(string); formatExists {
+ if formatValue != "enum" && formatValue != "date-time" {
+ delete(cleanedMap, "format")
+ }
+ }
+ }
+
+ // Clean properties
+ if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
+ cleanedProps := make(map[string]interface{})
+ for propName, propValue := range props {
+ cleanedProps[propName] = cleanFunctionParameters(propValue)
+ }
+ cleanedMap["properties"] = cleanedProps
+ }
+
+ // Recursively clean items in arrays
+ if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
+ cleanedMap["items"] = cleanFunctionParameters(items)
+ }
+ // Also handle items if it's an array of schemas
+ if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
+ cleanedItemsArray := make([]interface{}, len(itemsArray))
+ for i, item := range itemsArray {
+ cleanedItemsArray[i] = cleanFunctionParameters(item)
+ }
+ cleanedMap["items"] = cleanedItemsArray
+ }
+
+ // Recursively clean other schema composition keywords
+ for _, field := range []string{"allOf", "anyOf", "oneOf"} {
+ if nested, ok := cleanedMap[field].([]interface{}); ok {
+ cleanedNested := make([]interface{}, len(nested))
+ for i, item := range nested {
+ cleanedNested[i] = cleanFunctionParameters(item)
+ }
+ cleanedMap[field] = cleanedNested
+ }
+ }
+
+ // Recursively clean patternProperties
+ if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok {
+ cleanedPatternProps := make(map[string]interface{})
+ for pattern, schema := range patternProps {
+ cleanedPatternProps[pattern] = cleanFunctionParameters(schema)
+ }
+ cleanedMap["patternProperties"] = cleanedPatternProps
+ }
+
+ // Recursively clean definitions
+ if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok {
+ cleanedDefinitions := make(map[string]interface{})
+ for defName, defSchema := range definitions {
+ cleanedDefinitions[defName] = cleanFunctionParameters(defSchema)
+ }
+ cleanedMap["definitions"] = cleanedDefinitions
+ }
+
+ // Recursively clean $defs (newer JSON Schema draft)
+ if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok {
+ cleanedDefs := make(map[string]interface{})
+ for defName, defSchema := range defs {
+ cleanedDefs[defName] = cleanFunctionParameters(defSchema)
+ }
+ cleanedMap["$defs"] = cleanedDefs
+ }
+
+ // Clean conditional keywords
+ for _, field := range []string{"if", "then", "else", "not"} {
+ if nested, ok := cleanedMap[field]; ok {
+ cleanedMap[field] = cleanFunctionParameters(nested)
+ }
+ }
+
+ return cleanedMap
+
+ case []interface{}:
+ // Handle arrays of schemas
+ cleanedArray := make([]interface{}, len(v))
+ for i, item := range v {
+ cleanedArray[i] = cleanFunctionParameters(item)
+ }
+ return cleanedArray
+
+ default:
+ // Not a map or array, return as is (e.g., could be a primitive)
+ return params
+ }
+}
+
+func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
+ if depth >= 5 {
+ return schema
+ }
+
+ v, ok := schema.(map[string]interface{})
+ if !ok || len(v) == 0 {
+ return schema
+ }
+ // 删除所有的title字段
+ delete(v, "title")
+ delete(v, "$schema")
+ // 如果type不为object和array,则直接返回
+ if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
+ return schema
+ }
+ switch v["type"] {
+ case "object":
+ delete(v, "additionalProperties")
+ // 处理 properties
+ if properties, ok := v["properties"].(map[string]interface{}); ok {
+ for key, value := range properties {
+ properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
+ }
+ }
+ for _, field := range []string{"allOf", "anyOf", "oneOf"} {
+ if nested, ok := v[field].([]interface{}); ok {
+ for i, item := range nested {
+ nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
+ }
+ }
+ }
+ case "array":
+ if items, ok := v["items"].(map[string]interface{}); ok {
+ v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
+ }
+ }
+
+ return v
+}
+
+func unescapeString(s string) (string, error) {
+ var result []rune
+ escaped := false
+ i := 0
+
+ for i < len(s) {
+ r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
+ if r == utf8.RuneError {
+ return "", fmt.Errorf("invalid UTF-8 encoding")
+ }
+
+ if escaped {
+ // 如果是转义符后的字符,检查其类型
+ switch r {
+ case '"':
+ result = append(result, '"')
+ case '\\':
+ result = append(result, '\\')
+ case '/':
+ result = append(result, '/')
+ case 'b':
+ result = append(result, '\b')
+ case 'f':
+ result = append(result, '\f')
+ case 'n':
+ result = append(result, '\n')
+ case 'r':
+ result = append(result, '\r')
+ case 't':
+ result = append(result, '\t')
+ case '\'':
+ result = append(result, '\'')
+ default:
+ // 如果遇到一个非法的转义字符,直接按原样输出
+ result = append(result, '\\', r)
+ }
+ escaped = false
+ } else {
+ if r == '\\' {
+ escaped = true // 记录反斜杠作为转义符
+ } else {
+ result = append(result, r)
+ }
+ }
+ i += size // 移动到下一个字符
+ }
+
+ return string(result), nil
+}
+func unescapeMapOrSlice(data interface{}) interface{} {
+ switch v := data.(type) {
+ case map[string]interface{}:
+ for k, val := range v {
+ v[k] = unescapeMapOrSlice(val)
+ }
+ case []interface{}:
+ for i, val := range v {
+ v[i] = unescapeMapOrSlice(val)
+ }
+ case string:
+ if unescaped, err := unescapeString(v); err != nil {
+ return v
+ } else {
+ return unescaped
+ }
+ }
+ return data
+}
+
+func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
+ var argsBytes []byte
+ var err error
+ if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
+ argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
+ } else {
+ argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
+ }
+
+ if err != nil {
+ return nil
+ }
+ return &dto.ToolCallResponse{
+ ID: fmt.Sprintf("call_%s", common.GetUUID()),
+ Type: "function",
+ Function: dto.FunctionResponse{
+ Arguments: string(argsBytes),
+ Name: item.FunctionCall.FunctionName,
+ },
+ }
+}
+
+func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
+ fullTextResponse := dto.OpenAITextResponse{
+ Id: helper.GetResponseID(c),
+ Object: "chat.completion",
+ Created: common.GetTimestamp(),
+ Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
+ }
+ isToolCall := false
+ for _, candidate := range response.Candidates {
+ choice := dto.OpenAITextResponseChoice{
+ Index: int(candidate.Index),
+ Message: dto.Message{
+ Role: "assistant",
+ Content: "",
+ },
+ FinishReason: constant.FinishReasonStop,
+ }
+ if len(candidate.Content.Parts) > 0 {
+ var texts []string
+ var toolCalls []dto.ToolCallResponse
+ for _, part := range candidate.Content.Parts {
+ if part.FunctionCall != nil {
+ choice.FinishReason = constant.FinishReasonToolCalls
+ if call := getResponseToolCall(&part); call != nil {
+ toolCalls = append(toolCalls, *call)
+ }
+ } else if part.Thought {
+ choice.Message.ReasoningContent = part.Text
+ } else {
+ if part.ExecutableCode != nil {
+ texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
+ } else if part.CodeExecutionResult != nil {
+ texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
+ } else {
+ // 过滤掉空行
+ if part.Text != "\n" {
+ texts = append(texts, part.Text)
+ }
+ }
+ }
+ }
+ if len(toolCalls) > 0 {
+ choice.Message.SetToolCalls(toolCalls)
+ isToolCall = true
+ }
+ choice.Message.SetStringContent(strings.Join(texts, "\n"))
+
+ }
+ if candidate.FinishReason != nil {
+ switch *candidate.FinishReason {
+ case "STOP":
+ choice.FinishReason = constant.FinishReasonStop
+ case "MAX_TOKENS":
+ choice.FinishReason = constant.FinishReasonLength
+ default:
+ choice.FinishReason = constant.FinishReasonContentFilter
+ }
+ }
+ if isToolCall {
+ choice.FinishReason = constant.FinishReasonToolCalls
+ }
+
+ fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
+ }
+ return &fullTextResponse
+}
+
+func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
+ choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
+ isStop := false
+ hasImage := false
+ for _, candidate := range geminiResponse.Candidates {
+ if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
+ isStop = true
+ candidate.FinishReason = nil
+ }
+ choice := dto.ChatCompletionsStreamResponseChoice{
+ Index: int(candidate.Index),
+ Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
+ Role: "assistant",
+ },
+ }
+ var texts []string
+ isTools := false
+ isThought := false
+ if candidate.FinishReason != nil {
+ // p := GeminiConvertFinishReason(*candidate.FinishReason)
+ switch *candidate.FinishReason {
+ case "STOP":
+ choice.FinishReason = &constant.FinishReasonStop
+ case "MAX_TOKENS":
+ choice.FinishReason = &constant.FinishReasonLength
+ default:
+ choice.FinishReason = &constant.FinishReasonContentFilter
+ }
+ }
+ for _, part := range candidate.Content.Parts {
+ if part.InlineData != nil {
+ if strings.HasPrefix(part.InlineData.MimeType, "image") {
+ imgText := ""
+ texts = append(texts, imgText)
+ hasImage = true
+ }
+ } else if part.FunctionCall != nil {
+ isTools = true
+ if call := getResponseToolCall(&part); call != nil {
+ call.SetIndex(len(choice.Delta.ToolCalls))
+ choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
+ }
+ } else if part.Thought {
+ isThought = true
+ texts = append(texts, part.Text)
+ } else {
+ if part.ExecutableCode != nil {
+ texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
+ } else if part.CodeExecutionResult != nil {
+ texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
+ } else {
+ if part.Text != "\n" {
+ texts = append(texts, part.Text)
+ }
+ }
+ }
+ }
+ if isThought {
+ choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
+ } else {
+ choice.Delta.SetContentString(strings.Join(texts, "\n"))
+ }
+ if isTools {
+ choice.FinishReason = &constant.FinishReasonToolCalls
+ }
+ choices = append(choices, choice)
+ }
+
+ var response dto.ChatCompletionsStreamResponse
+ response.Object = "chat.completion.chunk"
+ response.Choices = choices
+ return &response, isStop, hasImage
+}
+
+func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ // responseText := ""
+ id := helper.GetResponseID(c)
+ createAt := common.GetTimestamp()
+ var usage = &dto.Usage{}
+ var imageCount int
+
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+ var geminiResponse GeminiChatResponse
+ err := common.UnmarshalJsonStr(data, &geminiResponse)
+ if err != nil {
+ common.LogError(c, "error unmarshalling stream response: "+err.Error())
+ return false
+ }
+
+ response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
+ if hasImage {
+ imageCount++
+ }
+ response.Id = id
+ response.Created = createAt
+ response.Model = info.UpstreamModelName
+ if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
+ usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
+ usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
+ usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
+ usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
+ for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+ if detail.Modality == "AUDIO" {
+ usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+ } else if detail.Modality == "TEXT" {
+ usage.PromptTokensDetails.TextTokens = detail.TokenCount
+ }
+ }
+ }
+ err = helper.ObjectData(c, response)
+ if err != nil {
+ common.LogError(c, err.Error())
+ }
+ if isStop {
+ response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
+ helper.ObjectData(c, response)
+ }
+ return true
+ })
+
+ var response *dto.ChatCompletionsStreamResponse
+
+ if imageCount != 0 {
+ if usage.CompletionTokens == 0 {
+ usage.CompletionTokens = imageCount * 258
+ }
+ }
+
+ usage.PromptTokensDetails.TextTokens = usage.PromptTokens
+ usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
+
+ if info.ShouldIncludeUsage {
+ response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
+ err := helper.ObjectData(c, response)
+ if err != nil {
+ common.SysError("send final response failed: " + err.Error())
+ }
+ }
+ helper.Done(c)
+ //resp.Body.Close()
+ return usage, nil
+}
+
+func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ common.CloseResponseBodyGracefully(resp)
+ if common.DebugEnabled {
+ println(string(responseBody))
+ }
+ var geminiResponse GeminiChatResponse
+ err = common.Unmarshal(responseBody, &geminiResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ if len(geminiResponse.Candidates) == 0 {
+ return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody)
+ }
+ fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
+ fullTextResponse.Model = info.UpstreamModelName
+ usage := dto.Usage{
+ PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
+ CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
+ TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
+ }
+
+ usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
+ usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
+
+ for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
+ if detail.Modality == "AUDIO" {
+ usage.PromptTokensDetails.AudioTokens = detail.TokenCount
+ } else if detail.Modality == "TEXT" {
+ usage.PromptTokensDetails.TextTokens = detail.TokenCount
+ }
+ }
+
+ fullTextResponse.Usage = usage
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ c.Writer.Write(jsonResponse)
+ return &usage, nil
+}
+
+func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ defer common.CloseResponseBodyGracefully(resp)
+
+ responseBody, readErr := io.ReadAll(resp.Body)
+ if readErr != nil {
+ return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
+ }
+
+ var geminiResponse GeminiEmbeddingResponse
+ if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
+ return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
+ }
+
+ // convert to openai format response
+ openAIResponse := dto.OpenAIEmbeddingResponse{
+ Object: "list",
+ Data: []dto.OpenAIEmbeddingResponseItem{
+ {
+ Object: "embedding",
+ Embedding: geminiResponse.Embedding.Values,
+ Index: 0,
+ },
+ },
+ Model: info.UpstreamModelName,
+ }
+
+ // calculate usage
+ // https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
+ // Google has not yet clarified how embedding models will be billed
+ // refer to openai billing method to use input tokens billing
+ // https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
+ usage := &dto.Usage{
+ PromptTokens: info.PromptTokens,
+ CompletionTokens: 0,
+ TotalTokens: info.PromptTokens,
+ }
+ openAIResponse.Usage = *usage
+
+ jsonResponse, jsonErr := common.Marshal(openAIResponse)
+ if jsonErr != nil {
+ return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
+ }
+
+ common.IOCopyBytesGracefully(c, resp, jsonResponse)
+ return usage, nil
+}
diff --git a/relay/channel/jimeng/adaptor.go b/relay/channel/jimeng/adaptor.go
new file mode 100644
index 00000000..0b743879
--- /dev/null
+++ b/relay/channel/jimeng/adaptor.go
@@ -0,0 +1,136 @@
+package jimeng
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ relayconstant "one-api/relay/constant"
+ "one-api/types"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.BaseUrl), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
+ return errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
+}
+
+type LogoInfo struct {
+ AddLogo bool `json:"add_logo,omitempty"`
+ Position int `json:"position,omitempty"`
+ Language int `json:"language,omitempty"`
+ Opacity float64 `json:"opacity,omitempty"`
+ LogoTextContent string `json:"logo_text_content,omitempty"`
+}
+
+type imageRequestPayload struct {
+ ReqKey string `json:"req_key"` // Service identifier, fixed value: jimeng_high_aes_general_v21_L
+ Prompt string `json:"prompt"` // Prompt for image generation, supports both Chinese and English
+ Seed int64 `json:"seed,omitempty"` // Random seed, default -1 (random)
+ Width int `json:"width,omitempty"` // Image width, default 512, range [256, 768]
+ Height int `json:"height,omitempty"` // Image height, default 512, range [256, 768]
+ UsePreLLM bool `json:"use_pre_llm,omitempty"` // Enable text expansion, default true
+ UseSR bool `json:"use_sr,omitempty"` // Enable super resolution, default true
+ ReturnURL bool `json:"return_url,omitempty"` // Whether to return image URL (valid for 24 hours)
+ LogoInfo LogoInfo `json:"logo_info,omitempty"` // Watermark information
+ ImageUrls []string `json:"image_urls,omitempty"` // Image URLs for input
+ BinaryData []string `json:"binary_data_base64,omitempty"` // Base64 encoded binary data
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ payload := imageRequestPayload{
+ ReqKey: request.Model,
+ Prompt: request.Prompt,
+ }
+ if request.ResponseFormat == "" || request.ResponseFormat == "url" {
+ payload.ReturnURL = true // Default to returning image URLs
+ }
+
+ if len(request.ExtraFields) > 0 {
+ if err := json.Unmarshal(request.ExtraFields, &payload); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal extra fields: %w", err)
+ }
+ }
+
+ return payload, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ fullRequestURL, err := a.GetRequestURL(info)
+ if err != nil {
+ return nil, fmt.Errorf("get request url failed: %w", err)
+ }
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ if err != nil {
+ return nil, fmt.Errorf("new request failed: %w", err)
+ }
+ err = Sign(c, req, info.ApiKey)
+ if err != nil {
+ return nil, fmt.Errorf("setup request header failed: %w", err)
+ }
+ resp, err := channel.DoRequest(c, req, info)
+ if err != nil {
+ return nil, fmt.Errorf("do request failed: %w", err)
+ }
+ return resp, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.RelayMode == relayconstant.RelayModeImagesGenerations {
+ usage, err = jimengImageHandler(c, resp, info)
+ } else if info.IsStream {
+ usage, err = openai.OaiStreamHandler(c, info, resp)
+ } else {
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/jimeng/constants.go b/relay/channel/jimeng/constants.go
new file mode 100644
index 00000000..0d1764e5
--- /dev/null
+++ b/relay/channel/jimeng/constants.go
@@ -0,0 +1,9 @@
+package jimeng
+
+const (
+ ChannelName = "jimeng"
+)
+
+var ModelList = []string{
+ "jimeng_high_aes_general_v21_L",
+}
diff --git a/relay/channel/jimeng/image.go b/relay/channel/jimeng/image.go
new file mode 100644
index 00000000..3c6a1d99
--- /dev/null
+++ b/relay/channel/jimeng/image.go
@@ -0,0 +1,89 @@
+package jimeng
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type ImageResponse struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data struct {
+ BinaryDataBase64 []string `json:"binary_data_base64"`
+ ImageUrls []string `json:"image_urls"`
+ RephraseResult string `json:"rephraser_result"`
+ RequestID string `json:"request_id"`
+ // Other fields are omitted for brevity
+ } `json:"data"`
+ RequestID string `json:"request_id"`
+ Status int `json:"status"`
+ TimeElapsed string `json:"time_elapsed"`
+}
+
+func responseJimeng2OpenAIImage(_ *gin.Context, response *ImageResponse, info *relaycommon.RelayInfo) *dto.ImageResponse {
+ imageResponse := dto.ImageResponse{
+ Created: info.StartTime.Unix(),
+ }
+
+ for _, base64Data := range response.Data.BinaryDataBase64 {
+ imageResponse.Data = append(imageResponse.Data, dto.ImageData{
+ B64Json: base64Data,
+ })
+ }
+ for _, imageUrl := range response.Data.ImageUrls {
+ imageResponse.Data = append(imageResponse.Data, dto.ImageData{
+ Url: imageUrl,
+ })
+ }
+
+ return &imageResponse
+}
+
+// jimengImageHandler handles the Jimeng image generation response
+func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
+ var jimengResponse ImageResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
+ }
+ common.CloseResponseBodyGracefully(resp)
+
+ err = json.Unmarshal(responseBody, &jimengResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+
+ // Check if the response indicates an error
+ if jimengResponse.Code != 10000 {
+ return nil, types.WithOpenAIError(types.OpenAIError{
+ Message: jimengResponse.Message,
+ Type: "jimeng_error",
+ Param: "",
+ Code: fmt.Sprintf("%d", jimengResponse.Code),
+ }, resp.StatusCode)
+ }
+
+ // Convert Jimeng response to OpenAI format
+ fullTextResponse := responseJimeng2OpenAIImage(c, &jimengResponse, info)
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = c.Writer.Write(jsonResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+
+ return &dto.Usage{}, nil
+}
diff --git a/relay/channel/jimeng/sign.go b/relay/channel/jimeng/sign.go
new file mode 100644
index 00000000..c9db6630
--- /dev/null
+++ b/relay/channel/jimeng/sign.go
@@ -0,0 +1,176 @@
+package jimeng
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+ "net/url"
+ "one-api/common"
+ "sort"
+ "strings"
+ "time"
+)
+
+// SignRequestForJimeng 对即梦 API 请求进行签名,支持 http.Request 或 header+url+body 方式
+//func SignRequestForJimeng(req *http.Request, accessKey, secretKey string) error {
+// var bodyBytes []byte
+// var err error
+//
+// if req.Body != nil {
+// bodyBytes, err = io.ReadAll(req.Body)
+// if err != nil {
+// return fmt.Errorf("read request body failed: %w", err)
+// }
+// _ = req.Body.Close()
+// req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // rewind
+// } else {
+// bodyBytes = []byte{}
+// }
+//
+// return signJimengHeaders(&req.Header, req.Method, req.URL, bodyBytes, accessKey, secretKey)
+//}
+
+const HexPayloadHashKey = "HexPayloadHash"
+
+func SetPayloadHash(c *gin.Context, req any) error {
+ body, err := json.Marshal(req)
+ if err != nil {
+ return err
+ }
+ common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
+ payloadHash := sha256.Sum256(body)
+ hexPayloadHash := hex.EncodeToString(payloadHash[:])
+ c.Set(HexPayloadHashKey, hexPayloadHash)
+ return nil
+}
+func getPayloadHash(c *gin.Context) string {
+ return c.GetString(HexPayloadHashKey)
+}
+
+func Sign(c *gin.Context, req *http.Request, apiKey string) error {
+ header := req.Header
+
+ var bodyBytes []byte
+ var err error
+
+ if req.Body != nil {
+ bodyBytes, err = io.ReadAll(req.Body)
+ if err != nil {
+ return err
+ }
+ _ = req.Body.Close()
+ req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
+ }
+
+ payloadHash := sha256.Sum256(bodyBytes)
+ hexPayloadHash := hex.EncodeToString(payloadHash[:])
+
+ method := c.Request.Method
+ u := req.URL
+ keyParts := strings.Split(apiKey, "|")
+ if len(keyParts) != 2 {
+ return errors.New("invalid api key format for jimeng: expected 'ak|sk'")
+ }
+ accessKey := strings.TrimSpace(keyParts[0])
+ secretKey := strings.TrimSpace(keyParts[1])
+ t := time.Now().UTC()
+ xDate := t.Format("20060102T150405Z")
+ shortDate := t.Format("20060102")
+
+ host := u.Host
+ header.Set("Host", host)
+ header.Set("X-Date", xDate)
+ header.Set("X-Content-Sha256", hexPayloadHash)
+
+ // Sort and encode query parameters to create canonical query string
+ queryParams := u.Query()
+ sortedKeys := make([]string, 0, len(queryParams))
+ for k := range queryParams {
+ sortedKeys = append(sortedKeys, k)
+ }
+ sort.Strings(sortedKeys)
+ var queryParts []string
+ for _, k := range sortedKeys {
+ values := queryParams[k]
+ sort.Strings(values)
+ for _, v := range values {
+ queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
+ }
+ }
+ canonicalQueryString := strings.Join(queryParts, "&")
+
+ headersToSign := map[string]string{
+ "host": host,
+ "x-date": xDate,
+ "x-content-sha256": hexPayloadHash,
+ }
+ if header.Get("Content-Type") == "" {
+ header.Set("Content-Type", "application/json")
+ }
+ headersToSign["content-type"] = header.Get("Content-Type")
+
+ var signedHeaderKeys []string
+ for k := range headersToSign {
+ signedHeaderKeys = append(signedHeaderKeys, k)
+ }
+ sort.Strings(signedHeaderKeys)
+
+ var canonicalHeaders strings.Builder
+ for _, k := range signedHeaderKeys {
+ canonicalHeaders.WriteString(k)
+ canonicalHeaders.WriteString(":")
+ canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
+ canonicalHeaders.WriteString("\n")
+ }
+ signedHeaders := strings.Join(signedHeaderKeys, ";")
+
+ canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
+ method,
+ u.Path,
+ canonicalQueryString,
+ canonicalHeaders.String(),
+ signedHeaders,
+ hexPayloadHash,
+ )
+
+ hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
+ hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
+
+ region := "cn-north-1"
+ serviceName := "cv"
+ credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
+ stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
+ xDate,
+ credentialScope,
+ hexHashedCanonicalRequest,
+ )
+
+ kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
+ kRegion := hmacSHA256(kDate, []byte(region))
+ kService := hmacSHA256(kRegion, []byte(serviceName))
+ kSigning := hmacSHA256(kService, []byte("request"))
+ signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
+
+ authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
+ accessKey,
+ credentialScope,
+ signedHeaders,
+ signature,
+ )
+ header.Set("Authorization", authorization)
+ return nil
+}
+
+// hmacSHA256 计算 HMAC-SHA256
+func hmacSHA256(key []byte, data []byte) []byte {
+ h := hmac.New(sha256.New, key)
+ h.Write(data)
+ return h.Sum(nil)
+}
diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go
new file mode 100644
index 00000000..408a5c6e
--- /dev/null
+++ b/relay/channel/jina/adaptor.go
@@ -0,0 +1,92 @@
+package jina
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/common_handler"
+ "one-api/relay/constant"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ if info.RelayMode == constant.RelayModeRerank {
+ return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
+ } else if info.RelayMode == constant.RelayModeEmbeddings {
+ return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
+ }
+ return "", errors.New("invalid relay mode")
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.RelayMode == constant.RelayModeRerank {
+ usage, err = common_handler.RerankHandler(c, info, resp)
+ } else if info.RelayMode == constant.RelayModeEmbeddings {
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/jina/constant.go b/relay/channel/jina/constant.go
new file mode 100644
index 00000000..be290fb6
--- /dev/null
+++ b/relay/channel/jina/constant.go
@@ -0,0 +1,9 @@
+package jina
+
+var ModelList = []string{
+ "jina-clip-v1",
+ "jina-reranker-v2-base-multilingual",
+ "jina-reranker-m0",
+}
+
+var ChannelName = "jina"
diff --git a/relay/channel/jina/relay-jina.go b/relay/channel/jina/relay-jina.go
new file mode 100644
index 00000000..d83b5854
--- /dev/null
+++ b/relay/channel/jina/relay-jina.go
@@ -0,0 +1 @@
+package jina
diff --git a/relay/channel/lingyiwanwu/constrants.go b/relay/channel/lingyiwanwu/constrants.go
new file mode 100644
index 00000000..a6345071
--- /dev/null
+++ b/relay/channel/lingyiwanwu/constrants.go
@@ -0,0 +1,9 @@
+package lingyiwanwu
+
+// https://platform.lingyiwanwu.com/docs
+
+var ModelList = []string{
+ "yi-large", "yi-medium", "yi-vision", "yi-medium-200k", "yi-spark", "yi-large-rag", "yi-large-turbo", "yi-large-preview", "yi-large-rag-preview",
+}
+
+var ChannelName = "lingyiwanwu"
diff --git a/relay/channel/minimax/constants.go b/relay/channel/minimax/constants.go
new file mode 100644
index 00000000..c480cac9
--- /dev/null
+++ b/relay/channel/minimax/constants.go
@@ -0,0 +1,13 @@
+package minimax
+
+// https://www.minimaxi.com/document/guides/chat-model/V2?id=65e0736ab2845de20908e2dd
+
+var ModelList = []string{
+ "abab6.5-chat",
+ "abab6.5s-chat",
+ "abab6-chat",
+ "abab5.5-chat",
+ "abab5.5s-chat",
+}
+
+var ChannelName = "minimax"
diff --git a/relay/channel/minimax/relay-minimax.go b/relay/channel/minimax/relay-minimax.go
new file mode 100644
index 00000000..d0a15b0d
--- /dev/null
+++ b/relay/channel/minimax/relay-minimax.go
@@ -0,0 +1,10 @@
+package minimax
+
+import (
+ "fmt"
+ relaycommon "one-api/relay/common"
+)
+
+func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.BaseUrl), nil
+}
diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go
new file mode 100644
index 00000000..434a1031
--- /dev/null
+++ b/relay/channel/mistral/adaptor.go
@@ -0,0 +1,88 @@
+package mistral
+
+import (
+ "errors"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", "Bearer "+info.ApiKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return requestOpenAI2Mistral(request), nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ usage, err = openai.OaiStreamHandler(c, info, resp)
+ } else {
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/mistral/constants.go b/relay/channel/mistral/constants.go
new file mode 100644
index 00000000..7f5f3aca
--- /dev/null
+++ b/relay/channel/mistral/constants.go
@@ -0,0 +1,12 @@
+package mistral
+
+var ModelList = []string{
+ "open-mistral-7b",
+ "open-mixtral-8x7b",
+ "mistral-small-latest",
+ "mistral-medium-latest",
+ "mistral-large-latest",
+ "mistral-embed",
+}
+
+var ChannelName = "mistral"
diff --git a/relay/channel/mistral/text.go b/relay/channel/mistral/text.go
new file mode 100644
index 00000000..e26c6101
--- /dev/null
+++ b/relay/channel/mistral/text.go
@@ -0,0 +1,78 @@
+package mistral
+
+import (
+ "one-api/common"
+ "one-api/dto"
+ "regexp"
+)
+
+var mistralToolCallIdRegexp = regexp.MustCompile("^[a-zA-Z0-9]{9}$")
+
+func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
+ messages := make([]dto.Message, 0, len(request.Messages))
+ idMap := make(map[string]string)
+ for _, message := range request.Messages {
+ // 1. tool_calls.id
+ toolCalls := message.ParseToolCalls()
+ if toolCalls != nil {
+ for i := range toolCalls {
+ if !mistralToolCallIdRegexp.MatchString(toolCalls[i].ID) {
+ if newId, ok := idMap[toolCalls[i].ID]; ok {
+ toolCalls[i].ID = newId
+ } else {
+ newId, err := common.GenerateRandomCharsKey(9)
+ if err == nil {
+ idMap[toolCalls[i].ID] = newId
+ toolCalls[i].ID = newId
+ }
+ }
+ }
+ }
+ message.SetToolCalls(toolCalls)
+ }
+
+ // 2. tool_call_id
+ if message.ToolCallId != "" {
+ if newId, ok := idMap[message.ToolCallId]; ok {
+ message.ToolCallId = newId
+ } else {
+ if !mistralToolCallIdRegexp.MatchString(message.ToolCallId) {
+ newId, err := common.GenerateRandomCharsKey(9)
+ if err == nil {
+ idMap[message.ToolCallId] = newId
+ message.ToolCallId = newId
+ }
+ }
+ }
+ }
+
+ mediaMessages := message.ParseContent()
+ if message.Role == "assistant" && message.ToolCalls != nil && message.Content == "" {
+ mediaMessages = []dto.MediaContent{}
+ }
+ for j, mediaMessage := range mediaMessages {
+ if mediaMessage.Type == dto.ContentTypeImageURL {
+ imageUrl := mediaMessage.GetImageMedia()
+ mediaMessage.ImageUrl = imageUrl.Url
+ mediaMessages[j] = mediaMessage
+ }
+ }
+ message.SetMediaContent(mediaMessages)
+ messages = append(messages, dto.Message{
+ Role: message.Role,
+ Content: message.Content,
+ ToolCalls: message.ToolCalls,
+ ToolCallId: message.ToolCallId,
+ })
+ }
+ return &dto.GeneralOpenAIRequest{
+ Model: request.Model,
+ Stream: request.Stream,
+ Messages: messages,
+ Temperature: request.Temperature,
+ TopP: request.TopP,
+ MaxTokens: request.MaxTokens,
+ Tools: request.Tools,
+ ToolChoice: request.ToolChoice,
+ }
+}
diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go
new file mode 100644
index 00000000..b0b54b0c
--- /dev/null
+++ b/relay/channel/mokaai/adaptor.go
@@ -0,0 +1,106 @@
+package mokaai
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //TODO implement me
+ return request, nil
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
+ suffix := "chat/"
+ if strings.HasPrefix(info.UpstreamModelName, "m3e") {
+ suffix = "embeddings"
+ }
+ fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix)
+ return fullRequestURL, nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ switch info.RelayMode {
+ case constant.RelayModeEmbeddings:
+ baiduEmbeddingRequest := embeddingRequestOpenAI2Moka(*request)
+ return baiduEmbeddingRequest, nil
+ default:
+ return nil, errors.New("not implemented")
+ }
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+
+ switch info.RelayMode {
+ case constant.RelayModeEmbeddings:
+ return mokaEmbeddingHandler(c, info, resp)
+ default:
+ // err, usage = mokaHandler(c, resp)
+
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/mokaai/constants.go b/relay/channel/mokaai/constants.go
new file mode 100644
index 00000000..415d83b7
--- /dev/null
+++ b/relay/channel/mokaai/constants.go
@@ -0,0 +1,9 @@
+package mokaai
+
+var ModelList = []string{
+ "m3e-large",
+ "m3e-base",
+ "m3e-small",
+}
+
+var ChannelName = "mokaai"
\ No newline at end of file
diff --git a/relay/channel/mokaai/relay-mokaai.go b/relay/channel/mokaai/relay-mokaai.go
new file mode 100644
index 00000000..78f96d6d
--- /dev/null
+++ b/relay/channel/mokaai/relay-mokaai.go
@@ -0,0 +1,82 @@
+package mokaai
+
+import (
+ "encoding/json"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest {
+ var input []string // Change input to []string
+
+ switch v := request.Input.(type) {
+ case string:
+ input = []string{v} // Convert string to []string
+ case []string:
+ input = v // Already a []string, no conversion needed
+ case []interface{}:
+ for _, part := range v {
+ if str, ok := part.(string); ok {
+ input = append(input, str) // Append each string to the slice
+ }
+ }
+ }
+ return &dto.EmbeddingRequest{
+ Input: input,
+ Model: request.Model,
+ }
+}
+
+func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEmbeddingResponse {
+ openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
+ Object: "list",
+ Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
+ Model: "baidu-embedding",
+ Usage: response.Usage,
+ }
+ for _, item := range response.Data {
+ openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
+ Object: item.Object,
+ Index: item.Index,
+ Embedding: item.Embedding,
+ })
+ }
+ return &openAIEmbeddingResponse
+}
+
+func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ var baiduResponse dto.EmbeddingResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ common.CloseResponseBodyGracefully(resp)
+ err = json.Unmarshal(responseBody, &baiduResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ // if baiduResponse.ErrorMsg != "" {
+ // return &dto.OpenAIErrorWithStatusCode{
+ // Error: dto.OpenAIError{
+ // Type: "baidu_error",
+ // Param: "",
+ // },
+ // StatusCode: resp.StatusCode,
+ // }, nil
+ // }
+ fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse)
+ jsonResponse, err := common.Marshal(fullTextResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ common.IOCopyBytesGracefully(c, resp, jsonResponse)
+ return &fullTextResponse.Usage, nil
+}
diff --git a/relay/channel/moonshot/constants.go b/relay/channel/moonshot/constants.go
new file mode 100644
index 00000000..a7da54b3
--- /dev/null
+++ b/relay/channel/moonshot/constants.go
@@ -0,0 +1,9 @@
+package moonshot
+
+var ModelList = []string{
+ "moonshot-v1-8k",
+ "moonshot-v1-32k",
+ "moonshot-v1-128k",
+}
+
+var ChannelName = "moonshot"
diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go
new file mode 100644
index 00000000..b9e304fc
--- /dev/null
+++ b/relay/channel/ollama/adaptor.go
@@ -0,0 +1,97 @@
+package ollama
+
+import (
+ "errors"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ relayconstant "one-api/relay/constant"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ switch info.RelayMode {
+ case relayconstant.RelayModeEmbeddings:
+ return info.BaseUrl + "/api/embed", nil
+ default:
+ return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
+ }
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", "Bearer "+info.ApiKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return requestOpenAI2Ollama(*request)
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ return requestOpenAI2Embeddings(request), nil
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ usage, err = openai.OaiStreamHandler(c, info, resp)
+ } else {
+ if info.RelayMode == relayconstant.RelayModeEmbeddings {
+ usage, err = ollamaEmbeddingHandler(c, info, resp)
+ } else {
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/ollama/constants.go b/relay/channel/ollama/constants.go
new file mode 100644
index 00000000..682626a2
--- /dev/null
+++ b/relay/channel/ollama/constants.go
@@ -0,0 +1,7 @@
+package ollama
+
+var ModelList = []string{
+ "llama3-7b",
+}
+
+var ChannelName = "ollama"
diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go
new file mode 100644
index 00000000..15c64cdc
--- /dev/null
+++ b/relay/channel/ollama/dto.go
@@ -0,0 +1,45 @@
+package ollama
+
+import "one-api/dto"
+
+type OllamaRequest struct {
+ Model string `json:"model,omitempty"`
+ Messages []dto.Message `json:"messages,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ Seed float64 `json:"seed,omitempty"`
+ Topp float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ Stop any `json:"stop,omitempty"`
+ MaxTokens uint `json:"max_tokens,omitempty"`
+ Tools []dto.ToolCallRequest `json:"tools,omitempty"`
+ ResponseFormat any `json:"response_format,omitempty"`
+ FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty float64 `json:"presence_penalty,omitempty"`
+ Suffix any `json:"suffix,omitempty"`
+ StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
+ Prompt any `json:"prompt,omitempty"`
+}
+
+type Options struct {
+ Seed int `json:"seed,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty float64 `json:"presence_penalty,omitempty"`
+ NumPredict int `json:"num_predict,omitempty"`
+ NumCtx int `json:"num_ctx,omitempty"`
+}
+
+type OllamaEmbeddingRequest struct {
+ Model string `json:"model,omitempty"`
+ Input []string `json:"input"`
+ Options *Options `json:"options,omitempty"`
+}
+
+type OllamaEmbeddingResponse struct {
+ Error string `json:"error,omitempty"`
+ Model string `json:"model"`
+ Embedding [][]float64 `json:"embeddings,omitempty"`
+}
diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go
new file mode 100644
index 00000000..295349e3
--- /dev/null
+++ b/relay/channel/ollama/relay-ollama.go
@@ -0,0 +1,132 @@
+package ollama
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
+ messages := make([]dto.Message, 0, len(request.Messages))
+ for _, message := range request.Messages {
+ if !message.IsStringContent() {
+ mediaMessages := message.ParseContent()
+ for j, mediaMessage := range mediaMessages {
+ if mediaMessage.Type == dto.ContentTypeImageURL {
+ imageUrl := mediaMessage.GetImageMedia()
+ // check if not base64
+ if strings.HasPrefix(imageUrl.Url, "http") {
+ fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
+ if err != nil {
+ return nil, err
+ }
+ imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
+ }
+ mediaMessage.ImageUrl = imageUrl
+ mediaMessages[j] = mediaMessage
+ }
+ }
+ message.SetMediaContent(mediaMessages)
+ }
+ messages = append(messages, dto.Message{
+ Role: message.Role,
+ Content: message.Content,
+ ToolCalls: message.ToolCalls,
+ ToolCallId: message.ToolCallId,
+ })
+ }
+ str, ok := request.Stop.(string)
+ var Stop []string
+ if ok {
+ Stop = []string{str}
+ } else {
+ Stop, _ = request.Stop.([]string)
+ }
+ return &OllamaRequest{
+ Model: request.Model,
+ Messages: messages,
+ Stream: request.Stream,
+ Temperature: request.Temperature,
+ Seed: request.Seed,
+ Topp: request.TopP,
+ TopK: request.TopK,
+ Stop: Stop,
+ Tools: request.Tools,
+ MaxTokens: request.MaxTokens,
+ ResponseFormat: request.ResponseFormat,
+ FrequencyPenalty: request.FrequencyPenalty,
+ PresencePenalty: request.PresencePenalty,
+ Prompt: request.Prompt,
+ StreamOptions: request.StreamOptions,
+ Suffix: request.Suffix,
+ }, nil
+}
+
+func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
+ return &OllamaEmbeddingRequest{
+ Model: request.Model,
+ Input: request.ParseInput(),
+ Options: &Options{
+ Seed: int(request.Seed),
+ Temperature: request.Temperature,
+ TopP: request.TopP,
+ FrequencyPenalty: request.FrequencyPenalty,
+ PresencePenalty: request.PresencePenalty,
+ },
+ }
+}
+
+func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ var ollamaEmbeddingResponse OllamaEmbeddingResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ common.CloseResponseBodyGracefully(resp)
+ err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ if ollamaEmbeddingResponse.Error != "" {
+ return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody)
+ }
+ flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
+ data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
+ data = append(data, dto.OpenAIEmbeddingResponseItem{
+ Embedding: flattenedEmbeddings,
+ Object: "embedding",
+ })
+ usage := &dto.Usage{
+ TotalTokens: info.PromptTokens,
+ CompletionTokens: 0,
+ PromptTokens: info.PromptTokens,
+ }
+ embeddingResponse := &dto.OpenAIEmbeddingResponse{
+ Object: "list",
+ Data: data,
+ Model: info.UpstreamModelName,
+ Usage: *usage,
+ }
+ doResponseBody, err := common.Marshal(embeddingResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ common.IOCopyBytesGracefully(c, resp, doResponseBody)
+ return usage, nil
+}
+
+func flattenEmbeddings(embeddings [][]float64) []float64 {
+ flattened := []float64{}
+ for _, row := range embeddings {
+ flattened = append(flattened, row...)
+ }
+ return flattened
+}
diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go
new file mode 100644
index 00000000..efd22878
--- /dev/null
+++ b/relay/channel/openai/adaptor.go
@@ -0,0 +1,491 @@
+package openai
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "mime/multipart"
+ "net/http"
+ "net/textproto"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/ai360"
+ "one-api/relay/channel/lingyiwanwu"
+ "one-api/relay/channel/minimax"
+ "one-api/relay/channel/moonshot"
+ "one-api/relay/channel/openrouter"
+ "one-api/relay/channel/xinference"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/common_handler"
+ relayconstant "one-api/relay/constant"
+ "one-api/service"
+ "one-api/types"
+ "path/filepath"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+ ChannelType int
+ ResponseFormat string
+}
+
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
+ //if !strings.Contains(request.Model, "claude") {
+ // return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
+ //}
+ aiRequest, err := service.ClaudeToOpenAIRequest(*request, info)
+ if err != nil {
+ return nil, err
+ }
+ if info.SupportStreamOptions {
+ aiRequest.StreamOptions = &dto.StreamOptions{
+ IncludeUsage: true,
+ }
+ }
+ return a.ConvertOpenAIRequest(c, info, aiRequest)
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+ a.ChannelType = info.ChannelType
+
+ // initialize ThinkingContentInfo when thinking_to_content is enabled
+ if info.ChannelSetting.ThinkingToContent {
+ info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
+ IsFirstThinkingContent: true,
+ SendLastThinkingContent: false,
+ HasSentThinkingContent: false,
+ }
+ }
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ if info.RelayFormat == relaycommon.RelayFormatClaude {
+ return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
+ }
+ if info.RelayMode == relayconstant.RelayModeRealtime {
+ if strings.HasPrefix(info.BaseUrl, "https://") {
+ baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
+ baseUrl = "wss://" + baseUrl
+ info.BaseUrl = baseUrl
+ } else if strings.HasPrefix(info.BaseUrl, "http://") {
+ baseUrl := strings.TrimPrefix(info.BaseUrl, "http://")
+ baseUrl = "ws://" + baseUrl
+ info.BaseUrl = baseUrl
+ }
+ }
+ switch info.ChannelType {
+ case constant.ChannelTypeAzure:
+ apiVersion := info.ApiVersion
+ if apiVersion == "" {
+ apiVersion = constant.AzureDefaultAPIVersion
+ }
+ // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
+ requestURL := strings.Split(info.RequestURLPath, "?")[0]
+ requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
+ task := strings.TrimPrefix(requestURL, "/v1/")
+
+ // 特殊处理 responses API
+ if info.RelayMode == relayconstant.RelayModeResponses {
+ requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
+ return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
+ }
+
+ model_ := info.UpstreamModelName
+ // 2025年5月10日后创建的渠道不移除.
+ if info.ChannelCreateTime < constant.AzureNoRemoveDotTime {
+ model_ = strings.Replace(model_, ".", "", -1)
+ }
+ // https://github.com/songquanpeng/one-api/issues/67
+ requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
+ if info.RelayMode == relayconstant.RelayModeRealtime {
+ requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
+ }
+ return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
+ case constant.ChannelTypeMiniMax:
+ return minimax.GetRequestURL(info)
+ case constant.ChannelTypeCustom:
+ url := info.BaseUrl
+ url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
+ return url, nil
+ default:
+ return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
+ }
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, header)
+ if info.ChannelType == constant.ChannelTypeAzure {
+ header.Set("api-key", info.ApiKey)
+ return nil
+ }
+ if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization {
+ header.Set("OpenAI-Organization", info.Organization)
+ }
+ if info.RelayMode == relayconstant.RelayModeRealtime {
+ swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
+ if swp != "" {
+ items := []string{
+ "realtime",
+ "openai-insecure-api-key." + info.ApiKey,
+ "openai-beta.realtime-v1",
+ }
+ header.Set("Sec-WebSocket-Protocol", strings.Join(items, ","))
+ //req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key"))
+ //req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions"))
+ //req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
+ } else {
+ header.Set("openai-beta", "realtime=v1")
+ header.Set("Authorization", "Bearer "+info.ApiKey)
+ }
+ } else {
+ header.Set("Authorization", "Bearer "+info.ApiKey)
+ }
+ if info.ChannelType == constant.ChannelTypeOpenRouter {
+ header.Set("HTTP-Referer", "https://www.newapi.ai")
+ header.Set("X-Title", "New API")
+ }
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ if info.ChannelType != constant.ChannelTypeOpenAI && info.ChannelType != constant.ChannelTypeAzure {
+ request.StreamOptions = nil
+ }
+ if info.ChannelType == constant.ChannelTypeOpenRouter {
+ if len(request.Usage) == 0 {
+ request.Usage = json.RawMessage(`{"include":true}`)
+ }
+ }
+ if strings.HasPrefix(request.Model, "o") {
+ if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
+ request.MaxCompletionTokens = request.MaxTokens
+ request.MaxTokens = 0
+ }
+ request.Temperature = nil
+ if strings.HasSuffix(request.Model, "-high") {
+ request.ReasoningEffort = "high"
+ request.Model = strings.TrimSuffix(request.Model, "-high")
+ } else if strings.HasSuffix(request.Model, "-low") {
+ request.ReasoningEffort = "low"
+ request.Model = strings.TrimSuffix(request.Model, "-low")
+ } else if strings.HasSuffix(request.Model, "-medium") {
+ request.ReasoningEffort = "medium"
+ request.Model = strings.TrimSuffix(request.Model, "-medium")
+ }
+ info.ReasoningEffort = request.ReasoningEffort
+ info.UpstreamModelName = request.Model
+
+ // o系列模型developer适配(o1-mini除外)
+ if !strings.HasPrefix(request.Model, "o1-mini") && !strings.HasPrefix(request.Model, "o1-preview") {
+ //修改第一个Message的内容,将system改为developer
+ if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
+ request.Messages[0].Role = "developer"
+ }
+ }
+ }
+
+ return request, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ a.ResponseFormat = request.ResponseFormat
+ if info.RelayMode == relayconstant.RelayModeAudioSpeech {
+ jsonData, err := json.Marshal(request)
+ if err != nil {
+ return nil, fmt.Errorf("error marshalling object: %w", err)
+ }
+ return bytes.NewReader(jsonData), nil
+ } else {
+ var requestBody bytes.Buffer
+ writer := multipart.NewWriter(&requestBody)
+
+ writer.WriteField("model", request.Model)
+
+ // 获取所有表单字段
+ formData := c.Request.PostForm
+
+ // 遍历表单字段并打印输出
+ for key, values := range formData {
+ if key == "model" {
+ continue
+ }
+ for _, value := range values {
+ writer.WriteField(key, value)
+ }
+ }
+
+ // 添加文件字段
+ file, header, err := c.Request.FormFile("file")
+ if err != nil {
+ return nil, errors.New("file is required")
+ }
+ defer file.Close()
+
+ part, err := writer.CreateFormFile("file", header.Filename)
+ if err != nil {
+ return nil, errors.New("create form file failed")
+ }
+ if _, err := io.Copy(part, file); err != nil {
+ return nil, errors.New("copy file failed")
+ }
+
+ // 关闭 multipart 编写器以设置分界线
+ writer.Close()
+ c.Request.Header.Set("Content-Type", writer.FormDataContentType())
+ return &requestBody, nil
+ }
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ switch info.RelayMode {
+ case relayconstant.RelayModeImagesEdits:
+
+ var requestBody bytes.Buffer
+ writer := multipart.NewWriter(&requestBody)
+
+ writer.WriteField("model", request.Model)
+ // 获取所有表单字段
+ formData := c.Request.PostForm
+ // 遍历表单字段并打印输出
+ for key, values := range formData {
+ if key == "model" {
+ continue
+ }
+ for _, value := range values {
+ writer.WriteField(key, value)
+ }
+ }
+
+ // Parse the multipart form to handle both single image and multiple images
+ if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
+ return nil, errors.New("failed to parse multipart form")
+ }
+
+ if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
+ // Check if "image" field exists in any form, including array notation
+ var imageFiles []*multipart.FileHeader
+ var exists bool
+
+ // First check for standard "image" field
+ if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
+ // If not found, check for "image[]" field
+ if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
+ // If still not found, iterate through all fields to find any that start with "image["
+ foundArrayImages := false
+ for fieldName, files := range c.Request.MultipartForm.File {
+ if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
+ foundArrayImages = true
+ for _, file := range files {
+ imageFiles = append(imageFiles, file)
+ }
+ }
+ }
+
+ // If no image fields found at all
+ if !foundArrayImages && (len(imageFiles) == 0) {
+ return nil, errors.New("image is required")
+ }
+ }
+ }
+
+ // Process all image files
+ for i, fileHeader := range imageFiles {
+ file, err := fileHeader.Open()
+ if err != nil {
+ return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
+ }
+ defer file.Close()
+
+ // If multiple images, use image[] as the field name
+ fieldName := "image"
+ if len(imageFiles) > 1 {
+ fieldName = "image[]"
+ }
+
+ // Determine MIME type based on file extension
+ mimeType := detectImageMimeType(fileHeader.Filename)
+
+ // Create a form file with the appropriate content type
+ h := make(textproto.MIMEHeader)
+ h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
+ h.Set("Content-Type", mimeType)
+
+ part, err := writer.CreatePart(h)
+ if err != nil {
+ return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
+ }
+
+ if _, err := io.Copy(part, file); err != nil {
+ return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
+ }
+ }
+
+ // Handle mask file if present
+ if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
+ maskFile, err := maskFiles[0].Open()
+ if err != nil {
+ return nil, errors.New("failed to open mask file")
+ }
+ defer maskFile.Close()
+
+ // Determine MIME type for mask file
+ mimeType := detectImageMimeType(maskFiles[0].Filename)
+
+ // Create a form file with the appropriate content type
+ h := make(textproto.MIMEHeader)
+ h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
+ h.Set("Content-Type", mimeType)
+
+ maskPart, err := writer.CreatePart(h)
+ if err != nil {
+ return nil, errors.New("create form file failed for mask")
+ }
+
+ if _, err := io.Copy(maskPart, maskFile); err != nil {
+ return nil, errors.New("copy mask file failed")
+ }
+ }
+ } else {
+ return nil, errors.New("no multipart form data found")
+ }
+
+ // 关闭 multipart 编写器以设置分界线
+ writer.Close()
+ c.Request.Header.Set("Content-Type", writer.FormDataContentType())
+ return bytes.NewReader(requestBody.Bytes()), nil
+
+ default:
+ return request, nil
+ }
+}
+
+// detectImageMimeType determines the MIME type based on the file extension
+func detectImageMimeType(filename string) string {
+ ext := strings.ToLower(filepath.Ext(filename))
+ switch ext {
+ case ".jpg", ".jpeg":
+ return "image/jpeg"
+ case ".png":
+ return "image/png"
+ case ".webp":
+ return "image/webp"
+ default:
+ // Try to detect from extension if possible
+ if strings.HasPrefix(ext, ".jp") {
+ return "image/jpeg"
+ }
+ // Default to png as a fallback
+ return "image/png"
+ }
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // 模型后缀转换 reasoning effort
+ if strings.HasSuffix(request.Model, "-high") {
+ request.Reasoning.Effort = "high"
+ request.Model = strings.TrimSuffix(request.Model, "-high")
+ } else if strings.HasSuffix(request.Model, "-low") {
+ request.Reasoning.Effort = "low"
+ request.Model = strings.TrimSuffix(request.Model, "-low")
+ } else if strings.HasSuffix(request.Model, "-medium") {
+ request.Reasoning.Effort = "medium"
+ request.Model = strings.TrimSuffix(request.Model, "-medium")
+ }
+ return request, nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ if info.RelayMode == relayconstant.RelayModeAudioTranscription ||
+ info.RelayMode == relayconstant.RelayModeAudioTranslation ||
+ info.RelayMode == relayconstant.RelayModeImagesEdits {
+ return channel.DoFormRequest(a, c, info, requestBody)
+ } else if info.RelayMode == relayconstant.RelayModeRealtime {
+ return channel.DoWssRequest(a, c, info, requestBody)
+ } else {
+ return channel.DoApiRequest(a, c, info, requestBody)
+ }
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ switch info.RelayMode {
+ case relayconstant.RelayModeRealtime:
+ err, usage = OpenaiRealtimeHandler(c, info)
+ case relayconstant.RelayModeAudioSpeech:
+ usage = OpenaiTTSHandler(c, resp, info)
+ case relayconstant.RelayModeAudioTranslation:
+ fallthrough
+ case relayconstant.RelayModeAudioTranscription:
+ err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
+ case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
+ usage, err = OpenaiHandlerWithUsage(c, info, resp)
+ case relayconstant.RelayModeRerank:
+ usage, err = common_handler.RerankHandler(c, info, resp)
+ case relayconstant.RelayModeResponses:
+ if info.IsStream {
+ usage, err = OaiResponsesStreamHandler(c, info, resp)
+ } else {
+ usage, err = OaiResponsesHandler(c, info, resp)
+ }
+ default:
+ if info.IsStream {
+ usage, err = OaiStreamHandler(c, info, resp)
+ } else {
+ usage, err = OpenaiHandler(c, info, resp)
+ }
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ switch a.ChannelType {
+ case constant.ChannelType360:
+ return ai360.ModelList
+ case constant.ChannelTypeMoonshot:
+ return moonshot.ModelList
+ case constant.ChannelTypeLingYiWanWu:
+ return lingyiwanwu.ModelList
+ case constant.ChannelTypeMiniMax:
+ return minimax.ModelList
+ case constant.ChannelTypeXinference:
+ return xinference.ModelList
+ case constant.ChannelTypeOpenRouter:
+ return openrouter.ModelList
+ default:
+ return ModelList
+ }
+}
+
+func (a *Adaptor) GetChannelName() string {
+ switch a.ChannelType {
+ case constant.ChannelType360:
+ return ai360.ChannelName
+ case constant.ChannelTypeMoonshot:
+ return moonshot.ChannelName
+ case constant.ChannelTypeLingYiWanWu:
+ return lingyiwanwu.ChannelName
+ case constant.ChannelTypeMiniMax:
+ return minimax.ChannelName
+ case constant.ChannelTypeXinference:
+ return xinference.ChannelName
+ case constant.ChannelTypeOpenRouter:
+ return openrouter.ChannelName
+ default:
+ return ChannelName
+ }
+}
diff --git a/relay/channel/openai/constant.go b/relay/channel/openai/constant.go
new file mode 100644
index 00000000..c703e414
--- /dev/null
+++ b/relay/channel/openai/constant.go
@@ -0,0 +1,35 @@
+package openai
+
+var ModelList = []string{
+ "gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
+ "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
+ "gpt-3.5-turbo-instruct",
+ "gpt-4", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
+ "gpt-4-32k", "gpt-4-32k-0613",
+ "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
+ "gpt-4-vision-preview",
+ "chatgpt-4o-latest",
+ "gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20",
+ "gpt-4o-mini", "gpt-4o-mini-2024-07-18",
+ "gpt-4.5-preview", "gpt-4.5-preview-2025-02-27",
+ "o1-preview", "o1-preview-2024-09-12",
+ "o1-mini", "o1-mini-2024-09-12",
+ "o3-mini", "o3-mini-2025-01-31",
+ "o3-mini-high", "o3-mini-2025-01-31-high",
+ "o3-mini-low", "o3-mini-2025-01-31-low",
+ "o3-mini-medium", "o3-mini-2025-01-31-medium",
+ "o1", "o1-2024-12-17",
+ "gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01",
+ "gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01", "gpt-4o-realtime-preview-2024-12-17",
+ "gpt-4o-mini-realtime-preview", "gpt-4o-mini-realtime-preview-2024-12-17",
+ "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
+ "text-curie-001", "text-babbage-001", "text-ada-001",
+ "text-moderation-latest", "text-moderation-stable",
+ "text-davinci-edit-001",
+ "davinci-002", "babbage-002",
+ "dall-e-3",
+ "whisper-1",
+ "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
+}
+
+var ChannelName = "openai"
diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go
new file mode 100644
index 00000000..a068c544
--- /dev/null
+++ b/relay/channel/openai/helper.go
@@ -0,0 +1,196 @@
+package openai
+
+import (
+ "encoding/json"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
+ "one-api/service"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+// 辅助函数
+func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
+ info.SendResponseCount++
+ switch info.RelayFormat {
+ case relaycommon.RelayFormatOpenAI:
+ return sendStreamData(c, info, data, forceFormat, thinkToContent)
+ case relaycommon.RelayFormatClaude:
+ return handleClaudeFormat(c, data, info)
+ }
+ return nil
+}
+
+func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
+ var streamResponse dto.ChatCompletionsStreamResponse
+ if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
+ return err
+ }
+
+ if streamResponse.Usage != nil {
+ info.ClaudeConvertInfo.Usage = streamResponse.Usage
+ }
+ claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
+ for _, resp := range claudeResponses {
+ helper.ClaudeData(c, *resp)
+ }
+ return nil
+}
+
+func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
+ for _, choice := range streamResponse.Choices {
+ responseTextBuilder.WriteString(choice.Delta.GetContentString())
+ responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
+ if choice.Delta.ToolCalls != nil {
+ if len(choice.Delta.ToolCalls) > *toolCount {
+ *toolCount = len(choice.Delta.ToolCalls)
+ }
+ for _, tool := range choice.Delta.ToolCalls {
+ responseTextBuilder.WriteString(tool.Function.Name)
+ responseTextBuilder.WriteString(tool.Function.Arguments)
+ }
+ }
+ }
+ return nil
+}
+
+func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
+ streamResp := "[" + strings.Join(streamItems, ",") + "]"
+
+ switch relayMode {
+ case relayconstant.RelayModeChatCompletions:
+ return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
+ case relayconstant.RelayModeCompletions:
+ return processCompletions(streamResp, streamItems, responseTextBuilder)
+ }
+ return nil
+}
+
+func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
+ var streamResponses []dto.ChatCompletionsStreamResponse
+ if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
+ // 一次性解析失败,逐个解析
+ common.SysError("error unmarshalling stream response: " + err.Error())
+ for _, item := range streamItems {
+ var streamResponse dto.ChatCompletionsStreamResponse
+ if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
+ return err
+ }
+ if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
+ common.SysError("error processing stream response: " + err.Error())
+ }
+ }
+ return nil
+ }
+
+ // 批量处理所有响应
+ for _, streamResponse := range streamResponses {
+ for _, choice := range streamResponse.Choices {
+ responseTextBuilder.WriteString(choice.Delta.GetContentString())
+ responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
+ if choice.Delta.ToolCalls != nil {
+ if len(choice.Delta.ToolCalls) > *toolCount {
+ *toolCount = len(choice.Delta.ToolCalls)
+ }
+ for _, tool := range choice.Delta.ToolCalls {
+ responseTextBuilder.WriteString(tool.Function.Name)
+ responseTextBuilder.WriteString(tool.Function.Arguments)
+ }
+ }
+ }
+ }
+ return nil
+}
+
+func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
+ var streamResponses []dto.CompletionsStreamResponse
+ if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
+ // 一次性解析失败,逐个解析
+ common.SysError("error unmarshalling stream response: " + err.Error())
+ for _, item := range streamItems {
+ var streamResponse dto.CompletionsStreamResponse
+ if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
+ continue
+ }
+ for _, choice := range streamResponse.Choices {
+ responseTextBuilder.WriteString(choice.Text)
+ }
+ }
+ return nil
+ }
+
+ // 批量处理所有响应
+ for _, streamResponse := range streamResponses {
+ for _, choice := range streamResponse.Choices {
+ responseTextBuilder.WriteString(choice.Text)
+ }
+ }
+ return nil
+}
+
+func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
+ systemFingerprint *string, model *string, usage **dto.Usage,
+ containStreamUsage *bool, info *relaycommon.RelayInfo,
+ shouldSendLastResp *bool) error {
+
+ var lastStreamResponse dto.ChatCompletionsStreamResponse
+ if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
+ return err
+ }
+
+ *responseId = lastStreamResponse.Id
+ *createAt = lastStreamResponse.Created
+ *systemFingerprint = lastStreamResponse.GetSystemFingerprint()
+ *model = lastStreamResponse.Model
+
+ if service.ValidUsage(lastStreamResponse.Usage) {
+ *containStreamUsage = true
+ *usage = lastStreamResponse.Usage
+ if !info.ShouldIncludeUsage {
+ *shouldSendLastResp = false
+ }
+ }
+
+ return nil
+}
+
+func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
+ responseId string, createAt int64, model string, systemFingerprint string,
+ usage *dto.Usage, containStreamUsage bool) {
+
+ switch info.RelayFormat {
+ case relaycommon.RelayFormatOpenAI:
+ if info.ShouldIncludeUsage && !containStreamUsage {
+ response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
+ response.SetSystemFingerprint(systemFingerprint)
+ helper.ObjectData(c, response)
+ }
+ helper.Done(c)
+
+ case relaycommon.RelayFormatClaude:
+ info.ClaudeConvertInfo.Done = true
+ var streamResponse dto.ChatCompletionsStreamResponse
+ if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
+ common.SysError("error unmarshalling stream response: " + err.Error())
+ return
+ }
+
+ info.ClaudeConvertInfo.Usage = usage
+
+ claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
+ for _, resp := range claudeResponses {
+ helper.ClaudeData(c, *resp)
+ }
+ }
+}
+
+func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
+ if data == "" {
+ return
+ }
+ helper.ResponseChunkData(c, streamResponse, data)
+}
diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go
new file mode 100644
index 00000000..bfe8bcd3
--- /dev/null
+++ b/relay/channel/openai/relay-openai.go
@@ -0,0 +1,587 @@
+package openai
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "math"
+ "mime/multipart"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "one-api/types"
+
+ "github.com/bytedance/gopkg/util/gopool"
+ "github.com/gin-gonic/gin"
+ "github.com/gorilla/websocket"
+ "github.com/pkg/errors"
+)
+
+func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
+ if data == "" {
+ return nil
+ }
+
+ if !forceFormat && !thinkToContent {
+ return helper.StringData(c, data)
+ }
+
+ var lastStreamResponse dto.ChatCompletionsStreamResponse
+ if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
+ return err
+ }
+
+ if !thinkToContent {
+ return helper.ObjectData(c, lastStreamResponse)
+ }
+
+ hasThinkingContent := false
+ hasContent := false
+ var thinkingContent strings.Builder
+ for _, choice := range lastStreamResponse.Choices {
+ if len(choice.Delta.GetReasoningContent()) > 0 {
+ hasThinkingContent = true
+ thinkingContent.WriteString(choice.Delta.GetReasoningContent())
+ }
+ if len(choice.Delta.GetContentString()) > 0 {
+ hasContent = true
+ }
+ }
+
+ // Handle think to content conversion
+ if info.ThinkingContentInfo.IsFirstThinkingContent {
+ if hasThinkingContent {
+ response := lastStreamResponse.Copy()
+ for i := range response.Choices {
+ // send `think` tag with thinking content
+ response.Choices[i].Delta.SetContentString("\n" + thinkingContent.String())
+ response.Choices[i].Delta.ReasoningContent = nil
+ response.Choices[i].Delta.Reasoning = nil
+ }
+ info.ThinkingContentInfo.IsFirstThinkingContent = false
+ info.ThinkingContentInfo.HasSentThinkingContent = true
+ return helper.ObjectData(c, response)
+ }
+ }
+
+ if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
+ return helper.ObjectData(c, lastStreamResponse)
+ }
+
+ // Process each choice
+ for i, choice := range lastStreamResponse.Choices {
+ // Handle transition from thinking to content
+ // only send `` tag when previous thinking content has been sent
+ if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
+ response := lastStreamResponse.Copy()
+ for j := range response.Choices {
+ response.Choices[j].Delta.SetContentString("\n\n")
+ response.Choices[j].Delta.ReasoningContent = nil
+ response.Choices[j].Delta.Reasoning = nil
+ }
+ info.ThinkingContentInfo.SendLastThinkingContent = true
+ helper.ObjectData(c, response)
+ }
+
+ // Convert reasoning content to regular content if any
+ if len(choice.Delta.GetReasoningContent()) > 0 {
+ lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
+ lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
+ lastStreamResponse.Choices[i].Delta.Reasoning = nil
+ } else if !hasThinkingContent && !hasContent {
+ // flush thinking content
+ lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
+ lastStreamResponse.Choices[i].Delta.Reasoning = nil
+ }
+ }
+
+ return helper.ObjectData(c, lastStreamResponse)
+}
+
+func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ if resp == nil || resp.Body == nil {
+ common.LogError(c, "invalid response or response body")
+ return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
+ }
+
+ defer common.CloseResponseBodyGracefully(resp)
+
+ model := info.UpstreamModelName
+ var responseId string
+ var createAt int64 = 0
+ var systemFingerprint string
+ var containStreamUsage bool
+ var responseTextBuilder strings.Builder
+ var toolCount int
+ var usage = &dto.Usage{}
+ var streamItems []string // store stream items
+ var forceFormat bool
+ var thinkToContent bool
+
+ if info.ChannelSetting.ForceFormat {
+ forceFormat = true
+ }
+
+ if info.ChannelSetting.ThinkingToContent {
+ thinkToContent = true
+ }
+
+ var (
+ lastStreamData string
+ )
+
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+ if lastStreamData != "" {
+ err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
+ if err != nil {
+ common.SysError("error handling stream format: " + err.Error())
+ }
+ }
+ lastStreamData = data
+ streamItems = append(streamItems, data)
+ return true
+ })
+
+ // 处理最后的响应
+ shouldSendLastResp := true
+ if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
+ &containStreamUsage, info, &shouldSendLastResp); err != nil {
+ common.SysError("error handling last response: " + err.Error())
+ }
+
+ if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI {
+ _ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
+ }
+
+ // 处理token计算
+ if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
+ common.SysError("error processing tokens: " + err.Error())
+ }
+
+ if !containStreamUsage {
+ usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+ usage.CompletionTokens += toolCount * 7
+ } else {
+ if info.ChannelType == constant.ChannelTypeDeepSeek {
+ if usage.PromptCacheHitTokens != 0 {
+ usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
+ }
+ }
+ }
+
+ handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
+
+ return usage, nil
+}
+
+func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ defer common.CloseResponseBodyGracefully(resp)
+
+ var simpleResponse dto.OpenAITextResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
+ }
+ err = common.Unmarshal(responseBody, &simpleResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
+ return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode)
+ }
+
+ forceFormat := false
+ if info.ChannelSetting.ForceFormat {
+ forceFormat = true
+ }
+
+ if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
+ completionTokens := 0
+ for _, choice := range simpleResponse.Choices {
+ ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
+ completionTokens += ctkm
+ }
+ simpleResponse.Usage = dto.Usage{
+ PromptTokens: info.PromptTokens,
+ CompletionTokens: completionTokens,
+ TotalTokens: info.PromptTokens + completionTokens,
+ }
+ }
+
+ switch info.RelayFormat {
+ case relaycommon.RelayFormatOpenAI:
+ if forceFormat {
+ responseBody, err = common.Marshal(simpleResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ } else {
+ break
+ }
+ case relaycommon.RelayFormatClaude:
+ claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
+ claudeRespStr, err := common.Marshal(claudeResp)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ responseBody = claudeRespStr
+ }
+
+ common.IOCopyBytesGracefully(c, resp, responseBody)
+
+ return &simpleResponse.Usage, nil
+}
+
+func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
+ // the status code has been judged before, if there is a body reading failure,
+ // it should be regarded as a non-recoverable error, so it should not return err for external retry.
+ // Analogous to nginx's load balancing, it will only retry if it can't be requested or
+ // if the upstream returns a specific status code, once the upstream has already written the header,
+ // the subsequent failure of the response body should be regarded as a non-recoverable error,
+ // and can be terminated directly.
+ defer common.CloseResponseBodyGracefully(resp)
+ usage := &dto.Usage{}
+ usage.PromptTokens = info.PromptTokens
+ usage.TotalTokens = info.PromptTokens
+ for k, v := range resp.Header {
+ c.Writer.Header().Set(k, v[0])
+ }
+ c.Writer.WriteHeader(resp.StatusCode)
+ c.Writer.WriteHeaderNow()
+ _, err := io.Copy(c.Writer, resp.Body)
+ if err != nil {
+ common.LogError(c, err.Error())
+ }
+ return usage
+}
+
+func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
+ defer common.CloseResponseBodyGracefully(resp)
+
+ // count tokens by audio file duration
+ audioTokens, err := countAudioTokens(c)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
+ }
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
+ }
+ // 写入新的 response body
+ common.IOCopyBytesGracefully(c, resp, responseBody)
+
+ usage := &dto.Usage{}
+ usage.PromptTokens = audioTokens
+ usage.CompletionTokens = 0
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ return nil, usage
+}
+
+func countAudioTokens(c *gin.Context) (int, error) {
+ body, err := common.GetRequestBody(c)
+ if err != nil {
+ return 0, errors.WithStack(err)
+ }
+
+ var reqBody struct {
+ File *multipart.FileHeader `form:"file" binding:"required"`
+ }
+ c.Request.Body = io.NopCloser(bytes.NewReader(body))
+ if err = c.ShouldBind(&reqBody); err != nil {
+ return 0, errors.WithStack(err)
+ }
+ ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
+ reqFp, err := reqBody.File.Open()
+ if err != nil {
+ return 0, errors.WithStack(err)
+ }
+ defer reqFp.Close()
+
+ tmpFp, err := os.CreateTemp("", "audio-*"+ext)
+ if err != nil {
+ return 0, errors.WithStack(err)
+ }
+ defer os.Remove(tmpFp.Name())
+
+ _, err = io.Copy(tmpFp, reqFp)
+ if err != nil {
+ return 0, errors.WithStack(err)
+ }
+ if err = tmpFp.Close(); err != nil {
+ return 0, errors.WithStack(err)
+ }
+
+ duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
+ if err != nil {
+ return 0, errors.WithStack(err)
+ }
+
+ return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
+}
+
+func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
+ if info == nil || info.ClientWs == nil || info.TargetWs == nil {
+ return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
+ }
+
+ info.IsStream = true
+ clientConn := info.ClientWs
+ targetConn := info.TargetWs
+
+ clientClosed := make(chan struct{})
+ targetClosed := make(chan struct{})
+ sendChan := make(chan []byte, 100)
+ receiveChan := make(chan []byte, 100)
+ errChan := make(chan error, 2)
+
+ usage := &dto.RealtimeUsage{}
+ localUsage := &dto.RealtimeUsage{}
+ sumUsage := &dto.RealtimeUsage{}
+
+ gopool.Go(func() {
+ defer func() {
+ if r := recover(); r != nil {
+ errChan <- fmt.Errorf("panic in client reader: %v", r)
+ }
+ }()
+ for {
+ select {
+ case <-c.Done():
+ return
+ default:
+ _, message, err := clientConn.ReadMessage()
+ if err != nil {
+ if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
+ errChan <- fmt.Errorf("error reading from client: %v", err)
+ }
+ close(clientClosed)
+ return
+ }
+
+ realtimeEvent := &dto.RealtimeEvent{}
+ err = common.Unmarshal(message, realtimeEvent)
+ if err != nil {
+ errChan <- fmt.Errorf("error unmarshalling message: %v", err)
+ return
+ }
+
+ if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
+ if realtimeEvent.Session != nil {
+ if realtimeEvent.Session.Tools != nil {
+ info.RealtimeTools = realtimeEvent.Session.Tools
+ }
+ }
+ }
+
+ textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
+ if err != nil {
+ errChan <- fmt.Errorf("error counting text token: %v", err)
+ return
+ }
+ common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
+ localUsage.TotalTokens += textToken + audioToken
+ localUsage.InputTokens += textToken + audioToken
+ localUsage.InputTokenDetails.TextTokens += textToken
+ localUsage.InputTokenDetails.AudioTokens += audioToken
+
+ err = helper.WssString(c, targetConn, string(message))
+ if err != nil {
+ errChan <- fmt.Errorf("error writing to target: %v", err)
+ return
+ }
+
+ select {
+ case sendChan <- message:
+ default:
+ }
+ }
+ }
+ })
+
+ gopool.Go(func() {
+ defer func() {
+ if r := recover(); r != nil {
+ errChan <- fmt.Errorf("panic in target reader: %v", r)
+ }
+ }()
+ for {
+ select {
+ case <-c.Done():
+ return
+ default:
+ _, message, err := targetConn.ReadMessage()
+ if err != nil {
+ if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
+ errChan <- fmt.Errorf("error reading from target: %v", err)
+ }
+ close(targetClosed)
+ return
+ }
+ info.SetFirstResponseTime()
+ realtimeEvent := &dto.RealtimeEvent{}
+ err = common.Unmarshal(message, realtimeEvent)
+ if err != nil {
+ errChan <- fmt.Errorf("error unmarshalling message: %v", err)
+ return
+ }
+
+ if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
+ realtimeUsage := realtimeEvent.Response.Usage
+ if realtimeUsage != nil {
+ usage.TotalTokens += realtimeUsage.TotalTokens
+ usage.InputTokens += realtimeUsage.InputTokens
+ usage.OutputTokens += realtimeUsage.OutputTokens
+ usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
+ usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
+ usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
+ usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
+ usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
+ err := preConsumeUsage(c, info, usage, sumUsage)
+ if err != nil {
+ errChan <- fmt.Errorf("error consume usage: %v", err)
+ return
+ }
+ // 本次计费完成,清除
+ usage = &dto.RealtimeUsage{}
+
+ localUsage = &dto.RealtimeUsage{}
+ } else {
+ textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
+ if err != nil {
+ errChan <- fmt.Errorf("error counting text token: %v", err)
+ return
+ }
+ common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
+ localUsage.TotalTokens += textToken + audioToken
+ info.IsFirstRequest = false
+ localUsage.InputTokens += textToken + audioToken
+ localUsage.InputTokenDetails.TextTokens += textToken
+ localUsage.InputTokenDetails.AudioTokens += audioToken
+ err = preConsumeUsage(c, info, localUsage, sumUsage)
+ if err != nil {
+ errChan <- fmt.Errorf("error consume usage: %v", err)
+ return
+ }
+ // 本次计费完成,清除
+ localUsage = &dto.RealtimeUsage{}
+ // print now usage
+ }
+ common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
+ common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
+ common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
+
+ } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
+ realtimeSession := realtimeEvent.Session
+ if realtimeSession != nil {
+ // update audio format
+ info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
+ info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
+ }
+ } else {
+ textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
+ if err != nil {
+ errChan <- fmt.Errorf("error counting text token: %v", err)
+ return
+ }
+ common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
+ localUsage.TotalTokens += textToken + audioToken
+ localUsage.OutputTokens += textToken + audioToken
+ localUsage.OutputTokenDetails.TextTokens += textToken
+ localUsage.OutputTokenDetails.AudioTokens += audioToken
+ }
+
+ err = helper.WssString(c, clientConn, string(message))
+ if err != nil {
+ errChan <- fmt.Errorf("error writing to client: %v", err)
+ return
+ }
+
+ select {
+ case receiveChan <- message:
+ default:
+ }
+ }
+ }
+ })
+
+ select {
+ case <-clientClosed:
+ case <-targetClosed:
+ case err := <-errChan:
+ //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
+ common.LogError(c, "realtime error: "+err.Error())
+ case <-c.Done():
+ }
+
+ if usage.TotalTokens != 0 {
+ _ = preConsumeUsage(c, info, usage, sumUsage)
+ }
+
+ if localUsage.TotalTokens != 0 {
+ _ = preConsumeUsage(c, info, localUsage, sumUsage)
+ }
+
+ // check usage total tokens, if 0, use local usage
+
+ return nil, sumUsage
+}
+
+func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
+ if usage == nil || totalUsage == nil {
+ return fmt.Errorf("invalid usage pointer")
+ }
+
+ totalUsage.TotalTokens += usage.TotalTokens
+ totalUsage.InputTokens += usage.InputTokens
+ totalUsage.OutputTokens += usage.OutputTokens
+ totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
+ totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
+ totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
+ totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
+ totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
+ // clear usage
+ err := service.PreWssConsumeQuota(ctx, info, usage)
+ return err
+}
+
+func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ defer common.CloseResponseBodyGracefully(resp)
+
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
+ }
+
+ var usageResp dto.SimpleResponse
+ err = common.Unmarshal(responseBody, &usageResp)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+
+ // 写入新的 response body
+ common.IOCopyBytesGracefully(c, resp, responseBody)
+
+ // Once we've written to the client, we should not return errors anymore
+ // because the upstream has already consumed resources and returned content
+ // We should still perform billing even if parsing fails
+ // format
+ if usageResp.InputTokens > 0 {
+ usageResp.PromptTokens += usageResp.InputTokens
+ }
+ if usageResp.OutputTokens > 0 {
+ usageResp.CompletionTokens += usageResp.OutputTokens
+ }
+ if usageResp.InputTokensDetails != nil {
+ usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
+ usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
+ }
+ return &usageResp.Usage, nil
+}
diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go
new file mode 100644
index 00000000..d9dd96b9
--- /dev/null
+++ b/relay/channel/openai/relay_responses.go
@@ -0,0 +1,97 @@
+package openai
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ defer common.CloseResponseBodyGracefully(resp)
+
+ // read response body
+ var responsesResponse dto.OpenAIResponsesResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
+ }
+ err = common.Unmarshal(responseBody, &responsesResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ if responsesResponse.Error != nil {
+ return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode)
+ }
+
+ // 写入新的 response body
+ common.IOCopyBytesGracefully(c, resp, responseBody)
+
+ // compute usage
+ usage := dto.Usage{}
+ usage.PromptTokens = responsesResponse.Usage.InputTokens
+ usage.CompletionTokens = responsesResponse.Usage.OutputTokens
+ usage.TotalTokens = responsesResponse.Usage.TotalTokens
+ // 解析 Tools 用量
+ for _, tool := range responsesResponse.Tools {
+ info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])].CallCount++
+ }
+ return &usage, nil
+}
+
+func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ if resp == nil || resp.Body == nil {
+ common.LogError(c, "invalid response or response body")
+ return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
+ }
+
+ var usage = &dto.Usage{}
+ var responseTextBuilder strings.Builder
+
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+
+ // 检查当前数据是否包含 completed 状态和 usage 信息
+ var streamResponse dto.ResponsesStreamResponse
+ if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
+ sendResponsesStreamData(c, streamResponse, data)
+ switch streamResponse.Type {
+ case "response.completed":
+ usage.PromptTokens = streamResponse.Response.Usage.InputTokens
+ usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
+ usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
+ case "response.output_text.delta":
+ // 处理输出文本
+ responseTextBuilder.WriteString(streamResponse.Delta)
+ case dto.ResponsesOutputTypeItemDone:
+ // 函数调用处理
+ if streamResponse.Item != nil {
+ switch streamResponse.Item.Type {
+ case dto.BuildInCallWebSearchCall:
+ info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview].CallCount++
+ }
+ }
+ }
+ }
+ return true
+ })
+
+ if usage.CompletionTokens == 0 {
+ // 计算输出文本的 token 数量
+ tempStr := responseTextBuilder.String()
+ if len(tempStr) > 0 {
+ // 非正常结束,使用输出文本的 token 数量
+ completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
+ usage.CompletionTokens = completionTokens
+ }
+ }
+
+ return usage, nil
+}
diff --git a/relay/channel/openrouter/constant.go b/relay/channel/openrouter/constant.go
new file mode 100644
index 00000000..0372eb9a
--- /dev/null
+++ b/relay/channel/openrouter/constant.go
@@ -0,0 +1,5 @@
+package openrouter
+
+var ModelList = []string{}
+
+var ChannelName = "openrouter"
diff --git a/relay/channel/openrouter/dto.go b/relay/channel/openrouter/dto.go
new file mode 100644
index 00000000..607f495b
--- /dev/null
+++ b/relay/channel/openrouter/dto.go
@@ -0,0 +1,9 @@
+package openrouter
+
+type RequestReasoning struct {
+ // One of the following (not both):
+ Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style)
+ MaxTokens int `json:"max_tokens,omitempty"` // Specific token limit (Anthropic-style)
+ // Optional: Default is false. All models support this.
+ Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response
+}
diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go
new file mode 100644
index 00000000..a60dc4b2
--- /dev/null
+++ b/relay/channel/palm/adaptor.go
@@ -0,0 +1,91 @@
+package palm
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("x-goog-api-key", info.ApiKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ var responseText string
+ err, responseText = palmStreamHandler(c, resp)
+ usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
+ } else {
+ usage, err = palmHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/palm/constants.go b/relay/channel/palm/constants.go
new file mode 100644
index 00000000..b5c881bf
--- /dev/null
+++ b/relay/channel/palm/constants.go
@@ -0,0 +1,7 @@
+package palm
+
+var ModelList = []string{
+ "PaLM-2",
+}
+
+var ChannelName = "google palm"
diff --git a/relay/channel/palm/dto.go b/relay/channel/palm/dto.go
new file mode 100644
index 00000000..b8a48e73
--- /dev/null
+++ b/relay/channel/palm/dto.go
@@ -0,0 +1,38 @@
+package palm
+
+import "one-api/dto"
+
+type PaLMChatMessage struct {
+ Author string `json:"author"`
+ Content string `json:"content"`
+}
+
+type PaLMFilter struct {
+ Reason string `json:"reason"`
+ Message string `json:"message"`
+}
+
+type PaLMPrompt struct {
+ Messages []PaLMChatMessage `json:"messages"`
+}
+
+type PaLMChatRequest struct {
+ Prompt PaLMPrompt `json:"prompt"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ CandidateCount int `json:"candidateCount,omitempty"`
+ TopP float64 `json:"topP,omitempty"`
+ TopK uint `json:"topK,omitempty"`
+}
+
+type PaLMError struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Status string `json:"status"`
+}
+
+type PaLMChatResponse struct {
+ Candidates []PaLMChatMessage `json:"candidates"`
+ Messages []dto.Message `json:"messages"`
+ Filters []PaLMFilter `json:"filters"`
+ Error PaLMError `json:"error"`
+}
diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go
new file mode 100644
index 00000000..4db31573
--- /dev/null
+++ b/relay/channel/palm/relay-palm.go
@@ -0,0 +1,162 @@
+package palm
+
+import (
+ "encoding/json"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
+// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
+
+func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
+ palmRequest := PaLMChatRequest{
+ Prompt: PaLMPrompt{
+ Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
+ },
+ Temperature: textRequest.Temperature,
+ CandidateCount: textRequest.N,
+ TopP: textRequest.TopP,
+ TopK: textRequest.MaxTokens,
+ }
+ for _, message := range textRequest.Messages {
+ palmMessage := PaLMChatMessage{
+ Content: message.StringContent(),
+ }
+ if message.Role == "user" {
+ palmMessage.Author = "0"
+ } else {
+ palmMessage.Author = "1"
+ }
+ palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
+ }
+ return &palmRequest
+}
+
+func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
+ fullTextResponse := dto.OpenAITextResponse{
+ Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
+ }
+ for i, candidate := range response.Candidates {
+ choice := dto.OpenAITextResponseChoice{
+ Index: i,
+ Message: dto.Message{
+ Role: "assistant",
+ Content: candidate.Content,
+ },
+ FinishReason: "stop",
+ }
+ fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
+ }
+ return &fullTextResponse
+}
+
+func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
+ var choice dto.ChatCompletionsStreamResponseChoice
+ if len(palmResponse.Candidates) > 0 {
+ choice.Delta.SetContentString(palmResponse.Candidates[0].Content)
+ }
+ choice.FinishReason = &constant.FinishReasonStop
+ var response dto.ChatCompletionsStreamResponse
+ response.Object = "chat.completion.chunk"
+ response.Model = "palm2"
+ response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
+ return &response
+}
+
+func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, string) {
+ responseText := ""
+ responseId := helper.GetResponseID(c)
+ createdTime := common.GetTimestamp()
+ dataChan := make(chan string)
+ stopChan := make(chan bool)
+ go func() {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ common.SysError("error reading stream response: " + err.Error())
+ stopChan <- true
+ return
+ }
+ common.CloseResponseBodyGracefully(resp)
+ var palmResponse PaLMChatResponse
+ err = json.Unmarshal(responseBody, &palmResponse)
+ if err != nil {
+ common.SysError("error unmarshalling stream response: " + err.Error())
+ stopChan <- true
+ return
+ }
+ fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
+ fullTextResponse.Id = responseId
+ fullTextResponse.Created = createdTime
+ if len(palmResponse.Candidates) > 0 {
+ responseText = palmResponse.Candidates[0].Content
+ }
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ common.SysError("error marshalling stream response: " + err.Error())
+ stopChan <- true
+ return
+ }
+ dataChan <- string(jsonResponse)
+ stopChan <- true
+ }()
+ helper.SetEventStreamHeaders(c)
+ c.Stream(func(w io.Writer) bool {
+ select {
+ case data := <-dataChan:
+ c.Render(-1, common.CustomEvent{Data: "data: " + data})
+ return true
+ case <-stopChan:
+ c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+ return false
+ }
+ })
+ common.CloseResponseBodyGracefully(resp)
+ return nil, responseText
+}
+
+func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
+ }
+ common.CloseResponseBodyGracefully(resp)
+ var palmResponse PaLMChatResponse
+ err = json.Unmarshal(responseBody, &palmResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
+ return nil, types.WithOpenAIError(types.OpenAIError{
+ Message: palmResponse.Error.Message,
+ Type: palmResponse.Error.Status,
+ Param: "",
+ Code: palmResponse.Error.Code,
+ }, resp.StatusCode)
+ }
+ fullTextResponse := responsePaLM2OpenAI(&palmResponse)
+ completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, info.UpstreamModelName)
+ usage := dto.Usage{
+ PromptTokens: info.PromptTokens,
+ CompletionTokens: completionTokens,
+ TotalTokens: info.PromptTokens + completionTokens,
+ }
+ fullTextResponse.Usage = usage
+ jsonResponse, err := common.Marshal(fullTextResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ common.IOCopyBytesGracefully(c, resp, jsonResponse)
+ return &usage, nil
+}
diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go
new file mode 100644
index 00000000..19830aca
--- /dev/null
+++ b/relay/channel/perplexity/adaptor.go
@@ -0,0 +1,92 @@
+package perplexity
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", "Bearer "+info.ApiKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ if request.TopP >= 1 {
+ request.TopP = 0.99
+ }
+ return requestOpenAI2Perplexity(*request), nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ usage, err = openai.OaiStreamHandler(c, info, resp)
+ } else {
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/perplexity/constants.go b/relay/channel/perplexity/constants.go
new file mode 100644
index 00000000..f9f030e0
--- /dev/null
+++ b/relay/channel/perplexity/constants.go
@@ -0,0 +1,7 @@
+package perplexity
+
+var ModelList = []string{
+ "llama-3-sonar-small-32k-chat", "llama-3-sonar-small-32k-online", "llama-3-sonar-large-32k-chat", "llama-3-sonar-large-32k-online", "llama-3-8b-instruct", "llama-3-70b-instruct", "mixtral-8x7b-instruct",
+}
+
+var ChannelName = "perplexity"
diff --git a/relay/channel/perplexity/relay-perplexity.go b/relay/channel/perplexity/relay-perplexity.go
new file mode 100644
index 00000000..9772aead
--- /dev/null
+++ b/relay/channel/perplexity/relay-perplexity.go
@@ -0,0 +1,21 @@
+package perplexity
+
+import "one-api/dto"
+
+func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
+ messages := make([]dto.Message, 0, len(request.Messages))
+ for _, message := range request.Messages {
+ messages = append(messages, dto.Message{
+ Role: message.Role,
+ Content: message.Content,
+ })
+ }
+ return &dto.GeneralOpenAIRequest{
+ Model: request.Model,
+ Stream: request.Stream,
+ Messages: messages,
+ Temperature: request.Temperature,
+ TopP: request.TopP,
+ MaxTokens: request.MaxTokens,
+ }
+}
diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go
new file mode 100644
index 00000000..63c1c84d
--- /dev/null
+++ b/relay/channel/siliconflow/adaptor.go
@@ -0,0 +1,104 @@
+package siliconflow
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ if info.RelayMode == constant.RelayModeRerank {
+ return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
+ } else if info.RelayMode == constant.RelayModeEmbeddings {
+ return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
+ } else if info.RelayMode == constant.RelayModeChatCompletions {
+ return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
+ } else if info.RelayMode == constant.RelayModeCompletions {
+ return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
+ }
+ return "", errors.New("invalid relay mode")
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ switch info.RelayMode {
+ case constant.RelayModeRerank:
+ usage, err = siliconflowRerankHandler(c, info, resp)
+ case constant.RelayModeCompletions:
+ fallthrough
+ case constant.RelayModeChatCompletions:
+ if info.IsStream {
+ usage, err = openai.OaiStreamHandler(c, info, resp)
+ } else {
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ case constant.RelayModeEmbeddings:
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/siliconflow/constant.go b/relay/channel/siliconflow/constant.go
new file mode 100644
index 00000000..fea6fcd4
--- /dev/null
+++ b/relay/channel/siliconflow/constant.go
@@ -0,0 +1,51 @@
+package siliconflow
+
+var ModelList = []string{
+ "THUDM/glm-4-9b-chat",
+ //"stabilityai/stable-diffusion-xl-base-1.0",
+ //"TencentARC/PhotoMaker",
+ "InstantX/InstantID",
+ //"stabilityai/stable-diffusion-2-1",
+ //"stabilityai/sd-turbo",
+ //"stabilityai/sdxl-turbo",
+ "ByteDance/SDXL-Lightning",
+ "deepseek-ai/deepseek-llm-67b-chat",
+ "Qwen/Qwen1.5-14B-Chat",
+ "Qwen/Qwen1.5-7B-Chat",
+ "Qwen/Qwen1.5-110B-Chat",
+ "Qwen/Qwen1.5-32B-Chat",
+ "01-ai/Yi-1.5-6B-Chat",
+ "01-ai/Yi-1.5-9B-Chat-16K",
+ "01-ai/Yi-1.5-34B-Chat-16K",
+ "THUDM/chatglm3-6b",
+ "deepseek-ai/DeepSeek-V2-Chat",
+ "Qwen/Qwen2-72B-Instruct",
+ "Qwen/Qwen2-7B-Instruct",
+ "Qwen/Qwen2-57B-A14B-Instruct",
+ //"stabilityai/stable-diffusion-3-medium",
+ "deepseek-ai/DeepSeek-Coder-V2-Instruct",
+ "Qwen/Qwen2-1.5B-Instruct",
+ "internlm/internlm2_5-7b-chat",
+ "BAAI/bge-large-en-v1.5",
+ "BAAI/bge-large-zh-v1.5",
+ "Pro/Qwen/Qwen2-7B-Instruct",
+ "Pro/Qwen/Qwen2-1.5B-Instruct",
+ "Pro/Qwen/Qwen1.5-7B-Chat",
+ "Pro/THUDM/glm-4-9b-chat",
+ "Pro/THUDM/chatglm3-6b",
+ "Pro/01-ai/Yi-1.5-9B-Chat-16K",
+ "Pro/01-ai/Yi-1.5-6B-Chat",
+ "Pro/google/gemma-2-9b-it",
+ "Pro/internlm/internlm2_5-7b-chat",
+ "Pro/meta-llama/Meta-Llama-3-8B-Instruct",
+ "Pro/mistralai/Mistral-7B-Instruct-v0.2",
+ "black-forest-labs/FLUX.1-schnell",
+ "FunAudioLLM/SenseVoiceSmall",
+ "netease-youdao/bce-embedding-base_v1",
+ "BAAI/bge-m3",
+ "internlm/internlm2_5-20b-chat",
+ "Qwen/Qwen2-Math-72B-Instruct",
+ "netease-youdao/bce-reranker-base_v1",
+ "BAAI/bge-reranker-v2-m3",
+}
+var ChannelName = "siliconflow"
diff --git a/relay/channel/siliconflow/dto.go b/relay/channel/siliconflow/dto.go
new file mode 100644
index 00000000..add0fd07
--- /dev/null
+++ b/relay/channel/siliconflow/dto.go
@@ -0,0 +1,17 @@
+package siliconflow
+
+import "one-api/dto"
+
+type SFTokens struct {
+ InputTokens int `json:"input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+}
+
+type SFMeta struct {
+ Tokens SFTokens `json:"tokens"`
+}
+
+type SFRerankResponse struct {
+ Results []dto.RerankResponseResult `json:"results"`
+ Meta SFMeta `json:"meta"`
+}
diff --git a/relay/channel/siliconflow/relay-siliconflow.go b/relay/channel/siliconflow/relay-siliconflow.go
new file mode 100644
index 00000000..fabaf9c6
--- /dev/null
+++ b/relay/channel/siliconflow/relay-siliconflow.go
@@ -0,0 +1,44 @@
+package siliconflow
+
+import (
+ "encoding/json"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
+ }
+ common.CloseResponseBodyGracefully(resp)
+ var siliconflowResp SFRerankResponse
+ err = json.Unmarshal(responseBody, &siliconflowResp)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ usage := &dto.Usage{
+ PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
+ CompletionTokens: siliconflowResp.Meta.Tokens.OutputTokens,
+ TotalTokens: siliconflowResp.Meta.Tokens.InputTokens + siliconflowResp.Meta.Tokens.OutputTokens,
+ }
+ rerankResp := &dto.RerankResponse{
+ Results: siliconflowResp.Results,
+ Usage: *usage,
+ }
+
+ jsonResponse, err := json.Marshal(rerankResp)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ common.IOCopyBytesGracefully(c, resp, jsonResponse)
+ return usage, nil
+}
diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go
new file mode 100644
index 00000000..8d057513
--- /dev/null
+++ b/relay/channel/task/jimeng/adaptor.go
@@ -0,0 +1,380 @@
+package jimeng
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "one-api/model"
+ "sort"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/pkg/errors"
+
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+)
+
+// ============================
+// Request / Response structures
+// ============================
+
+type requestPayload struct {
+ ReqKey string `json:"req_key"`
+ BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
+ ImageUrls []string `json:"image_urls,omitempty"`
+ Prompt string `json:"prompt,omitempty"`
+ Seed int64 `json:"seed"`
+ AspectRatio string `json:"aspect_ratio"`
+}
+
+type responsePayload struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ RequestId string `json:"request_id"`
+ Data struct {
+ TaskID string `json:"task_id"`
+ } `json:"data"`
+}
+
+type responseTask struct {
+ Code int `json:"code"`
+ Data struct {
+ BinaryDataBase64 []interface{} `json:"binary_data_base64"`
+ ImageUrls interface{} `json:"image_urls"`
+ RespData string `json:"resp_data"`
+ Status string `json:"status"`
+ VideoUrl string `json:"video_url"`
+ } `json:"data"`
+ Message string `json:"message"`
+ RequestId string `json:"request_id"`
+ Status int `json:"status"`
+ TimeElapsed string `json:"time_elapsed"`
+}
+
+// ============================
+// Adaptor implementation
+// ============================
+
+type TaskAdaptor struct {
+ ChannelType int
+ accessKey string
+ secretKey string
+ baseURL string
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
+ a.ChannelType = info.ChannelType
+ a.baseURL = info.BaseUrl
+
+ // apiKey format: "access_key|secret_key"
+ keyParts := strings.Split(info.ApiKey, "|")
+ if len(keyParts) == 2 {
+ a.accessKey = strings.TrimSpace(keyParts[0])
+ a.secretKey = strings.TrimSpace(keyParts[1])
+ }
+}
+
+// ValidateRequestAndSetAction parses body, validates fields and sets default action.
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
+ // Accept only POST /v1/video/generations as "generate" action.
+ action := constant.TaskActionGenerate
+ info.Action = action
+
+ req := relaycommon.TaskSubmitReq{}
+ if err := common.UnmarshalBodyReusable(c, &req); err != nil {
+ taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+ return
+ }
+ if strings.TrimSpace(req.Prompt) == "" {
+ taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
+ return
+ }
+
+ // Store into context for later usage
+ c.Set("task_request", req)
+ return nil
+}
+
+// BuildRequestURL constructs the upstream URL.
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
+ return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
+}
+
+// BuildRequestHeader sets required headers.
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ return a.signRequest(req, a.accessKey, a.secretKey)
+}
+
+// BuildRequestBody converts request into Jimeng specific format.
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
+ v, exists := c.Get("task_request")
+ if !exists {
+ return nil, fmt.Errorf("request not found in context")
+ }
+ req := v.(relaycommon.TaskSubmitReq)
+
+ body, err := a.convertToRequestPayload(&req)
+ if err != nil {
+ return nil, errors.Wrap(err, "convert request payload failed")
+ }
+ data, err := json.Marshal(body)
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}
+
+// DoRequest delegates to common helper.
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+// DoResponse handles upstream response, returns taskID etc.
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+ _ = resp.Body.Close()
+
+ // Parse Jimeng response
+ var jResp responsePayload
+ if err := json.Unmarshal(responseBody, &jResp); err != nil {
+ taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+
+ if jResp.Code != 10000 {
+ taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{"task_id": jResp.Data.TaskID})
+ return jResp.Data.TaskID, responseBody, nil
+}
+
+// FetchTask fetch task status
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+ taskID, ok := body["task_id"].(string)
+ if !ok {
+ return nil, fmt.Errorf("invalid task_id")
+ }
+
+ uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
+ payload := map[string]string{
+ "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
+ "task_id": taskID,
+ }
+ payloadBytes, err := json.Marshal(payload)
+ if err != nil {
+ return nil, errors.Wrap(err, "marshal fetch task payload failed")
+ }
+
+ req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes))
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Content-Type", "application/json")
+
+ keyParts := strings.Split(key, "|")
+ if len(keyParts) != 2 {
+ return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
+ }
+ accessKey := strings.TrimSpace(keyParts[0])
+ secretKey := strings.TrimSpace(keyParts[1])
+
+ if err := a.signRequest(req, accessKey, secretKey); err != nil {
+ return nil, errors.Wrap(err, "sign request failed")
+ }
+
+ return service.GetHttpClient().Do(req)
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+ return []string{"jimeng_vgfm_t2v_l20"}
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+ return "jimeng"
+}
+
+func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error {
+ var bodyBytes []byte
+ var err error
+
+ if req.Body != nil {
+ bodyBytes, err = io.ReadAll(req.Body)
+ if err != nil {
+ return errors.Wrap(err, "read request body failed")
+ }
+ _ = req.Body.Close()
+ req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
+ } else {
+ bodyBytes = []byte{}
+ }
+
+ payloadHash := sha256.Sum256(bodyBytes)
+ hexPayloadHash := hex.EncodeToString(payloadHash[:])
+
+ t := time.Now().UTC()
+ xDate := t.Format("20060102T150405Z")
+ shortDate := t.Format("20060102")
+
+ req.Header.Set("Host", req.URL.Host)
+ req.Header.Set("X-Date", xDate)
+ req.Header.Set("X-Content-Sha256", hexPayloadHash)
+
+ // Sort and encode query parameters to create canonical query string
+ queryParams := req.URL.Query()
+ sortedKeys := make([]string, 0, len(queryParams))
+ for k := range queryParams {
+ sortedKeys = append(sortedKeys, k)
+ }
+ sort.Strings(sortedKeys)
+ var queryParts []string
+ for _, k := range sortedKeys {
+ values := queryParams[k]
+ sort.Strings(values)
+ for _, v := range values {
+ queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
+ }
+ }
+ canonicalQueryString := strings.Join(queryParts, "&")
+
+ headersToSign := map[string]string{
+ "host": req.URL.Host,
+ "x-date": xDate,
+ "x-content-sha256": hexPayloadHash,
+ }
+ if req.Header.Get("Content-Type") != "" {
+ headersToSign["content-type"] = req.Header.Get("Content-Type")
+ }
+
+ var signedHeaderKeys []string
+ for k := range headersToSign {
+ signedHeaderKeys = append(signedHeaderKeys, k)
+ }
+ sort.Strings(signedHeaderKeys)
+
+ var canonicalHeaders strings.Builder
+ for _, k := range signedHeaderKeys {
+ canonicalHeaders.WriteString(k)
+ canonicalHeaders.WriteString(":")
+ canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
+ canonicalHeaders.WriteString("\n")
+ }
+ signedHeaders := strings.Join(signedHeaderKeys, ";")
+
+ canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
+ req.Method,
+ req.URL.Path,
+ canonicalQueryString,
+ canonicalHeaders.String(),
+ signedHeaders,
+ hexPayloadHash,
+ )
+
+ hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
+ hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
+
+ region := "cn-north-1"
+ serviceName := "cv"
+ credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
+ stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
+ xDate,
+ credentialScope,
+ hexHashedCanonicalRequest,
+ )
+
+ kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
+ kRegion := hmacSHA256(kDate, []byte(region))
+ kService := hmacSHA256(kRegion, []byte(serviceName))
+ kSigning := hmacSHA256(kService, []byte("request"))
+ signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
+
+ authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
+ accessKey,
+ credentialScope,
+ signedHeaders,
+ signature,
+ )
+ req.Header.Set("Authorization", authorization)
+ return nil
+}
+
+func hmacSHA256(key []byte, data []byte) []byte {
+ h := hmac.New(sha256.New, key)
+ h.Write(data)
+ return h.Sum(nil)
+}
+
+func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
+ r := requestPayload{
+ ReqKey: "jimeng_vgfm_i2v_l20",
+ Prompt: req.Prompt,
+ AspectRatio: "16:9", // Default aspect ratio
+ Seed: -1, // Default to random
+ }
+
+ // Handle one-of image_urls or binary_data_base64
+ if req.Image != "" {
+ if strings.HasPrefix(req.Image, "http") {
+ r.ImageUrls = []string{req.Image}
+ } else {
+ r.BinaryDataBase64 = []string{req.Image}
+ }
+ }
+ metadata := req.Metadata
+ medaBytes, err := json.Marshal(metadata)
+ if err != nil {
+ return nil, errors.Wrap(err, "metadata marshal metadata failed")
+ }
+ err = json.Unmarshal(medaBytes, &r)
+ if err != nil {
+ return nil, errors.Wrap(err, "unmarshal metadata failed")
+ }
+ return &r, nil
+}
+
+func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
+ resTask := responseTask{}
+ if err := json.Unmarshal(respBody, &resTask); err != nil {
+ return nil, errors.Wrap(err, "unmarshal task result failed")
+ }
+ taskResult := relaycommon.TaskInfo{}
+ if resTask.Code == 10000 {
+ taskResult.Code = 0
+ } else {
+ taskResult.Code = resTask.Code // todo uni code
+ taskResult.Reason = resTask.Message
+ taskResult.Status = model.TaskStatusFailure
+ taskResult.Progress = "100%"
+ }
+ switch resTask.Data.Status {
+ case "in_queue":
+ taskResult.Status = model.TaskStatusQueued
+ taskResult.Progress = "10%"
+ case "done":
+ taskResult.Status = model.TaskStatusSuccess
+ taskResult.Progress = "100%"
+ }
+ taskResult.Url = resTask.Data.VideoUrl
+ return &taskResult, nil
+}
diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go
new file mode 100644
index 00000000..afa39201
--- /dev/null
+++ b/relay/channel/task/kling/adaptor.go
@@ -0,0 +1,346 @@
+package kling
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "github.com/samber/lo"
+ "io"
+ "net/http"
+ "one-api/model"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/golang-jwt/jwt"
+ "github.com/pkg/errors"
+
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+)
+
+// ============================
+// Request / Response structures
+// ============================
+
+type SubmitReq struct {
+ Prompt string `json:"prompt"`
+ Model string `json:"model,omitempty"`
+ Mode string `json:"mode,omitempty"`
+ Image string `json:"image,omitempty"`
+ Size string `json:"size,omitempty"`
+ Duration int `json:"duration,omitempty"`
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+}
+
+type requestPayload struct {
+ Prompt string `json:"prompt,omitempty"`
+ Image string `json:"image,omitempty"`
+ Mode string `json:"mode,omitempty"`
+ Duration string `json:"duration,omitempty"`
+ AspectRatio string `json:"aspect_ratio,omitempty"`
+ ModelName string `json:"model_name,omitempty"`
+ CfgScale float64 `json:"cfg_scale,omitempty"`
+}
+
+type responsePayload struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ RequestId string `json:"request_id"`
+ Data struct {
+ TaskId string `json:"task_id"`
+ TaskStatus string `json:"task_status"`
+ TaskStatusMsg string `json:"task_status_msg"`
+ TaskResult struct {
+ Videos []struct {
+ Id string `json:"id"`
+ Url string `json:"url"`
+ Duration string `json:"duration"`
+ } `json:"videos"`
+ } `json:"task_result"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+ } `json:"data"`
+}
+
+// ============================
+// Adaptor implementation
+// ============================
+
+type TaskAdaptor struct {
+ ChannelType int
+ accessKey string
+ secretKey string
+ baseURL string
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
+ a.ChannelType = info.ChannelType
+ a.baseURL = info.BaseUrl
+
+ // apiKey format: "access_key|secret_key"
+ keyParts := strings.Split(info.ApiKey, "|")
+ if len(keyParts) == 2 {
+ a.accessKey = strings.TrimSpace(keyParts[0])
+ a.secretKey = strings.TrimSpace(keyParts[1])
+ }
+}
+
+// ValidateRequestAndSetAction parses body, validates fields and sets default action.
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
+ // Accept only POST /v1/video/generations as "generate" action.
+ action := constant.TaskActionGenerate
+ info.Action = action
+
+ var req SubmitReq
+ if err := common.UnmarshalBodyReusable(c, &req); err != nil {
+ taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+ return
+ }
+ if strings.TrimSpace(req.Prompt) == "" {
+ taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
+ return
+ }
+
+ // Store into context for later usage
+ c.Set("task_request", req)
+ return nil
+}
+
+// BuildRequestURL constructs the upstream URL.
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
+ path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
+ return fmt.Sprintf("%s%s", a.baseURL, path), nil
+}
+
+// BuildRequestHeader sets required headers.
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
+ token, err := a.createJWTToken()
+ if err != nil {
+ return fmt.Errorf("failed to create JWT token: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Authorization", "Bearer "+token)
+ req.Header.Set("User-Agent", "kling-sdk/1.0")
+ return nil
+}
+
+// BuildRequestBody converts request into Kling specific format.
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
+ v, exists := c.Get("task_request")
+ if !exists {
+ return nil, fmt.Errorf("request not found in context")
+ }
+ req := v.(SubmitReq)
+
+ body, err := a.convertToRequestPayload(&req)
+ if err != nil {
+ return nil, err
+ }
+ data, err := json.Marshal(body)
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}
+
+// DoRequest delegates to common helper.
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+ if action := c.GetString("action"); action != "" {
+ info.Action = action
+ }
+ return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+// DoResponse handles upstream response, returns taskID etc.
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+
+ // Attempt Kling response parse first.
+ var kResp responsePayload
+ if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
+ c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskId})
+ return kResp.Data.TaskId, responseBody, nil
+ }
+
+ // Fallback generic task response.
+ var generic dto.TaskResponse[string]
+ if err := json.Unmarshal(responseBody, &generic); err != nil {
+ taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+
+ if !generic.IsSuccess() {
+ taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
+ return generic.Data, responseBody, nil
+}
+
+// FetchTask fetch task status
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+ taskID, ok := body["task_id"].(string)
+ if !ok {
+ return nil, fmt.Errorf("invalid task_id")
+ }
+ action, ok := body["action"].(string)
+ if !ok {
+ return nil, fmt.Errorf("invalid action")
+ }
+ path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
+ url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
+
+ req, err := http.NewRequest(http.MethodGet, url, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ token, err := a.createJWTTokenWithKey(key)
+ if err != nil {
+ token = key
+ }
+
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Authorization", "Bearer "+token)
+ req.Header.Set("User-Agent", "kling-sdk/1.0")
+
+ return service.GetHttpClient().Do(req)
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+ return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+ return "kling"
+}
+
+// ============================
+// helpers
+// ============================
+
+func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
+ r := requestPayload{
+ Prompt: req.Prompt,
+ Image: req.Image,
+ Mode: defaultString(req.Mode, "std"),
+ Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
+ AspectRatio: a.getAspectRatio(req.Size),
+ ModelName: req.Model,
+ CfgScale: 0.5,
+ }
+ if r.ModelName == "" {
+ r.ModelName = "kling-v1"
+ }
+ metadata := req.Metadata
+ medaBytes, err := json.Marshal(metadata)
+ if err != nil {
+ return nil, errors.Wrap(err, "metadata marshal metadata failed")
+ }
+ err = json.Unmarshal(medaBytes, &r)
+ if err != nil {
+ return nil, errors.Wrap(err, "unmarshal metadata failed")
+ }
+ return &r, nil
+}
+
+func (a *TaskAdaptor) getAspectRatio(size string) string {
+ switch size {
+ case "1024x1024", "512x512":
+ return "1:1"
+ case "1280x720", "1920x1080":
+ return "16:9"
+ case "720x1280", "1080x1920":
+ return "9:16"
+ default:
+ return "1:1"
+ }
+}
+
+func defaultString(s, def string) string {
+ if strings.TrimSpace(s) == "" {
+ return def
+ }
+ return s
+}
+
+func defaultInt(v int, def int) int {
+ if v == 0 {
+ return def
+ }
+ return v
+}
+
+// ============================
+// JWT helpers
+// ============================
+
+func (a *TaskAdaptor) createJWTToken() (string, error) {
+ return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
+}
+
+func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
+ parts := strings.Split(apiKey, "|")
+ if len(parts) != 2 {
+ return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
+ }
+ return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
+}
+
+func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
+ if accessKey == "" || secretKey == "" {
+ return "", fmt.Errorf("access key and secret key are required")
+ }
+ now := time.Now().Unix()
+ claims := jwt.MapClaims{
+ "iss": accessKey,
+ "exp": now + 1800, // 30 minutes
+ "nbf": now - 5,
+ }
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+ token.Header["typ"] = "JWT"
+ return token.SignedString([]byte(secretKey))
+}
+
+func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
+ resPayload := responsePayload{}
+ err := json.Unmarshal(respBody, &resPayload)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to unmarshal response body")
+ }
+ taskInfo := &relaycommon.TaskInfo{}
+ taskInfo.Code = resPayload.Code
+ taskInfo.TaskID = resPayload.Data.TaskId
+ taskInfo.Reason = resPayload.Message
+ //任务状态,枚举值:submitted(已提交)、processing(处理中)、succeed(成功)、failed(失败)
+ status := resPayload.Data.TaskStatus
+ switch status {
+ case "submitted":
+ taskInfo.Status = model.TaskStatusSubmitted
+ case "processing":
+ taskInfo.Status = model.TaskStatusInProgress
+ case "succeed":
+ taskInfo.Status = model.TaskStatusSuccess
+ case "failed":
+ taskInfo.Status = model.TaskStatusFailure
+ default:
+ return nil, fmt.Errorf("unknown task status: %s", status)
+ }
+ if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 {
+ video := videos[0]
+ taskInfo.Url = video.Url
+ }
+ return taskInfo, nil
+}
diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go
new file mode 100644
index 00000000..9c04c7ad
--- /dev/null
+++ b/relay/channel/task/suno/adaptor.go
@@ -0,0 +1,176 @@
+package suno
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+ "strings"
+ "time"
+)
+
+type TaskAdaptor struct {
+ ChannelType int
+}
+
+func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
+ return nil, fmt.Errorf("not implement") // todo implement this method if needed
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
+ a.ChannelType = info.ChannelType
+}
+
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
+ action := strings.ToUpper(c.Param("action"))
+
+ var sunoRequest *dto.SunoSubmitReq
+ err := common.UnmarshalBodyReusable(c, &sunoRequest)
+ if err != nil {
+ taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+ return
+ }
+ err = actionValidate(c, sunoRequest, action)
+ if err != nil {
+ taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+ return
+ }
+
+ if sunoRequest.ContinueClipId != "" {
+ if sunoRequest.TaskID == "" {
+ taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
+ return
+ }
+ info.OriginTaskID = sunoRequest.TaskID
+ }
+
+ info.Action = action
+ c.Set("task_request", sunoRequest)
+ return nil
+}
+
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
+ baseURL := info.BaseUrl
+ fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action)
+ return fullRequestURL, nil
+}
+
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
+ req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+ req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+ req.Header.Set("Authorization", "Bearer "+info.ApiKey)
+ return nil
+}
+
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
+ sunoRequest, ok := c.Get("task_request")
+ if !ok {
+ err := common.UnmarshalBodyReusable(c, &sunoRequest)
+ if err != nil {
+ return nil, err
+ }
+ }
+ data, err := json.Marshal(sunoRequest)
+ if err != nil {
+ return nil, err
+ }
+ return bytes.NewReader(data), nil
+}
+
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+ return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+ var sunoResponse dto.TaskResponse[string]
+ err = json.Unmarshal(responseBody, &sunoResponse)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+ if !sunoResponse.IsSuccess() {
+ taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
+ return
+ }
+
+ for k, v := range resp.Header {
+ c.Writer.Header().Set(k, v[0])
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+
+ _, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+
+ return sunoResponse.Data, nil, nil
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+ return ChannelName
+}
+
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+ requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
+ byteBody, err := json.Marshal(body)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
+ if err != nil {
+ common.SysError(fmt.Sprintf("Get Task error: %v", err))
+ return nil, err
+ }
+ defer req.Body.Close()
+ // 设置超时时间
+ timeout := time.Second * 15
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ // 使用带有超时的 context 创建新的请求
+ req = req.WithContext(ctx)
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+key)
+ resp, err := service.GetHttpClient().Do(req)
+ if err != nil {
+ return nil, err
+ }
+ return resp, nil
+}
+
+func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {
+ switch action {
+ case constant.SunoActionMusic:
+ if sunoRequest.Mv == "" {
+ sunoRequest.Mv = "chirp-v3-0"
+ }
+ case constant.SunoActionLyrics:
+ if sunoRequest.Prompt == "" {
+ err = fmt.Errorf("prompt_empty")
+ return
+ }
+ default:
+ err = fmt.Errorf("invalid_action")
+ }
+ return
+}
diff --git a/relay/channel/task/suno/models.go b/relay/channel/task/suno/models.go
new file mode 100644
index 00000000..967cf1b1
--- /dev/null
+++ b/relay/channel/task/suno/models.go
@@ -0,0 +1,7 @@
+package suno
+
+var ModelList = []string{
+ "suno_music", "suno_lyrics",
+}
+
+var ChannelName = "suno"
diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go
new file mode 100644
index 00000000..520276a7
--- /dev/null
+++ b/relay/channel/tencent/adaptor.go
@@ -0,0 +1,113 @@
+package tencent
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+ "strconv"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+ Sign string
+ AppID int64
+ Action string
+ Version string
+ Timestamp int64
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+ a.Action = "ChatCompletions"
+ a.Version = "2023-09-01"
+ a.Timestamp = common.GetTimestamp()
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ return fmt.Sprintf("%s/", info.BaseUrl), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", a.Sign)
+ req.Set("X-TC-Action", a.Action)
+ req.Set("X-TC-Version", a.Version)
+ req.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
+ apiKey = strings.TrimPrefix(apiKey, "Bearer ")
+ appId, secretId, secretKey, err := parseTencentConfig(apiKey)
+ a.AppID = appId
+ if err != nil {
+ return nil, err
+ }
+ tencentRequest := requestOpenAI2Tencent(a, *request)
+ // we have to calculate the sign here
+ a.Sign = getTencentSign(*tencentRequest, a, secretId, secretKey)
+ return tencentRequest, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ usage, err = tencentStreamHandler(c, info, resp)
+ } else {
+ usage, err = tencentHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/tencent/constants.go b/relay/channel/tencent/constants.go
new file mode 100644
index 00000000..d4d9cc1f
--- /dev/null
+++ b/relay/channel/tencent/constants.go
@@ -0,0 +1,10 @@
+package tencent
+
+var ModelList = []string{
+ "hunyuan-lite",
+ "hunyuan-standard",
+ "hunyuan-standard-256K",
+ "hunyuan-pro",
+}
+
+var ChannelName = "tencent"
diff --git a/relay/channel/tencent/dto.go b/relay/channel/tencent/dto.go
new file mode 100644
index 00000000..65c548a9
--- /dev/null
+++ b/relay/channel/tencent/dto.go
@@ -0,0 +1,75 @@
+package tencent
+
+type TencentMessage struct {
+ Role string `json:"Role"`
+ Content string `json:"Content"`
+}
+
+type TencentChatRequest struct {
+ // 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。
+ // 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。
+ //
+ // 注意:
+ // 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。
+ Model *string `json:"Model"`
+ // 聊天上下文信息。
+ // 说明:
+ // 1. 长度最多为 40,按对话时间从旧到新在数组中排列。
+ // 2. Message.Role 可选值:system、user、assistant。
+ // 其中,system 角色可选,如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system(可选) user assistant user assistant user ...]。
+ // 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。
+ Messages []*TencentMessage `json:"Messages"`
+ // 流式调用开关。
+ // 说明:
+ // 1. 未传值时默认为非流式调用(false)。
+ // 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。
+ // 3. 非流式调用时:
+ // 调用方式与普通 HTTP 请求无异。
+ // 接口响应耗时较长,**如需更低时延建议设置为 true**。
+ // 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。
+ //
+ // 注意:
+ // 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
+ Stream *bool `json:"Stream,omitempty"`
+ // 说明:
+ // 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
+ // 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
+ // 3. 非必要不建议使用,不合理的取值会影响效果。
+ TopP *float64 `json:"TopP,omitempty"`
+ // 说明:
+ // 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
+ // 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
+ // 3. 非必要不建议使用,不合理的取值会影响效果。
+ Temperature *float64 `json:"Temperature,omitempty"`
+}
+
+type TencentError struct {
+ Code int `json:"Code"`
+ Message string `json:"Message"`
+}
+
+type TencentUsage struct {
+ PromptTokens int `json:"PromptTokens"`
+ CompletionTokens int `json:"CompletionTokens"`
+ TotalTokens int `json:"TotalTokens"`
+}
+
+type TencentResponseChoices struct {
+ FinishReason string `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
+ Messages TencentMessage `json:"Message,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
+ Delta TencentMessage `json:"Delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
+}
+
+type TencentChatResponse struct {
+ Choices []TencentResponseChoices `json:"Choices,omitempty"` // 结果
+ Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
+ Id string `json:"Id,omitempty"` // 会话 id
+ Usage TencentUsage `json:"Usage,omitempty"` // token 数量
+ Error TencentError `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
+ Note string `json:"Note,omitempty"` // 注释
+ ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
+}
+
+type TencentChatResponseSB struct {
+ Response TencentChatResponse `json:"Response,omitempty"`
+}
diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go
new file mode 100644
index 00000000..c3d96c49
--- /dev/null
+++ b/relay/channel/tencent/relay-tencent.go
@@ -0,0 +1,233 @@
+package tencent
+
+import (
+ "bufio"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+// https://cloud.tencent.com/document/product/1729/97732
+
+func requestOpenAI2Tencent(a *Adaptor, request dto.GeneralOpenAIRequest) *TencentChatRequest {
+ messages := make([]*TencentMessage, 0, len(request.Messages))
+ for i := 0; i < len(request.Messages); i++ {
+ message := request.Messages[i]
+ messages = append(messages, &TencentMessage{
+ Content: message.StringContent(),
+ Role: message.Role,
+ })
+ }
+ var req = TencentChatRequest{
+ Stream: &request.Stream,
+ Messages: messages,
+ Model: &request.Model,
+ }
+ if request.TopP != 0 {
+ req.TopP = &request.TopP
+ }
+ req.Temperature = request.Temperature
+ return &req
+}
+
+func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
+ fullTextResponse := dto.OpenAITextResponse{
+ Id: response.Id,
+ Object: "chat.completion",
+ Created: common.GetTimestamp(),
+ Usage: dto.Usage{
+ PromptTokens: response.Usage.PromptTokens,
+ CompletionTokens: response.Usage.CompletionTokens,
+ TotalTokens: response.Usage.TotalTokens,
+ },
+ }
+ if len(response.Choices) > 0 {
+ choice := dto.OpenAITextResponseChoice{
+ Index: 0,
+ Message: dto.Message{
+ Role: "assistant",
+ Content: response.Choices[0].Messages.Content,
+ },
+ FinishReason: response.Choices[0].FinishReason,
+ }
+ fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
+ }
+ return &fullTextResponse
+}
+
+func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.ChatCompletionsStreamResponse {
+ response := dto.ChatCompletionsStreamResponse{
+ Object: "chat.completion.chunk",
+ Created: common.GetTimestamp(),
+ Model: "tencent-hunyuan",
+ }
+ if len(TencentResponse.Choices) > 0 {
+ var choice dto.ChatCompletionsStreamResponseChoice
+ choice.Delta.SetContentString(TencentResponse.Choices[0].Delta.Content)
+ if TencentResponse.Choices[0].FinishReason == "stop" {
+ choice.FinishReason = &constant.FinishReasonStop
+ }
+ response.Choices = append(response.Choices, choice)
+ }
+ return &response
+}
+
+func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ var responseText string
+ scanner := bufio.NewScanner(resp.Body)
+ scanner.Split(bufio.ScanLines)
+
+ helper.SetEventStreamHeaders(c)
+
+ for scanner.Scan() {
+ data := scanner.Text()
+ if len(data) < 5 || !strings.HasPrefix(data, "data:") {
+ continue
+ }
+ data = strings.TrimPrefix(data, "data:")
+
+ var tencentResponse TencentChatResponse
+ err := json.Unmarshal([]byte(data), &tencentResponse)
+ if err != nil {
+ common.SysError("error unmarshalling stream response: " + err.Error())
+ continue
+ }
+
+ response := streamResponseTencent2OpenAI(&tencentResponse)
+ if len(response.Choices) != 0 {
+ responseText += response.Choices[0].Delta.GetContentString()
+ }
+
+ err = helper.ObjectData(c, response)
+ if err != nil {
+ common.SysError(err.Error())
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ common.SysError("error reading stream: " + err.Error())
+ }
+
+ helper.Done(c)
+
+ common.CloseResponseBodyGracefully(resp)
+
+ return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
+}
+
+func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ var tencentSb TencentChatResponseSB
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
+ }
+ common.CloseResponseBodyGracefully(resp)
+ err = json.Unmarshal(responseBody, &tencentSb)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ if tencentSb.Response.Error.Code != 0 {
+ return nil, types.WithOpenAIError(types.OpenAIError{
+ Message: tencentSb.Response.Error.Message,
+ Code: tencentSb.Response.Error.Code,
+ }, resp.StatusCode)
+ }
+ fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
+ jsonResponse, err := common.Marshal(fullTextResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ common.IOCopyBytesGracefully(c, resp, jsonResponse)
+ return &fullTextResponse.Usage, nil
+}
+
+func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
+ parts := strings.Split(config, "|")
+ if len(parts) != 3 {
+ err = errors.New("invalid tencent config")
+ return
+ }
+ appId, err = strconv.ParseInt(parts[0], 10, 64)
+ secretId = parts[1]
+ secretKey = parts[2]
+ return
+}
+
+func sha256hex(s string) string {
+ b := sha256.Sum256([]byte(s))
+ return hex.EncodeToString(b[:])
+}
+
+func hmacSha256(s, key string) string {
+ hashed := hmac.New(sha256.New, []byte(key))
+ hashed.Write([]byte(s))
+ return string(hashed.Sum(nil))
+}
+
+func getTencentSign(req TencentChatRequest, adaptor *Adaptor, secId, secKey string) string {
+ // build canonical request string
+ host := "hunyuan.tencentcloudapi.com"
+ httpRequestMethod := "POST"
+ canonicalURI := "/"
+ canonicalQueryString := ""
+ canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
+ "application/json", host, strings.ToLower(adaptor.Action))
+ signedHeaders := "content-type;host;x-tc-action"
+ payload, _ := json.Marshal(req)
+ hashedRequestPayload := sha256hex(string(payload))
+ canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
+ httpRequestMethod,
+ canonicalURI,
+ canonicalQueryString,
+ canonicalHeaders,
+ signedHeaders,
+ hashedRequestPayload)
+ // build string to sign
+ algorithm := "TC3-HMAC-SHA256"
+ requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
+ timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
+ t := time.Unix(timestamp, 0).UTC()
+ // must be the format 2006-01-02, ref to package time for more info
+ date := t.Format("2006-01-02")
+ credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
+ hashedCanonicalRequest := sha256hex(canonicalRequest)
+ string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
+ algorithm,
+ requestTimestamp,
+ credentialScope,
+ hashedCanonicalRequest)
+
+ // sign string
+ secretDate := hmacSha256(date, "TC3"+secKey)
+ secretService := hmacSha256("hunyuan", secretDate)
+ secretKey := hmacSha256("tc3_request", secretService)
+ signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
+
+ // build authorization
+ authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
+ algorithm,
+ secId,
+ credentialScope,
+ signedHeaders,
+ signature)
+ return authorization
+}
diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go
new file mode 100644
index 00000000..fa895de0
--- /dev/null
+++ b/relay/channel/vertex/adaptor.go
@@ -0,0 +1,262 @@
+package vertex
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/claude"
+ "one-api/relay/channel/gemini"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
+ "one-api/setting/model_setting"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ RequestModeClaude = 1
+ RequestModeGemini = 2
+ RequestModeLlama = 3
+)
+
+var claudeModelMap = map[string]string{
+ "claude-3-sonnet-20240229": "claude-3-sonnet@20240229",
+ "claude-3-opus-20240229": "claude-3-opus@20240229",
+ "claude-3-haiku-20240307": "claude-3-haiku@20240307",
+ "claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
+ "claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022",
+ "claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
+ "claude-sonnet-4-20250514": "claude-sonnet-4@20250514",
+ "claude-opus-4-20250514": "claude-opus-4@20250514",
+}
+
+const anthropicVersion = "vertex-2023-10-16"
+
+type Adaptor struct {
+ RequestMode int
+ AccountCredentials Credentials
+}
+
+func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
+ if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
+ c.Set("request_model", v)
+ } else {
+ c.Set("request_model", request.Model)
+ }
+ vertexClaudeReq := copyRequest(request, anthropicVersion)
+ return vertexClaudeReq, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+ if strings.HasPrefix(info.UpstreamModelName, "claude") {
+ a.RequestMode = RequestModeClaude
+ } else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
+ a.RequestMode = RequestModeGemini
+ } else if strings.Contains(info.UpstreamModelName, "llama") {
+ a.RequestMode = RequestModeLlama
+ }
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ adc := &Credentials{}
+ if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
+ return "", fmt.Errorf("failed to decode credentials file: %w", err)
+ }
+ region := GetModelRegion(info.ApiVersion, info.OriginModelName)
+ a.AccountCredentials = *adc
+ suffix := ""
+ if a.RequestMode == RequestModeGemini {
+ if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
+ // 新增逻辑:处理 -thinking- 格式
+ if strings.Contains(info.UpstreamModelName, "-thinking-") {
+ parts := strings.Split(info.UpstreamModelName, "-thinking-")
+ info.UpstreamModelName = parts[0]
+ } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
+ info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
+ } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
+ info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
+ }
+ }
+
+ if info.IsStream {
+ suffix = "streamGenerateContent?alt=sse"
+ } else {
+ suffix = "generateContent"
+ }
+ if region == "global" {
+ return fmt.Sprintf(
+ "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
+ adc.ProjectID,
+ info.UpstreamModelName,
+ suffix,
+ ), nil
+ } else {
+ return fmt.Sprintf(
+ "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
+ region,
+ adc.ProjectID,
+ region,
+ info.UpstreamModelName,
+ suffix,
+ ), nil
+ }
+ } else if a.RequestMode == RequestModeClaude {
+ if info.IsStream {
+ suffix = "streamRawPredict?alt=sse"
+ } else {
+ suffix = "rawPredict"
+ }
+ model := info.UpstreamModelName
+ if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
+ model = v
+ }
+ if region == "global" {
+ return fmt.Sprintf(
+ "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
+ adc.ProjectID,
+ model,
+ suffix,
+ ), nil
+ } else {
+ return fmt.Sprintf(
+ "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
+ region,
+ adc.ProjectID,
+ region,
+ model,
+ suffix,
+ ), nil
+ }
+ } else if a.RequestMode == RequestModeLlama {
+ return fmt.Sprintf(
+ "https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
+ region,
+ adc.ProjectID,
+ region,
+ ), nil
+ }
+ return "", errors.New("unsupported request mode")
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ accessToken, err := getAccessToken(a, info)
+ if err != nil {
+ return err
+ }
+ req.Set("Authorization", "Bearer "+accessToken)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ if a.RequestMode == RequestModeClaude {
+ claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request)
+ if err != nil {
+ return nil, err
+ }
+ vertexClaudeReq := copyRequest(claudeReq, anthropicVersion)
+ c.Set("request_model", claudeReq.Model)
+ info.UpstreamModelName = claudeReq.Model
+ return vertexClaudeReq, nil
+ } else if a.RequestMode == RequestModeGemini {
+ geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info)
+ if err != nil {
+ return nil, err
+ }
+ c.Set("request_model", request.Model)
+ return geminiRequest, nil
+ } else if a.RequestMode == RequestModeLlama {
+ return request, nil
+ }
+ return nil, errors.New("unsupported request mode")
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ switch a.RequestMode {
+ case RequestModeClaude:
+ err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
+ case RequestModeGemini:
+ if info.RelayMode == constant.RelayModeGemini {
+ usage, err = gemini.GeminiTextGenerationStreamHandler(c, info, resp)
+ } else {
+ usage, err = gemini.GeminiChatStreamHandler(c, info, resp)
+ }
+ case RequestModeLlama:
+ usage, err = openai.OaiStreamHandler(c, info, resp)
+ }
+ } else {
+ switch a.RequestMode {
+ case RequestModeClaude:
+ err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
+ case RequestModeGemini:
+ if info.RelayMode == constant.RelayModeGemini {
+ usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
+ } else {
+ usage, err = gemini.GeminiChatHandler(c, info, resp)
+ }
+ case RequestModeLlama:
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ var modelList []string
+ for i, s := range ModelList {
+ modelList = append(modelList, s)
+ ModelList[i] = s
+ }
+ for i, s := range claude.ModelList {
+ modelList = append(modelList, s)
+ claude.ModelList[i] = s
+ }
+ for i, s := range gemini.ModelList {
+ modelList = append(modelList, s)
+ gemini.ModelList[i] = s
+ }
+ return modelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/vertex/constants.go b/relay/channel/vertex/constants.go
new file mode 100644
index 00000000..c39e23d1
--- /dev/null
+++ b/relay/channel/vertex/constants.go
@@ -0,0 +1,15 @@
+package vertex
+
+var ModelList = []string{
+ //"claude-3-sonnet-20240229",
+ //"claude-3-opus-20240229",
+ //"claude-3-haiku-20240307",
+ //"claude-3-5-sonnet-20240620",
+
+ //"gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
+ //"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
+
+ "meta/llama3-405b-instruct-maas",
+}
+
+var ChannelName = "vertex-ai"
diff --git a/relay/channel/vertex/dto.go b/relay/channel/vertex/dto.go
new file mode 100644
index 00000000..4a571612
--- /dev/null
+++ b/relay/channel/vertex/dto.go
@@ -0,0 +1,37 @@
+package vertex
+
+import (
+ "one-api/dto"
+)
+
+type VertexAIClaudeRequest struct {
+ AnthropicVersion string `json:"anthropic_version"`
+ Messages []dto.ClaudeMessage `json:"messages"`
+ System any `json:"system,omitempty"`
+ MaxTokens uint `json:"max_tokens,omitempty"`
+ StopSequences []string `json:"stop_sequences,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ Tools any `json:"tools,omitempty"`
+ ToolChoice any `json:"tool_choice,omitempty"`
+ Thinking *dto.Thinking `json:"thinking,omitempty"`
+}
+
+func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest {
+ return &VertexAIClaudeRequest{
+ AnthropicVersion: version,
+ System: req.System,
+ Messages: req.Messages,
+ MaxTokens: req.MaxTokens,
+ Stream: req.Stream,
+ Temperature: req.Temperature,
+ TopP: req.TopP,
+ TopK: req.TopK,
+ StopSequences: req.StopSequences,
+ Tools: req.Tools,
+ ToolChoice: req.ToolChoice,
+ Thinking: req.Thinking,
+ }
+}
diff --git a/relay/channel/vertex/relay-vertex.go b/relay/channel/vertex/relay-vertex.go
new file mode 100644
index 00000000..5ed87665
--- /dev/null
+++ b/relay/channel/vertex/relay-vertex.go
@@ -0,0 +1,19 @@
+package vertex
+
+import "one-api/common"
+
+func GetModelRegion(other string, localModelName string) string {
+ // if other is json string
+ if common.IsJsonObject(other) {
+ m, err := common.StrToMap(other)
+ if err != nil {
+ return other // return original if parsing fails
+ }
+ if m[localModelName] != nil {
+ return m[localModelName].(string)
+ } else {
+ return m["default"].(string)
+ }
+ }
+ return other
+}
diff --git a/relay/channel/vertex/service_account.go b/relay/channel/vertex/service_account.go
new file mode 100644
index 00000000..5a97c021
--- /dev/null
+++ b/relay/channel/vertex/service_account.go
@@ -0,0 +1,134 @@
+package vertex
+
+import (
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/json"
+ "encoding/pem"
+ "errors"
+ "github.com/bytedance/gopkg/cache/asynccache"
+ "github.com/golang-jwt/jwt"
+ "net/http"
+ "net/url"
+ relaycommon "one-api/relay/common"
+ "one-api/service"
+ "strings"
+
+ "fmt"
+ "time"
+)
+
+type Credentials struct {
+ ProjectID string `json:"project_id"`
+ PrivateKeyID string `json:"private_key_id"`
+ PrivateKey string `json:"private_key"`
+ ClientEmail string `json:"client_email"`
+ ClientID string `json:"client_id"`
+}
+
+var Cache = asynccache.NewAsyncCache(asynccache.Options{
+ RefreshDuration: time.Minute * 35,
+ EnableExpire: true,
+ ExpireDuration: time.Minute * 30,
+ Fetcher: func(key string) (interface{}, error) {
+ return nil, errors.New("not found")
+ },
+})
+
+func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
+ cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
+ val, err := Cache.Get(cacheKey)
+ if err == nil {
+ return val.(string), nil
+ }
+
+ signedJWT, err := createSignedJWT(a.AccountCredentials.ClientEmail, a.AccountCredentials.PrivateKey)
+ if err != nil {
+ return "", fmt.Errorf("failed to create signed JWT: %w", err)
+ }
+ newToken, err := exchangeJwtForAccessToken(signedJWT, info)
+ if err != nil {
+ return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
+ }
+ if err := Cache.SetDefault(cacheKey, newToken); err {
+ return newToken, nil
+ }
+ return newToken, nil
+}
+
+func createSignedJWT(email, privateKeyPEM string) (string, error) {
+
+ privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----BEGIN PRIVATE KEY-----", "")
+ privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----END PRIVATE KEY-----", "")
+ privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\r", "")
+ privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\n", "")
+ privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\\n", "")
+
+ block, _ := pem.Decode([]byte("-----BEGIN PRIVATE KEY-----\n" + privateKeyPEM + "\n-----END PRIVATE KEY-----"))
+ if block == nil {
+ return "", fmt.Errorf("failed to parse PEM block containing the private key")
+ }
+
+ privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
+ if err != nil {
+ return "", err
+ }
+
+ rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
+ if !ok {
+ return "", fmt.Errorf("not an RSA private key")
+ }
+
+ now := time.Now()
+ claims := jwt.MapClaims{
+ "iss": email,
+ "scope": "https://www.googleapis.com/auth/cloud-platform",
+ "aud": "https://www.googleapis.com/oauth2/v4/token",
+ "exp": now.Add(time.Minute * 35).Unix(),
+ "iat": now.Unix(),
+ }
+
+ token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
+ signedToken, err := token.SignedString(rsaPrivateKey)
+ if err != nil {
+ return "", err
+ }
+
+ return signedToken, nil
+}
+
+func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (string, error) {
+
+ authURL := "https://www.googleapis.com/oauth2/v4/token"
+ data := url.Values{}
+ data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
+ data.Set("assertion", signedJWT)
+
+ var client *http.Client
+ var err error
+ if info.ChannelSetting.Proxy != "" {
+ client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
+ if err != nil {
+ return "", fmt.Errorf("new proxy http client failed: %w", err)
+ }
+ } else {
+ client = service.GetHttpClient()
+ }
+
+ resp, err := client.PostForm(authURL, data)
+ if err != nil {
+ return "", err
+ }
+ defer resp.Body.Close()
+
+ var result map[string]interface{}
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return "", err
+ }
+
+ if accessToken, ok := result["access_token"].(string); ok {
+ return accessToken, nil
+ }
+
+ return "", fmt.Errorf("failed to get access token: %v", result)
+}
diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go
new file mode 100644
index 00000000..af15d636
--- /dev/null
+++ b/relay/channel/volcengine/adaptor.go
@@ -0,0 +1,251 @@
+package volcengine
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "mime/multipart"
+ "net/http"
+ "net/textproto"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/constant"
+ "one-api/types"
+ "path/filepath"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ switch info.RelayMode {
+ case constant.RelayModeImagesEdits:
+
+ var requestBody bytes.Buffer
+ writer := multipart.NewWriter(&requestBody)
+
+ writer.WriteField("model", request.Model)
+ // 获取所有表单字段
+ formData := c.Request.PostForm
+ // 遍历表单字段并打印输出
+ for key, values := range formData {
+ if key == "model" {
+ continue
+ }
+ for _, value := range values {
+ writer.WriteField(key, value)
+ }
+ }
+
+ // Parse the multipart form to handle both single image and multiple images
+ if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
+ return nil, errors.New("failed to parse multipart form")
+ }
+
+ if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
+ // Check if "image" field exists in any form, including array notation
+ var imageFiles []*multipart.FileHeader
+ var exists bool
+
+ // First check for standard "image" field
+ if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
+ // If not found, check for "image[]" field
+ if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
+ // If still not found, iterate through all fields to find any that start with "image["
+ foundArrayImages := false
+ for fieldName, files := range c.Request.MultipartForm.File {
+ if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
+ foundArrayImages = true
+ for _, file := range files {
+ imageFiles = append(imageFiles, file)
+ }
+ }
+ }
+
+ // If no image fields found at all
+ if !foundArrayImages && (len(imageFiles) == 0) {
+ return nil, errors.New("image is required")
+ }
+ }
+ }
+
+ // Process all image files
+ for i, fileHeader := range imageFiles {
+ file, err := fileHeader.Open()
+ if err != nil {
+ return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
+ }
+ defer file.Close()
+
+ // If multiple images, use image[] as the field name
+ fieldName := "image"
+ if len(imageFiles) > 1 {
+ fieldName = "image[]"
+ }
+
+ // Determine MIME type based on file extension
+ mimeType := detectImageMimeType(fileHeader.Filename)
+
+ // Create a form file with the appropriate content type
+ h := make(textproto.MIMEHeader)
+ h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
+ h.Set("Content-Type", mimeType)
+
+ part, err := writer.CreatePart(h)
+ if err != nil {
+ return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
+ }
+
+ if _, err := io.Copy(part, file); err != nil {
+ return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
+ }
+ }
+
+ // Handle mask file if present
+ if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
+ maskFile, err := maskFiles[0].Open()
+ if err != nil {
+ return nil, errors.New("failed to open mask file")
+ }
+ defer maskFile.Close()
+
+ // Determine MIME type for mask file
+ mimeType := detectImageMimeType(maskFiles[0].Filename)
+
+ // Create a form file with the appropriate content type
+ h := make(textproto.MIMEHeader)
+ h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
+ h.Set("Content-Type", mimeType)
+
+ maskPart, err := writer.CreatePart(h)
+ if err != nil {
+ return nil, errors.New("create form file failed for mask")
+ }
+
+ if _, err := io.Copy(maskPart, maskFile); err != nil {
+ return nil, errors.New("copy mask file failed")
+ }
+ }
+ } else {
+ return nil, errors.New("no multipart form data found")
+ }
+
+ // 关闭 multipart 编写器以设置分界线
+ writer.Close()
+ c.Request.Header.Set("Content-Type", writer.FormDataContentType())
+ return bytes.NewReader(requestBody.Bytes()), nil
+
+ default:
+ return request, nil
+ }
+}
+
+// detectImageMimeType determines the MIME type based on the file extension
+func detectImageMimeType(filename string) string {
+ ext := strings.ToLower(filepath.Ext(filename))
+ switch ext {
+ case ".jpg", ".jpeg":
+ return "image/jpeg"
+ case ".png":
+ return "image/png"
+ case ".webp":
+ return "image/webp"
+ default:
+ // Try to detect from extension if possible
+ if strings.HasPrefix(ext, ".jp") {
+ return "image/jpeg"
+ }
+ // Default to png as a fallback
+ return "image/png"
+ }
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ switch info.RelayMode {
+ case constant.RelayModeChatCompletions:
+ if strings.HasPrefix(info.UpstreamModelName, "bot") {
+ return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.BaseUrl), nil
+ }
+ return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
+ case constant.RelayModeEmbeddings:
+ return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
+ case constant.RelayModeImagesGenerations:
+ return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
+ default:
+ }
+ return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", "Bearer "+info.ApiKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ switch info.RelayMode {
+ case constant.RelayModeChatCompletions:
+ if info.IsStream {
+ usage, err = openai.OaiStreamHandler(c, info, resp)
+ } else {
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ case constant.RelayModeEmbeddings:
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
+ usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/volcengine/constants.go b/relay/channel/volcengine/constants.go
new file mode 100644
index 00000000..30cc902e
--- /dev/null
+++ b/relay/channel/volcengine/constants.go
@@ -0,0 +1,13 @@
+package volcengine
+
+var ModelList = []string{
+ "Doubao-pro-128k",
+ "Doubao-pro-32k",
+ "Doubao-pro-4k",
+ "Doubao-lite-128k",
+ "Doubao-lite-32k",
+ "Doubao-lite-4k",
+ "Doubao-embedding",
+}
+
+var ChannelName = "volcengine"
diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go
new file mode 100644
index 00000000..8d880137
--- /dev/null
+++ b/relay/channel/xai/adaptor.go
@@ -0,0 +1,128 @@
+package xai
+
+import (
+ "errors"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+ "strings"
+
+ "one-api/relay/constant"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ //panic("implement me")
+ return nil, errors.New("not available")
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //not available
+ return nil, errors.New("not available")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ xaiRequest := ImageRequest{
+ Model: request.Model,
+ Prompt: request.Prompt,
+ N: request.N,
+ ResponseFormat: request.ResponseFormat,
+ }
+ return xaiRequest, nil
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ req.Set("Authorization", "Bearer "+info.ApiKey)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ if strings.HasSuffix(info.UpstreamModelName, "-search") {
+ info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search")
+ request.Model = info.UpstreamModelName
+ toMap := request.ToMap()
+ toMap["search_parameters"] = map[string]any{
+ "mode": "on",
+ }
+ return toMap, nil
+ }
+ if strings.HasPrefix(request.Model, "grok-3-mini") {
+ if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
+ request.MaxCompletionTokens = request.MaxTokens
+ request.MaxTokens = 0
+ }
+ if strings.HasSuffix(request.Model, "-high") {
+ request.ReasoningEffort = "high"
+ request.Model = strings.TrimSuffix(request.Model, "-high")
+ } else if strings.HasSuffix(request.Model, "-low") {
+ request.ReasoningEffort = "low"
+ request.Model = strings.TrimSuffix(request.Model, "-low")
+ } else if strings.HasSuffix(request.Model, "-medium") {
+ request.ReasoningEffort = "medium"
+ request.Model = strings.TrimSuffix(request.Model, "-medium")
+ }
+ info.ReasoningEffort = request.ReasoningEffort
+ info.UpstreamModelName = request.Model
+ }
+ return request, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //not available
+ return nil, errors.New("not available")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ switch info.RelayMode {
+ case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
+ usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
+ default:
+ if info.IsStream {
+ usage, err = xAIStreamHandler(c, info, resp)
+ } else {
+ usage, err = xAIHandler(c, info, resp)
+ }
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/xai/constants.go b/relay/channel/xai/constants.go
new file mode 100644
index 00000000..311b4bb6
--- /dev/null
+++ b/relay/channel/xai/constants.go
@@ -0,0 +1,20 @@
+package xai
+
+var ModelList = []string{
+ // grok-4
+ "grok-4", "grok-4-0709", "grok-4-0709-search",
+ // grok-3
+ "grok-3-beta", "grok-3-mini-beta",
+ // grok-3 mini
+ "grok-3-fast-beta", "grok-3-mini-fast-beta",
+ // extend grok-3-mini reasoning
+ "grok-3-mini-beta-high", "grok-3-mini-beta-low", "grok-3-mini-beta-medium",
+ "grok-3-mini-fast-beta-high", "grok-3-mini-fast-beta-low", "grok-3-mini-fast-beta-medium",
+ // image model
+ "grok-2-image",
+ // legacy models
+ "grok-2", "grok-2-vision",
+ "grok-beta", "grok-vision-beta",
+}
+
+var ChannelName = "xai"
diff --git a/relay/channel/xai/dto.go b/relay/channel/xai/dto.go
new file mode 100644
index 00000000..107a980a
--- /dev/null
+++ b/relay/channel/xai/dto.go
@@ -0,0 +1,27 @@
+package xai
+
+import "one-api/dto"
+
+// ChatCompletionResponse represents the response from XAI chat completion API
+type ChatCompletionResponse struct {
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ Choices []dto.OpenAITextResponseChoice `json:"choices"`
+ Usage *dto.Usage `json:"usage"`
+ SystemFingerprint string `json:"system_fingerprint"`
+}
+
+// quality, size or style are not supported by xAI API at the moment.
+type ImageRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt" binding:"required"`
+ N int `json:"n,omitempty"`
+ // Size string `json:"size,omitempty"`
+ // Quality string `json:"quality,omitempty"`
+ ResponseFormat string `json:"response_format,omitempty"`
+ // Style string `json:"style,omitempty"`
+ // User string `json:"user,omitempty"`
+ // ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
+}
diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go
new file mode 100644
index 00000000..4d098102
--- /dev/null
+++ b/relay/channel/xai/text.go
@@ -0,0 +1,107 @@
+package xai
+
+import (
+ "encoding/json"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
+ if xAIResp == nil {
+ return nil
+ }
+ if xAIResp.Usage != nil {
+ xAIResp.Usage.CompletionTokens = usage.CompletionTokens
+ }
+ openAIResp := &dto.ChatCompletionsStreamResponse{
+ Id: xAIResp.Id,
+ Object: xAIResp.Object,
+ Created: xAIResp.Created,
+ Model: xAIResp.Model,
+ Choices: xAIResp.Choices,
+ Usage: xAIResp.Usage,
+ }
+
+ return openAIResp
+}
+
+func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ usage := &dto.Usage{}
+ var responseTextBuilder strings.Builder
+ var toolCount int
+ var containStreamUsage bool
+
+ helper.SetEventStreamHeaders(c)
+
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+ var xAIResp *dto.ChatCompletionsStreamResponse
+ err := json.Unmarshal([]byte(data), &xAIResp)
+ if err != nil {
+ common.SysError("error unmarshalling stream response: " + err.Error())
+ return true
+ }
+
+ // 把 xAI 的usage转换为 OpenAI 的usage
+ if xAIResp.Usage != nil {
+ containStreamUsage = true
+ usage.PromptTokens = xAIResp.Usage.PromptTokens
+ usage.TotalTokens = xAIResp.Usage.TotalTokens
+ usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
+ }
+
+ openaiResponse := streamResponseXAI2OpenAI(xAIResp, usage)
+ _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
+ err = helper.ObjectData(c, openaiResponse)
+ if err != nil {
+ common.SysError(err.Error())
+ }
+ return true
+ })
+
+ if !containStreamUsage {
+ usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+ usage.CompletionTokens += toolCount * 7
+ }
+
+ helper.Done(c)
+ common.CloseResponseBodyGracefully(resp)
+ return usage, nil
+}
+
+func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ defer common.CloseResponseBodyGracefully(resp)
+
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ var xaiResponse ChatCompletionResponse
+ err = common.Unmarshal(responseBody, &xaiResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ if xaiResponse.Usage != nil {
+ xaiResponse.Usage.CompletionTokens = xaiResponse.Usage.TotalTokens - xaiResponse.Usage.PromptTokens
+ xaiResponse.Usage.CompletionTokenDetails.TextTokens = xaiResponse.Usage.CompletionTokens - xaiResponse.Usage.CompletionTokenDetails.ReasoningTokens
+ }
+
+ // new body
+ encodeJson, err := common.Marshal(xaiResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+
+ common.IOCopyBytesGracefully(c, resp, encodeJson)
+
+ return xaiResponse.Usage, nil
+}
diff --git a/relay/channel/xinference/constant.go b/relay/channel/xinference/constant.go
new file mode 100644
index 00000000..a119084f
--- /dev/null
+++ b/relay/channel/xinference/constant.go
@@ -0,0 +1,8 @@
+package xinference
+
+var ModelList = []string{
+ "bge-reranker-v2-m3",
+ "jina-reranker-v2",
+}
+
+var ChannelName = "xinference"
diff --git a/relay/channel/xinference/dto.go b/relay/channel/xinference/dto.go
new file mode 100644
index 00000000..35f339fe
--- /dev/null
+++ b/relay/channel/xinference/dto.go
@@ -0,0 +1,11 @@
+package xinference
+
+type XinRerankResponseDocument struct {
+ Document any `json:"document,omitempty"`
+ Index int `json:"index"`
+ RelevanceScore float64 `json:"relevance_score"`
+}
+
+type XinRerankResponse struct {
+ Results []XinRerankResponseDocument `json:"results"`
+}
diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go
new file mode 100644
index 00000000..0d218ada
--- /dev/null
+++ b/relay/channel/xunfei/adaptor.go
@@ -0,0 +1,99 @@
+package xunfei
+
+import (
+ "errors"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+ request *dto.GeneralOpenAIRequest
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ return "", nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ a.request = request
+ return request, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ // xunfei's request is not http request, so we don't need to do anything here
+ dummyResp := &http.Response{}
+ dummyResp.StatusCode = http.StatusOK
+ return dummyResp, nil
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ splits := strings.Split(info.ApiKey, "|")
+ if len(splits) != 3 {
+ return nil, types.NewError(errors.New("invalid auth"), types.ErrorCodeChannelInvalidKey)
+ }
+ if a.request == nil {
+ return nil, types.NewError(errors.New("request is nil"), types.ErrorCodeInvalidRequest)
+ }
+ if info.IsStream {
+ usage, err = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
+ } else {
+ usage, err = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2])
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/xunfei/constants.go b/relay/channel/xunfei/constants.go
new file mode 100644
index 00000000..e19f0113
--- /dev/null
+++ b/relay/channel/xunfei/constants.go
@@ -0,0 +1,12 @@
+package xunfei
+
+var ModelList = []string{
+ "SparkDesk",
+ "SparkDesk-v1.1",
+ "SparkDesk-v2.1",
+ "SparkDesk-v3.1",
+ "SparkDesk-v3.5",
+ "SparkDesk-v4.0",
+}
+
+var ChannelName = "xunfei"
diff --git a/relay/channel/xunfei/dto.go b/relay/channel/xunfei/dto.go
new file mode 100644
index 00000000..c169e5f7
--- /dev/null
+++ b/relay/channel/xunfei/dto.go
@@ -0,0 +1,59 @@
+package xunfei
+
+import "one-api/dto"
+
+type XunfeiMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+type XunfeiChatRequest struct {
+ Header struct {
+ AppId string `json:"app_id"`
+ } `json:"header"`
+ Parameter struct {
+ Chat struct {
+ Domain string `json:"domain,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ MaxTokens uint `json:"max_tokens,omitempty"`
+ Auditing bool `json:"auditing,omitempty"`
+ } `json:"chat"`
+ } `json:"parameter"`
+ Payload struct {
+ Message struct {
+ Text []XunfeiMessage `json:"text"`
+ } `json:"message"`
+ } `json:"payload"`
+}
+
+type XunfeiChatResponseTextItem struct {
+ Content string `json:"content"`
+ Role string `json:"role"`
+ Index int `json:"index"`
+}
+
+type XunfeiChatResponse struct {
+ Header struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Sid string `json:"sid"`
+ Status int `json:"status"`
+ } `json:"header"`
+ Payload struct {
+ Choices struct {
+ Status int `json:"status"`
+ Seq int `json:"seq"`
+ Text []XunfeiChatResponseTextItem `json:"text"`
+ } `json:"choices"`
+ Usage struct {
+ //Text struct {
+ // QuestionTokens string `json:"question_tokens"`
+ // PromptTokens string `json:"prompt_tokens"`
+ // CompletionTokens string `json:"completion_tokens"`
+ // TotalTokens string `json:"total_tokens"`
+ //} `json:"text"`
+ Text dto.Usage `json:"text"`
+ } `json:"usage"`
+ } `json:"payload"`
+}
diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go
new file mode 100644
index 00000000..373ad605
--- /dev/null
+++ b/relay/channel/xunfei/relay-xunfei.go
@@ -0,0 +1,287 @@
+package xunfei
+
+import (
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/url"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/relay/helper"
+ "one-api/types"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/gorilla/websocket"
+)
+
+// https://console.xfyun.cn/services/cbm
+// https://www.xfyun.cn/doc/spark/Web.html
+
+func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
+ messages := make([]XunfeiMessage, 0, len(request.Messages))
+ shouldCovertSystemMessage := !strings.HasSuffix(request.Model, "3.5")
+ for _, message := range request.Messages {
+ if message.Role == "system" && shouldCovertSystemMessage {
+ messages = append(messages, XunfeiMessage{
+ Role: "user",
+ Content: message.StringContent(),
+ })
+ messages = append(messages, XunfeiMessage{
+ Role: "assistant",
+ Content: "Okay",
+ })
+ } else {
+ messages = append(messages, XunfeiMessage{
+ Role: message.Role,
+ Content: message.StringContent(),
+ })
+ }
+ }
+ xunfeiRequest := XunfeiChatRequest{}
+ xunfeiRequest.Header.AppId = xunfeiAppId
+ xunfeiRequest.Parameter.Chat.Domain = domain
+ xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
+ xunfeiRequest.Parameter.Chat.TopK = request.N
+ xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
+ xunfeiRequest.Payload.Message.Text = messages
+ return &xunfeiRequest
+}
+
+func responseXunfei2OpenAI(response *XunfeiChatResponse) *dto.OpenAITextResponse {
+ if len(response.Payload.Choices.Text) == 0 {
+ response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
+ {
+ Content: "",
+ },
+ }
+ }
+ choice := dto.OpenAITextResponseChoice{
+ Index: 0,
+ Message: dto.Message{
+ Role: "assistant",
+ Content: response.Payload.Choices.Text[0].Content,
+ },
+ FinishReason: constant.FinishReasonStop,
+ }
+ fullTextResponse := dto.OpenAITextResponse{
+ Object: "chat.completion",
+ Created: common.GetTimestamp(),
+ Choices: []dto.OpenAITextResponseChoice{choice},
+ Usage: response.Payload.Usage.Text,
+ }
+ return &fullTextResponse
+}
+
+func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *dto.ChatCompletionsStreamResponse {
+ if len(xunfeiResponse.Payload.Choices.Text) == 0 {
+ xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
+ {
+ Content: "",
+ },
+ }
+ }
+ var choice dto.ChatCompletionsStreamResponseChoice
+ choice.Delta.SetContentString(xunfeiResponse.Payload.Choices.Text[0].Content)
+ if xunfeiResponse.Payload.Choices.Status == 2 {
+ choice.FinishReason = &constant.FinishReasonStop
+ }
+ response := dto.ChatCompletionsStreamResponse{
+ Object: "chat.completion.chunk",
+ Created: common.GetTimestamp(),
+ Model: "SparkDesk",
+ Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
+ }
+ return &response
+}
+
+func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
+ HmacWithShaToBase64 := func(algorithm, data, key string) string {
+ mac := hmac.New(sha256.New, []byte(key))
+ mac.Write([]byte(data))
+ encodeData := mac.Sum(nil)
+ return base64.StdEncoding.EncodeToString(encodeData)
+ }
+ ul, err := url.Parse(hostUrl)
+ if err != nil {
+ fmt.Println(err)
+ }
+ date := time.Now().UTC().Format(time.RFC1123)
+ signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
+ sign := strings.Join(signString, "\n")
+ sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
+ authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
+ "hmac-sha256", "host date request-line", sha)
+ authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
+ v := url.Values{}
+ v.Add("host", ul.Host)
+ v.Add("date", date)
+ v.Add("authorization", authorization)
+ callUrl := hostUrl + "?" + v.Encode()
+ return callUrl
+}
+
+func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
+ domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
+ dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
+ }
+ helper.SetEventStreamHeaders(c)
+ var usage dto.Usage
+ c.Stream(func(w io.Writer) bool {
+ select {
+ case xunfeiResponse := <-dataChan:
+ usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
+ usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
+ usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
+ response := streamResponseXunfei2OpenAI(&xunfeiResponse)
+ jsonResponse, err := json.Marshal(response)
+ if err != nil {
+ common.SysError("error marshalling stream response: " + err.Error())
+ return true
+ }
+ c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+ return true
+ case <-stopChan:
+ c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+ return false
+ }
+ })
+ return &usage, nil
+}
+
+func xunfeiHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*dto.Usage, *types.NewAPIError) {
+ domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model)
+ dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeDoRequestFailed)
+ }
+ var usage dto.Usage
+ var content string
+ var xunfeiResponse XunfeiChatResponse
+ stop := false
+ for !stop {
+ select {
+ case xunfeiResponse = <-dataChan:
+ if len(xunfeiResponse.Payload.Choices.Text) == 0 {
+ continue
+ }
+ content += xunfeiResponse.Payload.Choices.Text[0].Content
+ usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
+ usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
+ usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
+ case stop = <-stopChan:
+ }
+ }
+ if len(xunfeiResponse.Payload.Choices.Text) == 0 {
+ xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
+ {
+ Content: "",
+ },
+ }
+ }
+ xunfeiResponse.Payload.Choices.Text[0].Content = content
+
+ response := responseXunfei2OpenAI(&xunfeiResponse)
+ jsonResponse, err := json.Marshal(response)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ _, _ = c.Writer.Write(jsonResponse)
+ return &usage, nil
+}
+
+func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
+ d := websocket.Dialer{
+ HandshakeTimeout: 5 * time.Second,
+ }
+ conn, resp, err := d.Dial(authUrl, nil)
+ if err != nil || resp.StatusCode != 101 {
+ return nil, nil, err
+ }
+ data := requestOpenAI2Xunfei(textRequest, appId, domain)
+ err = conn.WriteJSON(data)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ dataChan := make(chan XunfeiChatResponse)
+ stopChan := make(chan bool)
+ go func() {
+ for {
+ _, msg, err := conn.ReadMessage()
+ if err != nil {
+ common.SysError("error reading stream response: " + err.Error())
+ break
+ }
+ var response XunfeiChatResponse
+ err = json.Unmarshal(msg, &response)
+ if err != nil {
+ common.SysError("error unmarshalling stream response: " + err.Error())
+ break
+ }
+ dataChan <- response
+ if response.Payload.Choices.Status == 2 {
+ err := conn.Close()
+ if err != nil {
+ common.SysError("error closing websocket connection: " + err.Error())
+ }
+ break
+ }
+ }
+ stopChan <- true
+ }()
+
+ return dataChan, stopChan, nil
+}
+
+func apiVersion2domain(apiVersion string) string {
+ switch apiVersion {
+ case "v1.1":
+ return "lite"
+ case "v2.1":
+ return "generalv2"
+ case "v3.1":
+ return "generalv3"
+ case "v3.5":
+ return "generalv3.5"
+ case "v4.0":
+ return "4.0Ultra"
+ }
+ return "general" + apiVersion
+}
+
+func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) {
+ apiVersion := getAPIVersion(c, modelName)
+ domain := apiVersion2domain(apiVersion)
+ authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
+ return domain, authUrl
+}
+
+func getAPIVersion(c *gin.Context, modelName string) string {
+ query := c.Request.URL.Query()
+ apiVersion := query.Get("api-version")
+ if apiVersion != "" {
+ return apiVersion
+ }
+ parts := strings.Split(modelName, "-")
+ if len(parts) == 2 {
+ apiVersion = parts[1]
+ return apiVersion
+
+ }
+ apiVersion = c.GetString("api_version")
+ if apiVersion != "" {
+ return apiVersion
+ }
+ apiVersion = "v1.1"
+ common.SysLog("api_version not found, using default: " + apiVersion)
+ return apiVersion
+}
diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go
new file mode 100644
index 00000000..43344428
--- /dev/null
+++ b/relay/channel/zhipu/adaptor.go
@@ -0,0 +1,96 @@
+package zhipu
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ method := "invoke"
+ if info.IsStream {
+ method = "sse-invoke"
+ }
+ return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ token := getZhipuToken(info.ApiKey)
+ req.Set("Authorization", token)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ if request.TopP >= 1 {
+ request.TopP = 0.99
+ }
+ return requestOpenAI2Zhipu(*request), nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ usage, err = zhipuStreamHandler(c, info, resp)
+ } else {
+ usage, err = zhipuHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/zhipu/constants.go b/relay/channel/zhipu/constants.go
new file mode 100644
index 00000000..81b18d63
--- /dev/null
+++ b/relay/channel/zhipu/constants.go
@@ -0,0 +1,7 @@
+package zhipu
+
+var ModelList = []string{
+ "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
+}
+
+var ChannelName = "zhipu"
diff --git a/relay/channel/zhipu/dto.go b/relay/channel/zhipu/dto.go
new file mode 100644
index 00000000..2682dd3a
--- /dev/null
+++ b/relay/channel/zhipu/dto.go
@@ -0,0 +1,46 @@
+package zhipu
+
+import (
+ "one-api/dto"
+ "time"
+)
+
+type ZhipuMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+type ZhipuRequest struct {
+ Prompt []ZhipuMessage `json:"prompt"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ RequestId string `json:"request_id,omitempty"`
+ Incremental bool `json:"incremental,omitempty"`
+}
+
+type ZhipuResponseData struct {
+ TaskId string `json:"task_id"`
+ RequestId string `json:"request_id"`
+ TaskStatus string `json:"task_status"`
+ Choices []ZhipuMessage `json:"choices"`
+ dto.Usage `json:"usage"`
+}
+
+type ZhipuResponse struct {
+ Code int `json:"code"`
+ Msg string `json:"msg"`
+ Success bool `json:"success"`
+ Data ZhipuResponseData `json:"data"`
+}
+
+type ZhipuStreamMetaResponse struct {
+ RequestId string `json:"request_id"`
+ TaskId string `json:"task_id"`
+ TaskStatus string `json:"task_status"`
+ dto.Usage `json:"usage"`
+}
+
+type zhipuTokenData struct {
+ Token string
+ ExpiryTime time.Time
+}
diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go
new file mode 100644
index 00000000..916a200d
--- /dev/null
+++ b/relay/channel/zhipu/relay-zhipu.go
@@ -0,0 +1,245 @@
+package zhipu
+
+import (
+ "bufio"
+ "encoding/json"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/types"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/golang-jwt/jwt"
+)
+
+// https://open.bigmodel.cn/doc/api#chatglm_std
+// chatglm_std, chatglm_lite
+// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
+// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
+
+var zhipuTokens sync.Map
+var expSeconds int64 = 24 * 3600
+
+func getZhipuToken(apikey string) string {
+ data, ok := zhipuTokens.Load(apikey)
+ if ok {
+ tokenData := data.(zhipuTokenData)
+ if time.Now().Before(tokenData.ExpiryTime) {
+ return tokenData.Token
+ }
+ }
+
+ split := strings.Split(apikey, ".")
+ if len(split) != 2 {
+ common.SysError("invalid zhipu key: " + apikey)
+ return ""
+ }
+
+ id := split[0]
+ secret := split[1]
+
+ expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
+ expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
+
+ timestamp := time.Now().UnixNano() / 1e6
+
+ payload := jwt.MapClaims{
+ "api_key": id,
+ "exp": expMillis,
+ "timestamp": timestamp,
+ }
+
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
+
+ token.Header["alg"] = "HS256"
+ token.Header["sign_type"] = "SIGN"
+
+ tokenString, err := token.SignedString([]byte(secret))
+ if err != nil {
+ return ""
+ }
+
+ zhipuTokens.Store(apikey, zhipuTokenData{
+ Token: tokenString,
+ ExpiryTime: expiryTime,
+ })
+
+ return tokenString
+}
+
+func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *ZhipuRequest {
+ messages := make([]ZhipuMessage, 0, len(request.Messages))
+ for _, message := range request.Messages {
+ if message.Role == "system" {
+ messages = append(messages, ZhipuMessage{
+ Role: "system",
+ Content: message.StringContent(),
+ })
+ messages = append(messages, ZhipuMessage{
+ Role: "user",
+ Content: "Okay",
+ })
+ } else {
+ messages = append(messages, ZhipuMessage{
+ Role: message.Role,
+ Content: message.StringContent(),
+ })
+ }
+ }
+ return &ZhipuRequest{
+ Prompt: messages,
+ Temperature: request.Temperature,
+ TopP: request.TopP,
+ Incremental: false,
+ }
+}
+
+func responseZhipu2OpenAI(response *ZhipuResponse) *dto.OpenAITextResponse {
+ fullTextResponse := dto.OpenAITextResponse{
+ Id: response.Data.TaskId,
+ Object: "chat.completion",
+ Created: common.GetTimestamp(),
+ Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Data.Choices)),
+ Usage: response.Data.Usage,
+ }
+ for i, choice := range response.Data.Choices {
+ openaiChoice := dto.OpenAITextResponseChoice{
+ Index: i,
+ Message: dto.Message{
+ Role: choice.Role,
+ Content: strings.Trim(choice.Content, "\""),
+ },
+ FinishReason: "",
+ }
+ if i == len(response.Data.Choices)-1 {
+ openaiChoice.FinishReason = "stop"
+ }
+ fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
+ }
+ return &fullTextResponse
+}
+
+func streamResponseZhipu2OpenAI(zhipuResponse string) *dto.ChatCompletionsStreamResponse {
+ var choice dto.ChatCompletionsStreamResponseChoice
+ choice.Delta.SetContentString(zhipuResponse)
+ response := dto.ChatCompletionsStreamResponse{
+ Object: "chat.completion.chunk",
+ Created: common.GetTimestamp(),
+ Model: "chatglm",
+ Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
+ }
+ return &response
+}
+
+func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
+ var choice dto.ChatCompletionsStreamResponseChoice
+ choice.Delta.SetContentString("")
+ choice.FinishReason = &constant.FinishReasonStop
+ response := dto.ChatCompletionsStreamResponse{
+ Id: zhipuResponse.RequestId,
+ Object: "chat.completion.chunk",
+ Created: common.GetTimestamp(),
+ Model: "chatglm",
+ Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
+ }
+ return &response, &zhipuResponse.Usage
+}
+
+func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ var usage *dto.Usage
+ scanner := bufio.NewScanner(resp.Body)
+ scanner.Split(bufio.ScanLines)
+ dataChan := make(chan string)
+ metaChan := make(chan string)
+ stopChan := make(chan bool)
+ go func() {
+ for scanner.Scan() {
+ data := scanner.Text()
+ lines := strings.Split(data, "\n")
+ for i, line := range lines {
+ if len(line) < 5 {
+ continue
+ }
+ if line[:5] == "data:" {
+ dataChan <- line[5:]
+ if i != len(lines)-1 {
+ dataChan <- "\n"
+ }
+ } else if line[:5] == "meta:" {
+ metaChan <- line[5:]
+ }
+ }
+ }
+ stopChan <- true
+ }()
+ helper.SetEventStreamHeaders(c)
+ c.Stream(func(w io.Writer) bool {
+ select {
+ case data := <-dataChan:
+ response := streamResponseZhipu2OpenAI(data)
+ jsonResponse, err := json.Marshal(response)
+ if err != nil {
+ common.SysError("error marshalling stream response: " + err.Error())
+ return true
+ }
+ c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+ return true
+ case data := <-metaChan:
+ var zhipuResponse ZhipuStreamMetaResponse
+ err := json.Unmarshal([]byte(data), &zhipuResponse)
+ if err != nil {
+ common.SysError("error unmarshalling stream response: " + err.Error())
+ return true
+ }
+ response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
+ jsonResponse, err := json.Marshal(response)
+ if err != nil {
+ common.SysError("error marshalling stream response: " + err.Error())
+ return true
+ }
+ usage = zhipuUsage
+ c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+ return true
+ case <-stopChan:
+ c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+ return false
+ }
+ })
+ common.CloseResponseBodyGracefully(resp)
+ return usage, nil
+}
+
+func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ var zhipuResponse ZhipuResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
+ }
+ common.CloseResponseBodyGracefully(resp)
+ err = json.Unmarshal(responseBody, &zhipuResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ if !zhipuResponse.Success {
+ return nil, types.WithOpenAIError(types.OpenAIError{
+ Message: zhipuResponse.Msg,
+ Code: zhipuResponse.Code,
+ }, resp.StatusCode)
+ }
+ fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = c.Writer.Write(jsonResponse)
+ return &fullTextResponse.Usage, nil
+}
diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go
new file mode 100644
index 00000000..edd7a534
--- /dev/null
+++ b/relay/channel/zhipu_4v/adaptor.go
@@ -0,0 +1,99 @@
+package zhipu_4v
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/dto"
+ "one-api/relay/channel"
+ "one-api/relay/channel/openai"
+ relaycommon "one-api/relay/common"
+ relayconstant "one-api/relay/constant"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
+ //TODO implement me
+ panic("implement me")
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+ //TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+ baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl)
+ switch info.RelayMode {
+ case relayconstant.RelayModeEmbeddings:
+ return fmt.Sprintf("%s/embeddings", baseUrl), nil
+ default:
+ return fmt.Sprintf("%s/chat/completions", baseUrl), nil
+ }
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+ channel.SetupApiRequestHeader(info, c, req)
+ token := getZhipuToken(info.ApiKey)
+ req.Set("Authorization", token)
+ return nil
+}
+
+func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ if request.TopP >= 1 {
+ request.TopP = 0.99
+ }
+ return requestOpenAI2Zhipu(*request), nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+ return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+ return request, nil
+}
+
+func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
+ // TODO implement me
+ return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+ return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
+ if info.IsStream {
+ usage, err = openai.OaiStreamHandler(c, info, resp)
+ } else {
+ usage, err = openai.OpenaiHandler(c, info, resp)
+ }
+ return
+}
+
+func (a *Adaptor) GetModelList() []string {
+ return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+ return ChannelName
+}
diff --git a/relay/channel/zhipu_4v/constants.go b/relay/channel/zhipu_4v/constants.go
new file mode 100644
index 00000000..816fa536
--- /dev/null
+++ b/relay/channel/zhipu_4v/constants.go
@@ -0,0 +1,7 @@
+package zhipu_4v
+
+var ModelList = []string{
+ "glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools", "glm-4-plus", "glm-4-0520", "glm-4-air", "glm-4-airx", "glm-4-long", "glm-4-flash", "glm-4v-plus",
+}
+
+var ChannelName = "zhipu_4v"
diff --git a/relay/channel/zhipu_4v/dto.go b/relay/channel/zhipu_4v/dto.go
new file mode 100644
index 00000000..4d867679
--- /dev/null
+++ b/relay/channel/zhipu_4v/dto.go
@@ -0,0 +1,59 @@
+package zhipu_4v
+
+import (
+ "one-api/dto"
+ "time"
+)
+
+// type ZhipuMessage struct {
+// Role string `json:"role,omitempty"`
+// Content string `json:"content,omitempty"`
+// ToolCalls any `json:"tool_calls,omitempty"`
+// ToolCallId any `json:"tool_call_id,omitempty"`
+// }
+//
+// type ZhipuRequest struct {
+// Model string `json:"model"`
+// Stream bool `json:"stream,omitempty"`
+// Messages []ZhipuMessage `json:"messages"`
+// Temperature float64 `json:"temperature,omitempty"`
+// TopP float64 `json:"top_p,omitempty"`
+// MaxTokens int `json:"max_tokens,omitempty"`
+// Stop []string `json:"stop,omitempty"`
+// RequestId string `json:"request_id,omitempty"`
+// Tools any `json:"tools,omitempty"`
+// ToolChoice any `json:"tool_choice,omitempty"`
+// }
+//
+// type ZhipuV4TextResponseChoice struct {
+// Index int `json:"index"`
+// ZhipuMessage `json:"message"`
+// FinishReason string `json:"finish_reason"`
+// }
+type ZhipuV4Response struct {
+ Id string `json:"id"`
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ TextResponseChoices []dto.OpenAITextResponseChoice `json:"choices"`
+ Usage dto.Usage `json:"usage"`
+ Error dto.OpenAIError `json:"error"`
+}
+
+//
+//type ZhipuV4StreamResponseChoice struct {
+// Index int `json:"index,omitempty"`
+// Delta ZhipuMessage `json:"delta"`
+// FinishReason *string `json:"finish_reason,omitempty"`
+//}
+
+type ZhipuV4StreamResponse struct {
+ Id string `json:"id"`
+ Created int64 `json:"created"`
+ Choices []dto.ChatCompletionsStreamResponseChoice `json:"choices"`
+ Usage dto.Usage `json:"usage"`
+}
+
+type tokenData struct {
+ Token string
+ ExpiryTime time.Time
+}
diff --git a/relay/channel/zhipu_4v/relay-zhipu_v4.go b/relay/channel/zhipu_4v/relay-zhipu_v4.go
new file mode 100644
index 00000000..271dda8f
--- /dev/null
+++ b/relay/channel/zhipu_4v/relay-zhipu_v4.go
@@ -0,0 +1,113 @@
+package zhipu_4v
+
+import (
+ "github.com/golang-jwt/jwt"
+ "one-api/common"
+ "one-api/dto"
+ "strings"
+ "sync"
+ "time"
+)
+
+// https://open.bigmodel.cn/doc/api#chatglm_std
+// chatglm_std, chatglm_lite
+// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
+// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
+
+var zhipuTokens sync.Map
+var expSeconds int64 = 24 * 3600
+
+func getZhipuToken(apikey string) string {
+ data, ok := zhipuTokens.Load(apikey)
+ if ok {
+ tokenData := data.(tokenData)
+ if time.Now().Before(tokenData.ExpiryTime) {
+ return tokenData.Token
+ }
+ }
+
+ split := strings.Split(apikey, ".")
+ if len(split) != 2 {
+ common.SysError("invalid zhipu key: " + apikey)
+ return ""
+ }
+
+ id := split[0]
+ secret := split[1]
+
+ expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
+ expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
+
+ timestamp := time.Now().UnixNano() / 1e6
+
+ payload := jwt.MapClaims{
+ "api_key": id,
+ "exp": expMillis,
+ "timestamp": timestamp,
+ }
+
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
+
+ token.Header["alg"] = "HS256"
+ token.Header["sign_type"] = "SIGN"
+
+ tokenString, err := token.SignedString([]byte(secret))
+ if err != nil {
+ return ""
+ }
+
+ zhipuTokens.Store(apikey, tokenData{
+ Token: tokenString,
+ ExpiryTime: expiryTime,
+ })
+
+ return tokenString
+}
+
+func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
+ messages := make([]dto.Message, 0, len(request.Messages))
+ for _, message := range request.Messages {
+ if !message.IsStringContent() {
+ mediaMessages := message.ParseContent()
+ for j, mediaMessage := range mediaMessages {
+ if mediaMessage.Type == dto.ContentTypeImageURL {
+ imageUrl := mediaMessage.GetImageMedia()
+ // check if base64
+ if strings.HasPrefix(imageUrl.Url, "data:image/") {
+ // 去除base64数据的URL前缀(如果有)
+ if idx := strings.Index(imageUrl.Url, ","); idx != -1 {
+ imageUrl.Url = imageUrl.Url[idx+1:]
+ }
+ }
+ mediaMessage.ImageUrl = imageUrl
+ mediaMessages[j] = mediaMessage
+ }
+ }
+ message.SetMediaContent(mediaMessages)
+ }
+ messages = append(messages, dto.Message{
+ Role: message.Role,
+ Content: message.Content,
+ ToolCalls: message.ToolCalls,
+ ToolCallId: message.ToolCallId,
+ })
+ }
+ str, ok := request.Stop.(string)
+ var Stop []string
+ if ok {
+ Stop = []string{str}
+ } else {
+ Stop, _ = request.Stop.([]string)
+ }
+ return &dto.GeneralOpenAIRequest{
+ Model: request.Model,
+ Stream: request.Stream,
+ Messages: messages,
+ Temperature: request.Temperature,
+ TopP: request.TopP,
+ MaxTokens: request.MaxTokens,
+ Stop: Stop,
+ Tools: request.Tools,
+ ToolChoice: request.ToolChoice,
+ }
+}
diff --git a/relay/claude_handler.go b/relay/claude_handler.go
new file mode 100644
index 00000000..5f38960e
--- /dev/null
+++ b/relay/claude_handler.go
@@ -0,0 +1,162 @@
+package relay
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting/model_setting"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
+ textRequest = &dto.ClaudeRequest{}
+ err = c.ShouldBindJSON(textRequest)
+ if err != nil {
+ return nil, err
+ }
+ if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
+ return nil, errors.New("field messages is required")
+ }
+ if textRequest.Model == "" {
+ return nil, errors.New("field model is required")
+ }
+ return textRequest, nil
+}
+
+func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
+
+ relayInfo := relaycommon.GenRelayInfoClaude(c)
+
+ // get & validate textRequest 获取并验证文本请求
+ textRequest, err := getAndValidateClaudeRequest(c)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeInvalidRequest)
+ }
+
+ if textRequest.Stream {
+ relayInfo.IsStream = true
+ }
+
+ err = helper.ModelMappedHelper(c, relayInfo, textRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+ }
+
+ promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
+ // count messages token error 计算promptTokens错误
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeCountTokenFailed)
+ }
+
+ priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeModelPriceError)
+ }
+
+ // pre-consume quota 预消耗配额
+ preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+
+ if newAPIError != nil {
+ return newAPIError
+ }
+ defer func() {
+ if newAPIError != nil {
+ returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+ }
+ }()
+
+ adaptor := GetAdaptor(relayInfo.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+ }
+ adaptor.Init(relayInfo)
+ var requestBody io.Reader
+
+ if textRequest.MaxTokens == 0 {
+ textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
+ }
+
+ if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
+ strings.HasSuffix(textRequest.Model, "-thinking") {
+ if textRequest.Thinking == nil {
+ // 因为BudgetTokens 必须大于1024
+ if textRequest.MaxTokens < 1280 {
+ textRequest.MaxTokens = 1280
+ }
+
+ // BudgetTokens 为 max_tokens 的 80%
+ textRequest.Thinking = &dto.Thinking{
+ Type: "enabled",
+ BudgetTokens: common.GetPointer[int](int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
+ }
+ // TODO: 临时处理
+ // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
+ textRequest.TopP = 0
+ textRequest.Temperature = common.GetPointer[float64](1.0)
+ }
+ textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
+ relayInfo.UpstreamModelName = textRequest.Model
+ }
+
+ convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+ jsonData, err := common.Marshal(convertedRequest)
+ if common.DebugEnabled {
+ println("requestBody: ", string(jsonData))
+ }
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+ requestBody = bytes.NewBuffer(jsonData)
+
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+ var httpResp *http.Response
+ resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+ if err != nil {
+ return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
+ }
+
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+ if httpResp.StatusCode != http.StatusOK {
+ newAPIError = service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ }
+
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+ //log.Printf("usage: %v", usage)
+ if newAPIError != nil {
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ return nil
+}
+
+func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) {
+ var promptTokens int
+ var err error
+ switch info.RelayMode {
+ default:
+ promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName)
+ }
+ info.PromptTokens = promptTokens
+ return promptTokens, err
+}
diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go
new file mode 100644
index 00000000..45fde019
--- /dev/null
+++ b/relay/common/relay_info.go
@@ -0,0 +1,344 @@
+package common
+
+import (
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ relayconstant "one-api/relay/constant"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/gorilla/websocket"
+)
+
+type ThinkingContentInfo struct {
+ IsFirstThinkingContent bool
+ SendLastThinkingContent bool
+ HasSentThinkingContent bool
+}
+
+const (
+ LastMessageTypeNone = "none"
+ LastMessageTypeText = "text"
+ LastMessageTypeTools = "tools"
+ LastMessageTypeThinking = "thinking"
+)
+
+type ClaudeConvertInfo struct {
+ LastMessagesType string
+ Index int
+ Usage *dto.Usage
+ FinishReason string
+ Done bool
+}
+
+const (
+ RelayFormatOpenAI = "openai"
+ RelayFormatClaude = "claude"
+ RelayFormatGemini = "gemini"
+ RelayFormatOpenAIResponses = "openai_responses"
+ RelayFormatOpenAIAudio = "openai_audio"
+ RelayFormatOpenAIImage = "openai_image"
+ RelayFormatRerank = "rerank"
+ RelayFormatEmbedding = "embedding"
+)
+
+type RerankerInfo struct {
+ Documents []any
+ ReturnDocuments bool
+}
+
+type BuildInToolInfo struct {
+ ToolName string
+ CallCount int
+ SearchContextSize string
+}
+
+type ResponsesUsageInfo struct {
+ BuiltInTools map[string]*BuildInToolInfo
+}
+
+type RelayInfo struct {
+ ChannelType int
+ ChannelId int
+ TokenId int
+ TokenKey string
+ UserId int
+ UsingGroup string // 使用的分组
+ UserGroup string // 用户所在分组
+ TokenUnlimited bool
+ StartTime time.Time
+ FirstResponseTime time.Time
+ isFirstResponse bool
+ //SendLastReasoningResponse bool
+ ApiType int
+ IsStream bool
+ IsPlayground bool
+ UsePrice bool
+ RelayMode int
+ UpstreamModelName string
+ OriginModelName string
+ //RecodeModelName string
+ RequestURLPath string
+ ApiVersion string
+ PromptTokens int
+ ApiKey string
+ Organization string
+ BaseUrl string
+ SupportStreamOptions bool
+ ShouldIncludeUsage bool
+ IsModelMapped bool
+ ClientWs *websocket.Conn
+ TargetWs *websocket.Conn
+ InputAudioFormat string
+ OutputAudioFormat string
+ RealtimeTools []dto.RealTimeTool
+ IsFirstRequest bool
+ AudioUsage bool
+ ReasoningEffort string
+ ChannelSetting dto.ChannelSettings
+ ParamOverride map[string]interface{}
+ UserSetting dto.UserSetting
+ UserEmail string
+ UserQuota int
+ RelayFormat string
+ SendResponseCount int
+ ChannelCreateTime int64
+ ThinkingContentInfo
+ *ClaudeConvertInfo
+ *RerankerInfo
+ *ResponsesUsageInfo
+}
+
+// 定义支持流式选项的通道类型
+var streamSupportedChannels = map[int]bool{
+ constant.ChannelTypeOpenAI: true,
+ constant.ChannelTypeAnthropic: true,
+ constant.ChannelTypeAws: true,
+ constant.ChannelTypeGemini: true,
+ constant.ChannelCloudflare: true,
+ constant.ChannelTypeAzure: true,
+ constant.ChannelTypeVolcEngine: true,
+ constant.ChannelTypeOllama: true,
+ constant.ChannelTypeXai: true,
+ constant.ChannelTypeDeepSeek: true,
+ constant.ChannelTypeBaiduV2: true,
+}
+
+func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
+ info := GenRelayInfo(c)
+ info.ClientWs = ws
+ info.InputAudioFormat = "pcm16"
+ info.OutputAudioFormat = "pcm16"
+ info.IsFirstRequest = true
+ return info
+}
+
+func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
+ info := GenRelayInfo(c)
+ info.RelayFormat = RelayFormatClaude
+ info.ShouldIncludeUsage = false
+ info.ClaudeConvertInfo = &ClaudeConvertInfo{
+ LastMessagesType: LastMessageTypeNone,
+ }
+ return info
+}
+
+func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
+ info := GenRelayInfo(c)
+ info.RelayMode = relayconstant.RelayModeRerank
+ info.RelayFormat = RelayFormatRerank
+ info.RerankerInfo = &RerankerInfo{
+ Documents: req.Documents,
+ ReturnDocuments: req.GetReturnDocuments(),
+ }
+ return info
+}
+
+func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo {
+ info := GenRelayInfo(c)
+ info.RelayFormat = RelayFormatOpenAIAudio
+ return info
+}
+
+func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo {
+ info := GenRelayInfo(c)
+ info.RelayFormat = RelayFormatEmbedding
+ return info
+}
+
+func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
+ info := GenRelayInfo(c)
+ info.RelayMode = relayconstant.RelayModeResponses
+ info.RelayFormat = RelayFormatOpenAIResponses
+
+ info.SupportStreamOptions = false
+
+ info.ResponsesUsageInfo = &ResponsesUsageInfo{
+ BuiltInTools: make(map[string]*BuildInToolInfo),
+ }
+ if len(req.Tools) > 0 {
+ for _, tool := range req.Tools {
+ toolType := common.Interface2String(tool["type"])
+ info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{
+ ToolName: toolType,
+ CallCount: 0,
+ }
+ switch toolType {
+ case dto.BuildInToolWebSearchPreview:
+ searchContextSize := common.Interface2String(tool["search_context_size"])
+ if searchContextSize == "" {
+ searchContextSize = "medium"
+ }
+ info.ResponsesUsageInfo.BuiltInTools[toolType].SearchContextSize = searchContextSize
+ }
+ }
+ }
+ info.IsStream = req.Stream
+ return info
+}
+
+func GenRelayInfoGemini(c *gin.Context) *RelayInfo {
+ info := GenRelayInfo(c)
+ info.RelayFormat = RelayFormatGemini
+ info.ShouldIncludeUsage = false
+ return info
+}
+
+func GenRelayInfoImage(c *gin.Context) *RelayInfo {
+ info := GenRelayInfo(c)
+ info.RelayFormat = RelayFormatOpenAIImage
+ return info
+}
+
+func GenRelayInfo(c *gin.Context) *RelayInfo {
+ channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
+ channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
+ paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyChannelParamOverride)
+
+ tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
+ tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
+ userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
+ tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
+ startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
+ // firstResponseTime = time.Now() - 1 second
+
+ apiType, _ := common.ChannelType2APIType(channelType)
+
+ info := &RelayInfo{
+ UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
+ UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
+ isFirstResponse: true,
+ RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
+ BaseUrl: common.GetContextKeyString(c, constant.ContextKeyChannelBaseUrl),
+ RequestURLPath: c.Request.URL.String(),
+ ChannelType: channelType,
+ ChannelId: channelId,
+ TokenId: tokenId,
+ TokenKey: tokenKey,
+ UserId: userId,
+ UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
+ UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
+ TokenUnlimited: tokenUnlimited,
+ StartTime: startTime,
+ FirstResponseTime: startTime.Add(-time.Second),
+ OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
+ UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
+ //RecodeModelName: c.GetString("original_model"),
+ IsModelMapped: false,
+ ApiType: apiType,
+ ApiVersion: c.GetString("api_version"),
+ ApiKey: common.GetContextKeyString(c, constant.ContextKeyChannelKey),
+ Organization: c.GetString("channel_organization"),
+
+ ChannelCreateTime: c.GetInt64("channel_create_time"),
+ ParamOverride: paramOverride,
+ RelayFormat: RelayFormatOpenAI,
+ ThinkingContentInfo: ThinkingContentInfo{
+ IsFirstThinkingContent: true,
+ SendLastThinkingContent: false,
+ },
+ }
+ if strings.HasPrefix(c.Request.URL.Path, "/pg") {
+ info.IsPlayground = true
+ info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
+ info.RequestURLPath = "/v1" + info.RequestURLPath
+ }
+ if info.BaseUrl == "" {
+ info.BaseUrl = constant.ChannelBaseURLs[channelType]
+ }
+ if info.ChannelType == constant.ChannelTypeAzure {
+ info.ApiVersion = GetAPIVersion(c)
+ }
+ if info.ChannelType == constant.ChannelTypeVertexAi {
+ info.ApiVersion = c.GetString("region")
+ }
+ if streamSupportedChannels[info.ChannelType] {
+ info.SupportStreamOptions = true
+ }
+
+ channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
+ if ok {
+ info.ChannelSetting = channelSetting
+ }
+ userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting)
+ if ok {
+ info.UserSetting = userSetting
+ }
+
+ return info
+}
+
+func (info *RelayInfo) SetPromptTokens(promptTokens int) {
+ info.PromptTokens = promptTokens
+}
+
+func (info *RelayInfo) SetIsStream(isStream bool) {
+ info.IsStream = isStream
+}
+
+func (info *RelayInfo) SetFirstResponseTime() {
+ if info.isFirstResponse {
+ info.FirstResponseTime = time.Now()
+ info.isFirstResponse = false
+ }
+}
+
+func (info *RelayInfo) HasSendResponse() bool {
+ return info.FirstResponseTime.After(info.StartTime)
+}
+
+type TaskRelayInfo struct {
+ *RelayInfo
+ Action string
+ OriginTaskID string
+
+ ConsumeQuota bool
+}
+
+func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
+ info := &TaskRelayInfo{
+ RelayInfo: GenRelayInfo(c),
+ }
+ return info
+}
+
+type TaskSubmitReq struct {
+ Prompt string `json:"prompt"`
+ Model string `json:"model,omitempty"`
+ Mode string `json:"mode,omitempty"`
+ Image string `json:"image,omitempty"`
+ Size string `json:"size,omitempty"`
+ Duration int `json:"duration,omitempty"`
+ Metadata map[string]interface{} `json:"metadata,omitempty"`
+}
+
+type TaskInfo struct {
+ Code int `json:"code"`
+ TaskID string `json:"task_id"`
+ Status string `json:"status"`
+ Reason string `json:"reason,omitempty"`
+ Url string `json:"url,omitempty"`
+ Progress string `json:"progress,omitempty"`
+}
diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go
new file mode 100644
index 00000000..29086585
--- /dev/null
+++ b/relay/common/relay_utils.go
@@ -0,0 +1,34 @@
+package common
+
+import (
+ "fmt"
+ "github.com/gin-gonic/gin"
+ _ "image/gif"
+ _ "image/jpeg"
+ _ "image/png"
+ "one-api/constant"
+ "strings"
+)
+
+func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
+ fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+
+ if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
+ switch channelType {
+ case constant.ChannelTypeOpenAI:
+ fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
+ case constant.ChannelTypeAzure:
+ fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
+ }
+ }
+ return fullRequestURL
+}
+
+func GetAPIVersion(c *gin.Context) string {
+ query := c.Request.URL.Query()
+ apiVersion := query.Get("api-version")
+ if apiVersion == "" {
+ apiVersion = c.GetString("api_version")
+ }
+ return apiVersion
+}
diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go
new file mode 100644
index 00000000..ce823b3a
--- /dev/null
+++ b/relay/common_handler/rerank.go
@@ -0,0 +1,73 @@
+package common_handler
+
+import (
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/relay/channel/xinference"
+ relaycommon "one-api/relay/common"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
+ }
+ common.CloseResponseBodyGracefully(resp)
+ if common.DebugEnabled {
+ println("reranker response body: ", string(responseBody))
+ }
+ var jinaResp dto.RerankResponse
+ if info.ChannelType == constant.ChannelTypeXinference {
+ var xinRerankResponse xinference.XinRerankResponse
+ err = common.Unmarshal(responseBody, &xinRerankResponse)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
+ for i, result := range xinRerankResponse.Results {
+ respResult := dto.RerankResponseResult{
+ Index: result.Index,
+ RelevanceScore: result.RelevanceScore,
+ }
+ if info.ReturnDocuments {
+ var document any
+ if result.Document != nil {
+ if doc, ok := result.Document.(string); ok {
+ if doc == "" {
+ document = info.Documents[result.Index]
+ } else {
+ document = doc
+ }
+ } else {
+ document = result.Document
+ }
+ }
+ respResult.Document = document
+ }
+ jinaRespResults[i] = respResult
+ }
+ jinaResp = dto.RerankResponse{
+ Results: jinaRespResults,
+ Usage: dto.Usage{
+ PromptTokens: info.PromptTokens,
+ TotalTokens: info.PromptTokens,
+ },
+ }
+ } else {
+ err = common.Unmarshal(responseBody, &jinaResp)
+ if err != nil {
+ return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+ }
+ jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
+ }
+
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.JSON(http.StatusOK, jinaResp)
+ return &jinaResp.Usage, nil
+}
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
new file mode 100644
index 00000000..b5195752
--- /dev/null
+++ b/relay/constant/relay_mode.go
@@ -0,0 +1,167 @@
+package constant
+
+import (
+ "net/http"
+ "strings"
+)
+
+const (
+ RelayModeUnknown = iota
+ RelayModeChatCompletions
+ RelayModeCompletions
+ RelayModeEmbeddings
+ RelayModeModerations
+ RelayModeImagesGenerations
+ RelayModeImagesEdits
+ RelayModeEdits
+
+ RelayModeMidjourneyImagine
+ RelayModeMidjourneyDescribe
+ RelayModeMidjourneyBlend
+ RelayModeMidjourneyChange
+ RelayModeMidjourneySimpleChange
+ RelayModeMidjourneyNotify
+ RelayModeMidjourneyTaskFetch
+ RelayModeMidjourneyTaskImageSeed
+ RelayModeMidjourneyTaskFetchByCondition
+ RelayModeMidjourneyAction
+ RelayModeMidjourneyModal
+ RelayModeMidjourneyShorten
+ RelayModeSwapFace
+ RelayModeMidjourneyUpload
+ RelayModeMidjourneyVideo
+ RelayModeMidjourneyEdits
+
+ RelayModeAudioSpeech // tts
+ RelayModeAudioTranscription // whisper
+ RelayModeAudioTranslation // whisper
+
+ RelayModeSunoFetch
+ RelayModeSunoFetchByID
+ RelayModeSunoSubmit
+
+ RelayModeKlingFetchByID
+ RelayModeKlingSubmit
+
+ RelayModeJimengFetchByID
+ RelayModeJimengSubmit
+
+ RelayModeRerank
+
+ RelayModeResponses
+
+ RelayModeRealtime
+
+ RelayModeGemini
+)
+
+func Path2RelayMode(path string) int {
+ relayMode := RelayModeUnknown
+ if strings.HasPrefix(path, "/v1/chat/completions") || strings.HasPrefix(path, "/pg/chat/completions") {
+ relayMode = RelayModeChatCompletions
+ } else if strings.HasPrefix(path, "/v1/completions") {
+ relayMode = RelayModeCompletions
+ } else if strings.HasPrefix(path, "/v1/embeddings") {
+ relayMode = RelayModeEmbeddings
+ } else if strings.HasSuffix(path, "embeddings") {
+ relayMode = RelayModeEmbeddings
+ } else if strings.HasPrefix(path, "/v1/moderations") {
+ relayMode = RelayModeModerations
+ } else if strings.HasPrefix(path, "/v1/images/generations") {
+ relayMode = RelayModeImagesGenerations
+ } else if strings.HasPrefix(path, "/v1/images/edits") {
+ relayMode = RelayModeImagesEdits
+ } else if strings.HasPrefix(path, "/v1/edits") {
+ relayMode = RelayModeEdits
+ } else if strings.HasPrefix(path, "/v1/responses") {
+ relayMode = RelayModeResponses
+ } else if strings.HasPrefix(path, "/v1/audio/speech") {
+ relayMode = RelayModeAudioSpeech
+ } else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
+ relayMode = RelayModeAudioTranscription
+ } else if strings.HasPrefix(path, "/v1/audio/translations") {
+ relayMode = RelayModeAudioTranslation
+ } else if strings.HasPrefix(path, "/v1/rerank") {
+ relayMode = RelayModeRerank
+ } else if strings.HasPrefix(path, "/v1/realtime") {
+ relayMode = RelayModeRealtime
+ } else if strings.HasPrefix(path, "/v1beta/models") || strings.HasPrefix(path, "/v1/models") {
+ relayMode = RelayModeGemini
+ }
+ return relayMode
+}
+
+func Path2RelayModeMidjourney(path string) int {
+ relayMode := RelayModeUnknown
+ if strings.HasSuffix(path, "/mj/submit/action") {
+ // midjourney plus
+ relayMode = RelayModeMidjourneyAction
+ } else if strings.HasSuffix(path, "/mj/submit/modal") {
+ // midjourney plus
+ relayMode = RelayModeMidjourneyModal
+ } else if strings.HasSuffix(path, "/mj/submit/shorten") {
+ // midjourney plus
+ relayMode = RelayModeMidjourneyShorten
+ } else if strings.HasSuffix(path, "/mj/insight-face/swap") {
+ // midjourney plus
+ relayMode = RelayModeSwapFace
+ } else if strings.HasSuffix(path, "/submit/upload-discord-images") {
+ // midjourney plus
+ relayMode = RelayModeMidjourneyUpload
+ } else if strings.HasSuffix(path, "/mj/submit/imagine") {
+ relayMode = RelayModeMidjourneyImagine
+ } else if strings.HasSuffix(path, "/mj/submit/video") {
+ relayMode = RelayModeMidjourneyVideo
+ } else if strings.HasSuffix(path, "/mj/submit/edits") {
+ relayMode = RelayModeMidjourneyEdits
+ } else if strings.HasSuffix(path, "/mj/submit/blend") {
+ relayMode = RelayModeMidjourneyBlend
+ } else if strings.HasSuffix(path, "/mj/submit/describe") {
+ relayMode = RelayModeMidjourneyDescribe
+ } else if strings.HasSuffix(path, "/mj/notify") {
+ relayMode = RelayModeMidjourneyNotify
+ } else if strings.HasSuffix(path, "/mj/submit/change") {
+ relayMode = RelayModeMidjourneyChange
+ } else if strings.HasSuffix(path, "/mj/submit/simple-change") {
+ relayMode = RelayModeMidjourneyChange
+ } else if strings.HasSuffix(path, "/fetch") {
+ relayMode = RelayModeMidjourneyTaskFetch
+ } else if strings.HasSuffix(path, "/image-seed") {
+ relayMode = RelayModeMidjourneyTaskImageSeed
+ } else if strings.HasSuffix(path, "/list-by-condition") {
+ relayMode = RelayModeMidjourneyTaskFetchByCondition
+ }
+ return relayMode
+}
+
+func Path2RelaySuno(method, path string) int {
+ relayMode := RelayModeUnknown
+ if method == http.MethodPost && strings.HasSuffix(path, "/fetch") {
+ relayMode = RelayModeSunoFetch
+ } else if method == http.MethodGet && strings.Contains(path, "/fetch/") {
+ relayMode = RelayModeSunoFetchByID
+ } else if strings.Contains(path, "/submit/") {
+ relayMode = RelayModeSunoSubmit
+ }
+ return relayMode
+}
+
+func Path2RelayKling(method, path string) int {
+ relayMode := RelayModeUnknown
+ if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
+ relayMode = RelayModeKlingSubmit
+ } else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
+ relayMode = RelayModeKlingFetchByID
+ }
+ return relayMode
+}
+
+func Path2RelayJimeng(method, path string) int {
+ relayMode := RelayModeUnknown
+ if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
+ relayMode = RelayModeJimengSubmit
+ } else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
+ relayMode = RelayModeJimengFetchByID
+ }
+ return relayMode
+}
diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go
new file mode 100644
index 00000000..be11bb2b
--- /dev/null
+++ b/relay/embedding_handler.go
@@ -0,0 +1,116 @@
+package relay
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
+ token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
+ return token
+}
+
+func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error {
+ if embeddingRequest.Input == nil {
+ return fmt.Errorf("input is empty")
+ }
+ if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
+ embeddingRequest.Model = "omni-moderation-latest"
+ }
+ if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
+ embeddingRequest.Model = c.Param("model")
+ }
+ return nil
+}
+
+func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
+ relayInfo := relaycommon.GenRelayInfoEmbedding(c)
+
+ var embeddingRequest *dto.EmbeddingRequest
+ err := common.UnmarshalBodyReusable(c, &embeddingRequest)
+ if err != nil {
+ common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
+ return types.NewError(err, types.ErrorCodeInvalidRequest)
+ }
+
+ err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeInvalidRequest)
+ }
+
+ err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+ }
+
+ promptToken := getEmbeddingPromptToken(*embeddingRequest)
+ relayInfo.PromptTokens = promptToken
+
+ priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeModelPriceError)
+ }
+ // pre-consume quota 预消耗配额
+ preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+ if newAPIError != nil {
+ return newAPIError
+ }
+ defer func() {
+ if newAPIError != nil {
+ returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+ }
+ }()
+
+ adaptor := GetAdaptor(relayInfo.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+ }
+ adaptor.Init(relayInfo)
+
+ convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
+
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+ jsonData, err := json.Marshal(convertedRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+ requestBody := bytes.NewBuffer(jsonData)
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+ resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+ if err != nil {
+ return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
+ }
+
+ var httpResp *http.Response
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ if httpResp.StatusCode != http.StatusOK {
+ newAPIError = service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ }
+
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+ if newAPIError != nil {
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ return nil
+}
diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go
new file mode 100644
index 00000000..e448b491
--- /dev/null
+++ b/relay/gemini_handler.go
@@ -0,0 +1,234 @@
+package relay
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/relay/channel/gemini"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting"
+ "one-api/setting/model_setting"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
+ request := &gemini.GeminiChatRequest{}
+ err := common.UnmarshalBodyReusable(c, request)
+ if err != nil {
+ return nil, err
+ }
+ if len(request.Contents) == 0 {
+ return nil, errors.New("contents is required")
+ }
+ return request, nil
+}
+
+// 流模式
+// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx
+func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
+ if c.Query("alt") == "sse" {
+ relayInfo.IsStream = true
+ }
+
+ // if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
+ // relayInfo.IsStream = true
+ // }
+}
+
+func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
+ var inputTexts []string
+ for _, content := range textRequest.Contents {
+ for _, part := range content.Parts {
+ if part.Text != "" {
+ inputTexts = append(inputTexts, part.Text)
+ }
+ }
+ }
+ if len(inputTexts) == 0 {
+ return nil, nil
+ }
+
+ sensitiveWords, err := service.CheckSensitiveInput(inputTexts)
+ return sensitiveWords, err
+}
+
+func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
+ // 计算输入 token 数量
+ var inputTexts []string
+ for _, content := range req.Contents {
+ for _, part := range content.Parts {
+ if part.Text != "" {
+ inputTexts = append(inputTexts, part.Text)
+ }
+ }
+ }
+
+ inputText := strings.Join(inputTexts, "\n")
+ inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
+ info.PromptTokens = inputTokens
+ return inputTokens
+}
+
+func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool {
+ if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
+ return *req.GenerationConfig.ThinkingConfig.ThinkingBudget <= 0
+ }
+ return false
+}
+
+func trimModelThinking(modelName string) string {
+ // 去除模型名称中的 -nothinking 后缀
+ if strings.HasSuffix(modelName, "-nothinking") {
+ return strings.TrimSuffix(modelName, "-nothinking")
+ }
+ // 去除模型名称中的 -thinking 后缀
+ if strings.HasSuffix(modelName, "-thinking") {
+ return strings.TrimSuffix(modelName, "-thinking")
+ }
+
+ // 去除模型名称中的 -thinking-number
+ if strings.Contains(modelName, "-thinking-") {
+ parts := strings.Split(modelName, "-thinking-")
+ if len(parts) > 1 {
+ return parts[0] + "-thinking"
+ }
+ }
+ return modelName
+}
+
+func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
+ req, err := getAndValidateGeminiRequest(c)
+ if err != nil {
+ common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
+ return types.NewError(err, types.ErrorCodeInvalidRequest)
+ }
+
+ relayInfo := relaycommon.GenRelayInfoGemini(c)
+
+ // 检查 Gemini 流式模式
+ checkGeminiStreamMode(c, relayInfo)
+
+ if setting.ShouldCheckPromptSensitive() {
+ sensitiveWords, err := checkGeminiInputSensitive(req)
+ if err != nil {
+ common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
+ return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
+ }
+ }
+
+ // model mapped 模型映射
+ err = helper.ModelMappedHelper(c, relayInfo, req)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+ }
+
+ if value, exists := c.Get("prompt_tokens"); exists {
+ promptTokens := value.(int)
+ relayInfo.SetPromptTokens(promptTokens)
+ } else {
+ promptTokens := getGeminiInputTokens(req, relayInfo)
+ c.Set("prompt_tokens", promptTokens)
+ }
+
+ if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
+ if isNoThinkingRequest(req) {
+ // check is thinking
+ if !strings.Contains(relayInfo.OriginModelName, "-nothinking") {
+ // try to get no thinking model price
+ noThinkingModelName := relayInfo.OriginModelName + "-nothinking"
+ containPrice := helper.ContainPriceOrRatio(noThinkingModelName)
+ if containPrice {
+ relayInfo.OriginModelName = noThinkingModelName
+ relayInfo.UpstreamModelName = noThinkingModelName
+ }
+ }
+ }
+ if req.GenerationConfig.ThinkingConfig == nil {
+ gemini.ThinkingAdaptor(req, relayInfo)
+ }
+ }
+
+ priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeModelPriceError)
+ }
+
+ // pre consume quota
+ preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+ if newAPIError != nil {
+ return newAPIError
+ }
+ defer func() {
+ if newAPIError != nil {
+ returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+ }
+ }()
+
+ adaptor := GetAdaptor(relayInfo.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+ }
+
+ adaptor.Init(relayInfo)
+
+ // Clean up empty system instruction
+ if req.SystemInstructions != nil {
+ hasContent := false
+ for _, part := range req.SystemInstructions.Parts {
+ if part.Text != "" {
+ hasContent = true
+ break
+ }
+ }
+ if !hasContent {
+ req.SystemInstructions = nil
+ }
+ }
+
+ requestBody, err := json.Marshal(req)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+
+ if common.DebugEnabled {
+ println("Gemini request body: %s", string(requestBody))
+ }
+
+ resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
+ if err != nil {
+ common.LogError(c, "Do gemini request failed: "+err.Error())
+ return types.NewError(err, types.ErrorCodeDoRequestFailed)
+ }
+
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+
+ var httpResp *http.Response
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+ if httpResp.StatusCode != http.StatusOK {
+ newAPIError = service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ }
+
+ usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
+ if openaiErr != nil {
+ service.ResetStatusCode(openaiErr, statusCodeMappingStr)
+ return openaiErr
+ }
+
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ return nil
+}
diff --git a/relay/helper/common.go b/relay/helper/common.go
new file mode 100644
index 00000000..5d23b512
--- /dev/null
+++ b/relay/helper/common.go
@@ -0,0 +1,167 @@
+package helper
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+ "github.com/gorilla/websocket"
+)
+
+func SetEventStreamHeaders(c *gin.Context) {
+ // 检查是否已经设置过头部
+ if _, exists := c.Get("event_stream_headers_set"); exists {
+ return
+ }
+
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("Transfer-Encoding", "chunked")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+
+ // 设置标志,表示头部已经设置过
+ c.Set("event_stream_headers_set", true)
+}
+
+func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
+ jsonData, err := json.Marshal(resp)
+ if err != nil {
+ common.SysError("error marshalling stream response: " + err.Error())
+ } else {
+ c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
+ c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
+ }
+ if flusher, ok := c.Writer.(http.Flusher); ok {
+ flusher.Flush()
+ } else {
+ return errors.New("streaming error: flusher not found")
+ }
+ return nil
+}
+
+func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
+ c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
+ c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
+ if flusher, ok := c.Writer.(http.Flusher); ok {
+ flusher.Flush()
+ }
+}
+
+func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) {
+ c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
+ c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)})
+ if flusher, ok := c.Writer.(http.Flusher); ok {
+ flusher.Flush()
+ }
+}
+
+func StringData(c *gin.Context, str string) error {
+ //str = strings.TrimPrefix(str, "data: ")
+ //str = strings.TrimSuffix(str, "\r")
+ c.Render(-1, common.CustomEvent{Data: "data: " + str})
+ if flusher, ok := c.Writer.(http.Flusher); ok {
+ flusher.Flush()
+ } else {
+ return errors.New("streaming error: flusher not found")
+ }
+ return nil
+}
+
+func PingData(c *gin.Context) error {
+ c.Writer.Write([]byte(": PING\n\n"))
+ if flusher, ok := c.Writer.(http.Flusher); ok {
+ flusher.Flush()
+ } else {
+ return errors.New("streaming error: flusher not found")
+ }
+ return nil
+}
+
+func ObjectData(c *gin.Context, object interface{}) error {
+ if object == nil {
+ return errors.New("object is nil")
+ }
+ jsonData, err := common.Marshal(object)
+ if err != nil {
+ return fmt.Errorf("error marshalling object: %w", err)
+ }
+ return StringData(c, string(jsonData))
+}
+
+func Done(c *gin.Context) {
+ _ = StringData(c, "[DONE]")
+}
+
+func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
+ if ws == nil {
+ common.LogError(c, "websocket connection is nil")
+ return errors.New("websocket connection is nil")
+ }
+ //common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
+ return ws.WriteMessage(1, []byte(str))
+}
+
+func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
+ jsonData, err := json.Marshal(object)
+ if err != nil {
+ return fmt.Errorf("error marshalling object: %w", err)
+ }
+ if ws == nil {
+ common.LogError(c, "websocket connection is nil")
+ return errors.New("websocket connection is nil")
+ }
+ //common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
+ return ws.WriteMessage(1, jsonData)
+}
+
+func WssError(c *gin.Context, ws *websocket.Conn, openaiError types.OpenAIError) {
+ errorObj := &dto.RealtimeEvent{
+ Type: "error",
+ EventId: GetLocalRealtimeID(c),
+ Error: &openaiError,
+ }
+ _ = WssObject(c, ws, errorObj)
+}
+
+func GetResponseID(c *gin.Context) string {
+ logID := c.GetString(common.RequestIdKey)
+ return fmt.Sprintf("chatcmpl-%s", logID)
+}
+
+func GetLocalRealtimeID(c *gin.Context) string {
+ logID := c.GetString(common.RequestIdKey)
+ return fmt.Sprintf("evt_%s", logID)
+}
+
+func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
+ return &dto.ChatCompletionsStreamResponse{
+ Id: id,
+ Object: "chat.completion.chunk",
+ Created: createAt,
+ Model: model,
+ SystemFingerprint: nil,
+ Choices: []dto.ChatCompletionsStreamResponseChoice{
+ {
+ FinishReason: &finishReason,
+ },
+ },
+ }
+}
+
+func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
+ return &dto.ChatCompletionsStreamResponse{
+ Id: id,
+ Object: "chat.completion.chunk",
+ Created: createAt,
+ Model: model,
+ SystemFingerprint: nil,
+ Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
+ Usage: &usage,
+ }
+}
diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go
new file mode 100644
index 00000000..c1735149
--- /dev/null
+++ b/relay/helper/model_mapped.go
@@ -0,0 +1,92 @@
+package helper
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ common2 "one-api/common"
+ "one-api/dto"
+ "one-api/relay/common"
+
+ "github.com/gin-gonic/gin"
+)
+
+func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
+ // map model name
+ modelMapping := c.GetString("model_mapping")
+ if modelMapping != "" && modelMapping != "{}" {
+ modelMap := make(map[string]string)
+ err := json.Unmarshal([]byte(modelMapping), &modelMap)
+ if err != nil {
+ return fmt.Errorf("unmarshal_model_mapping_failed")
+ }
+
+ // 支持链式模型重定向,最终使用链尾的模型
+ currentModel := info.OriginModelName
+ visitedModels := map[string]bool{
+ currentModel: true,
+ }
+ for {
+ if mappedModel, exists := modelMap[currentModel]; exists && mappedModel != "" {
+ // 模型重定向循环检测,避免无限循环
+ if visitedModels[mappedModel] {
+ if mappedModel == currentModel {
+ if currentModel == info.OriginModelName {
+ info.IsModelMapped = false
+ return nil
+ } else {
+ info.IsModelMapped = true
+ break
+ }
+ }
+ return errors.New("model_mapping_contains_cycle")
+ }
+ visitedModels[mappedModel] = true
+ currentModel = mappedModel
+ info.IsModelMapped = true
+ } else {
+ break
+ }
+ }
+ if info.IsModelMapped {
+ info.UpstreamModelName = currentModel
+ }
+ }
+ if request != nil {
+ switch info.RelayFormat {
+ case common.RelayFormatGemini:
+ // Gemini 模型映射
+ case common.RelayFormatClaude:
+ if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
+ claudeRequest.Model = info.UpstreamModelName
+ }
+ case common.RelayFormatOpenAIResponses:
+ if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
+ openAIResponsesRequest.Model = info.UpstreamModelName
+ }
+ case common.RelayFormatOpenAIAudio:
+ if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
+ openAIAudioRequest.Model = info.UpstreamModelName
+ }
+ case common.RelayFormatOpenAIImage:
+ if imageRequest, ok := request.(*dto.ImageRequest); ok {
+ imageRequest.Model = info.UpstreamModelName
+ }
+ case common.RelayFormatRerank:
+ if rerankRequest, ok := request.(*dto.RerankRequest); ok {
+ rerankRequest.Model = info.UpstreamModelName
+ }
+ case common.RelayFormatEmbedding:
+ if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
+ embeddingRequest.Model = info.UpstreamModelName
+ }
+ default:
+ if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
+ openAIRequest.Model = info.UpstreamModelName
+ } else {
+ common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
+ }
+ }
+ }
+ return nil
+}
diff --git a/relay/helper/price.go b/relay/helper/price.go
new file mode 100644
index 00000000..e80578e5
--- /dev/null
+++ b/relay/helper/price.go
@@ -0,0 +1,161 @@
+package helper
+
+import (
+ "fmt"
+ "one-api/common"
+ relaycommon "one-api/relay/common"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
+)
+
+type GroupRatioInfo struct {
+ GroupRatio float64
+ GroupSpecialRatio float64
+ HasSpecialRatio bool
+}
+
+type PriceData struct {
+ ModelPrice float64
+ ModelRatio float64
+ CompletionRatio float64
+ CacheRatio float64
+ CacheCreationRatio float64
+ ImageRatio float64
+ UsePrice bool
+ ShouldPreConsumedQuota int
+ GroupRatioInfo GroupRatioInfo
+}
+
+func (p PriceData) ToSetting() string {
+ return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
+}
+
+// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present
+func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
+ groupRatioInfo := GroupRatioInfo{
+ GroupRatio: 1.0, // default ratio
+ GroupSpecialRatio: -1,
+ }
+
+ // check auto group
+ autoGroup, exists := ctx.Get("auto_group")
+ if exists {
+ if common.DebugEnabled {
+ println(fmt.Sprintf("final group: %s", autoGroup))
+ }
+ relayInfo.UsingGroup = autoGroup.(string)
+ }
+
+ // check user group special ratio
+ userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
+ if ok {
+ // user group special ratio
+ groupRatioInfo.GroupSpecialRatio = userGroupRatio
+ groupRatioInfo.GroupRatio = userGroupRatio
+ groupRatioInfo.HasSpecialRatio = true
+ } else {
+ // normal group ratio
+ groupRatioInfo.GroupRatio = ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
+ }
+
+ return groupRatioInfo
+}
+
+func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
+ modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false)
+
+ groupRatioInfo := HandleGroupRatio(c, info)
+
+ var preConsumedQuota int
+ var modelRatio float64
+ var completionRatio float64
+ var cacheRatio float64
+ var imageRatio float64
+ var cacheCreationRatio float64
+ if !usePrice {
+ preConsumedTokens := common.PreConsumedQuota
+ if maxTokens != 0 {
+ preConsumedTokens = promptTokens + maxTokens
+ }
+ var success bool
+ var matchName string
+ modelRatio, success, matchName = ratio_setting.GetModelRatio(info.OriginModelName)
+ if !success {
+ acceptUnsetRatio := false
+ if info.UserSetting.AcceptUnsetRatioModel {
+ acceptUnsetRatio = true
+ }
+ if !acceptUnsetRatio {
+ return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", matchName, matchName)
+ }
+ }
+ completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName)
+ cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName)
+ cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName)
+ imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName)
+ ratio := modelRatio * groupRatioInfo.GroupRatio
+ preConsumedQuota = int(float64(preConsumedTokens) * ratio)
+ } else {
+ preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
+ }
+
+ priceData := PriceData{
+ ModelPrice: modelPrice,
+ ModelRatio: modelRatio,
+ CompletionRatio: completionRatio,
+ GroupRatioInfo: groupRatioInfo,
+ UsePrice: usePrice,
+ CacheRatio: cacheRatio,
+ ImageRatio: imageRatio,
+ CacheCreationRatio: cacheCreationRatio,
+ ShouldPreConsumedQuota: preConsumedQuota,
+ }
+
+ if common.DebugEnabled {
+ println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting()))
+ }
+
+ return priceData, nil
+}
+
+type PerCallPriceData struct {
+ ModelPrice float64
+ Quota int
+ GroupRatioInfo GroupRatioInfo
+}
+
+// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
+func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCallPriceData {
+ groupRatioInfo := HandleGroupRatio(c, info)
+
+ modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
+ // 如果没有配置价格,则使用默认价格
+ if !success {
+ defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
+ if !ok {
+ modelPrice = 0.1
+ } else {
+ modelPrice = defaultPrice
+ }
+ }
+ quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
+ priceData := PerCallPriceData{
+ ModelPrice: modelPrice,
+ Quota: quota,
+ GroupRatioInfo: groupRatioInfo,
+ }
+ return priceData
+}
+
+func ContainPriceOrRatio(modelName string) bool {
+ _, ok := ratio_setting.GetModelPrice(modelName, false)
+ if ok {
+ return true
+ }
+ _, ok, _ = ratio_setting.GetModelRatio(modelName)
+ if ok {
+ return true
+ }
+ return false
+}
diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go
new file mode 100644
index 00000000..b526b1c0
--- /dev/null
+++ b/relay/helper/stream_scanner.go
@@ -0,0 +1,259 @@
+package helper
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ relaycommon "one-api/relay/common"
+ "one-api/setting/operation_setting"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/bytedance/gopkg/util/gopool"
+
+ "github.com/gin-gonic/gin"
+)
+
+const (
+ InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
+ MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
+ DefaultPingInterval = 10 * time.Second
+)
+
+func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
+
+ if resp == nil || dataHandler == nil {
+ return
+ }
+
+ // 确保响应体总是被关闭
+ defer func() {
+ if resp.Body != nil {
+ resp.Body.Close()
+ }
+ }()
+
+ streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
+ if strings.HasPrefix(info.UpstreamModelName, "o") {
+ // twice timeout for thinking model
+ streamingTimeout *= 2
+ }
+
+ var (
+ stopChan = make(chan bool, 3) // 增加缓冲区避免阻塞
+ scanner = bufio.NewScanner(resp.Body)
+ ticker = time.NewTicker(streamingTimeout)
+ pingTicker *time.Ticker
+ writeMutex sync.Mutex // Mutex to protect concurrent writes
+ wg sync.WaitGroup // 用于等待所有 goroutine 退出
+ )
+
+ generalSettings := operation_setting.GetGeneralSetting()
+ pingEnabled := generalSettings.PingIntervalEnabled
+ pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
+ if pingInterval <= 0 {
+ pingInterval = DefaultPingInterval
+ }
+
+ if pingEnabled {
+ pingTicker = time.NewTicker(pingInterval)
+ }
+
+ if common.DebugEnabled {
+ // print timeout and ping interval for debugging
+ println("relay timeout seconds:", common.RelayTimeout)
+ println("streaming timeout seconds:", int64(streamingTimeout.Seconds()))
+ println("ping interval seconds:", int64(pingInterval.Seconds()))
+ }
+
+ // 改进资源清理,确保所有 goroutine 正确退出
+ defer func() {
+ // 通知所有 goroutine 停止
+ common.SafeSendBool(stopChan, true)
+
+ ticker.Stop()
+ if pingTicker != nil {
+ pingTicker.Stop()
+ }
+
+ // 等待所有 goroutine 退出,最多等待5秒
+ done := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(5 * time.Second):
+ common.LogError(c, "timeout waiting for goroutines to exit")
+ }
+
+ close(stopChan)
+ }()
+
+ scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
+ scanner.Split(bufio.ScanLines)
+ SetEventStreamHeaders(c)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ ctx = context.WithValue(ctx, "stop_chan", stopChan)
+
+ // Handle ping data sending with improved error handling
+ if pingEnabled && pingTicker != nil {
+ wg.Add(1)
+ gopool.Go(func() {
+ defer func() {
+ wg.Done()
+ if r := recover(); r != nil {
+ common.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
+ common.SafeSendBool(stopChan, true)
+ }
+ if common.DebugEnabled {
+ println("ping goroutine exited")
+ }
+ }()
+
+ // 添加超时保护,防止 goroutine 无限运行
+ maxPingDuration := 30 * time.Minute // 最大 ping 持续时间
+ pingTimeout := time.NewTimer(maxPingDuration)
+ defer pingTimeout.Stop()
+
+ for {
+ select {
+ case <-pingTicker.C:
+ // 使用超时机制防止写操作阻塞
+ done := make(chan error, 1)
+ go func() {
+ writeMutex.Lock()
+ defer writeMutex.Unlock()
+ done <- PingData(c)
+ }()
+
+ select {
+ case err := <-done:
+ if err != nil {
+ common.LogError(c, "ping data error: "+err.Error())
+ return
+ }
+ if common.DebugEnabled {
+ println("ping data sent")
+ }
+ case <-time.After(10 * time.Second):
+ common.LogError(c, "ping data send timeout")
+ return
+ case <-ctx.Done():
+ return
+ case <-stopChan:
+ return
+ }
+ case <-ctx.Done():
+ return
+ case <-stopChan:
+ return
+ case <-c.Request.Context().Done():
+ // 监听客户端断开连接
+ return
+ case <-pingTimeout.C:
+ common.LogError(c, "ping goroutine max duration reached")
+ return
+ }
+ }
+ })
+ }
+
+ // Scanner goroutine with improved error handling
+ wg.Add(1)
+ common.RelayCtxGo(ctx, func() {
+ defer func() {
+ wg.Done()
+ if r := recover(); r != nil {
+ common.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
+ }
+ common.SafeSendBool(stopChan, true)
+ if common.DebugEnabled {
+ println("scanner goroutine exited")
+ }
+ }()
+
+ for scanner.Scan() {
+ // 检查是否需要停止
+ select {
+ case <-stopChan:
+ return
+ case <-ctx.Done():
+ return
+ case <-c.Request.Context().Done():
+ return
+ default:
+ }
+
+ ticker.Reset(streamingTimeout)
+ data := scanner.Text()
+ if common.DebugEnabled {
+ println(data)
+ }
+
+ if len(data) < 6 {
+ continue
+ }
+ if data[:5] != "data:" && data[:6] != "[DONE]" {
+ continue
+ }
+ data = data[5:]
+ data = strings.TrimLeft(data, " ")
+ data = strings.TrimSuffix(data, "\r")
+ if !strings.HasPrefix(data, "[DONE]") {
+ info.SetFirstResponseTime()
+
+ // 使用超时机制防止写操作阻塞
+ done := make(chan bool, 1)
+ go func() {
+ writeMutex.Lock()
+ defer writeMutex.Unlock()
+ done <- dataHandler(data)
+ }()
+
+ select {
+ case success := <-done:
+ if !success {
+ return
+ }
+ case <-time.After(10 * time.Second):
+ common.LogError(c, "data handler timeout")
+ return
+ case <-ctx.Done():
+ return
+ case <-stopChan:
+ return
+ }
+ }
+ }
+
+ if err := scanner.Err(); err != nil {
+ if err != io.EOF {
+ common.LogError(c, "scanner error: "+err.Error())
+ }
+ }
+ })
+
+ // 主循环等待完成或超时
+ select {
+ case <-ticker.C:
+ // 超时处理逻辑
+ common.LogError(c, "streaming timeout")
+ case <-stopChan:
+ // 正常结束
+ common.LogInfo(c, "streaming finished")
+ case <-c.Request.Context().Done():
+ // 客户端断开连接
+ common.LogInfo(c, "client disconnected")
+ }
+}
diff --git a/relay/image_handler.go b/relay/image_handler.go
new file mode 100644
index 00000000..8e059863
--- /dev/null
+++ b/relay/image_handler.go
@@ -0,0 +1,247 @@
+package relay
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/model"
+ relaycommon "one-api/relay/common"
+ relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) {
+ imageRequest := &dto.ImageRequest{}
+
+ switch info.RelayMode {
+ case relayconstant.RelayModeImagesEdits:
+ _, err := c.MultipartForm()
+ if err != nil {
+ return nil, err
+ }
+ formData := c.Request.PostForm
+ imageRequest.Prompt = formData.Get("prompt")
+ imageRequest.Model = formData.Get("model")
+ imageRequest.N = common.String2Int(formData.Get("n"))
+ imageRequest.Quality = formData.Get("quality")
+ imageRequest.Size = formData.Get("size")
+
+ if imageRequest.Model == "gpt-image-1" {
+ if imageRequest.Quality == "" {
+ imageRequest.Quality = "standard"
+ }
+ }
+ if imageRequest.N == 0 {
+ imageRequest.N = 1
+ }
+
+ if info.ApiType == constant.APITypeVolcEngine {
+ watermark := formData.Has("watermark")
+ imageRequest.Watermark = &watermark
+ }
+ default:
+ err := common.UnmarshalBodyReusable(c, imageRequest)
+ if err != nil {
+ return nil, err
+ }
+
+ if imageRequest.Model == "" {
+ imageRequest.Model = "dall-e-3"
+ }
+
+ if strings.Contains(imageRequest.Size, "×") {
+ return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'")
+ }
+
+ // Not "256x256", "512x512", or "1024x1024"
+ if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" {
+ if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
+ return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024 for dall-e-2 or dall-e")
+ }
+ if imageRequest.Size == "" {
+ imageRequest.Size = "1024x1024"
+ }
+ } else if imageRequest.Model == "dall-e-3" {
+ if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" {
+ return nil, errors.New("size must be one of 1024x1024, 1024x1792 or 1792x1024 for dall-e-3")
+ }
+ if imageRequest.Quality == "" {
+ imageRequest.Quality = "standard"
+ }
+ if imageRequest.Size == "" {
+ imageRequest.Size = "1024x1024"
+ }
+ } else if imageRequest.Model == "gpt-image-1" {
+ if imageRequest.Quality == "" {
+ imageRequest.Quality = "auto"
+ }
+ }
+
+ if imageRequest.Prompt == "" {
+ return nil, errors.New("prompt is required")
+ }
+
+ if imageRequest.N == 0 {
+ imageRequest.N = 1
+ }
+ }
+
+ if setting.ShouldCheckPromptSensitive() {
+ words, err := service.CheckSensitiveInput(imageRequest.Prompt)
+ if err != nil {
+ common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
+ return nil, err
+ }
+ }
+ return imageRequest, nil
+}
+
+func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
+ relayInfo := relaycommon.GenRelayInfoImage(c)
+
+ imageRequest, err := getAndValidImageRequest(c, relayInfo)
+ if err != nil {
+ common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
+ return types.NewError(err, types.ErrorCodeInvalidRequest)
+ }
+
+ err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+ }
+
+ priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeModelPriceError)
+ }
+ var preConsumedQuota int
+ var quota int
+ var userQuota int
+ if !priceData.UsePrice {
+ // modelRatio 16 = modelPrice $0.04
+ // per 1 modelRatio = $0.04 / 16
+ // priceData.ModelPrice = 0.0025 * priceData.ModelRatio
+ preConsumedQuota, userQuota, newAPIError = preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+ if newAPIError != nil {
+ return newAPIError
+ }
+ defer func() {
+ if newAPIError != nil {
+ returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+ }
+ }()
+
+ } else {
+ sizeRatio := 1.0
+ qualityRatio := 1.0
+
+ if strings.HasPrefix(imageRequest.Model, "dall-e") {
+ // Size
+ if imageRequest.Size == "256x256" {
+ sizeRatio = 0.4
+ } else if imageRequest.Size == "512x512" {
+ sizeRatio = 0.45
+ } else if imageRequest.Size == "1024x1024" {
+ sizeRatio = 1
+ } else if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
+ sizeRatio = 2
+ }
+
+ if imageRequest.Model == "dall-e-3" && imageRequest.Quality == "hd" {
+ qualityRatio = 2.0
+ if imageRequest.Size == "1024x1792" || imageRequest.Size == "1792x1024" {
+ qualityRatio = 1.5
+ }
+ }
+ }
+
+ // reset model price
+ priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
+ quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
+ userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeQueryDataError)
+ }
+ if userQuota-quota < 0 {
+ return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota)
+ }
+ }
+
+ adaptor := GetAdaptor(relayInfo.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+ }
+ adaptor.Init(relayInfo)
+
+ var requestBody io.Reader
+
+ convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+ if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
+ requestBody = convertedRequest.(io.Reader)
+ } else {
+ jsonData, err := json.Marshal(convertedRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+ requestBody = bytes.NewBuffer(jsonData)
+ }
+
+ if common.DebugEnabled {
+ println(fmt.Sprintf("image request body: %s", requestBody))
+ }
+
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+
+ resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+ if err != nil {
+ return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
+ }
+ var httpResp *http.Response
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+ if httpResp.StatusCode != http.StatusOK {
+ newAPIError = service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ }
+
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+ if newAPIError != nil {
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+
+ if usage.(*dto.Usage).TotalTokens == 0 {
+ usage.(*dto.Usage).TotalTokens = imageRequest.N
+ }
+ if usage.(*dto.Usage).PromptTokens == 0 {
+ usage.(*dto.Usage).PromptTokens = imageRequest.N
+ }
+ quality := "standard"
+ if imageRequest.Quality == "hd" {
+ quality = "hd"
+ }
+
+ logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, logContent)
+ return nil
+}
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
new file mode 100644
index 00000000..e7f316b9
--- /dev/null
+++ b/relay/relay-mj.go
@@ -0,0 +1,668 @@
+package relay
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/model"
+ relaycommon "one-api/relay/common"
+ relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+func RelayMidjourneyImage(c *gin.Context) {
+ taskId := c.Param("id")
+ midjourneyTask := model.GetByOnlyMJId(taskId)
+ if midjourneyTask == nil {
+ c.JSON(400, gin.H{
+ "error": "midjourney_task_not_found",
+ })
+ return
+ }
+ var httpClient *http.Client
+ if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil {
+ proxy := channel.GetSetting().Proxy
+ if proxy != "" {
+ if httpClient, err = service.NewProxyHttpClient(proxy); err != nil {
+ c.JSON(400, gin.H{
+ "error": "proxy_url_invalid",
+ })
+ return
+ }
+ }
+ }
+ if httpClient == nil {
+ httpClient = service.GetHttpClient()
+ }
+ resp, err := httpClient.Get(midjourneyTask.ImageUrl)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "error": "http_get_image_failed",
+ })
+ return
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ responseBody, _ := io.ReadAll(resp.Body)
+ c.JSON(resp.StatusCode, gin.H{
+ "error": string(responseBody),
+ })
+ return
+ }
+ // 从Content-Type头获取MIME类型
+ contentType := resp.Header.Get("Content-Type")
+ if contentType == "" {
+ // 如果无法确定内容类型,则默认为jpeg
+ contentType = "image/jpeg"
+ }
+ // 设置响应的内容类型
+ c.Writer.Header().Set("Content-Type", contentType)
+ // 将图片流式传输到响应体
+ _, err = io.Copy(c.Writer, resp.Body)
+ if err != nil {
+ log.Println("Failed to stream image:", err)
+ }
+ return
+}
+
+func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
+ var midjRequest dto.MidjourneyDto
+ err := common.UnmarshalBodyReusable(c, &midjRequest)
+ if err != nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "bind_request_body_failed",
+ Properties: nil,
+ Result: "",
+ }
+ }
+ midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
+ if midjourneyTask == nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "midjourney_task_not_found",
+ Properties: nil,
+ Result: "",
+ }
+ }
+ midjourneyTask.Progress = midjRequest.Progress
+ midjourneyTask.PromptEn = midjRequest.PromptEn
+ midjourneyTask.State = midjRequest.State
+ midjourneyTask.SubmitTime = midjRequest.SubmitTime
+ midjourneyTask.StartTime = midjRequest.StartTime
+ midjourneyTask.FinishTime = midjRequest.FinishTime
+ midjourneyTask.ImageUrl = midjRequest.ImageUrl
+ midjourneyTask.VideoUrl = midjRequest.VideoUrl
+ videoUrlsStr, _ := json.Marshal(midjRequest.VideoUrls)
+ midjourneyTask.VideoUrls = string(videoUrlsStr)
+ midjourneyTask.Status = midjRequest.Status
+ midjourneyTask.FailReason = midjRequest.FailReason
+ err = midjourneyTask.Update()
+ if err != nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "update_midjourney_task_failed",
+ }
+ }
+
+ return nil
+}
+
+func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
+ midjourneyTask.MjId = originTask.MjId
+ midjourneyTask.Progress = originTask.Progress
+ midjourneyTask.PromptEn = originTask.PromptEn
+ midjourneyTask.State = originTask.State
+ midjourneyTask.SubmitTime = originTask.SubmitTime
+ midjourneyTask.StartTime = originTask.StartTime
+ midjourneyTask.FinishTime = originTask.FinishTime
+ midjourneyTask.ImageUrl = ""
+ if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled {
+ midjourneyTask.ImageUrl = setting.ServerAddress + "/mj/image/" + originTask.MjId
+ if originTask.Status != "SUCCESS" {
+ midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
+ }
+ } else {
+ midjourneyTask.ImageUrl = originTask.ImageUrl
+ }
+ if originTask.VideoUrl != "" {
+ midjourneyTask.VideoUrl = originTask.VideoUrl
+ }
+ midjourneyTask.Status = originTask.Status
+ midjourneyTask.FailReason = originTask.FailReason
+ midjourneyTask.Action = originTask.Action
+ midjourneyTask.Description = originTask.Description
+ midjourneyTask.Prompt = originTask.Prompt
+ if originTask.Buttons != "" {
+ var buttons []dto.ActionButton
+ err := json.Unmarshal([]byte(originTask.Buttons), &buttons)
+ if err == nil {
+ midjourneyTask.Buttons = buttons
+ }
+ }
+ if originTask.VideoUrls != "" {
+ var videoUrls []dto.ImgUrls
+ err := json.Unmarshal([]byte(originTask.VideoUrls), &videoUrls)
+ if err == nil {
+ midjourneyTask.VideoUrls = videoUrls
+ }
+ }
+ if originTask.Properties != "" {
+ var properties dto.Properties
+ err := json.Unmarshal([]byte(originTask.Properties), &properties)
+ if err == nil {
+ midjourneyTask.Properties = &properties
+ }
+ }
+ return
+}
+
+func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
+ startTime := time.Now().UnixNano() / int64(time.Millisecond)
+ tokenId := c.GetInt("token_id")
+ userId := c.GetInt("id")
+ //group := c.GetString("group")
+ channelId := c.GetInt("channel_id")
+ relayInfo := relaycommon.GenRelayInfo(c)
+ var swapFaceRequest dto.SwapFaceRequest
+ err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
+ }
+ if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
+ }
+ modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
+
+ priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
+
+ userQuota, err := model.GetUserQuota(userId, false)
+ if err != nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: err.Error(),
+ }
+ }
+
+ if userQuota-priceData.Quota < 0 {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "quota_not_enough",
+ }
+ }
+ requestURL := getMjRequestPath(c.Request.URL.String())
+ baseURL := c.GetString("base_url")
+ fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+ mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
+ if err != nil {
+ return &mjResp.Response
+ }
+ defer func() {
+ if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
+ err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
+ if err != nil {
+ common.SysError("error consuming token remain quota: " + err.Error())
+ }
+
+ tokenName := c.GetString("token_name")
+ logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
+ other := service.GenerateMjOtherInfo(priceData)
+ model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
+ ChannelId: channelId,
+ ModelName: modelName,
+ TokenName: tokenName,
+ Quota: priceData.Quota,
+ Content: logContent,
+ TokenId: tokenId,
+ UserQuota: userQuota,
+ Group: relayInfo.UsingGroup,
+ Other: other,
+ })
+ model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
+ model.UpdateChannelUsedQuota(channelId, priceData.Quota)
+ }
+ }()
+ midjResponse := &mjResp.Response
+ midjourneyTask := &model.Midjourney{
+ UserId: userId,
+ Code: midjResponse.Code,
+ Action: constant.MjActionSwapFace,
+ MjId: midjResponse.Result,
+ Prompt: "InsightFace",
+ PromptEn: "",
+ Description: midjResponse.Description,
+ State: "",
+ SubmitTime: startTime,
+ StartTime: time.Now().UnixNano() / int64(time.Millisecond),
+ FinishTime: 0,
+ ImageUrl: "",
+ Status: "",
+ Progress: "0%",
+ FailReason: "",
+ ChannelId: c.GetInt("channel_id"),
+ Quota: priceData.Quota,
+ }
+ err = midjourneyTask.Insert()
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed")
+ }
+ c.Writer.WriteHeader(mjResp.StatusCode)
+ respBody, err := json.Marshal(midjResponse)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
+ }
+ _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
+ }
+ return nil
+}
+
+func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
+ taskId := c.Param("id")
+ userId := c.GetInt("id")
+ originTask := model.GetByMJId(userId, taskId)
+ if originTask == nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
+ }
+ channel, err := model.GetChannelById(originTask.ChannelId, true)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
+ }
+ if channel.Status != common.ChannelStatusEnabled {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
+ }
+ c.Set("channel_id", originTask.ChannelId)
+ c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+
+ requestURL := getMjRequestPath(c.Request.URL.String())
+ fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
+ midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
+ if err != nil {
+ return &midjResponseWithStatus.Response
+ }
+ midjResponse := &midjResponseWithStatus.Response
+ c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
+ respBody, err := json.Marshal(midjResponse)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
+ }
+ common.IOCopyBytesGracefully(c, nil, respBody)
+ return nil
+}
+
+func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
+ userId := c.GetInt("id")
+ var err error
+ var respBody []byte
+ switch relayMode {
+ case relayconstant.RelayModeMidjourneyTaskFetch:
+ taskId := c.Param("id")
+ originTask := model.GetByMJId(userId, taskId)
+ if originTask == nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "task_no_found",
+ }
+ }
+ midjourneyTask := coverMidjourneyTaskDto(c, originTask)
+ respBody, err = json.Marshal(midjourneyTask)
+ if err != nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "unmarshal_response_body_failed",
+ }
+ }
+ case relayconstant.RelayModeMidjourneyTaskFetchByCondition:
+ var condition = struct {
+ IDs []string `json:"ids"`
+ }{}
+ err = c.BindJSON(&condition)
+ if err != nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "do_request_failed",
+ }
+ }
+ var tasks []dto.MidjourneyDto
+ if len(condition.IDs) != 0 {
+ originTasks := model.GetByMJIds(userId, condition.IDs)
+ for _, originTask := range originTasks {
+ midjourneyTask := coverMidjourneyTaskDto(c, originTask)
+ tasks = append(tasks, midjourneyTask)
+ }
+ }
+ if tasks == nil {
+ tasks = make([]dto.MidjourneyDto, 0)
+ }
+ respBody, err = json.Marshal(tasks)
+ if err != nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "unmarshal_response_body_failed",
+ }
+ }
+ }
+
+ c.Writer.Header().Set("Content-Type", "application/json")
+
+ _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
+ if err != nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "copy_response_body_failed",
+ }
+ }
+ return nil
+}
+
+func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
+
+ //tokenId := c.GetInt("token_id")
+ //channelType := c.GetInt("channel")
+ userId := c.GetInt("id")
+ group := c.GetString("group")
+ channelId := c.GetInt("channel_id")
+ relayInfo := relaycommon.GenRelayInfo(c)
+ consumeQuota := true
+ var midjRequest dto.MidjourneyRequest
+ err := common.UnmarshalBodyReusable(c, &midjRequest)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
+ }
+
+ if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
+ mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
+ if mjErr != nil {
+ return mjErr
+ }
+ relayMode = relayconstant.RelayModeMidjourneyChange
+ }
+ if relayMode == relayconstant.RelayModeMidjourneyVideo {
+ midjRequest.Action = constant.MjActionVideo
+ }
+
+ if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
+ if midjRequest.Prompt == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
+ }
+ midjRequest.Action = constant.MjActionImagine
+ } else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
+ midjRequest.Action = constant.MjActionDescribe
+ } else if relayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复
+ midjRequest.Action = constant.MjActionEdits
+ } else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
+ midjRequest.Action = constant.MjActionShorten
+ } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
+ midjRequest.Action = constant.MjActionBlend
+ } else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
+ midjRequest.Action = constant.MjActionUpload
+ } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
+ mjId := ""
+ if relayMode == relayconstant.RelayModeMidjourneyChange {
+ if midjRequest.TaskId == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
+ } else if midjRequest.Action == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
+ } else if midjRequest.Index == 0 {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_is_required")
+ }
+ //action = midjRequest.Action
+ mjId = midjRequest.TaskId
+ } else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
+ if midjRequest.Content == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
+ }
+ params := service.ConvertSimpleChangeParams(midjRequest.Content)
+ if params == nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
+ }
+ mjId = params.TaskId
+ midjRequest.Action = params.Action
+ } else if relayMode == relayconstant.RelayModeMidjourneyModal {
+ //if midjRequest.MaskBase64 == "" {
+ // return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
+ //}
+ mjId = midjRequest.TaskId
+ midjRequest.Action = constant.MjActionModal
+ } else if relayMode == relayconstant.RelayModeMidjourneyVideo {
+ midjRequest.Action = constant.MjActionVideo
+ if midjRequest.TaskId == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
+ } else if midjRequest.Action == "" {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
+ }
+ mjId = midjRequest.TaskId
+ }
+
+ originTask := model.GetByMJId(userId, mjId)
+ if originTask == nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
+ } else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
+ if setting.MjActionCheckSuccessEnabled {
+ if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
+ }
+ }
+ channel, err := model.GetChannelById(originTask.ChannelId, true)
+ if err != nil {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
+ }
+ if channel.Status != common.ChannelStatusEnabled {
+ return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
+ }
+ c.Set("base_url", channel.GetBaseURL())
+ c.Set("channel_id", originTask.ChannelId)
+ c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+ log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
+ }
+ midjRequest.Prompt = originTask.Prompt
+
+ //if channelType == common.ChannelTypeMidjourneyPlus {
+ // // plus
+ //} else {
+ // // 普通版渠道
+ //
+ //}
+ }
+
+ if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom {
+ consumeQuota = false
+ }
+
+ //baseURL := common.ChannelBaseURLs[channelType]
+ requestURL := getMjRequestPath(c.Request.URL.String())
+
+ baseURL := c.GetString("base_url")
+
+ //midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
+
+ fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+
+ modelName := service.CoverActionToModelName(midjRequest.Action)
+
+ priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
+
+ userQuota, err := model.GetUserQuota(userId, false)
+ if err != nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: err.Error(),
+ }
+ }
+
+ if consumeQuota && userQuota-priceData.Quota < 0 {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "quota_not_enough",
+ }
+ }
+
+ midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
+ if err != nil {
+ return &midjResponseWithStatus.Response
+ }
+ midjResponse := &midjResponseWithStatus.Response
+
+ defer func() {
+ if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
+ err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
+ if err != nil {
+ common.SysError("error consuming token remain quota: " + err.Error())
+ }
+ tokenName := c.GetString("token_name")
+ logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
+ other := service.GenerateMjOtherInfo(priceData)
+ model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
+ ChannelId: channelId,
+ ModelName: modelName,
+ TokenName: tokenName,
+ Quota: priceData.Quota,
+ Content: logContent,
+ TokenId: relayInfo.TokenId,
+ UserQuota: userQuota,
+ Group: group,
+ Other: other,
+ })
+ model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
+ model.UpdateChannelUsedQuota(channelId, priceData.Quota)
+ }
+ }()
+
+ // 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
+ //1-提交成功
+ // 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
+ // 22-排队中 {"code":22,"description":"排队中,前面还有1个任务","result":"0741798445574458","properties":{"numberOfQueues":1,"discordInstanceId":"1118138338562560102"}}
+ // 23-队列已满,请稍后再试 {"code":23,"description":"队列已满,请稍后尝试","result":"14001929738841620","properties":{"discordInstanceId":"1118138338562560102"}}
+ // 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
+ // other: 提交错误,description为错误描述
+ midjourneyTask := &model.Midjourney{
+ UserId: userId,
+ Code: midjResponse.Code,
+ Action: midjRequest.Action,
+ MjId: midjResponse.Result,
+ Prompt: midjRequest.Prompt,
+ PromptEn: "",
+ Description: midjResponse.Description,
+ State: "",
+ SubmitTime: time.Now().UnixNano() / int64(time.Millisecond),
+ StartTime: 0,
+ FinishTime: 0,
+ ImageUrl: "",
+ Status: "",
+ Progress: "0%",
+ FailReason: "",
+ ChannelId: c.GetInt("channel_id"),
+ Quota: priceData.Quota,
+ }
+ if midjResponse.Code == 3 {
+ //无实例账号自动禁用渠道(No available account instance)
+ channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
+ if err != nil {
+ common.SysError("get_channel_null: " + err.Error())
+ }
+ if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
+ model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")
+ }
+ }
+ if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
+ //非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
+ midjourneyTask.FailReason = midjResponse.Description
+ consumeQuota = false
+ }
+
+ if midjResponse.Code == 21 { //21-任务已存在(处理中或者有结果了)
+ // 将 properties 转换为一个 map
+ properties, ok := midjResponse.Properties.(map[string]interface{})
+ if ok {
+ imageUrl, ok1 := properties["imageUrl"].(string)
+ status, ok2 := properties["status"].(string)
+ if ok1 && ok2 {
+ midjourneyTask.ImageUrl = imageUrl
+ midjourneyTask.Status = status
+ if status == "SUCCESS" {
+ midjourneyTask.Progress = "100%"
+ midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond)
+ midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond)
+ midjResponse.Code = 1
+ }
+ }
+ }
+ //修改返回值
+ if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom {
+ newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
+ responseBody = []byte(newBody)
+ }
+ }
+ if midjResponse.Code == 1 && midjRequest.Action == "UPLOAD" {
+ midjourneyTask.Progress = "100%"
+ midjourneyTask.Status = "SUCCESS"
+ }
+ err = midjourneyTask.Insert()
+ if err != nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "insert_midjourney_task_failed",
+ }
+ }
+
+ if midjResponse.Code == 22 { //22-排队中,说明任务已存在
+ //修改返回值
+ newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
+ responseBody = []byte(newBody)
+ }
+ //resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+ bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
+
+ //for k, v := range resp.Header {
+ // c.Writer.Header().Set(k, v[0])
+ //}
+ c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
+
+ _, err = io.Copy(c.Writer, bodyReader)
+ if err != nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "copy_response_body_failed",
+ }
+ }
+ err = bodyReader.Close()
+ if err != nil {
+ return &dto.MidjourneyResponse{
+ Code: 4,
+ Description: "close_response_body_failed",
+ }
+ }
+ return nil
+}
+
+type taskChangeParams struct {
+ ID string
+ Action string
+ Index int
+}
+
+func getMjRequestPath(path string) string {
+ requestURL := path
+ if strings.Contains(requestURL, "/mj-") {
+ urls := strings.Split(requestURL, "/mj/")
+ if len(urls) < 2 {
+ return requestURL
+ }
+ requestURL = "/mj/" + urls[1]
+ }
+ return requestURL
+}
diff --git a/relay/relay-text.go b/relay/relay-text.go
new file mode 100644
index 00000000..60327074
--- /dev/null
+++ b/relay/relay-text.go
@@ -0,0 +1,571 @@
+package relay
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/model"
+ relaycommon "one-api/relay/common"
+ relayconstant "one-api/relay/constant"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting"
+ "one-api/setting/model_setting"
+ "one-api/setting/operation_setting"
+ "one-api/types"
+ "strings"
+ "time"
+
+ "github.com/bytedance/gopkg/util/gopool"
+ "github.com/shopspring/decimal"
+
+ "github.com/gin-gonic/gin"
+)
+
+func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
+ textRequest := &dto.GeneralOpenAIRequest{}
+ err := common.UnmarshalBodyReusable(c, textRequest)
+ if err != nil {
+ return nil, err
+ }
+ if relayInfo.RelayMode == relayconstant.RelayModeModerations && textRequest.Model == "" {
+ textRequest.Model = "text-moderation-latest"
+ }
+ if relayInfo.RelayMode == relayconstant.RelayModeEmbeddings && textRequest.Model == "" {
+ textRequest.Model = c.Param("model")
+ }
+
+ if textRequest.MaxTokens > math.MaxInt32/2 {
+ return nil, errors.New("max_tokens is invalid")
+ }
+ if textRequest.Model == "" {
+ return nil, errors.New("model is required")
+ }
+ if textRequest.WebSearchOptions != nil {
+ if textRequest.WebSearchOptions.SearchContextSize != "" {
+ validSizes := map[string]bool{
+ "high": true,
+ "medium": true,
+ "low": true,
+ }
+ if !validSizes[textRequest.WebSearchOptions.SearchContextSize] {
+ return nil, errors.New("invalid search_context_size, must be one of: high, medium, low")
+ }
+ } else {
+ textRequest.WebSearchOptions.SearchContextSize = "medium"
+ }
+ }
+ switch relayInfo.RelayMode {
+ case relayconstant.RelayModeCompletions:
+ if textRequest.Prompt == "" {
+ return nil, errors.New("field prompt is required")
+ }
+ case relayconstant.RelayModeChatCompletions:
+ if len(textRequest.Messages) == 0 {
+ return nil, errors.New("field messages is required")
+ }
+ case relayconstant.RelayModeEmbeddings:
+ case relayconstant.RelayModeModerations:
+ if textRequest.Input == nil || textRequest.Input == "" {
+ return nil, errors.New("field input is required")
+ }
+ case relayconstant.RelayModeEdits:
+ if textRequest.Instruction == "" {
+ return nil, errors.New("field instruction is required")
+ }
+ }
+ relayInfo.IsStream = textRequest.Stream
+ return textRequest, nil
+}
+
+func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
+
+ relayInfo := relaycommon.GenRelayInfo(c)
+
+ // get & validate textRequest 获取并验证文本请求
+ textRequest, err := getAndValidateTextRequest(c, relayInfo)
+
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeInvalidRequest)
+ }
+
+ if textRequest.WebSearchOptions != nil {
+ c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
+ }
+
+ if setting.ShouldCheckPromptSensitive() {
+ words, err := checkRequestSensitive(textRequest, relayInfo)
+ if err != nil {
+ common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
+ return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
+ }
+ }
+
+ err = helper.ModelMappedHelper(c, relayInfo, textRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+ }
+
+ // 获取 promptTokens,如果上下文中已经存在,则直接使用
+ var promptTokens int
+ if value, exists := c.Get("prompt_tokens"); exists {
+ promptTokens = value.(int)
+ relayInfo.PromptTokens = promptTokens
+ } else {
+ promptTokens, err = getPromptTokens(textRequest, relayInfo)
+ // count messages token error 计算promptTokens错误
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeCountTokenFailed)
+ }
+ c.Set("prompt_tokens", promptTokens)
+ }
+
+ priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeModelPriceError)
+ }
+
+ // pre-consume quota 预消耗配额
+ preConsumedQuota, userQuota, newApiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+ if newApiErr != nil {
+ return newApiErr
+ }
+ defer func() {
+ if newApiErr != nil {
+ returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+ }
+ }()
+ includeUsage := false
+ // 判断用户是否需要返回使用情况
+ if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage {
+ includeUsage = true
+ }
+
+ // 如果不支持StreamOptions,将StreamOptions设置为nil
+ if !relayInfo.SupportStreamOptions || !textRequest.Stream {
+ textRequest.StreamOptions = nil
+ } else {
+ // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
+ if constant.ForceStreamOption {
+ textRequest.StreamOptions = &dto.StreamOptions{
+ IncludeUsage: true,
+ }
+ }
+ }
+
+ if includeUsage {
+ relayInfo.ShouldIncludeUsage = true
+ }
+
+ adaptor := GetAdaptor(relayInfo.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+ }
+ adaptor.Init(relayInfo)
+ var requestBody io.Reader
+
+ if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
+ body, err := common.GetRequestBody(c)
+ if err != nil {
+ return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
+ }
+ requestBody = bytes.NewBuffer(body)
+ } else {
+ convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+ jsonData, err := json.Marshal(convertedRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+
+ // apply param override
+ if len(relayInfo.ParamOverride) > 0 {
+ reqMap := make(map[string]interface{})
+ _ = common.Unmarshal(jsonData, &reqMap)
+ for key, value := range relayInfo.ParamOverride {
+ reqMap[key] = value
+ }
+ jsonData, err = common.Marshal(reqMap)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
+ }
+ }
+
+ if common.DebugEnabled {
+ println("requestBody: ", string(jsonData))
+ }
+ requestBody = bytes.NewBuffer(jsonData)
+ }
+
+ var httpResp *http.Response
+ resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+
+ if err != nil {
+ return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
+ }
+
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
+ if httpResp.StatusCode != http.StatusOK {
+ newApiErr = service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(newApiErr, statusCodeMappingStr)
+ return newApiErr
+ }
+ }
+
+ usage, newApiErr := adaptor.DoResponse(c, httpResp, relayInfo)
+ if newApiErr != nil {
+ // reset status code 重置状态码
+ service.ResetStatusCode(newApiErr, statusCodeMappingStr)
+ return newApiErr
+ }
+
+ if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
+ service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ } else {
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ }
+ return nil
+}
+
+func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
+ var promptTokens int
+ var err error
+ switch info.RelayMode {
+ case relayconstant.RelayModeChatCompletions:
+ promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
+ case relayconstant.RelayModeCompletions:
+ promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
+ case relayconstant.RelayModeModerations:
+ promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
+ case relayconstant.RelayModeEmbeddings:
+ promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
+ default:
+ err = errors.New("unknown relay mode")
+ promptTokens = 0
+ }
+ info.PromptTokens = promptTokens
+ return promptTokens, err
+}
+
+func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) ([]string, error) {
+ var err error
+ var words []string
+ switch info.RelayMode {
+ case relayconstant.RelayModeChatCompletions:
+ words, err = service.CheckSensitiveMessages(textRequest.Messages)
+ case relayconstant.RelayModeCompletions:
+ words, err = service.CheckSensitiveInput(textRequest.Prompt)
+ case relayconstant.RelayModeModerations:
+ words, err = service.CheckSensitiveInput(textRequest.Input)
+ case relayconstant.RelayModeEmbeddings:
+ words, err = service.CheckSensitiveInput(textRequest.Input)
+ }
+ return words, err
+}
+
+// 预扣费并返回用户剩余配额
+func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) {
+ userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
+ if err != nil {
+ return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError)
+ }
+ if userQuota <= 0 {
+ return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
+ }
+ if userQuota-preConsumedQuota < 0 {
+ return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
+ }
+ relayInfo.UserQuota = userQuota
+ if userQuota > 100*preConsumedQuota {
+ // 用户额度充足,判断令牌额度是否充足
+ if !relayInfo.TokenUnlimited {
+ // 非无限令牌,判断令牌额度是否充足
+ tokenQuota := c.GetInt("token_quota")
+ if tokenQuota > 100*preConsumedQuota {
+ // 令牌额度充足,信任令牌
+ preConsumedQuota = 0
+ common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
+ }
+ } else {
+ // in this case, we do not pre-consume quota
+ // because the user has enough quota
+ preConsumedQuota = 0
+ common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
+ }
+ }
+
+ if preConsumedQuota > 0 {
+ err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
+ if err != nil {
+ return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden)
+ }
+ err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
+ if err != nil {
+ return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError)
+ }
+ }
+ return preConsumedQuota, userQuota, nil
+}
+
+func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
+ if preConsumedQuota != 0 {
+ gopool.Go(func() {
+ relayInfoCopy := *relayInfo
+
+ err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
+ if err != nil {
+ common.SysError("error return pre-consumed quota: " + err.Error())
+ }
+ })
+ }
+}
+
+func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
+ usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
+ if usage == nil {
+ usage = &dto.Usage{
+ PromptTokens: relayInfo.PromptTokens,
+ CompletionTokens: 0,
+ TotalTokens: relayInfo.PromptTokens,
+ }
+ extraContent += "(可能是请求出错)"
+ }
+ useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
+ promptTokens := usage.PromptTokens
+ cacheTokens := usage.PromptTokensDetails.CachedTokens
+ imageTokens := usage.PromptTokensDetails.ImageTokens
+ audioTokens := usage.PromptTokensDetails.AudioTokens
+ completionTokens := usage.CompletionTokens
+ modelName := relayInfo.OriginModelName
+
+ tokenName := ctx.GetString("token_name")
+ completionRatio := priceData.CompletionRatio
+ cacheRatio := priceData.CacheRatio
+ imageRatio := priceData.ImageRatio
+ modelRatio := priceData.ModelRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
+ modelPrice := priceData.ModelPrice
+
+ // Convert values to decimal for precise calculation
+ dPromptTokens := decimal.NewFromInt(int64(promptTokens))
+ dCacheTokens := decimal.NewFromInt(int64(cacheTokens))
+ dImageTokens := decimal.NewFromInt(int64(imageTokens))
+ dAudioTokens := decimal.NewFromInt(int64(audioTokens))
+ dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
+ dCompletionRatio := decimal.NewFromFloat(completionRatio)
+ dCacheRatio := decimal.NewFromFloat(cacheRatio)
+ dImageRatio := decimal.NewFromFloat(imageRatio)
+ dModelRatio := decimal.NewFromFloat(modelRatio)
+ dGroupRatio := decimal.NewFromFloat(groupRatio)
+ dModelPrice := decimal.NewFromFloat(modelPrice)
+ dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
+
+ ratio := dModelRatio.Mul(dGroupRatio)
+
+ // openai web search 工具计费
+ var dWebSearchQuota decimal.Decimal
+ var webSearchPrice float64
+ // response api 格式工具计费
+ if relayInfo.ResponsesUsageInfo != nil {
+ if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
+ // 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率)
+ webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, webSearchTool.SearchContextSize)
+ dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
+ Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
+ Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
+ extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
+ webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String())
+ }
+ } else if strings.HasSuffix(modelName, "search-preview") {
+ // search-preview 模型不支持 response api
+ searchContextSize := ctx.GetString("chat_completion_web_search_context_size")
+ if searchContextSize == "" {
+ searchContextSize = "medium"
+ }
+ webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
+ dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
+ Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
+ extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
+ searchContextSize, dWebSearchQuota.String())
+ }
+ // claude web search tool 计费
+ var dClaudeWebSearchQuota decimal.Decimal
+ var claudeWebSearchPrice float64
+ claudeWebSearchCallCount := ctx.GetInt("claude_web_search_requests")
+ if claudeWebSearchCallCount > 0 {
+ claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
+ dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
+ Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
+ extraContent += fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
+ claudeWebSearchCallCount, dClaudeWebSearchQuota.String())
+ }
+ // file search tool 计费
+ var dFileSearchQuota decimal.Decimal
+ var fileSearchPrice float64
+ if relayInfo.ResponsesUsageInfo != nil {
+ if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists && fileSearchTool.CallCount > 0 {
+ fileSearchPrice = operation_setting.GetFileSearchPricePerThousand()
+ dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
+ Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
+ Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
+ extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
+ fileSearchTool.CallCount, dFileSearchQuota.String())
+ }
+ }
+
+ var quotaCalculateDecimal decimal.Decimal
+
+ var audioInputQuota decimal.Decimal
+ var audioInputPrice float64
+ if !priceData.UsePrice {
+ baseTokens := dPromptTokens
+ // 减去 cached tokens
+ var cachedTokensWithRatio decimal.Decimal
+ if !dCacheTokens.IsZero() {
+ baseTokens = baseTokens.Sub(dCacheTokens)
+ cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
+ }
+
+ // 减去 image tokens
+ var imageTokensWithRatio decimal.Decimal
+ if !dImageTokens.IsZero() {
+ baseTokens = baseTokens.Sub(dImageTokens)
+ imageTokensWithRatio = dImageTokens.Mul(dImageRatio)
+ }
+
+ // 减去 Gemini audio tokens
+ if !dAudioTokens.IsZero() {
+ audioInputPrice = operation_setting.GetGeminiInputAudioPricePerMillionTokens(modelName)
+ if audioInputPrice > 0 {
+ // 重新计算 base tokens
+ baseTokens = baseTokens.Sub(dAudioTokens)
+ audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
+ extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
+ }
+ }
+ promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio)
+
+ completionQuota := dCompletionTokens.Mul(dCompletionRatio)
+
+ quotaCalculateDecimal = promptQuota.Add(completionQuota).Mul(ratio)
+
+ if !ratio.IsZero() && quotaCalculateDecimal.LessThanOrEqual(decimal.Zero) {
+ quotaCalculateDecimal = decimal.NewFromInt(1)
+ }
+ } else {
+ quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio)
+ }
+ // 添加 responses tools call 调用的配额
+ quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota)
+ quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota)
+ // 添加 audio input 独立计费
+ quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota)
+
+ quota := int(quotaCalculateDecimal.Round(0).IntPart())
+ totalTokens := promptTokens + completionTokens
+
+ var logContent string
+ if !priceData.UsePrice {
+ logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
+ } else {
+ logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
+ }
+
+ // record all the consume log even if quota is 0
+ if totalTokens == 0 {
+ // in this case, must be some error happened
+ // we cannot just return, because we may have to return the pre-consumed quota
+ quota = 0
+ logContent += fmt.Sprintf("(可能是上游超时)")
+ common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
+ } else {
+ model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
+ model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
+ }
+
+ quotaDelta := quota - preConsumedQuota
+ if quotaDelta != 0 {
+ err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
+ if err != nil {
+ common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+ }
+ }
+
+ logModel := modelName
+ if strings.HasPrefix(logModel, "gpt-4-gizmo") {
+ logModel = "gpt-4-gizmo-*"
+ logContent += fmt.Sprintf(",模型 %s", modelName)
+ }
+ if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
+ logModel = "gpt-4o-gizmo-*"
+ logContent += fmt.Sprintf(",模型 %s", modelName)
+ }
+ if extraContent != "" {
+ logContent += ", " + extraContent
+ }
+ other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+ if imageTokens != 0 {
+ other["image"] = true
+ other["image_ratio"] = imageRatio
+ other["image_output"] = imageTokens
+ }
+ if !dWebSearchQuota.IsZero() {
+ if relayInfo.ResponsesUsageInfo != nil {
+ if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
+ other["web_search"] = true
+ other["web_search_call_count"] = webSearchTool.CallCount
+ other["web_search_price"] = webSearchPrice
+ }
+ } else if strings.HasSuffix(modelName, "search-preview") {
+ other["web_search"] = true
+ other["web_search_call_count"] = 1
+ other["web_search_price"] = webSearchPrice
+ }
+ } else if !dClaudeWebSearchQuota.IsZero() {
+ other["web_search"] = true
+ other["web_search_call_count"] = claudeWebSearchCallCount
+ other["web_search_price"] = claudeWebSearchPrice
+ }
+ if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil {
+ if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists {
+ other["file_search"] = true
+ other["file_search_call_count"] = fileSearchTool.CallCount
+ other["file_search_price"] = fileSearchPrice
+ }
+ }
+ if !audioInputQuota.IsZero() {
+ other["audio_input_seperate_price"] = true
+ other["audio_input_token_count"] = audioTokens
+ other["audio_input_price"] = audioInputPrice
+ }
+ model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
+ ChannelId: relayInfo.ChannelId,
+ PromptTokens: promptTokens,
+ CompletionTokens: completionTokens,
+ ModelName: logModel,
+ TokenName: tokenName,
+ Quota: quota,
+ Content: logContent,
+ TokenId: relayInfo.TokenId,
+ UserQuota: userQuota,
+ UseTimeSeconds: int(useTimeSeconds),
+ IsStream: relayInfo.IsStream,
+ Group: relayInfo.UsingGroup,
+ Other: other,
+ })
+}
diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go
new file mode 100644
index 00000000..2ce12a87
--- /dev/null
+++ b/relay/relay_adaptor.go
@@ -0,0 +1,115 @@
+package relay
+
+import (
+ "one-api/constant"
+ commonconstant "one-api/constant"
+ "one-api/relay/channel"
+ "one-api/relay/channel/ali"
+ "one-api/relay/channel/aws"
+ "one-api/relay/channel/baidu"
+ "one-api/relay/channel/baidu_v2"
+ "one-api/relay/channel/claude"
+ "one-api/relay/channel/cloudflare"
+ "one-api/relay/channel/cohere"
+ "one-api/relay/channel/coze"
+ "one-api/relay/channel/deepseek"
+ "one-api/relay/channel/dify"
+ "one-api/relay/channel/gemini"
+ "one-api/relay/channel/jimeng"
+ "one-api/relay/channel/jina"
+ "one-api/relay/channel/mistral"
+ "one-api/relay/channel/mokaai"
+ "one-api/relay/channel/ollama"
+ "one-api/relay/channel/openai"
+ "one-api/relay/channel/palm"
+ "one-api/relay/channel/perplexity"
+ "one-api/relay/channel/siliconflow"
+ taskjimeng "one-api/relay/channel/task/jimeng"
+ "one-api/relay/channel/task/kling"
+ "one-api/relay/channel/task/suno"
+ "one-api/relay/channel/tencent"
+ "one-api/relay/channel/vertex"
+ "one-api/relay/channel/volcengine"
+ "one-api/relay/channel/xai"
+ "one-api/relay/channel/xunfei"
+ "one-api/relay/channel/zhipu"
+ "one-api/relay/channel/zhipu_4v"
+)
+
+func GetAdaptor(apiType int) channel.Adaptor {
+ switch apiType {
+ case constant.APITypeAli:
+ return &ali.Adaptor{}
+ case constant.APITypeAnthropic:
+ return &claude.Adaptor{}
+ case constant.APITypeBaidu:
+ return &baidu.Adaptor{}
+ case constant.APITypeGemini:
+ return &gemini.Adaptor{}
+ case constant.APITypeOpenAI:
+ return &openai.Adaptor{}
+ case constant.APITypePaLM:
+ return &palm.Adaptor{}
+ case constant.APITypeTencent:
+ return &tencent.Adaptor{}
+ case constant.APITypeXunfei:
+ return &xunfei.Adaptor{}
+ case constant.APITypeZhipu:
+ return &zhipu.Adaptor{}
+ case constant.APITypeZhipuV4:
+ return &zhipu_4v.Adaptor{}
+ case constant.APITypeOllama:
+ return &ollama.Adaptor{}
+ case constant.APITypePerplexity:
+ return &perplexity.Adaptor{}
+ case constant.APITypeAws:
+ return &aws.Adaptor{}
+ case constant.APITypeCohere:
+ return &cohere.Adaptor{}
+ case constant.APITypeDify:
+ return &dify.Adaptor{}
+ case constant.APITypeJina:
+ return &jina.Adaptor{}
+ case constant.APITypeCloudflare:
+ return &cloudflare.Adaptor{}
+ case constant.APITypeSiliconFlow:
+ return &siliconflow.Adaptor{}
+ case constant.APITypeVertexAi:
+ return &vertex.Adaptor{}
+ case constant.APITypeMistral:
+ return &mistral.Adaptor{}
+ case constant.APITypeDeepSeek:
+ return &deepseek.Adaptor{}
+ case constant.APITypeMokaAI:
+ return &mokaai.Adaptor{}
+ case constant.APITypeVolcEngine:
+ return &volcengine.Adaptor{}
+ case constant.APITypeBaiduV2:
+ return &baidu_v2.Adaptor{}
+ case constant.APITypeOpenRouter:
+ return &openai.Adaptor{}
+ case constant.APITypeXinference:
+ return &openai.Adaptor{}
+ case constant.APITypeXai:
+ return &xai.Adaptor{}
+ case constant.APITypeCoze:
+ return &coze.Adaptor{}
+ case constant.APITypeJimeng:
+ return &jimeng.Adaptor{}
+ }
+ return nil
+}
+
+func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
+ switch platform {
+ //case constant.APITypeAIProxyLibrary:
+ // return &aiproxy.Adaptor{}
+ case commonconstant.TaskPlatformSuno:
+ return &suno.TaskAdaptor{}
+ case commonconstant.TaskPlatformKling:
+ return &kling.TaskAdaptor{}
+ case commonconstant.TaskPlatformJimeng:
+ return &taskjimeng.TaskAdaptor{}
+ }
+ return nil
+}
diff --git a/relay/relay_task.go b/relay/relay_task.go
new file mode 100644
index 00000000..25f63d40
--- /dev/null
+++ b/relay/relay_task.go
@@ -0,0 +1,289 @@
+package relay
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/model"
+ relaycommon "one-api/relay/common"
+ relayconstant "one-api/relay/constant"
+ "one-api/service"
+ "one-api/setting/ratio_setting"
+
+ "github.com/gin-gonic/gin"
+)
+
+/*
+Task 任务通过平台、Action 区分任务
+*/
+func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
+ platform := constant.TaskPlatform(c.GetString("platform"))
+ relayInfo := relaycommon.GenTaskRelayInfo(c)
+
+ adaptor := GetTaskAdaptor(platform)
+ if adaptor == nil {
+ return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
+ }
+ adaptor.Init(relayInfo)
+ // get & validate taskRequest 获取并验证文本请求
+ taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
+ if taskErr != nil {
+ return
+ }
+
+ modelName := relayInfo.OriginModelName
+ if modelName == "" {
+ modelName = service.CoverTaskActionToModelName(platform, relayInfo.Action)
+ }
+ modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
+ if !success {
+ defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
+ if !ok {
+ modelPrice = 0.1
+ } else {
+ modelPrice = defaultPrice
+ }
+ }
+
+ // 预扣
+ groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
+ var ratio float64
+ userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
+ if hasUserGroupRatio {
+ ratio = modelPrice * userGroupRatio
+ } else {
+ ratio = modelPrice * groupRatio
+ }
+ userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
+ return
+ }
+ quota := int(ratio * common.QuotaPerUnit)
+ if userQuota-quota < 0 {
+ taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
+ return
+ }
+
+ if relayInfo.OriginTaskID != "" {
+ originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
+ return
+ }
+ if !exist {
+ taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
+ return
+ }
+ if originTask.ChannelId != relayInfo.ChannelId {
+ channel, err := model.GetChannelById(originTask.ChannelId, true)
+ if err != nil {
+ taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
+ return
+ }
+ if channel.Status != common.ChannelStatusEnabled {
+ return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
+ }
+ c.Set("base_url", channel.GetBaseURL())
+ c.Set("channel_id", originTask.ChannelId)
+ c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+
+ relayInfo.BaseUrl = channel.GetBaseURL()
+ relayInfo.ChannelId = originTask.ChannelId
+ }
+ }
+
+ // build body
+ requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
+ return
+ }
+ // do request
+ resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+ return
+ }
+ // handle response
+ if resp != nil && resp.StatusCode != http.StatusOK {
+ responseBody, _ := io.ReadAll(resp.Body)
+ taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
+ return
+ }
+
+ defer func() {
+ // release quota
+ if relayInfo.ConsumeQuota && taskErr == nil {
+
+ err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
+ if err != nil {
+ common.SysError("error consuming token remain quota: " + err.Error())
+ }
+ if quota != 0 {
+ tokenName := c.GetString("token_name")
+ gRatio := groupRatio
+ if hasUserGroupRatio {
+ gRatio = userGroupRatio
+ }
+ logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action)
+ other := make(map[string]interface{})
+ other["model_price"] = modelPrice
+ other["group_ratio"] = groupRatio
+ if hasUserGroupRatio {
+ other["user_group_ratio"] = userGroupRatio
+ }
+ model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
+ ChannelId: relayInfo.ChannelId,
+ ModelName: modelName,
+ TokenName: tokenName,
+ Quota: quota,
+ Content: logContent,
+ TokenId: relayInfo.TokenId,
+ UserQuota: userQuota,
+ Group: relayInfo.UsingGroup,
+ Other: other,
+ })
+ model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
+ model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
+ }
+ }
+ }()
+
+ taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
+ if taskErr != nil {
+ return
+ }
+ relayInfo.ConsumeQuota = true
+ // insert task
+ task := model.InitTask(platform, relayInfo)
+ task.TaskID = taskID
+ task.Quota = quota
+ task.Data = taskData
+ task.Action = relayInfo.Action
+ err = task.Insert()
+ if err != nil {
+ taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
+ return
+ }
+ return nil
+}
+
+var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
+ relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
+ relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
+ relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder,
+}
+
+func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
+ respBuilder, ok := fetchRespBuilders[relayMode]
+ if !ok {
+ taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
+ }
+
+ respBody, taskErr := respBuilder(c)
+ if taskErr != nil {
+ return taskErr
+ }
+
+ c.Writer.Header().Set("Content-Type", "application/json")
+ _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
+ if err != nil {
+ taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
+ return
+ }
+ return
+}
+
+func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
+ userId := c.GetInt("id")
+ var condition = struct {
+ IDs []any `json:"ids"`
+ Action string `json:"action"`
+ }{}
+ err := c.BindJSON(&condition)
+ if err != nil {
+ taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
+ return
+ }
+ var tasks []any
+ if len(condition.IDs) > 0 {
+ taskModels, err := model.GetByTaskIds(userId, condition.IDs)
+ if err != nil {
+ taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
+ return
+ }
+ for _, task := range taskModels {
+ tasks = append(tasks, TaskModel2Dto(task))
+ }
+ } else {
+ tasks = make([]any, 0)
+ }
+ respBody, err = json.Marshal(dto.TaskResponse[[]any]{
+ Code: "success",
+ Data: tasks,
+ })
+ return
+}
+
+func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
+ taskId := c.Param("id")
+ userId := c.GetInt("id")
+
+ originTask, exist, err := model.GetByTaskId(userId, taskId)
+ if err != nil {
+ taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
+ return
+ }
+ if !exist {
+ taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
+ return
+ }
+
+ respBody, err = json.Marshal(dto.TaskResponse[any]{
+ Code: "success",
+ Data: TaskModel2Dto(originTask),
+ })
+ return
+}
+
+func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
+ taskId := c.Param("task_id")
+ userId := c.GetInt("id")
+
+ originTask, exist, err := model.GetByTaskId(userId, taskId)
+ if err != nil {
+ taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
+ return
+ }
+ if !exist {
+ taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
+ return
+ }
+
+ respBody, err = json.Marshal(dto.TaskResponse[any]{
+ Code: "success",
+ Data: TaskModel2Dto(originTask),
+ })
+ return
+}
+
+func TaskModel2Dto(task *model.Task) *dto.TaskDto {
+ return &dto.TaskDto{
+ TaskID: task.TaskID,
+ Action: task.Action,
+ Status: string(task.Status),
+ FailReason: task.FailReason,
+ SubmitTime: task.SubmitTime,
+ StartTime: task.StartTime,
+ FinishTime: task.FinishTime,
+ Progress: task.Progress,
+ Data: task.Data,
+ }
+}
diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go
new file mode 100644
index 00000000..a092de4b
--- /dev/null
+++ b/relay/rerank_handler.go
@@ -0,0 +1,110 @@
+package relay
+
+import (
+ "bytes"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+)
+
+func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
+ token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
+ for _, document := range rerankRequest.Documents {
+ tkm := service.CountTokenInput(document, rerankRequest.Model)
+ token += tkm
+ }
+ return token
+}
+
+func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError) {
+
+ var rerankRequest *dto.RerankRequest
+ err := common.UnmarshalBodyReusable(c, &rerankRequest)
+ if err != nil {
+ common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
+ return types.NewError(err, types.ErrorCodeInvalidRequest)
+ }
+
+ relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
+
+ if rerankRequest.Query == "" {
+ return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest)
+ }
+ if len(rerankRequest.Documents) == 0 {
+ return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest)
+ }
+
+ err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+ }
+
+ promptToken := getRerankPromptToken(*rerankRequest)
+ relayInfo.PromptTokens = promptToken
+
+ priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeModelPriceError)
+ }
+ // pre-consume quota 预消耗配额
+ preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+ if newAPIError != nil {
+ return newAPIError
+ }
+ defer func() {
+ if newAPIError != nil {
+ returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+ }
+ }()
+
+ adaptor := GetAdaptor(relayInfo.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+ }
+ adaptor.Init(relayInfo)
+
+ convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+ jsonData, err := common.Marshal(convertedRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+ requestBody := bytes.NewBuffer(jsonData)
+ if common.DebugEnabled {
+ println(fmt.Sprintf("Rerank request body: %s", requestBody.String()))
+ }
+ resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+ if err != nil {
+ return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
+ }
+
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+ var httpResp *http.Response
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+ if httpResp.StatusCode != http.StatusOK {
+ newAPIError = service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ }
+
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+ if newAPIError != nil {
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ return nil
+}
diff --git a/relay/responses_handler.go b/relay/responses_handler.go
new file mode 100644
index 00000000..52d1db6e
--- /dev/null
+++ b/relay/responses_handler.go
@@ -0,0 +1,169 @@
+package relay
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/setting"
+ "one-api/setting/model_setting"
+ "one-api/types"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
+ request := &dto.OpenAIResponsesRequest{}
+ err := common.UnmarshalBodyReusable(c, request)
+ if err != nil {
+ return nil, err
+ }
+ if request.Model == "" {
+ return nil, errors.New("model is required")
+ }
+ if len(request.Input) == 0 {
+ return nil, errors.New("input is required")
+ }
+ return request, nil
+
+}
+
+func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) {
+ sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input)
+ return sensitiveWords, err
+}
+
+func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int {
+ inputTokens := service.CountTokenInput(req.Input, req.Model)
+ info.PromptTokens = inputTokens
+ return inputTokens
+}
+
+func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
+ req, err := getAndValidateResponsesRequest(c)
+ if err != nil {
+ common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
+ return types.NewError(err, types.ErrorCodeInvalidRequest)
+ }
+
+ relayInfo := relaycommon.GenRelayInfoResponses(c, req)
+
+ if setting.ShouldCheckPromptSensitive() {
+ sensitiveWords, err := checkInputSensitive(req, relayInfo)
+ if err != nil {
+ common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
+ return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
+ }
+ }
+
+ err = helper.ModelMappedHelper(c, relayInfo, req)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+ }
+
+ if value, exists := c.Get("prompt_tokens"); exists {
+ promptTokens := value.(int)
+ relayInfo.SetPromptTokens(promptTokens)
+ } else {
+ promptTokens := getInputTokens(req, relayInfo)
+ c.Set("prompt_tokens", promptTokens)
+ }
+
+ priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeModelPriceError)
+ }
+ // pre consume quota
+ preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+ if newAPIError != nil {
+ return newAPIError
+ }
+ defer func() {
+ if newAPIError != nil {
+ returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+ }
+ }()
+ adaptor := GetAdaptor(relayInfo.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+ }
+ adaptor.Init(relayInfo)
+ var requestBody io.Reader
+ if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
+ body, err := common.GetRequestBody(c)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeReadRequestBodyFailed)
+ }
+ requestBody = bytes.NewBuffer(body)
+ } else {
+ convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+ jsonData, err := json.Marshal(convertedRequest)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+ // apply param override
+ if len(relayInfo.ParamOverride) > 0 {
+ reqMap := make(map[string]interface{})
+ err = json.Unmarshal(jsonData, &reqMap)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
+ }
+ for key, value := range relayInfo.ParamOverride {
+ reqMap[key] = value
+ }
+ jsonData, err = json.Marshal(reqMap)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeConvertRequestFailed)
+ }
+ }
+
+ if common.DebugEnabled {
+ println("requestBody: ", string(jsonData))
+ }
+ requestBody = bytes.NewBuffer(jsonData)
+ }
+
+ var httpResp *http.Response
+ resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+ if err != nil {
+ return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
+ }
+
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+
+ if resp != nil {
+ httpResp = resp.(*http.Response)
+
+ if httpResp.StatusCode != http.StatusOK {
+ newAPIError = service.RelayErrorHandler(httpResp, false)
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ }
+
+ usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
+ if newAPIError != nil {
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+
+ if strings.HasPrefix(relayInfo.OriginModelName, "gpt-4o-audio") {
+ service.PostAudioConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ } else {
+ postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+ }
+ return nil
+}
diff --git a/relay/websocket.go b/relay/websocket.go
new file mode 100644
index 00000000..659e27d5
--- /dev/null
+++ b/relay/websocket.go
@@ -0,0 +1,76 @@
+package relay
+
+import (
+ "fmt"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/service"
+ "one-api/types"
+
+ "github.com/gin-gonic/gin"
+ "github.com/gorilla/websocket"
+)
+
+func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) {
+ relayInfo := relaycommon.GenRelayInfoWs(c, ws)
+
+ // get & validate textRequest 获取并验证文本请求
+ //realtimeEvent, err := getAndValidateWssRequest(c, ws)
+ //if err != nil {
+ // common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error()))
+ // return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
+ //}
+
+ err := helper.ModelMappedHelper(c, relayInfo, nil)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeChannelModelMappedError)
+ }
+
+ priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeModelPriceError)
+ }
+
+ // pre-consume quota 预消耗配额
+ preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+ if newAPIError != nil {
+ return newAPIError
+ }
+
+ defer func() {
+ if newAPIError != nil {
+ returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+ }
+ }()
+
+ adaptor := GetAdaptor(relayInfo.ApiType)
+ if adaptor == nil {
+ return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
+ }
+ adaptor.Init(relayInfo)
+ //var requestBody io.Reader
+ //firstWssRequest, _ := c.Get("first_wss_request")
+ //requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
+
+ statusCodeMappingStr := c.GetString("status_code_mapping")
+ resp, err := adaptor.DoRequest(c, relayInfo, nil)
+ if err != nil {
+ return types.NewError(err, types.ErrorCodeDoRequestFailed)
+ }
+
+ if resp != nil {
+ relayInfo.TargetWs = resp.(*websocket.Conn)
+ defer relayInfo.TargetWs.Close()
+ }
+
+ usage, newAPIError := adaptor.DoResponse(c, nil, relayInfo)
+ if newAPIError != nil {
+ // reset status code 重置状态码
+ service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+ return newAPIError
+ }
+ service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
+ userQuota, priceData, "")
+ return nil
+}
diff --git a/router/api-router.go b/router/api-router.go
new file mode 100644
index 00000000..bc49803a
--- /dev/null
+++ b/router/api-router.go
@@ -0,0 +1,179 @@
+package router
+
+import (
+ "one-api/controller"
+ "one-api/middleware"
+
+ "github.com/gin-contrib/gzip"
+ "github.com/gin-gonic/gin"
+)
+
+func SetApiRouter(router *gin.Engine) {
+ apiRouter := router.Group("/api")
+ apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
+ apiRouter.Use(middleware.GlobalAPIRateLimit())
+ {
+ apiRouter.GET("/setup", controller.GetSetup)
+ apiRouter.POST("/setup", controller.PostSetup)
+ apiRouter.GET("/status", controller.GetStatus)
+ apiRouter.GET("/uptime/status", controller.GetUptimeKumaStatus)
+ apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
+ apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
+ apiRouter.GET("/notice", controller.GetNotice)
+ apiRouter.GET("/about", controller.GetAbout)
+ //apiRouter.GET("/midjourney", controller.GetMidjourney)
+ apiRouter.GET("/home_page_content", controller.GetHomePageContent)
+ apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing)
+ apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
+ apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
+ apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
+ apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
+ apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth)
+ apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
+ apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
+ apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
+ apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
+ apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
+ apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
+ apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
+ apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
+
+ apiRouter.POST("/stripe/webhook", controller.StripeWebhook)
+
+ userRoute := apiRouter.Group("/user")
+ {
+ userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
+ userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
+ //userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
+ userRoute.GET("/logout", controller.Logout)
+ userRoute.GET("/epay/notify", controller.EpayNotify)
+ userRoute.GET("/groups", controller.GetUserGroups)
+
+ selfRoute := userRoute.Group("/")
+ selfRoute.Use(middleware.UserAuth())
+ {
+ selfRoute.GET("/self/groups", controller.GetUserGroups)
+ selfRoute.GET("/self", controller.GetSelf)
+ selfRoute.GET("/models", controller.GetUserModels)
+ selfRoute.PUT("/self", controller.UpdateSelf)
+ selfRoute.DELETE("/self", controller.DeleteSelf)
+ selfRoute.GET("/token", controller.GenerateAccessToken)
+ selfRoute.GET("/aff", controller.GetAffCode)
+ selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
+ selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestEpay)
+ selfRoute.POST("/amount", controller.RequestAmount)
+ selfRoute.POST("/stripe/pay", middleware.CriticalRateLimit(), controller.RequestStripePay)
+ selfRoute.POST("/stripe/amount", controller.RequestStripeAmount)
+ selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
+ selfRoute.PUT("/setting", controller.UpdateUserSetting)
+ }
+
+ adminRoute := userRoute.Group("/")
+ adminRoute.Use(middleware.AdminAuth())
+ {
+ adminRoute.GET("/", controller.GetAllUsers)
+ adminRoute.GET("/search", controller.SearchUsers)
+ adminRoute.GET("/:id", controller.GetUser)
+ adminRoute.POST("/", controller.CreateUser)
+ adminRoute.POST("/manage", controller.ManageUser)
+ adminRoute.PUT("/", controller.UpdateUser)
+ adminRoute.DELETE("/:id", controller.DeleteUser)
+ }
+ }
+ optionRoute := apiRouter.Group("/option")
+ optionRoute.Use(middleware.RootAuth())
+ {
+ optionRoute.GET("/", controller.GetOptions)
+ optionRoute.PUT("/", controller.UpdateOption)
+ optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
+ optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
+ }
+ ratioSyncRoute := apiRouter.Group("/ratio_sync")
+ ratioSyncRoute.Use(middleware.RootAuth())
+ {
+ ratioSyncRoute.GET("/channels", controller.GetSyncableChannels)
+ ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios)
+ }
+ channelRoute := apiRouter.Group("/channel")
+ channelRoute.Use(middleware.AdminAuth())
+ {
+ channelRoute.GET("/", controller.GetAllChannels)
+ channelRoute.GET("/search", controller.SearchChannels)
+ channelRoute.GET("/models", controller.ChannelListModels)
+ channelRoute.GET("/models_enabled", controller.EnabledListModels)
+ channelRoute.GET("/:id", controller.GetChannel)
+ channelRoute.GET("/test", controller.TestAllChannels)
+ channelRoute.GET("/test/:id", controller.TestChannel)
+ channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance)
+ channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
+ channelRoute.POST("/", controller.AddChannel)
+ channelRoute.PUT("/", controller.UpdateChannel)
+ channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
+ channelRoute.POST("/tag/disabled", controller.DisableTagChannels)
+ channelRoute.POST("/tag/enabled", controller.EnableTagChannels)
+ channelRoute.PUT("/tag", controller.EditTagChannels)
+ channelRoute.DELETE("/:id", controller.DeleteChannel)
+ channelRoute.POST("/batch", controller.DeleteChannelBatch)
+ channelRoute.POST("/fix", controller.FixChannelsAbilities)
+ channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
+ channelRoute.POST("/fetch_models", controller.FetchModels)
+ channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
+ channelRoute.GET("/tag/models", controller.GetTagModels)
+ channelRoute.POST("/copy/:id", controller.CopyChannel)
+ }
+ tokenRoute := apiRouter.Group("/token")
+ tokenRoute.Use(middleware.UserAuth())
+ {
+ tokenRoute.GET("/", controller.GetAllTokens)
+ tokenRoute.GET("/search", controller.SearchTokens)
+ tokenRoute.GET("/:id", controller.GetToken)
+ tokenRoute.POST("/", controller.AddToken)
+ tokenRoute.PUT("/", controller.UpdateToken)
+ tokenRoute.DELETE("/:id", controller.DeleteToken)
+ tokenRoute.POST("/batch", controller.DeleteTokenBatch)
+ }
+ redemptionRoute := apiRouter.Group("/redemption")
+ redemptionRoute.Use(middleware.AdminAuth())
+ {
+ redemptionRoute.GET("/", controller.GetAllRedemptions)
+ redemptionRoute.GET("/search", controller.SearchRedemptions)
+ redemptionRoute.GET("/:id", controller.GetRedemption)
+ redemptionRoute.POST("/", controller.AddRedemption)
+ redemptionRoute.PUT("/", controller.UpdateRedemption)
+ redemptionRoute.DELETE("/invalid", controller.DeleteInvalidRedemption)
+ redemptionRoute.DELETE("/:id", controller.DeleteRedemption)
+ }
+ logRoute := apiRouter.Group("/log")
+ logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
+ logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs)
+ logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
+ logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
+ logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
+ logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs)
+ logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs)
+
+ dataRoute := apiRouter.Group("/data")
+ dataRoute.GET("/", middleware.AdminAuth(), controller.GetAllQuotaDates)
+ dataRoute.GET("/self", middleware.UserAuth(), controller.GetUserQuotaDates)
+
+ logRoute.Use(middleware.CORS())
+ {
+ logRoute.GET("/token", controller.GetLogByKey)
+
+ }
+ groupRoute := apiRouter.Group("/group")
+ groupRoute.Use(middleware.AdminAuth())
+ {
+ groupRoute.GET("/", controller.GetGroups)
+ }
+ mjRoute := apiRouter.Group("/mj")
+ mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney)
+ mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney)
+
+ taskRoute := apiRouter.Group("/task")
+ {
+ taskRoute.GET("/self", middleware.UserAuth(), controller.GetUserTask)
+ taskRoute.GET("/", middleware.AdminAuth(), controller.GetAllTask)
+ }
+ }
+}
diff --git a/router/dashboard.go b/router/dashboard.go
new file mode 100644
index 00000000..94000679
--- /dev/null
+++ b/router/dashboard.go
@@ -0,0 +1,22 @@
+package router
+
+import (
+ "github.com/gin-contrib/gzip"
+ "github.com/gin-gonic/gin"
+ "one-api/controller"
+ "one-api/middleware"
+)
+
+func SetDashboardRouter(router *gin.Engine) {
+ apiRouter := router.Group("/")
+ apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
+ apiRouter.Use(middleware.GlobalAPIRateLimit())
+ apiRouter.Use(middleware.CORS())
+ apiRouter.Use(middleware.TokenAuth())
+ {
+ apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription)
+ apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription)
+ apiRouter.GET("/dashboard/billing/usage", controller.GetUsage)
+ apiRouter.GET("/v1/dashboard/billing/usage", controller.GetUsage)
+ }
+}
diff --git a/router/main.go b/router/main.go
new file mode 100644
index 00000000..0d2bfdce
--- /dev/null
+++ b/router/main.go
@@ -0,0 +1,31 @@
+package router
+
+import (
+ "embed"
+ "fmt"
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/common"
+ "os"
+ "strings"
+)
+
+func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
+ SetApiRouter(router)
+ SetDashboardRouter(router)
+ SetRelayRouter(router)
+ SetVideoRouter(router)
+ frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
+ if common.IsMasterNode && frontendBaseUrl != "" {
+ frontendBaseUrl = ""
+ common.SysLog("FRONTEND_BASE_URL is ignored on master node")
+ }
+ if frontendBaseUrl == "" {
+ SetWebRouter(router, buildFS, indexPage)
+ } else {
+ frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/")
+ router.NoRoute(func(c *gin.Context) {
+ c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI))
+ })
+ }
+}
diff --git a/router/relay-router.go b/router/relay-router.go
new file mode 100644
index 00000000..5b293dbd
--- /dev/null
+++ b/router/relay-router.go
@@ -0,0 +1,115 @@
+package router
+
+import (
+ "one-api/controller"
+ "one-api/middleware"
+ "one-api/relay"
+
+ "github.com/gin-gonic/gin"
+)
+
+func SetRelayRouter(router *gin.Engine) {
+ router.Use(middleware.CORS())
+ router.Use(middleware.DecompressRequestMiddleware())
+ router.Use(middleware.StatsMiddleware())
+ // https://platform.openai.com/docs/api-reference/introduction
+ modelsRouter := router.Group("/v1/models")
+ modelsRouter.Use(middleware.TokenAuth())
+ {
+ modelsRouter.GET("", controller.ListModels)
+ modelsRouter.GET("/:model", controller.RetrieveModel)
+ }
+ playgroundRouter := router.Group("/pg")
+ playgroundRouter.Use(middleware.UserAuth(), middleware.Distribute())
+ {
+ playgroundRouter.POST("/chat/completions", controller.Playground)
+ }
+ relayV1Router := router.Group("/v1")
+ relayV1Router.Use(middleware.TokenAuth())
+ relayV1Router.Use(middleware.ModelRequestRateLimit())
+ {
+ // WebSocket 路由
+ wsRouter := relayV1Router.Group("")
+ wsRouter.Use(middleware.Distribute())
+ wsRouter.GET("/realtime", controller.WssRelay)
+ }
+ {
+ //http router
+ httpRouter := relayV1Router.Group("")
+ httpRouter.Use(middleware.Distribute())
+ httpRouter.POST("/messages", controller.RelayClaude)
+ httpRouter.POST("/completions", controller.Relay)
+ httpRouter.POST("/chat/completions", controller.Relay)
+ httpRouter.POST("/edits", controller.Relay)
+ httpRouter.POST("/images/generations", controller.Relay)
+ httpRouter.POST("/images/edits", controller.Relay)
+ httpRouter.POST("/images/variations", controller.RelayNotImplemented)
+ httpRouter.POST("/embeddings", controller.Relay)
+ httpRouter.POST("/engines/:model/embeddings", controller.Relay)
+ httpRouter.POST("/audio/transcriptions", controller.Relay)
+ httpRouter.POST("/audio/translations", controller.Relay)
+ httpRouter.POST("/audio/speech", controller.Relay)
+ httpRouter.POST("/responses", controller.Relay)
+ httpRouter.GET("/files", controller.RelayNotImplemented)
+ httpRouter.POST("/files", controller.RelayNotImplemented)
+ httpRouter.DELETE("/files/:id", controller.RelayNotImplemented)
+ httpRouter.GET("/files/:id", controller.RelayNotImplemented)
+ httpRouter.GET("/files/:id/content", controller.RelayNotImplemented)
+ httpRouter.POST("/fine-tunes", controller.RelayNotImplemented)
+ httpRouter.GET("/fine-tunes", controller.RelayNotImplemented)
+ httpRouter.GET("/fine-tunes/:id", controller.RelayNotImplemented)
+ httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
+ httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
+ httpRouter.DELETE("/models/:model", controller.RelayNotImplemented)
+ httpRouter.POST("/moderations", controller.Relay)
+ httpRouter.POST("/rerank", controller.Relay)
+ httpRouter.POST("/models/*path", controller.Relay)
+ }
+
+ relayMjRouter := router.Group("/mj")
+ registerMjRouterGroup(relayMjRouter)
+
+ relayMjModeRouter := router.Group("/:mode/mj")
+ registerMjRouterGroup(relayMjModeRouter)
+ //relayMjRouter.Use()
+
+ relaySunoRouter := router.Group("/suno")
+ relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute())
+ {
+ relaySunoRouter.POST("/submit/:action", controller.RelayTask)
+ relaySunoRouter.POST("/fetch", controller.RelayTask)
+ relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
+ }
+
+ relayGeminiRouter := router.Group("/v1beta")
+ relayGeminiRouter.Use(middleware.TokenAuth())
+ relayGeminiRouter.Use(middleware.ModelRequestRateLimit())
+ relayGeminiRouter.Use(middleware.Distribute())
+ {
+ // Gemini API 路径格式: /v1beta/models/{model_name}:{action}
+ relayGeminiRouter.POST("/models/*path", controller.Relay)
+ }
+}
+
+func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
+ relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage)
+ relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute())
+ {
+ relayMjRouter.POST("/submit/action", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/shorten", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/modal", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/describe", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/edits", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/video", controller.RelayMidjourney)
+ relayMjRouter.POST("/notify", controller.RelayMidjourney)
+ relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
+ relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
+ relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
+ relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
+ relayMjRouter.POST("/submit/upload-discord-images", controller.RelayMidjourney)
+ }
+}
diff --git a/router/video-router.go b/router/video-router.go
new file mode 100644
index 00000000..9e605d54
--- /dev/null
+++ b/router/video-router.go
@@ -0,0 +1,24 @@
+package router
+
+import (
+ "one-api/controller"
+ "one-api/middleware"
+
+ "github.com/gin-gonic/gin"
+)
+
+func SetVideoRouter(router *gin.Engine) {
+ videoV1Router := router.Group("/v1")
+ videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
+ {
+ videoV1Router.POST("/video/generations", controller.RelayTask)
+ videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
+ }
+
+ klingV1Router := router.Group("/kling/v1")
+ klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
+ {
+ klingV1Router.POST("/videos/text2video", controller.RelayTask)
+ klingV1Router.POST("/videos/image2video", controller.RelayTask)
+ }
+}
diff --git a/router/web-router.go b/router/web-router.go
new file mode 100644
index 00000000..57cd61ac
--- /dev/null
+++ b/router/web-router.go
@@ -0,0 +1,28 @@
+package router
+
+import (
+ "embed"
+ "github.com/gin-contrib/gzip"
+ "github.com/gin-contrib/static"
+ "github.com/gin-gonic/gin"
+ "net/http"
+ "one-api/common"
+ "one-api/controller"
+ "one-api/middleware"
+ "strings"
+)
+
+func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
+ router.Use(gzip.Gzip(gzip.DefaultCompression))
+ router.Use(middleware.GlobalWebRateLimit())
+ router.Use(middleware.Cache())
+ router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/dist")))
+ router.NoRoute(func(c *gin.Context) {
+ if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") || strings.HasPrefix(c.Request.RequestURI, "/assets") {
+ controller.RelayNotFound(c)
+ return
+ }
+ c.Header("Cache-Control", "no-cache")
+ c.Data(http.StatusOK, "text/html; charset=utf-8", indexPage)
+ })
+}
diff --git a/service/audio.go b/service/audio.go
new file mode 100644
index 00000000..c4b6f01b
--- /dev/null
+++ b/service/audio.go
@@ -0,0 +1,48 @@
+package service
+
+import (
+ "encoding/base64"
+ "fmt"
+ "strings"
+)
+
+func parseAudio(audioBase64 string, format string) (duration float64, err error) {
+ audioData, err := base64.StdEncoding.DecodeString(audioBase64)
+ if err != nil {
+ return 0, fmt.Errorf("base64 decode error: %v", err)
+ }
+
+ var samplesCount int
+ var sampleRate int
+
+ switch format {
+ case "pcm16":
+ samplesCount = len(audioData) / 2 // 16位 = 2字节每样本
+ sampleRate = 24000 // 24kHz
+ case "g711_ulaw", "g711_alaw":
+ samplesCount = len(audioData) // 8位 = 1字节每样本
+ sampleRate = 8000 // 8kHz
+ default:
+ samplesCount = len(audioData) // 8位 = 1字节每样本
+ sampleRate = 8000 // 8kHz
+ }
+
+ duration = float64(samplesCount) / float64(sampleRate)
+ return duration, nil
+}
+
+func DecodeBase64AudioData(audioBase64 string) (string, error) {
+ // 检查并移除 data:audio/xxx;base64, 前缀
+ idx := strings.Index(audioBase64, ",")
+ if idx != -1 {
+ audioBase64 = audioBase64[idx+1:]
+ }
+
+ // 解码 Base64 数据
+ _, err := base64.StdEncoding.DecodeString(audioBase64)
+ if err != nil {
+ return "", fmt.Errorf("base64 decode error: %v", err)
+ }
+
+ return audioBase64, nil
+}
diff --git a/service/cf_worker.go b/service/cf_worker.go
new file mode 100644
index 00000000..ae6e1ffe
--- /dev/null
+++ b/service/cf_worker.go
@@ -0,0 +1,57 @@
+package service
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/setting"
+ "strings"
+)
+
+// WorkerRequest Worker请求的数据结构
+type WorkerRequest struct {
+ URL string `json:"url"`
+ Key string `json:"key"`
+ Method string `json:"method,omitempty"`
+ Headers map[string]string `json:"headers,omitempty"`
+ Body json.RawMessage `json:"body,omitempty"`
+}
+
+// DoWorkerRequest 通过Worker发送请求
+func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
+ if !setting.EnableWorker() {
+ return nil, fmt.Errorf("worker not enabled")
+ }
+ if !setting.WorkerAllowHttpImageRequestEnabled && !strings.HasPrefix(req.URL, "https") {
+ return nil, fmt.Errorf("only support https url")
+ }
+
+ workerUrl := setting.WorkerUrl
+ if !strings.HasSuffix(workerUrl, "/") {
+ workerUrl += "/"
+ }
+
+ // 序列化worker请求数据
+ workerPayload, err := json.Marshal(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
+ }
+
+ return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
+}
+
+func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
+ if setting.EnableWorker() {
+ common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
+ req := &WorkerRequest{
+ URL: originUrl,
+ Key: setting.WorkerValidKey,
+ }
+ return DoWorkerRequest(req)
+ } else {
+ common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
+ return http.Get(originUrl)
+ }
+}
diff --git a/service/channel.go b/service/channel.go
new file mode 100644
index 00000000..4d38e6ed
--- /dev/null
+++ b/service/channel.go
@@ -0,0 +1,101 @@
+package service
+
+import (
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/model"
+ "one-api/setting/operation_setting"
+ "one-api/types"
+ "strings"
+)
+
+func formatNotifyType(channelId int, status int) string {
+ return fmt.Sprintf("%s_%d_%d", dto.NotifyTypeChannelUpdate, channelId, status)
+}
+
+// disable & notify
+func DisableChannel(channelError types.ChannelError, reason string) {
+ success := model.UpdateChannelStatus(channelError.ChannelId, channelError.UsingKey, common.ChannelStatusAutoDisabled, reason)
+ if success {
+ subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelError.ChannelName, channelError.ChannelId)
+ content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelError.ChannelName, channelError.ChannelId, reason)
+ NotifyRootUser(formatNotifyType(channelError.ChannelId, common.ChannelStatusAutoDisabled), subject, content)
+ }
+}
+
+func EnableChannel(channelId int, usingKey string, channelName string) {
+ success := model.UpdateChannelStatus(channelId, usingKey, common.ChannelStatusEnabled, "")
+ if success {
+ subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
+ content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
+ NotifyRootUser(formatNotifyType(channelId, common.ChannelStatusEnabled), subject, content)
+ }
+}
+
+func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool {
+ if !common.AutomaticDisableChannelEnabled {
+ return false
+ }
+ if err == nil {
+ return false
+ }
+ if types.IsChannelError(err) {
+ return true
+ }
+ if types.IsLocalError(err) {
+ return false
+ }
+ if err.StatusCode == http.StatusUnauthorized {
+ return true
+ }
+ if err.StatusCode == http.StatusForbidden {
+ switch channelType {
+ case constant.ChannelTypeGemini:
+ return true
+ }
+ }
+ oaiErr := err.ToOpenAIError()
+ switch oaiErr.Code {
+ case "invalid_api_key":
+ return true
+ case "account_deactivated":
+ return true
+ case "billing_not_active":
+ return true
+ case "pre_consume_token_quota_failed":
+ return true
+ }
+ switch oaiErr.Type {
+ case "insufficient_quota":
+ return true
+ case "insufficient_user_quota":
+ return true
+ // https://docs.anthropic.com/claude/reference/errors
+ case "authentication_error":
+ return true
+ case "permission_error":
+ return true
+ case "forbidden":
+ return true
+ }
+
+ lowerMessage := strings.ToLower(err.Error())
+ search, _ := AcSearch(lowerMessage, operation_setting.AutomaticDisableKeywords, true)
+ return search
+}
+
+func ShouldEnableChannel(newAPIError *types.NewAPIError, status int) bool {
+ if !common.AutomaticEnableChannelEnabled {
+ return false
+ }
+ if newAPIError != nil {
+ return false
+ }
+ if status != common.ChannelStatusAutoDisabled {
+ return false
+ }
+ return true
+}
diff --git a/service/convert.go b/service/convert.go
new file mode 100644
index 00000000..593b59d9
--- /dev/null
+++ b/service/convert.go
@@ -0,0 +1,440 @@
+package service
+
+import (
+ "encoding/json"
+ "fmt"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/relay/channel/openrouter"
+ relaycommon "one-api/relay/common"
+ "strings"
+)
+
+func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
+ openAIRequest := dto.GeneralOpenAIRequest{
+ Model: claudeRequest.Model,
+ MaxTokens: claudeRequest.MaxTokens,
+ Temperature: claudeRequest.Temperature,
+ TopP: claudeRequest.TopP,
+ Stream: claudeRequest.Stream,
+ }
+
+ isOpenRouter := info.ChannelType == constant.ChannelTypeOpenRouter
+
+ if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" {
+ if isOpenRouter {
+ reasoning := openrouter.RequestReasoning{
+ MaxTokens: claudeRequest.Thinking.GetBudgetTokens(),
+ }
+ reasoningJSON, err := json.Marshal(reasoning)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal reasoning: %w", err)
+ }
+ openAIRequest.Reasoning = reasoningJSON
+ } else {
+ thinkingSuffix := "-thinking"
+ if strings.HasSuffix(info.OriginModelName, thinkingSuffix) &&
+ !strings.HasSuffix(openAIRequest.Model, thinkingSuffix) {
+ openAIRequest.Model = openAIRequest.Model + thinkingSuffix
+ }
+ }
+ }
+
+ // Convert stop sequences
+ if len(claudeRequest.StopSequences) == 1 {
+ openAIRequest.Stop = claudeRequest.StopSequences[0]
+ } else if len(claudeRequest.StopSequences) > 1 {
+ openAIRequest.Stop = claudeRequest.StopSequences
+ }
+
+ // Convert tools
+ tools, _ := common.Any2Type[[]dto.Tool](claudeRequest.Tools)
+ openAITools := make([]dto.ToolCallRequest, 0)
+ for _, claudeTool := range tools {
+ openAITool := dto.ToolCallRequest{
+ Type: "function",
+ Function: dto.FunctionRequest{
+ Name: claudeTool.Name,
+ Description: claudeTool.Description,
+ Parameters: claudeTool.InputSchema,
+ },
+ }
+ openAITools = append(openAITools, openAITool)
+ }
+ openAIRequest.Tools = openAITools
+
+ // Convert messages
+ openAIMessages := make([]dto.Message, 0)
+
+ // Add system message if present
+ if claudeRequest.System != nil {
+ if claudeRequest.IsStringSystem() && claudeRequest.GetStringSystem() != "" {
+ openAIMessage := dto.Message{
+ Role: "system",
+ }
+ openAIMessage.SetStringContent(claudeRequest.GetStringSystem())
+ openAIMessages = append(openAIMessages, openAIMessage)
+ } else {
+ systems := claudeRequest.ParseSystem()
+ if len(systems) > 0 {
+ openAIMessage := dto.Message{
+ Role: "system",
+ }
+ isOpenRouterClaude := isOpenRouter && strings.HasPrefix(info.UpstreamModelName, "anthropic/claude")
+ if isOpenRouterClaude {
+ systemMediaMessages := make([]dto.MediaContent, 0, len(systems))
+ for _, system := range systems {
+ message := dto.MediaContent{
+ Type: "text",
+ Text: system.GetText(),
+ CacheControl: system.CacheControl,
+ }
+ systemMediaMessages = append(systemMediaMessages, message)
+ }
+ openAIMessage.SetMediaContent(systemMediaMessages)
+ } else {
+ systemStr := ""
+ for _, system := range systems {
+ if system.Text != nil {
+ systemStr += *system.Text
+ }
+ }
+ openAIMessage.SetStringContent(systemStr)
+ }
+ openAIMessages = append(openAIMessages, openAIMessage)
+ }
+ }
+ }
+ for _, claudeMessage := range claudeRequest.Messages {
+ openAIMessage := dto.Message{
+ Role: claudeMessage.Role,
+ }
+
+ //log.Printf("claudeMessage.Content: %v", claudeMessage.Content)
+ if claudeMessage.IsStringContent() {
+ openAIMessage.SetStringContent(claudeMessage.GetStringContent())
+ } else {
+ content, err := claudeMessage.ParseContent()
+ if err != nil {
+ return nil, err
+ }
+ contents := content
+ var toolCalls []dto.ToolCallRequest
+ mediaMessages := make([]dto.MediaContent, 0, len(contents))
+
+ for _, mediaMsg := range contents {
+ switch mediaMsg.Type {
+ case "text":
+ message := dto.MediaContent{
+ Type: "text",
+ Text: mediaMsg.GetText(),
+ CacheControl: mediaMsg.CacheControl,
+ }
+ mediaMessages = append(mediaMessages, message)
+ case "image":
+ // Handle image conversion (base64 to URL or keep as is)
+ imageData := fmt.Sprintf("data:%s;base64,%s", mediaMsg.Source.MediaType, mediaMsg.Source.Data)
+ //textContent += fmt.Sprintf("[Image: %s]", imageData)
+ mediaMessage := dto.MediaContent{
+ Type: "image_url",
+ ImageUrl: &dto.MessageImageUrl{Url: imageData},
+ }
+ mediaMessages = append(mediaMessages, mediaMessage)
+ case "tool_use":
+ toolCall := dto.ToolCallRequest{
+ ID: mediaMsg.Id,
+ Type: "function",
+ Function: dto.FunctionRequest{
+ Name: mediaMsg.Name,
+ Arguments: toJSONString(mediaMsg.Input),
+ },
+ }
+ toolCalls = append(toolCalls, toolCall)
+ case "tool_result":
+ // Add tool result as a separate message
+ oaiToolMessage := dto.Message{
+ Role: "tool",
+ Name: &mediaMsg.Name,
+ ToolCallId: mediaMsg.ToolUseId,
+ }
+ //oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text)
+ if mediaMsg.IsStringContent() {
+ oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
+ } else {
+ mediaContents := mediaMsg.ParseMediaContent()
+ encodeJson, _ := common.Marshal(mediaContents)
+ oaiToolMessage.SetStringContent(string(encodeJson))
+ }
+ openAIMessages = append(openAIMessages, oaiToolMessage)
+ }
+ }
+
+ if len(toolCalls) > 0 {
+ openAIMessage.SetToolCalls(toolCalls)
+ }
+
+ if len(mediaMessages) > 0 && len(toolCalls) == 0 {
+ openAIMessage.SetMediaContent(mediaMessages)
+ }
+ }
+ if len(openAIMessage.ParseContent()) > 0 || len(openAIMessage.ToolCalls) > 0 {
+ openAIMessages = append(openAIMessages, openAIMessage)
+ }
+ }
+
+ openAIRequest.Messages = openAIMessages
+
+ return &openAIRequest, nil
+}
+
+func OpenAIErrorToClaudeError(openAIError *dto.OpenAIErrorWithStatusCode) *dto.ClaudeErrorWithStatusCode {
+ claudeError := dto.ClaudeError{
+ Type: "new_api_error",
+ Message: openAIError.Error.Message,
+ }
+ return &dto.ClaudeErrorWithStatusCode{
+ Error: claudeError,
+ StatusCode: openAIError.StatusCode,
+ }
+}
+
+func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.OpenAIErrorWithStatusCode {
+ openAIError := dto.OpenAIError{
+ Message: claudeError.Error.Message,
+ Type: "new_api_error",
+ }
+ return &dto.OpenAIErrorWithStatusCode{
+ Error: openAIError,
+ StatusCode: claudeError.StatusCode,
+ }
+}
+
+func generateStopBlock(index int) *dto.ClaudeResponse {
+ return &dto.ClaudeResponse{
+ Type: "content_block_stop",
+ Index: common.GetPointer[int](index),
+ }
+}
+
+func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
+ var claudeResponses []*dto.ClaudeResponse
+ if info.SendResponseCount == 1 {
+ msg := &dto.ClaudeMediaMessage{
+ Id: openAIResponse.Id,
+ Model: openAIResponse.Model,
+ Type: "message",
+ Role: "assistant",
+ Usage: &dto.ClaudeUsage{
+ InputTokens: info.PromptTokens,
+ OutputTokens: 0,
+ },
+ }
+ msg.SetContent(make([]any, 0))
+ claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+ Type: "message_start",
+ Message: msg,
+ })
+ claudeResponses = append(claudeResponses)
+ //claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+ // Type: "ping",
+ //})
+ if openAIResponse.IsToolCall() {
+ resp := &dto.ClaudeResponse{
+ Type: "content_block_start",
+ ContentBlock: &dto.ClaudeMediaMessage{
+ Id: openAIResponse.GetFirstToolCall().ID,
+ Type: "tool_use",
+ Name: openAIResponse.GetFirstToolCall().Function.Name,
+ },
+ }
+ resp.SetIndex(0)
+ claudeResponses = append(claudeResponses, resp)
+ } else {
+ //resp := &dto.ClaudeResponse{
+ // Type: "content_block_start",
+ // ContentBlock: &dto.ClaudeMediaMessage{
+ // Type: "text",
+ // Text: common.GetPointer[string](""),
+ // },
+ //}
+ //resp.SetIndex(0)
+ //claudeResponses = append(claudeResponses, resp)
+ }
+ return claudeResponses
+ }
+
+ if len(openAIResponse.Choices) == 0 {
+ // no choices
+ // TODO: handle this case
+ return claudeResponses
+ } else {
+ chosenChoice := openAIResponse.Choices[0]
+ if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
+ // should be done
+ info.FinishReason = *chosenChoice.FinishReason
+ return claudeResponses
+ }
+ if info.Done {
+ claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
+ oaiUsage := info.ClaudeConvertInfo.Usage
+ if oaiUsage != nil {
+ claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+ Type: "message_delta",
+ Usage: &dto.ClaudeUsage{
+ InputTokens: oaiUsage.PromptTokens,
+ OutputTokens: oaiUsage.CompletionTokens,
+ CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
+ CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
+ },
+ Delta: &dto.ClaudeMediaMessage{
+ StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
+ },
+ })
+ }
+ claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+ Type: "message_stop",
+ })
+ } else {
+ var claudeResponse dto.ClaudeResponse
+ var isEmpty bool
+ claudeResponse.Type = "content_block_delta"
+ if len(chosenChoice.Delta.ToolCalls) > 0 {
+ if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
+ claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
+ info.ClaudeConvertInfo.Index++
+ claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+ Index: &info.ClaudeConvertInfo.Index,
+ Type: "content_block_start",
+ ContentBlock: &dto.ClaudeMediaMessage{
+ Id: openAIResponse.GetFirstToolCall().ID,
+ Type: "tool_use",
+ Name: openAIResponse.GetFirstToolCall().Function.Name,
+ Input: map[string]interface{}{},
+ },
+ })
+ }
+ info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
+ // tools delta
+ claudeResponse.Delta = &dto.ClaudeMediaMessage{
+ Type: "input_json_delta",
+ PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
+ }
+ } else {
+ reasoning := chosenChoice.Delta.GetReasoningContent()
+ textContent := chosenChoice.Delta.GetContentString()
+ if reasoning != "" || textContent != "" {
+ if reasoning != "" {
+ if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
+ //info.ClaudeConvertInfo.Index++
+ claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+ Index: &info.ClaudeConvertInfo.Index,
+ Type: "content_block_start",
+ ContentBlock: &dto.ClaudeMediaMessage{
+ Type: "thinking",
+ Thinking: "",
+ },
+ })
+ }
+ info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
+ // text delta
+ claudeResponse.Delta = &dto.ClaudeMediaMessage{
+ Type: "thinking_delta",
+ Thinking: reasoning,
+ }
+ } else {
+ if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
+ if info.LastMessagesType == relaycommon.LastMessageTypeThinking || info.LastMessagesType == relaycommon.LastMessageTypeTools {
+ claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
+ info.ClaudeConvertInfo.Index++
+ }
+ claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+ Index: &info.ClaudeConvertInfo.Index,
+ Type: "content_block_start",
+ ContentBlock: &dto.ClaudeMediaMessage{
+ Type: "text",
+ Text: common.GetPointer[string](""),
+ },
+ })
+ }
+ info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
+ // text delta
+ claudeResponse.Delta = &dto.ClaudeMediaMessage{
+ Type: "text_delta",
+ Text: common.GetPointer[string](textContent),
+ }
+ }
+ } else {
+ isEmpty = true
+ }
+ }
+ claudeResponse.Index = &info.ClaudeConvertInfo.Index
+ if !isEmpty {
+ claudeResponses = append(claudeResponses, &claudeResponse)
+ }
+ }
+ }
+
+ return claudeResponses
+}
+
+func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.ClaudeResponse {
+ var stopReason string
+ contents := make([]dto.ClaudeMediaMessage, 0)
+ claudeResponse := &dto.ClaudeResponse{
+ Id: openAIResponse.Id,
+ Type: "message",
+ Role: "assistant",
+ Model: openAIResponse.Model,
+ }
+ for _, choice := range openAIResponse.Choices {
+ stopReason = stopReasonOpenAI2Claude(choice.FinishReason)
+ claudeContent := dto.ClaudeMediaMessage{}
+ if choice.FinishReason == "tool_calls" {
+ claudeContent.Type = "tool_use"
+ claudeContent.Id = choice.Message.ToolCallId
+ claudeContent.Name = choice.Message.ParseToolCalls()[0].Function.Name
+ var mapParams map[string]interface{}
+ if err := json.Unmarshal([]byte(choice.Message.ParseToolCalls()[0].Function.Arguments), &mapParams); err == nil {
+ claudeContent.Input = mapParams
+ } else {
+ claudeContent.Input = choice.Message.ParseToolCalls()[0].Function.Arguments
+ }
+ } else {
+ claudeContent.Type = "text"
+ claudeContent.SetText(choice.Message.StringContent())
+ }
+ contents = append(contents, claudeContent)
+ }
+ claudeResponse.Content = contents
+ claudeResponse.StopReason = stopReason
+ claudeResponse.Usage = &dto.ClaudeUsage{
+ InputTokens: openAIResponse.PromptTokens,
+ OutputTokens: openAIResponse.CompletionTokens,
+ }
+
+ return claudeResponse
+}
+
+func stopReasonOpenAI2Claude(reason string) string {
+ switch reason {
+ case "stop":
+ return "end_turn"
+ case "stop_sequence":
+ return "stop_sequence"
+ case "max_tokens":
+ return "max_tokens"
+ case "tool_calls":
+ return "tool_use"
+ default:
+ return reason
+ }
+}
+
+func toJSONString(v interface{}) string {
+ b, err := json.Marshal(v)
+ if err != nil {
+ return "{}"
+ }
+ return string(b)
+}
diff --git a/service/epay.go b/service/epay.go
new file mode 100644
index 00000000..a8259d21
--- /dev/null
+++ b/service/epay.go
@@ -0,0 +1,12 @@
+package service
+
+import (
+ "one-api/setting"
+)
+
+func GetCallbackAddress() string {
+ if setting.CustomCallbackAddress == "" {
+ return setting.ServerAddress
+ }
+ return setting.CustomCallbackAddress
+}
diff --git a/service/error.go b/service/error.go
new file mode 100644
index 00000000..a0713b55
--- /dev/null
+++ b/service/error.go
@@ -0,0 +1,155 @@
+package service
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/types"
+ "strconv"
+ "strings"
+)
+
+func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse {
+ return &dto.MidjourneyResponse{
+ Code: code,
+ Description: desc,
+ }
+}
+
+func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *dto.MidjourneyResponseWithStatusCode {
+ return &dto.MidjourneyResponseWithStatusCode{
+ StatusCode: statusCode,
+ Response: *MidjourneyErrorWrapper(code, desc),
+ }
+}
+
+//// OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode
+//func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
+// text := err.Error()
+// lowerText := strings.ToLower(text)
+// if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") {
+// if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
+// common.SysLog(fmt.Sprintf("error: %s", text))
+// text = "请求上游地址失败"
+// }
+// }
+// openAIError := dto.OpenAIError{
+// Message: text,
+// Type: "new_api_error",
+// Code: code,
+// }
+// return &dto.OpenAIErrorWithStatusCode{
+// Error: openAIError,
+// StatusCode: statusCode,
+// }
+//}
+//
+//func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
+// openaiErr := OpenAIErrorWrapper(err, code, statusCode)
+// openaiErr.LocalError = true
+// return openaiErr
+//}
+
+func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
+ text := err.Error()
+ lowerText := strings.ToLower(text)
+ if !strings.HasPrefix(lowerText, "get file base64 from url") {
+ if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
+ common.SysLog(fmt.Sprintf("error: %s", text))
+ text = "请求上游地址失败"
+ }
+ }
+ claudeError := dto.ClaudeError{
+ Message: text,
+ Type: "new_api_error",
+ }
+ return &dto.ClaudeErrorWithStatusCode{
+ Error: claudeError,
+ StatusCode: statusCode,
+ }
+}
+
+func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
+ claudeErr := ClaudeErrorWrapper(err, code, statusCode)
+ claudeErr.LocalError = true
+ return claudeErr
+}
+
+func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
+ newApiErr = &types.NewAPIError{
+ StatusCode: resp.StatusCode,
+ ErrorType: types.ErrorTypeOpenAIError,
+ }
+
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return
+ }
+ common.CloseResponseBodyGracefully(resp)
+ var errResponse dto.GeneralErrorResponse
+
+ err = common.Unmarshal(responseBody, &errResponse)
+ if err != nil {
+ if showBodyWhenFail {
+ newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
+ } else {
+ newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
+ }
+ return
+ }
+ if errResponse.Error.Message != "" {
+ // General format error (OpenAI, Anthropic, Gemini, etc.)
+ newApiErr = types.WithOpenAIError(errResponse.Error, resp.StatusCode)
+ } else {
+ newApiErr = types.NewErrorWithStatusCode(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
+ newApiErr.ErrorType = types.ErrorTypeOpenAIError
+ }
+ return
+}
+
+func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string) {
+ if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" {
+ return
+ }
+ statusCodeMapping := make(map[string]string)
+ err := json.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
+ if err != nil {
+ return
+ }
+ if newApiErr.StatusCode == http.StatusOK {
+ return
+ }
+ codeStr := strconv.Itoa(newApiErr.StatusCode)
+ if _, ok := statusCodeMapping[codeStr]; ok {
+ intCode, _ := strconv.Atoi(statusCodeMapping[codeStr])
+ newApiErr.StatusCode = intCode
+ }
+}
+
+func TaskErrorWrapperLocal(err error, code string, statusCode int) *dto.TaskError {
+ openaiErr := TaskErrorWrapper(err, code, statusCode)
+ openaiErr.LocalError = true
+ return openaiErr
+}
+
+func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
+ text := err.Error()
+ lowerText := strings.ToLower(text)
+ if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
+ common.SysLog(fmt.Sprintf("error: %s", text))
+ text = "请求上游地址失败"
+ }
+ //避免暴露内部错误
+ taskError := &dto.TaskError{
+ Code: code,
+ Message: text,
+ StatusCode: statusCode,
+ Error: err,
+ }
+
+ return taskError
+}
diff --git a/service/file_decoder.go b/service/file_decoder.go
new file mode 100644
index 00000000..c1d4fb0c
--- /dev/null
+++ b/service/file_decoder.go
@@ -0,0 +1,135 @@
+package service
+
+import (
+ "encoding/base64"
+ "fmt"
+ "io"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "strings"
+)
+
+func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
+ var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
+
+ resp, err := DoDownloadRequest(url)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ // Always use LimitReader to prevent oversized downloads
+ fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
+ if err != nil {
+ return nil, err
+ }
+ // Check actual size after reading
+ if len(fileBytes) > maxFileSize {
+ return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
+ }
+
+ // Convert to base64
+ base64Data := base64.StdEncoding.EncodeToString(fileBytes)
+
+ mimeType := resp.Header.Get("Content-Type")
+ if len(strings.Split(mimeType, ";")) > 1 {
+ // If Content-Type has parameters, take the first part
+ mimeType = strings.Split(mimeType, ";")[0]
+ }
+ if mimeType == "application/octet-stream" {
+ if common.DebugEnabled {
+ println("MIME type is application/octet-stream, trying to guess from URL or filename")
+ }
+ // try to guess the MIME type from the url last segment
+ urlParts := strings.Split(url, "/")
+ if len(urlParts) > 0 {
+ lastSegment := urlParts[len(urlParts)-1]
+ if strings.Contains(lastSegment, ".") {
+ // Extract the file extension
+ filename := strings.Split(lastSegment, ".")
+ if len(filename) > 1 {
+ ext := strings.ToLower(filename[len(filename)-1])
+ // Guess MIME type based on file extension
+ mimeType = GetMimeTypeByExtension(ext)
+ }
+ }
+ } else {
+ // try to guess the MIME type from the file extension
+ fileName := resp.Header.Get("Content-Disposition")
+ if fileName != "" {
+ // Extract the filename from the Content-Disposition header
+ parts := strings.Split(fileName, ";")
+ for _, part := range parts {
+ if strings.HasPrefix(strings.TrimSpace(part), "filename=") {
+ fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename="))
+ // Remove quotes if present
+ if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' {
+ fileName = fileName[1 : len(fileName)-1]
+ }
+ // Guess MIME type based on file extension
+ if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" {
+ mimeType = GetMimeTypeByExtension(ext)
+ }
+ break
+ }
+ }
+ }
+ }
+ }
+
+ return &dto.LocalFileData{
+ Base64Data: base64Data,
+ MimeType: mimeType,
+ Size: int64(len(fileBytes)),
+ }, nil
+}
+
+func GetMimeTypeByExtension(ext string) string {
+ // Convert to lowercase for case-insensitive comparison
+ ext = strings.ToLower(ext)
+ switch ext {
+ // Text files
+ case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm":
+ return "text/plain"
+
+ // Image files
+ case "jpg", "jpeg":
+ return "image/jpeg"
+ case "png":
+ return "image/png"
+ case "gif":
+ return "image/gif"
+
+ // Audio files
+ case "mp3":
+ return "audio/mp3"
+ case "wav":
+ return "audio/wav"
+ case "mpeg":
+ return "audio/mpeg"
+
+ // Video files
+ case "mp4":
+ return "video/mp4"
+ case "wmv":
+ return "video/wmv"
+ case "flv":
+ return "video/flv"
+ case "mov":
+ return "video/mov"
+ case "mpg":
+ return "video/mpg"
+ case "avi":
+ return "video/avi"
+ case "mpegps":
+ return "video/mpegps"
+
+ // Document files
+ case "pdf":
+ return "application/pdf"
+
+ default:
+ return "application/octet-stream" // Default for unknown types
+ }
+}
diff --git a/service/http_client.go b/service/http_client.go
new file mode 100644
index 00000000..b191ddd7
--- /dev/null
+++ b/service/http_client.go
@@ -0,0 +1,81 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/http"
+ "net/url"
+ "one-api/common"
+ "time"
+
+ "golang.org/x/net/proxy"
+)
+
+var httpClient *http.Client
+
+func InitHttpClient() {
+ if common.RelayTimeout == 0 {
+ httpClient = &http.Client{}
+ } else {
+ httpClient = &http.Client{
+ Timeout: time.Duration(common.RelayTimeout) * time.Second,
+ }
+ }
+}
+
+func GetHttpClient() *http.Client {
+ return httpClient
+}
+
+// NewProxyHttpClient 创建支持代理的 HTTP 客户端
+func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
+ if proxyURL == "" {
+ return http.DefaultClient, nil
+ }
+
+ parsedURL, err := url.Parse(proxyURL)
+ if err != nil {
+ return nil, err
+ }
+
+ switch parsedURL.Scheme {
+ case "http", "https":
+ return &http.Client{
+ Transport: &http.Transport{
+ Proxy: http.ProxyURL(parsedURL),
+ },
+ }, nil
+
+ case "socks5", "socks5h":
+ // 获取认证信息
+ var auth *proxy.Auth
+ if parsedURL.User != nil {
+ auth = &proxy.Auth{
+ User: parsedURL.User.Username(),
+ Password: "",
+ }
+ if password, ok := parsedURL.User.Password(); ok {
+ auth.Password = password
+ }
+ }
+
+ // 创建 SOCKS5 代理拨号器
+ // proxy.SOCKS5 使用 tcp 参数,所有 TCP 连接包括 DNS 查询都将通过代理进行。行为与 socks5h 相同
+ dialer, err := proxy.SOCKS5("tcp", parsedURL.Host, auth, proxy.Direct)
+ if err != nil {
+ return nil, err
+ }
+
+ return &http.Client{
+ Transport: &http.Transport{
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return dialer.Dial(network, addr)
+ },
+ },
+ }, nil
+
+ default:
+ return nil, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
+ }
+}
diff --git a/service/image.go b/service/image.go
new file mode 100644
index 00000000..252093f1
--- /dev/null
+++ b/service/image.go
@@ -0,0 +1,172 @@
+package service
+
+import (
+ "bytes"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "image"
+ "io"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "strings"
+
+ "golang.org/x/image/webp"
+)
+
+func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
+ // 去除base64数据的URL前缀(如果有)
+ if idx := strings.Index(base64String, ","); idx != -1 {
+ base64String = base64String[idx+1:]
+ }
+
+ // 将base64字符串解码为字节切片
+ decodedData, err := base64.StdEncoding.DecodeString(base64String)
+ if err != nil {
+ fmt.Println("Error: Failed to decode base64 string")
+ return image.Config{}, "", "", fmt.Errorf("failed to decode base64 string: %s", err.Error())
+ }
+
+ // 创建一个bytes.Buffer用于存储解码后的数据
+ reader := bytes.NewReader(decodedData)
+ config, format, err := getImageConfig(reader)
+ return config, format, base64String, err
+}
+
+func DecodeBase64FileData(base64String string) (string, string, error) {
+ var mimeType string
+ var idx int
+ idx = strings.Index(base64String, ",")
+ if idx == -1 {
+ _, file_type, base64, err := DecodeBase64ImageData(base64String)
+ return "image/" + file_type, base64, err
+ }
+ mimeType = base64String[:idx]
+ base64String = base64String[idx+1:]
+ idx = strings.Index(mimeType, ";")
+ if idx == -1 {
+ _, file_type, base64, err := DecodeBase64ImageData(base64String)
+ return "image/" + file_type, base64, err
+ }
+ mimeType = mimeType[:idx]
+ idx = strings.Index(mimeType, ":")
+ if idx == -1 {
+ _, file_type, base64, err := DecodeBase64ImageData(base64String)
+ return "image/" + file_type, base64, err
+ }
+ mimeType = mimeType[idx+1:]
+ return mimeType, base64String, nil
+}
+
+// GetImageFromUrl 获取图片的类型和base64编码的数据
+func GetImageFromUrl(url string) (mimeType string, data string, err error) {
+ resp, err := DoDownloadRequest(url)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to download image: %w", err)
+ }
+ defer resp.Body.Close()
+
+ // Check HTTP status code
+ if resp.StatusCode != http.StatusOK {
+ return "", "", fmt.Errorf("failed to download image: HTTP %d", resp.StatusCode)
+ }
+
+ contentType := resp.Header.Get("Content-Type")
+ if contentType != "application/octet-stream" && !strings.HasPrefix(contentType, "image/") {
+ return "", "", fmt.Errorf("invalid content type: %s, required image/*", contentType)
+ }
+ maxImageSize := int64(constant.MaxFileDownloadMB * 1024 * 1024)
+
+ // Check Content-Length if available
+ if resp.ContentLength > maxImageSize {
+ return "", "", fmt.Errorf("image size %d exceeds maximum allowed size of %d bytes", resp.ContentLength, maxImageSize)
+ }
+
+ // Use LimitReader to prevent reading oversized images
+ limitReader := io.LimitReader(resp.Body, maxImageSize)
+ buffer := &bytes.Buffer{}
+
+ written, err := io.Copy(buffer, limitReader)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to read image data: %w", err)
+ }
+ if written >= maxImageSize {
+ return "", "", fmt.Errorf("image size exceeds maximum allowed size of %d bytes", maxImageSize)
+ }
+
+ data = base64.StdEncoding.EncodeToString(buffer.Bytes())
+ mimeType = contentType
+
+ // Handle application/octet-stream type
+ if mimeType == "application/octet-stream" {
+ _, format, _, err := DecodeBase64ImageData(data)
+ if err != nil {
+ return "", "", err
+ }
+ mimeType = "image/" + format
+ }
+
+ return mimeType, data, nil
+}
+
+func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
+ response, err := DoDownloadRequest(imageUrl)
+ if err != nil {
+ common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
+ return image.Config{}, "", err
+ }
+ defer response.Body.Close()
+
+ if response.StatusCode != 200 {
+ err = errors.New(fmt.Sprintf("fail to get image from url: %s", response.Status))
+ return image.Config{}, "", err
+ }
+
+ mimeType := response.Header.Get("Content-Type")
+
+ if mimeType != "application/octet-stream" && !strings.HasPrefix(mimeType, "image/") {
+ return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType)
+ }
+
+ var readData []byte
+ for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
+ common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
+
+ // 从response.Body读取更多的数据直到达到当前的限制
+ additionalData := make([]byte, limit-int64(len(readData)))
+ n, _ := io.ReadFull(response.Body, additionalData)
+ readData = append(readData, additionalData[:n]...)
+
+ // 使用io.MultiReader组合已经读取的数据和response.Body
+ limitReader := io.MultiReader(bytes.NewReader(readData), response.Body)
+
+ var config image.Config
+ var format string
+ config, format, err = getImageConfig(limitReader)
+ if err == nil {
+ return config, format, nil
+ }
+ }
+
+ return image.Config{}, "", err // 返回最后一个错误
+}
+
+func getImageConfig(reader io.Reader) (image.Config, string, error) {
+ // 读取图片的头部信息来获取图片尺寸
+ config, format, err := image.DecodeConfig(reader)
+ if err != nil {
+ err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
+ common.SysLog(err.Error())
+ config, err = webp.DecodeConfig(reader)
+ if err != nil {
+ err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
+ common.SysLog(err.Error())
+ }
+ format = "webp"
+ }
+ if err != nil {
+ return image.Config{}, "", err
+ }
+ return config, format, nil
+}
diff --git a/service/log_info_generate.go b/service/log_info_generate.go
new file mode 100644
index 00000000..020a2ba9
--- /dev/null
+++ b/service/log_info_generate.go
@@ -0,0 +1,83 @@
+package service
+
+import (
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
+ cacheTokens int, cacheRatio float64, modelPrice float64, userGroupRatio float64) map[string]interface{} {
+ other := make(map[string]interface{})
+ other["model_ratio"] = modelRatio
+ other["group_ratio"] = groupRatio
+ other["completion_ratio"] = completionRatio
+ other["cache_tokens"] = cacheTokens
+ other["cache_ratio"] = cacheRatio
+ other["model_price"] = modelPrice
+ other["user_group_ratio"] = userGroupRatio
+ other["frt"] = float64(relayInfo.FirstResponseTime.UnixMilli() - relayInfo.StartTime.UnixMilli())
+ if relayInfo.ReasoningEffort != "" {
+ other["reasoning_effort"] = relayInfo.ReasoningEffort
+ }
+ if relayInfo.IsModelMapped {
+ other["is_model_mapped"] = true
+ other["upstream_model_name"] = relayInfo.UpstreamModelName
+ }
+ adminInfo := make(map[string]interface{})
+ adminInfo["use_channel"] = ctx.GetStringSlice("use_channel")
+ isMultiKey := common.GetContextKeyBool(ctx, constant.ContextKeyChannelIsMultiKey)
+ if isMultiKey {
+ adminInfo["is_multi_key"] = true
+ adminInfo["multi_key_index"] = common.GetContextKeyInt(ctx, constant.ContextKeyChannelMultiKeyIndex)
+ }
+ other["admin_info"] = adminInfo
+ return other
+}
+
+func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
+ info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
+ info["ws"] = true
+ info["audio_input"] = usage.InputTokenDetails.AudioTokens
+ info["audio_output"] = usage.OutputTokenDetails.AudioTokens
+ info["text_input"] = usage.InputTokenDetails.TextTokens
+ info["text_output"] = usage.OutputTokenDetails.TextTokens
+ info["audio_ratio"] = audioRatio
+ info["audio_completion_ratio"] = audioCompletionRatio
+ return info
+}
+
+func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
+ info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
+ info["audio"] = true
+ info["audio_input"] = usage.PromptTokensDetails.AudioTokens
+ info["audio_output"] = usage.CompletionTokenDetails.AudioTokens
+ info["text_input"] = usage.PromptTokensDetails.TextTokens
+ info["text_output"] = usage.CompletionTokenDetails.TextTokens
+ info["audio_ratio"] = audioRatio
+ info["audio_completion_ratio"] = audioCompletionRatio
+ return info
+}
+
+func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
+ cacheTokens int, cacheRatio float64, cacheCreationTokens int, cacheCreationRatio float64, modelPrice float64, userGroupRatio float64) map[string]interface{} {
+ info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio)
+ info["claude"] = true
+ info["cache_creation_tokens"] = cacheCreationTokens
+ info["cache_creation_ratio"] = cacheCreationRatio
+ return info
+}
+
+func GenerateMjOtherInfo(priceData helper.PerCallPriceData) map[string]interface{} {
+ other := make(map[string]interface{})
+ other["model_price"] = priceData.ModelPrice
+ other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio
+ if priceData.GroupRatioInfo.HasSpecialRatio {
+ other["user_group_ratio"] = priceData.GroupRatioInfo.GroupSpecialRatio
+ }
+ return other
+}
diff --git a/service/midjourney.go b/service/midjourney.go
new file mode 100644
index 00000000..1fc19682
--- /dev/null
+++ b/service/midjourney.go
@@ -0,0 +1,258 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "log"
+ "net/http"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ relayconstant "one-api/relay/constant"
+ "one-api/setting"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+func CoverActionToModelName(mjAction string) string {
+ modelName := "mj_" + strings.ToLower(mjAction)
+ if mjAction == constant.MjActionSwapFace {
+ modelName = "swap_face"
+ }
+ return modelName
+}
+
+func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) {
+ action := ""
+ if relayMode == relayconstant.RelayModeMidjourneyAction {
+ // plus request
+ err := CoverPlusActionToNormalAction(midjRequest)
+ if err != nil {
+ return "", err, false
+ }
+ action = midjRequest.Action
+ } else {
+ switch relayMode {
+ case relayconstant.RelayModeMidjourneyImagine:
+ action = constant.MjActionImagine
+ case relayconstant.RelayModeMidjourneyVideo:
+ action = constant.MjActionVideo
+ case relayconstant.RelayModeMidjourneyEdits:
+ action = constant.MjActionEdits
+ case relayconstant.RelayModeMidjourneyDescribe:
+ action = constant.MjActionDescribe
+ case relayconstant.RelayModeMidjourneyBlend:
+ action = constant.MjActionBlend
+ case relayconstant.RelayModeMidjourneyShorten:
+ action = constant.MjActionShorten
+ case relayconstant.RelayModeMidjourneyChange:
+ action = midjRequest.Action
+ case relayconstant.RelayModeMidjourneyModal:
+ action = constant.MjActionModal
+ case relayconstant.RelayModeSwapFace:
+ action = constant.MjActionSwapFace
+ case relayconstant.RelayModeMidjourneyUpload:
+ action = constant.MjActionUpload
+ case relayconstant.RelayModeMidjourneySimpleChange:
+ params := ConvertSimpleChangeParams(midjRequest.Content)
+ if params == nil {
+ return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false
+ }
+ action = params.Action
+ case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
+ return "", nil, true
+ default:
+ return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false
+ }
+ }
+ modelName := CoverActionToModelName(action)
+ return modelName, nil, true
+}
+
+func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
+ // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
+ customId := midjRequest.CustomId
+ if customId == "" {
+ return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
+ }
+ splits := strings.Split(customId, "::")
+ var action string
+ if splits[1] == "JOB" {
+ action = splits[2]
+ } else {
+ action = splits[1]
+ }
+
+ if action == "" {
+ return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
+ }
+ if strings.Contains(action, "upsample") {
+ index, err := strconv.Atoi(splits[3])
+ if err != nil {
+ return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
+ }
+ midjRequest.Index = index
+ midjRequest.Action = constant.MjActionUpscale
+ } else if strings.Contains(action, "variation") {
+ midjRequest.Index = 1
+ if action == "variation" {
+ index, err := strconv.Atoi(splits[3])
+ if err != nil {
+ return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
+ }
+ midjRequest.Index = index
+ midjRequest.Action = constant.MjActionVariation
+ } else if action == "low_variation" {
+ midjRequest.Action = constant.MjActionLowVariation
+ } else if action == "high_variation" {
+ midjRequest.Action = constant.MjActionHighVariation
+ }
+ } else if strings.Contains(action, "pan") {
+ midjRequest.Action = constant.MjActionPan
+ midjRequest.Index = 1
+ } else if strings.Contains(action, "reroll") {
+ midjRequest.Action = constant.MjActionReRoll
+ midjRequest.Index = 1
+ } else if action == "Outpaint" {
+ midjRequest.Action = constant.MjActionZoom
+ midjRequest.Index = 1
+ } else if action == "CustomZoom" {
+ midjRequest.Action = constant.MjActionCustomZoom
+ midjRequest.Index = 1
+ } else if action == "Inpaint" {
+ midjRequest.Action = constant.MjActionInPaint
+ midjRequest.Index = 1
+ } else {
+ return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
+ }
+ return nil
+}
+
+func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
+ split := strings.Split(content, " ")
+ if len(split) != 2 {
+ return nil
+ }
+
+ action := strings.ToLower(split[1])
+ changeParams := &dto.MidjourneyRequest{}
+ changeParams.TaskId = split[0]
+
+ if action[0] == 'u' {
+ changeParams.Action = "UPSCALE"
+ } else if action[0] == 'v' {
+ changeParams.Action = "VARIATION"
+ } else if action == "r" {
+ changeParams.Action = "REROLL"
+ return changeParams
+ } else {
+ return nil
+ }
+
+ index, err := strconv.Atoi(action[1:2])
+ if err != nil || index < 1 || index > 4 {
+ return nil
+ }
+ changeParams.Index = index
+ return changeParams
+}
+
+func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
+ var nullBytes []byte
+ //var requestBody io.Reader
+ //requestBody = c.Request.Body
+ // read request body to json, delete accountFilter and notifyHook
+ var mapResult map[string]interface{}
+ // if get request, no need to read request body
+ if c.Request.Method != "GET" {
+ err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ if !setting.MjAccountFilterEnabled {
+ delete(mapResult, "accountFilter")
+ }
+ if !setting.MjNotifyEnabled {
+ delete(mapResult, "notifyHook")
+ }
+ //req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+ // make new request with mapResult
+ }
+ if setting.MjModeClearEnabled {
+ if prompt, ok := mapResult["prompt"].(string); ok {
+ prompt = strings.Replace(prompt, "--fast", "", -1)
+ prompt = strings.Replace(prompt, "--relax", "", -1)
+ prompt = strings.Replace(prompt, "--turbo", "", -1)
+
+ mapResult["prompt"] = prompt
+ }
+ }
+ reqBody, err := json.Marshal(mapResult)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ // 使用带有超时的 context 创建新的请求
+ req = req.WithContext(ctx)
+ req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+ req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+ auth := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
+ if auth != "" {
+ auth = strings.TrimPrefix(auth, "Bearer ")
+ req.Header.Set("mj-api-secret", auth)
+ }
+ defer cancel()
+ resp, err := GetHttpClient().Do(req)
+ if err != nil {
+ common.SysError("do request failed: " + err.Error())
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
+ }
+ statusCode := resp.StatusCode
+ //if statusCode != 200 {
+ // return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil
+ //}
+ err = req.Body.Close()
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
+ }
+ err = c.Request.Body.Close()
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
+ }
+ var midjResponse dto.MidjourneyResponse
+ var midjourneyUploadsResponse dto.MidjourneyUploadResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
+ }
+ common.CloseResponseBodyGracefully(resp)
+ respStr := string(responseBody)
+ log.Printf("respStr: %s", respStr)
+ if respStr == "" {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
+ } else {
+ err = json.Unmarshal(responseBody, &midjResponse)
+ if err != nil {
+ err2 := json.Unmarshal(responseBody, &midjourneyUploadsResponse)
+ if err2 != nil {
+ return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
+ }
+ }
+ }
+ //log.Printf("midjResponse: %v", midjResponse)
+ //for k, v := range resp.Header {
+ // c.Writer.Header().Set(k, v[0])
+ //}
+ return &dto.MidjourneyResponseWithStatusCode{
+ StatusCode: statusCode,
+ Response: midjResponse,
+ }, responseBody, nil
+}
diff --git a/service/notify-limit.go b/service/notify-limit.go
new file mode 100644
index 00000000..309ea54d
--- /dev/null
+++ b/service/notify-limit.go
@@ -0,0 +1,117 @@
+package service
+
+import (
+ "fmt"
+ "github.com/bytedance/gopkg/util/gopool"
+ "one-api/common"
+ "one-api/constant"
+ "strconv"
+ "sync"
+ "time"
+)
+
+// notifyLimitStore is used for in-memory rate limiting when Redis is disabled
+var (
+ notifyLimitStore sync.Map
+ cleanupOnce sync.Once
+)
+
+type limitCount struct {
+ Count int
+ Timestamp time.Time
+}
+
+func getDuration() time.Duration {
+ minute := constant.NotificationLimitDurationMinute
+ return time.Duration(minute) * time.Minute
+}
+
+// startCleanupTask starts a background task to clean up expired entries
+func startCleanupTask() {
+ gopool.Go(func() {
+ for {
+ time.Sleep(time.Hour)
+ now := time.Now()
+ notifyLimitStore.Range(func(key, value interface{}) bool {
+ if limit, ok := value.(limitCount); ok {
+ if now.Sub(limit.Timestamp) >= getDuration() {
+ notifyLimitStore.Delete(key)
+ }
+ }
+ return true
+ })
+ }
+ })
+}
+
+// CheckNotificationLimit checks if the user has exceeded their notification limit
+// Returns true if the user can send notification, false if limit exceeded
+func CheckNotificationLimit(userId int, notifyType string) (bool, error) {
+ if common.RedisEnabled {
+ return checkRedisLimit(userId, notifyType)
+ }
+ return checkMemoryLimit(userId, notifyType)
+}
+
+func checkRedisLimit(userId int, notifyType string) (bool, error) {
+ key := fmt.Sprintf("notify_limit:%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
+
+ // Get current count
+ count, err := common.RedisGet(key)
+ if err != nil && err.Error() != "redis: nil" {
+ return false, fmt.Errorf("failed to get notification count: %w", err)
+ }
+
+ // If key doesn't exist, initialize it
+ if count == "" {
+ err = common.RedisSet(key, "1", getDuration())
+ return true, err
+ }
+
+ currentCount, _ := strconv.Atoi(count)
+ limit := constant.NotifyLimitCount
+
+ // Check if limit is already reached
+ if currentCount >= limit {
+ return false, nil
+ }
+
+ // Only increment if under limit
+ err = common.RedisIncr(key, 1)
+ if err != nil {
+ return false, fmt.Errorf("failed to increment notification count: %w", err)
+ }
+
+ return true, nil
+}
+
+func checkMemoryLimit(userId int, notifyType string) (bool, error) {
+ // Ensure cleanup task is started
+ cleanupOnce.Do(startCleanupTask)
+
+ key := fmt.Sprintf("%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
+ now := time.Now()
+
+ // Get current limit count or initialize new one
+ var currentLimit limitCount
+ if value, ok := notifyLimitStore.Load(key); ok {
+ currentLimit = value.(limitCount)
+ // Check if the entry has expired
+ if now.Sub(currentLimit.Timestamp) >= getDuration() {
+ currentLimit = limitCount{Count: 0, Timestamp: now}
+ }
+ } else {
+ currentLimit = limitCount{Count: 0, Timestamp: now}
+ }
+
+ // Increment count
+ currentLimit.Count++
+
+ // Check against limits
+ limit := constant.NotifyLimitCount
+
+ // Store updated count
+ notifyLimitStore.Store(key, currentLimit)
+
+ return currentLimit.Count <= limit, nil
+}
diff --git a/service/quota.go b/service/quota.go
new file mode 100644
index 00000000..0f618402
--- /dev/null
+++ b/service/quota.go
@@ -0,0 +1,510 @@
+package service
+
+import (
+ "errors"
+ "fmt"
+ "log"
+ "math"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ "one-api/model"
+ relaycommon "one-api/relay/common"
+ "one-api/relay/helper"
+ "one-api/setting"
+ "one-api/setting/ratio_setting"
+ "strings"
+ "time"
+
+ "github.com/bytedance/gopkg/util/gopool"
+
+ "github.com/gin-gonic/gin"
+ "github.com/shopspring/decimal"
+)
+
+type TokenDetails struct {
+ TextTokens int
+ AudioTokens int
+}
+
+type QuotaInfo struct {
+ InputDetails TokenDetails
+ OutputDetails TokenDetails
+ ModelName string
+ UsePrice bool
+ ModelPrice float64
+ ModelRatio float64
+ GroupRatio float64
+}
+
+func calculateAudioQuota(info QuotaInfo) int {
+ if info.UsePrice {
+ modelPrice := decimal.NewFromFloat(info.ModelPrice)
+ quotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
+ groupRatio := decimal.NewFromFloat(info.GroupRatio)
+
+ quota := modelPrice.Mul(quotaPerUnit).Mul(groupRatio)
+ return int(quota.IntPart())
+ }
+
+ completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(info.ModelName))
+ audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(info.ModelName))
+ audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(info.ModelName))
+
+ groupRatio := decimal.NewFromFloat(info.GroupRatio)
+ modelRatio := decimal.NewFromFloat(info.ModelRatio)
+ ratio := groupRatio.Mul(modelRatio)
+
+ inputTextTokens := decimal.NewFromInt(int64(info.InputDetails.TextTokens))
+ outputTextTokens := decimal.NewFromInt(int64(info.OutputDetails.TextTokens))
+ inputAudioTokens := decimal.NewFromInt(int64(info.InputDetails.AudioTokens))
+ outputAudioTokens := decimal.NewFromInt(int64(info.OutputDetails.AudioTokens))
+
+ quota := decimal.Zero
+ quota = quota.Add(inputTextTokens)
+ quota = quota.Add(outputTextTokens.Mul(completionRatio))
+ quota = quota.Add(inputAudioTokens.Mul(audioRatio))
+ quota = quota.Add(outputAudioTokens.Mul(audioRatio).Mul(audioCompletionRatio))
+
+ quota = quota.Mul(ratio)
+
+ // If ratio is not zero and quota is less than or equal to zero, set quota to 1
+ if !ratio.IsZero() && quota.LessThanOrEqual(decimal.Zero) {
+ quota = decimal.NewFromInt(1)
+ }
+
+ return int(quota.Round(0).IntPart())
+}
+
+func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error {
+ if relayInfo.UsePrice {
+ return nil
+ }
+ userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
+ if err != nil {
+ return err
+ }
+
+ token, err := model.GetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"), false)
+ if err != nil {
+ return err
+ }
+
+ modelName := relayInfo.OriginModelName
+ textInputTokens := usage.InputTokenDetails.TextTokens
+ textOutTokens := usage.OutputTokenDetails.TextTokens
+ audioInputTokens := usage.InputTokenDetails.AudioTokens
+ audioOutTokens := usage.OutputTokenDetails.AudioTokens
+ groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
+ modelRatio, _, _ := ratio_setting.GetModelRatio(modelName)
+
+ autoGroup, exists := ctx.Get("auto_group")
+ if exists {
+ groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string))
+ log.Printf("final group ratio: %f", groupRatio)
+ relayInfo.UsingGroup = autoGroup.(string)
+ }
+
+ actualGroupRatio := groupRatio
+ userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
+ if ok {
+ actualGroupRatio = userGroupRatio
+ }
+
+ quotaInfo := QuotaInfo{
+ InputDetails: TokenDetails{
+ TextTokens: textInputTokens,
+ AudioTokens: audioInputTokens,
+ },
+ OutputDetails: TokenDetails{
+ TextTokens: textOutTokens,
+ AudioTokens: audioOutTokens,
+ },
+ ModelName: modelName,
+ UsePrice: relayInfo.UsePrice,
+ ModelRatio: modelRatio,
+ GroupRatio: actualGroupRatio,
+ }
+
+ quota := calculateAudioQuota(quotaInfo)
+
+ if userQuota < quota {
+ return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota))
+ }
+
+ if !token.UnlimitedQuota && token.RemainQuota < quota {
+ return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
+ }
+
+ err = PostConsumeQuota(relayInfo, quota, 0, false)
+ if err != nil {
+ return err
+ }
+ common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
+ return nil
+}
+
+func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
+ usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
+
+ useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
+ textInputTokens := usage.InputTokenDetails.TextTokens
+ textOutTokens := usage.OutputTokenDetails.TextTokens
+
+ audioInputTokens := usage.InputTokenDetails.AudioTokens
+ audioOutTokens := usage.OutputTokenDetails.AudioTokens
+
+ tokenName := ctx.GetString("token_name")
+ completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(modelName))
+ audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
+ audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
+
+ modelRatio := priceData.ModelRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
+ modelPrice := priceData.ModelPrice
+ usePrice := priceData.UsePrice
+
+ quotaInfo := QuotaInfo{
+ InputDetails: TokenDetails{
+ TextTokens: textInputTokens,
+ AudioTokens: audioInputTokens,
+ },
+ OutputDetails: TokenDetails{
+ TextTokens: textOutTokens,
+ AudioTokens: audioOutTokens,
+ },
+ ModelName: modelName,
+ UsePrice: usePrice,
+ ModelRatio: modelRatio,
+ GroupRatio: groupRatio,
+ }
+
+ quota := calculateAudioQuota(quotaInfo)
+
+ totalTokens := usage.TotalTokens
+ var logContent string
+ if !usePrice {
+ logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
+ modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
+ } else {
+ logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
+ }
+
+ // record all the consume log even if quota is 0
+ if totalTokens == 0 {
+ // in this case, must be some error happened
+ // we cannot just return, because we may have to return the pre-consumed quota
+ quota = 0
+ logContent += fmt.Sprintf("(可能是上游超时)")
+ common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
+ } else {
+ model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
+ model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
+ }
+
+ logModel := modelName
+ if extraContent != "" {
+ logContent += ", " + extraContent
+ }
+ other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
+ completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+ model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
+ ChannelId: relayInfo.ChannelId,
+ PromptTokens: usage.InputTokens,
+ CompletionTokens: usage.OutputTokens,
+ ModelName: logModel,
+ TokenName: tokenName,
+ Quota: quota,
+ Content: logContent,
+ TokenId: relayInfo.TokenId,
+ UserQuota: userQuota,
+ UseTimeSeconds: int(useTimeSeconds),
+ IsStream: relayInfo.IsStream,
+ Group: relayInfo.UsingGroup,
+ Other: other,
+ })
+}
+
+func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
+ usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
+
+ useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
+ promptTokens := usage.PromptTokens
+ completionTokens := usage.CompletionTokens
+ modelName := relayInfo.OriginModelName
+
+ tokenName := ctx.GetString("token_name")
+ completionRatio := priceData.CompletionRatio
+ modelRatio := priceData.ModelRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
+ modelPrice := priceData.ModelPrice
+ cacheRatio := priceData.CacheRatio
+ cacheTokens := usage.PromptTokensDetails.CachedTokens
+
+ cacheCreationRatio := priceData.CacheCreationRatio
+ cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
+
+ if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
+ promptTokens -= cacheTokens
+ if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 {
+ maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData)
+ if promptTokens >= maybeCacheCreationTokens {
+ cacheCreationTokens = maybeCacheCreationTokens
+ }
+ }
+ promptTokens -= cacheCreationTokens
+ }
+
+ calculateQuota := 0.0
+ if !priceData.UsePrice {
+ calculateQuota = float64(promptTokens)
+ calculateQuota += float64(cacheTokens) * cacheRatio
+ calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio
+ calculateQuota += float64(completionTokens) * completionRatio
+ calculateQuota = calculateQuota * groupRatio * modelRatio
+ } else {
+ calculateQuota = modelPrice * common.QuotaPerUnit * groupRatio
+ }
+
+ if modelRatio != 0 && calculateQuota <= 0 {
+ calculateQuota = 1
+ }
+
+ quota := int(calculateQuota)
+
+ totalTokens := promptTokens + completionTokens
+
+ var logContent string
+ // record all the consume log even if quota is 0
+ if totalTokens == 0 {
+ // in this case, must be some error happened
+ // we cannot just return, because we may have to return the pre-consumed quota
+ quota = 0
+ logContent += fmt.Sprintf("(可能是上游出错)")
+ common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
+ } else {
+ model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
+ model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
+ }
+
+ quotaDelta := quota - preConsumedQuota
+ if quotaDelta != 0 {
+ err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
+ if err != nil {
+ common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+ }
+ }
+
+ other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
+ cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+ model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
+ ChannelId: relayInfo.ChannelId,
+ PromptTokens: promptTokens,
+ CompletionTokens: completionTokens,
+ ModelName: modelName,
+ TokenName: tokenName,
+ Quota: quota,
+ Content: logContent,
+ TokenId: relayInfo.TokenId,
+ UserQuota: userQuota,
+ UseTimeSeconds: int(useTimeSeconds),
+ IsStream: relayInfo.IsStream,
+ Group: relayInfo.UsingGroup,
+ Other: other,
+ })
+
+}
+
+func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int {
+ if priceData.CacheCreationRatio == 1 {
+ return 0
+ }
+ quotaPrice := priceData.ModelRatio / common.QuotaPerUnit
+ promptCacheCreatePrice := quotaPrice * priceData.CacheCreationRatio
+ promptCacheReadPrice := quotaPrice * priceData.CacheRatio
+ completionPrice := quotaPrice * priceData.CompletionRatio
+
+ cost, _ := usage.Cost.(float64)
+ totalPromptTokens := float64(usage.PromptTokens)
+ completionTokens := float64(usage.CompletionTokens)
+ promptCacheReadTokens := float64(usage.PromptTokensDetails.CachedTokens)
+
+ return int(math.Round((cost -
+ totalPromptTokens*quotaPrice +
+ promptCacheReadTokens*(quotaPrice-promptCacheReadPrice) -
+ completionTokens*completionPrice) /
+ (promptCacheCreatePrice - quotaPrice)))
+}
+
+func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
+ usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
+
+ useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
+ textInputTokens := usage.PromptTokensDetails.TextTokens
+ textOutTokens := usage.CompletionTokenDetails.TextTokens
+
+ audioInputTokens := usage.PromptTokensDetails.AudioTokens
+ audioOutTokens := usage.CompletionTokenDetails.AudioTokens
+
+ tokenName := ctx.GetString("token_name")
+ completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(relayInfo.OriginModelName))
+ audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
+ audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
+
+ modelRatio := priceData.ModelRatio
+ groupRatio := priceData.GroupRatioInfo.GroupRatio
+ modelPrice := priceData.ModelPrice
+ usePrice := priceData.UsePrice
+
+ quotaInfo := QuotaInfo{
+ InputDetails: TokenDetails{
+ TextTokens: textInputTokens,
+ AudioTokens: audioInputTokens,
+ },
+ OutputDetails: TokenDetails{
+ TextTokens: textOutTokens,
+ AudioTokens: audioOutTokens,
+ },
+ ModelName: relayInfo.OriginModelName,
+ UsePrice: usePrice,
+ ModelRatio: modelRatio,
+ GroupRatio: groupRatio,
+ }
+
+ quota := calculateAudioQuota(quotaInfo)
+
+ totalTokens := usage.TotalTokens
+ var logContent string
+ if !usePrice {
+ logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f",
+ modelRatio, completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), groupRatio)
+ } else {
+ logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
+ }
+
+ // record all the consume log even if quota is 0
+ if totalTokens == 0 {
+ // in this case, must be some error happened
+ // we cannot just return, because we may have to return the pre-consumed quota
+ quota = 0
+ logContent += fmt.Sprintf("(可能是上游超时)")
+ common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
+ } else {
+ model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
+ model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
+ }
+
+ quotaDelta := quota - preConsumedQuota
+ if quotaDelta != 0 {
+ err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
+ if err != nil {
+ common.LogError(ctx, "error consuming token remain quota: "+err.Error())
+ }
+ }
+
+ logModel := relayInfo.OriginModelName
+ if extraContent != "" {
+ logContent += ", " + extraContent
+ }
+ other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
+ completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
+ model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
+ ChannelId: relayInfo.ChannelId,
+ PromptTokens: usage.PromptTokens,
+ CompletionTokens: usage.CompletionTokens,
+ ModelName: logModel,
+ TokenName: tokenName,
+ Quota: quota,
+ Content: logContent,
+ TokenId: relayInfo.TokenId,
+ UserQuota: userQuota,
+ UseTimeSeconds: int(useTimeSeconds),
+ IsStream: relayInfo.IsStream,
+ Group: relayInfo.UsingGroup,
+ Other: other,
+ })
+}
+
+func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
+ if quota < 0 {
+ return errors.New("quota 不能为负数!")
+ }
+ if relayInfo.IsPlayground {
+ return nil
+ }
+ //if relayInfo.TokenUnlimited {
+ // return nil
+ //}
+ token, err := model.GetTokenByKey(relayInfo.TokenKey, false)
+ if err != nil {
+ return err
+ }
+ if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
+ return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
+ }
+ err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) {
+
+ if quota > 0 {
+ err = model.DecreaseUserQuota(relayInfo.UserId, quota)
+ } else {
+ err = model.IncreaseUserQuota(relayInfo.UserId, -quota, false)
+ }
+ if err != nil {
+ return err
+ }
+
+ if !relayInfo.IsPlayground {
+ if quota > 0 {
+ err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
+ } else {
+ err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ if sendEmail {
+ if (quota + preConsumedQuota) != 0 {
+ checkAndSendQuotaNotify(relayInfo, quota, preConsumedQuota)
+ }
+ }
+
+ return nil
+}
+
+func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int) {
+ gopool.Go(func() {
+ userSetting := relayInfo.UserSetting
+ threshold := common.QuotaRemindThreshold
+ if userSetting.QuotaWarningThreshold != 0 {
+ threshold = int(userSetting.QuotaWarningThreshold)
+ }
+
+ //noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
+ quotaTooLow := false
+ consumeQuota := quota + preConsumedQuota
+ if relayInfo.UserQuota-consumeQuota < threshold {
+ quotaTooLow = true
+ }
+ if quotaTooLow {
+ prompt := "您的额度即将用尽"
+ topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
+ content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}"
+ err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
+ if err != nil {
+ common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
+ }
+ }
+ })
+}
diff --git a/service/sensitive.go b/service/sensitive.go
new file mode 100644
index 00000000..b3e3c4d6
--- /dev/null
+++ b/service/sensitive.go
@@ -0,0 +1,94 @@
+package service
+
+import (
+ "errors"
+ "fmt"
+ "one-api/dto"
+ "one-api/setting"
+ "strings"
+)
+
+func CheckSensitiveMessages(messages []dto.Message) ([]string, error) {
+ if len(messages) == 0 {
+ return nil, nil
+ }
+
+ for _, message := range messages {
+ arrayContent := message.ParseContent()
+ for _, m := range arrayContent {
+ if m.Type == "image_url" {
+ // TODO: check image url
+ continue
+ }
+ // 检查 text 是否为空
+ if m.Text == "" {
+ continue
+ }
+ if ok, words := SensitiveWordContains(m.Text); ok {
+ return words, errors.New("sensitive words detected")
+ }
+ }
+ }
+ return nil, nil
+}
+
+func CheckSensitiveText(text string) ([]string, error) {
+ if ok, words := SensitiveWordContains(text); ok {
+ return words, errors.New("sensitive words detected")
+ }
+ return nil, nil
+}
+
+func CheckSensitiveInput(input any) ([]string, error) {
+ switch v := input.(type) {
+ case string:
+ return CheckSensitiveText(v)
+ case []string:
+ var builder strings.Builder
+ for _, s := range v {
+ builder.WriteString(s)
+ }
+ return CheckSensitiveText(builder.String())
+ }
+ return CheckSensitiveText(fmt.Sprintf("%v", input))
+}
+
+// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
+func SensitiveWordContains(text string) (bool, []string) {
+ if len(setting.SensitiveWords) == 0 {
+ return false, nil
+ }
+ if len(text) == 0 {
+ return false, nil
+ }
+ checkText := strings.ToLower(text)
+ return AcSearch(checkText, setting.SensitiveWords, true)
+}
+
+// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
+func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) {
+ if len(setting.SensitiveWords) == 0 {
+ return false, nil, text
+ }
+ checkText := strings.ToLower(text)
+ m := InitAc(setting.SensitiveWords)
+ hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
+ if len(hits) > 0 {
+ words := make([]string, 0, len(hits))
+ var builder strings.Builder
+ builder.Grow(len(text))
+ lastPos := 0
+
+ for _, hit := range hits {
+ pos := hit.Pos
+ word := string(hit.Word)
+ builder.WriteString(text[lastPos:pos])
+ builder.WriteString("**###**")
+ lastPos = pos + len(word)
+ words = append(words, word)
+ }
+ builder.WriteString(text[lastPos:])
+ return true, words, builder.String()
+ }
+ return false, nil, text
+}
diff --git a/service/str.go b/service/str.go
new file mode 100644
index 00000000..4390e99b
--- /dev/null
+++ b/service/str.go
@@ -0,0 +1,101 @@
+package service
+
+import (
+ "bytes"
+ "fmt"
+ goahocorasick "github.com/anknown/ahocorasick"
+ "strings"
+)
+
+func SundaySearch(text string, pattern string) bool {
+ // 计算偏移表
+ offset := make(map[rune]int)
+ for i, c := range pattern {
+ offset[c] = len(pattern) - i
+ }
+
+ // 文本串长度和模式串长度
+ n, m := len(text), len(pattern)
+
+ // 主循环,i表示当前对齐的文本串位置
+ for i := 0; i <= n-m; {
+ // 检查子串
+ j := 0
+ for j < m && text[i+j] == pattern[j] {
+ j++
+ }
+ // 如果完全匹配,返回匹配位置
+ if j == m {
+ return true
+ }
+
+ // 如果还有剩余字符,则检查下一位字符在偏移表中的值
+ if i+m < n {
+ next := rune(text[i+m])
+ if val, ok := offset[next]; ok {
+ i += val // 存在于偏移表中,进行跳跃
+ } else {
+ i += len(pattern) + 1 // 不存在于偏移表中,跳过整个模式串长度
+ }
+ } else {
+ break
+ }
+ }
+ return false // 如果没有找到匹配,返回-1
+}
+
+func RemoveDuplicate(s []string) []string {
+ result := make([]string, 0, len(s))
+ temp := map[string]struct{}{}
+ for _, item := range s {
+ if _, ok := temp[item]; !ok {
+ temp[item] = struct{}{}
+ result = append(result, item)
+ }
+ }
+ return result
+}
+
+func InitAc(words []string) *goahocorasick.Machine {
+ m := new(goahocorasick.Machine)
+ dict := readRunes(words)
+ if err := m.Build(dict); err != nil {
+ fmt.Println(err)
+ return nil
+ }
+ return m
+}
+
+func readRunes(words []string) [][]rune {
+ var dict [][]rune
+
+ for _, word := range words {
+ word = strings.ToLower(word)
+ l := bytes.TrimSpace([]byte(word))
+ dict = append(dict, bytes.Runes(l))
+ }
+
+ return dict
+}
+
+func AcSearch(findText string, dict []string, stopImmediately bool) (bool, []string) {
+ if len(dict) == 0 {
+ return false, nil
+ }
+ if len(findText) == 0 {
+ return false, nil
+ }
+ m := InitAc(dict)
+ if m == nil {
+ return false, nil
+ }
+ hits := m.MultiPatternSearch([]rune(findText), stopImmediately)
+ if len(hits) > 0 {
+ words := make([]string, 0)
+ for _, hit := range hits {
+ words = append(words, string(hit.Word))
+ }
+ return true, words
+ }
+ return false, nil
+}
diff --git a/service/task.go b/service/task.go
new file mode 100644
index 00000000..c2501fe2
--- /dev/null
+++ b/service/task.go
@@ -0,0 +1,10 @@
+package service
+
+import (
+ "one-api/constant"
+ "strings"
+)
+
+func CoverTaskActionToModelName(platform constant.TaskPlatform, action string) string {
+ return strings.ToLower(string(platform)) + "_" + strings.ToLower(action)
+}
diff --git a/service/token_counter.go b/service/token_counter.go
new file mode 100644
index 00000000..eed5b5ca
--- /dev/null
+++ b/service/token_counter.go
@@ -0,0 +1,474 @@
+package service
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "github.com/tiktoken-go/tokenizer"
+ "github.com/tiktoken-go/tokenizer/codec"
+ "image"
+ "log"
+ "math"
+ "one-api/common"
+ "one-api/constant"
+ "one-api/dto"
+ relaycommon "one-api/relay/common"
+ "strings"
+ "sync"
+ "unicode/utf8"
+)
+
+// tokenEncoderMap won't grow after initialization
+var defaultTokenEncoder tokenizer.Codec
+
+// tokenEncoderMap is used to store token encoders for different models
+var tokenEncoderMap = make(map[string]tokenizer.Codec)
+
+// tokenEncoderMutex protects tokenEncoderMap for concurrent access
+var tokenEncoderMutex sync.RWMutex
+
+func InitTokenEncoders() {
+ common.SysLog("initializing token encoders")
+ defaultTokenEncoder = codec.NewCl100kBase()
+ common.SysLog("token encoders initialized")
+}
+
+func getTokenEncoder(model string) tokenizer.Codec {
+ // First, try to get the encoder from cache with read lock
+ tokenEncoderMutex.RLock()
+ if encoder, exists := tokenEncoderMap[model]; exists {
+ tokenEncoderMutex.RUnlock()
+ return encoder
+ }
+ tokenEncoderMutex.RUnlock()
+
+ // If not in cache, create new encoder with write lock
+ tokenEncoderMutex.Lock()
+ defer tokenEncoderMutex.Unlock()
+
+ // Double-check if another goroutine already created the encoder
+ if encoder, exists := tokenEncoderMap[model]; exists {
+ return encoder
+ }
+
+ // Create new encoder
+ modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
+ if err != nil {
+ // Cache the default encoder for this model to avoid repeated failures
+ tokenEncoderMap[model] = defaultTokenEncoder
+ return defaultTokenEncoder
+ }
+
+ // Cache the new encoder
+ tokenEncoderMap[model] = modelCodec
+ return modelCodec
+}
+
+func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
+ if text == "" {
+ return 0
+ }
+ tkm, _ := tokenEncoder.Count(text)
+ return tkm
+}
+
+func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
+ if imageUrl == nil {
+ return 0, fmt.Errorf("image_url_is_nil")
+ }
+ baseTokens := 85
+ if model == "glm-4v" {
+ return 1047, nil
+ }
+ if imageUrl.Detail == "low" {
+ return baseTokens, nil
+ }
+ if !constant.GetMediaTokenNotStream && !stream {
+ return 3 * baseTokens, nil
+ }
+
+ // 同步One API的图片计费逻辑
+ if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
+ imageUrl.Detail = "high"
+ }
+
+ tileTokens := 170
+ if strings.HasPrefix(model, "gpt-4o-mini") {
+ tileTokens = 5667
+ baseTokens = 2833
+ }
+ // 是否统计图片token
+ if !constant.GetMediaToken {
+ return 3 * baseTokens, nil
+ }
+ if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic {
+ return 3 * baseTokens, nil
+ }
+ var config image.Config
+ var err error
+ var format string
+ var b64str string
+ if strings.HasPrefix(imageUrl.Url, "http") {
+ config, format, err = DecodeUrlImageData(imageUrl.Url)
+ } else {
+ common.SysLog(fmt.Sprintf("decoding image"))
+ config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url)
+ }
+ if err != nil {
+ return 0, err
+ }
+ imageUrl.MimeType = format
+
+ if config.Width == 0 || config.Height == 0 {
+ // not an image
+ if format != "" && b64str != "" {
+ // file type
+ return 3 * baseTokens, nil
+ }
+ return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url))
+ }
+
+ shortSide := config.Width
+ otherSide := config.Height
+ log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height)
+ // 缩放倍数
+ scale := 1.0
+ if config.Height < shortSide {
+ shortSide = config.Height
+ otherSide = config.Width
+ }
+
+ // 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768
+ if shortSide > 768 {
+ scale = float64(shortSide) / 768
+ shortSide = 768
+ }
+ // 将另一边按照相同的比例缩小,向上取整
+ otherSide = int(math.Ceil(float64(otherSide) / scale))
+ log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale)
+ // 计算图片的token数量(边的长度除以512,向上取整)
+ tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512)
+ log.Printf("tiles: %d", tiles)
+ return tiles*tileTokens + baseTokens, nil
+}
+
+func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
+ tkm := 0
+ msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
+ if err != nil {
+ return 0, err
+ }
+ tkm += msgTokens
+ if request.Tools != nil {
+ openaiTools := request.Tools
+ countStr := ""
+ for _, tool := range openaiTools {
+ countStr = tool.Function.Name
+ if tool.Function.Description != "" {
+ countStr += tool.Function.Description
+ }
+ if tool.Function.Parameters != nil {
+ countStr += fmt.Sprintf("%v", tool.Function.Parameters)
+ }
+ }
+ toolTokens := CountTokenInput(countStr, request.Model)
+ tkm += 8
+ tkm += toolTokens
+ }
+
+ return tkm, nil
+}
+
+func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
+ tkm := 0
+
+ // Count tokens in messages
+ msgTokens, err := CountTokenClaudeMessages(request.Messages, model, request.Stream)
+ if err != nil {
+ return 0, err
+ }
+ tkm += msgTokens
+
+ // Count tokens in system message
+ if request.System != "" {
+ systemTokens := CountTokenInput(request.System, model)
+ tkm += systemTokens
+ }
+
+ if request.Tools != nil {
+ // check is array
+ if tools, ok := request.Tools.([]any); ok {
+ if len(tools) > 0 {
+ parsedTools, err1 := common.Any2Type[[]dto.Tool](request.Tools)
+ if err1 != nil {
+ return 0, fmt.Errorf("tools: Input should be a valid list: %v", err)
+ }
+ toolTokens, err2 := CountTokenClaudeTools(parsedTools, model)
+ if err2 != nil {
+ return 0, fmt.Errorf("tools: %v", err)
+ }
+ tkm += toolTokens
+ }
+ } else {
+ return 0, errors.New("tools: Input should be a valid list")
+ }
+ }
+
+ return tkm, nil
+}
+
+func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream bool) (int, error) {
+ tokenEncoder := getTokenEncoder(model)
+ tokenNum := 0
+
+ for _, message := range messages {
+ // Count tokens for role
+ tokenNum += getTokenNum(tokenEncoder, message.Role)
+ if message.IsStringContent() {
+ tokenNum += getTokenNum(tokenEncoder, message.GetStringContent())
+ } else {
+ content, err := message.ParseContent()
+ if err != nil {
+ return 0, err
+ }
+ for _, mediaMessage := range content {
+ switch mediaMessage.Type {
+ case "text":
+ tokenNum += getTokenNum(tokenEncoder, mediaMessage.GetText())
+ case "image":
+ //imageTokenNum, err := getClaudeImageToken(mediaMsg.Source, model, stream)
+ //if err != nil {
+ // return 0, err
+ //}
+ tokenNum += 1000
+ case "tool_use":
+ if mediaMessage.Input != nil {
+ tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
+ inputJSON, _ := json.Marshal(mediaMessage.Input)
+ tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
+ }
+ case "tool_result":
+ if mediaMessage.Content != nil {
+ contentJSON, _ := json.Marshal(mediaMessage.Content)
+ tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
+ }
+ }
+ }
+ }
+ }
+
+ // Add a constant for message formatting (this may need adjustment based on Claude's exact formatting)
+ tokenNum += len(messages) * 2 // Assuming 2 tokens per message for formatting
+
+ return tokenNum, nil
+}
+
+func CountTokenClaudeTools(tools []dto.Tool, model string) (int, error) {
+ tokenEncoder := getTokenEncoder(model)
+ tokenNum := 0
+
+ for _, tool := range tools {
+ tokenNum += getTokenNum(tokenEncoder, tool.Name)
+ tokenNum += getTokenNum(tokenEncoder, tool.Description)
+
+ schemaJSON, err := json.Marshal(tool.InputSchema)
+ if err != nil {
+ return 0, errors.New(fmt.Sprintf("marshal_tool_schema_fail: %s", err.Error()))
+ }
+ tokenNum += getTokenNum(tokenEncoder, string(schemaJSON))
+ }
+
+ // Add a constant for tool formatting (this may need adjustment based on Claude's exact formatting)
+ tokenNum += len(tools) * 3 // Assuming 3 tokens per tool for formatting
+
+ return tokenNum, nil
+}
+
+func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
+ audioToken := 0
+ textToken := 0
+ switch request.Type {
+ case dto.RealtimeEventTypeSessionUpdate:
+ if request.Session != nil {
+ msgTokens := CountTextToken(request.Session.Instructions, model)
+ textToken += msgTokens
+ }
+ case dto.RealtimeEventResponseAudioDelta:
+ // count audio token
+ atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
+ if err != nil {
+ return 0, 0, fmt.Errorf("error counting audio token: %v", err)
+ }
+ audioToken += atk
+ case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
+ // count text token
+ tkm := CountTextToken(request.Delta, model)
+ textToken += tkm
+ case dto.RealtimeEventInputAudioBufferAppend:
+ // count audio token
+ atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
+ if err != nil {
+ return 0, 0, fmt.Errorf("error counting audio token: %v", err)
+ }
+ audioToken += atk
+ case dto.RealtimeEventConversationItemCreated:
+ if request.Item != nil {
+ switch request.Item.Type {
+ case "message":
+ for _, content := range request.Item.Content {
+ if content.Type == "input_text" {
+ tokens := CountTextToken(content.Text, model)
+ textToken += tokens
+ }
+ }
+ }
+ }
+ case dto.RealtimeEventTypeResponseDone:
+ // count tools token
+ if !info.IsFirstRequest {
+ if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
+ for _, tool := range info.RealtimeTools {
+ toolTokens := CountTokenInput(tool, model)
+ textToken += 8
+ textToken += toolTokens
+ }
+ }
+ }
+ }
+ return textToken, audioToken, nil
+}
+
+func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
+ //recover when panic
+ tokenEncoder := getTokenEncoder(model)
+ // Reference:
+ // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
+ // https://github.com/pkoukk/tiktoken-go/issues/6
+ //
+ // Every message follows <|start|>{role/name}\n{content}<|end|>\n
+ var tokensPerMessage int
+ var tokensPerName int
+ if model == "gpt-3.5-turbo-0301" {
+ tokensPerMessage = 4
+ tokensPerName = -1 // If there's a name, the role is omitted
+ } else {
+ tokensPerMessage = 3
+ tokensPerName = 1
+ }
+ tokenNum := 0
+ for _, message := range messages {
+ tokenNum += tokensPerMessage
+ tokenNum += getTokenNum(tokenEncoder, message.Role)
+ if message.Content != nil {
+ if message.Name != nil {
+ tokenNum += tokensPerName
+ tokenNum += getTokenNum(tokenEncoder, *message.Name)
+ }
+ arrayContent := message.ParseContent()
+ for _, m := range arrayContent {
+ if m.Type == dto.ContentTypeImageURL {
+ imageUrl := m.GetImageMedia()
+ imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
+ if err != nil {
+ return 0, err
+ }
+ tokenNum += imageTokenNum
+ log.Printf("image token num: %d", imageTokenNum)
+ } else if m.Type == dto.ContentTypeInputAudio {
+ // TODO: 音频token数量计算
+ tokenNum += 100
+ } else if m.Type == dto.ContentTypeFile {
+ tokenNum += 5000
+ } else if m.Type == dto.ContentTypeVideoUrl {
+ tokenNum += 5000
+ } else {
+ tokenNum += getTokenNum(tokenEncoder, m.Text)
+ }
+ }
+ }
+ }
+ tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
+ return tokenNum, nil
+}
+
+func CountTokenInput(input any, model string) int {
+ switch v := input.(type) {
+ case string:
+ return CountTextToken(v, model)
+ case []string:
+ text := ""
+ for _, s := range v {
+ text += s
+ }
+ return CountTextToken(text, model)
+ case []interface{}:
+ text := ""
+ for _, item := range v {
+ text += fmt.Sprintf("%v", item)
+ }
+ return CountTextToken(text, model)
+ }
+ return CountTokenInput(fmt.Sprintf("%v", input), model)
+}
+
+func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
+ tokens := 0
+ for _, message := range messages {
+ tkm := CountTokenInput(message.Delta.GetContentString(), model)
+ tokens += tkm
+ if message.Delta.ToolCalls != nil {
+ for _, tool := range message.Delta.ToolCalls {
+ tkm := CountTokenInput(tool.Function.Name, model)
+ tokens += tkm
+ tkm = CountTokenInput(tool.Function.Arguments, model)
+ tokens += tkm
+ }
+ }
+ }
+ return tokens
+}
+
+func CountTTSToken(text string, model string) int {
+ if strings.HasPrefix(model, "tts") {
+ return utf8.RuneCountInString(text)
+ } else {
+ return CountTextToken(text, model)
+ }
+}
+
+func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
+ if audioBase64 == "" {
+ return 0, nil
+ }
+ duration, err := parseAudio(audioBase64, audioFormat)
+ if err != nil {
+ return 0, err
+ }
+ return int(duration / 60 * 100 / 0.06), nil
+}
+
+func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
+ if audioBase64 == "" {
+ return 0, nil
+ }
+ duration, err := parseAudio(audioBase64, audioFormat)
+ if err != nil {
+ return 0, err
+ }
+ return int(duration / 60 * 200 / 0.24), nil
+}
+
+//func CountAudioToken(sec float64, audioType string) {
+// if audioType == "input" {
+//
+// }
+//}
+
+// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
+func CountTextToken(text string, model string) int {
+ if text == "" {
+ return 0
+ }
+ tokenEncoder := getTokenEncoder(model)
+ return getTokenNum(tokenEncoder, text)
+}
diff --git a/service/usage_helpr.go b/service/usage_helpr.go
new file mode 100644
index 00000000..ca9c0830
--- /dev/null
+++ b/service/usage_helpr.go
@@ -0,0 +1,30 @@
+package service
+
+import (
+ "one-api/dto"
+)
+
+//func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
+// switch relayMode {
+// case constant.RelayModeChatCompletions:
+// return CountTokenMessages(textRequest.Messages, textRequest.Model)
+// case constant.RelayModeCompletions:
+// return CountTokenInput(textRequest.Prompt, textRequest.Model), nil
+// case constant.RelayModeModerations:
+// return CountTokenInput(textRequest.Input, textRequest.Model), nil
+// }
+// return 0, errors.New("unknown relay mode")
+//}
+
+func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
+ usage := &dto.Usage{}
+ usage.PromptTokens = promptTokens
+ ctkm := CountTextToken(responseText, modeName)
+ usage.CompletionTokens = ctkm
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ return usage
+}
+
+func ValidUsage(usage *dto.Usage) bool {
+ return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0)
+}
diff --git a/service/user_notify.go b/service/user_notify.go
new file mode 100644
index 00000000..96664007
--- /dev/null
+++ b/service/user_notify.go
@@ -0,0 +1,66 @@
+package service
+
+import (
+ "fmt"
+ "one-api/common"
+ "one-api/dto"
+ "one-api/model"
+ "strings"
+)
+
+func NotifyRootUser(t string, subject string, content string) {
+ user := model.GetRootUser().ToBaseUser()
+ err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
+ if err != nil {
+ common.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error()))
+ }
+}
+
+func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data dto.Notify) error {
+ notifyType := userSetting.NotifyType
+ if notifyType == "" {
+ notifyType = dto.NotifyTypeEmail
+ }
+
+ // Check notification limit
+ canSend, err := CheckNotificationLimit(userId, data.Type)
+ if err != nil {
+ common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
+ return err
+ }
+ if !canSend {
+ return fmt.Errorf("notification limit exceeded for user %d with type %s", userId, notifyType)
+ }
+
+ switch notifyType {
+ case dto.NotifyTypeEmail:
+ // check setting email
+ userEmail = userSetting.NotificationEmail
+ if userEmail == "" {
+ common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId))
+ return nil
+ }
+ return sendEmailNotify(userEmail, data)
+ case dto.NotifyTypeWebhook:
+ webhookURLStr := userSetting.WebhookUrl
+ if webhookURLStr == "" {
+ common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
+ return nil
+ }
+
+ // 获取 webhook secret
+ webhookSecret := userSetting.WebhookSecret
+ return SendWebhookNotify(webhookURLStr, webhookSecret, data)
+ }
+ return nil
+}
+
+func sendEmailNotify(userEmail string, data dto.Notify) error {
+ // make email content
+ content := data.Content
+ // 处理占位符
+ for _, value := range data.Values {
+ content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1)
+ }
+ return common.SendEmail(data.Title, userEmail, content)
+}
diff --git a/service/webhook.go b/service/webhook.go
new file mode 100644
index 00000000..8faccda3
--- /dev/null
+++ b/service/webhook.go
@@ -0,0 +1,118 @@
+package service
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "one-api/dto"
+ "one-api/setting"
+ "time"
+)
+
+// WebhookPayload webhook 通知的负载数据
+type WebhookPayload struct {
+ Type string `json:"type"`
+ Title string `json:"title"`
+ Content string `json:"content"`
+ Values []interface{} `json:"values,omitempty"`
+ Timestamp int64 `json:"timestamp"`
+}
+
+// generateSignature 生成 webhook 签名
+func generateSignature(secret string, payload []byte) string {
+ h := hmac.New(sha256.New, []byte(secret))
+ h.Write(payload)
+ return hex.EncodeToString(h.Sum(nil))
+}
+
+// SendWebhookNotify 发送 webhook 通知
+func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error {
+ // 处理占位符
+ content := data.Content
+ for _, value := range data.Values {
+ content = fmt.Sprintf(content, value)
+ }
+
+ // 构建 webhook 负载
+ payload := WebhookPayload{
+ Type: data.Type,
+ Title: data.Title,
+ Content: content,
+ Values: data.Values,
+ Timestamp: time.Now().Unix(),
+ }
+
+ // 序列化负载
+ payloadBytes, err := json.Marshal(payload)
+ if err != nil {
+ return fmt.Errorf("failed to marshal webhook payload: %v", err)
+ }
+
+ // 创建 HTTP 请求
+ var req *http.Request
+ var resp *http.Response
+
+ if setting.EnableWorker() {
+ // 构建worker请求数据
+ workerReq := &WorkerRequest{
+ URL: webhookURL,
+ Key: setting.WorkerValidKey,
+ Method: http.MethodPost,
+ Headers: map[string]string{
+ "Content-Type": "application/json",
+ },
+ Body: payloadBytes,
+ }
+
+ // 如果有secret,添加签名到headers
+ if secret != "" {
+ signature := generateSignature(secret, payloadBytes)
+ workerReq.Headers["X-Webhook-Signature"] = signature
+ workerReq.Headers["Authorization"] = "Bearer " + secret
+ }
+
+ resp, err = DoWorkerRequest(workerReq)
+ if err != nil {
+ return fmt.Errorf("failed to send webhook request through worker: %v", err)
+ }
+ defer resp.Body.Close()
+
+ // 检查响应状态
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
+ }
+ } else {
+ req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes))
+ if err != nil {
+ return fmt.Errorf("failed to create webhook request: %v", err)
+ }
+
+ // 设置请求头
+ req.Header.Set("Content-Type", "application/json")
+
+ // 如果有 secret,生成签名
+ if secret != "" {
+ signature := generateSignature(secret, payloadBytes)
+ req.Header.Set("X-Webhook-Signature", signature)
+ }
+
+ // 发送请求
+ client := GetHttpClient()
+ resp, err = client.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to send webhook request: %v", err)
+ }
+ defer resp.Body.Close()
+
+ // 检查响应状态
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
+ }
+ }
+
+ return nil
+}
diff --git a/setting/auto_group.go b/setting/auto_group.go
new file mode 100644
index 00000000..5a87ae56
--- /dev/null
+++ b/setting/auto_group.go
@@ -0,0 +1,31 @@
+package setting
+
+import "encoding/json"
+
+var AutoGroups = []string{
+ "default",
+}
+
+var DefaultUseAutoGroup = false
+
+func ContainsAutoGroup(group string) bool {
+ for _, autoGroup := range AutoGroups {
+ if autoGroup == group {
+ return true
+ }
+ }
+ return false
+}
+
+func UpdateAutoGroupsByJsonString(jsonString string) error {
+ AutoGroups = make([]string, 0)
+ return json.Unmarshal([]byte(jsonString), &AutoGroups)
+}
+
+func AutoGroups2JsonString() string {
+ jsonBytes, err := json.Marshal(AutoGroups)
+ if err != nil {
+ return "[]"
+ }
+ return string(jsonBytes)
+}
diff --git a/setting/chat.go b/setting/chat.go
new file mode 100644
index 00000000..53cb655a
--- /dev/null
+++ b/setting/chat.go
@@ -0,0 +1,41 @@
+package setting
+
+import (
+ "encoding/json"
+ "one-api/common"
+)
+
+var Chats = []map[string]string{
+ //{
+ // "ChatGPT Next Web 官方示例": "https://app.nextchat.dev/#/?settings={\"key\":\"{key}\",\"url\":\"{address}\"}",
+ //},
+ {
+ "Cherry Studio": "cherrystudio://providers/api-keys?v=1&data={cherryConfig}",
+ },
+ {
+ "Lobe Chat 官方示例": "https://chat-preview.lobehub.com/?settings={\"keyVaults\":{\"openai\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\"}}}",
+ },
+ {
+ "AI as Workspace": "https://aiaw.app/set-provider?provider={\"type\":\"openai\",\"settings\":{\"apiKey\":\"{key}\",\"baseURL\":\"{address}/v1\",\"compatibility\":\"strict\"}}",
+ },
+ {
+ "AMA 问天": "ama://set-api-key?server={address}&key={key}",
+ },
+ {
+ "OpenCat": "opencat://team/join?domain={address}&token={key}",
+ },
+}
+
+func UpdateChatsByJsonString(jsonString string) error {
+ Chats = make([]map[string]string, 0)
+ return json.Unmarshal([]byte(jsonString), &Chats)
+}
+
+func Chats2JsonString() string {
+ jsonBytes, err := json.Marshal(Chats)
+ if err != nil {
+ common.SysError("error marshalling chats: " + err.Error())
+ return "[]"
+ }
+ return string(jsonBytes)
+}
diff --git a/setting/config/config.go b/setting/config/config.go
new file mode 100644
index 00000000..3af51b14
--- /dev/null
+++ b/setting/config/config.go
@@ -0,0 +1,259 @@
+package config
+
+import (
+ "encoding/json"
+ "one-api/common"
+ "reflect"
+ "strconv"
+ "strings"
+ "sync"
+)
+
+// ConfigManager 统一管理所有配置
+type ConfigManager struct {
+ configs map[string]interface{}
+ mutex sync.RWMutex
+}
+
+var GlobalConfig = NewConfigManager()
+
+func NewConfigManager() *ConfigManager {
+ return &ConfigManager{
+ configs: make(map[string]interface{}),
+ }
+}
+
+// Register 注册一个配置模块
+func (cm *ConfigManager) Register(name string, config interface{}) {
+ cm.mutex.Lock()
+ defer cm.mutex.Unlock()
+ cm.configs[name] = config
+}
+
+// Get 获取指定配置模块
+func (cm *ConfigManager) Get(name string) interface{} {
+ cm.mutex.RLock()
+ defer cm.mutex.RUnlock()
+ return cm.configs[name]
+}
+
+// LoadFromDB 从数据库加载配置
+func (cm *ConfigManager) LoadFromDB(options map[string]string) error {
+ cm.mutex.Lock()
+ defer cm.mutex.Unlock()
+
+ for name, config := range cm.configs {
+ prefix := name + "."
+ configMap := make(map[string]string)
+
+ // 收集属于此配置的所有选项
+ for key, value := range options {
+ if strings.HasPrefix(key, prefix) {
+ configKey := strings.TrimPrefix(key, prefix)
+ configMap[configKey] = value
+ }
+ }
+
+ // 如果找到配置项,则更新配置
+ if len(configMap) > 0 {
+ if err := updateConfigFromMap(config, configMap); err != nil {
+ common.SysError("failed to update config " + name + ": " + err.Error())
+ continue
+ }
+ }
+ }
+
+ return nil
+}
+
+// SaveToDB 将配置保存到数据库
+func (cm *ConfigManager) SaveToDB(updateFunc func(key, value string) error) error {
+ cm.mutex.RLock()
+ defer cm.mutex.RUnlock()
+
+ for name, config := range cm.configs {
+ configMap, err := configToMap(config)
+ if err != nil {
+ return err
+ }
+
+ for key, value := range configMap {
+ dbKey := name + "." + key
+ if err := updateFunc(dbKey, value); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// 辅助函数:将配置对象转换为map
+func configToMap(config interface{}) (map[string]string, error) {
+ result := make(map[string]string)
+
+ val := reflect.ValueOf(config)
+ if val.Kind() == reflect.Ptr {
+ val = val.Elem()
+ }
+
+ if val.Kind() != reflect.Struct {
+ return nil, nil
+ }
+
+ typ := val.Type()
+ for i := 0; i < val.NumField(); i++ {
+ field := val.Field(i)
+ fieldType := typ.Field(i)
+
+ // 跳过未导出字段
+ if !fieldType.IsExported() {
+ continue
+ }
+
+ // 获取json标签作为键名
+ key := fieldType.Tag.Get("json")
+ if key == "" || key == "-" {
+ key = fieldType.Name
+ }
+
+ // 处理不同类型的字段
+ var strValue string
+ switch field.Kind() {
+ case reflect.String:
+ strValue = field.String()
+ case reflect.Bool:
+ strValue = strconv.FormatBool(field.Bool())
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ strValue = strconv.FormatInt(field.Int(), 10)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ strValue = strconv.FormatUint(field.Uint(), 10)
+ case reflect.Float32, reflect.Float64:
+ strValue = strconv.FormatFloat(field.Float(), 'f', -1, 64)
+ case reflect.Map, reflect.Slice, reflect.Struct:
+ // 复杂类型使用JSON序列化
+ bytes, err := json.Marshal(field.Interface())
+ if err != nil {
+ return nil, err
+ }
+ strValue = string(bytes)
+ default:
+ // 跳过不支持的类型
+ continue
+ }
+
+ result[key] = strValue
+ }
+
+ return result, nil
+}
+
+// 辅助函数:从map更新配置对象
+func updateConfigFromMap(config interface{}, configMap map[string]string) error {
+ val := reflect.ValueOf(config)
+ if val.Kind() != reflect.Ptr {
+ return nil
+ }
+ val = val.Elem()
+
+ if val.Kind() != reflect.Struct {
+ return nil
+ }
+
+ typ := val.Type()
+ for i := 0; i < val.NumField(); i++ {
+ field := val.Field(i)
+ fieldType := typ.Field(i)
+
+ // 跳过未导出字段
+ if !fieldType.IsExported() {
+ continue
+ }
+
+ // 获取json标签作为键名
+ key := fieldType.Tag.Get("json")
+ if key == "" || key == "-" {
+ key = fieldType.Name
+ }
+
+ // 检查map中是否有对应的值
+ strValue, ok := configMap[key]
+ if !ok {
+ continue
+ }
+
+ // 根据字段类型设置值
+ if !field.CanSet() {
+ continue
+ }
+
+ switch field.Kind() {
+ case reflect.String:
+ field.SetString(strValue)
+ case reflect.Bool:
+ boolValue, err := strconv.ParseBool(strValue)
+ if err != nil {
+ continue
+ }
+ field.SetBool(boolValue)
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ intValue, err := strconv.ParseInt(strValue, 10, 64)
+ if err != nil {
+ continue
+ }
+ field.SetInt(intValue)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ uintValue, err := strconv.ParseUint(strValue, 10, 64)
+ if err != nil {
+ continue
+ }
+ field.SetUint(uintValue)
+ case reflect.Float32, reflect.Float64:
+ floatValue, err := strconv.ParseFloat(strValue, 64)
+ if err != nil {
+ continue
+ }
+ field.SetFloat(floatValue)
+ case reflect.Map, reflect.Slice, reflect.Struct:
+ // 复杂类型使用JSON反序列化
+ err := json.Unmarshal([]byte(strValue), field.Addr().Interface())
+ if err != nil {
+ continue
+ }
+ }
+ }
+
+ return nil
+}
+
+// ConfigToMap 将配置对象转换为map(导出函数)
+func ConfigToMap(config interface{}) (map[string]string, error) {
+ return configToMap(config)
+}
+
+// UpdateConfigFromMap 从map更新配置对象(导出函数)
+func UpdateConfigFromMap(config interface{}, configMap map[string]string) error {
+ return updateConfigFromMap(config, configMap)
+}
+
+// ExportAllConfigs 导出所有已注册的配置为扁平结构
+func (cm *ConfigManager) ExportAllConfigs() map[string]string {
+ cm.mutex.RLock()
+ defer cm.mutex.RUnlock()
+
+ result := make(map[string]string)
+
+ for name, cfg := range cm.configs {
+ configMap, err := ConfigToMap(cfg)
+ if err != nil {
+ continue
+ }
+
+ // 使用 "模块名.配置项" 的格式添加到结果中
+ for key, value := range configMap {
+ result[name+"."+key] = value
+ }
+ }
+
+ return result
+}
diff --git a/setting/console_setting/config.go b/setting/console_setting/config.go
new file mode 100644
index 00000000..6327e558
--- /dev/null
+++ b/setting/console_setting/config.go
@@ -0,0 +1,39 @@
+package console_setting
+
+import "one-api/setting/config"
+
+type ConsoleSetting struct {
+ ApiInfo string `json:"api_info"` // 控制台 API 信息 (JSON 数组字符串)
+ UptimeKumaGroups string `json:"uptime_kuma_groups"` // Uptime Kuma 分组配置 (JSON 数组字符串)
+ Announcements string `json:"announcements"` // 系统公告 (JSON 数组字符串)
+ FAQ string `json:"faq"` // 常见问题 (JSON 数组字符串)
+ ApiInfoEnabled bool `json:"api_info_enabled"` // 是否启用 API 信息面板
+ UptimeKumaEnabled bool `json:"uptime_kuma_enabled"` // 是否启用 Uptime Kuma 面板
+ AnnouncementsEnabled bool `json:"announcements_enabled"` // 是否启用系统公告面板
+ FAQEnabled bool `json:"faq_enabled"` // 是否启用常见问答面板
+}
+
+// 默认配置
+var defaultConsoleSetting = ConsoleSetting{
+ ApiInfo: "",
+ UptimeKumaGroups: "",
+ Announcements: "",
+ FAQ: "",
+ ApiInfoEnabled: true,
+ UptimeKumaEnabled: true,
+ AnnouncementsEnabled: true,
+ FAQEnabled: true,
+}
+
+// 全局实例
+var consoleSetting = defaultConsoleSetting
+
+func init() {
+ // 注册到全局配置管理器,键名为 console_setting
+ config.GlobalConfig.Register("console_setting", &consoleSetting)
+}
+
+// GetConsoleSetting 获取 ConsoleSetting 配置实例
+func GetConsoleSetting() *ConsoleSetting {
+ return &consoleSetting
+}
\ No newline at end of file
diff --git a/setting/console_setting/validation.go b/setting/console_setting/validation.go
new file mode 100644
index 00000000..fda6453d
--- /dev/null
+++ b/setting/console_setting/validation.go
@@ -0,0 +1,304 @@
+package console_setting
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/url"
+ "regexp"
+ "strings"
+ "time"
+ "sort"
+)
+
+var (
+ urlRegex = regexp.MustCompile(`^https?://(?:(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?|(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))(?:\:[0-9]{1,5})?(?:/.*)?$`)
+ dangerousChars = []string{"
+