添加完整项目文件
包含Go API项目的所有源代码、配置文件、Docker配置、文档和前端资源 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
15
.claude/settings.local.json
Normal file
15
.claude/settings.local.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(git init:*)",
|
||||
"Bash(touch:*)",
|
||||
"Bash(git checkout:*)",
|
||||
"Bash(git add:*)",
|
||||
"Bash(git commit:*)",
|
||||
"Bash(git remote add:*)",
|
||||
"Bash(git push:*)"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": []
|
||||
}
|
||||
}
|
||||
7
.dockerignore
Normal file
7
.dockerignore
Normal file
@@ -0,0 +1,7 @@
|
||||
.github
|
||||
.git
|
||||
*.md
|
||||
.vscode
|
||||
.gitignore
|
||||
Makefile
|
||||
docs
|
||||
75
.env.example
Normal file
75
.env.example
Normal file
@@ -0,0 +1,75 @@
|
||||
# 端口号
|
||||
# PORT=3000
|
||||
# 前端基础URL
|
||||
# FRONTEND_BASE_URL=https://your-frontend-url.com
|
||||
|
||||
|
||||
# 调试相关配置
|
||||
# 启用pprof
|
||||
# ENABLE_PPROF=true
|
||||
# 启用调试模式
|
||||
# DEBUG=true
|
||||
|
||||
# 数据库相关配置
|
||||
# 数据库连接字符串
|
||||
# SQL_DSN=user:password@tcp(127.0.0.1:3306)/dbname?parseTime=true
|
||||
# 日志数据库连接字符串
|
||||
# LOG_SQL_DSN=user:password@tcp(127.0.0.1:3306)/logdb?parseTime=true
|
||||
# SQLite数据库路径
|
||||
# SQLITE_PATH=/path/to/sqlite.db
|
||||
# 数据库最大空闲连接数
|
||||
# SQL_MAX_IDLE_CONNS=100
|
||||
# 数据库最大打开连接数
|
||||
# SQL_MAX_OPEN_CONNS=1000
|
||||
# 数据库连接最大生命周期(秒)
|
||||
# SQL_MAX_LIFETIME=60
|
||||
|
||||
|
||||
# 缓存相关配置
|
||||
# Redis连接字符串
|
||||
# REDIS_CONN_STRING=redis://user:password@localhost:6379/0
|
||||
# 同步频率(单位:秒)
|
||||
# SYNC_FREQUENCY=60
|
||||
# 内存缓存启用
|
||||
# MEMORY_CACHE_ENABLED=true
|
||||
# 渠道更新频率(单位:秒)
|
||||
# CHANNEL_UPDATE_FREQUENCY=30
|
||||
# 批量更新启用
|
||||
# BATCH_UPDATE_ENABLED=true
|
||||
# 批量更新间隔(单位:秒)
|
||||
# BATCH_UPDATE_INTERVAL=5
|
||||
|
||||
# 任务和功能配置
|
||||
# 更新任务启用
|
||||
# UPDATE_TASK=true
|
||||
|
||||
# 对话超时设置
|
||||
# 所有请求超时时间,单位秒,默认为0,表示不限制
|
||||
# RELAY_TIMEOUT=0
|
||||
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
|
||||
# STREAMING_TIMEOUT=120
|
||||
|
||||
# Gemini 识别图片 最大图片数量
|
||||
# GEMINI_VISION_MAX_IMAGE_NUM=16
|
||||
|
||||
# 会话密钥
|
||||
# SESSION_SECRET=random_string
|
||||
|
||||
# 其他配置
|
||||
# 渠道测试频率(单位:秒)
|
||||
# CHANNEL_TEST_FREQUENCY=10
|
||||
# 生成默认token
|
||||
# GENERATE_DEFAULT_TOKEN=false
|
||||
# Cohere 安全设置
|
||||
# COHERE_SAFETY_SETTING=NONE
|
||||
# 是否统计图片token
|
||||
# GET_MEDIA_TOKEN=true
|
||||
# 是否在非流(stream=false)情况下统计图片token
|
||||
# GET_MEDIA_TOKEN_NOT_STREAM=true
|
||||
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
|
||||
# DIFY_DEBUG=true
|
||||
|
||||
|
||||
# 节点类型
|
||||
# 如果是主节点则为master
|
||||
# NODE_TYPE=master
|
||||
12
.github/FUNDING.yml
vendored
Normal file
12
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: # Replace with a single Open Collective username
|
||||
ko_fi: # Replace with a single Ko-fi username
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
otechie: # Replace with a single Otechie username
|
||||
custom: ['https://afdian.com/a/new-api'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
||||
26
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
26
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
---
|
||||
name: 报告问题
|
||||
about: 使用简练详细的语言描述你遇到的问题
|
||||
title: ''
|
||||
labels: bug
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**例行检查**
|
||||
|
||||
[//]: # (方框内删除已有的空格,填 x 号)
|
||||
+ [ ] 我已确认目前没有类似 issue
|
||||
+ [ ] 我已确认我已升级到最新版本
|
||||
+ [ ] 我已完整查看过项目 README,尤其是常见问题部分
|
||||
+ [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈
|
||||
+ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭**
|
||||
|
||||
**问题描述**
|
||||
|
||||
**复现步骤**
|
||||
|
||||
**预期结果**
|
||||
|
||||
**相关截图**
|
||||
如果没有的话,请删除此节。
|
||||
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: 项目群聊
|
||||
url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg
|
||||
about: QQ 群:629454374
|
||||
21
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
21
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
---
|
||||
name: 功能请求
|
||||
about: 使用简练详细的语言描述希望加入的新功能
|
||||
title: ''
|
||||
labels: enhancement
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**例行检查**
|
||||
|
||||
[//]: # (方框内删除已有的空格,填 x 号)
|
||||
+ [ ] 我已确认目前没有类似 issue
|
||||
+ [ ] 我已确认我已升级到最新版本
|
||||
+ [ ] 我已完整查看过项目 README,已确定现有版本无法满足需求
|
||||
+ [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈
|
||||
+ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭**
|
||||
|
||||
**功能描述**
|
||||
|
||||
**应用场景**
|
||||
19
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
19
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
### PR 类型
|
||||
|
||||
- [ ] Bug 修复
|
||||
- [ ] 新功能
|
||||
- [ ] 文档更新
|
||||
- [ ] 其他
|
||||
|
||||
### PR 是否包含破坏性更新?
|
||||
|
||||
- [ ] 是
|
||||
- [ ] 否
|
||||
|
||||
### PR 描述
|
||||
|
||||
**请在下方详细描述您的 PR,包括目的、实现细节等。**
|
||||
|
||||
### **重要提示**
|
||||
|
||||
**所有 PR 都必须提交到 `alpha` 分支。请确保您的 PR 目标分支是 `alpha`。**
|
||||
62
.github/workflows/docker-image-alpha.yml
vendored
Normal file
62
.github/workflows/docker-image-alpha.yml
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
name: Publish Docker image (alpha)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- alpha
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
description: "reason"
|
||||
required: false
|
||||
|
||||
jobs:
|
||||
push_to_registries:
|
||||
name: Push Docker image to multiple registries
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
packages: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Save version info
|
||||
run: |
|
||||
echo "alpha-$(date +'%Y%m%d')-$(git rev-parse --short HEAD)" > VERSION
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
ghcr.io/${{ github.repository }}
|
||||
tags: |
|
||||
type=raw,value=alpha
|
||||
type=raw,value=alpha-{{date 'YYYYMMDD'}}-{{sha}}
|
||||
|
||||
- name: Build and push Docker images
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
56
.github/workflows/docker-image-arm64.yml
vendored
Normal file
56
.github/workflows/docker-image-arm64.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Publish Docker image (Multi Registries)
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
jobs:
|
||||
push_to_registries:
|
||||
name: Push Docker image to multiple registries
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
packages: write
|
||||
contents: read
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Save version info
|
||||
run: |
|
||||
git describe --tags > VERSION
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
calciumion/new-api
|
||||
ghcr.io/${{ github.repository }}
|
||||
|
||||
- name: Build and push Docker images
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
59
.github/workflows/linux-release.yml
vendored
Normal file
59
.github/workflows/linux-release.yml
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
name: Linux Release
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
description: 'reason'
|
||||
required: false
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- '!*-alpha*'
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '>=1.18.0'
|
||||
- name: Build Backend (amd64)
|
||||
run: |
|
||||
go mod download
|
||||
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api
|
||||
|
||||
- name: Build Backend (arm64)
|
||||
run: |
|
||||
sudo apt-get update
|
||||
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu
|
||||
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-arm64
|
||||
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: |
|
||||
one-api
|
||||
one-api-arm64
|
||||
draft: true
|
||||
generate_release_notes: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
51
.github/workflows/macos-release.yml
vendored
Normal file
51
.github/workflows/macos-release.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
name: macOS Release
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
description: 'reason'
|
||||
required: false
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- '!*-alpha*'
|
||||
jobs:
|
||||
release:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
env:
|
||||
CI: ""
|
||||
NODE_OPTIONS: "--max-old-space-size=4096"
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '>=1.18.0'
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
go build -ldflags "-X 'one-api/common.Version=$(git describe --tags)'" -o one-api-macos
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: one-api-macos
|
||||
draft: true
|
||||
generate_release_notes: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
21
.github/workflows/pr-target-branch-check.yml
vendored
Normal file
21
.github/workflows/pr-target-branch-check.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Check PR Branching Strategy
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, edited]
|
||||
|
||||
jobs:
|
||||
check-branching-strategy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Enforce branching strategy
|
||||
run: |
|
||||
if [[ "${{ github.base_ref }}" == "main" ]]; then
|
||||
if [[ "${{ github.head_ref }}" != "alpha" ]]; then
|
||||
echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
|
||||
exit 1
|
||||
fi
|
||||
elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
|
||||
echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
|
||||
exit 1
|
||||
fi
|
||||
echo "Branching strategy check passed."
|
||||
53
.github/workflows/windows-release.yml
vendored
Normal file
53
.github/workflows/windows-release.yml
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
name: Windows Release
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
description: 'reason'
|
||||
required: false
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- '!*-alpha*'
|
||||
jobs:
|
||||
release:
|
||||
runs-on: windows-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
- name: Build Frontend
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: '>=1.18.0'
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)'" -o one-api.exe
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: one-api.exe
|
||||
draft: true
|
||||
generate_release_notes: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
13
.gitignore
vendored
Normal file
13
.gitignore
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
.idea
|
||||
.vscode
|
||||
upload
|
||||
*.exe
|
||||
*.db
|
||||
build
|
||||
*.db-journal
|
||||
logs
|
||||
web/dist
|
||||
.env
|
||||
one-api
|
||||
.DS_Store
|
||||
tiktoken_cache
|
||||
35
Dockerfile
Normal file
35
Dockerfile
Normal file
@@ -0,0 +1,35 @@
|
||||
FROM oven/bun:latest AS builder
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/package.json .
|
||||
COPY web/bun.lock .
|
||||
RUN bun install
|
||||
COPY ./web .
|
||||
COPY ./VERSION .
|
||||
RUN DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(cat VERSION) bun run build
|
||||
|
||||
FROM golang:alpine AS builder2
|
||||
|
||||
ENV GO111MODULE=on \
|
||||
CGO_ENABLED=0 \
|
||||
GOOS=linux
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
ADD go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
COPY --from=builder /build/dist ./web/dist
|
||||
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-api
|
||||
|
||||
FROM alpine
|
||||
|
||||
RUN apk upgrade --no-cache \
|
||||
&& apk add --no-cache ca-certificates tzdata ffmpeg \
|
||||
&& update-ca-certificates
|
||||
|
||||
COPY --from=builder2 /build/one-api /
|
||||
EXPOSE 3000
|
||||
WORKDIR /data
|
||||
ENTRYPOINT ["/one-api"]
|
||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
216
README.en.md
Normal file
216
README.en.md
Normal file
@@ -0,0 +1,216 @@
|
||||
<p align="right">
|
||||
<a href="./README.md">中文</a> | <strong>English</strong>
|
||||
</p>
|
||||
<div align="center">
|
||||
|
||||

|
||||
|
||||
# New API
|
||||
|
||||
🍥 Next-Generation Large Model Gateway and AI Asset Management System
|
||||
|
||||
<a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://raw.githubusercontent.com/Calcium-Ion/new-api/main/LICENSE">
|
||||
<img src="https://img.shields.io/github/license/Calcium-Ion/new-api?color=brightgreen" alt="license">
|
||||
</a>
|
||||
<a href="https://github.com/Calcium-Ion/new-api/releases/latest">
|
||||
<img src="https://img.shields.io/github/v/release/Calcium-Ion/new-api?color=brightgreen&include_prereleases" alt="release">
|
||||
</a>
|
||||
<a href="https://github.com/users/Calcium-Ion/packages/container/package/new-api">
|
||||
<img src="https://img.shields.io/badge/docker-ghcr.io-blue" alt="docker">
|
||||
</a>
|
||||
<a href="https://hub.docker.com/r/CalciumIon/new-api">
|
||||
<img src="https://img.shields.io/badge/docker-dockerHub-blue" alt="docker">
|
||||
</a>
|
||||
<a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
|
||||
<img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
## 📝 Project Description
|
||||
|
||||
> [!NOTE]
|
||||
> This is an open-source project developed based on [One API](https://github.com/songquanpeng/one-api)
|
||||
|
||||
> [!IMPORTANT]
|
||||
> - This project is for personal learning purposes only, with no guarantee of stability or technical support.
|
||||
> - Users must comply with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**, and must not use it for illegal purposes.
|
||||
> - According to the [《Interim Measures for the Management of Generative Artificial Intelligence Services》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm), please do not provide any unregistered generative AI services to the public in China.
|
||||
|
||||
<h2>🤝 Trusted Partners</h2>
|
||||
<p id="premium-sponsors"> </p>
|
||||
<p align="center"><strong>No particular order</strong></p>
|
||||
<p align="center">
|
||||
<a href="https://www.cherry-ai.com/" target=_blank><img
|
||||
src="./docs/images/cherry-studio.png" alt="Cherry Studio" height="120"
|
||||
/></a>
|
||||
<a href="https://bda.pku.edu.cn/" target=_blank><img
|
||||
src="./docs/images/pku.png" alt="Peking University" height="120"
|
||||
/></a>
|
||||
<a href="https://www.compshare.cn/?ytag=GPU_yy_gh_newapi" target=_blank><img
|
||||
src="./docs/images/ucloud.png" alt="UCloud" height="120"
|
||||
/></a>
|
||||
<a href="https://www.aliyun.com/" target=_blank><img
|
||||
src="./docs/images/aliyun.png" alt="Alibaba Cloud" height="120"
|
||||
/></a>
|
||||
<a href="https://io.net/" target=_blank><img
|
||||
src="./docs/images/io-net.png" alt="IO.NET" height="120"
|
||||
/></a>
|
||||
</p>
|
||||
<p> </p>
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
For detailed documentation, please visit our official Wiki: [https://docs.newapi.pro/](https://docs.newapi.pro/)
|
||||
|
||||
You can also access the AI-generated DeepWiki:
|
||||
[](https://deepwiki.com/QuantumNous/new-api)
|
||||
|
||||
## ✨ Key Features
|
||||
|
||||
New API offers a wide range of features, please refer to [Features Introduction](https://docs.newapi.pro/wiki/features-introduction) for details:
|
||||
|
||||
1. 🎨 Brand new UI interface
|
||||
2. 🌍 Multi-language support
|
||||
3. 💰 Online recharge functionality (YiPay)
|
||||
4. 🔍 Support for querying usage quotas with keys (works with [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool))
|
||||
5. 🔄 Compatible with the original One API database
|
||||
6. 💵 Support for pay-per-use model pricing
|
||||
7. ⚖️ Support for weighted random channel selection
|
||||
8. 📈 Data dashboard (console)
|
||||
9. 🔒 Token grouping and model restrictions
|
||||
10. 🤖 Support for more authorization login methods (LinuxDO, Telegram, OIDC)
|
||||
11. 🔄 Support for Rerank models (Cohere and Jina), [API Documentation](https://docs.newapi.pro/api/jinaai-rerank)
|
||||
12. ⚡ Support for OpenAI Realtime API (including Azure channels), [API Documentation](https://docs.newapi.pro/api/openai-realtime)
|
||||
13. ⚡ Support for Claude Messages format, [API Documentation](https://docs.newapi.pro/api/anthropic-chat)
|
||||
14. Support for entering chat interface via /chat2link route
|
||||
15. 🧠 Support for setting reasoning effort through model name suffixes:
|
||||
1. OpenAI o-series models
|
||||
- Add `-high` suffix for high reasoning effort (e.g.: `o3-mini-high`)
|
||||
- Add `-medium` suffix for medium reasoning effort (e.g.: `o3-mini-medium`)
|
||||
- Add `-low` suffix for low reasoning effort (e.g.: `o3-mini-low`)
|
||||
2. Claude thinking models
|
||||
- Add `-thinking` suffix to enable thinking mode (e.g.: `claude-3-7-sonnet-20250219-thinking`)
|
||||
16. 🔄 Thinking-to-content functionality
|
||||
17. 🔄 Model rate limiting for users
|
||||
18. 💰 Cache billing support, which allows billing at a set ratio when cache is hit:
|
||||
1. Set the `Prompt Cache Ratio` option in `System Settings-Operation Settings`
|
||||
2. Set `Prompt Cache Ratio` in the channel, range 0-1, e.g., setting to 0.5 means billing at 50% when cache is hit
|
||||
3. Supported channels:
|
||||
- [x] OpenAI
|
||||
- [x] Azure
|
||||
- [x] DeepSeek
|
||||
- [x] Claude
|
||||
|
||||
## Model Support
|
||||
|
||||
This version supports multiple models, please refer to [API Documentation-Relay Interface](https://docs.newapi.pro/api) for details:
|
||||
|
||||
1. Third-party models **gpts** (gpt-4-gizmo-*)
|
||||
2. Third-party channel [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy) interface, [API Documentation](https://docs.newapi.pro/api/midjourney-proxy-image)
|
||||
3. Third-party channel [Suno API](https://github.com/Suno-API/Suno-API) interface, [API Documentation](https://docs.newapi.pro/api/suno-music)
|
||||
4. Custom channels, supporting full call address input
|
||||
5. Rerank models ([Cohere](https://cohere.ai/) and [Jina](https://jina.ai/)), [API Documentation](https://docs.newapi.pro/api/jinaai-rerank)
|
||||
6. Claude Messages format, [API Documentation](https://docs.newapi.pro/api/anthropic-chat)
|
||||
7. Dify, currently only supports chatflow
|
||||
|
||||
## Environment Variable Configuration
|
||||
|
||||
For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
|
||||
|
||||
- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
|
||||
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 300 seconds
|
||||
- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
|
||||
- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
|
||||
- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`
|
||||
- `GET_MEDIA_TOKEN_NOT_STREAM`: Whether to count image tokens in non-streaming cases, default is `true`
|
||||
- `UPDATE_TASK`: Whether to update asynchronous tasks (Midjourney, Suno), default is `true`
|
||||
- `COHERE_SAFETY_SETTING`: Cohere model safety settings, options are `NONE`, `CONTEXTUAL`, `STRICT`, default is `NONE`
|
||||
- `GEMINI_VISION_MAX_IMAGE_NUM`: Maximum number of images for Gemini models, default is `16`
|
||||
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default is `20`
|
||||
- `CRYPTO_SECRET`: Encryption key used for encrypting database content
|
||||
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, default is `2025-04-01-preview`
|
||||
- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Notification limit duration, default is `10` minutes
|
||||
- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications within the specified duration, default is `2`
|
||||
- `ERROR_LOG_ENABLED=true`: Whether to record and display error logs, default is `false`
|
||||
|
||||
## Deployment
|
||||
|
||||
For detailed deployment guides, please refer to [Installation Guide-Deployment Methods](https://docs.newapi.pro/installation):
|
||||
|
||||
> [!TIP]
|
||||
> Latest Docker image: `calciumion/new-api:latest`
|
||||
|
||||
### Multi-machine Deployment Considerations
|
||||
- Environment variable `SESSION_SECRET` must be set, otherwise login status will be inconsistent across multiple machines
|
||||
- If sharing Redis, `CRYPTO_SECRET` must be set, otherwise Redis content cannot be accessed across multiple machines
|
||||
|
||||
### Deployment Requirements
|
||||
- Local database (default): SQLite (Docker deployment must mount the `/data` directory)
|
||||
- Remote database: MySQL version >= 5.7.8, PgSQL version >= 9.6
|
||||
|
||||
### Deployment Methods
|
||||
|
||||
#### Using BaoTa Panel Docker Feature
|
||||
Install BaoTa Panel (version **9.2.0** or above), find **New-API** in the application store and install it.
|
||||
[Tutorial with images](./docs/BT.md)
|
||||
|
||||
#### Using Docker Compose (Recommended)
|
||||
```shell
|
||||
# Download the project
|
||||
git clone https://github.com/Calcium-Ion/new-api.git
|
||||
cd new-api
|
||||
# Edit docker-compose.yml as needed
|
||||
# Start
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
#### Using Docker Image Directly
|
||||
```shell
|
||||
# Using SQLite
|
||||
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
|
||||
|
||||
# Using MySQL
|
||||
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
|
||||
```
|
||||
|
||||
## Channel Retry and Cache
|
||||
Channel retry functionality has been implemented, you can set the number of retries in `Settings->Operation Settings->General Settings`. It is **recommended to enable caching**.
|
||||
|
||||
### Cache Configuration Method
|
||||
1. `REDIS_CONN_STRING`: Set Redis as cache
|
||||
2. `MEMORY_CACHE_ENABLED`: Enable memory cache (no need to set manually if Redis is set)
|
||||
|
||||
## API Documentation
|
||||
|
||||
For detailed API documentation, please refer to [API Documentation](https://docs.newapi.pro/api):
|
||||
|
||||
- [Chat API](https://docs.newapi.pro/api/openai-chat)
|
||||
- [Image API](https://docs.newapi.pro/api/openai-image)
|
||||
- [Rerank API](https://docs.newapi.pro/api/jinaai-rerank)
|
||||
- [Realtime API](https://docs.newapi.pro/api/openai-realtime)
|
||||
- [Claude Chat API (messages)](https://docs.newapi.pro/api/anthropic-chat)
|
||||
|
||||
## Related Projects
|
||||
- [One API](https://github.com/songquanpeng/one-api): Original project
|
||||
- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy): Midjourney interface support
|
||||
- [chatnio](https://github.com/Deeptrain-Community/chatnio): Next-generation AI one-stop B/C-end solution
|
||||
- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool): Query usage quota with key
|
||||
|
||||
Other projects based on New API:
|
||||
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon): High-performance optimized version of New API
|
||||
- [VoAPI](https://github.com/VoAPI/VoAPI): Frontend beautified version based on New API
|
||||
|
||||
## Help and Support
|
||||
|
||||
If you have any questions, please refer to [Help and Support](https://docs.newapi.pro/support):
|
||||
- [Community Interaction](https://docs.newapi.pro/support/community-interaction)
|
||||
- [Issue Feedback](https://docs.newapi.pro/support/feedback-issues)
|
||||
- [FAQ](https://docs.newapi.pro/support/faq)
|
||||
|
||||
## 🌟 Star History
|
||||
|
||||
[](https://star-history.com/#Calcium-Ion/new-api&Date)
|
||||
6
bin/migration_v0.2-v0.3.sql
Normal file
6
bin/migration_v0.2-v0.3.sql
Normal file
@@ -0,0 +1,6 @@
|
||||
UPDATE users
|
||||
SET quota = quota + (
|
||||
SELECT SUM(remain_quota)
|
||||
FROM tokens
|
||||
WHERE tokens.user_id = users.id
|
||||
)
|
||||
17
bin/migration_v0.3-v0.4.sql
Normal file
17
bin/migration_v0.3-v0.4.sql
Normal file
@@ -0,0 +1,17 @@
|
||||
INSERT INTO abilities (`group`, model, channel_id, enabled)
|
||||
SELECT c.`group`, m.model, c.id, 1
|
||||
FROM channels c
|
||||
CROSS JOIN (
|
||||
SELECT 'gpt-3.5-turbo' AS model UNION ALL
|
||||
SELECT 'gpt-3.5-turbo-0301' AS model UNION ALL
|
||||
SELECT 'gpt-4' AS model UNION ALL
|
||||
SELECT 'gpt-4-0314' AS model
|
||||
) AS m
|
||||
WHERE c.status = 1
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM abilities a
|
||||
WHERE a.`group` = c.`group`
|
||||
AND a.model = m.model
|
||||
AND a.channel_id = c.id
|
||||
);
|
||||
40
bin/time_test.sh
Normal file
40
bin/time_test.sh
Normal file
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# -lt 3 ]; then
|
||||
echo "Usage: time_test.sh <domain> <key> <count> [<model>]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
domain=$1
|
||||
key=$2
|
||||
count=$3
|
||||
model=${4:-"gpt-3.5-turbo"} # 设置默认模型为 gpt-3.5-turbo
|
||||
|
||||
total_time=0
|
||||
times=()
|
||||
|
||||
for ((i=1; i<=count; i++)); do
|
||||
result=$(curl -o /dev/null -s -w "%{http_code} %{time_total}\\n" \
|
||||
https://"$domain"/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer $key" \
|
||||
-d '{"messages": [{"content": "echo hi", "role": "user"}], "model": "'"$model"'", "stream": false, "max_tokens": 1}')
|
||||
http_code=$(echo "$result" | awk '{print $1}')
|
||||
time=$(echo "$result" | awk '{print $2}')
|
||||
echo "HTTP status code: $http_code, Time taken: $time"
|
||||
total_time=$(bc <<< "$total_time + $time")
|
||||
times+=("$time")
|
||||
done
|
||||
|
||||
average_time=$(echo "scale=4; $total_time / $count" | bc)
|
||||
|
||||
sum_of_squares=0
|
||||
for time in "${times[@]}"; do
|
||||
difference=$(echo "scale=4; $time - $average_time" | bc)
|
||||
square=$(echo "scale=4; $difference * $difference" | bc)
|
||||
sum_of_squares=$(echo "scale=4; $sum_of_squares + $square" | bc)
|
||||
done
|
||||
|
||||
standard_deviation=$(echo "scale=4; sqrt($sum_of_squares / $count)" | bc)
|
||||
|
||||
echo "Average time: $average_time±$standard_deviation"
|
||||
73
common/api_type.go
Normal file
73
common/api_type.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package common
|
||||
|
||||
import "one-api/constant"
|
||||
|
||||
func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType := -1
|
||||
switch channelType {
|
||||
case constant.ChannelTypeOpenAI:
|
||||
apiType = constant.APITypeOpenAI
|
||||
case constant.ChannelTypeAnthropic:
|
||||
apiType = constant.APITypeAnthropic
|
||||
case constant.ChannelTypeBaidu:
|
||||
apiType = constant.APITypeBaidu
|
||||
case constant.ChannelTypePaLM:
|
||||
apiType = constant.APITypePaLM
|
||||
case constant.ChannelTypeZhipu:
|
||||
apiType = constant.APITypeZhipu
|
||||
case constant.ChannelTypeAli:
|
||||
apiType = constant.APITypeAli
|
||||
case constant.ChannelTypeXunfei:
|
||||
apiType = constant.APITypeXunfei
|
||||
case constant.ChannelTypeAIProxyLibrary:
|
||||
apiType = constant.APITypeAIProxyLibrary
|
||||
case constant.ChannelTypeTencent:
|
||||
apiType = constant.APITypeTencent
|
||||
case constant.ChannelTypeGemini:
|
||||
apiType = constant.APITypeGemini
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
apiType = constant.APITypeZhipuV4
|
||||
case constant.ChannelTypeOllama:
|
||||
apiType = constant.APITypeOllama
|
||||
case constant.ChannelTypePerplexity:
|
||||
apiType = constant.APITypePerplexity
|
||||
case constant.ChannelTypeAws:
|
||||
apiType = constant.APITypeAws
|
||||
case constant.ChannelTypeCohere:
|
||||
apiType = constant.APITypeCohere
|
||||
case constant.ChannelTypeDify:
|
||||
apiType = constant.APITypeDify
|
||||
case constant.ChannelTypeJina:
|
||||
apiType = constant.APITypeJina
|
||||
case constant.ChannelCloudflare:
|
||||
apiType = constant.APITypeCloudflare
|
||||
case constant.ChannelTypeSiliconFlow:
|
||||
apiType = constant.APITypeSiliconFlow
|
||||
case constant.ChannelTypeVertexAi:
|
||||
apiType = constant.APITypeVertexAi
|
||||
case constant.ChannelTypeMistral:
|
||||
apiType = constant.APITypeMistral
|
||||
case constant.ChannelTypeDeepSeek:
|
||||
apiType = constant.APITypeDeepSeek
|
||||
case constant.ChannelTypeMokaAI:
|
||||
apiType = constant.APITypeMokaAI
|
||||
case constant.ChannelTypeVolcEngine:
|
||||
apiType = constant.APITypeVolcEngine
|
||||
case constant.ChannelTypeBaiduV2:
|
||||
apiType = constant.APITypeBaiduV2
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
apiType = constant.APITypeOpenRouter
|
||||
case constant.ChannelTypeXinference:
|
||||
apiType = constant.APITypeXinference
|
||||
case constant.ChannelTypeXai:
|
||||
apiType = constant.APITypeXai
|
||||
case constant.ChannelTypeCoze:
|
||||
apiType = constant.APITypeCoze
|
||||
case constant.ChannelTypeJimeng:
|
||||
apiType = constant.APITypeJimeng
|
||||
}
|
||||
if apiType == -1 {
|
||||
return constant.APITypeOpenAI, false
|
||||
}
|
||||
return apiType, true
|
||||
}
|
||||
201
common/constants.go
Normal file
201
common/constants.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
//"os"
|
||||
//"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var StartTime = time.Now().Unix() // unit: second
|
||||
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
|
||||
var SystemName = "New API"
|
||||
var Footer = ""
|
||||
var Logo = ""
|
||||
var TopUpLink = ""
|
||||
|
||||
// var ChatLink = ""
|
||||
// var ChatLink2 = ""
|
||||
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
|
||||
var DisplayInCurrencyEnabled = true
|
||||
var DisplayTokenStatEnabled = true
|
||||
var DrawingEnabled = true
|
||||
var TaskEnabled = true
|
||||
var DataExportEnabled = true
|
||||
var DataExportInterval = 5 // unit: minute
|
||||
var DataExportDefaultTime = "hour" // unit: minute
|
||||
var DefaultCollapseSidebar = false // default value of collapse sidebar
|
||||
|
||||
// Any options with "Secret", "Token" in its key won't be return by GetOptions
|
||||
|
||||
var SessionSecret = uuid.New().String()
|
||||
var CryptoSecret = uuid.New().String()
|
||||
|
||||
var OptionMap map[string]string
|
||||
var OptionMapRWMutex sync.RWMutex
|
||||
|
||||
var ItemsPerPage = 10
|
||||
var MaxRecentItems = 100
|
||||
|
||||
var PasswordLoginEnabled = true
|
||||
var PasswordRegisterEnabled = true
|
||||
var EmailVerificationEnabled = false
|
||||
var GitHubOAuthEnabled = false
|
||||
var LinuxDOOAuthEnabled = false
|
||||
var WeChatAuthEnabled = false
|
||||
var TelegramOAuthEnabled = false
|
||||
var TurnstileCheckEnabled = false
|
||||
var RegisterEnabled = true
|
||||
|
||||
var EmailDomainRestrictionEnabled = false // 是否启用邮箱域名限制
|
||||
var EmailAliasRestrictionEnabled = false // 是否启用邮箱别名限制
|
||||
var EmailDomainWhitelist = []string{
|
||||
"gmail.com",
|
||||
"163.com",
|
||||
"126.com",
|
||||
"qq.com",
|
||||
"outlook.com",
|
||||
"hotmail.com",
|
||||
"icloud.com",
|
||||
"yahoo.com",
|
||||
"foxmail.com",
|
||||
}
|
||||
var EmailLoginAuthServerList = []string{
|
||||
"smtp.sendcloud.net",
|
||||
"smtp.azurecomm.net",
|
||||
}
|
||||
|
||||
var DebugEnabled bool
|
||||
var MemoryCacheEnabled bool
|
||||
|
||||
var LogConsumeEnabled = true
|
||||
|
||||
var SMTPServer = ""
|
||||
var SMTPPort = 587
|
||||
var SMTPSSLEnabled = false
|
||||
var SMTPAccount = ""
|
||||
var SMTPFrom = ""
|
||||
var SMTPToken = ""
|
||||
|
||||
var GitHubClientId = ""
|
||||
var GitHubClientSecret = ""
|
||||
var LinuxDOClientId = ""
|
||||
var LinuxDOClientSecret = ""
|
||||
|
||||
var WeChatServerAddress = ""
|
||||
var WeChatServerToken = ""
|
||||
var WeChatAccountQRCodeImageURL = ""
|
||||
|
||||
var TurnstileSiteKey = ""
|
||||
var TurnstileSecretKey = ""
|
||||
|
||||
var TelegramBotToken = ""
|
||||
var TelegramBotName = ""
|
||||
|
||||
var QuotaForNewUser = 0
|
||||
var QuotaForInviter = 0
|
||||
var QuotaForInvitee = 0
|
||||
var ChannelDisableThreshold = 5.0
|
||||
var AutomaticDisableChannelEnabled = false
|
||||
var AutomaticEnableChannelEnabled = false
|
||||
var QuotaRemindThreshold = 1000
|
||||
var PreConsumedQuota = 500
|
||||
|
||||
var RetryTimes = 0
|
||||
|
||||
//var RootUserEmail = ""
|
||||
|
||||
var IsMasterNode bool
|
||||
|
||||
var requestInterval int
|
||||
var RequestInterval time.Duration
|
||||
|
||||
var SyncFrequency int // unit is second
|
||||
|
||||
var BatchUpdateEnabled = false
|
||||
var BatchUpdateInterval int
|
||||
|
||||
var RelayTimeout int // unit is second
|
||||
|
||||
var GeminiSafetySetting string
|
||||
|
||||
// https://docs.cohere.com/docs/safety-modes Type; NONE/CONTEXTUAL/STRICT
|
||||
var CohereSafetySetting string
|
||||
|
||||
const (
|
||||
RequestIdKey = "X-Oneapi-Request-Id"
|
||||
)
|
||||
|
||||
const (
|
||||
RoleGuestUser = 0
|
||||
RoleCommonUser = 1
|
||||
RoleAdminUser = 10
|
||||
RoleRootUser = 100
|
||||
)
|
||||
|
||||
func IsValidateRole(role int) bool {
|
||||
return role == RoleGuestUser || role == RoleCommonUser || role == RoleAdminUser || role == RoleRootUser
|
||||
}
|
||||
|
||||
var (
|
||||
FileUploadPermission = RoleGuestUser
|
||||
FileDownloadPermission = RoleGuestUser
|
||||
ImageUploadPermission = RoleGuestUser
|
||||
ImageDownloadPermission = RoleGuestUser
|
||||
)
|
||||
|
||||
// All duration's unit is seconds
|
||||
// Shouldn't larger then RateLimitKeyExpirationDuration
|
||||
var (
|
||||
GlobalApiRateLimitEnable bool
|
||||
GlobalApiRateLimitNum int
|
||||
GlobalApiRateLimitDuration int64
|
||||
|
||||
GlobalWebRateLimitEnable bool
|
||||
GlobalWebRateLimitNum int
|
||||
GlobalWebRateLimitDuration int64
|
||||
|
||||
UploadRateLimitNum = 10
|
||||
UploadRateLimitDuration int64 = 60
|
||||
|
||||
DownloadRateLimitNum = 10
|
||||
DownloadRateLimitDuration int64 = 60
|
||||
|
||||
CriticalRateLimitNum = 20
|
||||
CriticalRateLimitDuration int64 = 20 * 60
|
||||
)
|
||||
|
||||
var RateLimitKeyExpirationDuration = 20 * time.Minute
|
||||
|
||||
const (
|
||||
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||
UserStatusDisabled = 2 // also don't use 0
|
||||
)
|
||||
|
||||
const (
|
||||
TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||
TokenStatusDisabled = 2 // also don't use 0
|
||||
TokenStatusExpired = 3
|
||||
TokenStatusExhausted = 4
|
||||
)
|
||||
|
||||
const (
|
||||
RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||
RedemptionCodeStatusDisabled = 2 // also don't use 0
|
||||
RedemptionCodeStatusUsed = 3 // also don't use 0
|
||||
)
|
||||
|
||||
const (
|
||||
ChannelStatusUnknown = 0
|
||||
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
||||
ChannelStatusAutoDisabled = 3
|
||||
)
|
||||
|
||||
const (
|
||||
TopUpStatusPending = "pending"
|
||||
TopUpStatusSuccess = "success"
|
||||
TopUpStatusExpired = "expired"
|
||||
)
|
||||
31
common/crypto.go
Normal file
31
common/crypto.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func GenerateHMACWithKey(key []byte, data string) string {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write([]byte(data))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func GenerateHMAC(data string) string {
|
||||
h := hmac.New(sha256.New, []byte(CryptoSecret))
|
||||
h.Write([]byte(data))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func Password2Hash(password string) (string, error) {
|
||||
passwordBytes := []byte(password)
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost)
|
||||
return string(hashedPassword), err
|
||||
}
|
||||
|
||||
func ValidatePasswordAndHash(password string, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
82
common/custom-event.go
Normal file
82
common/custom-event.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
|
||||
// Use of this source code is governed by a MIT style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type stringWriter interface {
|
||||
io.Writer
|
||||
writeString(string) (int, error)
|
||||
}
|
||||
|
||||
type stringWrapper struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (w stringWrapper) writeString(str string) (int, error) {
|
||||
return w.Writer.Write([]byte(str))
|
||||
}
|
||||
|
||||
func checkWriter(writer io.Writer) stringWriter {
|
||||
if w, ok := writer.(stringWriter); ok {
|
||||
return w
|
||||
} else {
|
||||
return stringWrapper{writer}
|
||||
}
|
||||
}
|
||||
|
||||
// Server-Sent Events
|
||||
// W3C Working Draft 29 October 2009
|
||||
// http://www.w3.org/TR/2009/WD-eventsource-20091029/
|
||||
|
||||
var contentType = []string{"text/event-stream"}
|
||||
var noCache = []string{"no-cache"}
|
||||
|
||||
var fieldReplacer = strings.NewReplacer(
|
||||
"\n", "\\n",
|
||||
"\r", "\\r")
|
||||
|
||||
var dataReplacer = strings.NewReplacer(
|
||||
"\n", "\n",
|
||||
"\r", "\\r")
|
||||
|
||||
type CustomEvent struct {
|
||||
Event string
|
||||
Id string
|
||||
Retry uint
|
||||
Data interface{}
|
||||
}
|
||||
|
||||
func encode(writer io.Writer, event CustomEvent) error {
|
||||
w := checkWriter(writer)
|
||||
return writeData(w, event.Data)
|
||||
}
|
||||
|
||||
func writeData(w stringWriter, data interface{}) error {
|
||||
dataReplacer.WriteString(w, fmt.Sprint(data))
|
||||
if strings.HasPrefix(data.(string), "data") {
|
||||
w.writeString("\n\n")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r CustomEvent) Render(w http.ResponseWriter) error {
|
||||
r.WriteContentType(w)
|
||||
return encode(w, r)
|
||||
}
|
||||
|
||||
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
|
||||
header := w.Header()
|
||||
header["Content-Type"] = contentType
|
||||
|
||||
if _, exist := header["Cache-Control"]; !exist {
|
||||
header["Cache-Control"] = noCache
|
||||
}
|
||||
}
|
||||
15
common/database.go
Normal file
15
common/database.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package common
|
||||
|
||||
const (
|
||||
DatabaseTypeMySQL = "mysql"
|
||||
DatabaseTypeSQLite = "sqlite"
|
||||
DatabaseTypePostgreSQL = "postgres"
|
||||
)
|
||||
|
||||
var UsingSQLite = false
|
||||
var UsingPostgreSQL = false
|
||||
var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
|
||||
var UsingMySQL = false
|
||||
var UsingClickHouse = false
|
||||
|
||||
var SQLitePath = "one-api.db?_busy_timeout=5000"
|
||||
40
common/email-outlook-auth.go
Normal file
40
common/email-outlook-auth.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type outlookAuth struct {
|
||||
username, password string
|
||||
}
|
||||
|
||||
func LoginAuth(username, password string) smtp.Auth {
|
||||
return &outlookAuth{username, password}
|
||||
}
|
||||
|
||||
func (a *outlookAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
|
||||
return "LOGIN", []byte{}, nil
|
||||
}
|
||||
|
||||
func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
||||
if more {
|
||||
switch string(fromServer) {
|
||||
case "Username:":
|
||||
return []byte(a.username), nil
|
||||
case "Password:":
|
||||
return []byte(a.password), nil
|
||||
default:
|
||||
return nil, errors.New("unknown fromServer")
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func isOutlookServer(server string) bool {
|
||||
// 兼容多地区的outlook邮箱和ofb邮箱
|
||||
// 其实应该加一个Option来区分是否用LOGIN的方式登录
|
||||
// 先临时兼容一下
|
||||
return strings.Contains(server, "outlook") || strings.Contains(server, "onmicrosoft")
|
||||
}
|
||||
90
common/email.go
Normal file
90
common/email.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func generateMessageID() (string, error) {
|
||||
split := strings.Split(SMTPFrom, "@")
|
||||
if len(split) < 2 {
|
||||
return "", fmt.Errorf("invalid SMTP account")
|
||||
}
|
||||
domain := strings.Split(SMTPFrom, "@")[1]
|
||||
return fmt.Sprintf("<%d.%s@%s>", time.Now().UnixNano(), GetRandomString(12), domain), nil
|
||||
}
|
||||
|
||||
func SendEmail(subject string, receiver string, content string) error {
|
||||
if SMTPFrom == "" { // for compatibility
|
||||
SMTPFrom = SMTPAccount
|
||||
}
|
||||
id, err2 := generateMessageID()
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
if SMTPServer == "" && SMTPAccount == "" {
|
||||
return fmt.Errorf("SMTP 服务器未配置")
|
||||
}
|
||||
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
|
||||
mail := []byte(fmt.Sprintf("To: %s\r\n"+
|
||||
"From: %s<%s>\r\n"+
|
||||
"Subject: %s\r\n"+
|
||||
"Date: %s\r\n"+
|
||||
"Message-ID: %s\r\n"+ // 添加 Message-ID 头
|
||||
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
|
||||
receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), id, content))
|
||||
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
|
||||
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
|
||||
to := strings.Split(receiver, ";")
|
||||
var err error
|
||||
if SMTPPort == 465 || SMTPSSLEnabled {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: SMTPServer,
|
||||
}
|
||||
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client, err := smtp.NewClient(conn, SMTPServer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
if err = client.Auth(auth); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = client.Mail(SMTPFrom); err != nil {
|
||||
return err
|
||||
}
|
||||
receiverEmails := strings.Split(receiver, ";")
|
||||
for _, receiver := range receiverEmails {
|
||||
if err = client.Rcpt(receiver); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
w, err := client.Data()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(mail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = w.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if isOutlookServer(SMTPAccount) || slices.Contains(EmailLoginAuthServerList, SMTPServer) {
|
||||
auth = LoginAuth(SMTPAccount, SMTPToken)
|
||||
err = smtp.SendMail(addr, auth, SMTPFrom, to, mail)
|
||||
} else {
|
||||
err = smtp.SendMail(addr, auth, SMTPFrom, to, mail)
|
||||
}
|
||||
return err
|
||||
}
|
||||
32
common/embed-file-system.go
Normal file
32
common/embed-file-system.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"github.com/gin-contrib/static"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Credit: https://github.com/gin-contrib/static/issues/19
|
||||
|
||||
type embedFileSystem struct {
|
||||
http.FileSystem
|
||||
}
|
||||
|
||||
func (e embedFileSystem) Exists(prefix string, path string) bool {
|
||||
_, err := e.Open(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {
|
||||
efs, err := fs.Sub(fsEmbed, targetPath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return embedFileSystem{
|
||||
FileSystem: http.FS(efs),
|
||||
}
|
||||
}
|
||||
41
common/endpoint_type.go
Normal file
41
common/endpoint_type.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package common
|
||||
|
||||
import "one-api/constant"
|
||||
|
||||
// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
|
||||
func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
|
||||
var endpointTypes []constant.EndpointType
|
||||
switch channelType {
|
||||
case constant.ChannelTypeJina:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
|
||||
//case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney}
|
||||
//case constant.ChannelTypeSunoAPI:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno}
|
||||
//case constant.ChannelTypeKling:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeKling}
|
||||
//case constant.ChannelTypeJimeng:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng}
|
||||
case constant.ChannelTypeAws:
|
||||
fallthrough
|
||||
case constant.ChannelTypeAnthropic:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
|
||||
case constant.ChannelTypeVertexAi:
|
||||
fallthrough
|
||||
case constant.ChannelTypeGemini:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
|
||||
case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||
default:
|
||||
if IsOpenAIResponseOnlyModel(modelName) {
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
|
||||
} else {
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||
}
|
||||
}
|
||||
if IsImageGenerationModel(modelName) {
|
||||
// add to first
|
||||
endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...)
|
||||
}
|
||||
return endpointTypes
|
||||
}
|
||||
38
common/env.go
Normal file
38
common/env.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func GetEnvOrDefault(env string, defaultValue int) int {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
num, err := strconv.Atoi(os.Getenv(env))
|
||||
if err != nil {
|
||||
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
|
||||
return defaultValue
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
||||
func GetEnvOrDefaultString(env string, defaultValue string) string {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return os.Getenv(env)
|
||||
}
|
||||
|
||||
func GetEnvOrDefaultBool(env string, defaultValue bool) bool {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
b, err := strconv.ParseBool(os.Getenv(env))
|
||||
if err != nil {
|
||||
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %t", env, err.Error(), defaultValue))
|
||||
return defaultValue
|
||||
}
|
||||
return b
|
||||
}
|
||||
111
common/gin.go
Normal file
111
common/gin.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const KeyRequestBody = "key_request_body"
|
||||
|
||||
func GetRequestBody(c *gin.Context) ([]byte, error) {
|
||||
requestBody, _ := c.Get(KeyRequestBody)
|
||||
if requestBody != nil {
|
||||
return requestBody.([]byte), nil
|
||||
}
|
||||
requestBody, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = c.Request.Body.Close()
|
||||
c.Set(KeyRequestBody, requestBody)
|
||||
return requestBody.([]byte), nil
|
||||
}
|
||||
|
||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
requestBody, err := GetRequestBody(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = Unmarshal(requestBody, &v)
|
||||
} else {
|
||||
// skip for now
|
||||
// TODO: someday non json request have variant model, we will need to implementation this
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Reset request body
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
|
||||
c.Set(string(key), value)
|
||||
}
|
||||
|
||||
func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
|
||||
return c.Get(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
|
||||
return c.GetString(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
|
||||
return c.GetInt(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
|
||||
return c.GetBool(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
|
||||
return c.GetStringSlice(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
|
||||
return c.GetStringMap(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
|
||||
return c.GetTime(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
|
||||
if value, ok := c.Get(string(key)); ok {
|
||||
if v, ok := value.(T); ok {
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
var t T
|
||||
return t, false
|
||||
}
|
||||
|
||||
func ApiError(c *gin.Context, err error) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func ApiErrorMsg(c *gin.Context, msg string) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": msg,
|
||||
})
|
||||
}
|
||||
|
||||
func ApiSuccess(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
53
common/go-channel.go
Normal file
53
common/go-channel.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func SafeSendBool(ch chan bool, value bool) (closed bool) {
|
||||
defer func() {
|
||||
// Recover from panic if one occured. A panic would mean the channel was closed.
|
||||
if recover() != nil {
|
||||
closed = true
|
||||
}
|
||||
}()
|
||||
|
||||
// This will panic if the channel is closed.
|
||||
ch <- value
|
||||
|
||||
// If the code reaches here, then the channel was not closed.
|
||||
return false
|
||||
}
|
||||
|
||||
func SafeSendString(ch chan string, value string) (closed bool) {
|
||||
defer func() {
|
||||
// Recover from panic if one occured. A panic would mean the channel was closed.
|
||||
if recover() != nil {
|
||||
closed = true
|
||||
}
|
||||
}()
|
||||
|
||||
// This will panic if the channel is closed.
|
||||
ch <- value
|
||||
|
||||
// If the code reaches here, then the channel was not closed.
|
||||
return false
|
||||
}
|
||||
|
||||
// SafeSendStringTimeout send, return true, else return false
|
||||
func SafeSendStringTimeout(ch chan string, value string, timeout int) (closed bool) {
|
||||
defer func() {
|
||||
// Recover from panic if one occured. A panic would mean the channel was closed.
|
||||
if recover() != nil {
|
||||
closed = false
|
||||
}
|
||||
}()
|
||||
|
||||
// This will panic if the channel is closed.
|
||||
select {
|
||||
case ch <- value:
|
||||
return true
|
||||
case <-time.After(time.Duration(timeout) * time.Second):
|
||||
return false
|
||||
}
|
||||
}
|
||||
24
common/gopool.go
Normal file
24
common/gopool.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"math"
|
||||
)
|
||||
|
||||
var relayGoPool gopool.Pool
|
||||
|
||||
func init() {
|
||||
relayGoPool = gopool.NewPool("gopool.RelayPool", math.MaxInt32, gopool.NewConfig())
|
||||
relayGoPool.SetPanicHandler(func(ctx context.Context, i interface{}) {
|
||||
if stopChan, ok := ctx.Value("stop_chan").(chan bool); ok {
|
||||
SafeSendBool(stopChan, true)
|
||||
}
|
||||
SysError(fmt.Sprintf("panic in gopool.RelayPool: %v", i))
|
||||
})
|
||||
}
|
||||
|
||||
func RelayCtxGo(ctx context.Context, f func()) {
|
||||
relayGoPool.CtxGo(ctx, f)
|
||||
}
|
||||
34
common/hash.go
Normal file
34
common/hash.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
func Sha256Raw(data []byte) []byte {
|
||||
h := sha256.New()
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func Sha1Raw(data []byte) []byte {
|
||||
h := sha1.New()
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func Sha1(data []byte) string {
|
||||
return hex.EncodeToString(Sha1Raw(data))
|
||||
}
|
||||
|
||||
func HmacSha256Raw(message, key []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(message)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func HmacSha256(message, key string) string {
|
||||
return hex.EncodeToString(HmacSha256Raw([]byte(message), []byte(key)))
|
||||
}
|
||||
57
common/http.go
Normal file
57
common/http.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CloseResponseBodyGracefully(httpResponse *http.Response) {
|
||||
if httpResponse == nil || httpResponse.Body == nil {
|
||||
return
|
||||
}
|
||||
err := httpResponse.Body.Close()
|
||||
if err != nil {
|
||||
SysError("failed to close response body: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
|
||||
if c.Writer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
body := io.NopCloser(bytes.NewBuffer(data))
|
||||
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
if src != nil {
|
||||
for k, v := range src.Header {
|
||||
// avoid setting Content-Length
|
||||
if k == "Content-Length" {
|
||||
continue
|
||||
}
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
}
|
||||
|
||||
// set Content-Length header manually BEFORE calling WriteHeader
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
|
||||
// Write header with status code (this sends the headers)
|
||||
if src != nil {
|
||||
c.Writer.WriteHeader(src.StatusCode)
|
||||
} else {
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
_, err := io.Copy(c.Writer, body)
|
||||
if err != nil {
|
||||
LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
120
common/init.go
Normal file
120
common/init.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"one-api/constant"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
Port = flag.Int("port", 3000, "the listening port")
|
||||
PrintVersion = flag.Bool("version", false, "print version and exit")
|
||||
PrintHelp = flag.Bool("help", false, "print help and exit")
|
||||
LogDir = flag.String("log-dir", "./logs", "specify the log directory")
|
||||
)
|
||||
|
||||
func printHelp() {
|
||||
fmt.Println("New API " + Version + " - All in one API service for OpenAI API.")
|
||||
fmt.Println("Copyright (C) 2023 JustSong. All rights reserved.")
|
||||
fmt.Println("GitHub: https://github.com/songquanpeng/one-api")
|
||||
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
||||
}
|
||||
|
||||
func InitEnv() {
|
||||
flag.Parse()
|
||||
|
||||
if *PrintVersion {
|
||||
fmt.Println(Version)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if *PrintHelp {
|
||||
printHelp()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if os.Getenv("SESSION_SECRET") != "" {
|
||||
ss := os.Getenv("SESSION_SECRET")
|
||||
if ss == "random_string" {
|
||||
log.Println("WARNING: SESSION_SECRET is set to the default value 'random_string', please change it to a random string.")
|
||||
log.Println("警告:SESSION_SECRET被设置为默认值'random_string',请修改为随机字符串。")
|
||||
log.Fatal("Please set SESSION_SECRET to a random string.")
|
||||
} else {
|
||||
SessionSecret = ss
|
||||
}
|
||||
}
|
||||
if os.Getenv("CRYPTO_SECRET") != "" {
|
||||
CryptoSecret = os.Getenv("CRYPTO_SECRET")
|
||||
} else {
|
||||
CryptoSecret = SessionSecret
|
||||
}
|
||||
if os.Getenv("SQLITE_PATH") != "" {
|
||||
SQLitePath = os.Getenv("SQLITE_PATH")
|
||||
}
|
||||
if *LogDir != "" {
|
||||
var err error
|
||||
*LogDir, err = filepath.Abs(*LogDir)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if _, err := os.Stat(*LogDir); os.IsNotExist(err) {
|
||||
err = os.Mkdir(*LogDir, 0777)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize variables from constants.go that were using environment variables
|
||||
DebugEnabled = os.Getenv("DEBUG") == "true"
|
||||
MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
|
||||
IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
||||
|
||||
// Parse requestInterval and set RequestInterval
|
||||
requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
|
||||
RequestInterval = time.Duration(requestInterval) * time.Second
|
||||
|
||||
// Initialize variables with GetEnvOrDefault
|
||||
SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60)
|
||||
BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0)
|
||||
|
||||
// Initialize string variables with GetEnvOrDefaultString
|
||||
GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||
CohereSafetySetting = GetEnvOrDefaultString("COHERE_SAFETY_SETTING", "NONE")
|
||||
|
||||
// Initialize rate limit variables
|
||||
GlobalApiRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_API_RATE_LIMIT_ENABLE", true)
|
||||
GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180)
|
||||
GlobalApiRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_API_RATE_LIMIT_DURATION", 180))
|
||||
|
||||
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
||||
|
||||
initConstantEnv()
|
||||
}
|
||||
|
||||
func initConstantEnv() {
|
||||
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120)
|
||||
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
||||
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||
constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
// 是否启用错误日志
|
||||
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
}
|
||||
22
common/json.go
Normal file
22
common/json.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
func Unmarshal(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
func UnmarshalJsonStr(data string, v any) error {
|
||||
return json.Unmarshal(StringToByteSlice(data), v)
|
||||
}
|
||||
|
||||
func DecodeJson(reader *bytes.Reader, v any) error {
|
||||
return json.NewDecoder(reader).Decode(v)
|
||||
}
|
||||
|
||||
func Marshal(v any) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
89
common/limiter/limiter.go
Normal file
89
common/limiter/limiter.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package limiter
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
)
|
||||
|
||||
//go:embed lua/rate_limit.lua
|
||||
var rateLimitScript string
|
||||
|
||||
type RedisLimiter struct {
|
||||
client *redis.Client
|
||||
limitScriptSHA string
|
||||
}
|
||||
|
||||
var (
|
||||
instance *RedisLimiter
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func New(ctx context.Context, r *redis.Client) *RedisLimiter {
|
||||
once.Do(func() {
|
||||
// 预加载脚本
|
||||
limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
|
||||
}
|
||||
instance = &RedisLimiter{
|
||||
client: r,
|
||||
limitScriptSHA: limitSHA,
|
||||
}
|
||||
})
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) {
|
||||
// 默认配置
|
||||
config := &Config{
|
||||
Capacity: 10,
|
||||
Rate: 1,
|
||||
Requested: 1,
|
||||
}
|
||||
|
||||
// 应用选项模式
|
||||
for _, opt := range opts {
|
||||
opt(config)
|
||||
}
|
||||
|
||||
// 执行限流
|
||||
result, err := rl.client.EvalSha(
|
||||
ctx,
|
||||
rl.limitScriptSHA,
|
||||
[]string{key},
|
||||
config.Requested,
|
||||
config.Rate,
|
||||
config.Capacity,
|
||||
).Int()
|
||||
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("rate limit failed: %w", err)
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
// Config 配置选项模式
|
||||
type Config struct {
|
||||
Capacity int64
|
||||
Rate int64
|
||||
Requested int64
|
||||
}
|
||||
|
||||
type Option func(*Config)
|
||||
|
||||
func WithCapacity(c int64) Option {
|
||||
return func(cfg *Config) { cfg.Capacity = c }
|
||||
}
|
||||
|
||||
func WithRate(r int64) Option {
|
||||
return func(cfg *Config) { cfg.Rate = r }
|
||||
}
|
||||
|
||||
func WithRequested(n int64) Option {
|
||||
return func(cfg *Config) { cfg.Requested = n }
|
||||
}
|
||||
44
common/limiter/lua/rate_limit.lua
Normal file
44
common/limiter/lua/rate_limit.lua
Normal file
@@ -0,0 +1,44 @@
|
||||
-- 令牌桶限流器
|
||||
-- KEYS[1]: 限流器唯一标识
|
||||
-- ARGV[1]: 请求令牌数 (通常为1)
|
||||
-- ARGV[2]: 令牌生成速率 (每秒)
|
||||
-- ARGV[3]: 桶容量
|
||||
|
||||
local key = KEYS[1]
|
||||
local requested = tonumber(ARGV[1])
|
||||
local rate = tonumber(ARGV[2])
|
||||
local capacity = tonumber(ARGV[3])
|
||||
|
||||
-- 获取当前时间(Redis服务器时间)
|
||||
local now = redis.call('TIME')
|
||||
local nowInSeconds = tonumber(now[1])
|
||||
|
||||
-- 获取桶状态
|
||||
local bucket = redis.call('HMGET', key, 'tokens', 'last_time')
|
||||
local tokens = tonumber(bucket[1])
|
||||
local last_time = tonumber(bucket[2])
|
||||
|
||||
-- 初始化桶(首次请求或过期)
|
||||
if not tokens or not last_time then
|
||||
tokens = capacity
|
||||
last_time = nowInSeconds
|
||||
else
|
||||
-- 计算新增令牌
|
||||
local elapsed = nowInSeconds - last_time
|
||||
local add_tokens = elapsed * rate
|
||||
tokens = math.min(capacity, tokens + add_tokens)
|
||||
last_time = nowInSeconds
|
||||
end
|
||||
|
||||
-- 判断是否允许请求
|
||||
local allowed = false
|
||||
if tokens >= requested then
|
||||
tokens = tokens - requested
|
||||
allowed = true
|
||||
end
|
||||
|
||||
---- 更新桶状态并设置过期时间
|
||||
redis.call('HMSET', key, 'tokens', tokens, 'last_time', last_time)
|
||||
--redis.call('EXPIRE', key, math.ceil(capacity / rate) + 60) -- 适当延长过期时间
|
||||
|
||||
return allowed and 1 or 0
|
||||
123
common/logger.go
Normal file
123
common/logger.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
loggerINFO = "INFO"
|
||||
loggerWarn = "WARN"
|
||||
loggerError = "ERR"
|
||||
)
|
||||
|
||||
const maxLogCount = 1000000
|
||||
|
||||
var logCount int
|
||||
var setupLogLock sync.Mutex
|
||||
var setupLogWorking bool
|
||||
|
||||
func SetupLogger() {
|
||||
if *LogDir != "" {
|
||||
ok := setupLogLock.TryLock()
|
||||
if !ok {
|
||||
log.Println("setup log is already working")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
setupLogLock.Unlock()
|
||||
setupLogWorking = false
|
||||
}()
|
||||
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
|
||||
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
log.Fatal("failed to open log file")
|
||||
}
|
||||
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
|
||||
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
|
||||
}
|
||||
}
|
||||
|
||||
func SysLog(s string) {
|
||||
t := time.Now()
|
||||
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||
}
|
||||
|
||||
func SysError(s string) {
|
||||
t := time.Now()
|
||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
|
||||
}
|
||||
|
||||
func LogInfo(ctx context.Context, msg string) {
|
||||
logHelper(ctx, loggerINFO, msg)
|
||||
}
|
||||
|
||||
func LogWarn(ctx context.Context, msg string) {
|
||||
logHelper(ctx, loggerWarn, msg)
|
||||
}
|
||||
|
||||
func LogError(ctx context.Context, msg string) {
|
||||
logHelper(ctx, loggerError, msg)
|
||||
}
|
||||
|
||||
func logHelper(ctx context.Context, level string, msg string) {
|
||||
writer := gin.DefaultErrorWriter
|
||||
if level == loggerINFO {
|
||||
writer = gin.DefaultWriter
|
||||
}
|
||||
id := ctx.Value(RequestIdKey)
|
||||
if id == nil {
|
||||
id = "SYSTEM"
|
||||
}
|
||||
now := time.Now()
|
||||
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
|
||||
logCount++ // we don't need accurate count, so no lock here
|
||||
if logCount > maxLogCount && !setupLogWorking {
|
||||
logCount = 0
|
||||
setupLogWorking = true
|
||||
gopool.Go(func() {
|
||||
SetupLogger()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FatalLog(v ...any) {
|
||||
t := time.Now()
|
||||
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func LogQuota(quota int) string {
|
||||
if DisplayInCurrencyEnabled {
|
||||
return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit)
|
||||
} else {
|
||||
return fmt.Sprintf("%d 点额度", quota)
|
||||
}
|
||||
}
|
||||
|
||||
func FormatQuota(quota int) string {
|
||||
if DisplayInCurrencyEnabled {
|
||||
return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit)
|
||||
} else {
|
||||
return fmt.Sprintf("%d", quota)
|
||||
}
|
||||
}
|
||||
|
||||
// LogJson 仅供测试使用 only for test
|
||||
func LogJson(ctx context.Context, msg string, obj any) {
|
||||
jsonStr, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
|
||||
}
|
||||
42
common/model.go
Normal file
42
common/model.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package common
|
||||
|
||||
import "strings"
|
||||
|
||||
var (
|
||||
// OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
|
||||
OpenAIResponseOnlyModels = []string{
|
||||
"o3-pro",
|
||||
"o3-deep-research",
|
||||
"o4-mini-deep-research",
|
||||
}
|
||||
ImageGenerationModels = []string{
|
||||
"dall-e-3",
|
||||
"dall-e-2",
|
||||
"gpt-image-1",
|
||||
"prefix:imagen-",
|
||||
"flux-",
|
||||
"flux.1-",
|
||||
}
|
||||
)
|
||||
|
||||
func IsOpenAIResponseOnlyModel(modelName string) bool {
|
||||
for _, m := range OpenAIResponseOnlyModels {
|
||||
if strings.Contains(modelName, m) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsImageGenerationModel(modelName string) bool {
|
||||
modelName = strings.ToLower(modelName)
|
||||
for _, m := range ImageGenerationModels {
|
||||
if strings.Contains(modelName, m) {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
82
common/page_info.go
Normal file
82
common/page_info.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type PageInfo struct {
|
||||
Page int `json:"page"` // page num 页码
|
||||
PageSize int `json:"page_size"` // page size 页大小
|
||||
|
||||
Total int `json:"total"` // 总条数,后设置
|
||||
Items any `json:"items"` // 数据,后设置
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetStartIdx() int {
|
||||
return (p.Page - 1) * p.PageSize
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetEndIdx() int {
|
||||
return p.Page * p.PageSize
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetPageSize() int {
|
||||
return p.PageSize
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetPage() int {
|
||||
return p.Page
|
||||
}
|
||||
|
||||
func (p *PageInfo) SetTotal(total int) {
|
||||
p.Total = total
|
||||
}
|
||||
|
||||
func (p *PageInfo) SetItems(items any) {
|
||||
p.Items = items
|
||||
}
|
||||
|
||||
func GetPageQuery(c *gin.Context) *PageInfo {
|
||||
pageInfo := &PageInfo{}
|
||||
// 手动获取并处理每个参数
|
||||
if page, err := strconv.Atoi(c.Query("page")); err == nil {
|
||||
pageInfo.Page = page
|
||||
}
|
||||
if pageSize, err := strconv.Atoi(c.Query("page_size")); err == nil {
|
||||
pageInfo.PageSize = pageSize
|
||||
}
|
||||
if pageInfo.Page < 1 {
|
||||
// 兼容
|
||||
page, _ := strconv.Atoi(c.Query("p"))
|
||||
if page != 0 {
|
||||
pageInfo.Page = page
|
||||
} else {
|
||||
pageInfo.Page = 1
|
||||
}
|
||||
}
|
||||
|
||||
if pageInfo.PageSize == 0 {
|
||||
// 兼容
|
||||
pageSize, _ := strconv.Atoi(c.Query("ps"))
|
||||
if pageSize != 0 {
|
||||
pageInfo.PageSize = pageSize
|
||||
}
|
||||
if pageInfo.PageSize == 0 {
|
||||
pageSize, _ = strconv.Atoi(c.Query("size")) // token page
|
||||
if pageSize != 0 {
|
||||
pageInfo.PageSize = pageSize
|
||||
}
|
||||
}
|
||||
if pageInfo.PageSize == 0 {
|
||||
pageInfo.PageSize = ItemsPerPage
|
||||
}
|
||||
}
|
||||
|
||||
if pageInfo.PageSize > 100 {
|
||||
pageInfo.PageSize = 100
|
||||
}
|
||||
|
||||
return pageInfo
|
||||
}
|
||||
44
common/pprof.go
Normal file
44
common/pprof.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/shirou/gopsutil/cpu"
|
||||
"os"
|
||||
"runtime/pprof"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Monitor 定时监控cpu使用率,超过阈值输出pprof文件
|
||||
func Monitor() {
|
||||
for {
|
||||
percent, err := cpu.Percent(time.Second, false)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if percent[0] > 80 {
|
||||
fmt.Println("cpu usage too high")
|
||||
// write pprof file
|
||||
if _, err := os.Stat("./pprof"); os.IsNotExist(err) {
|
||||
err := os.Mkdir("./pprof", os.ModePerm)
|
||||
if err != nil {
|
||||
SysLog("创建pprof文件夹失败 " + err.Error())
|
||||
continue
|
||||
}
|
||||
}
|
||||
f, err := os.Create("./pprof/" + fmt.Sprintf("cpu-%s.pprof", time.Now().Format("20060102150405")))
|
||||
if err != nil {
|
||||
SysLog("创建pprof文件失败 " + err.Error())
|
||||
continue
|
||||
}
|
||||
err = pprof.StartCPUProfile(f)
|
||||
if err != nil {
|
||||
SysLog("启动pprof失败 " + err.Error())
|
||||
continue
|
||||
}
|
||||
time.Sleep(10 * time.Second) // profile for 30 seconds
|
||||
pprof.StopCPUProfile()
|
||||
f.Close()
|
||||
}
|
||||
time.Sleep(30 * time.Second)
|
||||
}
|
||||
}
|
||||
70
common/rate-limit.go
Normal file
70
common/rate-limit.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type InMemoryRateLimiter struct {
|
||||
store map[string]*[]int64
|
||||
mutex sync.Mutex
|
||||
expirationDuration time.Duration
|
||||
}
|
||||
|
||||
func (l *InMemoryRateLimiter) Init(expirationDuration time.Duration) {
|
||||
if l.store == nil {
|
||||
l.mutex.Lock()
|
||||
if l.store == nil {
|
||||
l.store = make(map[string]*[]int64)
|
||||
l.expirationDuration = expirationDuration
|
||||
if expirationDuration > 0 {
|
||||
go l.clearExpiredItems()
|
||||
}
|
||||
}
|
||||
l.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *InMemoryRateLimiter) clearExpiredItems() {
|
||||
for {
|
||||
time.Sleep(l.expirationDuration)
|
||||
l.mutex.Lock()
|
||||
now := time.Now().Unix()
|
||||
for key := range l.store {
|
||||
queue := l.store[key]
|
||||
size := len(*queue)
|
||||
if size == 0 || now-(*queue)[size-1] > int64(l.expirationDuration.Seconds()) {
|
||||
delete(l.store, key)
|
||||
}
|
||||
}
|
||||
l.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Request parameter duration's unit is seconds
|
||||
func (l *InMemoryRateLimiter) Request(key string, maxRequestNum int, duration int64) bool {
|
||||
l.mutex.Lock()
|
||||
defer l.mutex.Unlock()
|
||||
// [old <-- new]
|
||||
queue, ok := l.store[key]
|
||||
now := time.Now().Unix()
|
||||
if ok {
|
||||
if len(*queue) < maxRequestNum {
|
||||
*queue = append(*queue, now)
|
||||
return true
|
||||
} else {
|
||||
if now-(*queue)[0] >= duration {
|
||||
*queue = (*queue)[1:]
|
||||
*queue = append(*queue, now)
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s := make([]int64, 0, maxRequestNum)
|
||||
l.store[key] = &s
|
||||
*(l.store[key]) = append(*(l.store[key]), now)
|
||||
}
|
||||
return true
|
||||
}
|
||||
327
common/redis.go
Normal file
327
common/redis.go
Normal file
@@ -0,0 +1,327 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var RDB *redis.Client
|
||||
var RedisEnabled = true
|
||||
|
||||
func RedisKeyCacheSeconds() int {
|
||||
return SyncFrequency
|
||||
}
|
||||
|
||||
// InitRedisClient This function is called after init()
|
||||
func InitRedisClient() (err error) {
|
||||
if os.Getenv("REDIS_CONN_STRING") == "" {
|
||||
RedisEnabled = false
|
||||
SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
|
||||
return nil
|
||||
}
|
||||
if os.Getenv("SYNC_FREQUENCY") == "" {
|
||||
SysLog("SYNC_FREQUENCY not set, use default value 60")
|
||||
SyncFrequency = 60
|
||||
}
|
||||
SysLog("Redis is enabled")
|
||||
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
|
||||
if err != nil {
|
||||
FatalLog("failed to parse Redis connection string: " + err.Error())
|
||||
}
|
||||
opt.PoolSize = GetEnvOrDefault("REDIS_POOL_SIZE", 10)
|
||||
RDB = redis.NewClient(opt)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err = RDB.Ping(ctx).Result()
|
||||
if err != nil {
|
||||
FatalLog("Redis ping test failed: " + err.Error())
|
||||
}
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis connected to %s", opt.Addr))
|
||||
SysLog(fmt.Sprintf("Redis database: %d", opt.DB))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func ParseRedisOption() *redis.Options {
|
||||
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
|
||||
if err != nil {
|
||||
FatalLog("failed to parse Redis connection string: " + err.Error())
|
||||
}
|
||||
return opt
|
||||
}
|
||||
|
||||
func RedisSet(key string, value string, expiration time.Duration) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis SET: key=%s, value=%s, expiration=%v", key, value, expiration))
|
||||
}
|
||||
ctx := context.Background()
|
||||
return RDB.Set(ctx, key, value, expiration).Err()
|
||||
}
|
||||
|
||||
func RedisGet(key string) (string, error) {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis GET: key=%s", key))
|
||||
}
|
||||
ctx := context.Background()
|
||||
val, err := RDB.Get(ctx, key).Result()
|
||||
return val, err
|
||||
}
|
||||
|
||||
//func RedisExpire(key string, expiration time.Duration) error {
|
||||
// ctx := context.Background()
|
||||
// return RDB.Expire(ctx, key, expiration).Err()
|
||||
//}
|
||||
//
|
||||
//func RedisGetEx(key string, expiration time.Duration) (string, error) {
|
||||
// ctx := context.Background()
|
||||
// return RDB.GetSet(ctx, key, expiration).Result()
|
||||
//}
|
||||
|
||||
func RedisDel(key string) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis DEL: key=%s", key))
|
||||
}
|
||||
ctx := context.Background()
|
||||
return RDB.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func RedisDelKey(key string) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis DEL Key: key=%s", key))
|
||||
}
|
||||
ctx := context.Background()
|
||||
return RDB.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis HSET: key=%s, obj=%+v, expiration=%v", key, obj, expiration))
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
data := make(map[string]interface{})
|
||||
|
||||
// 使用反射遍历结构体字段
|
||||
v := reflect.ValueOf(obj).Elem()
|
||||
t := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
value := v.Field(i)
|
||||
|
||||
// Skip DeletedAt field
|
||||
if field.Type.String() == "gorm.DeletedAt" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理指针类型
|
||||
if value.Kind() == reflect.Ptr {
|
||||
if value.IsNil() {
|
||||
data[field.Name] = ""
|
||||
continue
|
||||
}
|
||||
value = value.Elem()
|
||||
}
|
||||
|
||||
// 处理布尔类型
|
||||
if value.Kind() == reflect.Bool {
|
||||
data[field.Name] = strconv.FormatBool(value.Bool())
|
||||
continue
|
||||
}
|
||||
|
||||
// 其他类型直接转换为字符串
|
||||
data[field.Name] = fmt.Sprintf("%v", value.Interface())
|
||||
}
|
||||
|
||||
txn := RDB.TxPipeline()
|
||||
txn.HSet(ctx, key, data)
|
||||
|
||||
// 只有在 expiration 大于 0 时才设置过期时间
|
||||
if expiration > 0 {
|
||||
txn.Expire(ctx, key, expiration)
|
||||
}
|
||||
|
||||
_, err := txn.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute transaction: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RedisHGetObj(key string, obj interface{}) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis HGETALL: key=%s", key))
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := RDB.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load hash from Redis: %w", err)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return fmt.Errorf("key %s not found in Redis", key)
|
||||
}
|
||||
|
||||
// Handle both pointer and non-pointer values
|
||||
val := reflect.ValueOf(obj)
|
||||
if val.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("obj must be a pointer to a struct, got %T", obj)
|
||||
}
|
||||
|
||||
v := val.Elem()
|
||||
if v.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface())
|
||||
}
|
||||
|
||||
t := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldName := field.Name
|
||||
if value, ok := result[fieldName]; ok {
|
||||
fieldValue := v.Field(i)
|
||||
|
||||
// Handle pointer types
|
||||
if fieldValue.Kind() == reflect.Ptr {
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
|
||||
}
|
||||
fieldValue = fieldValue.Elem()
|
||||
}
|
||||
|
||||
// Enhanced type handling for Token struct
|
||||
switch fieldValue.Kind() {
|
||||
case reflect.String:
|
||||
fieldValue.SetString(value)
|
||||
case reflect.Int, reflect.Int64:
|
||||
intValue, err := strconv.ParseInt(value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse int field %s: %w", fieldName, err)
|
||||
}
|
||||
fieldValue.SetInt(intValue)
|
||||
case reflect.Bool:
|
||||
boolValue, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err)
|
||||
}
|
||||
fieldValue.SetBool(boolValue)
|
||||
case reflect.Struct:
|
||||
// Special handling for gorm.DeletedAt
|
||||
if fieldValue.Type().String() == "gorm.DeletedAt" {
|
||||
if value != "" {
|
||||
timeValue, err := time.Parse(time.RFC3339, value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err)
|
||||
}
|
||||
fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true}))
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RedisIncr Add this function to handle atomic increments
|
||||
func RedisIncr(key string, delta int64) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis INCR: key=%s, delta=%d", key, delta))
|
||||
}
|
||||
// 检查键的剩余生存时间
|
||||
ttlCmd := RDB.TTL(context.Background(), key)
|
||||
ttl, err := ttlCmd.Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return fmt.Errorf("failed to get TTL: %w", err)
|
||||
}
|
||||
|
||||
// 只有在 key 存在且有 TTL 时才需要特殊处理
|
||||
if ttl > 0 {
|
||||
ctx := context.Background()
|
||||
// 开始一个Redis事务
|
||||
txn := RDB.TxPipeline()
|
||||
|
||||
// 减少余额
|
||||
decrCmd := txn.IncrBy(ctx, key, delta)
|
||||
if err := decrCmd.Err(); err != nil {
|
||||
return err // 如果减少失败,则直接返回错误
|
||||
}
|
||||
|
||||
// 重新设置过期时间,使用原来的过期时间
|
||||
txn.Expire(ctx, key, ttl)
|
||||
|
||||
// 执行事务
|
||||
_, err = txn.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RedisHIncrBy(key, field string, delta int64) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis HINCRBY: key=%s, field=%s, delta=%d", key, field, delta))
|
||||
}
|
||||
ttlCmd := RDB.TTL(context.Background(), key)
|
||||
ttl, err := ttlCmd.Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return fmt.Errorf("failed to get TTL: %w", err)
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
ctx := context.Background()
|
||||
txn := RDB.TxPipeline()
|
||||
|
||||
incrCmd := txn.HIncrBy(ctx, key, field, delta)
|
||||
if err := incrCmd.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
txn.Expire(ctx, key, ttl)
|
||||
|
||||
_, err = txn.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RedisHSetField(key, field string, value interface{}) error {
|
||||
if DebugEnabled {
|
||||
SysLog(fmt.Sprintf("Redis HSET field: key=%s, field=%s, value=%v", key, field, value))
|
||||
}
|
||||
ttlCmd := RDB.TTL(context.Background(), key)
|
||||
ttl, err := ttlCmd.Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return fmt.Errorf("failed to get TTL: %w", err)
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
ctx := context.Background()
|
||||
txn := RDB.TxPipeline()
|
||||
|
||||
hsetCmd := txn.HSet(ctx, key, field, value)
|
||||
if err := hsetCmd.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
txn.Expire(ctx, key, ttl)
|
||||
|
||||
_, err = txn.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
97
common/str.go
Normal file
97
common/str.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func GetStringIfEmpty(str string, defaultValue string) string {
|
||||
if str == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
func GetRandomString(length int) string {
|
||||
//rand.Seed(time.Now().UnixNano())
|
||||
key := make([]byte, length)
|
||||
for i := 0; i < length; i++ {
|
||||
key[i] = keyChars[rand.Intn(len(keyChars))]
|
||||
}
|
||||
return string(key)
|
||||
}
|
||||
|
||||
func MapToJsonStr(m map[string]interface{}) string {
|
||||
bytes, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
func StrToMap(str string) (map[string]interface{}, error) {
|
||||
m := make(map[string]interface{})
|
||||
err := Unmarshal([]byte(str), &m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func StrToJsonArray(str string) ([]interface{}, error) {
|
||||
var js []interface{}
|
||||
err := json.Unmarshal([]byte(str), &js)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return js, nil
|
||||
}
|
||||
|
||||
func IsJsonArray(str string) bool {
|
||||
var js []interface{}
|
||||
return json.Unmarshal([]byte(str), &js) == nil
|
||||
}
|
||||
|
||||
func IsJsonObject(str string) bool {
|
||||
var js map[string]interface{}
|
||||
return json.Unmarshal([]byte(str), &js) == nil
|
||||
}
|
||||
|
||||
func String2Int(str string) int {
|
||||
num, err := strconv.Atoi(str)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
||||
func StringsContains(strs []string, str string) bool {
|
||||
for _, s := range strs {
|
||||
if s == str {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// StringToByteSlice []byte only read, panic on append
|
||||
func StringToByteSlice(s string) []byte {
|
||||
tmp1 := (*[2]uintptr)(unsafe.Pointer(&s))
|
||||
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
|
||||
return *(*[]byte)(unsafe.Pointer(&tmp2))
|
||||
}
|
||||
|
||||
func EncodeBase64(str string) string {
|
||||
return base64.StdEncoding.EncodeToString([]byte(str))
|
||||
}
|
||||
|
||||
func GetJsonString(data any) string {
|
||||
if data == nil {
|
||||
return ""
|
||||
}
|
||||
b, _ := json.Marshal(data)
|
||||
return string(b)
|
||||
}
|
||||
33
common/topup-ratio.go
Normal file
33
common/topup-ratio.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
var TopupGroupRatio = map[string]float64{
|
||||
"default": 1,
|
||||
"vip": 1,
|
||||
"svip": 1,
|
||||
}
|
||||
|
||||
func TopupGroupRatio2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(TopupGroupRatio)
|
||||
if err != nil {
|
||||
SysError("error marshalling model ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateTopupGroupRatioByJSONString(jsonStr string) error {
|
||||
TopupGroupRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &TopupGroupRatio)
|
||||
}
|
||||
|
||||
func GetTopupGroupRatio(name string) float64 {
|
||||
ratio, ok := TopupGroupRatio[name]
|
||||
if !ok {
|
||||
SysError("topup group ratio not found: " + name)
|
||||
return 1
|
||||
}
|
||||
return ratio
|
||||
}
|
||||
304
common/utils.go
Normal file
304
common/utils.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"log"
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func OpenBrowser(url string) {
|
||||
var err error
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
err = exec.Command("xdg-open", url).Start()
|
||||
case "windows":
|
||||
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||
case "darwin":
|
||||
err = exec.Command("open", url).Start()
|
||||
}
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
|
||||
func GetIp() (ip string) {
|
||||
ips, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return ip
|
||||
}
|
||||
|
||||
for _, a := range ips {
|
||||
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
|
||||
if ipNet.IP.To4() != nil {
|
||||
ip = ipNet.IP.String()
|
||||
if strings.HasPrefix(ip, "10") {
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(ip, "172") {
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(ip, "192.168") {
|
||||
return
|
||||
}
|
||||
ip = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var sizeKB = 1024
|
||||
var sizeMB = sizeKB * 1024
|
||||
var sizeGB = sizeMB * 1024
|
||||
|
||||
func Bytes2Size(num int64) string {
|
||||
numStr := ""
|
||||
unit := "B"
|
||||
if num/int64(sizeGB) > 1 {
|
||||
numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
|
||||
unit = "GB"
|
||||
} else if num/int64(sizeMB) > 1 {
|
||||
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
|
||||
unit = "MB"
|
||||
} else if num/int64(sizeKB) > 1 {
|
||||
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
|
||||
unit = "KB"
|
||||
} else {
|
||||
numStr = fmt.Sprintf("%d", num)
|
||||
}
|
||||
return numStr + " " + unit
|
||||
}
|
||||
|
||||
func Seconds2Time(num int) (time string) {
|
||||
if num/31104000 > 0 {
|
||||
time += strconv.Itoa(num/31104000) + " 年 "
|
||||
num %= 31104000
|
||||
}
|
||||
if num/2592000 > 0 {
|
||||
time += strconv.Itoa(num/2592000) + " 个月 "
|
||||
num %= 2592000
|
||||
}
|
||||
if num/86400 > 0 {
|
||||
time += strconv.Itoa(num/86400) + " 天 "
|
||||
num %= 86400
|
||||
}
|
||||
if num/3600 > 0 {
|
||||
time += strconv.Itoa(num/3600) + " 小时 "
|
||||
num %= 3600
|
||||
}
|
||||
if num/60 > 0 {
|
||||
time += strconv.Itoa(num/60) + " 分钟 "
|
||||
num %= 60
|
||||
}
|
||||
time += strconv.Itoa(num) + " 秒"
|
||||
return
|
||||
}
|
||||
|
||||
func Interface2String(inter interface{}) string {
|
||||
switch inter.(type) {
|
||||
case string:
|
||||
return inter.(string)
|
||||
case int:
|
||||
return fmt.Sprintf("%d", inter.(int))
|
||||
case float64:
|
||||
return fmt.Sprintf("%f", inter.(float64))
|
||||
}
|
||||
return "Not Implemented"
|
||||
}
|
||||
|
||||
func UnescapeHTML(x string) interface{} {
|
||||
return template.HTML(x)
|
||||
}
|
||||
|
||||
func IntMax(a int, b int) int {
|
||||
if a >= b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func IsIP(s string) bool {
|
||||
ip := net.ParseIP(s)
|
||||
return ip != nil
|
||||
}
|
||||
|
||||
func GetUUID() string {
|
||||
code := uuid.New().String()
|
||||
code = strings.Replace(code, "-", "", -1)
|
||||
return code
|
||||
}
|
||||
|
||||
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
func init() {
|
||||
rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
}
|
||||
|
||||
func GenerateRandomCharsKey(length int) (string, error) {
|
||||
b := make([]byte, length)
|
||||
maxI := big.NewInt(int64(len(keyChars)))
|
||||
|
||||
for i := range b {
|
||||
n, err := crand.Int(crand.Reader, maxI)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b[i] = keyChars[n.Int64()]
|
||||
}
|
||||
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func GenerateRandomKey(length int) (string, error) {
|
||||
bytes := make([]byte, length*3/4) // 对于48位的输出,这里应该是36
|
||||
if _, err := crand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func GenerateKey() (string, error) {
|
||||
//rand.Seed(time.Now().UnixNano())
|
||||
return GenerateRandomCharsKey(48)
|
||||
}
|
||||
|
||||
func GetRandomInt(max int) int {
|
||||
//rand.Seed(time.Now().UnixNano())
|
||||
return rand.Intn(max)
|
||||
}
|
||||
|
||||
func GetTimestamp() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
func GetTimeString() string {
|
||||
now := time.Now()
|
||||
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
|
||||
}
|
||||
|
||||
func Max(a int, b int) int {
|
||||
if a >= b {
|
||||
return a
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
|
||||
func MessageWithRequestId(message string, id string) string {
|
||||
return fmt.Sprintf("%s (request id: %s)", message, id)
|
||||
}
|
||||
|
||||
func RandomSleep() {
|
||||
// Sleep for 0-3000 ms
|
||||
time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond)
|
||||
}
|
||||
|
||||
func GetPointer[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
func Any2Type[T any](data any) (T, error) {
|
||||
var zero T
|
||||
bytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
var res T
|
||||
err = json.Unmarshal(bytes, &res)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
|
||||
func SaveTmpFile(filename string, data io.Reader) (string, error) {
|
||||
f, err := os.CreateTemp(os.TempDir(), filename)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "failed to create temporary file %s", filename)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
_, err = io.Copy(f, data)
|
||||
if err != nil {
|
||||
return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename)
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
}
|
||||
|
||||
// GetAudioDuration returns the duration of an audio file in seconds.
|
||||
func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
|
||||
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
|
||||
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
|
||||
output, err := c.Output()
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||
}
|
||||
durationStr := string(bytes.TrimSpace(output))
|
||||
if durationStr == "N/A" {
|
||||
// Create a temporary output file name
|
||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to create temporary file")
|
||||
}
|
||||
tmpName := tmpFp.Name()
|
||||
// Close immediately so ffmpeg can open the file on Windows.
|
||||
_ = tmpFp.Close()
|
||||
defer os.Remove(tmpName)
|
||||
|
||||
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
||||
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
||||
if err := ffmpegCmd.Run(); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
||||
}
|
||||
|
||||
// Recalculate the duration of the new file
|
||||
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
||||
output, err := c.Output()
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
||||
}
|
||||
durationStr = string(bytes.TrimSpace(output))
|
||||
}
|
||||
return strconv.ParseFloat(durationStr, 64)
|
||||
}
|
||||
|
||||
// BuildURL concatenates base and endpoint, returns the complete url string
|
||||
func BuildURL(base string, endpoint string) string {
|
||||
u, err := url.Parse(base)
|
||||
if err != nil {
|
||||
return base + endpoint
|
||||
}
|
||||
end := endpoint
|
||||
if end == "" {
|
||||
end = "/"
|
||||
}
|
||||
ref, err := url.Parse(end)
|
||||
if err != nil {
|
||||
return base + endpoint
|
||||
}
|
||||
return u.ResolveReference(ref).String()
|
||||
}
|
||||
9
common/validate.go
Normal file
9
common/validate.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package common
|
||||
|
||||
import "github.com/go-playground/validator/v10"
|
||||
|
||||
var Validate *validator.Validate
|
||||
|
||||
func init() {
|
||||
Validate = validator.New()
|
||||
}
|
||||
77
common/verification.go
Normal file
77
common/verification.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type verificationValue struct {
|
||||
code string
|
||||
time time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
EmailVerificationPurpose = "v"
|
||||
PasswordResetPurpose = "r"
|
||||
)
|
||||
|
||||
var verificationMutex sync.Mutex
|
||||
var verificationMap map[string]verificationValue
|
||||
var verificationMapMaxSize = 10
|
||||
var VerificationValidMinutes = 10
|
||||
|
||||
func GenerateVerificationCode(length int) string {
|
||||
code := uuid.New().String()
|
||||
code = strings.Replace(code, "-", "", -1)
|
||||
if length == 0 {
|
||||
return code
|
||||
}
|
||||
return code[:length]
|
||||
}
|
||||
|
||||
func RegisterVerificationCodeWithKey(key string, code string, purpose string) {
|
||||
verificationMutex.Lock()
|
||||
defer verificationMutex.Unlock()
|
||||
verificationMap[purpose+key] = verificationValue{
|
||||
code: code,
|
||||
time: time.Now(),
|
||||
}
|
||||
if len(verificationMap) > verificationMapMaxSize {
|
||||
removeExpiredPairs()
|
||||
}
|
||||
}
|
||||
|
||||
func VerifyCodeWithKey(key string, code string, purpose string) bool {
|
||||
verificationMutex.Lock()
|
||||
defer verificationMutex.Unlock()
|
||||
value, okay := verificationMap[purpose+key]
|
||||
now := time.Now()
|
||||
if !okay || int(now.Sub(value.time).Seconds()) >= VerificationValidMinutes*60 {
|
||||
return false
|
||||
}
|
||||
return code == value.code
|
||||
}
|
||||
|
||||
func DeleteKey(key string, purpose string) {
|
||||
verificationMutex.Lock()
|
||||
defer verificationMutex.Unlock()
|
||||
delete(verificationMap, purpose+key)
|
||||
}
|
||||
|
||||
// no lock inside, so the caller must lock the verificationMap before calling!
|
||||
func removeExpiredPairs() {
|
||||
now := time.Now()
|
||||
for key := range verificationMap {
|
||||
if int(now.Sub(verificationMap[key].time).Seconds()) >= VerificationValidMinutes*60 {
|
||||
delete(verificationMap, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
verificationMutex.Lock()
|
||||
defer verificationMutex.Unlock()
|
||||
verificationMap = make(map[string]verificationValue)
|
||||
}
|
||||
26
constant/README.md
Normal file
26
constant/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# constant 包 (`/constant`)
|
||||
|
||||
该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
|
||||
|
||||
## 当前文件
|
||||
|
||||
| 文件 | 说明 |
|
||||
|----------------------|---------------------------------------------------------------------|
|
||||
| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 |
|
||||
| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 |
|
||||
| `channel_setting.go` | Channel 级别的设置键,如 `proxy`、`force_format` 等。 |
|
||||
| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量(请求时间、Token/Channel/User 相关信息等)。 |
|
||||
| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 |
|
||||
| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 |
|
||||
| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 |
|
||||
| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 |
|
||||
| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 |
|
||||
| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 |
|
||||
|
||||
## 使用约定
|
||||
|
||||
1. `constant` 包**只能被其他包引用**(import),**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**。
|
||||
2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
|
||||
3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
|
||||
|
||||
> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。
|
||||
35
constant/api_type.go
Normal file
35
constant/api_type.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
APITypeOpenAI = iota
|
||||
APITypeAnthropic
|
||||
APITypePaLM
|
||||
APITypeBaidu
|
||||
APITypeZhipu
|
||||
APITypeAli
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
APITypeZhipuV4
|
||||
APITypeOllama
|
||||
APITypePerplexity
|
||||
APITypeAws
|
||||
APITypeCohere
|
||||
APITypeDify
|
||||
APITypeJina
|
||||
APITypeCloudflare
|
||||
APITypeSiliconFlow
|
||||
APITypeVertexAi
|
||||
APITypeMistral
|
||||
APITypeDeepSeek
|
||||
APITypeMokaAI
|
||||
APITypeVolcEngine
|
||||
APITypeBaiduV2
|
||||
APITypeOpenRouter
|
||||
APITypeXinference
|
||||
APITypeXai
|
||||
APITypeCoze
|
||||
APITypeJimeng
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
5
constant/azure.go
Normal file
5
constant/azure.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package constant
|
||||
|
||||
import "time"
|
||||
|
||||
var AzureNoRemoveDotTime = time.Date(2025, time.May, 10, 0, 0, 0, 0, time.UTC).Unix()
|
||||
14
constant/cache_key.go
Normal file
14
constant/cache_key.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package constant
|
||||
|
||||
// Cache keys
|
||||
const (
|
||||
UserGroupKeyFmt = "user_group:%d"
|
||||
UserQuotaKeyFmt = "user_quota:%d"
|
||||
UserEnabledKeyFmt = "user_enabled:%d"
|
||||
UserUsernameKeyFmt = "user_name:%d"
|
||||
)
|
||||
|
||||
const (
|
||||
TokenFiledRemainQuota = "RemainQuota"
|
||||
TokenFieldGroup = "Group"
|
||||
)
|
||||
109
constant/channel.go
Normal file
109
constant/channel.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
ChannelTypeUnknown = 0
|
||||
ChannelTypeOpenAI = 1
|
||||
ChannelTypeMidjourney = 2
|
||||
ChannelTypeAzure = 3
|
||||
ChannelTypeOllama = 4
|
||||
ChannelTypeMidjourneyPlus = 5
|
||||
ChannelTypeOpenAIMax = 6
|
||||
ChannelTypeOhMyGPT = 7
|
||||
ChannelTypeCustom = 8
|
||||
ChannelTypeAILS = 9
|
||||
ChannelTypeAIProxy = 10
|
||||
ChannelTypePaLM = 11
|
||||
ChannelTypeAPI2GPT = 12
|
||||
ChannelTypeAIGC2D = 13
|
||||
ChannelTypeAnthropic = 14
|
||||
ChannelTypeBaidu = 15
|
||||
ChannelTypeZhipu = 16
|
||||
ChannelTypeAli = 17
|
||||
ChannelTypeXunfei = 18
|
||||
ChannelType360 = 19
|
||||
ChannelTypeOpenRouter = 20
|
||||
ChannelTypeAIProxyLibrary = 21
|
||||
ChannelTypeFastGPT = 22
|
||||
ChannelTypeTencent = 23
|
||||
ChannelTypeGemini = 24
|
||||
ChannelTypeMoonshot = 25
|
||||
ChannelTypeZhipu_v4 = 26
|
||||
ChannelTypePerplexity = 27
|
||||
ChannelTypeLingYiWanWu = 31
|
||||
ChannelTypeAws = 33
|
||||
ChannelTypeCohere = 34
|
||||
ChannelTypeMiniMax = 35
|
||||
ChannelTypeSunoAPI = 36
|
||||
ChannelTypeDify = 37
|
||||
ChannelTypeJina = 38
|
||||
ChannelCloudflare = 39
|
||||
ChannelTypeSiliconFlow = 40
|
||||
ChannelTypeVertexAi = 41
|
||||
ChannelTypeMistral = 42
|
||||
ChannelTypeDeepSeek = 43
|
||||
ChannelTypeMokaAI = 44
|
||||
ChannelTypeVolcEngine = 45
|
||||
ChannelTypeBaiduV2 = 46
|
||||
ChannelTypeXinference = 47
|
||||
ChannelTypeXai = 48
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeKling = 50
|
||||
ChannelTypeJimeng = 51
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://api.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.tencentcloudapi.com", //23
|
||||
"https://generativelanguage.googleapis.com", //24
|
||||
"https://api.moonshot.cn", //25
|
||||
"https://open.bigmodel.cn", //26
|
||||
"https://api.perplexity.ai", //27
|
||||
"", //28
|
||||
"", //29
|
||||
"", //30
|
||||
"https://api.lingyiwanwu.com", //31
|
||||
"", //32
|
||||
"", //33
|
||||
"https://api.cohere.ai", //34
|
||||
"https://api.minimax.chat", //35
|
||||
"", //36
|
||||
"https://api.dify.ai", //37
|
||||
"https://api.jina.ai", //38
|
||||
"https://api.cloudflare.com", //39
|
||||
"https://api.siliconflow.cn", //40
|
||||
"", //41
|
||||
"https://api.mistral.ai", //42
|
||||
"https://api.deepseek.com", //43
|
||||
"https://api.moka.ai", //44
|
||||
"https://ark.cn-beijing.volces.com", //45
|
||||
"https://qianfan.baidubce.com", //46
|
||||
"", //47
|
||||
"https://api.x.ai", //48
|
||||
"https://api.coze.cn", //49
|
||||
"https://api.klingai.com", //50
|
||||
"https://visual.volcengineapi.com", //51
|
||||
}
|
||||
44
constant/context_key.go
Normal file
44
constant/context_key.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package constant
|
||||
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
ContextKeyOriginalModel ContextKey = "original_model"
|
||||
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
||||
|
||||
/* token related keys */
|
||||
ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota"
|
||||
ContextKeyTokenKey ContextKey = "token_key"
|
||||
ContextKeyTokenId ContextKey = "token_id"
|
||||
ContextKeyTokenGroup ContextKey = "token_group"
|
||||
ContextKeyTokenAllowIps ContextKey = "allow_ips"
|
||||
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
||||
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
||||
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||
|
||||
/* channel related keys */
|
||||
ContextKeyChannelId ContextKey = "channel_id"
|
||||
ContextKeyChannelName ContextKey = "channel_name"
|
||||
ContextKeyChannelCreateTime ContextKey = "channel_create_time"
|
||||
ContextKeyChannelBaseUrl ContextKey = "base_url"
|
||||
ContextKeyChannelType ContextKey = "channel_type"
|
||||
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||
ContextKeyChannelParamOverride ContextKey = "param_override"
|
||||
ContextKeyChannelOrganization ContextKey = "channel_organization"
|
||||
ContextKeyChannelAutoBan ContextKey = "auto_ban"
|
||||
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
||||
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
|
||||
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
|
||||
ContextKeyChannelMultiKeyIndex ContextKey = "channel_multi_key_index"
|
||||
ContextKeyChannelKey ContextKey = "channel_key"
|
||||
|
||||
/* user related keys */
|
||||
ContextKeyUserId ContextKey = "id"
|
||||
ContextKeyUserSetting ContextKey = "user_setting"
|
||||
ContextKeyUserQuota ContextKey = "user_quota"
|
||||
ContextKeyUserStatus ContextKey = "user_status"
|
||||
ContextKeyUserEmail ContextKey = "user_email"
|
||||
ContextKeyUserGroup ContextKey = "user_group"
|
||||
ContextKeyUsingGroup ContextKey = "group"
|
||||
ContextKeyUserName ContextKey = "username"
|
||||
)
|
||||
16
constant/endpoint_type.go
Normal file
16
constant/endpoint_type.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package constant
|
||||
|
||||
type EndpointType string
|
||||
|
||||
const (
|
||||
EndpointTypeOpenAI EndpointType = "openai"
|
||||
EndpointTypeOpenAIResponse EndpointType = "openai-response"
|
||||
EndpointTypeAnthropic EndpointType = "anthropic"
|
||||
EndpointTypeGemini EndpointType = "gemini"
|
||||
EndpointTypeJinaRerank EndpointType = "jina-rerank"
|
||||
EndpointTypeImageGeneration EndpointType = "image-generation"
|
||||
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
|
||||
//EndpointTypeSuno EndpointType = "suno-proxy"
|
||||
//EndpointTypeKling EndpointType = "kling"
|
||||
//EndpointTypeJimeng EndpointType = "jimeng"
|
||||
)
|
||||
15
constant/env.go
Normal file
15
constant/env.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package constant
|
||||
|
||||
var StreamingTimeout int
|
||||
var DifyDebug bool
|
||||
var MaxFileDownloadMB int
|
||||
var ForceStreamOption bool
|
||||
var GetMediaToken bool
|
||||
var GetMediaTokenNotStream bool
|
||||
var UpdateTask bool
|
||||
var AzureDefaultAPIVersion string
|
||||
var GeminiVisionMaxImageNum int
|
||||
var NotifyLimitCount int
|
||||
var NotificationLimitDurationMinute int
|
||||
var GenerateDefaultToken bool
|
||||
var ErrorLogEnabled bool
|
||||
9
constant/finish_reason.go
Normal file
9
constant/finish_reason.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package constant
|
||||
|
||||
var (
|
||||
FinishReasonStop = "stop"
|
||||
FinishReasonToolCalls = "tool_calls"
|
||||
FinishReasonLength = "length"
|
||||
FinishReasonFunctionCall = "function_call"
|
||||
FinishReasonContentFilter = "content_filter"
|
||||
)
|
||||
48
constant/midjourney.go
Normal file
48
constant/midjourney.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
MjErrorUnknown = 5
|
||||
MjRequestError = 4
|
||||
)
|
||||
|
||||
const (
|
||||
MjActionImagine = "IMAGINE"
|
||||
MjActionDescribe = "DESCRIBE"
|
||||
MjActionBlend = "BLEND"
|
||||
MjActionUpscale = "UPSCALE"
|
||||
MjActionVariation = "VARIATION"
|
||||
MjActionReRoll = "REROLL"
|
||||
MjActionInPaint = "INPAINT"
|
||||
MjActionModal = "MODAL"
|
||||
MjActionZoom = "ZOOM"
|
||||
MjActionCustomZoom = "CUSTOM_ZOOM"
|
||||
MjActionShorten = "SHORTEN"
|
||||
MjActionHighVariation = "HIGH_VARIATION"
|
||||
MjActionLowVariation = "LOW_VARIATION"
|
||||
MjActionPan = "PAN"
|
||||
MjActionSwapFace = "SWAP_FACE"
|
||||
MjActionUpload = "UPLOAD"
|
||||
MjActionVideo = "VIDEO"
|
||||
MjActionEdits = "EDITS"
|
||||
)
|
||||
|
||||
var MidjourneyModel2Action = map[string]string{
|
||||
"mj_imagine": MjActionImagine,
|
||||
"mj_describe": MjActionDescribe,
|
||||
"mj_blend": MjActionBlend,
|
||||
"mj_upscale": MjActionUpscale,
|
||||
"mj_variation": MjActionVariation,
|
||||
"mj_reroll": MjActionReRoll,
|
||||
"mj_modal": MjActionModal,
|
||||
"mj_inpaint": MjActionInPaint,
|
||||
"mj_zoom": MjActionZoom,
|
||||
"mj_custom_zoom": MjActionCustomZoom,
|
||||
"mj_shorten": MjActionShorten,
|
||||
"mj_high_variation": MjActionHighVariation,
|
||||
"mj_low_variation": MjActionLowVariation,
|
||||
"mj_pan": MjActionPan,
|
||||
"swap_face": MjActionSwapFace,
|
||||
"mj_upload": MjActionUpload,
|
||||
"mj_video": MjActionVideo,
|
||||
"mj_edits": MjActionEdits,
|
||||
}
|
||||
8
constant/multi_key_mode.go
Normal file
8
constant/multi_key_mode.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package constant
|
||||
|
||||
type MultiKeyMode string
|
||||
|
||||
const (
|
||||
MultiKeyModeRandom MultiKeyMode = "random" // 随机
|
||||
MultiKeyModePolling MultiKeyMode = "polling" // 轮询
|
||||
)
|
||||
3
constant/setup.go
Normal file
3
constant/setup.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package constant
|
||||
|
||||
var Setup = false
|
||||
23
constant/task.go
Normal file
23
constant/task.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package constant
|
||||
|
||||
type TaskPlatform string
|
||||
|
||||
const (
|
||||
TaskPlatformSuno TaskPlatform = "suno"
|
||||
TaskPlatformMidjourney = "mj"
|
||||
TaskPlatformKling TaskPlatform = "kling"
|
||||
TaskPlatformJimeng TaskPlatform = "jimeng"
|
||||
)
|
||||
|
||||
const (
|
||||
SunoActionMusic = "MUSIC"
|
||||
SunoActionLyrics = "LYRICS"
|
||||
|
||||
TaskActionGenerate = "generate"
|
||||
TaskActionTextGenerate = "textGenerate"
|
||||
)
|
||||
|
||||
var SunoModel2Action = map[string]string{
|
||||
"suno_music": SunoActionMusic,
|
||||
"suno_lyrics": SunoActionLyrics,
|
||||
}
|
||||
92
controller/billing.go
Normal file
92
controller/billing.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
)
|
||||
|
||||
func GetSubscription(c *gin.Context) {
|
||||
var remainQuota int
|
||||
var usedQuota int
|
||||
var err error
|
||||
var token *model.Token
|
||||
var expiredTime int64
|
||||
if common.DisplayTokenStatEnabled {
|
||||
tokenId := c.GetInt("token_id")
|
||||
token, err = model.GetTokenById(tokenId)
|
||||
expiredTime = token.ExpiredTime
|
||||
remainQuota = token.RemainQuota
|
||||
usedQuota = token.UsedQuota
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
remainQuota, err = model.GetUserQuota(userId, false)
|
||||
usedQuota, err = model.GetUserUsedQuota(userId)
|
||||
}
|
||||
if expiredTime <= 0 {
|
||||
expiredTime = 0
|
||||
}
|
||||
if err != nil {
|
||||
openAIError := dto.OpenAIError{
|
||||
Message: err.Error(),
|
||||
Type: "upstream_error",
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"error": openAIError,
|
||||
})
|
||||
return
|
||||
}
|
||||
quota := remainQuota + usedQuota
|
||||
amount := float64(quota)
|
||||
if common.DisplayInCurrencyEnabled {
|
||||
amount /= common.QuotaPerUnit
|
||||
}
|
||||
if token != nil && token.UnlimitedQuota {
|
||||
amount = 100000000
|
||||
}
|
||||
subscription := OpenAISubscriptionResponse{
|
||||
Object: "billing_subscription",
|
||||
HasPaymentMethod: true,
|
||||
SoftLimitUSD: amount,
|
||||
HardLimitUSD: amount,
|
||||
SystemHardLimitUSD: amount,
|
||||
AccessUntil: expiredTime,
|
||||
}
|
||||
c.JSON(200, subscription)
|
||||
return
|
||||
}
|
||||
|
||||
func GetUsage(c *gin.Context) {
|
||||
var quota int
|
||||
var err error
|
||||
var token *model.Token
|
||||
if common.DisplayTokenStatEnabled {
|
||||
tokenId := c.GetInt("token_id")
|
||||
token, err = model.GetTokenById(tokenId)
|
||||
quota = token.UsedQuota
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
quota, err = model.GetUserUsedQuota(userId)
|
||||
}
|
||||
if err != nil {
|
||||
openAIError := dto.OpenAIError{
|
||||
Message: err.Error(),
|
||||
Type: "new_api_error",
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"error": openAIError,
|
||||
})
|
||||
return
|
||||
}
|
||||
amount := float64(quota)
|
||||
if common.DisplayInCurrencyEnabled {
|
||||
amount /= common.QuotaPerUnit
|
||||
}
|
||||
usage := OpenAIUsageResponse{
|
||||
Object: "list",
|
||||
TotalUsage: amount * 100,
|
||||
}
|
||||
c.JSON(200, usage)
|
||||
return
|
||||
}
|
||||
492
controller/channel-billing.go
Normal file
492
controller/channel-billing.go
Normal file
@@ -0,0 +1,492 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://github.com/songquanpeng/one-api/issues/79
|
||||
|
||||
type OpenAISubscriptionResponse struct {
|
||||
Object string `json:"object"`
|
||||
HasPaymentMethod bool `json:"has_payment_method"`
|
||||
SoftLimitUSD float64 `json:"soft_limit_usd"`
|
||||
HardLimitUSD float64 `json:"hard_limit_usd"`
|
||||
SystemHardLimitUSD float64 `json:"system_hard_limit_usd"`
|
||||
AccessUntil int64 `json:"access_until"`
|
||||
}
|
||||
|
||||
type OpenAIUsageDailyCost struct {
|
||||
Timestamp float64 `json:"timestamp"`
|
||||
LineItems []struct {
|
||||
Name string `json:"name"`
|
||||
Cost float64 `json:"cost"`
|
||||
}
|
||||
}
|
||||
|
||||
type OpenAICreditGrants struct {
|
||||
Object string `json:"object"`
|
||||
TotalGranted float64 `json:"total_granted"`
|
||||
TotalUsed float64 `json:"total_used"`
|
||||
TotalAvailable float64 `json:"total_available"`
|
||||
}
|
||||
|
||||
type OpenAIUsageResponse struct {
|
||||
Object string `json:"object"`
|
||||
//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
|
||||
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
|
||||
}
|
||||
|
||||
type OpenAISBUsageResponse struct {
|
||||
Msg string `json:"msg"`
|
||||
Data *struct {
|
||||
Credit string `json:"credit"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type AIProxyUserOverviewResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
ErrorCode int `json:"error_code"`
|
||||
Data struct {
|
||||
TotalPoints float64 `json:"totalPoints"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type API2GPTUsageResponse struct {
|
||||
Object string `json:"object"`
|
||||
TotalGranted float64 `json:"total_granted"`
|
||||
TotalUsed float64 `json:"total_used"`
|
||||
TotalRemaining float64 `json:"total_remaining"`
|
||||
}
|
||||
|
||||
type APGC2DGPTUsageResponse struct {
|
||||
//Grants interface{} `json:"grants"`
|
||||
Object string `json:"object"`
|
||||
TotalAvailable float64 `json:"total_available"`
|
||||
TotalGranted float64 `json:"total_granted"`
|
||||
TotalUsed float64 `json:"total_used"`
|
||||
}
|
||||
|
||||
type SiliconFlowUsageResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status bool `json:"status"`
|
||||
Data struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Image string `json:"image"`
|
||||
Email string `json:"email"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
Balance string `json:"balance"`
|
||||
Status string `json:"status"`
|
||||
Introduction string `json:"introduction"`
|
||||
Role string `json:"role"`
|
||||
ChargeBalance string `json:"chargeBalance"`
|
||||
TotalBalance string `json:"totalBalance"`
|
||||
Category string `json:"category"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type DeepSeekUsageResponse struct {
|
||||
IsAvailable bool `json:"is_available"`
|
||||
BalanceInfos []struct {
|
||||
Currency string `json:"currency"`
|
||||
TotalBalance string `json:"total_balance"`
|
||||
GrantedBalance string `json:"granted_balance"`
|
||||
ToppedUpBalance string `json:"topped_up_balance"`
|
||||
} `json:"balance_infos"`
|
||||
}
|
||||
|
||||
type OpenRouterCreditResponse struct {
|
||||
Data struct {
|
||||
TotalCredits float64 `json:"total_credits"`
|
||||
TotalUsage float64 `json:"total_usage"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// GetAuthHeader get auth header
|
||||
func GetAuthHeader(token string) http.Header {
|
||||
h := http.Header{}
|
||||
h.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
return h
|
||||
}
|
||||
|
||||
func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) {
|
||||
req, err := http.NewRequest(method, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k := range headers {
|
||||
req.Header.Add(k, headers.Get(k))
|
||||
}
|
||||
res, err := service.GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("status code: %d", res.StatusCode)
|
||||
}
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
|
||||
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := OpenAICreditGrants{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
channel.UpdateBalance(response.TotalAvailable)
|
||||
return response.TotalAvailable, nil
|
||||
}
|
||||
|
||||
func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) {
|
||||
url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key)
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := OpenAISBUsageResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if response.Data == nil {
|
||||
return 0, errors.New(response.Msg)
|
||||
}
|
||||
balance, err := strconv.ParseFloat(response.Data.Credit, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
channel.UpdateBalance(balance)
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://aiproxy.io/api/report/getUserOverview"
|
||||
headers := http.Header{}
|
||||
headers.Add("Api-Key", channel.Key)
|
||||
body, err := GetResponseBody("GET", url, channel, headers)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := AIProxyUserOverviewResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !response.Success {
|
||||
return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message)
|
||||
}
|
||||
channel.UpdateBalance(response.Data.TotalPoints)
|
||||
return response.Data.TotalPoints, nil
|
||||
}
|
||||
|
||||
func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://api.api2gpt.com/dashboard/billing/credit_grants"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := API2GPTUsageResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
channel.UpdateBalance(response.TotalRemaining)
|
||||
return response.TotalRemaining, nil
|
||||
}
|
||||
|
||||
func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://api.siliconflow.cn/v1/user/info"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := SiliconFlowUsageResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if response.Code != 20000 {
|
||||
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
|
||||
}
|
||||
balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
channel.UpdateBalance(balance)
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://api.deepseek.com/user/balance"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := DeepSeekUsageResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
index := -1
|
||||
for i, balanceInfo := range response.BalanceInfos {
|
||||
if balanceInfo.Currency == "CNY" {
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if index == -1 {
|
||||
return 0, errors.New("currency CNY not found")
|
||||
}
|
||||
balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
channel.UpdateBalance(balance)
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://api.aigc2d.com/dashboard/billing/credit_grants"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := APGC2DGPTUsageResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
channel.UpdateBalance(response.TotalAvailable)
|
||||
return response.TotalAvailable, nil
|
||||
}
|
||||
|
||||
func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://openrouter.ai/api/v1/credits"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
response := OpenRouterCreditResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
balance := response.Data.TotalCredits - response.Data.TotalUsage
|
||||
channel.UpdateBalance(balance)
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://api.moonshot.cn/v1/users/me/balance"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
type MoonshotBalanceData struct {
|
||||
AvailableBalance float64 `json:"available_balance"`
|
||||
VoucherBalance float64 `json:"voucher_balance"`
|
||||
CashBalance float64 `json:"cash_balance"`
|
||||
}
|
||||
|
||||
type MoonshotBalanceResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data MoonshotBalanceData `json:"data"`
|
||||
Scode string `json:"scode"`
|
||||
Status bool `json:"status"`
|
||||
}
|
||||
|
||||
response := MoonshotBalanceResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !response.Status || response.Code != 0 {
|
||||
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
||||
}
|
||||
availableBalanceCny := response.Data.AvailableBalance
|
||||
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
|
||||
channel.UpdateBalance(availableBalanceUsd)
|
||||
return availableBalanceUsd, nil
|
||||
}
|
||||
|
||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() == "" {
|
||||
channel.BaseURL = &baseURL
|
||||
}
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeOpenAI:
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
case constant.ChannelTypeAzure:
|
||||
return 0, errors.New("尚未实现")
|
||||
case constant.ChannelTypeCustom:
|
||||
baseURL = channel.GetBaseURL()
|
||||
//case common.ChannelTypeOpenAISB:
|
||||
// return updateChannelOpenAISBBalance(channel)
|
||||
case constant.ChannelTypeAIProxy:
|
||||
return updateChannelAIProxyBalance(channel)
|
||||
case constant.ChannelTypeAPI2GPT:
|
||||
return updateChannelAPI2GPTBalance(channel)
|
||||
case constant.ChannelTypeAIGC2D:
|
||||
return updateChannelAIGC2DBalance(channel)
|
||||
case constant.ChannelTypeSiliconFlow:
|
||||
return updateChannelSiliconFlowBalance(channel)
|
||||
case constant.ChannelTypeDeepSeek:
|
||||
return updateChannelDeepSeekBalance(channel)
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
return updateChannelOpenRouterBalance(channel)
|
||||
case constant.ChannelTypeMoonshot:
|
||||
return updateChannelMoonshotBalance(channel)
|
||||
default:
|
||||
return 0, errors.New("尚未实现")
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL)
|
||||
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
subscription := OpenAISubscriptionResponse{}
|
||||
err = json.Unmarshal(body, &subscription)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
now := time.Now()
|
||||
startDate := fmt.Sprintf("%s-01", now.Format("2006-01"))
|
||||
endDate := now.Format("2006-01-02")
|
||||
if !subscription.HasPaymentMethod {
|
||||
startDate = now.AddDate(0, 0, -100).Format("2006-01-02")
|
||||
}
|
||||
url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate)
|
||||
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
usage := OpenAIUsageResponse{}
|
||||
err = json.Unmarshal(body, &usage)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
balance := subscription.HardLimitUSD - usage.TotalUsage/100
|
||||
channel.UpdateBalance(balance)
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func UpdateChannelBalance(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
channel, err := model.CacheGetChannel(id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "多密钥渠道不支持余额查询",
|
||||
})
|
||||
return
|
||||
}
|
||||
balance, err := updateChannelBalance(channel)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"balance": balance,
|
||||
})
|
||||
}
|
||||
|
||||
func updateAllChannelsBalance() error {
|
||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, channel := range channels {
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
continue // skip multi-key channels
|
||||
}
|
||||
// TODO: support Azure
|
||||
//if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
|
||||
// continue
|
||||
//}
|
||||
balance, err := updateChannelBalance(channel)
|
||||
if err != nil {
|
||||
continue
|
||||
} else {
|
||||
// err is nil & balance <= 0 means quota is used up
|
||||
if balance <= 0 {
|
||||
service.DisableChannel(*types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, "", channel.GetAutoBan()), "余额不足")
|
||||
}
|
||||
}
|
||||
time.Sleep(common.RequestInterval)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateAllChannelsBalance(c *gin.Context) {
|
||||
// TODO: make it async
|
||||
err := updateAllChannelsBalance()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func AutomaticallyUpdateChannels(frequency int) {
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||
common.SysLog("updating all channels")
|
||||
_ = updateAllChannelsBalance()
|
||||
common.SysLog("channels update done")
|
||||
}
|
||||
}
|
||||
464
controller/channel-test.go
Normal file
464
controller/channel-test.go
Normal file
@@ -0,0 +1,464 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type testResult struct {
|
||||
context *gin.Context
|
||||
localErr error
|
||||
newAPIError *types.NewAPIError
|
||||
}
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
tik := time.Now()
|
||||
if channel.Type == constant.ChannelTypeMidjourney {
|
||||
return testResult{
|
||||
localErr: errors.New("midjourney channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
||||
return testResult{
|
||||
localErr: errors.New("midjourney plus channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeSunoAPI {
|
||||
return testResult{
|
||||
localErr: errors.New("suno channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeKling {
|
||||
return testResult{
|
||||
localErr: errors.New("kling channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeJimeng {
|
||||
return testResult{
|
||||
localErr: errors.New("jimeng channel test is not supported"),
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
requestPath := "/v1/chat/completions"
|
||||
|
||||
// 先判断是否为 Embedding 模型
|
||||
if strings.Contains(strings.ToLower(testModel), "embedding") ||
|
||||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
||||
strings.Contains(testModel, "bge-") || // bge 系列模型
|
||||
strings.Contains(testModel, "embed") ||
|
||||
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
|
||||
requestPath = "/v1/embeddings" // 修改请求路径
|
||||
}
|
||||
|
||||
c.Request = &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: requestPath}, // 使用动态路径
|
||||
Body: nil,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
|
||||
if testModel == "" {
|
||||
if channel.TestModel != nil && *channel.TestModel != "" {
|
||||
testModel = *channel.TestModel
|
||||
} else {
|
||||
if len(channel.GetModels()) > 0 {
|
||||
testModel = channel.GetModels()[0]
|
||||
} else {
|
||||
testModel = "gpt-4o-mini"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cache, err := model.GetUserCache(1)
|
||||
if err != nil {
|
||||
return testResult{
|
||||
localErr: err,
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
cache.WriteContext(c)
|
||||
|
||||
//c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
group, _ := model.GetUserGroup(1, false)
|
||||
c.Set("group", group)
|
||||
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||
if newAPIError != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: newAPIError,
|
||||
newAPIError: newAPIError,
|
||||
}
|
||||
}
|
||||
|
||||
info := relaycommon.GenRelayInfo(c)
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, nil)
|
||||
if err != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
||||
}
|
||||
}
|
||||
testModel = info.UpstreamModelName
|
||||
|
||||
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
|
||||
newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
|
||||
}
|
||||
}
|
||||
|
||||
request := buildTestRequest(testModel)
|
||||
// 创建一个用于日志的 info 副本,移除 ApiKey
|
||||
logInfo := *info
|
||||
logInfo.ApiKey = ""
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
||||
if err != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
|
||||
}
|
||||
}
|
||||
|
||||
adaptor.Init(info)
|
||||
|
||||
var convertedRequest any
|
||||
// 根据 RelayMode 选择正确的转换函数
|
||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||
// 创建一个 EmbeddingRequest
|
||||
embeddingRequest := dto.EmbeddingRequest{
|
||||
Input: request.Input,
|
||||
Model: request.Model,
|
||||
}
|
||||
// 调用专门用于 Embedding 的转换函数
|
||||
convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
|
||||
} else {
|
||||
// 对其他所有请求类型(如 Chat),保持原有逻辑
|
||||
convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
||||
}
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
|
||||
}
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
c.Request.Body = io.NopCloser(requestBody)
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed),
|
||||
}
|
||||
}
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
err := service.RelayErrorHandler(httpResp, true)
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeBadResponse),
|
||||
}
|
||||
}
|
||||
}
|
||||
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
||||
if respErr != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: respErr,
|
||||
newAPIError: respErr,
|
||||
}
|
||||
}
|
||||
if usageA == nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: errors.New("usage is nil"),
|
||||
newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody),
|
||||
}
|
||||
}
|
||||
usage := usageA.(*dto.Usage)
|
||||
result := w.Result()
|
||||
respBody, err := io.ReadAll(result.Body)
|
||||
if err != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed),
|
||||
}
|
||||
}
|
||||
info.PromptTokens = usage.PromptTokens
|
||||
|
||||
quota := 0
|
||||
if !priceData.UsePrice {
|
||||
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
|
||||
quota = int(math.Round(float64(quota) * priceData.ModelRatio))
|
||||
if priceData.ModelRatio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
} else {
|
||||
quota = int(priceData.ModelPrice * common.QuotaPerUnit)
|
||||
}
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
|
||||
ChannelId: channel.Id,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
ModelName: info.OriginModelName,
|
||||
TokenName: "模型测试",
|
||||
Quota: quota,
|
||||
Content: "模型测试",
|
||||
UseTimeSeconds: int(consumedTime),
|
||||
IsStream: false,
|
||||
Group: info.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: nil,
|
||||
newAPIError: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
testRequest := &dto.GeneralOpenAIRequest{
|
||||
Model: "", // this will be set later
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
// 先判断是否为 Embedding 模型
|
||||
if strings.Contains(strings.ToLower(model), "embedding") || // 其他 embedding 模型
|
||||
strings.HasPrefix(model, "m3e") || // m3e 系列模型
|
||||
strings.Contains(model, "bge-") {
|
||||
testRequest.Model = model
|
||||
// Embedding 请求
|
||||
testRequest.Input = []any{"hello world"} // 修改为any,因为dto/openai_request.go 的ParseInput方法无法处理[]string类型
|
||||
return testRequest
|
||||
}
|
||||
// 并非Embedding 模型
|
||||
if strings.HasPrefix(model, "o") {
|
||||
testRequest.MaxCompletionTokens = 10
|
||||
} else if strings.Contains(model, "thinking") {
|
||||
if !strings.Contains(model, "claude") {
|
||||
testRequest.MaxTokens = 50
|
||||
}
|
||||
} else if strings.Contains(model, "gemini") {
|
||||
testRequest.MaxTokens = 3000
|
||||
} else {
|
||||
testRequest.MaxTokens = 10
|
||||
}
|
||||
|
||||
testMessage := dto.Message{
|
||||
Role: "user",
|
||||
Content: "hi",
|
||||
}
|
||||
testRequest.Model = model
|
||||
testRequest.Messages = append(testRequest.Messages, testMessage)
|
||||
return testRequest
|
||||
}
|
||||
|
||||
func TestChannel(c *gin.Context) {
|
||||
channelId, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
channel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
//defer func() {
|
||||
// if channel.ChannelInfo.IsMultiKey {
|
||||
// go func() { _ = channel.SaveChannelInfo() }()
|
||||
// }
|
||||
//}()
|
||||
testModel := c.Query("model")
|
||||
tik := time.Now()
|
||||
result := testChannel(channel, testModel)
|
||||
if result.localErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": result.localErr.Error(),
|
||||
"time": 0.0,
|
||||
})
|
||||
return
|
||||
}
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
go channel.UpdateResponseTime(milliseconds)
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
if result.newAPIError != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": result.newAPIError.Error(),
|
||||
"time": consumedTime,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"time": consumedTime,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var testAllChannelsLock sync.Mutex
|
||||
var testAllChannelsRunning bool = false
|
||||
|
||||
func testAllChannels(notify bool) error {
|
||||
|
||||
testAllChannelsLock.Lock()
|
||||
if testAllChannelsRunning {
|
||||
testAllChannelsLock.Unlock()
|
||||
return errors.New("测试已在运行中")
|
||||
}
|
||||
testAllChannelsRunning = true
|
||||
testAllChannelsLock.Unlock()
|
||||
channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
|
||||
if getChannelErr != nil {
|
||||
return getChannelErr
|
||||
}
|
||||
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
||||
if disableThreshold == 0 {
|
||||
disableThreshold = 10000000 // a impossible value
|
||||
}
|
||||
gopool.Go(func() {
|
||||
// 使用 defer 确保无论如何都会重置运行状态,防止死锁
|
||||
defer func() {
|
||||
testAllChannelsLock.Lock()
|
||||
testAllChannelsRunning = false
|
||||
testAllChannelsLock.Unlock()
|
||||
}()
|
||||
|
||||
for _, channel := range channels {
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
result := testChannel(channel, "")
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
||||
shouldBanChannel := false
|
||||
newAPIError := result.newAPIError
|
||||
// request error disables the channel
|
||||
if newAPIError != nil {
|
||||
shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
|
||||
}
|
||||
|
||||
// 当错误检查通过,才检查响应时间
|
||||
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
||||
if milliseconds > disableThreshold {
|
||||
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded)
|
||||
shouldBanChannel = true
|
||||
}
|
||||
}
|
||||
|
||||
// disable channel
|
||||
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
||||
go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
}
|
||||
|
||||
// enable channel
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
|
||||
service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
|
||||
}
|
||||
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
time.Sleep(common.RequestInterval)
|
||||
}
|
||||
|
||||
if notify {
|
||||
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAllChannels(c *gin.Context) {
|
||||
err := testAllChannels(true)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func AutomaticallyTestChannels(frequency int) {
|
||||
if frequency <= 0 {
|
||||
common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
|
||||
return
|
||||
}
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||
common.SysLog("testing all channels")
|
||||
_ = testAllChannels(false)
|
||||
common.SysLog("channel test finished")
|
||||
}
|
||||
}
|
||||
916
controller/channel.go
Normal file
916
controller/channel.go
Normal file
@@ -0,0 +1,916 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OpenAIModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Permission []struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
AllowCreateEngine bool `json:"allow_create_engine"`
|
||||
AllowSampling bool `json:"allow_sampling"`
|
||||
AllowLogprobs bool `json:"allow_logprobs"`
|
||||
AllowSearchIndices bool `json:"allow_search_indices"`
|
||||
AllowView bool `json:"allow_view"`
|
||||
AllowFineTuning bool `json:"allow_fine_tuning"`
|
||||
Organization string `json:"organization"`
|
||||
Group string `json:"group"`
|
||||
IsBlocking bool `json:"is_blocking"`
|
||||
} `json:"permission"`
|
||||
Root string `json:"root"`
|
||||
Parent string `json:"parent"`
|
||||
}
|
||||
|
||||
type OpenAIModelsResponse struct {
|
||||
Data []OpenAIModel `json:"data"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
func parseStatusFilter(statusParam string) int {
|
||||
switch strings.ToLower(statusParam) {
|
||||
case "enabled", "1":
|
||||
return common.ChannelStatusEnabled
|
||||
case "disabled", "0":
|
||||
return 0
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllChannels(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
channelData := make([]*model.Channel, 0)
|
||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||
statusParam := c.Query("status")
|
||||
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
|
||||
statusFilter := parseStatusFilter(statusParam)
|
||||
// type filter
|
||||
typeStr := c.Query("type")
|
||||
typeFilter := -1
|
||||
if typeStr != "" {
|
||||
if t, err := strconv.Atoi(typeStr); err == nil {
|
||||
typeFilter = t
|
||||
}
|
||||
}
|
||||
|
||||
var total int64
|
||||
|
||||
if enableTagMode {
|
||||
tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
for _, tag := range tags {
|
||||
if tag == nil || *tag == "" {
|
||||
continue
|
||||
}
|
||||
tagChannels, err := model.GetChannelsByTag(*tag, idSort)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
filtered := make([]*model.Channel, 0)
|
||||
for _, ch := range tagChannels {
|
||||
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if typeFilter >= 0 && ch.Type != typeFilter {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, ch)
|
||||
}
|
||||
channelData = append(channelData, filtered...)
|
||||
}
|
||||
total, _ = model.CountAllTags()
|
||||
} else {
|
||||
baseQuery := model.DB.Model(&model.Channel{})
|
||||
if typeFilter >= 0 {
|
||||
baseQuery = baseQuery.Where("type = ?", typeFilter)
|
||||
}
|
||||
if statusFilter == common.ChannelStatusEnabled {
|
||||
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||
} else if statusFilter == 0 {
|
||||
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
|
||||
}
|
||||
|
||||
baseQuery.Count(&total)
|
||||
|
||||
order := "priority desc"
|
||||
if idSort {
|
||||
order = "id desc"
|
||||
}
|
||||
|
||||
err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
countQuery := model.DB.Model(&model.Channel{})
|
||||
if statusFilter == common.ChannelStatusEnabled {
|
||||
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||
} else if statusFilter == 0 {
|
||||
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
|
||||
}
|
||||
var results []struct {
|
||||
Type int64
|
||||
Count int64
|
||||
}
|
||||
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
|
||||
typeCounts := make(map[int64]int64)
|
||||
for _, r := range results {
|
||||
typeCounts[r.Type] = r.Count
|
||||
}
|
||||
common.ApiSuccess(c, gin.H{
|
||||
"items": channelData,
|
||||
"total": total,
|
||||
"page": pageInfo.GetPage(),
|
||||
"page_size": pageInfo.GetPageSize(),
|
||||
"type_counts": typeCounts,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func FetchUpstreamModels(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeGemini:
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
||||
case constant.ChannelTypeAli:
|
||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||
}
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var result OpenAIModelsResponse
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var ids []string
|
||||
for _, model := range result.Data {
|
||||
id := model.ID
|
||||
if channel.Type == constant.ChannelTypeGemini {
|
||||
id = strings.TrimPrefix(id, "models/")
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": ids,
|
||||
})
|
||||
}
|
||||
|
||||
func FixChannelsAbilities(c *gin.Context) {
|
||||
success, fails, err := model.FixAbility()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"success": success,
|
||||
"fails": fails,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func SearchChannels(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
group := c.Query("group")
|
||||
modelKeyword := c.Query("model")
|
||||
statusParam := c.Query("status")
|
||||
statusFilter := parseStatusFilter(statusParam)
|
||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||
channelData := make([]*model.Channel, 0)
|
||||
if enableTagMode {
|
||||
tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
for _, tag := range tags {
|
||||
if tag != nil && *tag != "" {
|
||||
tagChannel, err := model.GetChannelsByTag(*tag, idSort)
|
||||
if err == nil {
|
||||
channelData = append(channelData, tagChannel...)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
channelData = channels
|
||||
}
|
||||
|
||||
if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
|
||||
filtered := make([]*model.Channel, 0, len(channelData))
|
||||
for _, ch := range channelData {
|
||||
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, ch)
|
||||
}
|
||||
channelData = filtered
|
||||
}
|
||||
|
||||
// calculate type counts for search results
|
||||
typeCounts := make(map[int64]int64)
|
||||
for _, channel := range channelData {
|
||||
typeCounts[int64(channel.Type)]++
|
||||
}
|
||||
|
||||
typeParam := c.Query("type")
|
||||
typeFilter := -1
|
||||
if typeParam != "" {
|
||||
if tp, err := strconv.Atoi(typeParam); err == nil {
|
||||
typeFilter = tp
|
||||
}
|
||||
}
|
||||
|
||||
if typeFilter >= 0 {
|
||||
filtered := make([]*model.Channel, 0, len(channelData))
|
||||
for _, ch := range channelData {
|
||||
if ch.Type == typeFilter {
|
||||
filtered = append(filtered, ch)
|
||||
}
|
||||
}
|
||||
channelData = filtered
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
total := len(channelData)
|
||||
startIdx := (page - 1) * pageSize
|
||||
if startIdx > total {
|
||||
startIdx = total
|
||||
}
|
||||
endIdx := startIdx + pageSize
|
||||
if endIdx > total {
|
||||
endIdx = total
|
||||
}
|
||||
|
||||
pagedData := channelData[startIdx:endIdx]
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": pagedData,
|
||||
"total": total,
|
||||
"type_counts": typeCounts,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetChannel(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
channel, err := model.GetChannelById(id, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channel,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// validateChannel 通用的渠道校验函数
|
||||
func validateChannel(channel *model.Channel, isAdd bool) error {
|
||||
// 校验 channel settings
|
||||
if err := channel.ValidateSettings(); err != nil {
|
||||
return fmt.Errorf("渠道额外设置[channel setting] 格式错误:%s", err.Error())
|
||||
}
|
||||
|
||||
// 如果是添加操作,检查 channel 和 key 是否为空
|
||||
if isAdd {
|
||||
if channel == nil || channel.Key == "" {
|
||||
return fmt.Errorf("channel cannot be empty")
|
||||
}
|
||||
|
||||
// 检查模型名称长度是否超过 255
|
||||
for _, m := range channel.GetModels() {
|
||||
if len(m) > 255 {
|
||||
return fmt.Errorf("模型名称过长: %s", m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// VertexAI 特殊校验
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
if channel.Other == "" {
|
||||
return fmt.Errorf("部署地区不能为空")
|
||||
}
|
||||
|
||||
regionMap, err := common.StrToMap(channel.Other)
|
||||
if err != nil {
|
||||
return fmt.Errorf("部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}")
|
||||
}
|
||||
|
||||
if regionMap["default"] == nil {
|
||||
return fmt.Errorf("部署地区必须包含default字段")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type AddChannelRequest struct {
|
||||
Mode string `json:"mode"`
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
Channel *model.Channel `json:"channel"`
|
||||
}
|
||||
|
||||
func getVertexArrayKeys(keys string) ([]string, error) {
|
||||
if keys == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var keyArray []interface{}
|
||||
err := common.Unmarshal([]byte(keys), &keyArray)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
|
||||
}
|
||||
cleanKeys := make([]string, 0, len(keyArray))
|
||||
for _, key := range keyArray {
|
||||
var keyStr string
|
||||
switch v := key.(type) {
|
||||
case string:
|
||||
keyStr = strings.TrimSpace(v)
|
||||
default:
|
||||
bytes, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
|
||||
}
|
||||
keyStr = string(bytes)
|
||||
}
|
||||
if keyStr != "" {
|
||||
cleanKeys = append(cleanKeys, keyStr)
|
||||
}
|
||||
}
|
||||
if len(cleanKeys) == 0 {
|
||||
return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
|
||||
}
|
||||
return cleanKeys, nil
|
||||
}
|
||||
|
||||
func AddChannel(c *gin.Context) {
|
||||
addChannelRequest := AddChannelRequest{}
|
||||
err := c.ShouldBindJSON(&addChannelRequest)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 使用统一的校验函数
|
||||
if err := validateChannel(addChannelRequest.Channel, true); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
|
||||
keys := make([]string, 0)
|
||||
switch addChannelRequest.Mode {
|
||||
case "multi_to_single":
|
||||
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
|
||||
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
|
||||
addChannelRequest.Channel.Key = strings.Join(array, "\n")
|
||||
} else {
|
||||
cleanKeys := make([]string, 0)
|
||||
for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
key = strings.TrimSpace(key)
|
||||
cleanKeys = append(cleanKeys, key)
|
||||
}
|
||||
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
|
||||
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
|
||||
}
|
||||
keys = []string{addChannelRequest.Channel.Key}
|
||||
case "batch":
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
// multi json
|
||||
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
keys = strings.Split(addChannelRequest.Channel.Key, "\n")
|
||||
}
|
||||
case "single":
|
||||
keys = []string{addChannelRequest.Channel.Key}
|
||||
default:
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "不支持的添加模式",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
channels := make([]model.Channel, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
localChannel := addChannelRequest.Channel
|
||||
localChannel.Key = key
|
||||
channels = append(channels, *localChannel)
|
||||
}
|
||||
err = model.BatchInsertChannels(channels)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteChannel(c *gin.Context) {
|
||||
id, _ := strconv.Atoi(c.Param("id"))
|
||||
channel := model.Channel{Id: id}
|
||||
err := channel.Delete()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteDisabledChannel(c *gin.Context) {
|
||||
rows, err := model.DeleteDisabledChannel()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": rows,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type ChannelTag struct {
|
||||
Tag string `json:"tag"`
|
||||
NewTag *string `json:"new_tag"`
|
||||
Priority *int64 `json:"priority"`
|
||||
Weight *uint `json:"weight"`
|
||||
ModelMapping *string `json:"model_mapping"`
|
||||
Models *string `json:"models"`
|
||||
Groups *string `json:"groups"`
|
||||
}
|
||||
|
||||
func DisableTagChannels(c *gin.Context) {
|
||||
channelTag := ChannelTag{}
|
||||
err := c.ShouldBindJSON(&channelTag)
|
||||
if err != nil || channelTag.Tag == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
err = model.DisableChannelByTag(channelTag.Tag)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func EnableTagChannels(c *gin.Context) {
|
||||
channelTag := ChannelTag{}
|
||||
err := c.ShouldBindJSON(&channelTag)
|
||||
if err != nil || channelTag.Tag == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
err = model.EnableChannelByTag(channelTag.Tag)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func EditTagChannels(c *gin.Context) {
|
||||
channelTag := ChannelTag{}
|
||||
err := c.ShouldBindJSON(&channelTag)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
if channelTag.Tag == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "tag不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type ChannelBatch struct {
|
||||
Ids []int `json:"ids"`
|
||||
Tag *string `json:"tag"`
|
||||
}
|
||||
|
||||
func DeleteChannelBatch(c *gin.Context) {
|
||||
channelBatch := ChannelBatch{}
|
||||
err := c.ShouldBindJSON(&channelBatch)
|
||||
if err != nil || len(channelBatch.Ids) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
err = model.BatchDeleteChannels(channelBatch.Ids)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": len(channelBatch.Ids),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type PatchChannel struct {
|
||||
model.Channel
|
||||
MultiKeyMode *string `json:"multi_key_mode"`
|
||||
}
|
||||
|
||||
func UpdateChannel(c *gin.Context) {
|
||||
channel := PatchChannel{}
|
||||
err := c.ShouldBindJSON(&channel)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 使用统一的校验函数
|
||||
if err := validateChannel(&channel.Channel, false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
|
||||
originChannel, err := model.GetChannelById(channel.Id, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Always copy the original ChannelInfo so that fields like IsMultiKey and MultiKeySize are retained.
|
||||
channel.ChannelInfo = originChannel.ChannelInfo
|
||||
|
||||
// If the request explicitly specifies a new MultiKeyMode, apply it on top of the original info.
|
||||
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
|
||||
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
|
||||
}
|
||||
err = channel.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
channel.Key = ""
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channel,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func FetchModels(c *gin.Context) {
|
||||
var req struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
Type int `json:"type"`
|
||||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "Invalid request",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = constant.ChannelBaseURLs[req.Type]
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
|
||||
request, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// remove line breaks and extra spaces.
|
||||
key := strings.TrimSpace(req.Key)
|
||||
// If the key contains a line break, only take the first part.
|
||||
key = strings.Split(key, "\n")[0]
|
||||
request.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
//check status code
|
||||
if response.StatusCode != http.StatusOK {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": "Failed to fetch models",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
var result struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var models []string
|
||||
for _, model := range result.Data {
|
||||
models = append(models, model.ID)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": models,
|
||||
})
|
||||
}
|
||||
|
||||
func BatchSetChannelTag(c *gin.Context) {
|
||||
channelBatch := ChannelBatch{}
|
||||
err := c.ShouldBindJSON(&channelBatch)
|
||||
if err != nil || len(channelBatch.Ids) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": len(channelBatch.Ids),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetTagModels(c *gin.Context) {
|
||||
tag := c.Query("tag")
|
||||
if tag == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "tag不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
channels, err := model.GetChannelsByTag(tag, false) // Assuming false for idSort is fine here
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var longestModels string
|
||||
maxLength := 0
|
||||
|
||||
// Find the longest models string among all channels with the given tag
|
||||
for _, channel := range channels {
|
||||
if channel.Models != "" {
|
||||
currentModels := strings.Split(channel.Models, ",")
|
||||
if len(currentModels) > maxLength {
|
||||
maxLength = len(currentModels)
|
||||
longestModels = channel.Models
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": longestModels,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// CopyChannel handles cloning an existing channel with its key.
|
||||
// POST /api/channel/copy/:id
|
||||
// Optional query params:
|
||||
//
|
||||
// suffix - string appended to the original name (default "_复制")
|
||||
// reset_balance - bool, when true will reset balance & used_quota to 0 (default true)
|
||||
func CopyChannel(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid id"})
|
||||
return
|
||||
}
|
||||
|
||||
suffix := c.DefaultQuery("suffix", "_复制")
|
||||
resetBalance := true
|
||||
if rbStr := c.DefaultQuery("reset_balance", "true"); rbStr != "" {
|
||||
if v, err := strconv.ParseBool(rbStr); err == nil {
|
||||
resetBalance = v
|
||||
}
|
||||
}
|
||||
|
||||
// fetch original channel with key
|
||||
origin, err := model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// clone channel
|
||||
clone := *origin // shallow copy is sufficient as we will overwrite primitives
|
||||
clone.Id = 0 // let DB auto-generate
|
||||
clone.CreatedTime = common.GetTimestamp()
|
||||
clone.Name = origin.Name + suffix
|
||||
clone.TestTime = 0
|
||||
clone.ResponseTime = 0
|
||||
if resetBalance {
|
||||
clone.Balance = 0
|
||||
clone.UsedQuota = 0
|
||||
}
|
||||
|
||||
// insert
|
||||
if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
model.InitChannelCache()
|
||||
// success
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
|
||||
}
|
||||
103
controller/console_migrate.go
Normal file
103
controller/console_migrate.go
Normal file
@@ -0,0 +1,103 @@
|
||||
// 用于迁移检测的旧键,该文件下个版本会删除
|
||||
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
|
||||
func MigrateConsoleSetting(c *gin.Context) {
|
||||
// 读取全部 option
|
||||
opts, err := model.AllOption()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
// 建立 map
|
||||
valMap := map[string]string{}
|
||||
for _, o := range opts {
|
||||
valMap[o.Key] = o.Value
|
||||
}
|
||||
|
||||
// 处理 APIInfo
|
||||
if v := valMap["ApiInfo"]; v != "" {
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||
if len(arr) > 50 {
|
||||
arr = arr[:50]
|
||||
}
|
||||
bytes, _ := json.Marshal(arr)
|
||||
model.UpdateOption("console_setting.api_info", string(bytes))
|
||||
}
|
||||
model.UpdateOption("ApiInfo", "")
|
||||
}
|
||||
// Announcements 直接搬
|
||||
if v := valMap["Announcements"]; v != "" {
|
||||
model.UpdateOption("console_setting.announcements", v)
|
||||
model.UpdateOption("Announcements", "")
|
||||
}
|
||||
// FAQ 转换
|
||||
if v := valMap["FAQ"]; v != "" {
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||
out := []map[string]interface{}{}
|
||||
for _, item := range arr {
|
||||
q, _ := item["question"].(string)
|
||||
if q == "" {
|
||||
q, _ = item["title"].(string)
|
||||
}
|
||||
a, _ := item["answer"].(string)
|
||||
if a == "" {
|
||||
a, _ = item["content"].(string)
|
||||
}
|
||||
if q != "" && a != "" {
|
||||
out = append(out, map[string]interface{}{"question": q, "answer": a})
|
||||
}
|
||||
}
|
||||
if len(out) > 50 {
|
||||
out = out[:50]
|
||||
}
|
||||
bytes, _ := json.Marshal(out)
|
||||
model.UpdateOption("console_setting.faq", string(bytes))
|
||||
}
|
||||
model.UpdateOption("FAQ", "")
|
||||
}
|
||||
// Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
|
||||
url := valMap["UptimeKumaUrl"]
|
||||
slug := valMap["UptimeKumaSlug"]
|
||||
if url != "" && slug != "" {
|
||||
// 仅当同时存在 URL 与 Slug 时才进行迁移
|
||||
groups := []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"categoryName": "old",
|
||||
"url": url,
|
||||
"slug": slug,
|
||||
"description": "",
|
||||
},
|
||||
}
|
||||
bytes, _ := json.Marshal(groups)
|
||||
model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
|
||||
}
|
||||
// 清空旧键内容
|
||||
if url != "" {
|
||||
model.UpdateOption("UptimeKumaUrl", "")
|
||||
}
|
||||
if slug != "" {
|
||||
model.UpdateOption("UptimeKumaSlug", "")
|
||||
}
|
||||
|
||||
// 删除旧键记录
|
||||
oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
|
||||
model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
|
||||
|
||||
// 重新加载 OptionMap
|
||||
model.InitOptionMap()
|
||||
common.SysLog("console setting migrated")
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
|
||||
}
|
||||
239
controller/github.go
Normal file
239
controller/github.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GitHubOAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Scope string `json:"scope"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
type GitHubUser struct {
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
|
||||
jsonData, err := json.Marshal(values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oAuthResponse GitHubOAuthResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequest("GET", "https://api.github.com/user", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
var githubUser GitHubUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&githubUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if githubUser.Login == "" {
|
||||
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
|
||||
}
|
||||
return &githubUser, nil
|
||||
}
|
||||
|
||||
func GitHubOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
GitHubBind(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !common.GitHubOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 GitHub 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
GitHubId: githubUser.Login,
|
||||
}
|
||||
// IsGitHubIdAlreadyTaken is unscoped
|
||||
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
|
||||
// FillUserByGitHubId is scoped
|
||||
err := user.FillUserByGitHubId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// if user.Id == 0 , user has been deleted
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
if githubUser.Name != "" {
|
||||
user.DisplayName = githubUser.Name
|
||||
} else {
|
||||
user.DisplayName = "GitHub User"
|
||||
}
|
||||
user.Email = githubUser.Email
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func GitHubBind(c *gin.Context) {
|
||||
if !common.GitHubOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 GitHub 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
githubUser, err := getGitHubUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
GitHubId: githubUser.Login,
|
||||
}
|
||||
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 GitHub 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.GitHubId = githubUser.Login
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GenerateOAuthCode(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := common.GetRandomString(12)
|
||||
affCode := c.Query("aff")
|
||||
if affCode != "" {
|
||||
session.Set("aff", affCode)
|
||||
}
|
||||
session.Set("oauth_state", state)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": state,
|
||||
})
|
||||
}
|
||||
50
controller/group.go
Normal file
50
controller/group.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetGroups(c *gin.Context) {
|
||||
groupNames := make([]string, 0)
|
||||
for groupName := range ratio_setting.GetGroupRatioCopy() {
|
||||
groupNames = append(groupNames, groupName)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": groupNames,
|
||||
})
|
||||
}
|
||||
|
||||
func GetUserGroups(c *gin.Context) {
|
||||
usableGroups := make(map[string]map[string]interface{})
|
||||
userGroup := ""
|
||||
userId := c.GetInt("id")
|
||||
userGroup, _ = model.GetUserGroup(userId, false)
|
||||
for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
|
||||
// UserUsableGroups contains the groups that the user can use
|
||||
userUsableGroups := setting.GetUserUsableGroups(userGroup)
|
||||
if desc, ok := userUsableGroups[groupName]; ok {
|
||||
usableGroups[groupName] = map[string]interface{}{
|
||||
"ratio": ratio,
|
||||
"desc": desc,
|
||||
}
|
||||
}
|
||||
}
|
||||
if setting.GroupInUserUsableGroups("auto") {
|
||||
usableGroups["auto"] = map[string]interface{}{
|
||||
"ratio": "自动",
|
||||
"desc": setting.GetUsableGroupDescription("auto"),
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": usableGroups,
|
||||
})
|
||||
}
|
||||
9
controller/image.go
Normal file
9
controller/image.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetImage(c *gin.Context) {
|
||||
|
||||
}
|
||||
259
controller/linuxdo.go
Normal file
259
controller/linuxdo.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type LinuxdoUser struct {
|
||||
Id int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Active bool `json:"active"`
|
||||
TrustLevel int `json:"trust_level"`
|
||||
Silenced bool `json:"silenced"`
|
||||
}
|
||||
|
||||
func LinuxDoBind(c *gin.Context) {
|
||||
if !common.LinuxDOOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Linux DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
|
||||
}
|
||||
|
||||
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 Linux DO 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user.Id = id.(int)
|
||||
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
}
|
||||
|
||||
func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("invalid code")
|
||||
}
|
||||
|
||||
// Get access token using Basic auth
|
||||
tokenEndpoint := "https://connect.linux.do/oauth2/token"
|
||||
credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
|
||||
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
|
||||
|
||||
// Get redirect URI from request
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", redirectURI)
|
||||
|
||||
req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", basicAuth)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := http.Client{Timeout: 5 * time.Second}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to connect to Linux DO server")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
var tokenRes struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tokenRes.AccessToken == "" {
|
||||
return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message)
|
||||
}
|
||||
|
||||
// Get user info
|
||||
userEndpoint := "https://connect.linux.do/api/user"
|
||||
req, err = http.NewRequest("GET", userEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to get user info from Linux DO")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
|
||||
var linuxdoUser LinuxdoUser
|
||||
if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if linuxdoUser.Id == 0 {
|
||||
return nil, errors.New("invalid user info returned")
|
||||
}
|
||||
|
||||
return &linuxdoUser, nil
|
||||
}
|
||||
|
||||
func LinuxdoOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
|
||||
errorCode := c.Query("error")
|
||||
if errorCode != "" {
|
||||
errorDescription := c.Query("error_description")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": errorDescription,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
LinuxDoBind(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !common.LinuxDOOAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Linux DO 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
code := c.Query("code")
|
||||
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
|
||||
}
|
||||
|
||||
// Check if user exists
|
||||
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
|
||||
err := user.FillUserByLinuxDOId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
user.DisplayName = linuxdoUser.Name
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
affCode := session.Get("aff")
|
||||
inviterId := 0
|
||||
if affCode != nil {
|
||||
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
|
||||
}
|
||||
|
||||
if err := user.Insert(inviterId); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
168
controller/log.go
Normal file
168
controller/log.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAllLogs(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
username := c.Query("username")
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
group := c.Query("group")
|
||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), channel, group)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(logs)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func GetUserLogs(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
userId := c.GetInt("id")
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
group := c.Query("group")
|
||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), group)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(logs)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func SearchAllLogs(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
logs, err := model.SearchAllLogs(keyword)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SearchUserLogs(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
userId := c.GetInt("id")
|
||||
logs, err := model.SearchUserLogs(userId, keyword)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetLogByKey(c *gin.Context) {
|
||||
key := c.Query("key")
|
||||
logs, err := model.GetLogByKey(key)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
}
|
||||
|
||||
func GetLogsStat(c *gin.Context) {
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
tokenName := c.Query("token_name")
|
||||
username := c.Query("username")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
group := c.Query("group")
|
||||
stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"quota": stat.Quota,
|
||||
"rpm": stat.Rpm,
|
||||
"tpm": stat.Tpm,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetLogsSelfStat(c *gin.Context) {
|
||||
username := c.GetString("username")
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
group := c.Query("group")
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"quota": quotaNum.Quota,
|
||||
"rpm": quotaNum.Rpm,
|
||||
"tpm": quotaNum.Tpm,
|
||||
//"token": tokenNum,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteHistoryLogs(c *gin.Context) {
|
||||
targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
|
||||
if targetTimestamp == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "target timestamp is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
count, err := model.DeleteOldLog(c.Request.Context(), targetTimestamp, 100)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": count,
|
||||
})
|
||||
return
|
||||
}
|
||||
263
controller/midjourney.go
Normal file
263
controller/midjourney.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func UpdateMidjourneyTaskBulk() {
|
||||
//imageModel := "midjourney"
|
||||
ctx := context.TODO()
|
||||
for {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
|
||||
tasks := model.GetAllUnFinishTasks()
|
||||
if len(tasks) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
||||
taskChannelM := make(map[int][]string)
|
||||
taskM := make(map[string]*model.Midjourney)
|
||||
nullTaskIds := make([]int, 0)
|
||||
for _, task := range tasks {
|
||||
if task.MjId == "" {
|
||||
// 统计失败的未完成任务
|
||||
nullTaskIds = append(nullTaskIds, task.Id)
|
||||
continue
|
||||
}
|
||||
taskM[task.MjId] = task
|
||||
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
|
||||
}
|
||||
if len(nullTaskIds) > 0 {
|
||||
err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
|
||||
} else {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
|
||||
}
|
||||
}
|
||||
if len(taskChannelM) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
continue
|
||||
}
|
||||
midjourneyChannel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
|
||||
err := model.MjBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"ids": taskIds,
|
||||
})
|
||||
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
|
||||
continue
|
||||
}
|
||||
// 设置超时时间
|
||||
timeout := time.Second * 15
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
// 使用带有超时的 context 创建新的请求
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||
resp, err := service.GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
continue
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
continue
|
||||
}
|
||||
var responseItems []dto.MidjourneyDto
|
||||
err = json.Unmarshal(responseBody, &responseItems)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
req.Body.Close()
|
||||
cancel()
|
||||
|
||||
for _, responseItem := range responseItems {
|
||||
task := taskM[responseItem.MjId]
|
||||
|
||||
useTime := (time.Now().UnixNano() / int64(time.Millisecond)) - task.SubmitTime
|
||||
// 如果时间超过一小时,且进度不是100%,则认为任务失败
|
||||
if useTime > 3600000 && task.Progress != "100%" {
|
||||
responseItem.FailReason = "上游任务超时(超过1小时)"
|
||||
responseItem.Status = "FAILURE"
|
||||
}
|
||||
if !checkMjTaskNeedUpdate(task, responseItem) {
|
||||
continue
|
||||
}
|
||||
task.Code = 1
|
||||
task.Progress = responseItem.Progress
|
||||
task.PromptEn = responseItem.PromptEn
|
||||
task.State = responseItem.State
|
||||
task.SubmitTime = responseItem.SubmitTime
|
||||
task.StartTime = responseItem.StartTime
|
||||
task.FinishTime = responseItem.FinishTime
|
||||
task.ImageUrl = responseItem.ImageUrl
|
||||
task.Status = responseItem.Status
|
||||
task.FailReason = responseItem.FailReason
|
||||
if responseItem.Properties != nil {
|
||||
propertiesStr, _ := json.Marshal(responseItem.Properties)
|
||||
task.Properties = string(propertiesStr)
|
||||
}
|
||||
if responseItem.Buttons != nil {
|
||||
buttonStr, _ := json.Marshal(responseItem.Buttons)
|
||||
task.Buttons = string(buttonStr)
|
||||
}
|
||||
shouldReturnQuota := false
|
||||
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
|
||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
if task.Quota != 0 {
|
||||
shouldReturnQuota = true
|
||||
}
|
||||
}
|
||||
err = task.Update()
|
||||
if err != nil {
|
||||
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||
} else {
|
||||
if shouldReturnQuota {
|
||||
err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) bool {
|
||||
if oldTask.Code != 1 {
|
||||
return true
|
||||
}
|
||||
if oldTask.Progress != newTask.Progress {
|
||||
return true
|
||||
}
|
||||
if oldTask.PromptEn != newTask.PromptEn {
|
||||
return true
|
||||
}
|
||||
if oldTask.State != newTask.State {
|
||||
return true
|
||||
}
|
||||
if oldTask.SubmitTime != newTask.SubmitTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.StartTime != newTask.StartTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.FinishTime != newTask.FinishTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.ImageUrl != newTask.ImageUrl {
|
||||
return true
|
||||
}
|
||||
if oldTask.Status != newTask.Status {
|
||||
return true
|
||||
}
|
||||
if oldTask.FailReason != newTask.FailReason {
|
||||
return true
|
||||
}
|
||||
if oldTask.FinishTime != newTask.FinishTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.Progress != "100%" && newTask.FailReason != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func GetAllMidjourney(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
// 解析其他查询参数
|
||||
queryParams := model.TaskQueryParams{
|
||||
ChannelID: c.Query("channel_id"),
|
||||
MjID: c.Query("mj_id"),
|
||||
StartTimestamp: c.Query("start_timestamp"),
|
||||
EndTimestamp: c.Query("end_timestamp"),
|
||||
}
|
||||
|
||||
items := model.GetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.CountAllTasks(queryParams)
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range items {
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
func GetUserMidjourney(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
queryParams := model.TaskQueryParams{
|
||||
MjID: c.Query("mj_id"),
|
||||
StartTimestamp: c.Query("start_timestamp"),
|
||||
EndTimestamp: c.Query("end_timestamp"),
|
||||
}
|
||||
|
||||
items := model.GetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.CountAllUserTask(userId, queryParams)
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range items {
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
302
controller/misc.go
Normal file
302
controller/misc.go
Normal file
@@ -0,0 +1,302 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/console_setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestStatus(c *gin.Context) {
|
||||
err := model.PingDB()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"success": false,
|
||||
"message": "数据库连接失败",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 获取HTTP统计信息
|
||||
httpStats := middleware.GetStats()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Server is running",
|
||||
"http_stats": httpStats,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetStatus(c *gin.Context) {
|
||||
|
||||
cs := console_setting.GetConsoleSetting()
|
||||
|
||||
data := gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"stripe_unit_price": setting.StripeUnitPrice,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
"pay_methods": setting.PayMethods,
|
||||
"usd_exchange_rate": setting.USDExchangeRate,
|
||||
|
||||
// 面板启用开关
|
||||
"api_info_enabled": cs.ApiInfoEnabled,
|
||||
"uptime_kuma_enabled": cs.UptimeKumaEnabled,
|
||||
"announcements_enabled": cs.AnnouncementsEnabled,
|
||||
"faq_enabled": cs.FAQEnabled,
|
||||
|
||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||
"setup": constant.Setup,
|
||||
}
|
||||
|
||||
// 根据启用状态注入可选内容
|
||||
if cs.ApiInfoEnabled {
|
||||
data["api_info"] = console_setting.GetApiInfo()
|
||||
}
|
||||
if cs.AnnouncementsEnabled {
|
||||
data["announcements"] = console_setting.GetAnnouncements()
|
||||
}
|
||||
if cs.FAQEnabled {
|
||||
data["faq"] = console_setting.GetFAQ()
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": data,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetNotice(c *gin.Context) {
|
||||
common.OptionMapRWMutex.RLock()
|
||||
defer common.OptionMapRWMutex.RUnlock()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": common.OptionMap["Notice"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetAbout(c *gin.Context) {
|
||||
common.OptionMapRWMutex.RLock()
|
||||
defer common.OptionMapRWMutex.RUnlock()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": common.OptionMap["About"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetMidjourney(c *gin.Context) {
|
||||
common.OptionMapRWMutex.RLock()
|
||||
defer common.OptionMapRWMutex.RUnlock()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": common.OptionMap["Midjourney"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetHomePageContent(c *gin.Context) {
|
||||
common.OptionMapRWMutex.RLock()
|
||||
defer common.OptionMapRWMutex.RUnlock()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": common.OptionMap["HomePageContent"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SendEmailVerification(c *gin.Context) {
|
||||
email := c.Query("email")
|
||||
if err := common.Validate.Var(email, "required,email"); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的邮箱地址",
|
||||
})
|
||||
return
|
||||
}
|
||||
localPart := parts[0]
|
||||
domainPart := parts[1]
|
||||
if common.EmailDomainRestrictionEnabled {
|
||||
allowed := false
|
||||
for _, domain := range common.EmailDomainWhitelist {
|
||||
if domainPart == domain {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "The administrator has enabled the email domain name whitelist, and your email address is not allowed due to special symbols or it's not in the whitelist.",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if common.EmailAliasRestrictionEnabled {
|
||||
containsSpecialSymbols := strings.Contains(localPart, "+") || strings.Contains(localPart, ".")
|
||||
if containsSpecialSymbols {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员已启用邮箱地址别名限制,您的邮箱地址由于包含特殊符号而被拒绝。",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if model.IsEmailAlreadyTaken(email) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "邮箱地址已被占用",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := common.GenerateVerificationCode(6)
|
||||
common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose)
|
||||
subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName)
|
||||
content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+
|
||||
"<p>您的验证码为: <strong>%s</strong></p>"+
|
||||
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes)
|
||||
err := common.SendEmail(subject, email, content)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func SendPasswordResetEmail(c *gin.Context) {
|
||||
email := c.Query("email")
|
||||
if err := common.Validate.Var(email, "required,email"); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
if !model.IsEmailAlreadyTaken(email) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该邮箱地址未注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := common.GenerateVerificationCode(0)
|
||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code)
|
||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||
"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
|
||||
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
|
||||
err := common.SendEmail(subject, email, content)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type PasswordResetRequest struct {
|
||||
Email string `json:"email"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
func ResetPassword(c *gin.Context) {
|
||||
var req PasswordResetRequest
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&req)
|
||||
if req.Email == "" || req.Token == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
if !common.VerifyCodeWithKey(req.Email, req.Token, common.PasswordResetPurpose) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "重置链接非法或已过期",
|
||||
})
|
||||
return
|
||||
}
|
||||
password := common.GenerateVerificationCode(12)
|
||||
err = model.ResetUserPasswordByEmail(req.Email, password)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
common.DeleteKey(req.Email, common.PasswordResetPurpose)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": password,
|
||||
})
|
||||
return
|
||||
}
|
||||
216
controller/model.go
Normal file
216
controller/model.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/channel/ai360"
|
||||
"one-api/relay/channel/lingyiwanwu"
|
||||
"one-api/relay/channel/minimax"
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
|
||||
var openAIModels []dto.OpenAIModels
|
||||
var openAIModelsMap map[string]dto.OpenAIModels
|
||||
var channelId2Models map[int][]string
|
||||
|
||||
func init() {
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
for i := 0; i < constant.APITypeDummy; i++ {
|
||||
if i == constant.APITypeAIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
adaptor := relay.GetAdaptor(i)
|
||||
channelName := adaptor.GetChannelName()
|
||||
modelNames := adaptor.GetModelList()
|
||||
for _, modelName := range modelNames {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: channelName,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, modelName := range ai360.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: ai360.ChannelName,
|
||||
})
|
||||
}
|
||||
for _, modelName := range moonshot.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: moonshot.ChannelName,
|
||||
})
|
||||
}
|
||||
for _, modelName := range lingyiwanwu.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: lingyiwanwu.ChannelName,
|
||||
})
|
||||
}
|
||||
for _, modelName := range minimax.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: minimax.ChannelName,
|
||||
})
|
||||
}
|
||||
for modelName, _ := range constant.MidjourneyModel2Action {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "midjourney",
|
||||
})
|
||||
}
|
||||
openAIModelsMap = make(map[string]dto.OpenAIModels)
|
||||
for _, aiModel := range openAIModels {
|
||||
openAIModelsMap[aiModel.Id] = aiModel
|
||||
}
|
||||
channelId2Models = make(map[int][]string)
|
||||
for i := 1; i <= constant.ChannelTypeDummy; i++ {
|
||||
apiType, success := common.ChannelType2APIType(i)
|
||||
if !success || apiType == constant.APITypeAIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
adaptor.Init(meta)
|
||||
channelId2Models[i] = adaptor.GetModelList()
|
||||
}
|
||||
openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
|
||||
return m.Id
|
||||
})
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||
|
||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||
if modelLimitEnable {
|
||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||
var tokenModelLimit map[string]bool
|
||||
if ok {
|
||||
tokenModelLimit = s.(map[string]bool)
|
||||
} else {
|
||||
tokenModelLimit = map[string]bool{}
|
||||
}
|
||||
for allowModel, _ := range tokenModelLimit {
|
||||
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: allowModel,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
userGroup, err := model.GetUserGroup(userId, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "get user group failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
group := userGroup
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
var models []string
|
||||
if tokenGroup == "auto" {
|
||||
for _, autoGroup := range setting.AutoGroups {
|
||||
groupModels := model.GetGroupEnabledModels(autoGroup)
|
||||
for _, g := range groupModels {
|
||||
if !common.StringsContains(models, g) {
|
||||
models = append(models, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
models = model.GetGroupEnabledModels(group)
|
||||
}
|
||||
for _, modelName := range models {
|
||||
if oaiModel, ok := openAIModelsMap[modelName]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": userOpenAiModels,
|
||||
})
|
||||
}
|
||||
|
||||
func ChannelListModels(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": openAIModels,
|
||||
})
|
||||
}
|
||||
|
||||
func DashboardListModels(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": channelId2Models,
|
||||
})
|
||||
}
|
||||
|
||||
func EnabledListModels(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": model.GetEnabledModels(),
|
||||
})
|
||||
}
|
||||
|
||||
func RetrieveModel(c *gin.Context) {
|
||||
modelId := c.Param("model")
|
||||
if aiModel, ok := openAIModelsMap[modelId]; ok {
|
||||
c.JSON(200, aiModel)
|
||||
} else {
|
||||
openAIError := dto.OpenAIError{
|
||||
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
||||
Type: "invalid_request_error",
|
||||
Param: "model",
|
||||
Code: "model_not_found",
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"error": openAIError,
|
||||
})
|
||||
}
|
||||
}
|
||||
228
controller/oidc.go
Normal file
228
controller/oidc.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OidcResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type OidcUser struct {
|
||||
OpenID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
|
||||
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
|
||||
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var oidcResponse OidcResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if oidcResponse.AccessToken == "" {
|
||||
common.SysError("OIDC 获取 Token 失败,请检查设置!")
|
||||
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
common.SysError("OIDC 获取用户信息失败!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
|
||||
}
|
||||
|
||||
var oidcUser OidcUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oidcUser.OpenID == "" || oidcUser.Email == "" {
|
||||
common.SysError("OIDC 获取用户信息为空!请检查设置!")
|
||||
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
|
||||
}
|
||||
return &oidcUser, nil
|
||||
}
|
||||
|
||||
func OidcAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
OidcBind(c)
|
||||
return
|
||||
}
|
||||
if !system_setting.GetOIDCSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
err := user.FillUserByOidcId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Email = oidcUser.Email
|
||||
if oidcUser.PreferredUsername != "" {
|
||||
user.Username = oidcUser.PreferredUsername
|
||||
} else {
|
||||
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if oidcUser.Name != "" {
|
||||
user.DisplayName = oidcUser.Name
|
||||
} else {
|
||||
user.DisplayName = "OIDC User"
|
||||
}
|
||||
err := user.Insert(0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func OidcBind(c *gin.Context) {
|
||||
if !system_setting.GetOIDCSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 OIDC 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
oidcUser, err := getOidcUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
OidcId: oidcUser.OpenID,
|
||||
}
|
||||
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 OIDC 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
// id := c.GetInt("id") // critical bug!
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.OidcId = oidcUser.OpenID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
return
|
||||
}
|
||||
171
controller/option.go
Normal file
171
controller/option.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/console_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetOptions(c *gin.Context) {
|
||||
var options []*model.Option
|
||||
common.OptionMapRWMutex.Lock()
|
||||
for k, v := range common.OptionMap {
|
||||
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "Key") {
|
||||
continue
|
||||
}
|
||||
options = append(options, &model.Option{
|
||||
Key: k,
|
||||
Value: common.Interface2String(v),
|
||||
})
|
||||
}
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": options,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateOption(c *gin.Context) {
|
||||
var option model.Option
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&option)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
switch option.Key {
|
||||
case "GitHubOAuthEnabled":
|
||||
if option.Value == "true" && common.GitHubClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "oidc.enabled":
|
||||
if option.Value == "true" && system_setting.GetOIDCSettings().ClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 OIDC 登录,请先填入 OIDC Client Id 以及 OIDC Client Secret!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "LinuxDOOAuthEnabled":
|
||||
if option.Value == "true" && common.LinuxDOClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 LinuxDO OAuth,请先填入 LinuxDO Client Id 以及 LinuxDO Client Secret!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "EmailDomainRestrictionEnabled":
|
||||
if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "WeChatAuthEnabled":
|
||||
if option.Value == "true" && common.WeChatServerAddress == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用微信登录,请先填入微信登录相关配置信息!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "TurnstileCheckEnabled":
|
||||
if option.Value == "true" && common.TurnstileSiteKey == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
case "TelegramOAuthEnabled":
|
||||
if option.Value == "true" && common.TelegramBotToken == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 Telegram OAuth,请先填入 Telegram Bot Token!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "GroupRatio":
|
||||
err = ratio_setting.CheckGroupRatio(option.Value)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "ModelRequestRateLimitGroup":
|
||||
err = setting.CheckModelRequestRateLimitGroup(option.Value)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "console_setting.api_info":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "console_setting.announcements":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "console_setting.faq":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "console_setting.uptime_kuma_groups":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
err = model.UpdateOption(option.Key, option.Value)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
84
controller/playground.go
Normal file
84
controller/playground.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Playground(c *gin.Context) {
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
defer func() {
|
||||
if newAPIError != nil {
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"error": newAPIError.ToOpenAIError(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
useAccessToken := c.GetBool("use_access_token")
|
||||
if useAccessToken {
|
||||
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
|
||||
return
|
||||
}
|
||||
|
||||
playgroundRequest := &dto.PlayGroundRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if playgroundRequest.Model == "" {
|
||||
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
c.Set("original_model", playgroundRequest.Model)
|
||||
group := playgroundRequest.Group
|
||||
userGroup := c.GetString("group")
|
||||
|
||||
if group == "" {
|
||||
group = userGroup
|
||||
} else {
|
||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
||||
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
|
||||
return
|
||||
}
|
||||
c.Set("group", group)
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
// Write user context to ensure acceptUnsetRatio is available
|
||||
userCache, err := model.GetUserCache(userId)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
|
||||
return
|
||||
}
|
||||
userCache.WriteContext(c)
|
||||
|
||||
tempToken := &model.Token{
|
||||
UserId: userId,
|
||||
Name: fmt.Sprintf("playground-%s", group),
|
||||
Group: group,
|
||||
}
|
||||
_ = middleware.SetupContextForToken(c, tempToken)
|
||||
_, newAPIError = getChannel(c, group, playgroundRequest.Model, 0)
|
||||
if newAPIError != nil {
|
||||
return
|
||||
}
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||
|
||||
Relay(c)
|
||||
}
|
||||
71
controller/pricing.go
Normal file
71
controller/pricing.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetPricing(c *gin.Context) {
|
||||
pricing := model.GetPricing()
|
||||
userId, exists := c.Get("id")
|
||||
usableGroup := map[string]string{}
|
||||
groupRatio := map[string]float64{}
|
||||
for s, f := range ratio_setting.GetGroupRatioCopy() {
|
||||
groupRatio[s] = f
|
||||
}
|
||||
var group string
|
||||
if exists {
|
||||
user, err := model.GetUserCache(userId.(int))
|
||||
if err == nil {
|
||||
group = user.Group
|
||||
for g := range groupRatio {
|
||||
ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
|
||||
if ok {
|
||||
groupRatio[g] = ratio
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
usableGroup = setting.GetUserUsableGroups(group)
|
||||
// check groupRatio contains usableGroup
|
||||
for group := range ratio_setting.GetGroupRatioCopy() {
|
||||
if _, ok := usableGroup[group]; !ok {
|
||||
delete(groupRatio, group)
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": pricing,
|
||||
"group_ratio": groupRatio,
|
||||
"usable_group": usableGroup,
|
||||
})
|
||||
}
|
||||
|
||||
func ResetModelRatio(c *gin.Context) {
|
||||
defaultStr := ratio_setting.DefaultModelRatio2JSONString()
|
||||
err := model.UpdateOption("ModelRatio", defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "重置模型倍率成功",
|
||||
})
|
||||
}
|
||||
24
controller/ratio_config.go
Normal file
24
controller/ratio_config.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetRatioConfig(c *gin.Context) {
|
||||
if !ratio_setting.IsExposeRatioEnabled() {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "倍率配置接口未启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": ratio_setting.GetExposedData(),
|
||||
})
|
||||
}
|
||||
474
controller/ratio_sync.go
Normal file
474
controller/ratio_sync.go
Normal file
@@ -0,0 +1,474 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeoutSeconds = 10
|
||||
defaultEndpoint = "/api/ratio_config"
|
||||
maxConcurrentFetches = 8
|
||||
)
|
||||
|
||||
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||
|
||||
type upstreamResult struct {
|
||||
Name string `json:"name"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
Err string `json:"err,omitempty"`
|
||||
}
|
||||
|
||||
func FetchUpstreamRatios(c *gin.Context) {
|
||||
var req dto.UpstreamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Timeout <= 0 {
|
||||
req.Timeout = defaultTimeoutSeconds
|
||||
}
|
||||
|
||||
var upstreams []dto.UpstreamDTO
|
||||
|
||||
if len(req.Upstreams) > 0 {
|
||||
for _, u := range req.Upstreams {
|
||||
if strings.HasPrefix(u.BaseURL, "http") {
|
||||
if u.Endpoint == "" {
|
||||
u.Endpoint = defaultEndpoint
|
||||
}
|
||||
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
|
||||
upstreams = append(upstreams, u)
|
||||
}
|
||||
}
|
||||
} else if len(req.ChannelIDs) > 0 {
|
||||
intIds := make([]int, 0, len(req.ChannelIDs))
|
||||
for _, id64 := range req.ChannelIDs {
|
||||
intIds = append(intIds, int(id64))
|
||||
}
|
||||
dbChannels, err := model.GetChannelsByIds(intIds)
|
||||
if err != nil {
|
||||
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
||||
return
|
||||
}
|
||||
for _, ch := range dbChannels {
|
||||
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
||||
upstreams = append(upstreams, dto.UpstreamDTO{
|
||||
ID: ch.Id,
|
||||
Name: ch.Name,
|
||||
BaseURL: strings.TrimRight(base, "/"),
|
||||
Endpoint: "",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(upstreams) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
ch := make(chan upstreamResult, len(upstreams))
|
||||
|
||||
sem := make(chan struct{}, maxConcurrentFetches)
|
||||
|
||||
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
||||
|
||||
for _, chn := range upstreams {
|
||||
wg.Add(1)
|
||||
go func(chItem dto.UpstreamDTO) {
|
||||
defer wg.Done()
|
||||
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
|
||||
endpoint := chItem.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = defaultEndpoint
|
||||
} else if !strings.HasPrefix(endpoint, "/") {
|
||||
endpoint = "/" + endpoint
|
||||
}
|
||||
fullURL := chItem.BaseURL + endpoint
|
||||
|
||||
uniqueName := chItem.Name
|
||||
if chItem.ID != 0 {
|
||||
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||
if err != nil {
|
||||
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
||||
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
||||
return
|
||||
}
|
||||
// 兼容两种上游接口格式:
|
||||
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
||||
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
||||
var body struct {
|
||||
Success bool `json:"success"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
if !body.Success {
|
||||
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试按 type1 解析
|
||||
var type1Data map[string]any
|
||||
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
||||
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
||||
isType1 := false
|
||||
for _, rt := range ratioTypes {
|
||||
if _, ok := type1Data[rt]; ok {
|
||||
isType1 = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if isType1 {
|
||||
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
||||
var pricingItems []struct {
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
}
|
||||
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
|
||||
common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
|
||||
return
|
||||
}
|
||||
|
||||
modelRatioMap := make(map[string]float64)
|
||||
completionRatioMap := make(map[string]float64)
|
||||
modelPriceMap := make(map[string]float64)
|
||||
|
||||
for _, item := range pricingItems {
|
||||
if item.QuotaType == 1 {
|
||||
modelPriceMap[item.ModelName] = item.ModelPrice
|
||||
} else {
|
||||
modelRatioMap[item.ModelName] = item.ModelRatio
|
||||
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
||||
completionRatioMap[item.ModelName] = item.CompletionRatio
|
||||
}
|
||||
}
|
||||
|
||||
converted := make(map[string]any)
|
||||
|
||||
if len(modelRatioMap) > 0 {
|
||||
ratioAny := make(map[string]any, len(modelRatioMap))
|
||||
for k, v := range modelRatioMap {
|
||||
ratioAny[k] = v
|
||||
}
|
||||
converted["model_ratio"] = ratioAny
|
||||
}
|
||||
|
||||
if len(completionRatioMap) > 0 {
|
||||
compAny := make(map[string]any, len(completionRatioMap))
|
||||
for k, v := range completionRatioMap {
|
||||
compAny[k] = v
|
||||
}
|
||||
converted["completion_ratio"] = compAny
|
||||
}
|
||||
|
||||
if len(modelPriceMap) > 0 {
|
||||
priceAny := make(map[string]any, len(modelPriceMap))
|
||||
for k, v := range modelPriceMap {
|
||||
priceAny[k] = v
|
||||
}
|
||||
converted["model_price"] = priceAny
|
||||
}
|
||||
|
||||
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
||||
}(chn)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
|
||||
localData := ratio_setting.GetExposedData()
|
||||
|
||||
var testResults []dto.TestResult
|
||||
var successfulChannels []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}
|
||||
|
||||
for r := range ch {
|
||||
if r.Err != "" {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "error",
|
||||
Error: r.Err,
|
||||
})
|
||||
} else {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "success",
|
||||
})
|
||||
successfulChannels = append(successfulChannels, struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}{name: r.Name, data: r.Data})
|
||||
}
|
||||
}
|
||||
|
||||
differences := buildDifferences(localData, successfulChannels)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"differences": differences,
|
||||
"test_results": testResults,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}) map[string]map[string]dto.DifferenceItem {
|
||||
differences := make(map[string]map[string]dto.DifferenceItem)
|
||||
|
||||
allModels := make(map[string]struct{})
|
||||
|
||||
for _, ratioType := range ratioTypes {
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
for modelName := range localRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
for modelName := range upstreamRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
confidenceMap := make(map[string]map[string]bool)
|
||||
|
||||
// 预处理阶段:检查pricing接口的可信度
|
||||
for _, channel := range successfulChannels {
|
||||
confidenceMap[channel.name] = make(map[string]bool)
|
||||
|
||||
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
||||
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
||||
|
||||
if hasModelRatio && hasCompletionRatio {
|
||||
// 遍历所有模型,检查是否满足不可信条件
|
||||
for modelName := range allModels {
|
||||
// 默认为可信
|
||||
confidenceMap[channel.name][modelName] = true
|
||||
|
||||
// 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
|
||||
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
||||
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
||||
// 转换为float64进行比较
|
||||
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
||||
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
||||
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
||||
confidenceMap[channel.name][modelName] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 如果不是从pricing接口获取的数据,则全部标记为可信
|
||||
for modelName := range allModels {
|
||||
confidenceMap[channel.name][modelName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelName := range allModels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
var localValue interface{} = nil
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
if val, exists := localRatio[modelName]; exists {
|
||||
localValue = val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
upstreamValues := make(map[string]interface{})
|
||||
confidenceValues := make(map[string]bool)
|
||||
hasUpstreamValue := false
|
||||
hasDifference := false
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
var upstreamValue interface{} = nil
|
||||
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
if val, exists := upstreamRatio[modelName]; exists {
|
||||
upstreamValue = val
|
||||
hasUpstreamValue = true
|
||||
|
||||
if localValue != nil && localValue != val {
|
||||
hasDifference = true
|
||||
} else if localValue == val {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
}
|
||||
}
|
||||
if upstreamValue == nil && localValue == nil {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
|
||||
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
||||
hasDifference = true
|
||||
}
|
||||
|
||||
upstreamValues[channel.name] = upstreamValue
|
||||
|
||||
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
|
||||
}
|
||||
|
||||
shouldInclude := false
|
||||
|
||||
if localValue != nil {
|
||||
if hasDifference {
|
||||
shouldInclude = true
|
||||
}
|
||||
} else {
|
||||
if hasUpstreamValue {
|
||||
shouldInclude = true
|
||||
}
|
||||
}
|
||||
|
||||
if shouldInclude {
|
||||
if differences[modelName] == nil {
|
||||
differences[modelName] = make(map[string]dto.DifferenceItem)
|
||||
}
|
||||
differences[modelName][ratioType] = dto.DifferenceItem{
|
||||
Current: localValue,
|
||||
Upstreams: upstreamValues,
|
||||
Confidence: confidenceValues,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
channelHasDiff := make(map[string]bool)
|
||||
for _, ratioMap := range differences {
|
||||
for _, item := range ratioMap {
|
||||
for chName, val := range item.Upstreams {
|
||||
if val != nil && val != "same" {
|
||||
channelHasDiff[chName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelName, ratioMap := range differences {
|
||||
for ratioType, item := range ratioMap {
|
||||
for chName := range item.Upstreams {
|
||||
if !channelHasDiff[chName] {
|
||||
delete(item.Upstreams, chName)
|
||||
delete(item.Confidence, chName)
|
||||
}
|
||||
}
|
||||
|
||||
allSame := true
|
||||
for _, v := range item.Upstreams {
|
||||
if v != "same" {
|
||||
allSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(item.Upstreams) == 0 || allSame {
|
||||
delete(ratioMap, ratioType)
|
||||
} else {
|
||||
differences[modelName][ratioType] = item
|
||||
}
|
||||
}
|
||||
|
||||
if len(ratioMap) == 0 {
|
||||
delete(differences, modelName)
|
||||
}
|
||||
}
|
||||
|
||||
return differences
|
||||
}
|
||||
|
||||
func GetSyncableChannels(c *gin.Context) {
|
||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var syncableChannels []dto.SyncableChannel
|
||||
for _, channel := range channels {
|
||||
if channel.GetBaseURL() != "" {
|
||||
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||
ID: channel.Id,
|
||||
Name: channel.Name,
|
||||
BaseURL: channel.GetBaseURL(),
|
||||
Status: channel.Status,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": syncableChannels,
|
||||
})
|
||||
}
|
||||
193
controller/redemption.go
Normal file
193
controller/redemption.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAllRedemptions(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
redemptions, total, err := model.GetAllRedemptions(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(redemptions)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func SearchRedemptions(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
redemptions, total, err := model.SearchRedemptions(keyword, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(redemptions)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func GetRedemption(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
redemption, err := model.GetRedemptionById(id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": redemption,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func AddRedemption(c *gin.Context) {
|
||||
redemption := model.Redemption{}
|
||||
err := c.ShouldBindJSON(&redemption)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(redemption.Name) == 0 || len(redemption.Name) > 20 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "兑换码名称长度必须在1-20之间",
|
||||
})
|
||||
return
|
||||
}
|
||||
if redemption.Count <= 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "兑换码个数必须大于0",
|
||||
})
|
||||
return
|
||||
}
|
||||
if redemption.Count > 100 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "一次兑换码批量生成的个数不能大于 100",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
var keys []string
|
||||
for i := 0; i < redemption.Count; i++ {
|
||||
key := common.GetUUID()
|
||||
cleanRedemption := model.Redemption{
|
||||
UserId: c.GetInt("id"),
|
||||
Name: redemption.Name,
|
||||
Key: key,
|
||||
CreatedTime: common.GetTimestamp(),
|
||||
Quota: redemption.Quota,
|
||||
ExpiredTime: redemption.ExpiredTime,
|
||||
}
|
||||
err = cleanRedemption.Insert()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
"data": keys,
|
||||
})
|
||||
return
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": keys,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteRedemption(c *gin.Context) {
|
||||
id, _ := strconv.Atoi(c.Param("id"))
|
||||
err := model.DeleteRedemptionById(id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateRedemption(c *gin.Context) {
|
||||
statusOnly := c.Query("status_only")
|
||||
redemption := model.Redemption{}
|
||||
err := c.ShouldBindJSON(&redemption)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
cleanRedemption, err := model.GetRedemptionById(redemption.Id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if statusOnly == "" {
|
||||
if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
// If you add more fields, please also update redemption.Update()
|
||||
cleanRedemption.Name = redemption.Name
|
||||
cleanRedemption.Quota = redemption.Quota
|
||||
cleanRedemption.ExpiredTime = redemption.ExpiredTime
|
||||
}
|
||||
if statusOnly != "" {
|
||||
cleanRedemption.Status = redemption.Status
|
||||
}
|
||||
err = cleanRedemption.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": cleanRedemption,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteInvalidRedemption(c *gin.Context) {
|
||||
rows, err := model.DeleteInvalidRedemptions()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": rows,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func validateExpiredTime(expired int64) error {
|
||||
if expired != 0 && expired < common.GetTimestamp() {
|
||||
return errors.New("过期时间不能早于当前时间")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
476
controller/relay.go
Normal file
476
controller/relay.go
Normal file
@@ -0,0 +1,476 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
constant2 "one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
||||
var err *types.NewAPIError
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||
err = relay.ImageHelper(c)
|
||||
case relayconstant.RelayModeAudioSpeech:
|
||||
fallthrough
|
||||
case relayconstant.RelayModeAudioTranslation:
|
||||
fallthrough
|
||||
case relayconstant.RelayModeAudioTranscription:
|
||||
err = relay.AudioHelper(c)
|
||||
case relayconstant.RelayModeRerank:
|
||||
err = relay.RerankHelper(c, relayMode)
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
err = relay.EmbeddingHelper(c)
|
||||
case relayconstant.RelayModeResponses:
|
||||
err = relay.ResponsesHelper(c)
|
||||
case relayconstant.RelayModeGemini:
|
||||
err = relay.GeminiHelper(c)
|
||||
default:
|
||||
err = relay.TextHelper(c)
|
||||
}
|
||||
|
||||
if constant2.ErrorLogEnabled && err != nil {
|
||||
// 保存错误日志到mysql中
|
||||
userId := c.GetInt("id")
|
||||
tokenName := c.GetString("token_name")
|
||||
modelName := c.GetString("original_model")
|
||||
tokenId := c.GetInt("token_id")
|
||||
userGroup := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
other := make(map[string]interface{})
|
||||
other["error_type"] = err.ErrorType
|
||||
other["error_code"] = err.GetErrorCode()
|
||||
other["status_code"] = err.StatusCode
|
||||
other["channel_id"] = channelId
|
||||
other["channel_name"] = c.GetString("channel_name")
|
||||
other["channel_type"] = c.GetInt("channel_type")
|
||||
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
newAPIError = err
|
||||
break
|
||||
}
|
||||
|
||||
newAPIError = relayRequest(c, relayMode, channel)
|
||||
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
if len(useChannel) > 1 {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if newAPIError != nil {
|
||||
//if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||
// common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
|
||||
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||
//}
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"error": newAPIError.ToOpenAIError(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // 允许跨域
|
||||
},
|
||||
}
|
||||
|
||||
func WssRelay(c *gin.Context) {
|
||||
// 将 HTTP 连接升级为 WebSocket 连接
|
||||
|
||||
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
defer ws.Close()
|
||||
|
||||
if err != nil {
|
||||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError())
|
||||
return
|
||||
}
|
||||
|
||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
||||
originalModel := c.GetString("original_model")
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
newAPIError = err
|
||||
break
|
||||
}
|
||||
|
||||
newAPIError = wssRequest(c, ws, relayMode, channel)
|
||||
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
if len(useChannel) > 1 {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if newAPIError != nil {
|
||||
//if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||
//}
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
||||
}
|
||||
}
|
||||
|
||||
func RelayClaude(c *gin.Context) {
|
||||
//relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
newAPIError = err
|
||||
break
|
||||
}
|
||||
|
||||
newAPIError = claudeRequest(c, channel)
|
||||
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
if len(useChannel) > 1 {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if newAPIError != nil {
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"type": "error",
|
||||
"error": newAPIError.ToClaudeError(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relayHandler(c, relayMode)
|
||||
}
|
||||
|
||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relay.WssHelper(c, ws)
|
||||
}
|
||||
|
||||
func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relay.ClaudeHelper(c)
|
||||
}
|
||||
|
||||
func addUsedChannel(c *gin.Context, channelId int) {
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||
c.Set("use_channel", useChannel)
|
||||
}
|
||||
|
||||
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, *types.NewAPIError) {
|
||||
if retryCount == 0 {
|
||||
autoBan := c.GetBool("auto_ban")
|
||||
autoBanInt := 1
|
||||
if !autoBan {
|
||||
autoBanInt = 0
|
||||
}
|
||||
return &model.Channel{
|
||||
Id: c.GetInt("channel_id"),
|
||||
Type: c.GetInt("channel_type"),
|
||||
Name: c.GetString("channel_name"),
|
||||
AutoBan: &autoBanInt,
|
||||
}, nil
|
||||
}
|
||||
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
if err != nil {
|
||||
if group == "auto" {
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
|
||||
}
|
||||
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
|
||||
}
|
||||
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
if newAPIError != nil {
|
||||
return nil, newAPIError
|
||||
}
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
|
||||
if openaiErr == nil {
|
||||
return false
|
||||
}
|
||||
if types.IsChannelError(openaiErr) {
|
||||
return true
|
||||
}
|
||||
if types.IsLocalError(openaiErr) {
|
||||
return false
|
||||
}
|
||||
if retryTimes <= 0 {
|
||||
return false
|
||||
}
|
||||
if _, ok := c.Get("specific_channel_id"); ok {
|
||||
return false
|
||||
}
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
return true
|
||||
}
|
||||
if openaiErr.StatusCode == 307 {
|
||||
return true
|
||||
}
|
||||
if openaiErr.StatusCode/100 == 5 {
|
||||
// 超时不重试
|
||||
if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||
channelType := c.GetInt("channel_type")
|
||||
if channelType == constant.ChannelTypeAnthropic {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
if openaiErr.StatusCode == 408 {
|
||||
// azure处理超时不重试
|
||||
return false
|
||||
}
|
||||
if openaiErr.StatusCode/100 == 2 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
service.DisableChannel(channelError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func RelayMidjourney(c *gin.Context) {
|
||||
relayMode := c.GetInt("relay_mode")
|
||||
var err *dto.MidjourneyResponse
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeMidjourneyNotify:
|
||||
err = relay.RelayMidjourneyNotify(c)
|
||||
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
||||
err = relay.RelayMidjourneyTask(c, relayMode)
|
||||
case relayconstant.RelayModeMidjourneyTaskImageSeed:
|
||||
err = relay.RelayMidjourneyTaskImageSeed(c)
|
||||
case relayconstant.RelayModeSwapFace:
|
||||
err = relay.RelaySwapFace(c)
|
||||
default:
|
||||
err = relay.RelayMidjourneySubmit(c, relayMode)
|
||||
}
|
||||
//err = relayMidjourneySubmit(c, relayMode)
|
||||
log.Println(err)
|
||||
if err != nil {
|
||||
statusCode := http.StatusBadRequest
|
||||
if err.Code == 30 {
|
||||
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||
statusCode = http.StatusTooManyRequests
|
||||
}
|
||||
c.JSON(statusCode, gin.H{
|
||||
"description": fmt.Sprintf("%s %s", err.Description, err.Result),
|
||||
"type": "upstream_error",
|
||||
"code": err.Code,
|
||||
})
|
||||
channelId := c.GetInt("channel_id")
|
||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result)))
|
||||
}
|
||||
}
|
||||
|
||||
func RelayNotImplemented(c *gin.Context) {
|
||||
err := dto.OpenAIError{
|
||||
Message: "API not implemented",
|
||||
Type: "new_api_error",
|
||||
Param: "",
|
||||
Code: "api_not_implemented",
|
||||
}
|
||||
c.JSON(http.StatusNotImplemented, gin.H{
|
||||
"error": err,
|
||||
})
|
||||
}
|
||||
|
||||
func RelayNotFound(c *gin.Context) {
|
||||
err := dto.OpenAIError{
|
||||
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
|
||||
Type: "invalid_request_error",
|
||||
Param: "",
|
||||
Code: "",
|
||||
}
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": err,
|
||||
})
|
||||
}
|
||||
|
||||
func RelayTask(c *gin.Context) {
|
||||
retryTimes := common.RetryTimes
|
||||
channelId := c.GetInt("channel_id")
|
||||
relayMode := c.GetInt("relay_mode")
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
||||
taskErr := taskRelayHandler(c, relayMode)
|
||||
if taskErr == nil {
|
||||
retryTimes = 0
|
||||
}
|
||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||
channel, newAPIError := getChannel(c, group, originalModel, i)
|
||||
if newAPIError != nil {
|
||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
|
||||
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
|
||||
break
|
||||
}
|
||||
channelId = channel.Id
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||
c.Set("use_channel", useChannel)
|
||||
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
taskErr = taskRelayHandler(c, relayMode)
|
||||
}
|
||||
useChannel := c.GetStringSlice("use_channel")
|
||||
if len(useChannel) > 1 {
|
||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
if taskErr != nil {
|
||||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||||
taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
c.JSON(taskErr.StatusCode, taskErr)
|
||||
}
|
||||
}
|
||||
|
||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
||||
var err *dto.TaskError
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
|
||||
err = relay.RelayTaskFetch(c, relayMode)
|
||||
default:
|
||||
err = relay.RelayTaskSubmit(c, relayMode)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
|
||||
if taskErr == nil {
|
||||
return false
|
||||
}
|
||||
if retryTimes <= 0 {
|
||||
return false
|
||||
}
|
||||
if _, ok := c.Get("specific_channel_id"); ok {
|
||||
return false
|
||||
}
|
||||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||||
return true
|
||||
}
|
||||
if taskErr.StatusCode == 307 {
|
||||
return true
|
||||
}
|
||||
if taskErr.StatusCode/100 == 5 {
|
||||
// 超时不重试
|
||||
if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
if taskErr.StatusCode == http.StatusBadRequest {
|
||||
return false
|
||||
}
|
||||
if taskErr.StatusCode == 408 {
|
||||
// azure处理超时不重试
|
||||
return false
|
||||
}
|
||||
if taskErr.LocalError {
|
||||
return false
|
||||
}
|
||||
if taskErr.StatusCode/100 == 2 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
181
controller/setup.go
Normal file
181
controller/setup.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/setting/operation_setting"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Setup struct {
|
||||
Status bool `json:"status"`
|
||||
RootInit bool `json:"root_init"`
|
||||
DatabaseType string `json:"database_type"`
|
||||
}
|
||||
|
||||
type SetupRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
ConfirmPassword string `json:"confirmPassword"`
|
||||
SelfUseModeEnabled bool `json:"SelfUseModeEnabled"`
|
||||
DemoSiteEnabled bool `json:"DemoSiteEnabled"`
|
||||
}
|
||||
|
||||
func GetSetup(c *gin.Context) {
|
||||
setup := Setup{
|
||||
Status: constant.Setup,
|
||||
}
|
||||
if constant.Setup {
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": setup,
|
||||
})
|
||||
return
|
||||
}
|
||||
setup.RootInit = model.RootUserExists()
|
||||
if common.UsingMySQL {
|
||||
setup.DatabaseType = "mysql"
|
||||
}
|
||||
if common.UsingPostgreSQL {
|
||||
setup.DatabaseType = "postgres"
|
||||
}
|
||||
if common.UsingSQLite {
|
||||
setup.DatabaseType = "sqlite"
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": setup,
|
||||
})
|
||||
}
|
||||
|
||||
func PostSetup(c *gin.Context) {
|
||||
// Check if setup is already completed
|
||||
if constant.Setup {
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "系统已经初始化完成",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if root user already exists
|
||||
rootExists := model.RootUserExists()
|
||||
|
||||
var req SetupRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "请求参数有误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// If root doesn't exist, validate and create admin account
|
||||
if !rootExists {
|
||||
// Validate username length: max 12 characters to align with model.User validation
|
||||
if len(req.Username) > 12 {
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "用户名长度不能超过12个字符",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Validate password
|
||||
if req.Password != req.ConfirmPassword {
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "两次输入的密码不一致",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Password) < 8 {
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "密码长度至少为8个字符",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create root user
|
||||
hashedPassword, err := common.Password2Hash(req.Password)
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
"success": false,
|
||||
"message": "系统错误: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
rootUser := model.User{
|
||||
Username: req.Username,
|
||||
Password: hashedPassword,
|
||||
Role: common.RoleRootUser,
|
||||
Status: common.UserStatusEnabled,
|
||||
DisplayName: "Root User",
|
||||
AccessToken: nil,
|
||||
Quota: 100000000,
|
||||
}
|
||||
err = model.DB.Create(&rootUser).Error
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
"success": false,
|
||||
"message": "创建管理员账号失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Set operation modes
|
||||
operation_setting.SelfUseModeEnabled = req.SelfUseModeEnabled
|
||||
operation_setting.DemoSiteEnabled = req.DemoSiteEnabled
|
||||
|
||||
// Save operation modes to database for persistence
|
||||
err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
"success": false,
|
||||
"message": "保存自用模式设置失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
"success": false,
|
||||
"message": "保存演示站点模式设置失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Update setup status
|
||||
constant.Setup = true
|
||||
|
||||
setup := model.Setup{
|
||||
Version: common.Version,
|
||||
InitializedAt: time.Now().Unix(),
|
||||
}
|
||||
err = model.DB.Create(&setup).Error
|
||||
if err != nil {
|
||||
c.JSON(500, gin.H{
|
||||
"success": false,
|
||||
"message": "系统初始化失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "系统初始化成功",
|
||||
})
|
||||
}
|
||||
|
||||
func boolToString(b bool) string {
|
||||
if b {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
116
controller/swag_video.go
Normal file
116
controller/swag_video.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// VideoGenerations
|
||||
// @Summary 生成视频
|
||||
// @Description 调用视频生成接口生成视频
|
||||
// @Description 支持多种视频生成服务:
|
||||
// @Description - 可灵AI (Kling): https://app.klingai.com/cn/dev/document-api/apiReference/commonInfo
|
||||
// @Description - 即梦 (Jimeng): https://www.volcengine.com/docs/85621/1538636
|
||||
// @Tags Video
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
|
||||
// @Param request body dto.VideoRequest true "视频生成请求参数"
|
||||
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||
// @Router /v1/video/generations [post]
|
||||
func VideoGenerations(c *gin.Context) {
|
||||
}
|
||||
|
||||
// VideoGenerationsTaskId
|
||||
// @Summary 查询视频
|
||||
// @Description 根据任务ID查询视频生成任务的状态和结果
|
||||
// @Tags Video
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param task_id path string true "Task ID"
|
||||
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
|
||||
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||
// @Router /v1/video/generations/{task_id} [get]
|
||||
func VideoGenerationsTaskId(c *gin.Context) {
|
||||
}
|
||||
|
||||
// KlingText2VideoGenerations
|
||||
// @Summary 可灵文生视频
|
||||
// @Description 调用可灵AI文生视频接口,生成视频内容
|
||||
// @Tags Video
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
|
||||
// @Param request body KlingText2VideoRequest true "视频生成请求参数"
|
||||
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
|
||||
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||
// @Router /kling/v1/videos/text2video [post]
|
||||
func KlingText2VideoGenerations(c *gin.Context) {
|
||||
}
|
||||
|
||||
type KlingText2VideoRequest struct {
|
||||
ModelName string `json:"model_name,omitempty" example:"kling-v1"`
|
||||
Prompt string `json:"prompt" binding:"required" example:"A cat playing piano in the garden"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
|
||||
CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
|
||||
Mode string `json:"mode,omitempty" example:"std"`
|
||||
CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
|
||||
Duration string `json:"duration,omitempty" example:"5"`
|
||||
CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
|
||||
ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-001"`
|
||||
}
|
||||
|
||||
type KlingCameraControl struct {
|
||||
Type string `json:"type,omitempty" example:"simple"`
|
||||
Config *KlingCameraConfig `json:"config,omitempty"`
|
||||
}
|
||||
|
||||
type KlingCameraConfig struct {
|
||||
Horizontal float64 `json:"horizontal,omitempty" example:"2.5"`
|
||||
Vertical float64 `json:"vertical,omitempty" example:"0"`
|
||||
Pan float64 `json:"pan,omitempty" example:"0"`
|
||||
Tilt float64 `json:"tilt,omitempty" example:"0"`
|
||||
Roll float64 `json:"roll,omitempty" example:"0"`
|
||||
Zoom float64 `json:"zoom,omitempty" example:"0"`
|
||||
}
|
||||
|
||||
// KlingImage2VideoGenerations
|
||||
// @Summary 可灵官方-图生视频
|
||||
// @Description 调用可灵AI图生视频接口,生成视频内容
|
||||
// @Tags Video
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "用户认证令牌 (Aeess-Token: sk-xxxx)"
|
||||
// @Param request body KlingImage2VideoRequest true "图生视频请求参数"
|
||||
// @Success 200 {object} dto.VideoTaskResponse "任务状态和结果"
|
||||
// @Failure 400 {object} dto.OpenAIError "请求参数错误"
|
||||
// @Failure 401 {object} dto.OpenAIError "未授权"
|
||||
// @Failure 403 {object} dto.OpenAIError "无权限"
|
||||
// @Failure 500 {object} dto.OpenAIError "服务器内部错误"
|
||||
// @Router /kling/v1/videos/image2video [post]
|
||||
func KlingImage2VideoGenerations(c *gin.Context) {
|
||||
}
|
||||
|
||||
type KlingImage2VideoRequest struct {
|
||||
ModelName string `json:"model_name,omitempty" example:"kling-v2-master"`
|
||||
Image string `json:"image" binding:"required" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"`
|
||||
Prompt string `json:"prompt,omitempty" example:"A cat playing piano in the garden"`
|
||||
NegativePrompt string `json:"negative_prompt,omitempty" example:"blurry, low quality"`
|
||||
CfgScale float64 `json:"cfg_scale,omitempty" example:"0.7"`
|
||||
Mode string `json:"mode,omitempty" example:"std"`
|
||||
CameraControl *KlingCameraControl `json:"camera_control,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty" example:"16:9"`
|
||||
Duration string `json:"duration,omitempty" example:"5"`
|
||||
CallbackURL string `json:"callback_url,omitempty" example:"https://your.domain/callback"`
|
||||
ExternalTaskId string `json:"external_task_id,omitempty" example:"custom-task-002"`
|
||||
}
|
||||
273
controller/task.go
Normal file
273
controller/task.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
func UpdateTaskBulk() {
|
||||
//revocer
|
||||
//imageModel := "midjourney"
|
||||
for {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
common.SysLog("任务进度轮询开始")
|
||||
ctx := context.TODO()
|
||||
allTasks := model.GetAllUnFinishSyncTasks(500)
|
||||
platformTask := make(map[constant.TaskPlatform][]*model.Task)
|
||||
for _, t := range allTasks {
|
||||
platformTask[t.Platform] = append(platformTask[t.Platform], t)
|
||||
}
|
||||
for platform, tasks := range platformTask {
|
||||
if len(tasks) == 0 {
|
||||
continue
|
||||
}
|
||||
taskChannelM := make(map[int][]string)
|
||||
taskM := make(map[string]*model.Task)
|
||||
nullTaskIds := make([]int64, 0)
|
||||
for _, task := range tasks {
|
||||
if task.TaskID == "" {
|
||||
// 统计失败的未完成任务
|
||||
nullTaskIds = append(nullTaskIds, task.ID)
|
||||
continue
|
||||
}
|
||||
taskM[task.TaskID] = task
|
||||
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
|
||||
}
|
||||
if len(nullTaskIds) > 0 {
|
||||
err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
||||
} else {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
||||
}
|
||||
}
|
||||
if len(taskChannelM) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
UpdateTaskByPlatform(platform, taskChannelM, taskM)
|
||||
}
|
||||
common.SysLog("任务进度轮询完成")
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
|
||||
switch platform {
|
||||
case constant.TaskPlatformMidjourney:
|
||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||
case constant.TaskPlatformSuno:
|
||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||
case constant.TaskPlatformKling, constant.TaskPlatformJimeng:
|
||||
_ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM)
|
||||
default:
|
||||
common.SysLog("未知平台")
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
channel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
|
||||
err = model.TaskBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
|
||||
if adaptor == nil {
|
||||
return errors.New("adaptor not found")
|
||||
}
|
||||
resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
|
||||
"ids": taskIds,
|
||||
})
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
return err
|
||||
}
|
||||
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
||||
err = json.Unmarshal(responseBody, &responseItems)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
return err
|
||||
}
|
||||
if !responseItems.IsSuccess() {
|
||||
common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
|
||||
return err
|
||||
}
|
||||
|
||||
for _, responseItem := range responseItems.Data {
|
||||
task := taskM[responseItem.TaskID]
|
||||
if !checkTaskNeedUpdate(task, responseItem) {
|
||||
continue
|
||||
}
|
||||
|
||||
task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
|
||||
task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
|
||||
task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
|
||||
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
|
||||
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
|
||||
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
||||
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
//err = model.CacheUpdateUserQuota(task.UserId) ?
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
} else {
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
err = model.IncreaseUserQuota(task.UserId, quota, false)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
if responseItem.Status == model.TaskStatusSuccess {
|
||||
task.Progress = "100%"
|
||||
}
|
||||
task.Data = responseItem.Data
|
||||
|
||||
err = task.Update()
|
||||
if err != nil {
|
||||
common.SysError("UpdateMidjourneyTask task error: " + err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
|
||||
|
||||
if oldTask.SubmitTime != newTask.SubmitTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.StartTime != newTask.StartTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.FinishTime != newTask.FinishTime {
|
||||
return true
|
||||
}
|
||||
if string(oldTask.Status) != newTask.Status {
|
||||
return true
|
||||
}
|
||||
if oldTask.FailReason != newTask.FailReason {
|
||||
return true
|
||||
}
|
||||
if oldTask.FinishTime != newTask.FinishTime {
|
||||
return true
|
||||
}
|
||||
|
||||
if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
|
||||
return true
|
||||
}
|
||||
|
||||
oldData, _ := json.Marshal(oldTask.Data)
|
||||
newData, _ := json.Marshal(newTask.Data)
|
||||
|
||||
sort.Slice(oldData, func(i, j int) bool {
|
||||
return oldData[i] < oldData[j]
|
||||
})
|
||||
sort.Slice(newData, func(i, j int) bool {
|
||||
return newData[i] < newData[j]
|
||||
})
|
||||
|
||||
if string(oldData) != string(newData) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func GetAllTask(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
// 解析其他查询参数
|
||||
queryParams := model.SyncTaskQueryParams{
|
||||
Platform: constant.TaskPlatform(c.Query("platform")),
|
||||
TaskID: c.Query("task_id"),
|
||||
Status: c.Query("status"),
|
||||
Action: c.Query("action"),
|
||||
StartTimestamp: startTimestamp,
|
||||
EndTimestamp: endTimestamp,
|
||||
ChannelID: c.Query("channel_id"),
|
||||
}
|
||||
|
||||
items := model.TaskGetAllTasks(pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.TaskCountAllTasks(queryParams)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
|
||||
func GetUserTask(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
|
||||
queryParams := model.SyncTaskQueryParams{
|
||||
Platform: constant.TaskPlatform(c.Query("platform")),
|
||||
TaskID: c.Query("task_id"),
|
||||
Status: c.Query("status"),
|
||||
Action: c.Query("action"),
|
||||
StartTimestamp: startTimestamp,
|
||||
EndTimestamp: endTimestamp,
|
||||
}
|
||||
|
||||
items := model.TaskGetAllUserTask(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), queryParams)
|
||||
total := model.TaskCountAllUserTask(userId, queryParams)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(items)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
}
|
||||
138
controller/task_video.go
Normal file
138
controller/task_video.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/channel"
|
||||
"time"
|
||||
)
|
||||
|
||||
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if errUpdate != nil {
|
||||
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||
}
|
||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||
}
|
||||
adaptor := relay.GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("video adaptor not found")
|
||||
}
|
||||
for _, taskId := range taskIds {
|
||||
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
|
||||
task := taskM[taskId]
|
||||
if task == nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||
return fmt.Errorf("task %s not found", taskId)
|
||||
}
|
||||
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||
"task_id": taskId,
|
||||
"action": task.Action,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
|
||||
}
|
||||
//if resp.StatusCode != http.StatusOK {
|
||||
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
|
||||
//}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
taskResult, err := adaptor.ParseTaskResult(responseBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||
}
|
||||
//if taskResult.Code != 0 {
|
||||
// return fmt.Errorf("video task fetch failed for task %s", taskId)
|
||||
//}
|
||||
|
||||
now := time.Now().Unix()
|
||||
if taskResult.Status == "" {
|
||||
return fmt.Errorf("task %s status is empty", taskId)
|
||||
}
|
||||
task.Status = model.TaskStatus(taskResult.Status)
|
||||
switch taskResult.Status {
|
||||
case model.TaskStatusSubmitted:
|
||||
task.Progress = "10%"
|
||||
case model.TaskStatusQueued:
|
||||
task.Progress = "20%"
|
||||
case model.TaskStatusInProgress:
|
||||
task.Progress = "30%"
|
||||
if task.StartTime == 0 {
|
||||
task.StartTime = now
|
||||
}
|
||||
case model.TaskStatusSuccess:
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Url
|
||||
case model.TaskStatusFailure:
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Reason
|
||||
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
|
||||
}
|
||||
if taskResult.Progress != "" {
|
||||
task.Progress = taskResult.Progress
|
||||
}
|
||||
|
||||
task.Data = responseBody
|
||||
if err := task.Update(); err != nil {
|
||||
common.SysError("UpdateVideoTask task error: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
124
controller/telegram.go
Normal file
124
controller/telegram.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"sort"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TelegramBind(c *gin.Context) {
|
||||
if !common.TelegramOAuthEnabled {
|
||||
c.JSON(200, gin.H{
|
||||
"message": "管理员未开启通过 Telegram 登录以及注册",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
params := c.Request.URL.Query()
|
||||
if !checkTelegramAuthorization(params, common.TelegramBotToken) {
|
||||
c.JSON(200, gin.H{
|
||||
"message": "无效的请求",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
telegramId := params["id"][0]
|
||||
if model.IsTelegramIdAlreadyTaken(telegramId) {
|
||||
c.JSON(200, gin.H{
|
||||
"message": "该 Telegram 账户已被绑定",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user := model.User{Id: id.(int)}
|
||||
if err := user.FillUserById(); err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"message": err.Error(),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
user.TelegramId = telegramId
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"message": err.Error(),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(302, "/setting")
|
||||
}
|
||||
|
||||
func TelegramLogin(c *gin.Context) {
|
||||
if !common.TelegramOAuthEnabled {
|
||||
c.JSON(200, gin.H{
|
||||
"message": "管理员未开启通过 Telegram 登录以及注册",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
params := c.Request.URL.Query()
|
||||
if !checkTelegramAuthorization(params, common.TelegramBotToken) {
|
||||
c.JSON(200, gin.H{
|
||||
"message": "无效的请求",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
telegramId := params["id"][0]
|
||||
user := model.User{TelegramId: telegramId}
|
||||
if err := user.FillUserByTelegramId(); err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"message": err.Error(),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func checkTelegramAuthorization(params map[string][]string, token string) bool {
|
||||
strs := []string{}
|
||||
var hash = ""
|
||||
for k, v := range params {
|
||||
if k == "hash" {
|
||||
hash = v[0]
|
||||
continue
|
||||
}
|
||||
strs = append(strs, k+"="+v[0])
|
||||
}
|
||||
sort.Strings(strs)
|
||||
var imploded = ""
|
||||
for _, s := range strs {
|
||||
if imploded != "" {
|
||||
imploded += "\n"
|
||||
}
|
||||
imploded += s
|
||||
}
|
||||
sha256hash := sha256.New()
|
||||
io.WriteString(sha256hash, token)
|
||||
hmachash := hmac.New(sha256.New, sha256hash.Sum(nil))
|
||||
io.WriteString(hmachash, imploded)
|
||||
ss := hex.EncodeToString(hmachash.Sum(nil))
|
||||
return hash == ss
|
||||
}
|
||||
236
controller/token.go
Normal file
236
controller/token.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAllTokens(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
tokens, err := model.GetAllUserTokens(userId, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
total, _ := model.CountUserTokens(userId)
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(tokens)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func SearchTokens(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
keyword := c.Query("keyword")
|
||||
token := c.Query("token")
|
||||
tokens, err := model.SearchUserTokens(userId, keyword, token)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": tokens,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetToken(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
userId := c.GetInt("id")
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
token, err := model.GetTokenByIds(id, userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": token,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetTokenStatus(c *gin.Context) {
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
token, err := model.GetTokenByIds(tokenId, userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
expiredAt := token.ExpiredTime
|
||||
if expiredAt == -1 {
|
||||
expiredAt = 0
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "credit_summary",
|
||||
"total_granted": token.RemainQuota,
|
||||
"total_used": 0, // not supported currently
|
||||
"total_available": token.RemainQuota,
|
||||
"expires_at": expiredAt * 1000,
|
||||
})
|
||||
}
|
||||
|
||||
func AddToken(c *gin.Context) {
|
||||
token := model.Token{}
|
||||
err := c.ShouldBindJSON(&token)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(token.Name) > 30 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "令牌名称过长",
|
||||
})
|
||||
return
|
||||
}
|
||||
key, err := common.GenerateKey()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成令牌失败",
|
||||
})
|
||||
common.SysError("failed to generate token key: " + err.Error())
|
||||
return
|
||||
}
|
||||
cleanToken := model.Token{
|
||||
UserId: c.GetInt("id"),
|
||||
Name: token.Name,
|
||||
Key: key,
|
||||
CreatedTime: common.GetTimestamp(),
|
||||
AccessedTime: common.GetTimestamp(),
|
||||
ExpiredTime: token.ExpiredTime,
|
||||
RemainQuota: token.RemainQuota,
|
||||
UnlimitedQuota: token.UnlimitedQuota,
|
||||
ModelLimitsEnabled: token.ModelLimitsEnabled,
|
||||
ModelLimits: token.ModelLimits,
|
||||
AllowIps: token.AllowIps,
|
||||
Group: token.Group,
|
||||
}
|
||||
err = cleanToken.Insert()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteToken(c *gin.Context) {
|
||||
id, _ := strconv.Atoi(c.Param("id"))
|
||||
userId := c.GetInt("id")
|
||||
err := model.DeleteTokenById(id, userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateToken(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
statusOnly := c.Query("status_only")
|
||||
token := model.Token{}
|
||||
err := c.ShouldBindJSON(&token)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(token.Name) > 30 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "令牌名称过长",
|
||||
})
|
||||
return
|
||||
}
|
||||
cleanToken, err := model.GetTokenByIds(token.Id, userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if token.Status == common.TokenStatusEnabled {
|
||||
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
|
||||
})
|
||||
return
|
||||
}
|
||||
if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if statusOnly != "" {
|
||||
cleanToken.Status = token.Status
|
||||
} else {
|
||||
// If you add more fields, please also update token.Update()
|
||||
cleanToken.Name = token.Name
|
||||
cleanToken.ExpiredTime = token.ExpiredTime
|
||||
cleanToken.RemainQuota = token.RemainQuota
|
||||
cleanToken.UnlimitedQuota = token.UnlimitedQuota
|
||||
cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
|
||||
cleanToken.ModelLimits = token.ModelLimits
|
||||
cleanToken.AllowIps = token.AllowIps
|
||||
cleanToken.Group = token.Group
|
||||
}
|
||||
err = cleanToken.Update()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": cleanToken,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type TokenBatch struct {
|
||||
Ids []int `json:"ids"`
|
||||
}
|
||||
|
||||
func DeleteTokenBatch(c *gin.Context) {
|
||||
tokenBatch := TokenBatch{}
|
||||
if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
userId := c.GetInt("id")
|
||||
count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": count,
|
||||
})
|
||||
}
|
||||
265
controller/topup.go
Normal file
265
controller/topup.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Calcium-Ion/go-epay/epay"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
type EpayRequest struct {
|
||||
Amount int64 `json:"amount"`
|
||||
PaymentMethod string `json:"payment_method"`
|
||||
TopUpCode string `json:"top_up_code"`
|
||||
}
|
||||
|
||||
type AmountRequest struct {
|
||||
Amount int64 `json:"amount"`
|
||||
TopUpCode string `json:"top_up_code"`
|
||||
}
|
||||
|
||||
func GetEpayClient() *epay.Client {
|
||||
if setting.PayAddress == "" || setting.EpayId == "" || setting.EpayKey == "" {
|
||||
return nil
|
||||
}
|
||||
withUrl, err := epay.NewClient(&epay.Config{
|
||||
PartnerID: setting.EpayId,
|
||||
Key: setting.EpayKey,
|
||||
}, setting.PayAddress)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return withUrl
|
||||
}
|
||||
|
||||
func getPayMoney(amount int64, group string) float64 {
|
||||
dAmount := decimal.NewFromInt(amount)
|
||||
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
dAmount = dAmount.Div(dQuotaPerUnit)
|
||||
}
|
||||
|
||||
topupGroupRatio := common.GetTopupGroupRatio(group)
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
|
||||
dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio)
|
||||
dPrice := decimal.NewFromFloat(setting.Price)
|
||||
|
||||
payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio)
|
||||
|
||||
return payMoney.InexactFloat64()
|
||||
}
|
||||
|
||||
func getMinTopup() int64 {
|
||||
minTopup := setting.MinTopUp
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
dMinTopup := decimal.NewFromInt(int64(minTopup))
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
minTopup = int(dMinTopup.Mul(dQuotaPerUnit).IntPart())
|
||||
}
|
||||
return int64(minTopup)
|
||||
}
|
||||
|
||||
func RequestEpay(c *gin.Context) {
|
||||
var req EpayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
if req.Amount < getMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
group, err := model.GetUserGroup(id, true)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
payMoney := getPayMoney(req.Amount, group)
|
||||
if payMoney < 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
|
||||
if !setting.ContainsPayMethod(req.PaymentMethod) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
||||
client := GetEpayClient()
|
||||
if client == nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
|
||||
return
|
||||
}
|
||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
||||
Type: req.PaymentMethod,
|
||||
ServiceTradeNo: tradeNo,
|
||||
Name: fmt.Sprintf("TUC%d", req.Amount),
|
||||
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||
Device: epay.PC,
|
||||
NotifyUrl: notifyUrl,
|
||||
ReturnUrl: returnUrl,
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
amount := req.Amount
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
dAmount := decimal.NewFromInt(int64(amount))
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
amount = dAmount.Div(dQuotaPerUnit).IntPart()
|
||||
}
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: amount,
|
||||
Money: payMoney,
|
||||
TradeNo: tradeNo,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: "pending",
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
|
||||
}
|
||||
|
||||
// tradeNo lock
|
||||
var orderLocks sync.Map
|
||||
var createLock sync.Mutex
|
||||
|
||||
// LockOrder 尝试对给定订单号加锁
|
||||
func LockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
createLock.Lock()
|
||||
defer createLock.Unlock()
|
||||
lock, ok = orderLocks.Load(tradeNo)
|
||||
if !ok {
|
||||
lock = new(sync.Mutex)
|
||||
orderLocks.Store(tradeNo, lock)
|
||||
}
|
||||
}
|
||||
lock.(*sync.Mutex).Lock()
|
||||
}
|
||||
|
||||
// UnlockOrder 释放给定订单号的锁
|
||||
func UnlockOrder(tradeNo string) {
|
||||
lock, ok := orderLocks.Load(tradeNo)
|
||||
if ok {
|
||||
lock.(*sync.Mutex).Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func EpayNotify(c *gin.Context) {
|
||||
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
|
||||
r[t] = c.Request.URL.Query().Get(t)
|
||||
return r
|
||||
}, map[string]string{})
|
||||
client := GetEpayClient()
|
||||
if client == nil {
|
||||
log.Println("易支付回调失败 未找到配置信息")
|
||||
_, err := c.Writer.Write([]byte("fail"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
return
|
||||
}
|
||||
}
|
||||
verifyInfo, err := client.Verify(params)
|
||||
if err == nil && verifyInfo.VerifyStatus {
|
||||
_, err := c.Writer.Write([]byte("success"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
}
|
||||
} else {
|
||||
_, err := c.Writer.Write([]byte("fail"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
}
|
||||
log.Println("易支付回调签名验证失败")
|
||||
return
|
||||
}
|
||||
|
||||
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
|
||||
log.Println(verifyInfo)
|
||||
LockOrder(verifyInfo.ServiceTradeNo)
|
||||
defer UnlockOrder(verifyInfo.ServiceTradeNo)
|
||||
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
|
||||
if topUp == nil {
|
||||
log.Printf("易支付回调未找到订单: %v", verifyInfo)
|
||||
return
|
||||
}
|
||||
if topUp.Status == "pending" {
|
||||
topUp.Status = "success"
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
log.Printf("易支付回调更新订单失败: %v", topUp)
|
||||
return
|
||||
}
|
||||
//user, _ := model.GetUserById(topUp.UserId, false)
|
||||
//user.Quota += topUp.Amount * 500000
|
||||
dAmount := decimal.NewFromInt(int64(topUp.Amount))
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
quotaToAdd := int(dAmount.Mul(dQuotaPerUnit).IntPart())
|
||||
err = model.IncreaseUserQuota(topUp.UserId, quotaToAdd, true)
|
||||
if err != nil {
|
||||
log.Printf("易支付回调更新用户失败: %v", topUp)
|
||||
return
|
||||
}
|
||||
log.Printf("易支付回调更新用户成功 %v", topUp)
|
||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v,支付金额:%f", common.LogQuota(quotaToAdd), topUp.Money))
|
||||
}
|
||||
} else {
|
||||
log.Printf("易支付异常回调: %v", verifyInfo)
|
||||
}
|
||||
}
|
||||
|
||||
func RequestAmount(c *gin.Context) {
|
||||
var req AmountRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Amount < getMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getMinTopup())})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
group, err := model.GetUserGroup(id, true)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
payMoney := getPayMoney(req.Amount, group)
|
||||
if payMoney <= 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||
}
|
||||
275
controller/topup_stripe.go
Normal file
275
controller/topup_stripe.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stripe/stripe-go/v81"
|
||||
"github.com/stripe/stripe-go/v81/checkout/session"
|
||||
"github.com/stripe/stripe-go/v81/webhook"
|
||||
"github.com/thanhpk/randstr"
|
||||
)
|
||||
|
||||
const (
|
||||
PaymentMethodStripe = "stripe"
|
||||
)
|
||||
|
||||
var stripeAdaptor = &StripeAdaptor{}
|
||||
|
||||
type StripePayRequest struct {
|
||||
Amount int64 `json:"amount"`
|
||||
PaymentMethod string `json:"payment_method"`
|
||||
}
|
||||
|
||||
type StripeAdaptor struct {
|
||||
}
|
||||
|
||||
func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
|
||||
if req.Amount < getStripeMinTopup() {
|
||||
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
group, err := model.GetUserGroup(id, true)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
}
|
||||
payMoney := getStripePayMoney(float64(req.Amount), group)
|
||||
if payMoney <= 0.01 {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
||||
}
|
||||
|
||||
func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
|
||||
if req.PaymentMethod != PaymentMethodStripe {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
|
||||
return
|
||||
}
|
||||
if req.Amount < getStripeMinTopup() {
|
||||
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
|
||||
return
|
||||
}
|
||||
if req.Amount > 10000 {
|
||||
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
|
||||
|
||||
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), randstr.String(4))
|
||||
referenceId := "ref_" + common.Sha1([]byte(reference))
|
||||
|
||||
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount)
|
||||
if err != nil {
|
||||
log.Println("获取Stripe Checkout支付链接失败", err)
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: req.Amount,
|
||||
Money: chargedMoney,
|
||||
TradeNo: referenceId,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: common.TopUpStatusPending,
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"message": "success",
|
||||
"data": gin.H{
|
||||
"pay_link": payLink,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func RequestStripeAmount(c *gin.Context) {
|
||||
var req StripePayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
stripeAdaptor.RequestAmount(c, &req)
|
||||
}
|
||||
|
||||
func RequestStripePay(c *gin.Context) {
|
||||
var req StripePayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
||||
return
|
||||
}
|
||||
stripeAdaptor.RequestPay(c, &req)
|
||||
}
|
||||
|
||||
func StripeWebhook(c *gin.Context) {
|
||||
payload, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
|
||||
c.AbortWithStatus(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
signature := c.GetHeader("Stripe-Signature")
|
||||
endpointSecret := setting.StripeWebhookSecret
|
||||
event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
|
||||
IgnoreAPIVersionMismatch: true,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Stripe Webhook验签失败: %v\n", err)
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case stripe.EventTypeCheckoutSessionCompleted:
|
||||
sessionCompleted(event)
|
||||
case stripe.EventTypeCheckoutSessionExpired:
|
||||
sessionExpired(event)
|
||||
default:
|
||||
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
|
||||
}
|
||||
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
func sessionCompleted(event stripe.Event) {
|
||||
customerId := event.GetObjectValue("customer")
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
status := event.GetObjectValue("status")
|
||||
if "complete" != status {
|
||||
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
err := model.Recharge(referenceId, customerId)
|
||||
if err != nil {
|
||||
log.Println(err.Error(), referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
|
||||
currency := strings.ToUpper(event.GetObjectValue("currency"))
|
||||
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
|
||||
}
|
||||
|
||||
func sessionExpired(event stripe.Event) {
|
||||
referenceId := event.GetObjectValue("client_reference_id")
|
||||
status := event.GetObjectValue("status")
|
||||
if "expired" != status {
|
||||
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
if len(referenceId) == 0 {
|
||||
log.Println("未提供支付单号")
|
||||
return
|
||||
}
|
||||
|
||||
topUp := model.GetTopUpByTradeNo(referenceId)
|
||||
if topUp == nil {
|
||||
log.Println("充值订单不存在", referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
if topUp.Status != common.TopUpStatusPending {
|
||||
log.Println("充值订单状态错误", referenceId)
|
||||
}
|
||||
|
||||
topUp.Status = common.TopUpStatusExpired
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("充值订单已过期", referenceId)
|
||||
}
|
||||
|
||||
func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
|
||||
if !strings.HasPrefix(setting.StripeApiSecret, "sk_") && !strings.HasPrefix(setting.StripeApiSecret, "rk_") {
|
||||
return "", fmt.Errorf("无效的Stripe API密钥")
|
||||
}
|
||||
|
||||
stripe.Key = setting.StripeApiSecret
|
||||
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
ClientReferenceID: stripe.String(referenceId),
|
||||
SuccessURL: stripe.String(setting.ServerAddress + "/log"),
|
||||
CancelURL: stripe.String(setting.ServerAddress + "/topup"),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(setting.StripePriceId),
|
||||
Quantity: stripe.Int64(amount),
|
||||
},
|
||||
},
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
|
||||
}
|
||||
|
||||
if "" == customerId {
|
||||
if "" != email {
|
||||
params.CustomerEmail = stripe.String(email)
|
||||
}
|
||||
|
||||
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
|
||||
} else {
|
||||
params.Customer = stripe.String(customerId)
|
||||
}
|
||||
|
||||
result, err := session.New(params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
func GetChargedAmount(count float64, user model.User) float64 {
|
||||
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
|
||||
if topUpGroupRatio == 0 {
|
||||
topUpGroupRatio = 1
|
||||
}
|
||||
|
||||
return count * topUpGroupRatio
|
||||
}
|
||||
|
||||
func getStripePayMoney(amount float64, group string) float64 {
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / common.QuotaPerUnit
|
||||
}
|
||||
// Using float64 for monetary calculations is acceptable here due to the small amounts involved
|
||||
topupGroupRatio := common.GetTopupGroupRatio(group)
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio
|
||||
return payMoney
|
||||
}
|
||||
|
||||
func getStripeMinTopup() int64 {
|
||||
minTopup := setting.StripeMinTopUp
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
minTopup = minTopup * int(common.QuotaPerUnit)
|
||||
}
|
||||
return int64(minTopup)
|
||||
}
|
||||
154
controller/uptime_kuma.go
Normal file
154
controller/uptime_kuma.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/setting/console_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
requestTimeout = 30 * time.Second
|
||||
httpTimeout = 10 * time.Second
|
||||
uptimeKeySuffix = "_24"
|
||||
apiStatusPath = "/api/status-page/"
|
||||
apiHeartbeatPath = "/api/status-page/heartbeat/"
|
||||
)
|
||||
|
||||
type Monitor struct {
|
||||
Name string `json:"name"`
|
||||
Uptime float64 `json:"uptime"`
|
||||
Status int `json:"status"`
|
||||
Group string `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
type UptimeGroupResult struct {
|
||||
CategoryName string `json:"categoryName"`
|
||||
Monitors []Monitor `json:"monitors"`
|
||||
}
|
||||
|
||||
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return errors.New("non-200 status")
|
||||
}
|
||||
|
||||
return json.NewDecoder(resp.Body).Decode(dest)
|
||||
}
|
||||
|
||||
func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[string]interface{}) UptimeGroupResult {
|
||||
url, _ := groupConfig["url"].(string)
|
||||
slug, _ := groupConfig["slug"].(string)
|
||||
categoryName, _ := groupConfig["categoryName"].(string)
|
||||
|
||||
result := UptimeGroupResult{
|
||||
CategoryName: categoryName,
|
||||
Monitors: []Monitor{},
|
||||
}
|
||||
|
||||
if url == "" || slug == "" {
|
||||
return result
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSuffix(url, "/")
|
||||
|
||||
var statusData struct {
|
||||
PublicGroupList []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
MonitorList []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
} `json:"monitorList"`
|
||||
} `json:"publicGroupList"`
|
||||
}
|
||||
|
||||
var heartbeatData struct {
|
||||
HeartbeatList map[string][]struct {
|
||||
Status int `json:"status"`
|
||||
} `json:"heartbeatList"`
|
||||
UptimeList map[string]float64 `json:"uptimeList"`
|
||||
}
|
||||
|
||||
g, gCtx := errgroup.WithContext(ctx)
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
||||
})
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
||||
})
|
||||
|
||||
if g.Wait() != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
for _, pg := range statusData.PublicGroupList {
|
||||
if len(pg.MonitorList) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, m := range pg.MonitorList {
|
||||
monitor := Monitor{
|
||||
Name: m.Name,
|
||||
Group: pg.Name,
|
||||
}
|
||||
|
||||
monitorID := strconv.Itoa(m.ID)
|
||||
|
||||
if uptime, exists := heartbeatData.UptimeList[monitorID+uptimeKeySuffix]; exists {
|
||||
monitor.Uptime = uptime
|
||||
}
|
||||
|
||||
if heartbeats, exists := heartbeatData.HeartbeatList[monitorID]; exists && len(heartbeats) > 0 {
|
||||
monitor.Status = heartbeats[0].Status
|
||||
}
|
||||
|
||||
result.Monitors = append(result.Monitors, monitor)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func GetUptimeKumaStatus(c *gin.Context) {
|
||||
groups := console_setting.GetUptimeKumaGroups()
|
||||
if len(groups) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": []UptimeGroupResult{}})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
|
||||
defer cancel()
|
||||
|
||||
client := &http.Client{Timeout: httpTimeout}
|
||||
results := make([]UptimeGroupResult, len(groups))
|
||||
|
||||
g, gCtx := errgroup.WithContext(ctx)
|
||||
for i, group := range groups {
|
||||
i, group := i, group
|
||||
g.Go(func() error {
|
||||
results[i] = fetchGroupData(gCtx, client, group)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
g.Wait()
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
|
||||
}
|
||||
52
controller/usedata.go
Normal file
52
controller/usedata.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAllQuotaDates(c *gin.Context) {
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
username := c.Query("username")
|
||||
dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": dates,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetUserQuotaDates(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
// 判断时间跨度是否超过 1 个月
|
||||
if endTimestamp-startTimestamp > 2592000 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "时间跨度不能超过 1 个月",
|
||||
})
|
||||
return
|
||||
}
|
||||
dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": dates,
|
||||
})
|
||||
return
|
||||
}
|
||||
956
controller/user.go
Normal file
956
controller/user.go
Normal file
@@ -0,0 +1,956 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"one-api/constant"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func Login(c *gin.Context) {
|
||||
if !common.PasswordLoginEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "管理员关闭了密码登录",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
var loginRequest LoginRequest
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&loginRequest)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "无效的参数",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
username := loginRequest.Username
|
||||
password := loginRequest.Password
|
||||
if username == "" || password == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "无效的参数",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
err = user.ValidateAndFill()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": err.Error(),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
// setup session & cookies and then return user info
|
||||
func setupLogin(user *model.User, c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
session.Set("id", user.Id)
|
||||
session.Set("username", user.Username)
|
||||
session.Set("role", user.Role)
|
||||
session.Set("status", user.Status)
|
||||
session.Set("group", user.Group)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "无法保存会话信息,请重试",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
cleanUser := model.User{
|
||||
Id: user.Id,
|
||||
Username: user.Username,
|
||||
DisplayName: user.DisplayName,
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
Group: user.Group,
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "",
|
||||
"success": true,
|
||||
"data": cleanUser,
|
||||
})
|
||||
}
|
||||
|
||||
func Logout(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
session.Clear()
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": err.Error(),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "",
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
func Register(c *gin.Context) {
|
||||
if !common.RegisterEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "管理员关闭了新用户注册",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
if !common.PasswordRegisterEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
var user model.User
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := common.Validate.Struct(&user); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "输入不合法 " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if common.EmailVerificationEnabled {
|
||||
if user.Email == "" || user.VerificationCode == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员开启了邮箱验证,请输入邮箱地址和验证码",
|
||||
})
|
||||
return
|
||||
}
|
||||
if !common.VerifyCodeWithKey(user.Email, user.VerificationCode, common.EmailVerificationPurpose) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码错误或已过期",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
exist, err := model.CheckUserExistOrDeleted(user.Username, user.Email)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "数据库错误,请稍后重试",
|
||||
})
|
||||
common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
|
||||
return
|
||||
}
|
||||
if exist {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户名已存在,或已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
affCode := user.AffCode // this code is the inviter's code, not the user's own code
|
||||
inviterId, _ := model.GetUserIdByAffCode(affCode)
|
||||
cleanUser := model.User{
|
||||
Username: user.Username,
|
||||
Password: user.Password,
|
||||
DisplayName: user.Username,
|
||||
InviterId: inviterId,
|
||||
}
|
||||
if common.EmailVerificationEnabled {
|
||||
cleanUser.Email = user.Email
|
||||
}
|
||||
if err := cleanUser.Insert(inviterId); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取插入后的用户ID
|
||||
var insertedUser model.User
|
||||
if err := model.DB.Where("username = ?", cleanUser.Username).First(&insertedUser).Error; err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户注册失败或用户ID获取失败",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 生成默认令牌
|
||||
if constant.GenerateDefaultToken {
|
||||
key, err := common.GenerateKey()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成默认令牌失败",
|
||||
})
|
||||
common.SysError("failed to generate token key: " + err.Error())
|
||||
return
|
||||
}
|
||||
// 生成默认令牌
|
||||
token := model.Token{
|
||||
UserId: insertedUser.Id, // 使用插入后的用户ID
|
||||
Name: cleanUser.Username + "的初始令牌",
|
||||
Key: key,
|
||||
CreatedTime: common.GetTimestamp(),
|
||||
AccessedTime: common.GetTimestamp(),
|
||||
ExpiredTime: -1, // 永不过期
|
||||
RemainQuota: 500000, // 示例额度
|
||||
UnlimitedQuota: true,
|
||||
ModelLimitsEnabled: false,
|
||||
}
|
||||
if setting.DefaultUseAutoGroup {
|
||||
token.Group = "auto"
|
||||
}
|
||||
if err := token.Insert(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "创建默认令牌失败",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetAllUsers(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
users, total, err := model.GetAllUsers(pageInfo)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(users)
|
||||
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func SearchUsers(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
group := c.Query("group")
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
users, total, err := model.SearchUsers(keyword, group, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(users)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
func GetUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= user.Role && myRole != common.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权获取同级或更高等级用户的信息",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": user,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GenerateAccessToken(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
// get rand int 28-32
|
||||
randI := common.GetRandomInt(4)
|
||||
key, err := common.GenerateRandomKey(29 + randI)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "生成失败",
|
||||
})
|
||||
common.SysError("failed to generate key: " + err.Error())
|
||||
return
|
||||
}
|
||||
user.SetAccessToken(key)
|
||||
|
||||
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "请重试,系统生成的 UUID 竟然重复了!",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if err := user.Update(false); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": user.AccessToken,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type TransferAffQuotaRequest struct {
|
||||
Quota int `json:"quota" binding:"required"`
|
||||
}
|
||||
|
||||
func TransferAffQuota(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
tran := TransferAffQuotaRequest{}
|
||||
if err := c.ShouldBindJSON(&tran); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
err = user.TransferAffQuotaToQuota(tran.Quota)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "划转失败 " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "划转成功",
|
||||
})
|
||||
}
|
||||
|
||||
func GetAffCode(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, true)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if user.AffCode == "" {
|
||||
user.AffCode = common.GetRandomString(4)
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": user.AffCode,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetSelf(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
||||
user.Remark = ""
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": user,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetUserModels(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
id = c.GetInt("id")
|
||||
}
|
||||
user, err := model.GetUserCache(id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
groups := setting.GetUserUsableGroups(user.Group)
|
||||
var models []string
|
||||
for group := range groups {
|
||||
for _, g := range model.GetGroupEnabledModels(group) {
|
||||
if !common.StringsContains(models, g) {
|
||||
models = append(models, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": models,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateUser(c *gin.Context) {
|
||||
var updatedUser model.User
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&updatedUser)
|
||||
if err != nil || updatedUser.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
if updatedUser.Password == "" {
|
||||
updatedUser.Password = "$I_LOVE_U" // make Validator happy :)
|
||||
}
|
||||
if err := common.Validate.Struct(&updatedUser); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "输入不合法 " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
originUser, err := model.GetUserById(updatedUser.Id, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= originUser.Role && myRole != common.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权更新同权限等级或更高权限等级的用户信息",
|
||||
})
|
||||
return
|
||||
}
|
||||
if myRole <= updatedUser.Role && myRole != common.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
|
||||
})
|
||||
return
|
||||
}
|
||||
if updatedUser.Password == "$I_LOVE_U" {
|
||||
updatedUser.Password = "" // rollback to what it should be
|
||||
}
|
||||
updatePassword := updatedUser.Password != ""
|
||||
if err := updatedUser.Edit(updatePassword); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if originUser.Quota != updatedUser.Quota {
|
||||
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateSelf(c *gin.Context) {
|
||||
var user model.User
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Password == "" {
|
||||
user.Password = "$I_LOVE_U" // make Validator happy :)
|
||||
}
|
||||
if err := common.Validate.Struct(&user); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "输入不合法 " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cleanUser := model.User{
|
||||
Id: c.GetInt("id"),
|
||||
Username: user.Username,
|
||||
Password: user.Password,
|
||||
DisplayName: user.DisplayName,
|
||||
}
|
||||
if user.Password == "$I_LOVE_U" {
|
||||
user.Password = "" // rollback to what it should be
|
||||
cleanUser.Password = ""
|
||||
}
|
||||
updatePassword, err := checkUpdatePassword(user.OriginalPassword, user.Password, cleanUser.Id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if err := cleanUser.Update(updatePassword); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func checkUpdatePassword(originalPassword string, newPassword string, userId int) (updatePassword bool, err error) {
|
||||
var currentUser *model.User
|
||||
currentUser, err = model.GetUserById(userId, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !common.ValidatePasswordAndHash(originalPassword, currentUser.Password) {
|
||||
err = fmt.Errorf("原密码错误")
|
||||
return
|
||||
}
|
||||
if newPassword == "" {
|
||||
return
|
||||
}
|
||||
updatePassword = true
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteUser(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
originUser, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= originUser.Role {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权删除同权限等级或更高权限等级的用户",
|
||||
})
|
||||
return
|
||||
}
|
||||
err = model.HardDeleteUserById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteSelf(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, _ := model.GetUserById(id, false)
|
||||
|
||||
if user.Role == common.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "不能删除超级管理员账户",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err := model.DeleteUserById(id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func CreateUser(c *gin.Context) {
|
||||
var user model.User
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&user)
|
||||
user.Username = strings.TrimSpace(user.Username)
|
||||
if err != nil || user.Username == "" || user.Password == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := common.Validate.Struct(&user); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "输入不合法 " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.DisplayName == "" {
|
||||
user.DisplayName = user.Username
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
if user.Role >= myRole {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法创建权限大于等于自己的用户",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Even for admin users, we cannot fully trust them!
|
||||
cleanUser := model.User{
|
||||
Username: user.Username,
|
||||
Password: user.Password,
|
||||
DisplayName: user.DisplayName,
|
||||
}
|
||||
if err := cleanUser.Insert(0); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type ManageRequest struct {
|
||||
Id int `json:"id"`
|
||||
Action string `json:"action"`
|
||||
}
|
||||
|
||||
// ManageUser Only admin user can do this
|
||||
func ManageUser(c *gin.Context) {
|
||||
var req ManageRequest
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&req)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
Id: req.Id,
|
||||
}
|
||||
// Fill attributes
|
||||
model.DB.Unscoped().Where(&user).First(&user)
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户不存在",
|
||||
})
|
||||
return
|
||||
}
|
||||
myRole := c.GetInt("role")
|
||||
if myRole <= user.Role && myRole != common.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无权更新同权限等级或更高权限等级的用户信息",
|
||||
})
|
||||
return
|
||||
}
|
||||
switch req.Action {
|
||||
case "disable":
|
||||
user.Status = common.UserStatusDisabled
|
||||
if user.Role == common.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法禁用超级管理员用户",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "enable":
|
||||
user.Status = common.UserStatusEnabled
|
||||
case "delete":
|
||||
if user.Role == common.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法删除超级管理员用户",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := user.Delete(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "promote":
|
||||
if myRole != common.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "普通管理员用户无法提升其他用户为管理员",
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Role >= common.RoleAdminUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该用户已经是管理员",
|
||||
})
|
||||
return
|
||||
}
|
||||
user.Role = common.RoleAdminUser
|
||||
case "demote":
|
||||
if user.Role == common.RoleRootUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法降级超级管理员用户",
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Role == common.RoleCommonUser {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该用户已经是普通用户",
|
||||
})
|
||||
return
|
||||
}
|
||||
user.Role = common.RoleCommonUser
|
||||
}
|
||||
|
||||
if err := user.Update(false); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
clearUser := model.User{
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": clearUser,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func EmailBind(c *gin.Context) {
|
||||
email := c.Query("email")
|
||||
code := c.Query("code")
|
||||
if !common.VerifyCodeWithKey(email, code, common.EmailVerificationPurpose) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "验证码错误或已过期",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user := model.User{
|
||||
Id: id.(int),
|
||||
}
|
||||
err := user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.Email = email
|
||||
// no need to check if this email already taken, because we have used verification code to check it
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type topUpRequest struct {
|
||||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
var topUpLock = sync.Mutex{}
|
||||
|
||||
func TopUp(c *gin.Context) {
|
||||
topUpLock.Lock()
|
||||
defer topUpLock.Unlock()
|
||||
req := topUpRequest{}
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
quota, err := model.Redeem(req.Key, id)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": quota,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type UpdateUserSettingRequest struct {
|
||||
QuotaWarningType string `json:"notify_type"`
|
||||
QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
|
||||
WebhookUrl string `json:"webhook_url,omitempty"`
|
||||
WebhookSecret string `json:"webhook_secret,omitempty"`
|
||||
NotificationEmail string `json:"notification_email,omitempty"`
|
||||
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
||||
RecordIpLog bool `json:"record_ip_log"`
|
||||
}
|
||||
|
||||
func UpdateUserSetting(c *gin.Context) {
|
||||
var req UpdateUserSettingRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证预警类型
|
||||
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的预警类型",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证预警阈值
|
||||
if req.QuotaWarningThreshold <= 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "预警阈值必须大于0",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果是webhook类型,验证webhook地址
|
||||
if req.QuotaWarningType == dto.NotifyTypeWebhook {
|
||||
if req.WebhookUrl == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "Webhook地址不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 验证URL格式
|
||||
if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的Webhook地址",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 如果是邮件类型,验证邮箱地址
|
||||
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||
// 验证邮箱格式
|
||||
if !strings.Contains(req.NotificationEmail, "@") {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的邮箱地址",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
user, err := model.GetUserById(userId, true)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 构建设置
|
||||
settings := dto.UserSetting{
|
||||
NotifyType: req.QuotaWarningType,
|
||||
QuotaWarningThreshold: req.QuotaWarningThreshold,
|
||||
AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel,
|
||||
RecordIpLog: req.RecordIpLog,
|
||||
}
|
||||
|
||||
// 如果是webhook类型,添加webhook相关设置
|
||||
if req.QuotaWarningType == dto.NotifyTypeWebhook {
|
||||
settings.WebhookUrl = req.WebhookUrl
|
||||
if req.WebhookSecret != "" {
|
||||
settings.WebhookSecret = req.WebhookSecret
|
||||
}
|
||||
}
|
||||
|
||||
// 如果提供了通知邮箱,添加到设置中
|
||||
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||
settings.NotificationEmail = req.NotificationEmail
|
||||
}
|
||||
|
||||
// 更新用户设置
|
||||
user.SetSetting(settings)
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "更新设置失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "设置已更新",
|
||||
})
|
||||
}
|
||||
168
controller/wechat.go
Normal file
168
controller/wechat.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type wechatLoginResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
func getWeChatIdByCode(code string) (string, error) {
|
||||
if code == "" {
|
||||
return "", errors.New("无效的参数")
|
||||
}
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", common.WeChatServerToken)
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
httpResponse, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer httpResponse.Body.Close()
|
||||
var res wechatLoginResponse
|
||||
err = json.NewDecoder(httpResponse.Body).Decode(&res)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !res.Success {
|
||||
return "", errors.New(res.Message)
|
||||
}
|
||||
if res.Data == "" {
|
||||
return "", errors.New("验证码错误或已过期")
|
||||
}
|
||||
return res.Data, nil
|
||||
}
|
||||
|
||||
func WeChatAuth(c *gin.Context) {
|
||||
if !common.WeChatAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "管理员未开启通过微信登录以及注册",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
wechatId, err := getWeChatIdByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": err.Error(),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
WeChatId: wechatId,
|
||||
}
|
||||
if model.IsWeChatIdAlreadyTaken(wechatId) {
|
||||
err := user.FillUserByWeChatId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
user.DisplayName = "WeChat User"
|
||||
user.Role = common.RoleCommonUser
|
||||
user.Status = common.UserStatusEnabled
|
||||
|
||||
if err := user.Insert(0); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func WeChatBind(c *gin.Context) {
|
||||
if !common.WeChatAuthEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "管理员未开启通过微信登录以及注册",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
wechatId, err := getWeChatIdByCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": err.Error(),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
if model.IsWeChatIdAlreadyTaken(wechatId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该微信账号已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user := model.User{
|
||||
Id: id.(int),
|
||||
}
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.WeChatId = wechatId
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
})
|
||||
return
|
||||
}
|
||||
52
docker-compose.yml
Normal file
52
docker-compose.yml
Normal file
@@ -0,0 +1,52 @@
|
||||
version: '3.4'
|
||||
|
||||
services:
|
||||
new-api:
|
||||
image: calciumion/new-api:latest
|
||||
container_name: new-api
|
||||
restart: always
|
||||
command: --log-dir /app/logs
|
||||
ports:
|
||||
- "3000:3000"
|
||||
volumes:
|
||||
- ./data:/data
|
||||
- ./logs:/app/logs
|
||||
environment:
|
||||
- SQL_DSN=root:123456@tcp(mysql:3306)/new-api # Point to the mysql service
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
- TZ=Asia/Shanghai
|
||||
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
||||
# - STREAMING_TIMEOUT=120 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值
|
||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
|
||||
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||
# - FRONTEND_BASE_URL=https://openai.justsong.cn # Uncomment for multi-node deployment with front-end URL
|
||||
|
||||
depends_on:
|
||||
- redis
|
||||
- mysql
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
redis:
|
||||
image: redis:latest
|
||||
container_name: redis
|
||||
restart: always
|
||||
|
||||
mysql:
|
||||
image: mysql:8.2
|
||||
container_name: mysql
|
||||
restart: always
|
||||
environment:
|
||||
MYSQL_ROOT_PASSWORD: 123456 # Ensure this matches the password in SQL_DSN
|
||||
MYSQL_DATABASE: new-api
|
||||
volumes:
|
||||
- mysql_data:/var/lib/mysql
|
||||
# ports:
|
||||
# - "3306:3306" # If you want to access MySQL from outside Docker, uncomment
|
||||
|
||||
volumes:
|
||||
mysql_data:
|
||||
53
docs/api/api_auth.md
Normal file
53
docs/api/api_auth.md
Normal file
@@ -0,0 +1,53 @@
|
||||
# API 鉴权文档
|
||||
|
||||
## 认证方式
|
||||
|
||||
### Access Token
|
||||
|
||||
对于需要鉴权的 API 接口,必须同时提供以下两个请求头来进行 Access Token 认证:
|
||||
|
||||
1. **请求头中的 `Authorization` 字段**
|
||||
|
||||
将 Access Token 放置于 HTTP 请求头部的 `Authorization` 字段中,格式如下:
|
||||
|
||||
```
|
||||
Authorization: <your_access_token>
|
||||
```
|
||||
|
||||
其中 `<your_access_token>` 需要替换为实际的 Access Token 值。
|
||||
|
||||
2. **请求头中的 `New-Api-User` 字段**
|
||||
|
||||
将用户 ID 放置于 HTTP 请求头部的 `New-Api-User` 字段中,格式如下:
|
||||
|
||||
```
|
||||
New-Api-User: <your_user_id>
|
||||
```
|
||||
|
||||
其中 `<your_user_id>` 需要替换为实际的用户 ID。
|
||||
|
||||
**注意:**
|
||||
|
||||
* **必须同时提供 `Authorization` 和 `New-Api-User` 两个请求头才能通过鉴权。**
|
||||
* 如果只提供其中一个请求头,或者两个请求头都未提供,则会返回 `401 Unauthorized` 错误。
|
||||
* 如果 `Authorization` 中的 Access Token 无效,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,access token 无效”。
|
||||
* 如果 `New-Api-User` 中的用户 ID 与 Access Token 不匹配,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,与登录用户不匹配,请重新登录”。
|
||||
* 如果没有提供 `New-Api-User` 请求头,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,未提供 New-Api-User”。
|
||||
* 如果 `New-Api-User` 请求头格式错误,则会返回 `401 Unauthorized` 错误,并提示“无权进行此操作,New-Api-User 格式错误”。
|
||||
* 如果用户已被禁用,则会返回 `403 Forbidden` 错误,并提示“用户已被封禁”。
|
||||
* 如果用户权限不足,则会返回 `403 Forbidden` 错误,并提示“无权进行此操作,权限不足”。
|
||||
* 如果用户信息无效,则会返回 `403 Forbidden` 错误,并提示“无权进行此操作,用户信息无效”。
|
||||
|
||||
## Curl 示例
|
||||
|
||||
假设您的 Access Token 为 `access_token`,用户 ID 为 `123`,要访问的 API 接口为 `/api/user/self`,则可以使用以下 curl 命令:
|
||||
|
||||
```bash
|
||||
curl -X GET \
|
||||
-H "Authorization: access_token" \
|
||||
-H "New-Api-User: 123" \
|
||||
https://your-domain.com/api/user/self
|
||||
```
|
||||
|
||||
请将 `access_token`、`123` 和 `https://your-domain.com` 替换为实际的值。
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user